Pav*_*nov 5 python customization keras tensorflow loss-function
我需要使用tf.keras.*. 但:
(num_classes - 1)相反。我想利用tf.nn.ctc_loss其中有一些很好的论点:blank_index。所以我做了一个简单的包装器来计算 CTC 损失:
class CTCLossWrapper(tf.keras.losses.Loss):
def __init__(self, blank_class: int, reduction: str = tf.keras.losses.Reduction.AUTO, name: str = 'ctc_loss'):
super().__init__(reduction=reduction, name=name)
self.blank_class = blank_class
def call(self, y_true, y_pred):
output = y_true['output']
targets, target_lenghts = output['targets'], output['target_lengths']
y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + K.epsilon())
max_input_len = K.cast(K.shape(y_pred)[1], dtype='int32')
input_lengths = tf.ones((K.shape(y_pred)[0]), dtype='int32') * max_input_len
return tf.nn.ctc_loss(
labels=targets,
logits=y_pred,
label_length=target_lenghts,
logit_length=input_lengths,
blank_index=self.blank_class
)
Run Code Online (Sandbox Code Playgroud)
我还编写了一个简单的生成器函数,它生成训练样本:
def generator(dataset, batch_size: int, shuffle=False):
indexes = np.arange(len(dataset))
while True:
if shuffle:
indexes = np.random.permutation(indexes)
for i in range(0, len(dataset), batch_size):
# Get next batch
batch = dataset[indexes[i:i+batch_size]]
images, image_widths = batch['images'], batch['image_widths']
targets, target_lengths = batch['targets'], batch['target_lengths']
# Re-arrange dimensions (B, H, W, C) -> (B, W, H, C)
# Important Note: width=W and height=H are swapped from typical Keras convention
# because width is the time dimension when it gets fed into the RNN
images = np.transpose(images, axes=(0, 2, 1, 3)).astype(np.float32) / 255.0
# Change zero target length to 1 due to invalid implementation of ctc_batch_cost in keras
target_lengths[target_lengths == 0] = 1
# Add singleton dimension
# image_widths = image_widths[:, np.newaxis]
# target_lengths = target_lengths[:, np.newaxis]
# Construct output value
outputs = {
'images': images, # (batch_size, max_image_width, 32, 1)
'image_widths': image_widths, # (batch_size,)
'targets': targets, # (batch_size, max_target_len)
'target_lengths': target_lengths, # (batch_size,)
}
yield images, dict(output=outputs)
Run Code Online (Sandbox Code Playgroud)
正如您所看到的,生成器不仅输出(x, y_true)4 个值:
之所以如此是因为tf.nn.ctc_loss还需要至少 4 个参数才能工作。
我的计划是将输入图像传递为x并将所有 4 个值的字典传递为y_true传递。
然后当然我使用我的CTCLossWrapperand编译模型blank_class:
model.compile(
optimizer=Adam(),
loss=CTCLossWrapper(blank_class=blank_class),
)
Run Code Online (Sandbox Code Playgroud)
之后我可以通过以下方式开始训练:
model.fit(
x=generator(train_dataset, batch_size=batch_size, shuffle=True),
steps_per_epoch=int(len(train_dataset) // batch_size),
epochs=200
)
Run Code Online (Sandbox Code Playgroud)
问题是,当CTCLossWrapper调用 my 时,它不会得到 dict() 作为y_true。它仅从中获取一个张量。
如何避免或关闭张量流预处理并以y_true与数据集提供的形式相同的形式获取值?
| 归档时间: |
|
| 查看次数: |
999 次 |
| 最近记录: |