Jaxpr tree
Utility for representing JAX expressions (jaxprs) as a tree.
Builds a tree from a JAX jaxpr where each node is a primitive call; primitives that
nest a jaxpr (e.g. jit, while, cond, scan) have children from their nested jaxprs.
Supports computing the number of non-structural (e.g. non-JIT) primitive operations
in the subtree of each node which can be used to compare if the size of the traced
computation changes with the input sizes.
[!WARNING] This tool was created with the help of large language model coding assistants as a quick prototype to aid in trying to diagnose what was causing the JIT compilation time to grow input size in a particular project. While some basic verification of the code has been performed there are no guarantees it will give sensible results in all cases.
Getting started
Prerequisites
jaxpr-tree requires Python 3.12+.
Installation
We recommend installing in a project specific virtual environment created using
a environment management tool such as
uv. To install the latest
development version of jaxpr-tree using uv in the currently active
environment run
uv pip install git+https://github.com/UCL/jaxpr-tree.git
Alternatively create a local clone of the repository with
git clone https://github.com/UCL/jaxpr-tree.git
and then install in editable mode by running
uv pip install -e .
Example usage
import jax
import jaxpr_tree
@jax.jit
def inner(x):
return jax.numpy.cos(x)
def outer(x):
return jax.lax.cond(
x < 0,
lambda: 2 * inner(x) + 1.0,
lambda: -x
)
x = 0.5
jaxpr = jax.make_jaxpr(outer)(x)
tree = jaxpr_tree.jaxpr_to_tree(jaxpr)
tree.compute_subtree_leaf_counts()
print(tree)
outputs
(root) leaves=6
├── lt
├── convert_element_type
└── cond leaves=4
├── (param block) branches[0]
│ └── neg
└── (param block) branches[1] leaves=3
├── jit
│ └── (param block) jaxpr
│ └── cos
├── mul
└── add
Running tests
Tests can be run across all compatible Python versions in isolated environments
using tox by running
tox
To run tests manually in a Python environment with pytest installed run
pytest tests
again from the root of the repository.
Building documentation
The MkDocs HTML documentation can be built locally by running
tox -e docs
from the root of the repository. The built documentation will be written to
site.
Alternatively to build and preview the documentation locally, in a Python
environment with the optional docs dependencies installed, run
mkdocs serve