Custom PyTree Nodes¶
This benchmark measures flatten and unflatten performance for custom PyTree nodes with seven data fields and three static fields.
| Method | Converter | Flatten | Unflatten | Total |
|---|---|---|---|---|
jarp.define |
no | 0.48 µs | 0.30 µs | 0.78 µs |
jarp.define |
yes | 0.49 µs | 0.65 µs | 1.14 µs |
jtu.register_dataclass |
no | 0.42 µs | 0.77 µs | 1.19 µs |
jtu.register_dataclass |
yes | 0.41 µs | 25.23 µs | 25.65 µs |
equinox.Module |
no | 0.89 µs | 0.76 µs | 1.65 µs |
equinox.Module |
yes | 0.90 µs | 0.79 µs | 1.69 µs |
jarp.define¶
jarp gets its unflatten speed from generated Python callbacks for each
registered class.
- Bypassing
__init__: The generated unflatten function creates a new instance withobject.__new__and populates fields directly. Converters and validators do not rerun during unflattening. - Assignment strategy: For simple classes, generated code can assign fields
directly. When attribute setters may interfere, it can use
object.__setattr__instead.
jax.tree_util.register_dataclass¶
When converters are present, jtu reconstructs objects by calling __init__,
which reruns converters such as jnp.asarray while restoring a tree node.
equinox.Module¶
Equinox is consistent regardless of converters, with a higher baseline than
the generated jarp callbacks in this benchmark.