K.gradients(loss,input_img)[0]返回“无”。(带有Tensorflow后端的Keras CNN可视化)

Jex*_*xus 8 python neural-network deep-learning keras tensorflow

我使用在Tensorflow后端使用Keras训练的CNN模型。我想通过本教程可视化我的CNN过滤器:https : //blog.keras.io/how-convolutional-neural-networks-see-the-world.html

from keras import backend as K
from keras.models import load_model
import numpy as np

model = load_model('my_cnn_model.h5')
input_img = np.load('my_picture.npy')

# get the symbolic outputs of each "key" layer (we gave them unique names).
layer_dict = dict([(layer.name, layer) for layer in model.layers])

layer_name = 'block5_conv3'
filter_index = 0  # can be any integer from 0 to 511, as there are 512 filters in that layer

# build a loss function that maximizes the activation
# of the nth filter of the layer considered
layer_output = layer_dict[layer_name].output
loss = K.mean(layer_output[:, :, :, filter_index])

# compute the gradient of the input picture wrt this loss
grads = K.gradients(loss, input_img)[0]

# normalization trick: we normalize the gradient
grads /= (K.sqrt(K.mean(K.square(grads))) + 1e-5)

# this function returns the loss and grads given the input picture
iterate = K.function([input_img], [loss, grads])
Run Code Online (Sandbox Code Playgroud)

但是,当代码执行到这一行时:
grads = K.gradients(loss, input_img)[0]
我发现它只返回None对象,因此什么也不会执行。

我寻找一些解决方案。有人说input_img应该是tensorflow的Tensor类型:https : //github.com/keras-team/keras/issues/5455

但是当我尝试将img转换为Tensor时,问题仍然存在。
我在上面的链接中尝试了解决方案,但仍然失败。

也有人说存在此问题,因为您的CNN模型不可区分。 https://github.com/keras-team/keras/issues/8478

但是我的模型仅使用ReLU和Sigmoid(在输出层)的激活功能。这个问题真的是由不可微问题引起的吗?

谁能帮我?非常感谢你!

Mat*_*gro 9

如果您有一个Model实例,则要考虑损耗相对于输入的梯度,您应该执行以下操作:

grads = K.gradients(loss, model.input)[0]
Run Code Online (Sandbox Code Playgroud)

model.input包含代表模型输入的符号张量。使用普通的numpy数组没有意义,因为TensorFlow不知道如何将其连接到计算图,并返回None作为梯度。

然后,您还应该将iterate函数重写为:

iterate = K.function([model.input], [loss, grads])
Run Code Online (Sandbox Code Playgroud)

  • 很好地补充此答案以及如何使用该功能。我们需要`sess = K.get_session()`和`results = sess.run(iterate,feed_dict = {model.input:numpyBatchWithData})`。 (2认同)