小编aev*_*ert的帖子

通过TensorFlow中的嵌入包装器提供数据

我正在研究文本摘要网络,需要实现一个编码器来使用tf.nn.seq2seq.embedding_attention_decoder.作为其中的一部分,我需要将不同批次的序列编码为表示向量,但最内层编码不会通过.

这是一个简化的片段,给出了同样的错误:

import tensorflow as tf                                               

single_cell = tf.nn.rnn_cell.GRUCell(1024)                            
sentence_cell = tf.nn.rnn_cell.EmbeddingWrapper(single_cell,
                                                embedding_classes = 40000)                                                     
batch = [tf.placeholder(tf.int32, [1,1]) for _ in range(250)]       
(_ , state) = tf.nn.rnn(sentence_cell, batch, dtype= tf.int32)
Run Code Online (Sandbox Code Playgroud)

这会因以下堆栈跟踪而失败:

Traceback (most recent call last):                                                                                                                                                                                                     
  File "/home/ubuntu/workspace/example.py", line 6, in <module>                                                                                                                                                                        
    (_ , state) = tf.nn.rnn(sentence_cell, batch, dtype= tf.int32)                                                                                                                                                                     
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 126, in rnn                                                                                                                                         
    (output, state) = call_cell()                                                                                                                                                                                                      
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 119, in <lambda>                                                                                                                                    
    call_cell = lambda: cell(input_, state)                                                                                                                                                                                            
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell.py", line 616, in __call__ …
Run Code Online (Sandbox Code Playgroud)

python tensorflow

5
推荐指数
0
解决办法
2196
查看次数

标签 统计

python ×1

tensorflow ×1