如何在tf.data.Dataset.map()中使用Keras的predict_on_batch?

And*_*ers 3 python keras tensorflow tensorflow-datasets tensorflow2.0

我想找到一种在内部使用 Keras的predict_on_batch方法tf.data.Dataset.map()TF2.0.

假设我有一个 numpy 数据集

n_data = 10**5
my_data    = np.random.random((n_data,10,1))
my_targets = np.random.randint(0,2,(n_data,1))

data = ({'x_input':my_data}, {'target':my_targets})
Run Code Online (Sandbox Code Playgroud)

和一个tf.keras模型

x_input = Input((None,1), name = 'x_input')
RNN     = SimpleRNN(100,  name = 'RNN')(x_input)
dense   = Dense(1, name = 'target')(RNN)

my_model = Model(inputs = [x_input], outputs = [dense])
my_model.compile(optimizer='SGD', loss = 'binary_crossentropy')
Run Code Online (Sandbox Code Playgroud)

我可以创建一个dataset批处理

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(10)
prediction_dataset = dataset.map(transform_predictions)
Run Code Online (Sandbox Code Playgroud)

其中transform_predictions是用户定义的函数,用于获取预测predict_on_batch

def transform_predictions(inputs, outputs):
    predictions = my_model.predict_on_batch(inputs)
    # predictions = do_transformations_here(predictions)
    return predictions
Run Code Online (Sandbox Code Playgroud)

这给出了一个错误predict_on_batch

AttributeError: 'Tensor' object has no attribute 'numpy'

据我了解,predict_on_batch需要一个 numpy 数组,并且它正在从数据集中获取一个张量对象。

似乎一种可能的解决方案是包装predict_on_batch在“tf.py_function”中,尽管我也无法使其工作。

有谁知道如何做到这一点?

Man*_*han 5

Dataset.map() 返回 <class 'tensorflow.python.framework.ops.Tensor'> 没有 numpy() 方法。

迭代 Dataset 返回 <class 'tensorflow.python.framework.ops.EagerTensor'> 具有 numpy() 方法。

将急切的张量提供给 Predict() 系列方法效果很好。

你可以尝试这样的事情:

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(10)

for x,y in dataset:
    predictions = my_model.predict_on_batch(x['x_input'])
    #or 
    predictions = my_model.predict_on_batch(x)
Run Code Online (Sandbox Code Playgroud)