Keras:如果数据大小不能被batch_size整除怎么办?

aha*_*jib 9 theano deep-learning keras

我是Keras的新手,刚刚开始研究一些例子.我正在处理以下问题:我有4032个样本并且使用大约650个样本用于拟合或基本上训练状态,然后使用其余的来测试模型.问题是我不断收到以下错误:

Exception: In a stateful network, you should only pass inputs with a number of samples that can be divided by the batch size.
Run Code Online (Sandbox Code Playgroud)

我理解为什么我会收到此错误,我的问题是,如果我的数据大小不能被整除batch_size怎么办?我曾经和Deeplearning4j LSTM一起工作,没有必要处理这个问题.反正有没有解决这个问题?

谢谢

Pro*_*ies 5

最简单的解决方案是使用 fit_generator 而不是 fit。我写了一个简单的数据加载器类,可以继承它来做更复杂的事情。它看起来像这样,将 get_next_batch_data 重新定义为您的数据包括增强等内容。

class BatchedLoader():
    def __init__(self):
        self.possible_indices = [0,1,2,...N] #(say N = 33)
        self.cur_it = 0
        self.cur_epoch = 0

    def get_batch_indices(self):
        batch_indices = self.possible_indices [cur_it : cur_it + batchsize]
        # If len(batch_indices) < batchsize, the you've reached the end
        # In that case, reset cur_it to 0 and increase cur_epoch and shuffle possible_indices if wanted
        # And add remaining K = batchsize - len(batch_indices) to batch_indices


    def get_next_batch_data(self):
        # batch_indices = self.get_batch_indices()
        # The data points corresponding to those indices will be your next batch data
Run Code Online (Sandbox Code Playgroud)