我tf.keras.Model在将自定义循环中经过训练的子类模型 ( )转换为 TFLite 时遇到问题。
假设我们有一个小的 CNN 架构,它使用输入数据 ( x) 和取决于批量大小和其他维度 ( add_info) 的附加信息:
class ContextExtractor(tf.keras.Model):
def __init__(self):
super().__init__()
self.model = self.__get_model()
def call(self, x, training=False, **kwargs):
b, h, w, c = x.shape
add_info = tf.zeros((b, h, w, c), dtype=tf.float32)
features = self.model(tf.concat([x, add_info], axis=-1), training=training)
return features
def __get_model(self):
return self.__get_small_cnn()
def __get_small_cnn(self):
model = tf.keras.Sequential()
model.add(layers.Conv2D(32, (3, 3), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Conv2D(32, (3, 3), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Conv2D(64, (3, 3), strides=(2, 2), …Run Code Online (Sandbox Code Playgroud)