Jon*_*n E 4 python lstm keras tensorflow
在运行加密货币RNN的senddex教程脚本时,请在此处链接
但尝试训练模型时遇到错误。我的tensorflow版本是2.0.0,我正在运行python 3.6。尝试训练模型时,出现以下错误:
File "C:\python36-64\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 734, in fit
use_multiprocessing=use_multiprocessing)
File "C:\python36-64\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py", line 224, in fit
distribution_strategy=strategy)
File "C:\python36-64\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py", line 497, in _process_training_inputs
adapter_cls = data_adapter.select_data_adapter(x, y)
File "C:\python36-64\lib\site-packages\tensorflow_core\python\keras\engine\data_adapter.py", line 628, in select_data_adapter
_type_name(x), _type_name(y)))
ValueError: Failed to find data adapter that can handle input: <class 'numpy.ndarray'>, (<class 'list'> containing values of types {"<class 'numpy.float64'>"})
Run Code Online (Sandbox Code Playgroud)
任何建议将不胜感激!
小智 36
您可以通过在调用之前将标签转换为数组来避免此错误model.fit()
:
train_x = np.asarray(train_x)
train_y = np.asarray(train_y)
validation_x = np.asarray(validation_x)
validation_y = np.asarray(validation_y)
Run Code Online (Sandbox Code Playgroud)
小智 9
Have you checked whether your training/testing data and training/testing labels are all numpy arrays? It might be that you're mixing numpy arrays with lists.
如果在处理从类继承的自定义生成器时遇到此问题,则keras.utils.Sequence
可能必须确保不要混合使用 aKeras
或tensorflow - Keras
-import。
当您必须切换到以前的tensorflow
版本以实现兼容性(例如使用cuDNN
)时,这种情况尤其可能发生。
例如,如果您将它与tensorflow
-version > 2 ...
from keras.utils import Sequence
class generatorClass(Sequence):
def __init__(self, x_set, y_set, batch_size):
...
def __len__(self):
...
def __getitem__(self, idx):
return ...
Run Code Online (Sandbox Code Playgroud)
...但您实际上尝试将此生成器安装在tensorflow
-version < 2 中,您必须确保Sequence
从此版本导入-class,例如:
keras = tf.compat.v1.keras
Sequence = keras.utils.Sequence
class generatorClass(Sequence):
...
Run Code Online (Sandbox Code Playgroud)
我有一个类似的问题。在我的情况下,我使用的是tf.keras.Sequential
模型而不是keras
生成器是一个问题。
错误的:
from keras.preprocessing.sequence import TimeseriesGenerator
gen = TimeseriesGenerator(...)
Run Code Online (Sandbox Code Playgroud)
正确的:
gen = tf.keras.preprocessing.sequence.TimeseriesGenerator(...)
Run Code Online (Sandbox Code Playgroud)