tsv*_*iko 9 python tensorflow tensorflow-datasets
我通过读取TFRecords来创建数据集,我映射了值,我想过滤数据集中的特定值,但由于结果是带有张量的dict,我无法获得张量的实际值或检查它用tf.cond()/ tf.equal.我怎样才能做到这一点?
def mapping_func(serialized_example):
feature = { 'label': tf.FixedLenFeature([1], tf.string) }
features = tf.parse_single_example(serialized_example, features=feature)
return features
def filter_func(features):
# this doesn't work
#result = features['label'] == 'some_label_value'
# neither this
result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
return result
def main():
file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(file_names)
dataset = dataset.map(mapping_func)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.filter(filter_func)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
sample = iterator.get_next()
Run Code Online (Sandbox Code Playgroud)
我正在回答我自己的问题。我发现了问题!
我需要做的是这样tf.unstack()的标签:
label = tf.unstack(features['label'])
label = label[0]
Run Code Online (Sandbox Code Playgroud)
在我把它交给tf.equal():
result = tf.reshape(tf.equal(label, 'some_label_value'), [])
Run Code Online (Sandbox Code Playgroud)
我想问题是标签被定义为一个数组,其中包含一个 string 类型的元素tf.FixedLenFeature([1], tf.string),因此为了获得第一个和单个元素,我必须解压缩它(这会创建一个列表),然后获取索引为 0 的元素,如果我错了纠正我。
| 归档时间: |
|
| 查看次数: |
3664 次 |
| 最近记录: |