Skip to content

liblaf.jarp.lax

Control-flow wrappers with automatic Python fallbacks.

Use liblaf.jarp.lax when you want to try the corresponding jax.lax primitive first, but still rerun the same callback structure eagerly if JAX raises one of the tracing or indexing errors that commonly appear with Python-only code.

Classes:

  • LaxWrapper

    Call a JAX primitive first and cache Python fallback signatures.

Functions:

  • cond

    Choose between two branches, then retry eagerly if JAX rejects them.

  • fori_loop

    Run a counted loop, then retry in Python if JAX rejects the body.

  • lax_wrapper

    Decorate an eager fallback with a LaxWrapper.

  • switch

    Choose one branch by index, then retry eagerly if JAX rejects it.

  • while_loop

    Run a loop, then retry in Python if JAX rejects the callbacks.

LaxWrapper

Call a JAX primitive first and cache Python fallback signatures.

LaxWrapper powers the public helpers in liblaf.jarp.lax. It preserves wrapper metadata from the wrapped JAX primitive when that metadata exists, tries that primitive on each new call shape, and records metadata signatures that should skip directly to the Python fallback after a supported JAX error. Callable objects without ordinary function metadata are accepted.

Examples:

>>> from liblaf.jarp.lax import LaxWrapper
>>> class Wrapped:
...     def __call__(self, value):
...         return value + 1
>>> wrapper = LaxWrapper(Wrapped(), lambda value: value - 1)
>>> wrapper(2)
3

Attributes:

Parameters:

  • fallback (Callable[ParamSpec, T]) –
  • success_cache (dict[AuxData, bool], default: <class 'dict'> ) –

    dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods:

__wrapped__ class-attribute instance-attribute

__wrapped__: Callable[P, T] = static(alias='__wrapped__')

fallback class-attribute instance-attribute

fallback: Callable[P, T] = static()

success_cache class-attribute instance-attribute

success_cache: dict[AuxData, bool] = field(factory=dict)

__attrs_post_init__

__attrs_post_init__() -> None
Source code in src/liblaf/jarp/lax/_wrapper.py
def __attrs_post_init__(self) -> None:
    for attr in functools.WRAPPER_ASSIGNMENTS:
        try:
            value = getattr(self.__wrapped__, attr)
        except AttributeError:
            pass
        else:
            object.__setattr__(self, attr, value)
    for attr in functools.WRAPPER_UPDATES:
        getattr(self, attr).update(getattr(self.__wrapped__, attr, {}))

__call__

__call__(*args: P.args, **kwargs: P.kwargs) -> T
Source code in src/liblaf/jarp/lax/_wrapper.py
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
    __tracebackhide__ = True
    if self.success_cache:
        _inputs_data, inputs_meta = tree.partition((args, kwargs))
        if self.success_cache.get(inputs_meta) is False:
            return self.fallback(*args, **kwargs)
    try:
        return self.__wrapped__(*args, **kwargs)
    except (jax.errors.JAXTypeError, jax.errors.JAXIndexError):
        logger.exception("", stacklevel=2)
    _inputs_data, inputs_meta = tree.partition((args, kwargs))
    self.success_cache[inputs_meta] = False
    return self.fallback(*args, **kwargs)

cond

cond[*Ts, T](
    pred: ScalarLike,
    true_fun: Callable[[*Ts], T],
    false_fun: Callable[[*Ts], T],
    *operands: *Ts,
) -> T

Choose between two branches, then retry eagerly if JAX rejects them.

The wrapper first calls jax.lax.cond. If that raises jax.errors.JAXTypeError or jax.errors.JAXIndexError, it logs the exception and reruns the selected branch in plain Python.

Parameters:

  • pred (ScalarLike) –

    Scalar predicate. Python truthiness decides which branch runs on the fallback path.

  • true_fun (Callable[[*Ts], T]) –

    Branch evaluated when pred is true.

  • false_fun (Callable[[*Ts], T]) –

    Branch evaluated when pred is false.

  • *operands (*Ts, default: () ) –

    Positional operands forwarded to the selected branch.

