yan*_*hen 13 pickle keras tensorflow tensorflow2.0
我已经训练了一个 TextVectorization 层(见下文),我想将其保存到磁盘,以便下次可以重新加载它?我已经尝试过了pickle并且joblib.dump()。这是行不通的。
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
text_dataset = tf.data.Dataset.from_tensor_slices(text_clean)
vectorizer = TextVectorization(max_tokens=100000, output_mode='tf-idf',ngrams=None)
vectorizer.adapt(text_dataset.batch(1024))
Run Code Online (Sandbox Code Playgroud)
生成的错误如下:
InvalidArgumentError: Cannot convert a Tensor of dtype resource to a NumPy array
Run Code Online (Sandbox Code Playgroud)
我该如何保存它?
muj*_*iga 15
不要腌制对象,而是腌制配置和权重。稍后将其解封并使用配置来创建对象并加载保存的权重。官方文档在这里。
text_dataset = tf.data.Dataset.from_tensor_slices([
"this is some clean text",
"some more text",
"even some more text"])
# Fit a TextVectorization layer
vectorizer = TextVectorization(max_tokens=10, output_mode='tf-idf',ngrams=None)
vectorizer.adapt(text_dataset.batch(1024))
# Vector for word "this"
print (vectorizer("this"))
# Pickle the config and weights
pickle.dump({'config': vectorizer.get_config(),
'weights': vectorizer.get_weights()}
, open("tv_layer.pkl", "wb"))
print ("*"*10)
# Later you can unpickle and use
# `config` to create object and
# `weights` to load the trained weights.
from_disk = pickle.load(open("tv_layer.pkl", "rb"))
new_v = TextVectorization.from_config(from_disk['config'])
# You have to call `adapt` with some dummy data (BUG in Keras)
new_v.adapt(tf.data.Dataset.from_tensor_slices(["xyz"]))
new_v.set_weights(from_disk['weights'])
# Lets see the Vector for word "this"
print (new_v("this"))
Run Code Online (Sandbox Code Playgroud)
输出:
tf.Tensor(
[[0. 0. 0. 0. 0.91629076 0.
0. 0. 0. 0. ]], shape=(1, 10), dtype=float32)
**********
tf.Tensor(
[[0. 0. 0. 0. 0.91629076 0.
0. 0. 0. 0. ]], shape=(1, 10), dtype=float32)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
7192 次 |
| 最近记录: |