数据集API,迭代器和tf.contrib.data.rejection_resample

Sha*_*rny 4 iterator tensorflow

[在@mrry评论之后编辑#1] 我正在使用(伟大和惊人的)数据集API以及tf.contrib.data.rejection_resample来为输入训练管道设置特定的分布函数.

在将tf.contrib.data.rejection_resample添加到input_fn之前,我使用了一次性Iterator.唉,当开始使用后者时,我尝试使用dataset.make_initializable_iterator() - 这是因为我们引入了管道状态变量,并且在输入管道中的所有变量都是init之后需要初始化迭代器.正如@mrry在这里写的那样.

我将input_fn传递给估算器并由实验包装.

问题是 - 在哪里挂钩迭代器的init?如果我尝试:

dataset = dataset.batch(batch_size)
if self.balance:
   dataset = tf.contrib.data.rejection_resample(dataset, self.class_mapping_function, self.dist_target)
   iterator = dataset.make_initializable_iterator()
   tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
else:
   iterator = dataset.make_one_shot_iterator() 

image_batch, label_batch = iterator.get_next()
print (image_batch) 
Run Code Online (Sandbox Code Playgroud)

和映射功能:

def class_mapping_function(self, feature, label):
    """
    returns a a function to be used with dataset.map() to return class numeric ID
    The function is mapping a nested structure of tensors (having shapes and types defined by dataset.output_shapes
    and dataset.output_types) to a scalar tf.int32 tensor. Values should be in [0, num_classes).
    """
    # For simplicity, trying to return the label itself as I assume its numeric...

    return tf.cast(label, tf.int32)  # <-- I guess this is the bug
Run Code Online (Sandbox Code Playgroud)

迭代器不会像单击迭代器那样接收Tensor形状.

例如.使用One Shot迭代器运行,迭代器获得正确的形状:

Tensor("train_input_fn/IteratorGetNext:0", shape=(?, 100, 100, 3), dtype=float32, device=/device:CPU:0)
Run Code Online (Sandbox Code Playgroud)

但是当使用可初始化的迭代器时,它缺少张量形状信息:

Tensor("train_input_fn/IteratorGetNext:0", shape=(?,), dtype=int32, device=/device:CPU:0)
Run Code Online (Sandbox Code Playgroud)

任何帮助将非常感激!

[ 编辑#2 ] - 跟随@mrry注释它看起来像是另一个数据集]也许真正的问题不是迭代器的init序列,而是tf.contrib.data.rejection_resample使用的返回tf.int32的映射函数.但后来我想知道应该如何定义映射函数?例如,将数据集形状保持为(?,100,100,3)...

[ 编辑#3 ]:从rejection_resample的实现

class_values_ds = dataset.map(class_func)
Run Code Online (Sandbox Code Playgroud)

因此,class_func将采用数据集并返回tf.int32的数据集.

Sha*_*rny 12

在@mrry响应之后,我可以想出如何使用数据集API和tf.contrib.data.rejection_resample(使用TF1.3)的解决方案.

目标

给定具有某种分布的特征/标签数据集,让输入管道将分布重塑为特定目标分布.

数值例子

让我们假设我们正在构建一个网络,将某些功能分类为10个类之一.并假设我们只有100个功能,随机分布标签.
30个标记为1类的功能,5个标记为2类的功能,依此类推.在培训期间,我们不希望优先选择1级而不是2级,因此我们希望每个小型批次都能为所有课程保持统一分布.

解决方案

使用tf.contrib.data.rejection_resample将允许为输入管道设置特定分布.

在文档中,它说tf.contrib.data.rejection_resample将采用

(1)数据集 - 您要平衡的数据集

(2)class_func - 这是一个仅从原始数据集生成新数字标签数据集的函数

(3)target_dist - 一个向量大小的类的数量,以特定所需的新分布.

(4)一些更多可选值 - 暂时跳过

并且正如文档所说,它返回一个`数据集.

事实证明,输入数据集的形状与输出数据集形状不同.因此,返回的数据集(在TF1.3中实现)应由用户过滤,如下所示:

    balanced_dataset = tf.contrib.data.rejection_resample(input_dataset,
                                                          self.class_mapping_function,
                                                          self.target_distribution)

    # Return to the same Dataset shape as was the original input
    balanced_dataset = balanced_dataset.map(lambda _, data: (data))
Run Code Online (Sandbox Code Playgroud)

关于迭代器类型的一个注释.正如@mrry 在这里解释的那样,当在管道中使用有状态对象时,应该使用可初始化的迭代器而不是单一的迭代器.请注意,在使用可初始化的迭代器时,您应该将init_op添加到TABLE_INITIALIZERS,否则您将收到此错误:"GetNext()失败,因为迭代器尚未初始化."

代码示例:

# Creating the iterator, that allows to access elements from the dataset
if self.use_balancing:
    # For balancing function, we use stateful variables in the sense that they hold current dataset distribution
    # and calculate next distribution according to incoming examples.
    # For dataset pipeline that have state, one_shot iterator will not work, and we are forced to use
    # initializable iterator
    # This should be relaxed in the future.
    # https://stackoverflow.com/questions/44374083/tensorflow-cannot-capture-a-stateful-node-by-value-in-tf-contrib-data-api
    iterator = dataset.make_initializable_iterator()
    tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)

else:
    iterator = dataset.make_one_shot_iterator()

image_batch, label_batch = iterator.get_next()
Run Code Online (Sandbox Code Playgroud)

它有用吗?

是.这是来自Tensorboard的2张图片,在输入管道标签上收集直方图.原始输入标签均匀分布.方案A:试图实现以下10级分布:0.1,0.4,0.05,0.05,0.05,0.05,0.05,0.05,0.1,0.1]

结果如下:

在此输入图像描述

方案B:试图实现以下10级分布:[0.1,0.1,0.05,0.05,0.05,0.05,0.05,0.05,0.4,0.1]

结果如下:

在此输入图像描述