Returns:

  • T

    The value returned by the selected branch.

Source code in src/liblaf/jarp/lax/_control.py
@lax_wrapper(jax.lax.cond)
def cond[*Ts, T](
    pred: ScalarLike,
    true_fun: Callable[[*Ts], T],
    false_fun: Callable[[*Ts], T],
    *operands: *Ts,
) -> T:
    """Choose between two branches, then retry eagerly if JAX rejects them.

    The wrapper first calls [`jax.lax.cond`][jax.lax.cond]. If that raises
    [`jax.errors.JAXTypeError`][jax.errors.JAXTypeError] or
    [`jax.errors.JAXIndexError`][jax.errors.JAXIndexError], it logs the
    exception and reruns the selected branch in plain Python.

    Args:
        pred: Scalar predicate. Python truthiness decides which branch runs on
            the fallback path.
        true_fun: Branch evaluated when `pred` is true.
        false_fun: Branch evaluated when `pred` is false.
        *operands: Positional operands forwarded to the selected branch.

    Returns:
        The value returned by the selected branch.
    """
    if pred:
        return true_fun(*operands)
    return false_fun(*operands)

fori_loop

fori_loop[T](
    lower: int,
    upper: int,
    body_fun: Callable[[int, T], T],
    init_val: T,
    **kwargs: Any,
) -> T

Run a counted loop, then retry in Python if JAX rejects the body.

The wrapper first calls jax.lax.fori_loop. If that raises jax.errors.JAXTypeError or jax.errors.JAXIndexError, it logs the exception and runs an ordinary Python for loop instead.

Parameters:

  • lower (int) –

    Inclusive loop lower bound.

  • upper (int) –

    Exclusive loop upper bound.

  • body_fun (Callable[[int, T], T]) –

    Callback that receives the iteration index and current loop value, then returns the next loop value.

  • init_val (T) –

    Initial loop value.

  • **kwargs (Any, default: {} ) –

    Extra keyword arguments forwarded to jax.lax.fori_loop on the first attempt. They are ignored on the Python fallback path.

Returns:

  • T

    The final loop value.

Source code in src/liblaf/jarp/lax/_control.py
@lax_wrapper(jax.lax.fori_loop)
def fori_loop[T](
    lower: int,
    upper: int,
    body_fun: Callable[[int, T], T],
    init_val: T,
    **kwargs: Any,
) -> T:
    """Run a counted loop, then retry in Python if JAX rejects the body.

    The wrapper first calls [`jax.lax.fori_loop`][jax.lax.fori_loop]. If that
    raises [`jax.errors.JAXTypeError`][jax.errors.JAXTypeError] or
    [`jax.errors.JAXIndexError`][jax.errors.JAXIndexError], it logs the
    exception and runs an ordinary Python `for` loop instead.

    Args:
        lower: Inclusive loop lower bound.
        upper: Exclusive loop upper bound.
        body_fun: Callback that receives the iteration index and current loop
            value, then returns the next loop value.
        init_val: Initial loop value.
        **kwargs: Extra keyword arguments forwarded to
            [`jax.lax.fori_loop`][jax.lax.fori_loop] on the first attempt.
            They are ignored on the Python fallback path.

    Returns:
        The final loop value.
    """
    del kwargs
    val: T = init_val
    for i in range(lower, upper):
        val: T = body_fun(i, val)
    return val

lax_wrapper

lax_wrapper[**P, T](
    wrapped: Callable[..., Any],
) -> Callable[[Callable[P, T]], LaxWrapper[P, T]]

Decorate an eager fallback with a LaxWrapper.

Parameters:

  • wrapped (Callable[..., Any]) –

    JAX primitive or compatible callable to try first.

Returns:

