如何使用 jax vmap 进行嵌套循环?

akk*_*kkh 5 python performance jit vectorization jax

我想使用 vmap 对该代码进行矢量化以提高性能。

def matrix(dataA, dataB):
    return jnp.array([[func(a, b) for b in dataB] for a in dataA])
matrix(data, data)
Run Code Online (Sandbox Code Playgroud)

我试过这个:

def f(x, y):
    return func(x, y)
mapped = jax.vmap(f)
mapped(data, data)
Run Code Online (Sandbox Code Playgroud)

但这只给出了对角线条目。

基本上我有一个向量data = [1,2,3,4,5](示例),我想得到一个矩阵,使得(i, j)矩阵的每个条目都是f(data[i], data[j])。因此,得到的矩阵形状将为(len(data), len(data))

jak*_*vdp 4

jax.vmap一次映射一组轴。如果要映射两组独立的轴,可以通过嵌套两个vmap转换来实现:

mapped = jax.vmap(jax.vmap(f, in_axes=(None, 0)), in_axes=(0, None))
result = mapped(data, data)
Run Code Online (Sandbox Code Playgroud)