Skip to content

liblaf.jarp.tree.attrs

attrs helpers for classes that should behave like JAX PyTrees.

The decorators and field specifiers in liblaf.jarp.tree.attrs wrap attrs while recording which fields should flatten as dynamic data, remain static metadata, or be decided from the runtime value.

Classes:

  • FieldType

    Describe how a field participates in PyTree flattening.

  • PyTreeType

    Choose how a class should participate in JAX PyTree flattening.

Functions:

  • array

    Create a data field whose default is normalized to a JAX array.

  • auto

    Create a field whose PyTree role is chosen from the runtime value.

  • define

    Define an attrs class and optionally register it as a PyTree.

  • field

    Create an attrs field using jarp's static metadata convention.

  • frozen

    Define a frozen attrs class and register it as a data PyTree.

  • frozen_static

    Define a frozen attrs class and register it as a static PyTree.

  • register_fieldz

    Register an attrs class with JAX using field metadata.

  • static

    Create a field that is always treated as static metadata.

FieldType

Bases: StrEnum


              flowchart TD
              liblaf.jarp.tree.attrs.FieldType[FieldType]

              

              click liblaf.jarp.tree.attrs.FieldType href "" "liblaf.jarp.tree.attrs.FieldType"
            

Describe how a field participates in PyTree flattening.

Methods:

Attributes:

AUTO class-attribute instance-attribute

AUTO = auto()

DATA class-attribute instance-attribute

DATA = auto()

META class-attribute instance-attribute

META = auto()

__bool__

__bool__() -> bool
Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
def __bool__(self) -> bool:
    match self:
        case FieldType.META:
            return True
        case FieldType.AUTO | FieldType.DATA:
            # for consistency with `jax.tree_util.register_dataclass`
            return False

PyTreeType

Bases: StrEnum


              flowchart TD
              liblaf.jarp.tree.attrs.PyTreeType[PyTreeType]

              

              click liblaf.jarp.tree.attrs.PyTreeType href "" "liblaf.jarp.tree.attrs.PyTreeType"
            

Choose how a class should participate in JAX PyTree flattening.

Attributes:

DATA class-attribute instance-attribute

DATA = auto()

NONE class-attribute instance-attribute

NONE = auto()

STATIC class-attribute instance-attribute

STATIC = auto()

array

array(
    *,
    default: T = ...,
    validator: _ValidatorArgType[T] | None = ...,
    repr: _ReprArgType = ...,
    hash: bool | None = ...,
    init: bool = ...,
    metadata: Mapping[Any, Any] | None = ...,
    converter: _ConverterType
    | list[_ConverterType]
    | tuple[_ConverterType, ...]
    | None = ...,
    factory: Callable[[], T] | None = ...,
    kw_only: bool | None = ...,
    eq: _EqOrderType | None = ...,
    order: _EqOrderType | None = ...,
    on_setattr: _OnSetAttrArgType | None = ...,
    alias: str | None = ...,
    type: type | None = ...,
    static: FieldType | bool | None = ...,
) -> Array

Create a data field whose default is normalized to a JAX array.

When default is a concrete array-like value, array rewrites it into a factory so each instance receives its own array object.

Parameters:

  • default (T, default: ... ) –
  • validator (_ValidatorArgType[T] | None, default: ... ) –
  • repr (_ReprArgType, default: ... ) –
  • hash (bool | None, default: ... ) –
  • init (bool, default: ... ) –
  • metadata (Mapping[Any, Any] | None, default: ... ) –
  • converter (_ConverterType | list[_ConverterType] | tuple[_ConverterType, ...] | None, default: ... ) –
  • factory (Callable[[], T] | None, default: ... ) –
  • kw_only (bool | None, default: ... ) –
  • eq (_EqOrderType | None, default: ... ) –
  • order (_EqOrderType | None, default: ... ) –
  • on_setattr (_OnSetAttrArgType | None, default: ... ) –
  • alias (str | None, default: ... ) –
  • type (type | None, default: ... ) –
  • static (FieldType | bool | None, default: ... ) –
Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
def array(**kwargs: Unpack[FieldOptions[Any]]) -> Array:
    """Create a data field whose default is normalized to a JAX array.

    When `default` is a concrete array-like value, `array` rewrites it into
    a factory so each instance receives its own array object.
    """
    if "default" in kwargs and "factory" not in kwargs:
        default: Any = kwargs["default"]
        if not (default is None or isinstance(default, attrs.Factory)):  # ty:ignore[invalid-argument-type]
            default: Array = jnp.asarray(default)
            kwargs.pop("default")
            kwargs["factory"] = lambda: default
    return field(**kwargs)  # ty:ignore[no-matching-overload]

