liblaf.jarp¶
liblaf.jarp keeps mixed JAX PyTrees usable across function boundaries,
attrs, and NVIDIA Warp. The package is intentionally small: it focuses on
filtered call wrappers, PyTree-friendly class definitions, round-trippable
flattening, control-flow and ordered-condition helpers, and Warp array,
callable, and struct adapters that fit JAX-first workflows.
import jax.numpy as jnp
from liblaf import jarp
@jarp.define
class Batch:
values: object = jarp.array()
label: str = jarp.static()
@jarp.filter_jit
def normalize(batch: Batch) -> Batch:
centered = batch.values - jnp.mean(batch.values)
return Batch(values=centered, label=batch.label)
Read By Workflow¶
- Getting started for installation, the core field specifiers, and the first filtered call wrapper.
- Call wrappers for
filter_jit,fallback_jit,LaxWrapper, and thejarp.laxhelpers. - PyTree workflows for
define,auto,ravel,PyTreeProxy, import-time prelude registrations, and custom registration helpers. - Warp interop for
to_warp,struct, generic Warp adapters, and dtype helpers. - API reference for exact signatures and module-level details.
- Benchmarks for the current wrapper-overhead and PyTree registration measurements.