Sam*_*uel 6 auto-vectorization jax
我试图使用vmap基于 JAX 文档的最小工作示例来了解 JAX 的自动矢量化功能。
我不明白如何in_axes正确使用。在下面的示例中,我可以设置in_axes=(None, 0)或in_axes=(None, 1)导致相同的结果。为什么会这样?
为什么我必须使用in_axes=(None, 0)而不是类似的东西in_axes=(0, )?
import jax.numpy as jnp
from jax import vmap
def predict(params, input_vec):
assert input_vec.ndim == 1
activations = input_vec
for W, b in params:
outputs = jnp.dot(W, activations) + b
activations = jnp.tanh(outputs)
return outputs
if __name__ == "__main__":
# Parameters
dims = [2, 3, 5]
input_dims = dims[0]
batch_size = 2
# Weights
params = list()
for dims_in, dims_out in zip(dims, dims[1:]):
params.append((jnp.ones((dims_out, dims_in)), jnp.ones((dims_out,))))
# Input data
input_batch = jnp.ones((batch_size, input_dims))
# With vmap
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
print(predictions)
Run Code Online (Sandbox Code Playgroud)
in_axes=(None, 0) means that the first argument (here params) will not be mapped, while the second argument (here input_vec) will be mapped along axis 0.
In the example below I can set
in_axes=(None, 0)orin_axes=(None, 1)leading to the same results. Why is that the case?
这是因为input_vec是一个由 1 组成的 2x2 矩阵,因此无论您沿轴 0 还是轴 1 映射,输入向量都是由 1 组成的长度为 2 的向量。在更一般的情况下,这两个规范并不等效,您可以通过 (1) 使 与batch_size不同input_dims[0],或 (2) 用非常量值填充数组来看到这一点。
为什么我必须使用
in_axes=(None, 0)而不是类似的东西in_axes=(0, )?
如果您in_axes=(0, )为具有两个参数的函数进行设置,则会收到错误,因为in_axes元组的长度必须与传递给函数的参数数量相匹配。也就是说,可以传递一个标量in_axes=0作为 的简写in_axes=(0, 0),但是对于您的函数,这会导致形状错误,因为 中 的数组的前导维度与params的前导维度不匹配input_vec。
| 归档时间: |
|
| 查看次数: |
1698 次 |
| 最近记录: |