如何将预训练的Keras模型的所有层转换为不同的dtype(从float32到float16)?

Mic*_*ice 6 numpy python-3.x keras

我试图将我的(float32)模型的精度更改为float16,以查看需要付出多少性能。加载模型(base_model)后,我尝试了以下操作:

from keras import backend as K
K.set_floatx('float16')
weights_list = base_model.layers[1].get_weights()
print('Original:')
print(weights_list[0].dtype)
new_weights = [K.cast_to_floatx(weights_list[0])]
print('New Weights:')
print(new_weights[0].dtype)
print('Setting New Weights')
base_model.layers[1].set_weights(new_weights)
new_weights_list = base_model.layers[1].get_weights()
print(new_weights_list[0].dtype)
Run Code Online (Sandbox Code Playgroud)

输出:

Original:
float32
New Weights:
float16
Setting New Weights
float32
Run Code Online (Sandbox Code Playgroud)

使用此代码,一层中的权重将转换为float16,并且模型中的权重将设置为新的权重,但是在使用get_weights之后,数据类型将返回到float32。有没有办法设置图层的dtype?据我所知,K.cast_to_floatx用于numpy数组,而K.cast用于张量。我是否需要使用新的dtype来构建和构建一个全新的空模型,并将重铸权重放入新模型中?

还是有一些更简单的方法来加载所有具有dtype'float32'的图层的模型,并将所有图层转换为具有dtype'float16'的模型?这是mlmodel中引入的功能,因此我认为在Keras中并不是特别困难。

cra*_*ael 4

有同样的问题并得到了这个工作。什么对我不起作用

  • 保存到文件并加载回来
  • 铸造所有权重并重新分配给原始模型

这对我有用

  • 创建相同架构的新模型并手动设置其权重

微量元素:

>>> from keras import backend as K
>>> from keras.models import Sequential
>>> from keras.layers import Dense, Dropout, Activation
>>> import numpy as np
>>> 
>>> def make_model():
...     model = Sequential()
...     model.add(Dense(64, activation='relu', input_dim=20))
...     model.add(Dropout(0.5))
...     model.add(Dense(64, activation='relu'))
...     model.add(Dropout(0.5))
...     model.add(Dense(10, activation='softmax'))
...     return model
... 
>>> K.set_floatx('float64')
>>> model = make_model()
>>> 
>>> K.set_floatx('float32')
>>> ws = model.get_weights()
>>> wsp = [w.astype(K.floatx()) for w in ws]
>>> model_quant = make_model()
>>> model_quant.set_weights(wsp)
>>> xp = x.astype(K.floatx())
>>> 
>>> print(np.unique([w.dtype for w in model.get_weights()]))
[dtype('float64')]
>>> print(np.unique([w.dtype for w in model_quant.get_weights()]))
[dtype('float32')]
Run Code Online (Sandbox Code Playgroud)