JAX 的 vmap 中的 in_axes 关键字

Sam*_*uel 6 auto-vectorization jax

我试图使用vmap基于 JAX 文档的最小工作示例来了解 JAX 的自动矢量化功能。

我不明白如何in_axes正确使用。在下面的示例中,我可以设置in_axes=(None, 0)in_axes=(None, 1)导致相同的结果。为什么会这样?

为什么我必须使用in_axes=(None, 0)而不是类似的东西in_axes=(0, )

import jax.numpy as jnp
from jax import vmap


def predict(params, input_vec):
    assert input_vec.ndim == 1
    activations = input_vec
    for W, b in params:
        outputs = jnp.dot(W, activations) + b
        activations = jnp.tanh(outputs)
    return outputs


if __name__ == "__main__":

    # Parameters
    dims = [2, 3, 5]
    input_dims = dims[0]
    batch_size = 2

    # Weights
    params = list()
    for dims_in, dims_out in zip(dims, dims[1:]):
        params.append((jnp.ones((dims_out, dims_in)), jnp.ones((dims_out,))))

    # Input data
    input_batch = jnp.ones((batch_size, input_dims))

    # With vmap
    predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
    print(predictions)
Run Code Online (Sandbox Code Playgroud)

jak*_*vdp 4

in_axes=(None, 0) means that the first argument (here params) will not be mapped, while the second argument (here input_vec) will be mapped along axis 0.

In the example below I can set in_axes=(None, 0) or in_axes=(None, 1) leading to the same results. Why is that the case?

这是因为input_vec是一个由 1 组成的 2x2 矩阵,因此无论您沿轴 0 还是轴 1 映射,输入向量都是由 1 组成的长度为 2 的向量。在更一般的情况下,这两个规范并不等效,您可以通过 (1) 使 与batch_size不同input_dims[0],或 (2) 用非常量值填充数组来看到这一点。

为什么我必须使用in_axes=(None, 0)而不是类似的东西in_axes=(0, )

如果您in_axes=(0, )为具有两个参数的函数进行设置,则会收到错误,因为in_axes元组的长度必须与传递给函数的参数数量相匹配。也就是说,可以传递一个标量in_axes=0作为 的简写in_axes=(0, 0),但是对于您的函数,这会导致形状错误,因为 中 的数组的前导维度与params的前导维度不匹配input_vec