小编Fra*_*cob的帖子

将 TensorFlow 张量转换为 Numpy 数组

问题描述

我正在尝试在 TensorFlow 2.3.0 中编写自定义损失函数。要计算损失,我需要将y_pred参数转换为 numpy 数组。但是,我找不到将其从<class 'tensorflow.python.framework.ops.Tensor'>numpy 数组转换的方法,即使 TensorFlow 函数似乎可以这样做。

代码示例

def custom_loss(y_true, y_pred):
    print(type(y_pred))
    npa = y_pred.make_ndarray()
    ...
    

if __name__ == '__main__':
    ...
    model.compile(loss=custom_loss, optimizer="adam")
    model.fit(x=train_data, y=train_data, epochs=10)
Run Code Online (Sandbox Code Playgroud)

给出错误信息:AttributeError: 'Tensor' object has no attribute 'make_ndarray 打印y_pred参数类型后:<class 'tensorflow.python.framework.ops.Tensor'>

到目前为止我尝试过的

在寻找解决方案时,我发现这似乎是一个常见问题,并且有一些建议,但到目前为止它们对我不起作用:

1.“...所以只需在 Tensor 对象上调用 .numpy()。”:如何在 TensorFlow 中将张量转换为 numpy 数组?

所以我试过:

def custom_loss(y_true, y_pred):
    npa = y_pred.numpy()
    ...
Run Code Online (Sandbox Code Playgroud)

给我 AttributeError: 'Tensor' object has no attribute 'numpy'

2.“使用tensorflow.Tensor.eval() to convert a tensor to …

python keras tensorflow tensorflow2.0

8
推荐指数
1
解决办法
5389
查看次数

Tensorflow 2.1 错误“当最终确定 GeneratorDataset 迭代器时” - 可能是我的生成器中的内存泄漏,但如何缩小范围?

问题

我在 Centos Linux 下使用 TensorFlow 2.1.0 进行图像分类。随着我的图像训练数据集不断增长,我必须开始使用生成器,因为我没有足够的 RAM 来保存所有图片。我已经根据本教程对生成器进行了编码。

它似乎工作正常,直到我的程序突然被杀死而没有错误消息:

Epoch 6/30
2020-03-08 13:28:11.361785: W tensorflow/core/kernels/data/generator_dataset_op.cc:103] Error occurred when finalizing GeneratorDataset iterator: Cancelled: Operation was cancelled
43/43 [==============================] - 54s 1s/step - loss: 5.6839 - accuracy: 0.4669
Epoch 7/30
2020-03-08 13:29:05.511813: W tensorflow/core/kernels/data/generator_dataset_op.cc:103] Error occurred when finalizing GeneratorDataset iterator: Cancelled: Operation was cancelled
 7/43 [===>..........................] - ETA: 1:04 - loss: 4.3953 - accuracy: 0.5268Killed
Run Code Online (Sandbox Code Playgroud)

看着 linux 的 top 不断增长的内存消耗,我怀疑是内存泄漏?

我试过的

  • 在这里建议,切换到 TF nightly build 版本会有所帮助。对我来说没有,降级到 TF2.0.1 …

python memory-leaks keras tensorflow

6
推荐指数
0
解决办法
2274
查看次数

标签 统计

keras ×2

python ×2

tensorflow ×2

memory-leaks ×1

tensorflow2.0 ×1