Source code in src/liblaf/jarp/lax/_wrapper.py
def lax_wrapper[**P, T](
    wrapped: Callable[..., Any],  # jax's typing is not precise, so we loosen it here
) -> Callable[[Callable[P, T]], LaxWrapper[P, T]]:
    """Decorate an eager fallback with a [`LaxWrapper`][liblaf.jarp.lax.LaxWrapper].

    Args:
        wrapped: JAX primitive or compatible callable to try first.

    Returns:
        A decorator that turns the fallback function into a `LaxWrapper`.
    """
    return functools.partial(LaxWrapper, wrapped)

switch

switch[*Ts, T](
    index: ArrayLike,
    branches: Sequence[Callable[[*Ts], T]],
    *operands: *Ts,
) -> T

Choose one branch by index, then retry eagerly if JAX rejects it.

The wrapper first calls jax.lax.switch. If that raises jax.errors.JAXTypeError or jax.errors.JAXIndexError, it logs the exception, clamps index into the valid range, and dispatches in plain Python.

Parameters:

  • index (ArrayLike) –

    Branch index. The fallback path clamps the value into the valid range before dispatch.

  • branches (Sequence[Callable[[*Ts], T]]) –

    Candidate branch functions.

  • *operands (*Ts, default: () ) –

    Positional operands forwarded to the selected branch.

Returns:

  • T

    The value returned by the selected branch.

Source code in src/liblaf/jarp/lax/_control.py
@lax_wrapper(jax.lax.switch)
def switch[*Ts, T](
    index: ArrayLike, branches: Sequence[Callable[[*Ts], T]], *operands: *Ts
) -> T:
    """Choose one branch by index, then retry eagerly if JAX rejects it.

    The wrapper first calls [`jax.lax.switch`][jax.lax.switch]. If that raises
    [`jax.errors.JAXTypeError`][jax.errors.JAXTypeError] or
    [`jax.errors.JAXIndexError`][jax.errors.JAXIndexError], it logs the
    exception, clamps `index` into the valid range, and dispatches in plain
    Python.

    Args:
        index: Branch index. The fallback path clamps the value into the valid
            range before dispatch.
        branches: Candidate branch functions.
        *operands: Positional operands forwarded to the selected branch.

    Returns:
        The value returned by the selected branch.
    """
    index: Array = jax.lax.clamp(index, 0, len(branches) - 1)
    return branches[cast("int", index)](*operands)

while_loop

while_loop[T](
    cond_fun: Callable[[T], BooleanNumeric],
    body_fun: Callable[[T], T],
    init_val: T,
) -> T

Run a loop, then retry in Python if JAX rejects the callbacks.

The wrapper first calls jax.lax.while_loop. If that raises jax.errors.JAXTypeError or jax.errors.JAXIndexError, it logs the exception and reruns the loop eagerly in Python.

Parameters:

  • cond_fun (Callable[[T], BooleanNumeric]) –

    Predicate evaluated on the loop state.

  • body_fun (Callable[[T], T]) –

    Function that produces the next loop state.

  • init_val (T) –

    Initial loop state.

Returns:

  • T

    The final loop state.

Source code in src/liblaf/jarp/lax/_control.py
@lax_wrapper(jax.lax.while_loop)
def while_loop[T](
    cond_fun: Callable[[T], BooleanNumeric], body_fun: Callable[[T], T], init_val: T
) -> T:
    """Run a loop, then retry in Python if JAX rejects the callbacks.

    The wrapper first calls [`jax.lax.while_loop`][jax.lax.while_loop]. If
    that raises [`jax.errors.JAXTypeError`][jax.errors.JAXTypeError] or
    [`jax.errors.JAXIndexError`][jax.errors.JAXIndexError], it logs the
    exception and reruns the loop eagerly in Python.

    Args:
        cond_fun: Predicate evaluated on the loop state.
        body_fun: Function that produces the next loop state.
        init_val: Initial loop state.

    Returns:
        The final loop state.
    """
    val: T = init_val
    while cond_fun(val):
        val: T = body_fun(val)
    return val