如何使用 JAX 打印

Ali*_*_Sh 2 jit jax

我有一个 JAX 布尔数组,想要打印一条与True之和相结合的语句:

import jax
import jax.numpy as jnp
from jax.experimental.host_callback import id_print

@jax.jit
def overlaps_jax():
    mask_cp = jnp.array([True, False, False, True, False, True, False, True, True])
    id_print(jnp.sum(mask_cp))

overlaps_jax()
Run Code Online (Sandbox Code Playgroud)

中有 5 个Truemask_cp;我想打印为:

With jax accelerator
There are 5 true bools
Run Code Online (Sandbox Code Playgroud)

由于这个函数是jitted 的,我尝试使用 来打印它id_print,但我不能。id_print(jnp.sum(mask_cp))将打印5,但我无法将其与字符串一起使用。我已经尝试过以下方法:

id_print(jnp.sum(mask_cp))
# print:
# 5

id_print("\nWith jax accelerator\nThere are " + jnp.sum(mask_cp) + " true bools\n")
# error:
# TypeError: can only concatenate str (not "DynamicJaxprTracer") to str

print("\nWith jax accelerator\nThere are {} true bools\n".format(jnp.sum(mask_cp)))
# print:
# With jax accelerator
# There are Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)> true bools
Run Code Online (Sandbox Code Playgroud)

我怎样才能在这段代码中打印这样的语句?

jak*_*vdp 6

请注意,这id_print是实验性的,其 API 和功能可能会发生变化。也就是说,我不相信id_print有能力添加这样的文本,但你可以通过更通用的方式来做到这一点host_callback.call

import jax
import jax.numpy as jnp
from jax.experimental.host_callback import call

@jax.jit
def overlaps_jax():
    mask_cp = jnp.array([True, False, False, True, False, True, False, True, True])
    call(lambda x: print(f"There are {x} true bools"), jnp.sum(mask_cp))

overlaps_jax()
Run Code Online (Sandbox Code Playgroud)

输出是

There are 5 true bools
Run Code Online (Sandbox Code Playgroud)