在TensorFlow图中使用if条件

jea*_*ean 2 random deep-learning tensorflow

在tensorflow CIFAR-10 教程cifar10_inputs.py线174它是说你应该随机化操作的顺序random_contrast和random_brightness更好的数据增强.

为此,我想到的第一件事是从0和1之间的均匀分布中绘制一个随机变量:p_order.并做:

if p_order>0.5:
  distorted_image=tf.image.random_contrast(image)
  distorted_image=tf.image.random_brightness(distorted_image)
else:
  distorted_image=tf.image.random_brightness(image)
  distorted_image=tf.image.random_contrast(distorted_image)
Run Code Online (Sandbox Code Playgroud)

但是获取p_order有两种可能的选择:

1)使用numpy不满意我,因为我想要纯TF和TF阻止其用户混合numpy和tensorflow

2)使用TF,但是因为p_order只能在tf.Session()中进行评估,所以我真的不知道是否应该这样做:

with tf.Session() as sess2:
  p_order_tensor=tf.random_uniform([1,],0.,1.)
  p_order=float(p_order_tensor.eval())
Run Code Online (Sandbox Code Playgroud)

所有这些操作都在函数体内,并从另一个具有不同会话/图形的脚本运行.或者我可以将其他脚本中的图形作为参数传递给此函数,但我很困惑.即使像tensorflow这样的函数或者例如推论似乎以全局方式定义图形而没有明确地将其作为输出返回,这对我来说有点难以理解.

Oli*_*rot 12

你可以使用tf.cond(pred, fn1, fn2, name=None)(见文档).此函数允许您使用predTensorFlow图形内部的布尔值(无需调用self.eval()sess.run()因此不需要会话).

以下是如何使用它的示例:

def fn1():
    distorted_image=tf.image.random_contrast(image)
    distorted_image=tf.image.random_brightness(distorted_image)
    return distorted_image
def fn2():
    distorted_image=tf.image.random_brightness(image)
    distorted_image=tf.image.random_contrast(distorted_image)
    return distorted_image

# Uniform variable in [0,1)
p_order = tf.random_uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
pred = tf.less(p_order, 0.5)

distorted_image = tf.cond(pred, fn1, fn2)
Run Code Online (Sandbox Code Playgroud)