在TensorFlow训练的模型中获取一些权重的值

S.A*_*EEN 62 tensorflow

我已经使用TensorFlow训练了一个ConvNet模型,我希望在图层中获得特定的权重.例如在torch7中,我只是访问model.modules[2].weights.得到第2层的权重.我如何在TensorFlow中做同样的事情?

mrr*_*rry 90

在TensorFlow中,训练的权重由tf.Variable对象表示.如果您创建了一个tf.Variable名为v-yourself的-eg,则可以通过调用sess.run(v)(where sessis a tf.Session)将其值作为NumPy数组获取.

如果您当前没有指向该指针的指针,则tf.Variable可以通过调用获取当前图形中的可训练变量列表tf.trainable_variables().此函数返回tf.Variable当前图形中所有可训练对象的列表,您可以通过匹配v.name属性来选择所需的对象.例如:

# Desired variable is called "tower_2/filter:0".
var = [v for v in tf.trainable_variables() if v.name == "tower_2/filter:0"][0]
Run Code Online (Sandbox Code Playgroud)

  • 它取决于用于加载模型的机制.如果你使用较新的`tf.train.import_meta_graph()`那么`tf.trainable_variables()`应该可行.如果你使用较低级别的`tf.import_graph_def()`函数,那么你应该在`return_elements`可选参数中传递变量的名称,并返回一个张量(然后你可以传递给`sess.run( )`. (5认同)
  • 啊,该模型的问题是变量没有有意义的名称.你有两个选择.(注意:您必须使用与训练网络相同的会话来检索权重.)1.在训练后执行`session.run(layer1_weights)`以获取变量的值(这是第一个)在我的回答中的建议).2.在`layer1_weights = tf.Variable(...)`之后添加一个`print`语句,找出该变量的TensorFlow名称(例如`print(layer1_weights.name)`),然后使用该字符串查找`tf.trainable_variables()`中的变量. (5认同)

Ten*_*ort 11

2.0 兼容答案:如果我们使用 构建模型Keras Sequential API,我们可以使用下面提到的代码获取模型的权重:

!pip install tensorflow==2.1

from tf.keras import Sequential

model = Sequential()

model.add(Conv2D(filters=conv1_fmaps, kernel_size=conv1_ksize,
                         strides=conv1_stride, padding=conv1_pad,
                         activation=tf.nn.relu, input_shape=(height, width, channels),
                    data_format='channels_last'))

model.add(MaxPool2D(pool_size = (2,2), strides= (2,2), padding="VALID"))

model.add(Dropout(0.25))

model.add(Flatten())

model.add(Dense(units = 32, activation = 'relu'))

model.add(Dense(units = 10, activation = 'softmax'))

model.summary()

print(model.trainable_variables) 
Run Code Online (Sandbox Code Playgroud)

最后一条语句print(model.trainable_variables)将返回模型的权重,如下所示:

    [<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 32) dtype=float32>,
 <tf.Variable 'conv2d/bias:0' shape=(32,) dtype=float32>, <tf.Variable 
'dense/kernel:0' shape=(6272, 32) dtype=float32>, <tf.Variable 'dense/bias:0' 
shape=(32,) dtype=float32>, <tf.Variable 'dense_1/kernel:0' shape=(32, 10) 
dtype=float32>, <tf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32>]
Run Code Online (Sandbox Code Playgroud)


sag*_*gzz 6

因此,如果逐步执行此代码,您将首先获得已使用/可训练变量的列表。然后,您可以将它们排序在一个列表中,在其中将权重矩阵/列表排序为变量名,例如,如何处理该信息。

vars = tf.trainable_variables()
print(vars) #some infos about variables...
vars_vals = sess.run(vars)
for var, val in zip(vars, vars_vals):
    print("var: {}, value: {}".format(var.name, val)) #...or sort it in a list....
Run Code Online (Sandbox Code Playgroud)

  • 始终要归功于它的归属。该[issue]中给出了ans(https://github.com/google/prettytensor/issues/6#issuecomment-380919368) (4认同)
  • 另外,这里的“sess”是什么?答案应该是独立的.. (3认同)