我已经为多输入 nn 编写了一个生成器,但是在使用 tf.data.Dataset.from_generator() 函数时出现错误,所有数据都在 numpy 中,其中:输入 1 的形状为(16,100,223,3),输入 2 的形状shape(100,223),输入 3 的形状为 (16,),标签的形状为(2,)。数据是所有这些组合的数组
我的代码
def data_generator(train_list, batch_size):
i = 0
j = 0
flag = True
while True:
# inputs = []
# outputs = []
if i < len(train_list):
if flag == True:
train_path = os.path.join(training_dir, train_list[i])
data = np.load(train_path, allow_pickle=True)
flag = False
if j >= len(data):
j = 0
i += 1
flag = True
del data
else:
if len(data[j:]) >= batch_size:
input_1 = data[j:(j+batch_size), 0] …Run Code Online (Sandbox Code Playgroud)