cod*_*tle 5 python keras tensorflow
一些注意事项:我使用的是tensorflow 2.3.0、python 3.8.2和numpy 1.18.5(但不确定这是否重要)
我正在编写一个自定义层,它在内部存储形状为 (a, b) 的不可训练张量 N,其中 a, b 是已知值(该张量是在 init 期间创建的)。当调用输入张量时,它会展平输入张量,展平其存储的张量,并将两者连接在一起。不幸的是,我似乎无法弄清楚如何在连接过程中保留未知的批量维度。这是最少的代码:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Flatten
class CustomLayer(Layer):
def __init__(self, N): # N is a tensor of shape (a, b), where a, b > 1
super(CustomLayer, self).__init__()
self.N = self.add_weight(name="N", shape=N.shape, trainable=False, initializer=lambda *args, **kwargs: N)
# correct me if I'm wrong in using this initializer approach, but for some reason, when I
# just do self.N = N, this variable would disappear when I saved and loaded the model
def build(self, input_shape):
pass # my reasoning is that all the necessary stuff is handled in init
def call(self, input_tensor):
input_flattened = Flatten()(input_tensor)
N_flattened = Flatten()(self.N)
return tf.concat((input_flattened, N_flattened), axis=-1)
Run Code Online (Sandbox Code Playgroud)
我注意到的第一个问题是,Flatten()(self.N)将返回与原始形状 (a, b) 相同的张量self.N,因此,返回值的形状为 (a, num_input_tensor_values+b)。我的理由是,第一个维度 a 被视为批量大小。我修改了call函数:
def call(self, input_tensor):
input_flattened = Flatten()(input_tensor)
N = tf.expand_dims(self.N, axis=0) # N would now be shape (1, a, b)
N_flattened = Flatten()(N)
return tf.concat((input_flattened, N_flattened), axis=-1)
Run Code Online (Sandbox Code Playgroud)
这将返回一个形状为 (1, num_input_vals + a*b) 的张量,这很好,但现在批量维度永久为 1,当我开始使用这一层训练模型时我意识到它只适用于批量大小为 1。这在模型摘要中也非常明显 - 如果我将此层放在输入之后,然后添加一些其他层,则输出张量的第一个维度将类似于None, 1, 1, 1, 1...。有没有办法存储这个内部张量并call在保留可变批量大小的同时使用它?(例如,批量大小为 4 时,相同展平 N 的副本将连接到 4 个展平输入张量中每一个的末尾。)
您必须拥有N与输入中的样本一样多的展平向量,因为您要连接到每个样本。可以将其想象为将行配对并连接它们。如果只有一个N向量,则只能连接一对。为了解决这个问题,您应该根据批次中的样品数量tf.tile()来重复N多次。
例子:
def call(self, input_tensor):
input_flattened = Flatten()(input_tensor) # input_flattened shape: (None, ..)
N = tf.expand_dims(self.N, axis=0) # N shape: (1, a, b)
N_flattened = Flatten()(N) # N_flattened shape: (1, a*b)
N_tiled = tf.tile(N_flattened, [tf.shape(input_tensor)[0], 1]) # repeat along the first dim as many times, as there are samples and leave the second dim alone
return tf.concat((input_flattened, N_tiled), axis=-1)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1414 次 |
| 最近记录: |