Skip to content

Jaxpr tree

pre-commit Tests status Linting status Documentation status License

Utility for representing JAX expressions (jaxprs) as a tree.

Builds a tree from a JAX jaxpr where each node is a primitive call; primitives that nest a jaxpr (e.g. jit, while, cond, scan) have children from their nested jaxprs. Supports computing the number of non-structural (e.g. non-JIT) primitive operations in the subtree of each node which can be used to compare if the size of the traced computation changes with the input sizes.

[!WARNING] This tool was created with the help of large language model coding assistants as a quick prototype to aid in trying to diagnose what was causing the JIT compilation time to grow input size in a particular project. While some basic verification of the code has been performed there are no guarantees it will give sensible results in all cases.

Getting started

Prerequisites

jaxpr-tree requires Python 3.12+.

Installation

We recommend installing in a project specific virtual environment created using a environment management tool such as uv. To install the latest development version of jaxpr-tree using uv in the currently active environment run

uv pip install git+https://github.com/UCL/jaxpr-tree.git

Alternatively create a local clone of the repository with

git clone https://github.com/UCL/jaxpr-tree.git

and then install in editable mode by running

uv pip install -e .

Example usage

import jax
import jaxpr_tree

@jax.jit
def inner(x):
    return jax.numpy.cos(x)

def outer(x):
    return jax.lax.cond(
        x < 0,
        lambda: 2 * inner(x) + 1.0,
        lambda: -x
    )

x = 0.5
jaxpr = jax.make_jaxpr(outer)(x)
tree = jaxpr_tree.jaxpr_to_tree(jaxpr)
tree.compute_subtree_leaf_counts()
print(tree)

outputs

(root)  leaves=6
    ├── lt
    ├── convert_element_type
    └── cond  leaves=4
        ├── (param block)  branches[0]
        │   └── neg
        └── (param block)  branches[1]  leaves=3
            ├── jit
            │   └── (param block)  jaxpr
            │       └── cos
            ├── mul
            └── add

Running tests

Tests can be run across all compatible Python versions in isolated environments using tox by running

tox

To run tests manually in a Python environment with pytest installed run

pytest tests

again from the root of the repository.

Building documentation

The MkDocs HTML documentation can be built locally by running

tox -e docs

from the root of the repository. The built documentation will be written to site.

Alternatively to build and preview the documentation locally, in a Python environment with the optional docs dependencies installed, run

mkdocs serve