Liu*_*uka 4 python jit deep-learning jax
我对贾克斯有以下疑问。我将使用官方optax 文档中的一个示例来说明它:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
@jax.jit
def step(params, opt_state, batch, labels):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
params, opt_state, loss_value = step(params, opt_state, batch, labels)
if i % 100 == 0:
print(f'step {i}, loss: {loss_value}')
return params
# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)
Run Code Online (Sandbox Code Playgroud)
在此示例中,函数step使用该变量optimizer,尽管该变量未在函数参数内传递(因为该函数正在被抖动并且optax.GradientTransformation不是受支持的类型)。但是,同一函数使用其他变量作为参数传递(即params, opt_state, batch, labels)。我知道 jax 函数需要是纯函数才能进行 jitted,但是输入(只读)变量又如何呢?如果我通过函数参数传递变量来访问变量,或者因为它位于step函数作用域中而直接访问它,有什么区别吗?如果这个变量不是常量而是在单独的step调用之间修改了怎么办?如果直接访问,它们是否被视为静态参数?或者它们只是被扔掉,因此不会考虑对这些参数的修改?
更具体地说,让我们看下面的例子:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
extra_learning_rate = 0.1
@jax.jit
def step(params, opt_state, batch, labels):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
updates *= extra_learning_rate # not really valid code, but you get the idea
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
extra_learning_rate = 0.1
params, opt_state, loss_value = step(params, opt_state, batch, labels)
extra_learning_rate = 0.01 # does this affect the next `step` call?
params, opt_state, loss_value = step(params, opt_state, batch, labels)
return params
Run Code Online (Sandbox Code Playgroud)
与
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
extra_learning_rate = 0.1
@jax.jit
def step(params, opt_state, batch, labels, extra_lr):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
updates *= extra_lr # not really valid code, but you get the idea
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
extra_learning_rate = 0.1
params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)
extra_learning_rate = 0.01 # does this now affect the next `step` call?
params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)
return params
Run Code Online (Sandbox Code Playgroud)
从我有限的实验来看,它们的表现有所不同,因为第二次step调用在全局情况下没有使用新的学习率,也没有发生“重新调整”,但是我想知道是否有我需要的标准实践/规则要知道。我正在编写一个以性能为基础的库,我不想因为我做错了事情而错过一些 jit 优化。
在 JIT 跟踪期间,JAX 将全局值视为正在跟踪的函数的隐式参数。您可以在表示该函数的jaxpr中看到这一点。
以下是两个返回等效结果的简单函数,一个具有隐式参数,一个具有显式参数:
import jax
import jax.numpy as jnp
def f_explicit(a, b):
return a + b
def f_implicit(b):
return a_global + b
a_global = jnp.arange(5.0)
b = jnp.ones(5)
print(jax.make_jaxpr(f_explicit)(a_global, b))
# { lambda ; a:f32[5] b:f32[5]. let c:f32[5] = add a b in (c,) }
print(jax.make_jaxpr(f_implicit)(b))
# { lambda a:f32[5]; b:f32[5]. let c:f32[5] = add a b in (c,) }
Run Code Online (Sandbox Code Playgroud)
请注意,两个 jaxpr 中的唯一区别是,在 中f_implicit,a变量位于分号之前:这是jaxpr表示形式指示参数通过闭包传递而不是通过显式参数传递的方式。但这两个函数生成的计算将是相同的。
也就是说,需要注意的一个区别是,当闭包传递的参数是可散列常量时,它将在跟踪函数内被视为静态static_argnums(与显式参数通过或static_argnames在 内标记为静态时类似jax.jit):
a_global = 1.0
print(jax.make_jaxpr(f_implicit)(b))
# { lambda ; a:f32[5]. let b:f32[5] = add 1.0 a in (b,) }
Run Code Online (Sandbox Code Playgroud)
请注意,在 jaxpr 表示中,常量值直接作为操作的参数插入add。为 JIT 编译函数获得相同结果的显式方法如下所示:
from functools import partial
@partial(jax.jit, static_argnames=['a'])
def f_explicit(a, b):
return a + b
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1053 次 |
| 最近记录: |