在 keras 中可视化经过训练的神经网络的权重

liw*_*wei 4 python deep-learning keras tensorflow

你好,我训练了一个卷积层为 96*96*32 的自动编码器网络

\n\n

现在我得到了名为自动编码器的模型的权重

\n\n
layer=autoencoder.layers[1]\nW=layer.get_weights()\n
Run Code Online (Sandbox Code Playgroud)\n\n

由于 w 是一个列表,请帮我对其元素进行排序并可视化经过训练的内核。\n我猜它应该是 32 个内核,大小为 96\xc3\x9796

\n\n

当我打字时

\n\n
len(w)\n
Run Code Online (Sandbox Code Playgroud)\n\n

它给出 2 所以我有 2 个数组

\n\n

顶部数组有 9 个子数组,每个子数组有 32 个数字\n最后一个数组有 32 个元素。所以一定是有偏见

\n\n
\n[array([[[[-6.56146603e-03, -1.51752336e-02, -3.76937017e-02,\n           -4.55160812e-03,  1.26366820e-02, -2.97747254e-02,\n            3.76312323e-02, -1.56892575e-02,  2.03932393e-02,\n            3.29606095e-03,  3.76580656e-02,  6.99581252e-03,\n           -4.97130565e-02,  3.63005586e-02,  3.70187908e-02,\n            2.63699284e-03,  4.42482866e-02,  8.26128479e-03,\n            3.44854854e-02,  1.94760375e-02,  3.91177870e-02,\n           -6.67006942e-03,  5.64308763e-02, -1.55166145e-02,\n           -3.46037326e-03, -3.14556211e-02, -2.31548538e-03,\n            5.77888393e-04,  2.17472352e-02, -8.16953406e-02,\n            1.54041937e-02, -3.55066173e-02]],\n\n         [[ 7.61649990e-03, -6.52475432e-02,  2.02584285e-02,\n           -4.36152853e-02, -7.94242844e-02, -6.29556971e-03,\n           -2.17294712e-02,  3.30206454e-02,  3.47386077e-02,\n           -2.77627818e-03,  4.49984707e-02, -3.03241126e-02,\n           -3.36903334e-02,  2.34354921e-02,  3.31020765e-02,\n           -7.81059638e-03, -9.54489596e-03, -1.07985372e-02,\n            4.10569459e-02,  5.06392084e-02, -1.64809041e-02,\n            8.42852518e-03, -6.24148361e-03,  1.38165271e-02,\n            4.47277874e-02, -1.68551356e-02,  2.87279133e-02,\n           -4.17906158e-02, -3.29194516e-02,  5.37550561e-02,\n           -3.10864598e-02, -4.53849025e-02]],\n\n         [[ 5.37880100e-02,  2.00091377e-02, -8.04780126e-02,\n            2.05146279e-02, -6.41385652e-03,  2.94176023e-02,\n            2.42049675e-02,  2.98423916e-02,  1.30865928e-02,\n           -9.23016574e-03, -2.63463743e-02, -1.58412699e-02,\n           -4.76215854e-02, -1.53328422e-02, -2.54222248e-02,\n            1.03113698e-02,  1.97005924e-02, -1.09527409e-02,\n           -4.29149866e-02,  1.15255425e-02,  3.65356952e-02,\n            2.26275604e-02,  8.76231957e-03, -1.82650369e-02,\n            4.30952013e-02, -1.58966344e-03,  1.01399068e-02,\n            7.15927547e-03,  2.70794444e-02, -1.93151142e-02,\n            2.06329934e-02, -3.24055366e-02]]],\n\n\n        [[[ 7.32885906e-04, -5.99233769e-02,  1.01583647e-02,\n            2.62707975e-02, -1.60765275e-02,  4.54364009e-02,\n            1.22182900e-02,  1.77695882e-02,  3.40870097e-02,\n           -3.20678158e-03,  1.94115974e-02, -5.89495376e-02,\n            5.51430099e-02,  1.08586736e-02, -2.14386974e-02,\n           -1.10124948e-03, -1.41514605e-02, -8.40184465e-03,\n           -4.09237854e-02,  2.27938611e-02,  2.82027805e-03,\n            3.99805643e-02, -5.23957238e-02, -6.65743649e-02,\n           -1.86213956e-03,  1.84283289e-03,  8.22036352e-04,\n           -2.04587094e-02, -4.95675243e-02,  5.40869832e-02,\n            4.00022417e-02, -4.74570543e-02]],\n\n         [[-3.73015292e-02,  9.84914601e-03,  9.94246900e-02,\n            3.19805741e-02,  8.14174674e-03,  2.72354241e-02,\n           -1.58177980e-03, -5.65455444e-02, -2.13499945e-02,\n            2.36055311e-02,  4.57456382e-03,  5.87781705e-02,\n           -4.50953143e-03, -3.05559561e-02,  8.65572542e-02,\n           -2.87776738e-02,  7.56273838e-03, -2.02421043e-02,\n            4.32164557e-02,  1.07650533e-02,  1.74834915e-02,\n           -2.26386450e-02, -4.51299828e-03, -7.19766971e-03,\n           -5.64673692e-02, -3.46505865e-02, -9.57003422e-03,\n           -4.17267382e-02,  2.74983943e-02,  7.50013590e-02,\n           -1.39447292e-02, -2.10063234e-02]],\n\n         [[-4.99953330e-03, -1.95915010e-02,  7.38414973e-02,\n            3.00457701e-02,  4.11909744e-02, -4.93509434e-02,\n           -3.72827090e-02, -4.84874584e-02, -1.73344277e-02,\n            2.13540550e-02,  2.63152272e-02,  5.11181913e-02,\n            5.94335012e-02, -8.46157200e-04, -3.79960015e-02,\n           -2.01609023e-02,  2.21411046e-02, -1.14003820e-02,\n           -1.78077854e-02, -6.17240835e-03, -9.96494666e-03,\n           -2.70768851e-02,  3.32489684e-02, -1.18451891e-02,\n            7.48611614e-02,  3.68427448e-02, -1.70680200e-04,\n            2.78645731e-03,  3.37152109e-02, -6.00774325e-02,\n            3.43431458e-02,  6.80516511e-02]]],\n\n\n        [[[ 4.51148823e-02,  4.12209071e-02, -1.92945134e-02,\n           -2.68811788e-02,  4.68725041e-02, -2.08357088e-02,\n           -3.62888947e-02, -1.60191804e-02,  3.19913588e-02,\n            1.54639455e-02, -7.92380888e-03, -4.85247411e-02,\n           -3.52074914e-02, -1.04825860e-02, -6.63231388e-02,\n            4.35819328e-02,  1.74060687e-02, -3.14022303e-02,\n           -2.88435258e-02, -2.56987382e-03, -4.61222306e-02,\n            9.01424140e-03, -3.54990773e-02,  3.61517034e-02,\n           -4.51472104e-02, -1.96188372e-02,  2.76502203e-02,\n           -3.39846462e-02, -5.75804268e-04, -4.55158725e-02,\n            2.47761561e-03,  5.08131757e-02]],\n\n         [[ 3.74217257e-02,  4.53428067e-02, -4.36269939e-02,\n           -1.65079869e-02, -2.69084796e-02, -2.38134293e-03,\n            2.26788968e-02, -3.10470518e-02, -4.33242172e-02,\n            1.89485904e-02, -5.52747138e-02,  6.01334386e-02,\n           -1.70235410e-02, -4.17503342e-02, -1.59652822e-03,\n           -3.10646854e-02, -1.94913559e-02,  5.42740058e-03,\n            5.47912866e-02,  2.19548331e-03, -2.94116754e-02,\n            2.24571414e-02, -1.57341175e-02, -5.24678500e-03,\n            4.41270098e-02,  1.79115515e-02, -3.40841003e-02,\n           -2.95497216e-02,  4.40835916e-02,  4.28234115e-02,\n           -4.25039157e-02,  5.90493456e-02]],\n\n         [[-2.71476209e-02,  6.84098527e-02, -2.91980486e-02,\n           -2.52507403e-02, -6.22444265e-02,  3.67519422e-03,\n            5.06899729e-02,  3.09969904e-03,  4.50362265e-02,\n            8.56801707e-05,  4.21552844e-02, -3.78406122e-02,\n           -1.73772611e-02,  4.68185954e-02, -6.93227863e-03,\n           -4.71074954e-02,  5.72011899e-03, -1.59831103e-02,\n           -1.66428182e-02,  1.12894354e-02,  5.62585844e-03,\n            1.36870472e-02, -2.89466791e-02, -2.87153292e-03,\n           -3.21626514e-02, -3.75866666e-02, -1.62240565e-02,\n            3.01954672e-02, -2.69964593e-03, -2.27513053e-02,\n            2.10835561e-02, -4.13369946e-02]]]], dtype=float32),\n array([-1.1922461e-03, -2.0752363e-04,  1.1357996e-05,  1.6377015e-05,\n        -2.5950783e-04,  1.9307183e-05, -1.5572178e-06, -1.3648998e-03,\n        -8.6763187e-04,  4.4856939e-04,  2.7988455e-03, -7.7398616e-04,\n        -5.1178242e-04, -6.8265648e-04,  1.8571866e-04, -7.1992702e-04,\n        -5.5880222e-04, -3.6114815e-04, -9.7678707e-04,  2.6443407e-03,\n         1.1190268e-03, -1.0251488e-03, -1.1638318e-03,  7.1209669e-04,\n         4.9417594e-04,  2.3746442e-04, -4.8552561e-04,  1.4480414e-03,\n        -1.8445569e-05,  4.2989667e-04,  1.0579359e-04, -3.2821635e-04],\n       dtype=float32)]\n\n
Run Code Online (Sandbox Code Playgroud)\n\n

