Skip to content

liblaf.jarp.tree.codegen

Code-generation helpers for high-performance PyTree registrations.

These utilities build specialized flatten and unflatten callbacks for classes whose field layout is known ahead of time.

Classes:

  • PyTreeFunctions

    Container for callbacks passed to register_pytree_node.

Functions:

PyTreeFunctions

Bases: NamedTuple


              flowchart TD
              liblaf.jarp.tree.codegen.PyTreeFunctions[PyTreeFunctions]

              

              click liblaf.jarp.tree.codegen.PyTreeFunctions href "" "liblaf.jarp.tree.codegen.PyTreeFunctions"
            

Container for callbacks passed to register_pytree_node.

Parameters:

Attributes:

flatten instance-attribute

flatten: Callable[[T], tuple[_Children, _AuxData]]

flatten_with_keys instance-attribute

flatten_with_keys: Callable[
    [T], tuple[_ChildrenWithKeys, _AuxData]
]

unflatten instance-attribute

unflatten: Callable[[_AuxData, _Children], T]

codegen_flatten

codegen_flatten(
    data_fields: Sequence[str],
    meta_fields: Sequence[str],
    auto_fields: Sequence[str],
) -> FunctionDef
Source code in src/liblaf/jarp/tree/codegen/_codegen.py
def codegen_flatten(
    data_fields: Sequence[str], meta_fields: Sequence[str], auto_fields: Sequence[str]
) -> FunctionDef:
    body: list[stmt] = codegen_partition(auto_fields)
    children: list[expr] = codegen_children(data_fields, auto_fields)
    aux: list[expr] = codegen_aux(meta_fields, auto_fields)
    body.append(Return(Tuple([Tuple(children, Load()), Tuple(aux, Load())], Load())))
    return codegen_function_def("flatten", [arg("obj")], body)

codegen_flatten_with_keys

codegen_flatten_with_keys(
    data_fields: Sequence[str],
    meta_fields: Sequence[str],
    auto_fields: Sequence[str],
) -> FunctionDef
Source code in src/liblaf/jarp/tree/codegen/_codegen.py
def codegen_flatten_with_keys(
    data_fields: Sequence[str], meta_fields: Sequence[str], auto_fields: Sequence[str]
) -> FunctionDef:
    body: list[stmt] = codegen_partition(auto_fields)
    children: list[expr] = codegen_children(data_fields, auto_fields)
    aux: list[expr] = codegen_aux(meta_fields, auto_fields)
    keys: list[expr] = [
        Name(f"_{name}_key", Load()) for name in (*data_fields, *auto_fields)
    ]
    children_with_keys: list[expr] = [
        Tuple([key, child], Load()) for key, child in zip(keys, children, strict=True)
    ]
    body.append(
        Return(Tuple([Tuple(children_with_keys, Load()), Tuple(aux, Load())], Load()))
    )
    return codegen_function_def("flatten_with_keys", [arg("obj")], body)

codegen_pytree_functions

