Jak*_*old 5 python autodiff tensorflow eager-execution
考虑以下函数
def foo(x):
with tf.GradientTape() as tape:
tape.watch(x)
y = x**2 + x + 4
return tape.gradient(y, x)
Run Code Online (Sandbox Code Playgroud)
tape.watch(x)如果函数被称为 as foo(tf.constant(3.14)),则调用是必要的,但当它直接传入变量时则不需要,例如foo(tf.Variable(3.14))。
现在我的问题是,tape.watch(x)即使tf.Variable在直接传入的情况下也调用安全吗?还是会因为变量已经被自动监视然后再次手动监视而发生一些奇怪的事情?编写这样可以同时接受tf.Tensor和的通用函数的正确方法是什么tf.Variable?
它应该是安全的。一方面,文档tf.GradientTape.watch说:
确保
tensor此磁带正在跟踪。
“确保”似乎暗示它将确保它被跟踪,以防万一。事实上,文档没有给出任何迹象表明在同一个对象上使用它两次应该是一个问题(尽管如果他们明确表示不会有什么坏处)。
但无论如何,我们可以深入源代码进行检查。最后,调用watch一个变量(如果它不是一个变量,但路径略有不同,则答案最终相同)归结为C++WatchVariable中的GradientTape类的方法:
void WatchVariable(PyObject* v) {
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
return;
}
tensorflow::int64 id = FastTensorId(handle.get());
if (!PyErr_Occurred()) {
this->Watch(id);
}
tensorflow::mutex_lock l(watched_variables_mu_);
auto insert_result = watched_variables_.emplace(id, v);
if (insert_result.second) {
// Only increment the reference count if we aren't already watching this
// variable.
Py_INCREF(v);
}
}
Run Code Online (Sandbox Code Playgroud)
该方法的后半部分显示了被监视的变量被添加到watched_variables_,即 a std::set,因此再次添加一些东西不会做任何事情。这实际上是在稍后检查以确保 Python 引用计数是正确的。前半部分基本上称为Watch:
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
int64 tensor_id) {
tensor_tape_.emplace(tensor_id, -1);
}
Run Code Online (Sandbox Code Playgroud)
tensor_tape_是一个地图(特别是 a tensorflow::gtl:FlatMap,几乎与标准 C++ 地图相同),所以如果tensor_id已经存在,这将不起作用。
因此,即使没有明确说明,一切都表明它应该没有问题。
| 归档时间: |
|
| 查看次数: |
2380 次 |
| 最近记录: |