pybrain:如何打印网络(节点和权重)

Dr *_*ban 11 python neural-network pybrain

最后我设法从一个文件训练一个网络:)现在我想打印节点和权重,特别是权重,因为我想用pybrain训练网络,然后在其他地方实现NN将使用它.

我需要一种方法来打印节点之间的层,节点和权重,以便我可以轻松地复制它.到目前为止,我看到我可以使用n ['in']来访问图层,然后例如我可以这样做:

dir(n ['in'])[' class ',' delattr ',' dict ',' doc ',' format ',' getattribute ',' hash ',' init ',' module ',' new ', ' reduce ',' reduce_ex ',' repr ',' setattr ',' sizeof ',' str ',' subclasshook ',' weakref ','_ backwardImplementation','_ forwardImplementation','_ generateName','_ getName','_ growBuffers ','_ name','_ nameIds','_resetBuffers','_ setName','activate','activateOnDataset','argdict','backActivate','backward','bufferlist','dim','forward', 'getName','indim','inputbuffer','inputerror','name','offset','outdim','outputbuffer','outputerror','paramdim','reset','sequential','setArgs ','setName','shift','whichNeuron']

但我不知道如何在这里访问权重.还有params属性,例如我的网络是2 4 1有偏见,它说:

n.params array([ - 0.8167133,1.00077451,-0.7591257,-1.1150532,-1.58789386,0.11625991,0.98547457,-0.99397871,-1.8324281,-2.42200963,1.90617387,1.93741167,-2.88433965,0.27449852,-1.52606976,2.339446258,3.01359547])

很难说是什么,至少在重量连接哪些节点.这就是我所需要的一切.

sch*_*aul 21

有许多方法可以访问网络的内部,即通过其"模块"列表或其"连接"字典.参数存储在这些连接或模块中.例如,以下内容应打印任意网络的所有此信息:

for mod in net.modules:
    print("Module:", mod.name)
    if mod.paramdim > 0:
        print("--parameters:", mod.params)
    for conn in net.connections[mod]:
        print("-connection to", conn.outmod.name)
        if conn.paramdim > 0:
             print("- parameters", conn.params)
    if hasattr(net, "recurrentConns"):
        print("Recurrent connections")
        for conn in net.recurrentConns:
            print("-", conn.inmod.name, " to", conn.outmod.name)
            if conn.paramdim > 0:
                print("- parameters", conn.params)
Run Code Online (Sandbox Code Playgroud)

如果你想要更细粒度的东西(在神经元级而不是层级),你将不得不进一步分解这些参数向量 - 或者,从单神经元层构建你的网络.


小智 11

试试这个,它对我有用:

def pesos_conexiones(n):
    for mod in n.modules:
        for conn in n.connections[mod]:
            print conn
            for cc in range(len(conn.params)):
                print conn.whichBuffers(cc), conn.params[cc]
Run Code Online (Sandbox Code Playgroud)

结果应该是:

<FullConnection 'co1': 'hidden1' -> 'out'>
(0, 0) -0.926912942354
(1, 0) -0.964135087592
<FullConnection 'ci1': 'in' -> 'hidden1'>
(0, 0) -1.22895643048
(1, 0) 2.97080368887
(2, 0) -0.0182867906276
(3, 0) 0.4292544603
(4, 0) 0.817440427069
(0, 1) 1.90099230604
(1, 1) 1.83477578625
(2, 1) -0.285569867513
(3, 1) 0.592193396226
(4, 1) 1.13092061631
Run Code Online (Sandbox Code Playgroud)