codegen_pytree_functions(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> PyTreeFunctions

Generate flatten and unflatten callbacks for a class.

Parameters:

  • cls (type) –

    Class whose instances should become PyTree nodes.

  • data_fields (Sequence[str], default: () ) –

    Field names that are always emitted as dynamic children.

  • meta_fields (Sequence[str], default: () ) –

    Field names that are always emitted as auxiliary metadata.

  • auto_fields (Sequence[str], default: () ) –

    Field names filtered at runtime with filter_spec.

  • filter_spec (Callable[[Any], bool], default: is_data ) –

    Predicate used to split auto_fields into dynamic data or metadata.

  • bypass_setattr (bool | None, default: None ) –

    Whether generated unflattening code should use object.__setattr__ instead of normal attribute assignment.

Returns:

Source code in src/liblaf/jarp/tree/codegen/_compile.py
def codegen_pytree_functions(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> PyTreeFunctions:
    """Generate flatten and unflatten callbacks for a class.

    Args:
        cls: Class whose instances should become PyTree nodes.
        data_fields: Field names that are always emitted as dynamic children.
        meta_fields: Field names that are always emitted as auxiliary metadata.
        auto_fields: Field names filtered at runtime with `filter_spec`.
        filter_spec: Predicate used to split `auto_fields` into dynamic data
            or metadata.
        bypass_setattr: Whether generated unflattening code should use
            [`object.__setattr__`][object.__setattr__] instead of normal
            attribute assignment.

    Returns:
        A [`PyTreeFunctions`][liblaf.jarp.tree.codegen.PyTreeFunctions] tuple
        containing `flatten`, `unflatten`, and `flatten_with_keys` callables.
    """
    if bypass_setattr is None:
        bypass_setattr = cls.__setattr__ is not object.__setattr__
    flatten_def: ast.FunctionDef = codegen_flatten(
        data_fields, meta_fields, auto_fields
    )
    flatten_with_keys_def: ast.FunctionDef = codegen_flatten_with_keys(
        data_fields, meta_fields, auto_fields
    )
    unflatten_def: ast.FunctionDef = codegen_unflatten(
        data_fields, meta_fields, auto_fields, bypass_setattr=bypass_setattr
    )
    module: ast.Module = ast.Module(
        body=[flatten_def, flatten_with_keys_def, unflatten_def], type_ignores=[]
    )
    module = ast.fix_missing_locations(module)
    source: str = ast.unparse(module)
    namespace: dict = {
        "_cls": cls,
        "_filter_spec": filter_spec,
        "_object_new": object.__new__,
        "_object_setattr": object.__setattr__,
        **_make_keys((*data_fields, *meta_fields, *auto_fields)),
    }
    filename: str = _make_filename(cls)
    # use unparse source so we have correct source code locations
    code: types.CodeType = compile(source, filename, "exec")
    exec(code, namespace)  # noqa: S102
    _update_linecache(source, filename)
    return PyTreeFunctions(
        _add_dunder(cls, namespace["flatten"]),
        _add_dunder(cls, namespace["unflatten"]),
        _add_dunder(cls, namespace["flatten_with_keys"]),
    )

codegen_unflatten

codegen_unflatten(
    data_fields: Sequence[str],
    meta_fields: Sequence[str],
    auto_fields: Sequence[str],
    *,
    bypass_setattr: bool = False,
) -> FunctionDef
Source code in src/liblaf/jarp/tree/codegen/_codegen.py
def codegen_unflatten(
    data_fields: Sequence[str],
    meta_fields: Sequence[str],
    auto_fields: Sequence[str],
    *,
    bypass_setattr: bool = False,
) -> FunctionDef:
    body: list[stmt] = [
        Assign(
            [Name("obj", Store())],
            Call(Name("_object_new", Load()), [Name("_cls", Load())], []),
        )
    ]
    if bypass_setattr:
        body.extend(_codegen_unflatten_bypass(data_fields, meta_fields, auto_fields))
    else:
        body.extend(_codegen_unflatten_direct(data_fields, meta_fields, auto_fields))
    body.append(Return(Name("obj", Load())))
    return codegen_function_def("unflatten", [arg("aux"), arg("children")], body)

register_generic

register_generic(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> None

Register a class as a PyTree using explicit field groups.

Use this lower-level helper when you want to control the flattening layout directly instead of relying on attrs metadata.

Parameters:

  • cls (type) –

    Class to register.

  • data_fields (Sequence[str], default: () ) –

    Field names that are always emitted as dynamic children.

  • meta_fields (Sequence[str], default: () ) –

    Field names that are always emitted as auxiliary metadata.

  • auto_fields (Sequence[str], default: () ) –

    Field names filtered at runtime with filter_spec.

  • filter_spec (Callable[[Any], bool], default: is_data ) –

    Predicate used to split auto_fields into dynamic data or metadata.

  • bypass_setattr (bool | None, default: None ) –

    Whether generated unflattening code should use object.__setattr__ instead of normal attribute assignment.

Source code in src/liblaf/jarp/tree/codegen/_compile.py
def register_generic(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> None:
    """Register a class as a PyTree using explicit field groups.

    Use this lower-level helper when you want to control the flattening layout
    directly instead of relying on [attrs][] metadata.

    Args:
        cls: Class to register.
        data_fields: Field names that are always emitted as dynamic children.
        meta_fields: Field names that are always emitted as auxiliary metadata.
        auto_fields: Field names filtered at runtime with `filter_spec`.
        filter_spec: Predicate used to split `auto_fields` into dynamic data
            or metadata.
        bypass_setattr: Whether generated unflattening code should use
            [`object.__setattr__`][object.__setattr__] instead of normal
            attribute assignment.
    """
    flatten: Callable
    unflatten: Callable
    flatten_with_keys: Callable
    flatten, unflatten, flatten_with_keys = codegen_pytree_functions(
        cls,
        data_fields,
        meta_fields,
        auto_fields,
        filter_spec=filter_spec,
        bypass_setattr=bypass_setattr,
    )
    jtu.register_pytree_node(cls, flatten, unflatten, flatten_with_keys)