TensorFlow 数据集 API 中的 IDE 断点映射 py_function?

gol*_*enk 13 python tensorflow tensorflow-datasets tensorflow2.0

我正在使用Tensorflow Dataset API来准备我的数据以输入到我的网络中。在这个过程中,我有一些自定义的 Python 函数,它们使用tf.py_function. 我希望能够调试进入这些函数的数据以及这些函数内的数据会发生什么。当 apy_function被调用时,它会回调到主要的 Python 进程(根据这个答案)。由于此函数在 Python 中,并且在主进程中,我希望常规 IDE 断点能够在此进程中停止。但是,情况似乎并非如此(下面的示例中断点不会停止执行)。有没有办法py_function在 Dataset 使用的a 中放入断点map

断点不停止执行的示例

import tensorflow as tf

def add_ten(example, label):
    example_plus_ten = example + 10  # Breakpoint here.
    return example_plus_ten, label

examples = [10, 20, 30, 40, 50, 60, 70, 80]
labels =   [ 0,  0,  1,  1,  1,  1,  0,  0]

examples_dataset = tf.data.Dataset.from_tensor_slices(examples)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
dataset = tf.data.Dataset.zip((examples_dataset, labels_dataset))
dataset = dataset.map(map_func=lambda example, label: tf.py_function(func=add_ten, inp=[example, label],
                                                                     Tout=[tf.int32, tf.int32]))
dataset = dataset.batch(2)
example_and_label = next(iter(dataset))
Run Code Online (Sandbox Code Playgroud)

Dan*_*aun 11

Tensorflow 2.0 implementation of tf.data.Dataset opens a C threads for each call without notifying your debugger. Use pydevd's to manually set a tracing function that will connect to your default debugger server and start feeding it the debug data.

import pydevd
pydevd.settrace()
Run Code Online (Sandbox Code Playgroud)

Example with your code:

import tensorflow as tf
import pydevd

def add_ten(example, label):
    pydevd.settrace(suspend=False)
    example_plus_ten = example + 10  # Breakpoint here.
    return example_plus_ten, label

examples = [10, 20, 30, 40, 50, 60, 70, 80]
labels =   [ 0,  0,  1,  1,  1,  1,  0,  0]

examples_dataset = tf.data.Dataset.from_tensor_slices(examples)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
dataset = tf.data.Dataset.zip((examples_dataset, labels_dataset))
dataset = dataset.map(map_func=lambda example, label: tf.py_function(func=add_ten, inp=[example, label],
                                                                     Tout=[tf.int32, tf.int32]))
dataset = dataset.batch(2)
example_and_label = next(iter(dataset))
Run Code Online (Sandbox Code Playgroud)

Note: If you are using IDE which already bundles pydevd (such as PyDev or PyCharm) you do not have to install pydevd separately, it will picked up during the debug session.

  • @DanielBraun:我将 PyCharm 与tensorflow 2.4.1(在 Windows 中)一起使用,在上面的示例中设置断点不起作用。但是,如果我理解正确的话,它应该停止,因为 pycharm 已经捆绑了 pydevd,或者? (2认同)