K.ctc_batch_cost() 中 input_length 的意思是什么

Lin*_*Ina 5 python keras tensorflow

我已经下载了一个使用 Keras 的 ocr 代码,它应用了 CRNN 网络并使用 CTC 损失作为损失函数。但是,我对 CTC 损失真的很K.ctc_batch_cost()陌生,只是在使用 时遇到了麻烦,尤其是 input_length 的含义。在 keras 的文档中,

tf.keras.backend.ctc_batch_cost( y_true, y_pred, input_length, label_length ) 的参数

  1. y_true:包含真实标签的张量(样本,max_string_length)。
  2. y_pred: 张量 (samples, time_steps, num_categories) 包含 softmax 的预测或输出。
  3. input_length:张量 (samples, 1) 包含 y_pred 中每个批次项目的序列长度。
  4. label_length:张量 (samples, 1) 包含 y_true 中每个批次项目的序列长度。

    但是,我的问题是 input_length 的含义是什么?那是 LSTM 输出的维度吗?

小智 0

一个示例的 CTC 损失是在 2D 阵列 (T,C) 上计算的。C 必须等于字符数 + 1(空白字符)。C 包含某个时间戳处字符的概率分布。T 将是时间戳的数量。

T 的长度应为 2* max_string_length。长度为 T 的 y_true 的所有可能编码都将用于负对数损失计算。

它通常是前一层输出的形状。