如何在Keras Lambda层中有条件地缩放值?

yan*_*hen 2 python machine-learning keras tensorflow tensor

输入张量rnn_pv是形状(?, 48, 1)。我想缩放此张量中的每个元素,因此我尝试使用Lambda如下图层:

rnn_pv_scale = Lambda(lambda x: 1 if x >=1000 else x/1000.0 )(rnn_pv)
Run Code Online (Sandbox Code Playgroud)

但这带来了错误:

TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.
Run Code Online (Sandbox Code Playgroud)

那么实现此功能的正确方法是什么?

tod*_*day 6

您不能在模型定义中使用Python控制流语句(如if-else语句)执行条件操作。相反,您需要使用Keras后端中定义的方法。由于您使用TensorFlow作为后端,因此可以tf.where()用来实现以下目的:

import tensorflow as tf

scaled = Lambda(lambda x: tf.where(x >= 1000, tf.ones_like(x), x/1000.))(input_tensor)
Run Code Online (Sandbox Code Playgroud)

或者,要支持所有后端,您可以创建一个掩码来执行此操作:

from keras import backend as K

def rescale(x):
    mask = K.cast(x >= 1000., dtype=K.floatx())
    return mask + (x/1000.0) * (1-mask)

#...
scaled = Lambda(rescale)(input_tensor)
Run Code Online (Sandbox Code Playgroud)

更新:支持所有后端的另一种方法是使用K.switch方法:

from keras import backend as K

scaled = Lambda(lambda x: K.switch(x >= 1000., K.ones_like(x), x / 1000.))(input_tensor)
Run Code Online (Sandbox Code Playgroud)