无法在 Python 中腌制 Tensorflow 对象 - TypeError:无法腌制 _thread._local 对象

Al-*_*Hag 6 python pickle tensorflow

我想在 tensorflow 上运行 keras fit 后腌制历史对象。但我收到一个错误。

import gzip
import numpy as np
import os
import pickle
import tensorflow as tf
from tensorflow import keras


with gzip.open('mnist.pkl.gz', 'rb') as f:
    train_set, test_set = pickle.load(f, encoding='latin1')

X_train = np.asarray(train_set[0])
y_train = np.asarray(train_set[1])

X_test = np.asarray(test_set[0])
y_test = np.asarray(test_set[1])

X_valid, X_train = X_train[:5000]/255.0, X_train[5000:]/255.0
y_valid, y_train = y_train[:5000], y_train[5000:]

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot']

model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28,28]))
model.add(keras.layers.Dense(300, activation = 'relu'))
model.add(keras.layers.Dense(100, activation = 'relu'))
model.add(keras.layers.Dense(10, activation = 'softmax'))
model.summary()

model.compile(loss='sparse_categorical_crossentropy',
              optimizer='sgd',
              metrics=['accuracy'])

history = model.fit(X_train, y_train, epochs=1,
                    validation_data =(X_valid, y_valid))

if not os.path.isdir('models'):
    os.mkdir('models')

model.save('models/basic.h5')
with open('models/basic_history.pickle', 'wb') as f:
    pickle.dump(history, f)
Run Code Online (Sandbox Code Playgroud)

它给了我以下错误:

Traceback (most recent call last):
  File "main.py", line 69, in <module>
    pickle.dump(history, f)
TypeError: can't pickle _thread._local objects
Run Code Online (Sandbox Code Playgroud)

PS:要运行代码,下载fashion_mnist数据:https ://s3.amazonaws.com/img-datasets/mnist.pkl.g

Al-*_*Hag 6

正如卡尔所建议的那样,历史对象不能被酸洗。但它的字典可以:

with open('models/basic_history.pickle', 'wb') as f:
    pickle.dump(history.history, f)
Run Code Online (Sandbox Code Playgroud)