当张量中有未知元素时,操纵张量形状的正确方法是什么?

Meh*_*ran 5 python keras tensorflow

假设我有一个形状张量,(None, None, None, 32)我想将其重塑(None, None, 32)为中间维度是原始维度的两个中间维度的乘积。这样做的正确方法是什么?

Dan*_*ler 4

import keras.backend as K

def flatten_pixels(x):
    shape = K.shape(x)
    newShape = K.concatenate([
                                 shape[0:1], 
                                 shape[1:2] * shape[2:3],
                                 shape[3:4]
                             ])

    return K.reshape(x, newShape)
Run Code Online (Sandbox Code Playgroud)

在图层中使用它Lambda

from keras.layers import Lambda

model.add(Lambda(flatten_pixels))
Run Code Online (Sandbox Code Playgroud)

一点知识:

  • K.shape返回张量的“当前”形状,包含数据 - 它包含Tensor所有int维度的值。它仅在运行模型时正确存在,不能在模型定义中使用,只能在运行时计算中使用。
  • K.int_shape将张量的“定义”形状返回为tuple。这意味着变量维度将包含None值。