模型几个起始层总结

\n\n
\nLayer (type)                    Output Shape         Param #     Connected to                     \n==================================================================================================\ninput_1 (InputLayer)            (None, 96, 96, 1)    0                                            \n__________________________________________________________________________________________________\nconv2d_1 (Conv2D)               (None, 96, 96, 32)   320         input_1[0][0]                    \n__________________________________________________________________________________________________\nbatch_normalization_1 (BatchNor (None, 96, 96, 32)   128         conv2d_1[0][0]                   \n\n\n
Run Code Online (Sandbox Code Playgroud)\n\n

现在我如何订购它们并可视化

\n\n

我正在使用喀拉拉邦

\n\n

谢谢

\n

Zab*_*azi 6

通常,如果您使用密集层,则第一个长度 2 对应于权重向量和偏差向量。

由于我不知道您的图层的类型,因此我添加了一个示例来解释密集、Conv2D 图层的形状。

第一个长度始终对应于权重和偏差,权重和偏差的第二个形状不同,对于偏差,它始终是一个数组,对于密集,权重具有形状(input_dim,output_dim),对于Conv2D(通道,kernel_h,kernel_w,过滤器数量)。

from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import numpy as np

i1 = Input(shape=(32,32,3))
c1 = Conv2D(32, 3)(i1)
f1 = Flatten()(c1)
d1 = Dense(5)(f1)