auto

auto(**kwargs) -> Any

Create a field whose PyTree role is chosen from the runtime value.

Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def auto(**kwargs) -> Any:
    """Create a field whose PyTree role is chosen from the runtime value."""
    kwargs.setdefault("static", FieldType.AUTO)
    return field(**kwargs)

define

define[T: type](
    cls: T, /, **kwargs: Unpack[DefineOptions]
) -> T
define[T: type](
    cls: None = None, **kwargs: Unpack[DefineOptions]
) -> Callable[[T], T]

Define an attrs class and optionally register it as a PyTree.

Parameters:

  • maybe_cls (T | None, default: None ) –

    Class being decorated. When omitted, return a configured decorator.

  • **kwargs (Any, default: {} ) –

    Options forwarded to attrs.define, plus pytree to control JAX registration. pytree="data" registers fields with fieldz semantics, "static" registers the whole instance as a static value, and "none" leaves the class unregistered.

Returns:

  • Any

    The decorated class or a class decorator.

Source code in src/liblaf/jarp/tree/attrs/_define.py
def define[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define an `attrs` class and optionally register it as a PyTree.

    Args:
        maybe_cls: Class being decorated. When omitted, return a configured
            decorator.
        **kwargs: Options forwarded to [`attrs.define`][attrs.define], plus
            `pytree` to control JAX registration. `pytree="data"`
            registers fields with `fieldz` semantics, `"static"` registers
            the whole instance as a static value, and `"none"` leaves the
            class unregistered.

    Returns:
        The decorated class or a class decorator.
    """
    if maybe_cls is None:
        return functools.partial(define, **kwargs)
    pytree: PyTreeType = PyTreeType(kwargs.pop("pytree", None))
    frozen: bool = kwargs.get("frozen", False)
    if pytree is PyTreeType.STATIC and not frozen:
        warnings.warn(
            "Defining a static class that is not frozen may lead to unexpected behavior.",
            stacklevel=2,
        )
    cls: T = attrs.define(maybe_cls, **kwargs)  # ty:ignore[invalid-assignment]
    match pytree:
        case PyTreeType.DATA:
            register_fieldz(cls)
        case PyTreeType.STATIC:
            jtu.register_static(cls)
    return cls

field

field(**kwargs) -> Any

Create an attrs field using jarp's static metadata convention.

Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def field(**kwargs) -> Any:
    """Create an `attrs` field using jarp's `static` metadata convention."""
    if "static" in kwargs:
        kwargs["metadata"] = {
            "static": kwargs.pop("static"),
            **(kwargs.get("metadata") or {}),
        }
    return attrs.field(**kwargs)

frozen

frozen[T: type](
    cls: T, /, **kwargs: Unpack[DefineOptions]
) -> T
frozen[T: type](
    cls: None = None, /, **kwargs: Unpack[DefineOptions]
) -> Callable[[T], T]

Define a frozen attrs class and register it as a data PyTree.

This is the common choice for immutable structures whose array fields should participate in JAX transformations.

Source code in src/liblaf/jarp/tree/attrs/_define.py
def frozen[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define a frozen `attrs` class and register it as a data PyTree.

    This is the common choice for immutable structures whose array fields
    should participate in JAX transformations.
    """
    _warnings_hide = True
    if maybe_cls is None:
        return functools.partial(frozen, **kwargs)
    kwargs.setdefault("frozen", True)
    return define(maybe_cls, **kwargs)

frozen_static

frozen_static[T: type](
    cls: T, /, **kwargs: Unpack[DefineOptions]
) -> T
frozen_static[T: type](
    cls: None = None, /, **kwargs: Unpack[DefineOptions]
) -> Callable[[T], T]

Define a frozen attrs class and register it as a static PyTree.

Use this for immutable helper objects that should be treated as static metadata instead of flattening into JAX leaves.

Source code in src/liblaf/jarp/tree/attrs/_define.py
def frozen_static[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define a frozen `attrs` class and register it as a static PyTree.

    Use this for immutable helper objects that should be treated as static
    metadata instead of flattening into JAX leaves.
    """
    _warnings_hide = True
    if maybe_cls is None:
        return functools.partial(frozen_static, **kwargs)
    kwargs.setdefault("frozen", True)
    kwargs.setdefault("pytree", PyTreeType.STATIC)
    return define(maybe_cls, **kwargs)

register_fieldz

register_fieldz[T: type](
    cls: T,
    data_fields: Sequence[str] | None = None,
    meta_fields: Sequence[str] | None = None,
    auto_fields: Sequence[str] | None = None,
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> T

Register an attrs class with JAX using field metadata.

Field groups default to the metadata written by array, auto, and static. Pass explicit field lists when you need to register a class that was not declared with liblaf.jarp field helpers.

Parameters:

  • cls (T) –

    Class to register.

  • data_fields (Sequence[str] | None, default: None ) –

    Field names that are always treated as dynamic children.

  • meta_fields (Sequence[str] | None, default: None ) –

    Field names that are always treated as static metadata.

  • auto_fields (Sequence[str] | None, default: None ) –

    Field names filtered at runtime with filter_spec.

  • filter_spec (Callable[[Any], bool], default: is_data ) –

    Predicate used to split auto_fields into dynamic data or metadata.

  • bypass_setattr (bool | None, default: None ) –

    Whether generated unflattening code should use object.__setattr__ instead of normal attribute assignment.

Returns:

  • T

    The same class object, for decorator-style usage.

Source code in src/liblaf/jarp/tree/attrs/_register.py
def register_fieldz[T: type](
    cls: T,
    data_fields: Sequence[str] | None = None,
    meta_fields: Sequence[str] | None = None,
    auto_fields: Sequence[str] | None = None,
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> T:
    """Register an `attrs` class with JAX using field metadata.

    Field groups default to the metadata written by
    [`array`][liblaf.jarp.tree.array], [`auto`][liblaf.jarp.tree.auto], and
    [`static`][liblaf.jarp.tree.static]. Pass explicit field lists when you
    need to register a class that was not declared with `liblaf.jarp` field
    helpers.

    Args:
        cls: Class to register.
        data_fields: Field names that are always treated as dynamic children.
        meta_fields: Field names that are always treated as static metadata.
        auto_fields: Field names filtered at runtime with `filter_spec`.
        filter_spec: Predicate used to split `auto_fields` into dynamic data
            or metadata.
        bypass_setattr: Whether generated unflattening code should use
            [`object.__setattr__`][object.__setattr__] instead of normal
            attribute assignment.

    Returns:
        The same class object, for decorator-style usage.
    """
    if data_fields is None:
        data_fields: list[str] = _filter_field_names(cls, FieldType.DATA)
    if meta_fields is None:
        meta_fields: list[str] = _filter_field_names(cls, FieldType.META)
    if auto_fields is None:
        auto_fields: list[str] = _filter_field_names(cls, FieldType.AUTO)
    register_generic(
        cls,
        data_fields,
        meta_fields,
        auto_fields,
        filter_spec=filter_spec,
        bypass_setattr=bypass_setattr,
    )
    return cls

static

static(**kwargs) -> Any

Create a field that is always treated as static metadata.

Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def static(**kwargs) -> Any:
    """Create a field that is always treated as static metadata."""
    # for consistency with `jax.tree_util.register_dataclass`
    kwargs.setdefault("static", True)
    return field(**kwargs)