Rnn神经网络预测返回意外预测

Mar*_*tin 5 java deep-learning deeplearning4j rnn

我正在试图配置RNN神经网络,以预测5种不同类型的文本实体.我正在使用下一个配置:

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .iterations(100)
            .updater(Updater.ADAM)  //To configure: .updater(Adam.builder().beta1(0.9).beta2(0.999).build())
            .regularization(true).l2(1e-5)
            .weightInit(WeightInit.XAVIER)
            .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(1.0)
            .learningRate(2e-2)
            .trainingWorkspaceMode(WorkspaceMode.SEPARATE).inferenceWorkspaceMode(WorkspaceMode.SEPARATE)   //https://deeplearning4j.org/workspaces
            .list()
            .layer(0, new GravesLSTM.Builder().nIn(500).nOut(3)
                    .activation(Activation.TANH).build())
            .layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX)        //MCXENT + softmax for classification
                    .nIn(3).nOut(5).build())
            .pretrain(false).backprop(true).build();
  MultiLayerNetwork net = new MultiLayerNetwork(conf);
  net.init();
Run Code Online (Sandbox Code Playgroud)

我训练它然后我评估它.有用.不过我用的时候:

 int[] prediction = net.predict(features);
Run Code Online (Sandbox Code Playgroud)

有时它会回归并出现意想不到的预测.它返回正确的预测为1,2 .... 5但有时它返回数字为9,14,12 ...这个数字不对应于已识别的预测/标签.

为什么此配置会返回意外输出?

Ada*_*son 2

不要使用 net.predict。将 net.output 与 Nd4j.argMax(outputOfNeuralNet,-1) 结合使用;不应使用 Net.predict(它主要与 2d 一起使用)。