小编SON*_*ARK的帖子

使用 one-hot 代码的 Tensorflow 混淆矩阵

我使用 RNN 进行多类分类,这是我的 RNN 主要代码:

def RNN(x, weights, biases):
    x = tf.unstack(x, input_size, 1)
    lstm_cell = rnn.BasicLSTMCell(num_unit, forget_bias=1.0, state_is_tuple=True) 
    stacked_lstm = rnn.MultiRNNCell([lstm_cell]*lstm_size, state_is_tuple=True) 
    outputs, states = tf.nn.static_rnn(stacked_lstm, x, dtype=tf.float32)

    return tf.matmul(outputs[-1], weights) + biases

logits = RNN(X, weights, biases)
prediction = tf.nn.softmax(logits)

cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(cost)

correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
Run Code Online (Sandbox Code Playgroud)

我必须将所有输入分类为 6 个类,每个类都由一个热代码标签组成,如下所示:

happy = [1, 0, 0, 0, 0, 0]
angry = [0, 1, 0, 0, 0, 0] …
Run Code Online (Sandbox Code Playgroud)

confusion-matrix tensorflow one-hot-encoding multiclass-classification

8
推荐指数
1
解决办法
4563
查看次数