现在基于tensorflow-char-rnn我开始一个word-rnn项目来预测下一个单词.但我发现我的火车数据集的速度太慢了.这是我的培训细节:
机器细节:
在我的测试中,训练数据1个时代的时间需要17天!这真的太慢了,然后我将seq2seq.rnn_decoder更改为tf.nn.dynamic_rnn,但时间仍然是17天.
我想找到太慢的原因是由我的代码引起的,或者它总是那么慢?因为我听到一些传言称Tensorflow rnn比其他DL Framework慢.
这是我的型号代码:
class SeqModel():
def __init__(self, config, infer=False):
self.args = config
if infer:
config.batch_size = 1
config.seq_length = 1
if config.model == 'rnn':
cell_fn = rnn_cell.BasicRNNCell
elif config.model == 'gru':
cell_fn = rnn_cell.GRUCell
elif config.model == 'lstm':
cell_fn = rnn_cell.BasicLSTMCell
else:
raise Exception("model type not supported: {}".format(config.model))
cell = cell_fn(config.hidden_size)
self.cell = cell = rnn_cell.MultiRNNCell([cell] * config.num_layers)
self.input_data = tf.placeholder(tf.int32, …Run Code Online (Sandbox Code Playgroud)