Ami*_*rHJ 15 python functional-programming tensorflow tensorflow-datasets
我正在使用Tensorflow v1.3中的Dataset API.这很棒.可以使用此处描述的函数映射数据集.我很想知道如何传递一个具有附加参数的函数,例如arg1:
def _parse_function(example_proto, arg1):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int32, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"]
Run Code Online (Sandbox Code Playgroud)
当然,
dataset = dataset.map(_parse_function)
Run Code Online (Sandbox Code Playgroud)
因为没有办法传递,所以行不通arg1.
mik*_*ola 22
下面是一个使用lambda表达式来包装我们想要传递参数的函数的示例:
import tensorflow as tf
def fun(x, arg):
return x * arg
my_arg = tf.constant(2, dtype=tf.int64)
ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: fun(x, my_arg))
Run Code Online (Sandbox Code Playgroud)
在上面,提供的函数的签名map必须与我们的数据集的内容匹配.所以我们必须编写我们的lambda表达式来匹配它.这里很简单,因为数据集中只包含一个元素,x其中包含0到4范围内的元素.
如有必要,您可以从数据集外部传递任意数量的外部参数:ds = ds.map(lambda x: my_other_fun(x, arg1, arg2, arg3)等等.
为了验证上述工作,我们可以观察到映射确实将每个数据集元素乘以2:
iterator = ds.make_initializable_iterator()
next_x = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer)
while True:
try:
print(sess.run(next_x))
except tf.errors.OutOfRangeError:
break
Run Code Online (Sandbox Code Playgroud)
输出:
0
2
4
6
8
Run Code Online (Sandbox Code Playgroud)
小智 5
您还可以使用Partial函数来包装参数:
def _parse_function(arg1, example_proto):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int32, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"]
Run Code Online (Sandbox Code Playgroud)
更改函数的参数顺序以适应偏向性,然后您可以使用如下参数值包装函数:
from functools import partial
arg1 = ...
dataset = dataset.map(partial(_parse_function, arg1))
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
4794 次 |
| 最近记录: |