按 id 过滤 Tensorflow 数据集

Nov*_*rog 2 python tensorflow

问题

我正在尝试基于包含我希望子集化的索引的 numpy 数组来过滤 Tensorflow 2.4 数据集。该数据集有 1000 张形状为 (28,28,1) 的图像。

玩具示例代码

m_X_ds = tf.data.Dataset.from_tensor_slices(list(range(1, 21))).shuffle(10, reshuffle_each_iteration=False)
arr = np.array([3, 4, 5])
m_X_ds = tf.gather(m_X_ds, arr)  # This is the offending code
Run Code Online (Sandbox Code Playgroud)

错误信息

ValueError: Attempt to convert a value (<ShuffleDataset shapes: (), types: tf.int32>) with an unsupported type (<class 'tensorflow.python.data.ops.dataset_ops.ShuffleDataset'>) to a Tensor.
Run Code Online (Sandbox Code Playgroud)

迄今为止的研究

我发现了这个这个,但它们是特定于它们的用例的,而我正在寻找一种更通用的子集方法(即基于外部派生的索引数组)。

我对 Tensorflow 数据集非常陌生,迄今为止发现学习曲线相当陡峭。希望能得到一些帮助。提前致谢!

seb*_*-sz 8

长话短说

建议使用选项 C,定义如下。

完整答案

创建该tf.data.Dataset对象是为了不必将所有对象加载到内存中。因此,tf.gather默认情况下使用不会起作用。您可以选择以下三个选项:

选项A:将ds加载到内存+tf.gather

如果您想使用收集,则必须将整个数据集加载到内存中,并选择一个子集:

m_X_ds = list(m_X_ds)  # Load into memory.
m_X_ds = tf.gather(m_X_ds, arr))  # Gather as usual.
print(m_X_ds)  
# Example result: <tf.Tensor: shape=(3,), dtype=int32, numpy=array([8, 6, 2], dtype=int32)>
Run Code Online (Sandbox Code Playgroud)

请注意,这并不总是可能的,尤其是对于庞大的数据集。

选项 B:迭代数据集,并过滤不需要的样本

您还可以迭代数据集并手动选择具有所需索引的样本。这可以通过过滤器tf.py_function的组合来实现

m_X_ds = m_X_ds.enumerate()  # Create index,value pairs in the dataset.

# Create filter function:
def filter_fn(idx, value):
    return idx in arr

# The above is not going to work in graph mode
# We are wrapping it with py_function to execute it eagerly
def py_function_filter(idx, value):
    return tf.py_function(filter_fn, (idx, value), tf.bool)

# Filter the dataset as usual:
filtered_ds = m_X_ds.filter(py_function_filter)
Run Code Online (Sandbox Code Playgroud)

选项 C:将选项 B 与 tf.lookup.StaticHashTable 结合起来

选项 B 很好,除了在转换图张量 -> 急切张量 -> 图张量时你可以预期性能会受到影响。tf.py_function很有用,但要付出代价。

相反,我们可以声明一个字典,其中所需的索引将返回 true,不存在的索引可能返回 false。这个字典可能看起来像这样

my_table = {3: True, 4: True, 5: True}.
Run Code Online (Sandbox Code Playgroud)

我们不能使用张量作为字典键,但我们可以声明张量流的哈希表来让我们检查“好”索引。

m_X_ds = m_X_ds.enumerate()  # Do not repeat this if executed in Option B.

keys_tensor = tf.constant(arr)
vals_tensor = tf.ones_like(keys_tensor)  # Ones will be casted to True.

table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
    default_value=0)  # If index not in table, return 0.


def hash_table_filter(index, value):
    table_value = table.lookup(index)  # 1 if index in arr, else 0.
    index_in_arr =  tf.cast(table_value, tf.bool) # 1 -> True, 0 -> False
    return index_in_arr

filtered_ds = m_X_ds.filter(hash_table_filter)
Run Code Online (Sandbox Code Playgroud)

无论选项 B 或 C,剩下的就是从 fileterd 数据集中删除索引。我们可以使用带有 lambda 函数的简单映射:

final_ds = filtered_ds.map(lambda idx,value: value)
for entry in final_ds:
    print(entry)

# tf.Tensor(7, shape=(), dtype=int32)
# tf.Tensor(13, shape=(), dtype=int32)
# tf.Tensor(6, shape=(), dtype=int32)
Run Code Online (Sandbox Code Playgroud)

祝你好运。