Art*_*rin 34 python deep-learning tensorflow tf.keras tensorflow2.x
假设我想编写一个符合tf.kerasAPI的自定义优化器类(使用 TensorFlow version>=2.0)。我对记录在案的方法与实现中的方法感到困惑。
tf.keras.optimizers.Optimizer 状态的文档,
### Write a customized optimizer.
If you intend to create your own optimization algorithm, simply inherit from
this class and override the following methods:
- resource_apply_dense (update variable given gradient tensor is dense)
- resource_apply_sparse (update variable given gradient tensor is sparse)
- create_slots (if your optimizer algorithm requires additional variables)
Run Code Online (Sandbox Code Playgroud)
不过,目前的tf.keras.optimizers.Optimizer实现没有定义resource_apply_dense方法,但它确实定义了一个私人的前瞻性_resource_apply_dense方法存根。同样,没有resource_apply_sparseorcreate_slots方法,但有_resource_apply_sparse方法存根和_create_slots方法调用。
在官方tf.keras.optimizers.Optimizer亚类(使用tf.keras.optimizers.Adam作为一个例子),还有_resource_apply_dense,_resource_apply_sparse和_create_slots方法,并且不存在这样的方法没有前导下划线。
在不太正式的tf.keras.optimizers.Optimizer子类中也有类似的前导下划线方法(例如,tfa.optimizers.MovingAverage来自 TensorFlow Addons: _resource_apply_dense, _resource_apply_sparse, _create_slots)。
另一个让我困惑的地方是,一些 TensorFlow Addons 优化器也会覆盖该apply_gradients方法(例如,tfa.optimizers.MovingAverage),而tf.keras.optimizers优化器不会。
而且,我注意到方法调用apply_gradients的tf.keras.optimizers.Optimizer方法,但基类没有方法。因此,似乎必须在优化器子类中定义一个方法,如果该子类没有覆盖._create_slotstf.keras.optimizers.Optimizer_create_slots_create_slotsapply_gradients
子类 a 的正确方法是tf.keras.optimizers.Optimizer什么?具体来说,
tf.keras.optimizers.Optimizer顶部列出的文档是否仅仅意味着覆盖他们提到的方法的前导下划线版本(例如,_resource_apply_dense而不是resource_apply_dense)?如果是这样,是否有任何 API 保证这些看起来私有的方法不会在 TensorFlow 的未来版本中改变它们的行为?这些方法的签名是什么?apply_gradients除了_apply_resource_[dense|sparse]方法之外,什么时候会覆盖?编辑。在 GitHub 上打开的问题:#36449
更新:TF2.2 迫使我清理所有实现 - 所以现在它们可以用作 TF 最佳实践的参考。还在下面添加了一个关于_get_hypervs.的部分_set_hyper。
我已经在所有主要的 TF 和 Keras 版本中实现了Keras AdamW - 我邀请您检查optimizers_v2.py。几点:
OptimizerV2,这实际上是您链接的内容;它是tf.keras优化器的最新和当前基类apply_gradients(或任何其他方法)仅在默认值无法完成给定优化器所需的内容时才被覆盖;在您链接的示例中,它只是原始文件的单行插件_create_slots如果该子类没有覆盖,则似乎必须在优化器子类中定义一个方法apply_gradients” - 两者无关;这是巧合。_resource_apply_dense和 和有_resource_apply_sparse什么区别?后者处理稀疏层 - 例如Embedding- 前者处理其他一切;例子。
_create_slots()?在定义可训练的 tf.Variables 时;例如:权重的一阶和二阶矩(例如 Adam)。它使用add_slot().
_get_hyper对比_set_hyper:它们使设置和获取Python的文字(int,str,等),可调用和张量。它们的存在主要是为了方便:_set_hyper可以通过 检索通过设置的任何内容_get_hyper,避免重复样板代码。我在这里专门针对它进行了问答。
def _create_slots(self, var_list):
"""Create all slots needed by the variables.
Args:
var_list: A list of `Variable` objects.
"""
# No slots needed by default
pass
def _resource_apply_dense(self, grad, handle):
"""Add ops to apply dense gradients to the variable `handle`.
Args:
grad: a `Tensor` representing the gradient.
handle: a `Tensor` of dtype `resource` which points to the variable
to be updated.
Returns:
An `Operation` which updates the value of the variable.
"""
raise NotImplementedError()
def _resource_apply_sparse(self, grad, handle, indices):
"""Add ops to apply sparse gradients to the variable `handle`.
Similar to `_apply_sparse`, the `indices` argument to this method has been
de-duplicated. Optimizers which deal correctly with non-unique indices may
instead override `_resource_apply_sparse_duplicate_indices` to avoid this
overhead.
Args:
grad: a `Tensor` representing the gradient for the affected indices.
handle: a `Tensor` of dtype `resource` which points to the variable
to be updated.
indices: a `Tensor` of integral type representing the indices for
which the gradient is nonzero. Indices are unique.
Returns:
An `Operation` which updates the value of the variable.
"""
raise NotImplementedError()
Run Code Online (Sandbox Code Playgroud)
apply_dense。一方面,如果您确实覆盖它,代码会提到每个副本的 DistributionStrategy 可能是“危险的” # TODO(isaprykin): When using a DistributionStrategy, and when an
# optimizer is created in each replica, it might be dangerous to
# rely on some Optimizer methods. When such methods are called on a
# per-replica optimizer, an exception needs to be thrown. We do
# allow creation per-replica optimizers however, because the
# compute_gradients()->apply_gradients() sequence is safe.
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
3844 次 |
| 最近记录: |