vik*_*sit 10 python tensorflow tensorboard tensorflow-estimator
在培训预先预测的估算器时,打印精度指标以及丢失的最简单方法是什么?
大多数教程和文档似乎都解决了当您创建自定义估算器时的问题 - 如果打算使用其中一个可用的估算器,这似乎有点过头了.
tf.contrib.learn有一些(现已弃用)Monitor钩子.TF现在建议使用钩子API,但看起来它实际上没有任何可以利用标签和预测来生成准确度数字的东西.
你试过tf.contrib.estimator.add_metrics(estimator, metric_fn)(doc)吗?它需要一个初始化的估算器(可以预先设置),并为其添加定义的度量metric_fn.
用法示例:
def custom_metric(labels, predictions):
# This function will be called by the Estimator, passing its predictions.
# Let's suppose you want to add the "mean" metric...
# Accessing the class predictions (careful, the key name may change from one canned Estimator to another)
predicted_classes = predictions["class_ids"]
# Defining the metric (value and update tensors):
custom_metric = tf.metrics.mean(labels, predicted_classes, name="custom_metric")
# Returning as a dict:
return {"custom_metric": custom_metric}
# Initializing your canned Estimator:
classifier = tf.estimator.DNNClassifier(feature_columns=columns_feat, hidden_units=[10, 10], n_classes=NUM_CLASSES)
# Adding your custom metrics:
classifier = tf.contrib.estimator.add_metrics(classifier, custom_metric)
# Training/Evaluating:
tf.logging.set_verbosity(tf.logging.INFO) # Just to have some logs to display for demonstration
train_spec = tf.estimator.TrainSpec(input_fn=lambda:your_train_dataset_function(),
max_steps=TRAIN_STEPS)
eval_spec=tf.estimator.EvalSpec(input_fn=lambda:your_test_dataset_function(),
steps=EVAL_STEPS,
start_delay_secs=EVAL_DELAY,
throttle_secs=EVAL_INTERVAL)
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
Run Code Online (Sandbox Code Playgroud)
日志:
...
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [20/200]
INFO:tensorflow:Evaluation [40/200]
...
INFO:tensorflow:Evaluation [200/200]
INFO:tensorflow:Finished evaluation at 2018-04-19-09:23:03
INFO:tensorflow:Saving dict for global step 1: accuracy = 0.5668, average_loss = 0.951766, custom_metric = 1.2442, global_step = 1, loss = 95.1766
...
Run Code Online (Sandbox Code Playgroud)
如您所见,custom_metric将按默认指标和损失返回.
| 归档时间: |
|
| 查看次数: |
4227 次 |
| 最近记录: |