3vo*_*voC 7 unit-testing tensorflow
我有一些测试可以使用图形和会话.我还想用热切模式编写一些小测试来轻松测试一些功能.例如:
def test_normal_execution():
matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
iterator = dataset.make_one_shot_iterator()
first_elem = iterator.get_next()
with tf.Session() as sess:
result = sess.run(first_elem)
assert (result == [1, 2, 3, 4]).all()
sess.close()
Run Code Online (Sandbox Code Playgroud)
在另一个文件中:
def test_eager_execution():
matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
tf.enable_eager_execution()
dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
iterator = dataset.__iter__()
first_elem = iterator.next()
assert (first_elem.numpy() == [1, 2, 3, 4]).all()
Run Code Online (Sandbox Code Playgroud)
有办法吗?ValueError: tf.enable_eager_execution must be called at program startup.当我试图急切地执行测试时,我得到了.我正在使用pytest我的测试.
编辑:
在接受的响应的帮助下,我创建了一个装饰器,它与eager模式和pytest的灯具很好地配合:
def run_eagerly(func):
@functools.wraps(func)
def eager_fun(*args, **kwargs):
with tf.Session() as sess:
sess.run(tfe.py_func(func, inp=list(kwargs.values()), Tout=[]))
return eager_fun
Run Code Online (Sandbox Code Playgroud)
需要注意的是,tf.contrib命名空间中的任何内容都可能会在发行版之间发生变化,您可以使用它来装饰您的测试@tf.contrib.eager.run_test_in_graph_and_eager_modes.其他一些项目,如TensorFlow Probability 似乎也在使用它.
对于非测试,需要注意的一些事项是:
tf.contrib.eager.defun:当您启用了预先执行但希望将某些计算"编译"到图形中以从内存和/或性能优化中受益时,这非常有用.tf.contrib.eager.py_func:当没有启用急切执行但希望在图形中执行某些计算作为Python函数时,它很有用.有人可能会质疑不允许tf.enable_eager_execution()撤消电话的原因.这个想法是库作者不应该调用它,只有最终用户才能调用它main().减少了库以不兼容的方式编写的可能性(其中一个库中的函数禁用急切执行并返回符号张量,而另一个库中的函数启用急切执行并期望具体的值张量.这会使库混合成问题).
希望有所帮助
有一种在图形环境中使用急切执行的官方方法。但我不确定这对您来说是否足够好和方便,因为您需要编写相当多的代码来包装和运行测试函数。不管怎样,这是你的例子,至少应该有效:
import numpy as np
import tensorflow as tf
def test_normal_execution():
matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
iterator = dataset.make_one_shot_iterator()
first_elem = iterator.get_next()
with tf.Session() as sess:
result = sess.run(first_elem)
assert (result == [1, 2, 3, 4]).all()
sess.close()
def test_eager_execution():
matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
iterator = dataset.__iter__()
first_elem = iterator.next()
assert (first_elem.numpy() == [1, 2, 3, 4]).all()
test_normal_execution()
# test_eager_execution() # Instead, you have to use the following three lines.
with tf.Session() as sess:
tfe = tf.contrib.eager
sess.run(tfe.py_func(test_eager_execution, [], []))
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
879 次 |
| 最近记录: |