问题
我正在尝试基于包含我希望子集化的索引的 numpy 数组来过滤 Tensorflow 2.4 数据集。该数据集有 1000 张形状为 (28,28,1) 的图像。
玩具示例代码
m_X_ds = tf.data.Dataset.from_tensor_slices(list(range(1, 21))).shuffle(10, reshuffle_each_iteration=False)
arr = np.array([3, 4, 5])
m_X_ds = tf.gather(m_X_ds, arr) # This is the offending code
Run Code Online (Sandbox Code Playgroud)
错误信息
ValueError: Attempt to convert a value (<ShuffleDataset shapes: (), types: tf.int32>) with an unsupported type (<class 'tensorflow.python.data.ops.dataset_ops.ShuffleDataset'>) to a Tensor.
Run Code Online (Sandbox Code Playgroud)
迄今为止的研究
我发现了这个和这个,但它们是特定于它们的用例的,而我正在寻找一种更通用的子集方法(即基于外部派生的索引数组)。
我对 Tensorflow 数据集非常陌生,迄今为止发现学习曲线相当陡峭。希望能得到一些帮助。提前致谢!
建议使用选项 C,定义如下。
创建该tf.data.Dataset对象是为了不必将所有对象加载到内存中。因此,tf.gather默认情况下使用不会起作用。您可以选择以下三个选项:
如果您想使用收集,则必须将整个数据集加载到内存中,并选择一个子集:
m_X_ds = list(m_X_ds) # Load into memory.
m_X_ds = tf.gather(m_X_ds, arr)) # Gather as usual.
print(m_X_ds)
# Example result: <tf.Tensor: shape=(3,), dtype=int32, numpy=array([8, 6, 2], dtype=int32)>
Run Code Online (Sandbox Code Playgroud)
请注意,这并不总是可能的,尤其是对于庞大的数据集。
您还可以迭代数据集并手动选择具有所需索引的样本。这可以通过过滤器和tf.py_function的组合来实现
m_X_ds = m_X_ds.enumerate() # Create index,value pairs in the dataset.
# Create filter function:
def filter_fn(idx, value):
return idx in arr
# The above is not going to work in graph mode
# We are wrapping it with py_function to execute it eagerly
def py_function_filter(idx, value):
return tf.py_function(filter_fn, (idx, value), tf.bool)
# Filter the dataset as usual:
filtered_ds = m_X_ds.filter(py_function_filter)
Run Code Online (Sandbox Code Playgroud)
选项 B 很好,除了在转换图张量 -> 急切张量 -> 图张量时你可以预期性能会受到影响。tf.py_function很有用,但要付出代价。
相反,我们可以声明一个字典,其中所需的索引将返回 true,不存在的索引可能返回 false。这个字典可能看起来像这样
my_table = {3: True, 4: True, 5: True}.
Run Code Online (Sandbox Code Playgroud)
我们不能使用张量作为字典键,但我们可以声明张量流的哈希表来让我们检查“好”索引。
m_X_ds = m_X_ds.enumerate() # Do not repeat this if executed in Option B.
keys_tensor = tf.constant(arr)
vals_tensor = tf.ones_like(keys_tensor) # Ones will be casted to True.
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
default_value=0) # If index not in table, return 0.
def hash_table_filter(index, value):
table_value = table.lookup(index) # 1 if index in arr, else 0.
index_in_arr = tf.cast(table_value, tf.bool) # 1 -> True, 0 -> False
return index_in_arr
filtered_ds = m_X_ds.filter(hash_table_filter)
Run Code Online (Sandbox Code Playgroud)
无论选项 B 或 C,剩下的就是从 fileterd 数据集中删除索引。我们可以使用带有 lambda 函数的简单映射:
final_ds = filtered_ds.map(lambda idx,value: value)
for entry in final_ds:
print(entry)
# tf.Tensor(7, shape=(), dtype=int32)
# tf.Tensor(13, shape=(), dtype=int32)
# tf.Tensor(6, shape=(), dtype=int32)
Run Code Online (Sandbox Code Playgroud)
祝你好运。
| 归档时间: |
|
| 查看次数: |
2369 次 |
| 最近记录: |