aha*_*jib 9 theano deep-learning keras
我是Keras的新手,刚刚开始研究一些例子.我正在处理以下问题:我有4032个样本并且使用大约650个样本用于拟合或基本上训练状态,然后使用其余的来测试模型.问题是我不断收到以下错误:
Exception: In a stateful network, you should only pass inputs with a number of samples that can be divided by the batch size.
Run Code Online (Sandbox Code Playgroud)
我理解为什么我会收到此错误,我的问题是,如果我的数据大小不能被整除batch_size怎么办?我曾经和Deeplearning4j LSTM一起工作,没有必要处理这个问题.反正有没有解决这个问题?
谢谢
最简单的解决方案是使用 fit_generator 而不是 fit。我写了一个简单的数据加载器类,可以继承它来做更复杂的事情。它看起来像这样,将 get_next_batch_data 重新定义为您的数据包括增强等内容。
class BatchedLoader():
def __init__(self):
self.possible_indices = [0,1,2,...N] #(say N = 33)
self.cur_it = 0
self.cur_epoch = 0
def get_batch_indices(self):
batch_indices = self.possible_indices [cur_it : cur_it + batchsize]
# If len(batch_indices) < batchsize, the you've reached the end
# In that case, reset cur_it to 0 and increase cur_epoch and shuffle possible_indices if wanted
# And add remaining K = batchsize - len(batch_indices) to batch_indices
def get_next_batch_data(self):
# batch_indices = self.get_batch_indices()
# The data points corresponding to those indices will be your next batch data
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2488 次 |
| 最近记录: |