Numba可以和Tensorflow一起使用吗?

Ron*_*hen 9 python numpy numba tensorflow

Numba可以用来编译与Tensorflow接口的Python代码吗?即Tensorflow宇宙之外的计算将与Numba一起运行以获得速度.我没有找到任何有关如何执行此操作的资源.

Kam*_*min 7

您可以使用tf.numpy_functiontf.py_func来包装 python 函数并将其用作 TensorFlow 操作。这是我使用的一个例子:

@jit
def dice_coeff_nb(y_true, y_pred):
    "Calculates dice coefficient"
    smooth = np.float32(1)
    y_true_f = np.reshape(y_true, [-1])
    y_pred_f = np.reshape(y_pred, [-1])
    intersection = np.sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (np.sum(y_true_f) +
                                            np.sum(y_pred_f) + smooth)
    return score

@jit
def dice_loss_nb(y_true, y_pred):
    "Calculates dice loss"
    loss = 1 - dice_coeff_nb(y_true, y_pred)
    return loss

def bce_dice_loss_nb(y_true, y_pred):
    "Adds dice_loss to categorical_crossentropy"
    loss =  tf.numpy_function(dice_loss_nb, [y_true, y_pred], tf.float64) + \
            tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    return loss
Run Code Online (Sandbox Code Playgroud)

然后我使用这个损失函数来训练 tf.keras 模型:

...
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss=bce_dice_loss_nb)
Run Code Online (Sandbox Code Playgroud)


Ioa*_*des 5

我知道这并不能直接回答你的问题,但这可能是一个不错的选择。Numba 使用即时 (JIT) 编译。因此,您可以按照官方 TensorFlow 文档中的说明进行操作了解如何在 TensorFlow 中使用 JIT(但不在 Numba 生态系统中)。