5 python tensorflow tensorflow2.0
我正在实现简单的 RNN。在其中,我想返回列表中每个时间步的输出,稍后我可以将其提供给优化器。我已经建立了工作 rnn 没有@tf.function. 但添加后却@tf.function出现问题
def basic_rnn_cell(self,x,s):#Note :These function are defined in class
s=self.U*x+self.W*s+self.b #U,W,b and all are tf.Variable
y=self.V*s+self.c
return y,s
@tf.function
def rnn(self,X):
outputs=[]
state=self.state
for x in X:
output,state=self.basic_rnn_cell(x,state)
outputs.append(output)
return outputs
Run Code Online (Sandbox Code Playgroud)
这就是我的称呼:
x=np.array([0.01,0.02,0.03],dtype=np.float32)
o.rnn(x)
Run Code Online (Sandbox Code Playgroud)
我得到的错误:
raise errors.InaccessibleTensorError(
tensorflow.python.framework.errors_impl.InaccessibleTensorError: The tensor 'Tensor("while/add_2:0", shape=(), dtype=float32)' cannot be accessed here: it is defined in another function or code block.
Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=while_body_44, id=2538759416224); accessed from: FuncGraph(name=rnn, id=2538758824096).
Run Code Online (Sandbox Code Playgroud)
小智 11
这是因为使用python list临时保存张量对象。内存回收机制会在跟踪该函数后删除您保存的内容,因此无法实现。如果你想保存这些临时张量,你必须使用它tf.TensorArray作为替代。你可以参考这个: https: //www.tensorflow.org/guide/function#loops
| 归档时间: |
|
| 查看次数: |
9120 次 |
| 最近记录: |