Keras:模型中可训练参数的数量

Pra*_*aha 10 python keras

trainable=False在我的所有图层中设置,通过ModelAPI 实现,但我想验证它是否有效.model.count_params()返回参数的总数,但除了查看最后几行之外,还有什么方法可以获得可训练参数的总数model.summary()

tuo*_*tik 22

from keras import backend as K

trainable_count = int(
    np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
non_trainable_count = int(
    np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))

print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))
Run Code Online (Sandbox Code Playgroud)

上面的代码片段可以在layer_utils.print_summary()定义的末尾发现,即summary()调用.

  • 如果您使用的是 TensorFlow 2,则现在可以在“from tensorflow.python.keras.utils.layer_utils import count_params”中找到“count_params”函数 (4认同)

Dan*_*bak 12

对于TensorFlow 2.0

import tensorflow.keras.backend as K

trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
non_trainable_count = np.sum([K.count_params(w) for w in model.non_trainable_weights])

print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))
Run Code Online (Sandbox Code Playgroud)