Rez*_*ian 6 methods call lstm tensorflow recurrent-neural-network
我知道 __call__ 是什么,但让我困惑的是,像 BasicRNNCell 或 tf.nn.rnn_cell.MultiRNNCell 这样的类有这个 \'call\' 方法而不是 _call__ 。这个简单的调用方法是什么?看起来像是同一件事,如果不是,那么我没有看到它被调用。\n我在没有任何线索的情况下找到了这个解释。你能澄清一下吗?
\n\n“调用函数是单元逻辑所在的位置。RNNCell\xe2\x80\x99s __call_ 方法将包装您的调用方法并帮助确定范围和其他后勤工作。”\n示例:
\n\ndef call(self, inputs, state):\n\n total_hidden_size = sum(c._h_above_size for c in self._cells)\n\n # split out the part of the input that stores values of ha\n raw_inp = inputs[:, :-total_hidden_size] # [B, I]\n raw_h_aboves = inputs[:, -total_hidden_size:] # [B, sum(ha_l)]\n\n ha_splits = [c._h_above_size for c in self._cells]\n h_aboves = array_ops.split(value=raw_h_aboves,\n num_or_size_splits=ha_splits, axis=1)\n\n z_below = tf.ones([tf.shape(inputs)[0], 1]) # [B, 1]\n raw_inp = array_ops.concat([raw_inp, z_below], axis=1) # [B, I + 1]\nRun Code Online (Sandbox Code Playgroud)\n
小智 3
在tensorflow2.0中,如果通过子类化tf.keras.Model来定义网络,则需要在call()中实现模型的前向传递。
https://www.tensorflow.org/api_docs/python/tf/keras/Model
| 归档时间: |
|
| 查看次数: |
1875 次 |
| 最近记录: |