Mis*_*hal 5 keras tensorflow tf.keras tensorflow2.0
我有一个张量字典数据集,以及使用子类化 API 定义的以下模型:
class Model(tf.keras.Model):
def __init__(self):
super().__init__()
self._movie_full_dense = tf.keras.layers.Dense(
units=40, activation=tf.keras.layers.Activation("relu"))
self._user_full_dense = tf.keras.layers.Dense(
units=40, activation=tf.keras.layers.Activation("relu"))
self._full_dense = tf.keras.layers.Dense(
units=1, activation=tf.keras.layers.Activation("sigmoid"))
def call(self, features):
movie_embedding = self._movie_full_dense(features['movie'])
user_embedding = self._user_full_dense(features['user'])
combined = tf.concat([movie_embedding, user_embedding], 1)
output = self._full_dense(combined)
return output
Run Code Online (Sandbox Code Playgroud)
我想使用函数式 API 来实现它。但我不知道如何定义输入?即,以下功能的等效项是什么?
self._movie_full_dense(features['movie'])
Run Code Online (Sandbox Code Playgroud)
import tensorflow as tf
print(tf.version.VERSION)
toy_data = {'movie': [[0], [1], [0], [1]], 'user': [[10], [12], [12], [10]]}
dataset = tf.data.Dataset.from_tensor_slices(toy_data).batch(2)
for x in dataset:
print(x)
def make_model():
inp_movie = tf.keras.Input(shape=(1,))
inp_user = tf.keras.Input(shape=(1,))
movie_embedding = tf.keras.layers.Dense(
units=40, activation=tf.keras.layers.Activation("relu"))(inp_movie)
user_embedding = tf.keras.layers.Dense(
units=40, activation=tf.keras.layers.Activation("relu"))(inp_user)
combined = tf.concat([movie_embedding, user_embedding], 1)
output = tf.keras.layers.Dense(
units=1, activation=tf.keras.layers.Activation("sigmoid"))(combined)
model = tf.keras.Model(inputs=[inp_movie, inp_user], outputs=output)
return model
model = make_model()
for x in dataset:
print(model(x))
Run Code Online (Sandbox Code Playgroud)
这有效。请注意,传递给inputs调用参数的迭代tf.keras.Model必须按照与您将使用的字典相同的顺序进行排序,字典按其键排序,movie然后user。所以使用inputs={'a': inp_movie, 'b': inp_user}orinputs={'movie': inp_movie, 'user': inp_user}也有效,而inputs=[inp_user, inp_movie]不会。
您可以使用此代码来测试这种交互:
def make_test_model():
inp_movie = tf.keras.Input(shape=(1,))
inp_user = tf.keras.Input(shape=(1,))
model = tf.keras.Model(inputs={'a': inp_movie, 'b': inp_user}, outputs=inp_movie)
return model
def make_test_model_2():
inp_movie = tf.keras.Input(shape=(1,))
inp_user = tf.keras.Input(shape=(1,))
model = tf.keras.Model(inputs=[inp_user, inp_movie], outputs=inp_movie)
return model
model_test = make_test_model()
model_test_2 = make_test_model_2()
for x in dataset:
print(model_test(x))
for x in dataset:
print(model_test_2(x))
Run Code Online (Sandbox Code Playgroud)
您还可以Input使用字典的键命名图层,并给出按图层名称排序的图层inputs列表作为参数。Input这使您可以在模型中添加或删除输入,而不必担心inputs每次都重写您的参数。所以这就是我要做的:
def make_model_2():
input_list = []
inp_movie = tf.keras.Input(shape=(1,), name='movie')
input_list.append(inp_movie)
inp_user = tf.keras.Input(shape=(1,), name='user')
input_list.append(inp_user)
movie_embedding = tf.keras.layers.Dense(
units=40, activation=tf.keras.layers.Activation("relu"))(inp_movie)
user_embedding = tf.keras.layers.Dense(
units=40, activation=tf.keras.layers.Activation("relu"))(inp_user)
combined = tf.concat([movie_embedding, user_embedding], 1)
output = tf.keras.layers.Dense(
units=1, activation=tf.keras.layers.Activation("sigmoid"))(combined)
input_list.sort(key=lambda inp: inp._keras_history.layer.name)
model = tf.keras.Model(inputs=input_list, outputs=output)
return model
Run Code Online (Sandbox Code Playgroud)
这是一个测试它是否有效的方法:
def make_test_model_3(boolean):
input_list = []
inp_movie = tf.keras.Input(shape=(1,), name='movie')
inp_user = tf.keras.Input(shape=(1,), name='user')
if boolean:
input_list.append(inp_movie)
input_list.append(inp_user)
else:
input_list.append(inp_user)
input_list.append(inp_movie)
input_list.sort(key=lambda inp: inp._keras_history.layer.name)
model = tf.keras.Model(inputs=input_list, outputs=inp_movie)
return model
model_test_3_0= make_test_model_3(True)
model_test_3_1= make_test_model_3(False)
for x in dataset:
print(model_test_3_0(x))
for x in dataset:
print(model_test_3_1(x))
Run Code Online (Sandbox Code Playgroud)
2020年2月20日编辑:
make_model不适用于 tf2.1.0,但make_model_2仍然可以。我在 GitHub 上提出了一个关于这种向后不兼容性的问题。如果您有兴趣,这是链接。回想一下,如果您打算继续使用 tf2.0.0,这两个函数都可以使用。