API reference
Jaxpr tree analysis for JAX traced computations.
Builds a tree from a JAX jaxpr where each node is a primitive call; primitives that nest a jaxpr (e.g. jit, while, cond, scan) have children from their nested jaxprs. Supports computing the number of non-structural (e.g. non-JIT) primitive operations in the subtree of each node.
To identify which parts of a computation grow with an input, trace the function
with different values (e.g. via jax.make_jaxpr(f)(*args) with different args),
build the tree for each jaxpr, compute subtree_leaf_count on each,
then compare statistics across traces (e.g. by matching nodes via path).
Use format_jaxpr_tree to pretty-print the tree in the terminal.
JaxprNode
dataclass
A node in the jaxpr tree representing one primitive call (equation).
Attributes:
| Name | Type | Description |
|---|---|---|
primitive |
str
|
Name of the primitive (e.g. "add", "jit", "while"). |
eqn_index |
int | None
|
Index of this equation in the parent jaxpr's eqns, or None for root. |
nested_param_key |
str | None
|
For a child node, which param key this nested jaxpr came from. |
children |
list[JaxprNode]
|
Child nodes (from nested jaxprs); empty for leaf equations. |
params_summary |
list[str]
|
List of param keys that contain nested jaxprs (for inspection). |
invars_count |
int
|
Number of input variables to this primitive. |
outvars_count |
int
|
Number of output variables. |
subtree_leaf_count |
int | None
|
Set by compute_subtree_leaf_counts; number of non-structural primitives in this node's subtree (including this node if it is a leaf). |
path |
list[tuple[int | None, str | None]]
|
Optional path from root for cross-trace matching, e.g. [(eqn_index, key), ...]. |
Source code in src/jaxpr_tree/__init__.py
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | |
compute_subtree_leaf_counts()
Compute the subtree leaf count for this node and all children.
Source code in src/jaxpr_tree/__init__.py
129 130 131 | |
assign_paths(node, path_prefix=None)
Assign path to each node for cross-trace matching (in-place).
Path is a sequence of (eqn_index, nested_param_key) from root. Nodes already have path set by jaxpr_to_tree; this function can be used to re-assign or normalize paths if needed.
Source code in src/jaxpr_tree/__init__.py
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 | |
format_jaxpr_tree(node, *, show_eqn_index=False, show_inouts=False, max_depth=None, indent=' ')
Pretty-print the JaxprNode tree for terminal output.
Returns a multiline string with one line per node. Each line shows the
primitive name and, when set (e.g. after :func:compute_subtree_leaf_counts),
the subtree leaf count (non-structural primitive count). Tree branches use
Unicode box-drawing characters (├──, └──, │).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node
|
JaxprNode
|
Root of the tree (e.g. from :func: |
required |
show_eqn_index
|
bool
|
If True, append equation index (e.g. #1) when not None. |
False
|
show_inouts
|
bool
|
If True, append (in=N, out=M) from invars/outvars counts. |
False
|
max_depth
|
int | None
|
If set, stop recursing beyond this depth and show a placeholder. |
None
|
indent
|
str
|
String used for one level of indent (default 4 spaces). |
' '
|
Returns:
| Type | Description |
|---|---|
str
|
Multiline string suitable for print(format_jaxpr_tree(root)). |
Source code in src/jaxpr_tree/__init__.py
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 | |
get_nested_jaxprs(eqn)
Collect all nested jaxprs from an equation's params.
Inspects known param keys for nested jaxprs. Handles both ClosedJaxpr and bare Jaxpr; for 'branches', each branch jaxpr is returned with a key like "branches[0]".
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eqn
|
JaxprEqn
|
A JAX equation (primitive call). |
required |
Returns:
| Type | Description |
|---|---|
list[tuple[Jaxpr, str]]
|
List of (jaxpr, param_key) where param_key identifies the param |
list[tuple[Jaxpr, str]]
|
(e.g. "body_jaxpr", "branches[1]"). |
Source code in src/jaxpr_tree/__init__.py
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | |