Skip to content

Getting Started

liblaf.jarp is for PyTrees that contain both traceable arrays and ordinary Python metadata. The common pattern is to describe that split once with field specifiers, then reuse it everywhere else.

Install

uv add liblaf-jarp

Optional extras install CUDA-enabled JAX wheels that match the environment:

uv add 'liblaf-jarp[cuda12]'
uv add 'liblaf-jarp[cuda13]'

Define A PyTree-Friendly Class

import jax.numpy as jnp
from liblaf import jarp


@jarp.define
class Batch:
    values: object = jarp.array()
    label: str = jarp.static()


@jarp.filter_jit
def normalize(batch: Batch) -> Batch:
    centered = batch.values - jnp.mean(batch.values)
    return Batch(values=centered, label=batch.label)


batch = Batch(values=jnp.array([1.0, 2.0, 3.0]), label="train")
result = normalize(batch)

array() marks values that should stay on the dynamic side of the partition. static() marks metadata that should stay out of the dynamic leaves. auto() is the middle ground: it decides at flatten time whether the current value behaves like data or metadata.

filter_jit uses the same split for ordinary call arguments, so a function can accept strings, callables, or other metadata inside the same tree as JAX arrays without manual tree surgery.

Flatten Mixed Trees Into One Vector

import jax.numpy as jnp
from liblaf import jarp


payload = {"a": jnp.zeros((3,)), "b": jnp.ones((4,)), "static": "foo"}
flat, structure = jarp.ravel(payload)
round_trip = structure.unravel(flat)

flat contains only the dynamic leaves. structure keeps the tree definition, static leaves, and reshape offsets needed to rebuild compatible values later.

If you already have a compatible tree, Structure.ravel can flatten it again and Structure.unravel will accept an already-matching tree unchanged. If the recorded value was itself a JAX array, Structure.unravel reshapes a flat vector back to that array shape.

Retry Selected Control-Flow Errors Eagerly

jarp.lax tries jax.lax first and reruns the same callbacks in plain Python when JAX raises the selected tracing or indexing errors that the wrappers know how to recover from.

from liblaf import jarp


value = jarp.lax.while_loop(
    lambda state: state[0] < 3,
    lambda state: (state[0] + 1, state[1] + [10, 20, 30][state[0]]),
    (0, 0),
)

For the control-flow helpers and the cached Python fallback in fallback_jit, continue with Call wrappers.

Next Steps