如何将 y_true 作为字典传递给自定义损失函数而不改变?

Pav*_*nov 5 python customization keras tensorflow loss-function

我需要使用tf.keras.*. 但:

  • 空白类并不像 tf 期望的那样为零,但是(num_classes - 1)相反。
  • 输入图像的宽度事先未知(不同批次的宽度不同)。

我想利用tf.nn.ctc_loss其中有一些很好的论点:blank_index。所以我做了一个简单的包装器来计算 CTC 损失:

    class CTCLossWrapper(tf.keras.losses.Loss):
        def __init__(self, blank_class: int, reduction: str = tf.keras.losses.Reduction.AUTO, name: str = 'ctc_loss'):
            super().__init__(reduction=reduction, name=name)
            self.blank_class = blank_class
        
        def call(self, y_true, y_pred):
            output = y_true['output']
            targets, target_lenghts = output['targets'], output['target_lengths']
            y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + K.epsilon())
            max_input_len = K.cast(K.shape(y_pred)[1], dtype='int32')
            input_lengths = tf.ones((K.shape(y_pred)[0]), dtype='int32') * max_input_len
            return tf.nn.ctc_loss(
                labels=targets,
                logits=y_pred,
                label_length=target_lenghts,
                logit_length=input_lengths,
                blank_index=self.blank_class
            )
Run Code Online (Sandbox Code Playgroud)

我还编写了一个简单的生成器函数,它生成训练样本:

def generator(dataset, batch_size: int, shuffle=False):
    indexes = np.arange(len(dataset))
    while True:
        if shuffle:
            indexes = np.random.permutation(indexes)
        for i in range(0, len(dataset), batch_size):
            # Get next batch
            batch = dataset[indexes[i:i+batch_size]]
            images, image_widths = batch['images'], batch['image_widths']
            targets, target_lengths = batch['targets'], batch['target_lengths']
            # Re-arrange dimensions (B, H, W, C) -> (B, W, H, C)
            # Important Note: width=W and height=H are swapped from typical Keras convention
            # because width is the time dimension when it gets fed into the RNN
            images = np.transpose(images, axes=(0, 2, 1, 3)).astype(np.float32) / 255.0
            # Change zero target length to 1 due to invalid implementation of ctc_batch_cost in keras
            target_lengths[target_lengths == 0] = 1
            # Add singleton dimension
            # image_widths = image_widths[:, np.newaxis]
            # target_lengths = target_lengths[:, np.newaxis]
            # Construct output value
            outputs = {
                'images': images,  # (batch_size, max_image_width, 32, 1)
                'image_widths': image_widths,  # (batch_size,)
                'targets': targets,  # (batch_size, max_target_len)
                'target_lengths': target_lengths,  # (batch_size,)
            }
            yield images, dict(output=outputs)
Run Code Online (Sandbox Code Playgroud)

正如您所看到的,生成器不仅输出(x, y_true)4 个值:

  • 输入图像,
  • 输入图像宽度,
  • 目标序列,
  • 每个目标序列的长度。

之所以如此是因为tf.nn.ctc_loss还需要至少 4 个参数才能工作。

我的计划是将输入图像传递为x并将所有 4 个值的字典传递为y_true传递。

然后当然我使用我的CTCLossWrapperand编译模型blank_class

    model.compile(
        optimizer=Adam(),
        loss=CTCLossWrapper(blank_class=blank_class),
    )

Run Code Online (Sandbox Code Playgroud)

之后我可以通过以下方式开始训练:

    model.fit(
        x=generator(train_dataset, batch_size=batch_size, shuffle=True),
        steps_per_epoch=int(len(train_dataset) // batch_size),
        epochs=200
    )

Run Code Online (Sandbox Code Playgroud)

问题是,当CTCLossWrapper调用 my 时,它不会得到 dict() 作为y_true。它仅从中获取一个张量。

如何避免或关闭张量流预处理并以y_true与数据集提供的形式相同的形式获取值?