尝试使用tf.data()keras api 批量生成数据时,我看到了奇怪的问题。它不断抛出错误,说它的training_data用完了。
TensorFlow 2.1
import numpy as np
import nibabel
import tensorflow as tf
from tensorflow.keras.layers import Conv3D, MaxPooling3D
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Flatten
from tensorflow.keras import Model
import os
import random
"""Configure GPUs to prevent OOM errors"""
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
"""Retrieve file names"""
ad_files = os.listdir("/home/asdf/OASIS/3D/ad/")
cn_files = os.listdir("/home/asdf/OASIS/3D/cn/")
sub_id_ad = []
sub_id_cn = []
"""OASIS AD: 178 Subjects, 278 3T MRIs"""
"""OASIS …Run Code Online (Sandbox Code Playgroud)