n0o*_*der 2 python-3.x tensorflow tensorflow-datasets tensorflow2.0 tf.data.dataset
我正在使用张量流 2.3.0
我有一个 python 数据生成器-
import tensorflow as tf
import numpy as np
vocab = [1,2,3,4,5]
def create_generator():
'generates a random number from 0 to len(vocab)-1'
count = 0
while count < 4:
x = np.random.randint(0, len(vocab))
yield x
count +=1
Run Code Online (Sandbox Code Playgroud)
我把它变成了一个 tf.data.Dataset 对象
gen = tf.data.Dataset.from_generator(create_generator,
args=[],
output_types=tf.int32,
output_shapes = (), )
Run Code Online (Sandbox Code Playgroud)
现在我想使用map方法对项目进行子采样,这样 tf 生成器永远不会输出任何偶数。
def subsample(x):
'remove item if it is present in an even number [2,4]'
'''
#TODO
'''
return x
gen = gen.map(subsample)
Run Code Online (Sandbox Code Playgroud)
如何使用map方法实现这一点?
很快,您无法使用map
. Map 函数对数据集的每个元素应用一些转换。您想要的是检查某个谓词的每个元素,并仅获取满足谓词的那些元素。
而那个函数是filter()
.
所以你可以这样做:
gen = gen.filter(lambda x: x % 2 != 0)
Run Code Online (Sandbox Code Playgroud)
更新:
如果您想使用自定义函数而不是lambda
,您可以执行以下操作:
def filter_func(x):
if x**2 < 500:
return True
return False
gen = gen.filter(filter_func)
Run Code Online (Sandbox Code Playgroud)
如果将此函数传递给filter
所有平方小于 500 的数字,则将返回。
归档时间: |
|
查看次数: |
534 次 |
最近记录: |