TensorFlow:从损失函数中排除一类的正确方法

Sta*_*ckd 5 classification machine-learning tensorflow

我有三个类1(活动),0(不活动)和-1(未知)。我想在TensorFlow中构建一个模型,该模型在给定输入的情况下预测活动或不活动。以下是仅通过活动标签和非活动标签计算损失并忽略未知标签的正确方法吗?

logits = tf.reshape(logits, [-1])
labels = tf.reshape(labels, [-1])
index = tf.where(tf.not_equal(labels, tf.constant(-1, dtype=tf.float32)))
logits = tf.gather(logits, index)
labels = tf.gather(labels, index)
entropies = tf.nn.sigmoid_cross_entropy_with_logits(logits, labels)
loss = tf.reduce_mean(entropies)
Run Code Online (Sandbox Code Playgroud)