我trainable=False在我的所有图层中设置,通过ModelAPI 实现,但我想验证它是否有效.model.count_params()返回参数的总数,但除了查看最后几行之外,还有什么方法可以获得可训练参数的总数model.summary()?
tuo*_*tik 22
from keras import backend as K
trainable_count = int(
np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
non_trainable_count = int(
np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))
print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))
Run Code Online (Sandbox Code Playgroud)
上面的代码片段可以在layer_utils.print_summary()定义的末尾发现,即summary()调用.
Dan*_*bak 12
对于TensorFlow 2.0:
import tensorflow.keras.backend as K
trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
non_trainable_count = np.sum([K.count_params(w) for w in model.non_trainable_weights])
print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
5375 次 |
| 最近记录: |