如何在theano中保存/序列化训练有素的模型?

xag*_*agg 8 python serialization loading save theano

我按照加载和保存的文档保存了模型.

# saving trained model
f = file('models/simple_model.save', 'wb')
cPickle.dump(ca, f, protocol=cPickle.HIGHEST_PROTOCOL)
f.close()
Run Code Online (Sandbox Code Playgroud)

ca是一个训练有素的自动编码器.这是一个类的实例cA.从我构建的脚本和保存模型我可以调用ca.get_reconstructed_input(...),ca.get_hidden_values(...)没有任何问题.

在另一个脚本中,我尝试加载训练的模型.

# loading the trained model
model_file = file('models/simple_model.save', 'rb')
ca = cPickle.load(model_file)
model_file.close()
Run Code Online (Sandbox Code Playgroud)

我收到以下错误.

ca = cPickle.load(model_file)
Run Code Online (Sandbox Code Playgroud)

AttributeError:'module'对象没有属性'cA'

Dan*_*haw 12

执行unpickling的脚本需要知道pickle对象的所有类定义.其他StackOverflow问题还有更多内容(例如,AttributeError:'module'对象没有属性'newperson').

只要您正确导入,您的代码就是正确的cA.鉴于你得到的错误可能不是这样.确保你正在使用from cA import cA而不仅仅是import cA.

或者,您的模型由其参数定义,因此您可以只选择参数值.这可以通过两种方式完成,具体取决于您的观点.

  1. 保存Theano共享变量.这里我们假设这ca.params是Theano共享变量实例的常规Python列表.

    cPickle.dump(ca.params, f, protocol=cPickle.HIGHEST_PROTOCOL)
    
    Run Code Online (Sandbox Code Playgroud)
  2. 保存存储在Theano共享变量中的numpy数组.

    cPickle.dump([param.get_value() for param in ca.params], f, protocol=cPickle.HIGHEST_PROTOCOL)
    
    Run Code Online (Sandbox Code Playgroud)

如果要加载模型,则需要重新初始化参数.例如,然后创建cA该类的新实例

ca.params = cPickle.load(f)
ca.W, ca.b, ca.b_prime = ca.params
Run Code Online (Sandbox Code Playgroud)

要么

ca.params = [theano.shared(param) for param in cPickle.load(f)]
ca.W, ca.b, ca.b_prime = ca.params
Run Code Online (Sandbox Code Playgroud)

请注意,您需要设置params字段和单独的参数字段.