Vib*_*wal 5 python hdf5 large-data keras
我在编写与Keras中的fit_generator一起使用的data_generator时遇到麻烦。我有一个HDF5文件,该文件具有作为单独数据集存储的4维numpy数组(3-D数据,具有来自处理的额外一维数据)。
每个数据集的格式为(xxx,512、512、1),其中xxx是该特定数据集中的切片数。我有大量数据(500个3-D图像,每个图像有〜300个切片,总计〜50 GB,比我拥有的RAM大得多)。
当我运行代码时,它显示
时代1/250
然后ram开始填充,并给出内存错误。我该如何为此编写一个data_generator?我还可以研究其他什么方法?
这是我的代码:
import scipy.io as sio
import numpy as np
import os
import pandas
import pickle
import random
import keras
from keras.models import Sequential, Model
from keras.layers import Conv2D, MaxPool2D, GlobalAveragePooling2D
from keras.layers import Dense, Dropout, Flatten, Activation
from keras.utils import np_utils
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
from keras.backend import tf as ktf
import h5py
with open('categs/database.pickle', 'rb') as f:
a=pickle.load(f)
def generate_data(a, batch_size):
while True:
key=random.choice(list(a.keys()))
# X_train= np.load(key+'.npy')
with open('categs/'+key+'.pickle', 'rb') as labels, h5py.File('datafile.hdf5', mode='r') as f:
X_train= np.asarray(f[key])
Y_train=pickle.load(labels)
X_train1=np.zeros((X_train.shape[0], X_train.shape[1], X_train.shape[2],3))
for i in range(1,X_train.shape[0]-1):
X_train1[i,:,:,1]=X_train[i,:,:,0]
X_train1[i,:,:,0]=X_train[i-1, :, :, 0]
X_train1[i, :, :,2]=X_train[i+1, :, :, 0]
X_train=X_train1
X_train1=[]
Y_train=np_utils.to_categorical(Y_train)
yield X_train, Y_train
base_model=ResNet50(include_top=False, weights='imagenet', input_tensor=None, input_shape=(512,512,3), pooling=None)
x=base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(7, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
for layer in base_model.layers:
layer.trainable = False
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.fit_generator(generator=generate_data(a, batch_size=1), steps_per_epoch=1, epochs=250,workers=6)
Run Code Online (Sandbox Code Playgroud)
我还尝试通过将每个3-D文件另存为单独的numpy数组,从磁盘中随机读取数据,但是会发生相同的情况。
| 归档时间: |
|
| 查看次数: |
805 次 |
| 最近记录: |