Skip to content

API reference

Jaxpr tree analysis for JAX traced computations.

Builds a tree from a JAX jaxpr where each node is a primitive call; primitives that nest a jaxpr (e.g. jit, while, cond, scan) have children from their nested jaxprs. Supports computing the number of non-structural (e.g. non-JIT) primitive operations in the subtree of each node.

To identify which parts of a computation grow with an input, trace the function with different values (e.g. via jax.make_jaxpr(f)(*args) with different args), build the tree for each jaxpr, compute subtree_leaf_count on each, then compare statistics across traces (e.g. by matching nodes via path).

Use format_jaxpr_tree to pretty-print the tree in the terminal.

JaxprNode dataclass

A node in the jaxpr tree representing one primitive call (equation).

Attributes:

Name Type Description
primitive str

Name of the primitive (e.g. "add", "jit", "while").

eqn_index int | None

Index of this equation in the parent jaxpr's eqns, or None for root.

nested_param_key str | None

For a child node, which param key this nested jaxpr came from.

children list[JaxprNode]

Child nodes (from nested jaxprs); empty for leaf equations.

params_summary list[str]

List of param keys that contain nested jaxprs (for inspection).

invars_count int

Number of input variables to this primitive.

outvars_count int

Number of output variables.

subtree_leaf_count int | None

Set by compute_subtree_leaf_counts; number of non-structural primitives in this node's subtree (including this node if it is a leaf).

path list[tuple[int | None, str | None]]

Optional path from root for cross-trace matching, e.g. [(eqn_index, key), ...].

Source code in src/jaxpr_tree/__init__.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
@dataclass
class JaxprNode:
    """
    A node in the jaxpr tree representing one primitive call (equation).

    Attributes:
        primitive: Name of the primitive (e.g. "add", "jit", "while").
        eqn_index: Index of this equation in the parent jaxpr's eqns, or None for root.
        nested_param_key: For a child node, which param key this nested jaxpr came from.
        children: Child nodes (from nested jaxprs); empty for leaf equations.
        params_summary: List of param keys that contain nested jaxprs (for inspection).
        invars_count: Number of input variables to this primitive.
        outvars_count: Number of output variables.
        subtree_leaf_count: Set by compute_subtree_leaf_counts; number of non-structural
            primitives in this node's subtree (including this node if it is a leaf).
        path: Optional path from root for cross-trace matching, e.g. [(eqn_index, key), ...].

    """

    primitive: str
    eqn_index: int | None = None
    nested_param_key: str | None = None
    children: list[JaxprNode] = field(default_factory=list)
    params_summary: list[str] = field(default_factory=list)
    invars_count: int = 0
    outvars_count: int = 0
    subtree_leaf_count: int | None = None
    path: list[tuple[int | None, str | None]] = field(default_factory=list)

    def __str__(self) -> str:
        return format_jaxpr_tree(self)

    def __repr__(self) -> str:
        return str(self)

    def compute_subtree_leaf_counts(self) -> None:
        """Compute the subtree leaf count for this node and all children."""
        compute_subtree_leaf_counts(self)

compute_subtree_leaf_counts()

Compute the subtree leaf count for this node and all children.

Source code in src/jaxpr_tree/__init__.py
129
130
131
def compute_subtree_leaf_counts(self) -> None:
    """Compute the subtree leaf count for this node and all children."""
    compute_subtree_leaf_counts(self)

assign_paths(node, path_prefix=None)

Assign path to each node for cross-trace matching (in-place).

Path is a sequence of (eqn_index, nested_param_key) from root. Nodes already have path set by jaxpr_to_tree; this function can be used to re-assign or normalize paths if needed.

Source code in src/jaxpr_tree/__init__.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def assign_paths(
    node: JaxprNode,
    path_prefix: list[tuple[int | None, str | None]] | None = None,
) -> None:
    """
    Assign path to each node for cross-trace matching (in-place).

    Path is a sequence of (eqn_index, nested_param_key) from root. Nodes already
    have path set by jaxpr_to_tree; this function can be used to re-assign or
    normalize paths if needed.
    """
    if path_prefix is None:
        path_prefix = []
    node.path = list(path_prefix)
    for child in node.children:
        if child.primitive == _NESTED_BLOCK_PRIMITIVE:
            child_path = path_prefix + [(None, child.nested_param_key)]
            assign_paths(child, child_path)
            for grandchild in child.children:
                assign_paths(
                    grandchild,
                    child_path + [(grandchild.eqn_index, None)],
                )
        else:
            assign_paths(child, path_prefix + [(child.eqn_index, None)])

