如何修复sketch_rnn算法中的“ allow_pickle = False时无法加载对象数组”

Dun*_*rry 11 python machine-learning

我在jupyter笔记本上运行sketch_rnn.ipynb时,在加载环境以加载训练后的数据集时,返回错误“ allow_pickle = False时无法加载对象数组”

这是Google开发人员在开发sketch_rnn算法(甚至在google colab中运行)时已经使用的代码。过去我自己在google colab上运行过它,但是它似乎无法在我自己的jupyter笔记本上运行

from magenta.models.sketch_rnn.sketch_rnn_train import *
from magenta.models.sketch_rnn.model import *
from magenta.models.sketch_rnn.utils import *
from magenta.models.sketch_rnn.rnn import * 

model_params.batch_size = 1
eval_model_params = sketch_rnn_model.copy_hparams(model_params)
eval_model_params.use_input_dropout = 0
eval_model_params.use_recurrent_dropout = 0
eval_model_params.use_output_dropout = 0
eval_model_params.is_training = 0
sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
sample_model_params.max_seq_len = 1
return [model_params, eval_model_params, sample_model_params]


[train_set, valid_set, test_set, hps_model, eval_hps_model, 
sample_hps_model] = load_env_compatible(data_dir, model_dir)
Run Code Online (Sandbox Code Playgroud)

我期望输出是

INFO:tensorflow:Downloading http://github.com/hardmaru/sketch-rnn- 
datasets/raw/master/aaron_sheep/aaron_sheep.npz
INFO:tensorflow:Loaded 7400/300/300 from aaron_sheep.npz
INFO:tensorflow:Dataset combined: 8000 (7400/300/300), avg len 125
INFO:tensorflow:model_params.max_seq_len 250.
total images <= max_seq_len is 7400
total images <= max_seq_len is 300
total images <= max_seq_len is 300
INFO:tensorflow:normalizing_scale_factor 18.5198.
Run Code Online (Sandbox Code Playgroud)

但这给了我

ValueError: Object arrays cannot be loaded when allow_pickle=False
Run Code Online (Sandbox Code Playgroud)

Mad*_*mik 24

使用 allow_pickle=True 作为 np.load() 的参数之一。

  • 这是所有方法中最简单的。它之所以如此有效,是因为它确实按照预期处理了问题。 (2认同)

Sal*_*ngo 14

这段代码解决了我这边的问题。

# Downgrate numpy to fix a problem
!pip install numpy==1.16.2
import numpy as np
print(np.__version__)
Run Code Online (Sandbox Code Playgroud)

我只是将numpy降级,因为问题是由于某些内部冲突引起的。

  • @DuncanJerry不客气。您可以将其标记为正确答案,以帮助其他人快速找到它。 (2认同)
  • 版本 1.16.3 中已更改:响应 CVE-2019-6446,将默认设置为 False。请参阅:https://numpy.org/devdocs/reference/ generated/numpy.load.html#numpy.load (2认同)

Bry*_*n W 7

所以我认为这只是由于numpy对load()的更改而浮出水面,如果您观察到发生错误的行,它引用的是类似

    with np.load(path) as f:
        x_train, labels_train = f['x_train'], f['y_train']
        x_test, labels_test = f['x_test'], f['y_test']
Run Code Online (Sandbox Code Playgroud)

但是Keras源代码,例如在第58行:https : //github.com/keras-team/keras/blob/master/keras/datasets/imdb.py

现在使用

    with np.load(path, allow_pickle=True) as f:
        x_train, labels_train = f['x_train'], f['y_train']
        x_test, labels_test = f['x_test'], f['y_test']
Run Code Online (Sandbox Code Playgroud)

其中,np.load(path)成为np.load(path, boolean)

从简短的阅读中可以看出,的增加pickles与安全性有关,因为它pickles可以包含将在加载某些内容时运行的任意Python代码。(可能类似于执行SQL注入的方式)

用新的参数列表更新np.load之后,它就可以用于我的项目了