我正在尝试保存一个 TensorFlow 模型,其中包括一些标签的后期处理。
给定一些分类标签,我有兴趣训练一个模型(例如, a tf.keras.Sequential),其中我以前对标签应用了 One-hot 编码。这是模型的样子:
model = tf.keras.Sequential([
tf.keras.layers.DenseFeatures(transform_features),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(loss='categorical_crossentropy',optimizer='adam')
history = model.fit(train_data, epochs=5)
Run Code Online (Sandbox Code Playgroud)
其中transform_features是 的列表tf.feature_columns,train_data是tf.data.Dataset包含训练数据的(train_X,train_y)。
一旦训练了模型,我想应用一些后期处理。我想将这个后处理添加到一个新的(或相同的)TensorFlow 模型中,这样当我要求对这个模型进行预测时(例如在 BigQuery 中使用导入的 TensorFlow 模型进行预测),它会给我解码的最终标签。
我正在考虑制作第一个模型,如之前所示,训练后,向模型添加以下层:
from tf.keras.layers import Lambda
model.add(Lambda(lambda x: tf.argmax(x, axis=-1)))
Run Code Online (Sandbox Code Playgroud)
但我不知道如何“合并”这两个不同的模型并将它们保存为相同的 TensorFlow SavedModel 格式(使用tf.saved_model.save(model, MODEL_PATH))。有没有什么方法可以让一个云在 Tensorflow 中进行这种后处理?
谢谢
Tensorflow 提供了一种构建自定义层的方法,该层运行称为Lambda 层的自定义函数。要查看图层示例argmax,请参阅此答案。
然而,另一种方法是使用 更高级并提供更多灵活性的子类层。 keras.layers.Layer
子类层的示例:
乘以比例因子
class ScaleLayer(tf.keras.layers.Layer):
def __init__(self):
super(ScaleLayer, self).__init__()
self.scale = tf.Variable(1.)
def call(self, inputs):
return inputs * self.scale
Run Code Online (Sandbox Code Playgroud)
检索最高值的索引。
class argmax_layer(Layer):
def __init__(self):
super(argmax_layer, self).__init__()
def call(self, inputs):
return tf.math.argmax(inputs, axis=1)
Run Code Online (Sandbox Code Playgroud)
这是用于分类任务的 CNN 架构的完整代码,我在其中添加了argmaxLayer.
import tensorflow as tf
from keras.models import Sequential
from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, MaxPooling2D,Flatten,Dropout,Dense
from tensorflow.keras import optimizers
import numpy as np
class argmax_layer(Layer):
def __init__(self):
super(argmax_layer, self).__init__()
def call(self, inputs):
return tf.math.argmax(inputs, axis=1)
def cnn_model(image_x=100,image_y=100,num_classes=10):
model = Sequential()
model.add(Conv2D(32, (5,5), input_shape=(image_x, image_y, 1), activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(10, 10), strides=(10, 10), padding='same'))
model.add(Flatten())
model.add(Dense(1024, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.6))
model.add(Dense(num_classes, activation='softmax'))
sgd = optimizers.SGD(lr=1e-2)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
return model
########################
# Model summary
model = cnn_model()
model.add(argmax_layer())
model.summary()
#################### TEST
input = np.random.random((5,100,100,1)) # 5 samples
print("Output:", model.predict(input))
Run Code Online (Sandbox Code Playgroud)
型号总结:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 96, 96, 32) 832
_________________________________________________________________
batch_normalization (BatchNo (None, 96, 96, 32) 128
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 10, 10, 32) 0
_________________________________________________________________
flatten (Flatten) (None, 3200) 0
_________________________________________________________________
dense (Dense) (None, 1024) 3277824
_________________________________________________________________
batch_normalization_1 (Batch (None, 1024) 4096
_________________________________________________________________
dropout (Dropout) (None, 1024) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 10250
_________________________________________________________________
argmax_layer (argmax_layer) (None,) 0
=================================================================
Total params: 3,293,130
Trainable params: 3,291,018
Non-trainable params: 2,112
________________________________
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
755 次 |
| 最近记录: |