假设我有一堆定义如下的摘要:
loss = ...
tf.scalar_summary("loss", loss)
# ...
summaries = tf.merge_all_summaries()
Run Code Online (Sandbox Code Playgroud)
我可以summaries
在训练数据的每几步评估张量,并将结果传递给a SummaryWriter
.结果将是嘈杂的摘要,因为它们仅在一个批次上计算.
但是,我想计算整个验证数据集的摘要.当然,我无法将验证数据集作为单个批次传递,因为它太大了.因此,我将获得每个验证批次的摘要输出.
有没有办法对这些摘要进行平均,以便看起来好像是在整个验证集上计算了摘要?
Tom*_*Tom 45
在Python中对度量进行平均,并为每个均值创建一个新的Summary对象.这是我做的:
accuracies = []
# Calculate your measure over as many batches as you need
for batch in validation_set:
accuracies.append(sess.run([training_op]))
# Take the mean of you measure
accuracy = np.mean(accuracies)
# Create a new Summary object with your measure
summary = tf.Summary()
summary.value.add(tag="%sAccuracy" % prefix, simple_value=accuracy)
# Add it to the Tensorboard summary writer
# Make sure to specify a step parameter to get nice graphs over time
summary_writer.add_summary(summary, global_step)
Run Code Online (Sandbox Code Playgroud)
MZH*_*ZHm 10
我会避免计算图表外的平均值.
您可以使用tf.train.ExponentialMovingAverage:
ema = tf.train.ExponentialMovingAverage(decay=my_decay_value, zero_debias=True)
maintain_ema_op = ema.apply(your_losses_list)
# Create an op that will update the moving averages after each training step.
with tf.control_dependencies([your_original_train_op]):
train_op = tf.group(maintain_ema_op)
Run Code Online (Sandbox Code Playgroud)
然后,使用:
sess.run(train_op)
Run Code Online (Sandbox Code Playgroud)
这将调用,maintain_ema_op
因为它被定义为控件依赖项.
为了获得指数移动平均线,请使用:
moving_average = ema.average(an_item_from_your_losses_list_above)
Run Code Online (Sandbox Code Playgroud)
并使用以下方法检索其值:
value = sess.run(moving_average)
Run Code Online (Sandbox Code Playgroud)
这会计算计算图表中的移动平均值.
我认为让tensorflow进行计算总是更好.
看看流媒体指标.它们具有更新功能以提供当前批次的信息和函数以获取平均摘要.看起来有点像这样:
accuracy = ...
streaming_accuracy, streaming_accuracy_update = tf.contrib.metrics.streaming_mean(accuracy)
streaming_accuracy_scalar = tf.summary.scalar('streaming_accuracy', streaming_accuracy)
# set up your session etc.
for i in iterations:
for b in batches:
sess.run([streaming_accuracy_update], feed_dict={...})
streaming_summ = sess.run(streaming_accuracy_scalar)
writer.add_summary(streaming_summary, i)
Run Code Online (Sandbox Code Playgroud)
另请参阅tensorflow文档:https://www.tensorflow.org/versions/master/api_guides/python/contrib.metrics
这个问题: 如何在tensorflow中累积汇总统计数据
归档时间: |
|
查看次数: |
10590 次 |
最近记录: |