Skip to content

Filtered Call Wrappers

This benchmark compares the steady-state invocation overhead around a no-op function after warmup. jax.jit and equinox.filter_jit measure compiled call overhead. liblaf.jarp.filter_jit measures the cost of partitioning mixed inputs and recombining outputs on the same callable shape.

Source Code

Method No PyTree Complex PyTree
jax.jit 9.11 µs 266.30 µs
jarp.filter_jit 11.64 µs 321.06 µs
equinox.filter_jit 319.96 µs 326.68 µs

JAX

jax.jit provides the compiled-call baseline. It also imposes the strictest input requirements: leaves must be JAX-friendly values or JAX will raise while tracing.

JARP

jarp.filter_jit introduces a lightweight filtering mechanism. It partitions the call into:

  • Dynamic leaves: JAX arrays and None placeholders.
  • Static leaves: Other values, which are stored as metadata and stitched back into the original call shape.

That lets users pass mixed PyTrees through one callable boundary without manual partitioning. The overhead for this convenience is small in the benchmark.

Equinox

equinox.filter_jit is the closest comparison point for mixed-tree call wrappers. In this microbenchmark, its no-input invocation overhead is much higher, while the complex-PyTree case is close to jarp.filter_jit.

Test Environment

python==3.14.3
jax==0.10.0
liblaf-jarp==0.1.10.dev9+g99e249b88
equinox==0.13.8