format_jaxpr_tree(node, *, show_eqn_index=False, show_inouts=False, max_depth=None, indent=' ')

Pretty-print the JaxprNode tree for terminal output.

Returns a multiline string with one line per node. Each line shows the primitive name and, when set (e.g. after :func:compute_subtree_leaf_counts), the subtree leaf count (non-structural primitive count). Tree branches use Unicode box-drawing characters (├──, └──, │).

Parameters:

Name Type Description Default
node JaxprNode

Root of the tree (e.g. from :func:jaxpr_to_tree).

required
show_eqn_index bool

If True, append equation index (e.g. #1) when not None.

False
show_inouts bool

If True, append (in=N, out=M) from invars/outvars counts.

False
max_depth int | None

If set, stop recursing beyond this depth and show a placeholder.

None
indent str

String used for one level of indent (default 4 spaces).

' '

Returns:

Type Description
str

Multiline string suitable for print(format_jaxpr_tree(root)).

Source code in src/jaxpr_tree/__init__.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def format_jaxpr_tree(
    node: JaxprNode,
    *,
    show_eqn_index: bool = False,
    show_inouts: bool = False,
    max_depth: int | None = None,
    indent: str = "    ",
) -> str:
    """
    Pretty-print the JaxprNode tree for terminal output.

    Returns a multiline string with one line per node. Each line shows the
    primitive name and, when set (e.g. after :func:`compute_subtree_leaf_counts`),
    the subtree leaf count (non-structural primitive count). Tree branches use
    Unicode box-drawing characters (├──, └──, │).

    Args:
        node: Root of the tree (e.g. from :func:`jaxpr_to_tree`).
        show_eqn_index: If True, append equation index (e.g. #1) when not None.
        show_inouts: If True, append (in=N, out=M) from invars/outvars counts.
        max_depth: If set, stop recursing beyond this depth and show a placeholder.
        indent: String used for one level of indent (default 4 spaces).

    Returns:
        Multiline string suitable for print(format_jaxpr_tree(root)).

    """
    lines = _format_jaxpr_tree_lines(
        node,
        depth=0,
        prefix="",
        is_last=True,
        show_eqn_index=show_eqn_index,
        show_inouts=show_inouts,
        max_depth=max_depth,
        indent=indent,
    )
    return "\n".join(lines)

get_nested_jaxprs(eqn)

Collect all nested jaxprs from an equation's params.

Inspects known param keys for nested jaxprs. Handles both ClosedJaxpr and bare Jaxpr; for 'branches', each branch jaxpr is returned with a key like "branches[0]".

Parameters:

Name Type Description Default
eqn JaxprEqn

A JAX equation (primitive call).

required

Returns:

Type Description
list[tuple[Jaxpr, str]]

List of (jaxpr, param_key) where param_key identifies the param

list[tuple[Jaxpr, str]]

(e.g. "body_jaxpr", "branches[1]").

Source code in src/jaxpr_tree/__init__.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def get_nested_jaxprs(eqn: core.JaxprEqn) -> list[tuple[core.Jaxpr, str]]:
    """
    Collect all nested jaxprs from an equation's params.

    Inspects known param keys for nested jaxprs. Handles both ClosedJaxpr and
    bare Jaxpr; for 'branches', each branch jaxpr is returned with a key
    like "branches[0]".

    Args:
        eqn: A JAX equation (primitive call).

    Returns:
        List of (jaxpr, param_key) where param_key identifies the param
        (e.g. "body_jaxpr", "branches[1]").

    """
    out: list[tuple[core.Jaxpr, str]] = []
    params = getattr(eqn, "params", {}) or {}

    for key in _CALL_JAXPR_KEYS:
        if key not in params:
            continue
        v = params[key]
        jaxpr = _get_jaxpr_from_param_value(v)
        if jaxpr is not None:
            out.append((jaxpr, key))

    if _BRANCHES_KEY in params:
        branches = params[_BRANCHES_KEY]
        if isinstance(branches, (list, tuple)):
            for i, branch in enumerate(branches):
                jaxpr = _get_jaxpr_from_param_value(branch)
                if jaxpr is not None:
                    out.append((jaxpr, f"{_BRANCHES_KEY}[{i}]"))

    return out