liblaf.jarp
¶
Utilities for mixed JAX PyTrees and NVIDIA Warp interop.
The top-level package re-exports the filtered call wrappers
filter_jit and
fallback_jit, the most common helpers from
liblaf.jarp.tree, and Warp integration utilities such as
to_warp, struct,
jax_callable, and
jax_kernel. Import liblaf.jarp.lax,
liblaf.jarp.tree, or liblaf.jarp.warp directly when you need the
larger submodule surfaces.
Modules:
-
lax–Control-flow wrappers with automatic Python fallbacks.
-
tree–Helpers for defining, flattening, and transforming JAX PyTrees.
-
warp–Interop helpers between JAX arrays and NVIDIA Warp.
Classes:
-
Partial–Store a partially applied callable as a PyTree-aware proxy.
-
PyTreeProxy–Wrap an arbitrary object and flatten the wrapped value as a PyTree.
-
Structure–Record how to flatten and rebuild a PyTree's dynamic leaves.
Functions:
-
array–Create a data field whose default is normalized to a JAX array.
-
auto–Create a field whose PyTree role is chosen from the runtime value.
-
cond–Choose between two branches, then retry eagerly if JAX rejects them.
-
define–Define an
attrsclass and optionally register it as a PyTree. -
fallback_jit–Wrap a callable and cache Python fallbacks for failing metadata shapes.
-
field–Create an
attrsfield using jarp'sstaticmetadata convention. -
filter_jit–Wrap a callable with
liblaf.jarpdata-versus-metadata partitioning. -
fori_loop–Run a counted loop, then retry in Python if JAX rejects the body.
-
frozen–Define a frozen
attrsclass and register it as a data PyTree. -
frozen_static–Define a frozen
attrsclass and register it as a static PyTree. -
jax_callable–Wrap
warp.jax_experimental.jax_callablewith optional dtype dispatch. -
jax_kernel–Wrap
warp.jax_experimental.jax_kernelwith optional overload lookup. -
partial–Partially apply a callable and keep bound values visible to JAX trees.
-
ravel–Flatten a PyTree's dynamic leaves into one vector.
-
static–Create a field that is always treated as static metadata.
-
struct–Decorate a class as a Warp struct.
-
switch–Choose one branch by index, then retry eagerly if JAX rejects it.
-
to_warp–Convert a supported array object into a [
warp.array][]. -
while_loop–Run a loop, then retry in Python if JAX rejects the callbacks.
Attributes:
-
__commit_id__(str | None) – -
__version__(str) – -
__version_tuple__(tuple[int | str, ...]) –
__version_tuple__
module-attribute
¶
Partial
¶
Bases: PartialCallableObjectProxy
flowchart TD
liblaf.jarp.Partial[Partial]
click liblaf.jarp.Partial href "" "liblaf.jarp.Partial"
Store a partially applied callable as a PyTree-aware proxy.
Bound arguments and keyword arguments flatten as PyTree children, while the wrapped callable itself is partitioned between dynamic data and static metadata when needed.
Examples:
>>> import jax
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> def add(left, right):
... return left + right
>>> part = jarp.partial(add, jnp.array([1, 2]))
>>> leaves, _treedef = jax.tree.flatten(part)
>>> [leaf.tolist() for leaf in leaves]
[[1, 2]]
>>> part(jnp.array([3, 4])).tolist()
[4, 6]
Methods:
-
__call__–
Attributes:
-
__wrapped__(Callable[..., T]) –
Source code in src/liblaf/jarp/tree/prelude/_partial.py
PyTreeProxy
¶
Bases: BaseObjectProxy
flowchart TD
liblaf.jarp.PyTreeProxy[PyTreeProxy]
click liblaf.jarp.PyTreeProxy href "" "liblaf.jarp.PyTreeProxy"
Wrap an arbitrary object and flatten the wrapped value as a PyTree.
The proxy itself stays transparent while JAX sees the wrapped object's PyTree structure.
Attributes:
-
__wrapped__(T) –
Structure
¶
Record how to flatten and rebuild a PyTree's dynamic leaves.
Instances are returned by ravel and capture the
original tree definition, the static leaves that were removed from the flat
vector, and the offsets needed to reconstruct each dynamic leaf.
Parameters:
-
dtype(str | type[Any] | dtype | SupportsDType) – -
meta_leaves(tuple[Any, ...]) – -
offsets(tuple[int, ...]) – -
shapes(tuple[Shape | None, ...]) – -
treedef(PyTreeDef) –
Methods:
-
ravel–Flatten a compatible tree or flatten an array directly.
-
unravel–Rebuild the original tree shape from a flat vector.
Attributes:
-
dtype(DTypeLike) – -
is_leaf(bool) –Return whether the recorded tree was a single leaf.
-
meta_leaves(tuple[Any, ...]) – -
offsets(tuple[int, ...]) – -
shapes(tuple[Shape | None, ...]) – -
treedef(PyTreeDef) –
ravel
¶
Flatten a compatible tree or flatten an array directly.
Parameters:
-
tree(T | Array) –A tree with the same structure and static leaves used to build this
Structure, or a JAX array that should be flattened directly.
Returns:
-
Array1D–A one-dimensional array containing the dynamic leaves.
Source code in src/liblaf/jarp/tree/_ravel.py
unravel
¶
Rebuild the original tree shape from a flat vector.
Parameters:
-
flat(T | Array) –One-dimensional data produced by
ravel, or a tree that already matches the recorded structure. -
dtype(DTypeLike | None, default:None) –Optional dtype override applied to the flat array before it is split and reshaped.
Returns:
-
T–A tree with the same structure and static metadata as the original
-
T–input to
ravel.
Source code in src/liblaf/jarp/tree/_ravel.py
array
¶
array(
*,
default: T = ...,
validator: _ValidatorArgType[T] | None = ...,
repr: _ReprArgType = ...,
hash: bool | None = ...,
init: bool = ...,
metadata: Mapping[Any, Any] | None = ...,
converter: _ConverterType
| list[_ConverterType]
| tuple[_ConverterType, ...]
| None = ...,
factory: Callable[[], T] | None = ...,
kw_only: bool | None = ...,
eq: _EqOrderType | None = ...,
order: _EqOrderType | None = ...,
on_setattr: _OnSetAttrArgType | None = ...,
alias: str | None = ...,
type: type | None = ...,
static: FieldType | bool | None = ...,
) -> Array
Create a data field whose default is normalized to a JAX array.
When default is a concrete array-like value, array rewrites it into
a factory so each instance receives its own array object.
Parameters:
-
default(T, default:...) – -
validator(_ValidatorArgType[T] | None, default:...) – -
repr(_ReprArgType, default:...) – -
hash(bool | None, default:...) – -
init(bool, default:...) – -
metadata(Mapping[Any, Any] | None, default:...) – -
converter(_ConverterType | list[_ConverterType] | tuple[_ConverterType, ...] | None, default:...) – -
factory(Callable[[], T] | None, default:...) – -
kw_only(bool | None, default:...) – -
eq(_EqOrderType | None, default:...) – -
order(_EqOrderType | None, default:...) – -
on_setattr(_OnSetAttrArgType | None, default:...) – -
alias(str | None, default:...) – -
type(type | None, default:...) – -
static(FieldType | bool | None, default:...) –
Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
cond
¶
cond[*Ts, T](
pred: ScalarLike,
true_fun: Callable[[*Ts], T],
false_fun: Callable[[*Ts], T],
*operands: *Ts,
) -> T
Choose between two branches, then retry eagerly if JAX rejects them.
The wrapper first calls jax.lax.cond. If that raises
jax.errors.JAXTypeError or
jax.errors.JAXIndexError, it logs the
exception and reruns the selected branch in plain Python.
Parameters:
-
pred(ScalarLike) –Scalar predicate. Python truthiness decides which branch runs on the fallback path.
-
true_fun(Callable[[*Ts], T]) –Branch evaluated when
predis true. -
false_fun(Callable[[*Ts], T]) –Branch evaluated when
predis false. -
*operands(*Ts, default:()) –Positional operands forwarded to the selected branch.
Returns:
-
T–The value returned by the selected branch.
Source code in src/liblaf/jarp/lax/_control.py
define
¶
Define an attrs class and optionally register it as a PyTree.
Parameters:
-
maybe_cls(T | None, default:None) –Class being decorated. When omitted, return a configured decorator.
-
**kwargs(Any, default:{}) –Options forwarded to
attrs.define, pluspytreeto control JAX registration.pytree="data"registers fields withfieldzsemantics,"static"registers the whole instance as a static value, and"none"leaves the class unregistered.
Returns:
-
Any–The decorated class or a class decorator.
Source code in src/liblaf/jarp/tree/attrs/_define.py
fallback_jit
¶
fallback_jit(
fun: None = None, **kwargs: Unpack[FilterJitOptions]
) -> IdentityFunction
Wrap a callable and cache Python fallbacks for failing metadata shapes.
The wrapper first uses the same partitioned call path as
filter_jit. If that path raises
jax.errors.JAXTypeError or
jax.errors.JAXIndexError, the exception is
logged, the current static-metadata signature is marked as unsupported,
and the original callable is invoked directly in Python. Later calls with
the same static metadata skip the partitioned path and reuse the Python
fallback immediately.
Examples:
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> @jarp.fallback_jit
... def add_one(value):
... return value + 1
>>> int(add_one(jnp.array(2)))
3
Parameters:
-
fun(Callable[P, T] | None, default:None) –Callable to wrap. When omitted, return a configured decorator.
-
keep_unused(bool, default:...) – -
device(Any | None, default:...) – -
backend(str | None, default:...) – -
inline(bool, default:...) –
Returns:
-
Callable–The wrapped callable, or a decorator that produces one.
Source code in src/liblaf/jarp/_jit/_fallback_jit.py
field
¶
field(**kwargs) -> Any
Create an attrs field using jarp's static metadata convention.
Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
filter_jit
¶
filter_jit(
fun: None = None, **kwargs: Unpack[FilterJitOptions]
) -> IdentityFunction
Wrap a callable with liblaf.jarp data-versus-metadata partitioning.
The wrapper partitions the callable and each invocation's arguments with
partition, rebuilds the original call shape,
and partitions the return value again before handing it back. This keeps
JAX arrays on the dynamic side of the partition while preserving ordinary
Python metadata such as strings, bound methods, or configuration objects.
Examples:
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> @jarp.filter_jit
... def scale(value, *, label):
... assert label == "active"
... return value * 2
>>> scale(jnp.array([1, 2]), label="active").tolist()
[2, 4]
Parameters:
-
fun(F | None, default:None) –Callable to wrap. When omitted, return a configured decorator.
-
keep_unused(bool, default:...) – -
device(Any | None, default:...) – -
backend(str | None, default:...) – -
inline(bool, default:...) –
Returns:
Source code in src/liblaf/jarp/_jit/_filter_jit.py
fori_loop
¶
fori_loop[T](
lower: int,
upper: int,
body_fun: Callable[[int, T], T],
init_val: T,
**kwargs: Any,
) -> T
Run a counted loop, then retry in Python if JAX rejects the body.
The wrapper first calls jax.lax.fori_loop. If that
raises jax.errors.JAXTypeError or
jax.errors.JAXIndexError, it logs the
exception and runs an ordinary Python for loop instead.
Parameters:
-
lower(int) –Inclusive loop lower bound.
-
upper(int) –Exclusive loop upper bound.
-
body_fun(Callable[[int, T], T]) –Callback that receives the iteration index and current loop value, then returns the next loop value.
-
init_val(T) –Initial loop value.
-
**kwargs(Any, default:{}) –Extra keyword arguments forwarded to
jax.lax.fori_loopon the first attempt. They are ignored on the Python fallback path.
Returns:
-
T–The final loop value.
Source code in src/liblaf/jarp/lax/_control.py
frozen
¶
Define a frozen attrs class and register it as a data PyTree.
This is the common choice for immutable structures whose array fields should participate in JAX transformations.
Source code in src/liblaf/jarp/tree/attrs/_define.py
frozen_static
¶
Define a frozen attrs class and register it as a static PyTree.
Use this for immutable helper objects that should be treated as static metadata instead of flattening into JAX leaves.
Source code in src/liblaf/jarp/tree/attrs/_define.py
jax_callable
¶
jax_callable(
func: _FfiCallableFunction,
*,
generic: Literal[False] = False,
**kwargs: Unpack[JaxCallableOptions],
) -> FfiCallableProtocol
jax_callable(
*,
generic: Literal[False] = False,
**kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFunction], FfiCallableProtocol]
jax_callable(
func: _FfiCallableFactory,
*,
generic: Literal[True],
**kwargs: Unpack[JaxCallableOptions],
) -> _FfiCallable
jax_callable(
*,
generic: Literal[True],
**kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFactory], _FfiCallable]
Wrap warp.jax_experimental.jax_callable with optional dtype dispatch.
When generic=True, func is treated as a factory keyed by the Warp
scalar dtypes inferred from the runtime JAX arguments. The factory output is
cached, so repeated calls with the same dtype signature reuse the same Warp
callable.
Parameters:
-
func(Callable | None, default:None) –Warp callable function or factory. When omitted, return a decorator.
-
generic(bool, default:False) –When true,
funcis treated as a factory that receives Warp scalar dtypes inferred from the runtime JAX arguments and returns a concrete Warp callable implementation. -
num_outputs(int, default:...) – -
graph_mode(GraphMode, default:...) – -
vmap_method(VmapMethod | None, default:...) – -
output_dims(dict[str, ShapeLike] | None, default:...) – -
in_out_argnames(Iterable[str], default:...) – -
stage_in_argnames(Iterable[str], default:...) – -
stage_out_argnames(Iterable[str], default:...) – -
graph_cache_max(int | None, default:...) – -
module_preload_mode(ModulePreloadMode, default:...) – -
has_side_effect(bool, default:...) –
Returns:
-
Any–A callable compatible with JAX tracing, or a decorator producing one.
-
Any–The callable returns the output arrays produced by Warp's FFI wrapper.
Source code in src/liblaf/jarp/warp/_jax_callable.py
jax_kernel
¶
jax_kernel(
*,
arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
| None = None,
**kwargs: Unpack[JaxKernelOptions],
) -> Callable[[Callable], FfiKernelProtocol]
jax_kernel(
kernel: Callable,
*,
arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
| None = None,
**kwargs: Unpack[JaxKernelOptions],
) -> FfiKernelProtocol
Wrap warp.jax_experimental.jax_kernel with optional overload lookup.
When arg_types_factory is provided, the wrapper infers Warp scalar
dtypes from the runtime JAX arguments, builds an overload signature, and
resolves the corresponding Warp kernel before dispatch.
Parameters:
-
kernel(Callable | None, default:None) –Warp kernel to expose to JAX. When omitted, return a decorator.
-
arg_types_factory(Callable[[WarpScalarDType], ArgTypes] | None, default:None) –Optional callback that maps runtime Warp scalar dtypes to the overloaded kernel argument types expected by [warp.overload][].
-
num_outputs(int, default:...) – -
vmap_method(VmapMethod, default:...) – -
launch_dims(ShapeLike | None, default:...) – -
output_dims(ShapeLike | dict[str, ShapeLike] | None, default:...) – -
in_out_argnames(Iterable[str], default:...) – -
module_preload_mode(ModulePreloadMode, default:...) – -
enable_backward(bool, default:...) –
Returns:
-
Any–A callable compatible with JAX tracing, or a decorator producing one.
-
Any–The callable returns the output arrays produced by Warp's FFI wrapper.
Source code in src/liblaf/jarp/warp/_jax_kernel.py
partial
¶
Partially apply a callable and keep bound values visible to JAX trees.
ravel
¶
Flatten a PyTree's dynamic leaves into one vector.
Non-array leaves are treated as static metadata and preserved in the
returned Structure instead of being
concatenated into the flat array.
Examples:
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> flat, structure = jarp.ravel({"x": jnp.array([1.0, 2.0]), "tag": "train"})
>>> flat.tolist()
[1.0, 2.0]
>>> rebuilt = structure.unravel(jnp.array([3.0, 4.0]))
>>> rebuilt["x"].tolist(), rebuilt["tag"]
([3.0, 4.0], 'train')
Parameters:
-
tree(T) –PyTree to flatten.
Returns:
-
Array–A tuple of
(flat, structure)whereflatis a one-dimensional JAX -
Structure[T]–array and
structurecan rebuild compatible trees later.
Source code in src/liblaf/jarp/tree/_ravel.py
static
¶
static(**kwargs) -> Any
Create a field that is always treated as static metadata.
Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
struct
¶
struct[T: type](cls: T) -> T
Decorate a class as a Warp struct.
Plain classes are forwarded to warp.struct. Classes that define
__annotations_factory__(dtype) stay generic: MyStruct[wp.float64]
builds and caches a specialized Warp struct from the factory annotations,
while MyStruct() instantiates MyStruct[liblaf.jarp.warp.types.floating]
so the default follows JAX's active precision mode.
Parameters:
-
cls(T) –Class to decorate.
Returns:
-
T–The Warp struct for plain classes, or the original generic class with
-
T–dtype subscription and default construction hooks installed.
Source code in src/liblaf/jarp/warp/_struct.py
switch
¶
Choose one branch by index, then retry eagerly if JAX rejects it.
The wrapper first calls jax.lax.switch. If that raises
jax.errors.JAXTypeError or
jax.errors.JAXIndexError, it logs the
exception, clamps index into the valid range, and dispatches in plain
Python.
Parameters:
-
index(ArrayLike) –Branch index. The fallback path clamps the value into the valid range before dispatch.
-
branches(Sequence[Callable[[*Ts], T]]) –Candidate branch functions.
-
*operands(*Ts, default:()) –Positional operands forwarded to the selected branch.
Returns:
-
T–The value returned by the selected branch.
Source code in src/liblaf/jarp/lax/_control.py
to_warp
¶
Convert a supported array object into a [warp.array][liblaf.jarp.warp.array].
The dispatcher supports existing Warp arrays, NumPy arrays, and JAX arrays.
A dtype hint may be a concrete Warp dtype or a tuple that describes a
vector or matrix dtype inferred from the trailing dimensions of arr.
Use (-1, Any) for vector inference and (-1, -1, Any) for matrix
inference when the element type should follow the source array.
Parameters:
-
arr(array | ndarray | Array) –Array object to convert.
-
*_args(Any, default:()) –Reserved for singledispatch compatibility.
-
**_kwargs(Any, default:{}) –Reserved for singledispatch compatibility.
Returns:
-
array–A Warp array view or converted array, depending on the source type.
Raises:
-
TypeError–If
arruses an unsupported type.
Source code in src/liblaf/jarp/warp/_to_warp.py
while_loop
¶
while_loop[T](
cond_fun: Callable[[T], BooleanNumeric],
body_fun: Callable[[T], T],
init_val: T,
) -> T
Run a loop, then retry in Python if JAX rejects the callbacks.
The wrapper first calls jax.lax.while_loop. If
that raises jax.errors.JAXTypeError or
jax.errors.JAXIndexError, it logs the
exception and reruns the loop eagerly in Python.
Parameters:
-
cond_fun(Callable[[T], BooleanNumeric]) –Predicate evaluated on the loop state.
-
body_fun(Callable[[T], T]) –Function that produces the next loop state.
-
init_val(T) –Initial loop state.
Returns:
-
T–The final loop state.