从 TensorFlow Estimator 模型 (2.0) 保存、加载和预测

Chr*_*eck 5 tensorflow

是否有Estimator在 TF2 中序列化和恢复模型的指南?文档非常参差不齐,其中大部分都没有更新到 TF2。我还没有在任何地方看到一个清晰而完整的例子,任何地方都Estimator被保存,从磁盘加载并用于从新输入进行预测。

TBH,我对这看起来有多复杂感到有些困惑。Estimators 被宣传为拟合标准模型的简单、相对高级的方法,但在生产中使用它们的过程似乎非常神秘。例如,当我通过tf.saved_model.load(export_path)获取一个AutoTrackable对象从磁盘加载模型时:

<tensorflow.python.training.tracking.tracking.AutoTrackable at 0x7fc42e779f60>

不清楚为什么我不Estimator回来。看起来曾经有一个听起来很有用的函数tf.contrib.predictor.from_saved_model,但自从contrib消失了,它似乎不再起作用了(除了它出现在 TFLite 中)。

任何指针都会非常有帮助。如你所见,我有点失落。

小智 6

也许作者不再需要答案,但我能够使用 TensorFlow 2.1 保存和加载 DNNClassifier

# training.py
from pathlib import Path
import tensorflow as tf

....
# Creating the estimator
estimator = tf.estimator.DNNClassifier(
    model_dir= < model_dir >,
    hidden_units = [1000, 500],
    feature_columns = feature_columns,  # this is a list defined earlier
    n_classes = 2,
    optimizer = 'adam'
)

feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)
export_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
servable_model_path = Path(estimator.export_saved_model( < model_dir >, export_input_fn).decode('utf8'))
print(f'Model saved at {servable_model_path}')

Run Code Online (Sandbox Code Playgroud)

对于加载,您找到了正确的方法,您只需检索predict_fn

# testing.py
import tensorflow as tf
import pandas as pd

def predict_input_fn(test_df):
    '''Convert your dataframe using tf.train.Example() and tf.train.Features()'''
    examples = []
    ....
    return tf.constant(examples)

test_df = pd.read_csv('test.csv', ...)

# Loading the estimator
predict_fn = tf.saved_model.load(<model_dir>).signatures['predict']
# Predict
predictions = predict_fn(examples=predict_input_fn(test_df))
Run Code Online (Sandbox Code Playgroud)

希望这也能帮助其他人(:

  • 我不敢相信用当前的 Estimator API 做一些简单的事情(例如保存和加载模型)是多么麻烦。相比之下,Pytorch 是如此简单和直观。 (2认同)