
当对带有 `custom_vjp` 的函数调用 `vmap` 后再使用 `vjp`,若直接覆写原函数名会导致前向传播中递归调用错误的 vmapped 版本,从而引发 cotangent 形状不匹配的错误;正确做法是保留原始函数不变,仅对新变量赋值 vmapped 版本。
在 JAX 中,custom_vjp 允许用户自定义前向和反向传播逻辑,而 vmap 则用于向量化操作。二者结合使用时需格外注意函数绑定与作用域问题——最典型的错误是将 vmap 结果直接赋值给原函数名(如 test_func = vmap(test_func, ...)),这会破坏 custom_vjp 前向函数(test_func_fwd)内部对原始未向量化函数的预期调用。
回顾问题代码:在 test_func_fwd 中,primal_out = test_func(f, primal) 这一行本意是调用原始标量版 test_func,但由于 test_func 已被重新绑定为 vmap 版本,实际执行的是 vmap(test_func)(f, primal)。该调用将输入 primal(形状 (10, 3))沿 batch 轴展开,导致前向输出变为 (10,),但 custom_vjp 的反向逻辑仍按标量语义构造残差(residual = 2. * primal * primal_out),其中 primal 是 (10, 3) 而 primal_out 是 (10,),广播后 residual 变为 (10, 3)。最终 vjp 拉回(pullback)函数接收到的 cotangent 是 (10,)(对应输出形状),却试图与 (10, 3) 的梯度做运算,JAX 在校验阶段即抛出误导性错误:“cotangent shape (10,) must match primal input shape (10, 3)”。
✅ 正确解法是避免污染原始函数名,显式命名 vmapped 版本:
# ✅ 保留 test_func 不变,创建新变量 test_func_mapped = vmap(test_func, in_axes=(None, 0)) # 使用 test_func_mapped 进行 vjp primal, f_vjp = vjp(partial(test_func_mapped, f), jnp.ones((10, 3))) cotangent = jnp.ones(10) cotangent_out = f_vjp(cotangent) # 输出形状为 (10, 3),符合预期
⚠️ 注意事项:
- custom_vjp 的前向函数(fwd)必须严格调用原始未修饰的函数,不可依赖全局变量动态变化;
- 若需多层嵌套(如 vmap + jit + custom_vjp),建议始终采用“函数工厂”模式:先定义基础函数,再按需封装,避免就地覆写;
- 可通过 jax.make_jaxpr 或 jax.eval_shape 验证前向输出形状是否符合 custom_vjp 设计假设。
总结:JAX 的函数式特性要求开发者对绑定关系保持显式控制。vmap 不应覆盖原始 custom_vjp 函数,而应作为独立转换结果参与后续计算——这是保障梯度逻辑正确性的关键约定。










