sknn - 第二次拟合时输入尺寸不匹配

see*_*equ 6 python reinforcement-learning scikit-learn

我试图建立一个利用强化学习的神经网络.我选择了scikit-neuralnetwork作为库(因为它很简单).看起来,两次装配会使Theano崩溃.

这是导致崩溃的最简单的代码(注意,它与哪些层无关,学习率或n_iter也不重要):

import numpy as np
from sknn.mlp import Classifier, Layer

clf = Classifier(
    layers=[
        Layer("Softmax")
        ],
    learning_rate=0.001,
    n_iter=1)

clf.fit(np.array([[0.]]), np.array([[0.]])) # Initialize the network for learning

X = np.array([[-1.], [1.]])
Y = np.array([[1.], [0.]])

clf.fit(X, Y) # crash
Run Code Online (Sandbox Code Playgroud)

这是我得到的错误:

ValueError: Input dimension mis-match. (input[0].shape[1] = 2, input[1].shape[1] = 1)
Apply node that caused the error: Elemwise{Mul}[(0, 1)](y, LogSoftmax.0)
Toposort index: 12
Inputs types: [TensorType(float64, matrix), TensorType(float64, matrix)]
Inputs shapes: [(1L, 2L), (1L, 1L)]
Inputs strides: [(16L, 8L), (8L, 8L)]
Inputs values: [array([[ 1.,  0.]]), array([[ 0.]])]
Outputs clients: [[Sum{axis=[1], acc_dtype=float64}(Elemwise{Mul}[(0, 1)].0)]]
Run Code Online (Sandbox Code Playgroud)

在Python 2.7.11中测试

sknn不支持多次拟合,还是我做了一些愚蠢的错误?如果没有,你应该如何实施强化学习?

pim*_*314 1

我不sknn经常使用,但它非常相似,sklearn所以我也许可以提供帮助!

首先,在使用该fit方法时,您将重新初始化权重,如果您想根据新数据更新权重,则应该使用该partial_fit方法。

关于崩溃,这是因为数组X在第一维而不是第二维的形状不同。

import numpy as np
from sknn.mlp import Classifier, Layer

clf = Classifier(
    layers=[
        Layer("Softmax")
        ],
    learning_rate=0.001,
    n_iter=1)

# Original training data
X = np.array([[0.]])
Y = np.array([[0.]])
print X.shape, Y.shape

# Data used for second fitting
X = np.array([[-1.], [1.]])
Y = np.array([[1.], [0.]])
print X.shape, Y.shape


# Use the partial fit method to update weights
clf.partial_fit(X, Y) # Initialize the network for learning
clf.partial_fit(X, Y) # Update the weights


# Multiple training examples by stacking two on top of each other
X = np.concatenate((X, X))
Y = np.concatenate((Y, Y))
print X.shape, Y.shape

clf.partial_fit(X, Y)
Run Code Online (Sandbox Code Playgroud)

输出:

(1, 1) (1, 1)
(2, 1) (2, 1)
(4, 1) (4, 1)
Run Code Online (Sandbox Code Playgroud)