重置tensorflow流量指标的变量

ted*_*ted 1 python metrics tensorflow

我有一大堆的流指标(中tf.metrics.accuracy定制流媒体micro,macroweightedF1分数).

在训练期间,我得到了下面的那种情节(永远不要过度拟合).

这是因为计算验证集的度量标准,我调用它tf.local_variables_initializer来重置度量标准,并且只有验证集的值.

这意味着2个副作用:

  1. 图像中的尖峰
  2. 在验证之间,即使验证每2个时期发生,训练指标也会保持聚合

我可以通过让不同的张量保持每个度量(train vs val)来部分解决这种情况.但它无法解决2.

因此,我有两个问题:

  • 根据您的经验,这是您期望看到的行为(或不是?解决方案?)
  • 有没有办法让指标只在最后n一批中流?

spinging情节

vij*_*y m 7

如果您在培训之间重置指标,则会出现此行为.如果列车指标是两个不同的操作,则不会对验证指标进行粗略评估.我将举例说明如何使这些指标保持不同以及如何仅重置其中一个指标.


玩具示例:

logits = tf.placeholder(tf.int64, [2,3])
labels = tf.Variable([[0, 1, 0], [1, 0, 1]])

#create two different ops
with tf.name_scope('train'):
   train_acc, train_acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1), 
                                                 predictions=tf.argmax(logits,1))
with tf.name_scope('valid'):
   valid_acc, valid_acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1), 
                                                 predictions=tf.argmax(logits,1))
Run Code Online (Sandbox Code Playgroud)

训练:

#initialize the local variables has it holds the variables used for metrics calculation.
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())

# initial state
print(sess.run(train_acc, {logits:[[0,1,0],[1,0,1]]}))
print(sess.run(valid_acc, {logits:[[0,1,0],[1,0,1]]}))

#0.0
#0.0
Run Code Online (Sandbox Code Playgroud)

初始状态0.0如预期.

现在调用培训操作指标:

#training loop
for _ in range(10):
    sess.run(train_acc_op, {logits:[[0,1,0],[1,0,1]]})  
print(sess.run(train_acc, {logits:[[0,1,0],[1,0,1]]}))
# 1.0
print(sess.run(valid_acc, {logits:[[0,1,0],[1,0,1]]}))
# 0.0
Run Code Online (Sandbox Code Playgroud)

只有在有效精度仍然有效的情况下才能更新训练准确度0.0.调用有效的操作:

for _ in range(10):
    sess.run(valid_acc_op, {logits:[[0,1,0],[0,1,0]]}) 
print(sess.run(valid_acc, {logits:[[0,1,0],[1,0,1]]}))
#0.5
print(sess.run(train_acc, {logits:[[0,1,0],[1,0,1]]}))
#1.0
Run Code Online (Sandbox Code Playgroud)

在此,有效精度更新为新值,而训练精度保持不变.

让我们只重置验证操作:

stream_vars_valid = [v for v in tf.local_variables() if 'valid/' in v.name]
sess.run(tf.variables_initializer(stream_vars_valid))

print(sess.run(valid_acc, {logits:[[0,1,0],[1,0,1]]}))
#0.0
print(sess.run(train_acc, {logits:[[0,1,0],[1,0,1]]}))
#1.0
Run Code Online (Sandbox Code Playgroud)

当训练精度保持不变时,有效精度重置为零.