如何使用TF1.3中的新数据集api映射具有附加参数的函数?

Ami*_*rHJ 15 python functional-programming tensorflow tensorflow-datasets

我正在使用Tensorflow v1.3中的Dataset API.这很棒.可以使用此处描述的函数映射数据集.我很想知道如何传递一个具有附加参数的函数,例如arg1:

def _parse_function(example_proto, arg1):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]
Run Code Online (Sandbox Code Playgroud)

当然,

dataset = dataset.map(_parse_function)
Run Code Online (Sandbox Code Playgroud)

因为没有办法传递,所以行不通arg1.

mik*_*ola 22

下面是一个使用lambda表达式来包装我们想要传递参数的函数的示例:

import tensorflow as tf
def fun(x, arg):
    return x * arg

my_arg = tf.constant(2, dtype=tf.int64)
ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: fun(x, my_arg))
Run Code Online (Sandbox Code Playgroud)

在上面,提供的函数的签名map必须与我们的数据集的内容匹配.所以我们必须编写我们的lambda表达式来匹配它.这里很简单,因为数据集中只包含一个元素,x其中包含0到4范围内的元素.

如有必要,您可以从数据集外部传递任意数量的外部参数:ds = ds.map(lambda x: my_other_fun(x, arg1, arg2, arg3)等等.

为了验证上述工作,我们可以观察到映射确实将每个数据集元素乘以2:

iterator = ds.make_initializable_iterator()
next_x = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)

    while True:
      try:
        print(sess.run(next_x))
      except tf.errors.OutOfRangeError:
        break
Run Code Online (Sandbox Code Playgroud)

输出:

0
2
4
6
8
Run Code Online (Sandbox Code Playgroud)


小智 5

您还可以使用Partial函数来包装参数:

def _parse_function(arg1, example_proto):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]
Run Code Online (Sandbox Code Playgroud)

更改函数的参数顺序以适应偏向性,然后您可以使用如下参数值包装函数:

from functools import partial

arg1 = ...
dataset = dataset.map(partial(_parse_function, arg1))
Run Code Online (Sandbox Code Playgroud)

  • functools.partial 可以被tensorflow的图执行理解/转换吗? (2认同)