m = Model(i1, d1)

m.summary()

y = m(np.zeros((1, 32, 32, 3)))

print(m.layers)
cw1 = np.array(m.layers[1].get_weights())
print(cw1.shape) # 2 weight, 1 weight, 1 bias
print(cw1[0].shape) # 3 channels, 3 by 3 kernels, 32 filters
print(cw1[1].shape) # 32 biases

cw1 = np.array(m.layers[2].get_weights())
print(cw1.shape) # this is just a flatten operations, so no weights

cw1 = np.array(m.layers[3].get_weights())
print(cw1.shape) # 2 -> 1 weight, 1 bias
print(cw1[0].shape) # 28800 inputs, 5 outputs, 28800 by 5 weight matrix
print(cw1[1].shape) # 5 biases
Run Code Online (Sandbox Code Playgroud)
Model: "model_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_14 (InputLayer)        [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 30, 30, 32)        896       
_________________________________________________________________
flatten_13 (Flatten)         (None, 28800)             0         
_________________________________________________________________
dense_13 (Dense)             (None, 5)                 144005    
=================================================================
Total params: 144,901
Trainable params: 144,901
Non-trainable params: 0
_________________________________________________________________
[<tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7fb8ce3bb828>, <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fb8ce5fd6d8>, <tensorflow.python.keras.layers.core.Flatten object at 0x7fb8ce3bb940>, <tensorflow.python.keras.layers.core.Dense object at 0x7fb8ce3bbb70>]
(2,)
(3, 3, 3, 32)
(32,)
(0,)
(2,)
(28800, 5)
(5,)

Run Code Online (Sandbox Code Playgroud)

可视化完全取决于维度。

如果是一维的话

import matplotlib.pyplot as plt
plt.plot(weight)
plt.show()
Run Code Online (Sandbox Code Playgroud)

如果是二维的话

import matplotlib.pyplot as plt
plt.imshow(weight)
plt.show()
Run Code Online (Sandbox Code Playgroud)

如果是3D的话

您可以选择一个通道并仅绘制该部分。


# plotting the 32 conv filter
import matplotlib.pyplot as plt
cw1 = np.array(m.layers[1].get_weights())
for i in range(32):
  plt.imshow(cw1[0][:,:,:,i])
  plt.show()
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述