小编tsv*_*iko的帖子

如何按特定值过滤tf.data.Dataset?

我通过读取TFRecords来创建数据集,我映射了值,我想过滤数据集中的特定值,但由于结果是带有张量的dict,我无法获得张量的实际值或检查它用tf.cond()/ tf.equal.我怎样才能做到这一点?

def mapping_func(serialized_example):
    feature = { 'label': tf.FixedLenFeature([1], tf.string) }
    features = tf.parse_single_example(serialized_example, features=feature)
    return features

def filter_func(features):
    # this doesn't work
    #result = features['label'] == 'some_label_value'
    # neither this
    result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
    return result

def main():
    file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
    dataset = tf.contrib.data.TFRecordDataset(file_names)
    dataset = dataset.map(mapping_func)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.filter(filter_func)
    dataset = dataset.repeat()
    iterator = dataset.make_one_shot_iterator()
    sample = iterator.get_next()
Run Code Online (Sandbox Code Playgroud)

python tensorflow tensorflow-datasets

9
推荐指数
1
解决办法
3664
查看次数

标签 统计

python ×1

tensorflow ×1

tensorflow-datasets ×1