Skip to content

API reference

causalprog package.

algorithms

Algorithms.

do

Algorithms for applying do to a graph.

do(graph, node, value, label=None)

Apply do to a graph.

Parameters:

Name Type Description Default
graph Graph

The graph to apply do to. This will be copied.

required
node str

The label of the node to apply do to.

required
value float

The value to set the node to.

required
label str | None

The label of the new graph

None

Returns:

Type Description
Graph

A copy of the graph with do applied

Source code in src/causalprog/algorithms/do.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph:
    """
    Apply do to a graph.

    Args:
        graph: The graph to apply do to. This will be copied.
        node: The label of the node to apply do to.
        value: The value to set the node to.
        label: The label of the new graph

    Returns:
        A copy of the graph with do applied

    """
    if label is None:
        label = f"{graph.label}_do_{node}__" + f"{value}".replace(".", "_")

    nodes = {n.label: deepcopy(n) for n in graph.nodes if n.label != node}

    # Recursively remove nodes that are predecessors of removed nodes
    nodes_to_remove: tuple[str, ...] = (node,)
    while len(nodes_to_remove) > 0:
        nodes_to_remove = removable_nodes(graph, nodes)
        for n in removable_nodes(graph, nodes):
            nodes.pop(n)

    # Check for nodes that are predecessors of both a removed node and a remaining node
    # and throw an error if one of these is found
    for n in nodes:
        _, excluded = get_included_excluded_successors(graph, nodes, n)
        if len(excluded) > 0:
            msg = (
                "Node that is predecessor of node set by do and "
                f'nodes that are not removed found ("{n}")'
            )
            raise ValueError(msg)

    nodes[node] = ConstantNode(label=node, value=value)

    g = Graph(label=f"{label}_do_{node}__" + f"{value}".replace(".", "_"))
    for n in nodes.values():
        g.add_node(n)

    # Any nodes whose counterparts connect to other nodes in the network need
    # to mimic these links.
    for edge in graph.edges:
        if edge[0].label in nodes and edge[1].label in nodes:
            g.add_edge(edge[0].label, edge[1].label)

    return g

get_included_excluded_successors(graph, node_list, successors_of)

Split successors of a node into nodes included and not included in a list.

Split the successorts of a node into a list of nodes that are included in the input node list and a list of nodes that are not in the list.

Parameters:

Name Type Description Default
graph Graph

The graph

required
node_list dict[str, Node]

A dictionary of nodes, indexed by label

required
successors_of str

The node to check the successors of

required

Returns:

Type Description
tuple[tuple[str, ...], tuple[str, ...]]

Lists of included and excluded nodes

Source code in src/causalprog/algorithms/do.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def get_included_excluded_successors(
    graph: Graph, node_list: dict[str, Node], successors_of: str
) -> tuple[tuple[str, ...], tuple[str, ...]]:
    """
    Split successors of a node into nodes included and not included in a list.

    Split the successorts of a node into a list of nodes that are included in
    the input node list and a list of nodes that are not in the list.

    Args:
        graph: The graph
        node_list: A dictionary of nodes, indexed by label
        successors_of: The node to check the successors of

    Returns:
        Lists of included and excluded nodes

    """
    included = []
    excluded = []
    for n in graph.successors[graph.get_node(successors_of)]:
        if n.label in node_list:
            included.append(n)
        else:
            excluded.append(n)
    return tuple(included), tuple(excluded)

removable_nodes(graph, nodes)

Generate list of nodes that can be removed from the graph.

Parameters:

Name Type Description Default
graph Graph

The graph

required
nodes dict[str, Node]

A dictionary of nodes, indexed by label

required

Returns:

Type Description
tuple[str, ...]

List of labels of removable nodes

Source code in src/causalprog/algorithms/do.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def removable_nodes(graph: Graph, nodes: dict[str, Node]) -> tuple[str, ...]:
    """
    Generate list of nodes that can be removed from the graph.

    Args:
        graph: The graph
        nodes: A dictionary of nodes, indexed by label

    Returns:
        List of labels of removable nodes

    """
    removable: list[str] = []
    for n in nodes:
        included, excluded = get_included_excluded_successors(graph, nodes, n)
        if len(excluded) > 0 and len(included) == 0:
            removable.append(n)
    return tuple(removable)

moments

Algorithms for estimating the expectation and standard deviation.

expectation(graph, outcome_node_label, samples, *, parameter_values=None, rng_key)

Estimate the expectation of (a random variable attached to) a node in a graph.

Source code in src/causalprog/algorithms/moments.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def expectation(
    graph: Graph,
    outcome_node_label: str,
    samples: int,
    *,
    parameter_values: dict[str, float] | None = None,
    rng_key: jax.Array,
) -> float:
    """Estimate the expectation of (a random variable attached to) a node in a graph."""
    return moment(
        1,
        graph,
        outcome_node_label,
        samples,
        rng_key=rng_key,
        parameter_values=parameter_values,
    )

moment(order, graph, outcome_node_label, samples, *, parameter_values=None, rng_key)

Estimate a moment of (a random variable attached to) a node in a graph.

Source code in src/causalprog/algorithms/moments.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def moment(
    order: int,
    graph: Graph,
    outcome_node_label: str,
    samples: int,
    *,
    parameter_values: dict[str, float] | None = None,
    rng_key: jax.Array,
) -> float:
    """Estimate a moment of (a random variable attached to) a node in a graph."""
    return (
        sum(
            sample(
                graph,
                outcome_node_label,
                samples,
                rng_key=rng_key,
                parameter_values=parameter_values,
            )
            ** order
        )
        / samples
    )

sample(graph, outcome_node_label, samples, *, parameter_values=None, rng_key)

Sample data from (a random variable attached to) a node in a graph.

Source code in src/causalprog/algorithms/moments.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def sample(
    graph: Graph,
    outcome_node_label: str,
    samples: int,
    *,
    parameter_values: dict[str, float] | None = None,
    rng_key: jax.Array,
) -> npt.NDArray[float]:
    """Sample data from (a random variable attached to) a node in a graph."""
    nodes = graph.roots_down_to_outcome(outcome_node_label)

    values: dict[str, npt.NDArray[float]] = {}
    keys = jax.random.split(rng_key, len(nodes))

    for node, key in zip(nodes, keys, strict=False):
        values[node.label] = node.sample(
            parameter_values or {},
            values,
            samples,
            rng_key=key,
        )
    return values[outcome_node_label]

standard_deviation(graph, outcome_node_label, samples, *, parameter_values=None, rng_key, rng_key_first_moment=None)

Estimate the standard deviation of (a RV attached to) a node in a graph.

Source code in src/causalprog/algorithms/moments.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def standard_deviation(
    graph: Graph,
    outcome_node_label: str,
    samples: int,
    *,
    parameter_values: dict[str, float] | None = None,
    rng_key: jax.Array,
    rng_key_first_moment: jax.Array | None = None,
) -> float:
    """Estimate the standard deviation of (a RV attached to) a node in a graph."""
    return (
        moment(
            2,
            graph,
            outcome_node_label,
            samples,
            rng_key=rng_key,
            parameter_values=parameter_values,
        )
        - moment(
            1,
            graph,
            outcome_node_label,
            samples,
            rng_key=rng_key if rng_key_first_moment is None else rng_key_first_moment,
            parameter_values=parameter_values,
        )
        ** 2
    ) ** 0.5

backend

Helper functionality for incorporating different backends.

causal_problem

Classes for defining causal problems.

CausalEstimand

Bases: _CPComponent

A Causal Estimand.

The causal estimand is the function that we want to minimise (and maximise) as part of a causal problem. It should be a scalar-valued function of the random variables appearing in a graph.

Source code in src/causalprog/causal_problem/components.py
16
17
18
19
20
21
22
23
class CausalEstimand(_CPComponent):
    """
    A Causal Estimand.

    The causal estimand is the function that we want to minimise (and maximise)
    as part of a causal problem. It should be a scalar-valued function of the
    random variables appearing in a graph.
    """

CausalProblem

Defines a causal problem.

Source code in src/causalprog/causal_problem/causal_problem.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class CausalProblem:
    """Defines a causal problem."""

    _underlying_graph: Graph
    causal_estimand: CausalEstimand
    constraints: list[Constraint]

    @property
    def _ordered_components(self) -> list[_CPComponent]:
        """Internal ordering for components of the causal problem."""
        return [*self.constraints, self.causal_estimand]

    def __init__(
        self,
        graph: Graph,
        *constraints: Constraint,
        causal_estimand: CausalEstimand,
    ) -> None:
        """Create a new causal problem."""
        self._underlying_graph = graph
        self.causal_estimand = causal_estimand
        self.constraints = list(constraints)

    def _associate_models_to_components(
        self, n_samples: int
    ) -> tuple[list[Predictive], list[int]]:
        """
        Create models to be used by components of the problem.

        Depending on how many constraints (and the causal estimand) require effect
        handlers to wrap `self._underlying_graph.model`, we will need to create several
        predictive models to sample from. However, we also want to minimise the number
        of such models we have to make, in order to minimise the time we spend
        actually computing samples.

        As such, in this method we determine:
        - How many models we will need to build, by grouping the constraints and the
          causal estimand by the handlers they use.
        - Build these models, returning them in a list called `models`.
        - Build another list that maps the index of components in
          `self._ordered_components` to the index of the model in `models` that they
          use. The causal estimand is by convention the component at index -1 of this
          returned list.

        Args:
            n_samples: Value to be passed to `numpyro.Predictive`'s `num_samples`
                argument for each of the models that are constructed from the underlying
                graph.

        Returns:
            list[Predictive]: List of Predictive models, whose elements contain all the
                models needed by the components.
            list[int]: Mapping of component indexes (as per `self_ordered_components`)
                to the index of the model in the first return argument that the
                component uses.

        """
        models: list[Predictive] = []
        grouped_component_indexes: list[list[int]] = []
        for index, component in enumerate(self._ordered_components):
            # Determine if this constraint uses the same handlers as those of any of
            # the other sets.
            belongs_to_existing_group = False
            for group in grouped_component_indexes:
                # Pull any element from the group to compare models to.
                # Items in a group are known to have the same model, so we can just
                # pull out the first one.
                group_element = self._ordered_components[group[0]]
                # Check if the current constraint can also use this model.
                if component.can_use_same_model_as(group_element):
                    group.append(index)
                    belongs_to_existing_group = True
                    break

            # If the component does not fit into any existing group, create a new
            # group for it. And add the model corresponding to the group to the
            # list of models.
            if not belongs_to_existing_group:
                grouped_component_indexes.append([index])

                models.append(
                    Predictive(
                        component.apply_effects(self._underlying_graph.model),
                        num_samples=n_samples,
                    )
                )

        # Now "invert" the grouping, creating a mapping that maps the index of a
        # component to the (index of the) model it uses.
        component_index_to_model_index = []
        for index in range(len(self._ordered_components)):
            for group_index, group in enumerate(grouped_component_indexes):
                if index in group:
                    component_index_to_model_index.append(group_index)
                    break
        # All indexes should belong to at least one group (worst case scenario,
        # their own individual group). Thus, it is safe to do the above to create
        # the mapping from component index -> model (group) index.
        return models, component_index_to_model_index

    def lagrangian(
        self, n_samples: int = 1000, *, maximum_problem: bool = False
    ) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]:
        """
        Return a function that evaluates the Lagrangian of this `CausalProblem`.

        Following the
        [KKT theorem](https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions),
        given the causal estimand and the constraints we can assemble a Lagrangian and
        seek its stationary points, to in turn identify minimisers of the constrained
        optimisation problem that we started with.

        The Lagrangian returned is a mathematical function of its first two arguments.
        The first argument is the same dictionary of parameters that is passed to models
        like `Graph.model`, and is the values the parameters (represented by the
        `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange
        multipliers, whose length is equal to the number of constraints.

        The remaining argument of the Lagrangian is the PRNG Key that should be used
        when drawing samples.

        Note that our current implementation assumes there are no equality constraints
        being imposed (in which case, we would need a 3-argument Lagrangian function).

        Args:
            n_samples: The number of random samples to be drawn when estimating the
                value of functions of the RVs.
            maximum_problem: If passed as `True`, assemble the Lagrangian for the
                maximisation problem. Otherwise assemble that for the minimisation
                problem (default behaviour).

        Returns:
            The Lagrangian, as a function of the model parameters, Lagrange multipliers,
                and PRNG key.

        """
        maximisation_prefactor = -1.0 if maximum_problem else 1.0

        # Build association between self.constraints and the model-samples that each
        # one needs to use. We do this here, since once it is constructed, it is
        # fixed, and doesn't need to be done each time we call the Lagrangian.
        models, component_to_index_mapping = self._associate_models_to_components(
            n_samples
        )

        def _inner(
            parameter_values: dict[str, npt.ArrayLike],
            l_mult: jax.Array,
            rng_key: jax.Array,
        ) -> npt.ArrayLike:
            # Draw samples from all models
            all_samples = tuple(
                sample_model(model, rng_key, parameter_values) for model in models
            )

            value = maximisation_prefactor * self.causal_estimand(all_samples[-1])
            # TODO: https://github.com/UCL/causalprog/issues/87
            value += sum(
                l_mult[i] * c(all_samples[component_to_index_mapping[i]])
                for i, c in enumerate(self.constraints)
            )
            return value

        return _inner

__init__(graph, *constraints, causal_estimand)

Create a new causal problem.

Source code in src/causalprog/causal_problem/causal_problem.py
52
53
54
55
56
57
58
59
60
61
def __init__(
    self,
    graph: Graph,
    *constraints: Constraint,
    causal_estimand: CausalEstimand,
) -> None:
    """Create a new causal problem."""
    self._underlying_graph = graph
    self.causal_estimand = causal_estimand
    self.constraints = list(constraints)

lagrangian(n_samples=1000, *, maximum_problem=False)

Return a function that evaluates the Lagrangian of this CausalProblem.

Following the KKT theorem, given the causal estimand and the constraints we can assemble a Lagrangian and seek its stationary points, to in turn identify minimisers of the constrained optimisation problem that we started with.

The Lagrangian returned is a mathematical function of its first two arguments. The first argument is the same dictionary of parameters that is passed to models like Graph.model, and is the values the parameters (represented by the ParameterNodes) are taking. The second argument is a 1D vector of Lagrange multipliers, whose length is equal to the number of constraints.

The remaining argument of the Lagrangian is the PRNG Key that should be used when drawing samples.

Note that our current implementation assumes there are no equality constraints being imposed (in which case, we would need a 3-argument Lagrangian function).

Parameters:

Name Type Description Default
n_samples int

The number of random samples to be drawn when estimating the value of functions of the RVs.

1000
maximum_problem bool

If passed as True, assemble the Lagrangian for the maximisation problem. Otherwise assemble that for the minimisation problem (default behaviour).

False

Returns:

Type Description
Callable[[dict[str, ArrayLike], ArrayLike, Array], ArrayLike]

The Lagrangian, as a function of the model parameters, Lagrange multipliers, and PRNG key.

Source code in src/causalprog/causal_problem/causal_problem.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def lagrangian(
    self, n_samples: int = 1000, *, maximum_problem: bool = False
) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]:
    """
    Return a function that evaluates the Lagrangian of this `CausalProblem`.

    Following the
    [KKT theorem](https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions),
    given the causal estimand and the constraints we can assemble a Lagrangian and
    seek its stationary points, to in turn identify minimisers of the constrained
    optimisation problem that we started with.

    The Lagrangian returned is a mathematical function of its first two arguments.
    The first argument is the same dictionary of parameters that is passed to models
    like `Graph.model`, and is the values the parameters (represented by the
    `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange
    multipliers, whose length is equal to the number of constraints.

    The remaining argument of the Lagrangian is the PRNG Key that should be used
    when drawing samples.

    Note that our current implementation assumes there are no equality constraints
    being imposed (in which case, we would need a 3-argument Lagrangian function).

    Args:
        n_samples: The number of random samples to be drawn when estimating the
            value of functions of the RVs.
        maximum_problem: If passed as `True`, assemble the Lagrangian for the
            maximisation problem. Otherwise assemble that for the minimisation
            problem (default behaviour).

    Returns:
        The Lagrangian, as a function of the model parameters, Lagrange multipliers,
            and PRNG key.

    """
    maximisation_prefactor = -1.0 if maximum_problem else 1.0

    # Build association between self.constraints and the model-samples that each
    # one needs to use. We do this here, since once it is constructed, it is
    # fixed, and doesn't need to be done each time we call the Lagrangian.
    models, component_to_index_mapping = self._associate_models_to_components(
        n_samples
    )

    def _inner(
        parameter_values: dict[str, npt.ArrayLike],
        l_mult: jax.Array,
        rng_key: jax.Array,
    ) -> npt.ArrayLike:
        # Draw samples from all models
        all_samples = tuple(
            sample_model(model, rng_key, parameter_values) for model in models
        )

        value = maximisation_prefactor * self.causal_estimand(all_samples[-1])
        # TODO: https://github.com/UCL/causalprog/issues/87
        value += sum(
            l_mult[i] * c(all_samples[component_to_index_mapping[i]])
            for i, c in enumerate(self.constraints)
        )
        return value

    return _inner

Constraint

Bases: _CPComponent

A Constraint that forms part of a causal problem.

Constraints of a causal problem are derived properties of RVs for which we have observed data. The causal estimand is minimised (or maximised) subject to the predicted values of the constraints being close to their observed values in the data.

Adding a constraint \(g(\theta)\) to a causal problem (where \(\theta\) are the parameters of the causal problem) essentially imposes an additional constraint on the minimisation problem;

\[ g(\theta) - g_{\text{data}} \leq \epsilon, \]

where \(g_{\text{data}}\) is the observed data value for the quantity \(g\), and \(\epsilon\) is some tolerance.

Source code in src/causalprog/causal_problem/components.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class Constraint(_CPComponent):
    r"""
    A Constraint that forms part of a causal problem.

    Constraints of a causal problem are derived properties of RVs for which we
    have observed data. The causal estimand is minimised (or maximised) subject
    to the predicted values of the constraints being close to their observed
    values in the data.

    Adding a constraint $g(\theta)$ to a causal problem (where $\theta$ are the
    parameters of the causal problem) essentially imposes an additional
    constraint on the minimisation problem;

    $$ g(\theta) - g_{\text{data}} \leq \epsilon, $$

    where $g_{\text{data}}$ is the observed data value for the quantity $g$,
    and $\epsilon$ is some tolerance.
    """

    data: npt.ArrayLike
    tolerance: npt.ArrayLike
    _outer_norm: Callable[[npt.ArrayLike], float]

    def __init__(
        self,
        *effect_handlers: ModelMask,
        model_quantity: Callable[..., npt.ArrayLike],
        outer_norm: Callable[[npt.ArrayLike], float] | None = None,
        data: npt.ArrayLike = 0.0,
        tolerance: float = 1.0e-6,
    ) -> None:
        r"""
        Create a new constraint.

        Constraints have the form

        $$ c(\theta) :=
        \mathrm{norm}\left( g(\theta)
        - g_{\mathrm{data}} \right)
        - \epsilon $$

        where;
        - $\mathrm{norm}$ is the outer norm of the constraint (`outer_norm`),
        - $g(\theta)$ is the model quantity involved in the constraint
            (`model_quantity`),
        - $g_{\mathrm{data}}$ is the observed data (`data`),
        - $\epsilon$ is the tolerance in the data (`tolerance`).

        In a causal problem, each constraint appears as the condition $c(\theta)\leq 0$
        in the minimisation / maximisation (hence the inclusion of the $-\epsilon$
        term within $c(\theta)$ itself).

        $g$ should be a (possibly vector-valued) function that acts on (a subset of)
        samples from the random variables of the causal problem. It must accept
        variable keyword-arguments only, and should access the samples for each random
        variable by indexing via the RV names (node labels). It should return the
        model quantity as computed from the samples, that $g_{\mathrm{data}}$ observed.

        $g_{\mathrm{data}}$ should be a fixed value whose shape is broadcast-able with
        the return shape of $g$. It defaults to $0$ if not explicitly set.

        $\mathrm{norm}$ should be a suitable norm to take on the difference between the
        model quantity as predicted by the samples ($g$) and the observed data
        ($g_{\mathrm{data}}$). It must return a scalar value. The default is the 2-norm.
        """
        super().__init__(*effect_handlers, do_with_samples=model_quantity)

        if outer_norm is None:
            self._outer_norm = jnp.linalg.vector_norm
        else:
            self._outer_norm = outer_norm

        self.data = data
        self.tolerance = tolerance

    def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
        """
        Evaluate the constraint, given RV samples.

        Args:
            samples: Mapping of RV (node) labels to drawn samples.

        Returns:
            Value of the constraint.

        """
        return (
            self._outer_norm(self._do_with_samples(**samples) - self.data)
            - self.tolerance
        )

__call__(samples)

Evaluate the constraint, given RV samples.

Parameters:

Name Type Description Default
samples dict[str, ArrayLike]

Mapping of RV (node) labels to drawn samples.

required

Returns:

Type Description
ArrayLike

Value of the constraint.

Source code in src/causalprog/causal_problem/components.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
    """
    Evaluate the constraint, given RV samples.

    Args:
        samples: Mapping of RV (node) labels to drawn samples.

    Returns:
        Value of the constraint.

    """
    return (
        self._outer_norm(self._do_with_samples(**samples) - self.data)
        - self.tolerance
    )

__init__(*effect_handlers, model_quantity, outer_norm=None, data=0.0, tolerance=1e-06)

Create a new constraint.

Constraints have the form

\[ c(\theta) := \mathrm{norm}\left( g(\theta) - g_{\mathrm{data}} \right) - \epsilon \]

where; - \(\mathrm{norm}\) is the outer norm of the constraint (outer_norm), - \(g(\theta)\) is the model quantity involved in the constraint (model_quantity), - \(g_{\mathrm{data}}\) is the observed data (data), - \(\epsilon\) is the tolerance in the data (tolerance).

In a causal problem, each constraint appears as the condition \(c(\theta)\leq 0\) in the minimisation / maximisation (hence the inclusion of the \(-\epsilon\) term within \(c(\theta)\) itself).

\(g\) should be a (possibly vector-valued) function that acts on (a subset of) samples from the random variables of the causal problem. It must accept variable keyword-arguments only, and should access the samples for each random variable by indexing via the RV names (node labels). It should return the model quantity as computed from the samples, that \(g_{\mathrm{data}}\) observed.

\(g_{\mathrm{data}}\) should be a fixed value whose shape is broadcast-able with the return shape of \(g\). It defaults to \(0\) if not explicitly set.

\(\mathrm{norm}\) should be a suitable norm to take on the difference between the model quantity as predicted by the samples (\(g\)) and the observed data (\(g_{\mathrm{data}}\)). It must return a scalar value. The default is the 2-norm.

Source code in src/causalprog/causal_problem/components.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def __init__(
    self,
    *effect_handlers: ModelMask,
    model_quantity: Callable[..., npt.ArrayLike],
    outer_norm: Callable[[npt.ArrayLike], float] | None = None,
    data: npt.ArrayLike = 0.0,
    tolerance: float = 1.0e-6,
) -> None:
    r"""
    Create a new constraint.

    Constraints have the form

    $$ c(\theta) :=
    \mathrm{norm}\left( g(\theta)
    - g_{\mathrm{data}} \right)
    - \epsilon $$

    where;
    - $\mathrm{norm}$ is the outer norm of the constraint (`outer_norm`),
    - $g(\theta)$ is the model quantity involved in the constraint
        (`model_quantity`),
    - $g_{\mathrm{data}}$ is the observed data (`data`),
    - $\epsilon$ is the tolerance in the data (`tolerance`).

    In a causal problem, each constraint appears as the condition $c(\theta)\leq 0$
    in the minimisation / maximisation (hence the inclusion of the $-\epsilon$
    term within $c(\theta)$ itself).

    $g$ should be a (possibly vector-valued) function that acts on (a subset of)
    samples from the random variables of the causal problem. It must accept
    variable keyword-arguments only, and should access the samples for each random
    variable by indexing via the RV names (node labels). It should return the
    model quantity as computed from the samples, that $g_{\mathrm{data}}$ observed.

    $g_{\mathrm{data}}$ should be a fixed value whose shape is broadcast-able with
    the return shape of $g$. It defaults to $0$ if not explicitly set.

    $\mathrm{norm}$ should be a suitable norm to take on the difference between the
    model quantity as predicted by the samples ($g$) and the observed data
    ($g_{\mathrm{data}}$). It must return a scalar value. The default is the 2-norm.
    """
    super().__init__(*effect_handlers, do_with_samples=model_quantity)

    if outer_norm is None:
        self._outer_norm = jnp.linalg.vector_norm
    else:
        self._outer_norm = outer_norm

    self.data = data
    self.tolerance = tolerance

HandlerToApply dataclass

Specifies a handler that needs to be applied to a model at runtime.

Source code in src/causalprog/causal_problem/handlers.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@dataclass
class HandlerToApply:
    """Specifies a handler that needs to be applied to a model at runtime."""

    handler: EffectHandler
    options: dict[str, Any] = field(default_factory=dict)

    @classmethod
    def from_pair(cls, pair: Sequence) -> "HandlerToApply":
        """
        Create an instance from an effect handler and its options.

        The two objects should be passed in as the elements of a container of length
        2. They can be passed in any order;
        - One element must be a dictionary, which will be interpreted as the `options`
            for the effect handler.
        - The other element must be callable, and will be interpreted as the `handler`
            itself.

        Args:
            pair: Container of two elements, one being the effect handler callable and
                the other being the options to pass to it (as a dictionary).

        Returns:
            Class instance corresponding to the effect handler and options passed.

        """
        if len(pair) != 2:  # noqa: PLR2004
            msg = (
                f"{cls.__name__} can only be constructed from a container of 2 elements"
            )
            raise ValueError(msg)

        # __post_init__ will catch cases when the incorrect types for one or both items
        # is passed, so we can just naively if-else here.
        handler: EffectHandler
        options: dict
        if callable(pair[0]):
            handler = pair[0]
            options = pair[1]
        else:
            handler = pair[1]
            options = pair[0]

        return cls(handler=handler, options=options)

    def __post_init__(self) -> None:
        """
        Validate set attributes.

        - The handler is a callable object.
        - The options have been passed as a dictionary of keyword-value pairs.
        """
        if not callable(self.handler):
            msg = f"{type(self.handler).__name__} is not callable."
            raise TypeError(msg)
        if not isinstance(self.options, dict):
            msg = (
                "Options should be dictionary mapping option arguments to values "
                f"(got {type(self.options).__name__})."
            )
            raise TypeError(msg)

    def __eq__(self, other: object) -> bool:
        """
        Equality operation.

        `HandlerToApply`s are considered equal if they use the same handler function and
        provide the same options to this function.

        Comparison to other types returns `False`.
        """
        return (
            isinstance(other, HandlerToApply)
            and self.handler is other.handler
            and self.options == other.options
        )

__eq__(other)

Equality operation.

HandlerToApplys are considered equal if they use the same handler function and provide the same options to this function.

Comparison to other types returns False.

Source code in src/causalprog/causal_problem/handlers.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __eq__(self, other: object) -> bool:
    """
    Equality operation.

    `HandlerToApply`s are considered equal if they use the same handler function and
    provide the same options to this function.

    Comparison to other types returns `False`.
    """
    return (
        isinstance(other, HandlerToApply)
        and self.handler is other.handler
        and self.options == other.options
    )

__post_init__()

Validate set attributes.

  • The handler is a callable object.
  • The options have been passed as a dictionary of keyword-value pairs.
Source code in src/causalprog/causal_problem/handlers.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __post_init__(self) -> None:
    """
    Validate set attributes.

    - The handler is a callable object.
    - The options have been passed as a dictionary of keyword-value pairs.
    """
    if not callable(self.handler):
        msg = f"{type(self.handler).__name__} is not callable."
        raise TypeError(msg)
    if not isinstance(self.options, dict):
        msg = (
            "Options should be dictionary mapping option arguments to values "
            f"(got {type(self.options).__name__})."
        )
        raise TypeError(msg)

from_pair(pair) classmethod

Create an instance from an effect handler and its options.

The two objects should be passed in as the elements of a container of length 2. They can be passed in any order; - One element must be a dictionary, which will be interpreted as the options for the effect handler. - The other element must be callable, and will be interpreted as the handler itself.

Parameters:

Name Type Description Default
pair Sequence

Container of two elements, one being the effect handler callable and the other being the options to pass to it (as a dictionary).

required

Returns:

Type Description
HandlerToApply

Class instance corresponding to the effect handler and options passed.

Source code in src/causalprog/causal_problem/handlers.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@classmethod
def from_pair(cls, pair: Sequence) -> "HandlerToApply":
    """
    Create an instance from an effect handler and its options.

    The two objects should be passed in as the elements of a container of length
    2. They can be passed in any order;
    - One element must be a dictionary, which will be interpreted as the `options`
        for the effect handler.
    - The other element must be callable, and will be interpreted as the `handler`
        itself.

    Args:
        pair: Container of two elements, one being the effect handler callable and
            the other being the options to pass to it (as a dictionary).

    Returns:
        Class instance corresponding to the effect handler and options passed.

    """
    if len(pair) != 2:  # noqa: PLR2004
        msg = (
            f"{cls.__name__} can only be constructed from a container of 2 elements"
        )
        raise ValueError(msg)

    # __post_init__ will catch cases when the incorrect types for one or both items
    # is passed, so we can just naively if-else here.
    handler: EffectHandler
    options: dict
    if callable(pair[0]):
        handler = pair[0]
        options = pair[1]
    else:
        handler = pair[1]
        options = pair[0]

    return cls(handler=handler, options=options)

causal_problem

Classes for representing causal problems.

CausalProblem

Defines a causal problem.

Source code in src/causalprog/causal_problem/causal_problem.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class CausalProblem:
    """Defines a causal problem."""

    _underlying_graph: Graph
    causal_estimand: CausalEstimand
    constraints: list[Constraint]

    @property
    def _ordered_components(self) -> list[_CPComponent]:
        """Internal ordering for components of the causal problem."""
        return [*self.constraints, self.causal_estimand]

    def __init__(
        self,
        graph: Graph,
        *constraints: Constraint,
        causal_estimand: CausalEstimand,
    ) -> None:
        """Create a new causal problem."""
        self._underlying_graph = graph
        self.causal_estimand = causal_estimand
        self.constraints = list(constraints)

    def _associate_models_to_components(
        self, n_samples: int
    ) -> tuple[list[Predictive], list[int]]:
        """
        Create models to be used by components of the problem.

        Depending on how many constraints (and the causal estimand) require effect
        handlers to wrap `self._underlying_graph.model`, we will need to create several
        predictive models to sample from. However, we also want to minimise the number
        of such models we have to make, in order to minimise the time we spend
        actually computing samples.

        As such, in this method we determine:
        - How many models we will need to build, by grouping the constraints and the
          causal estimand by the handlers they use.
        - Build these models, returning them in a list called `models`.
        - Build another list that maps the index of components in
          `self._ordered_components` to the index of the model in `models` that they
          use. The causal estimand is by convention the component at index -1 of this
          returned list.

        Args:
            n_samples: Value to be passed to `numpyro.Predictive`'s `num_samples`
                argument for each of the models that are constructed from the underlying
                graph.

        Returns:
            list[Predictive]: List of Predictive models, whose elements contain all the
                models needed by the components.
            list[int]: Mapping of component indexes (as per `self_ordered_components`)
                to the index of the model in the first return argument that the
                component uses.

        """
        models: list[Predictive] = []
        grouped_component_indexes: list[list[int]] = []
        for index, component in enumerate(self._ordered_components):
            # Determine if this constraint uses the same handlers as those of any of
            # the other sets.
            belongs_to_existing_group = False
            for group in grouped_component_indexes:
                # Pull any element from the group to compare models to.
                # Items in a group are known to have the same model, so we can just
                # pull out the first one.
                group_element = self._ordered_components[group[0]]
                # Check if the current constraint can also use this model.
                if component.can_use_same_model_as(group_element):
                    group.append(index)
                    belongs_to_existing_group = True
                    break

            # If the component does not fit into any existing group, create a new
            # group for it. And add the model corresponding to the group to the
            # list of models.
            if not belongs_to_existing_group:
                grouped_component_indexes.append([index])

                models.append(
                    Predictive(
                        component.apply_effects(self._underlying_graph.model),
                        num_samples=n_samples,
                    )
                )

        # Now "invert" the grouping, creating a mapping that maps the index of a
        # component to the (index of the) model it uses.
        component_index_to_model_index = []
        for index in range(len(self._ordered_components)):
            for group_index, group in enumerate(grouped_component_indexes):
                if index in group:
                    component_index_to_model_index.append(group_index)
                    break
        # All indexes should belong to at least one group (worst case scenario,
        # their own individual group). Thus, it is safe to do the above to create
        # the mapping from component index -> model (group) index.
        return models, component_index_to_model_index

    def lagrangian(
        self, n_samples: int = 1000, *, maximum_problem: bool = False
    ) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]:
        """
        Return a function that evaluates the Lagrangian of this `CausalProblem`.

        Following the
        [KKT theorem](https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions),
        given the causal estimand and the constraints we can assemble a Lagrangian and
        seek its stationary points, to in turn identify minimisers of the constrained
        optimisation problem that we started with.

        The Lagrangian returned is a mathematical function of its first two arguments.
        The first argument is the same dictionary of parameters that is passed to models
        like `Graph.model`, and is the values the parameters (represented by the
        `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange
        multipliers, whose length is equal to the number of constraints.

        The remaining argument of the Lagrangian is the PRNG Key that should be used
        when drawing samples.

        Note that our current implementation assumes there are no equality constraints
        being imposed (in which case, we would need a 3-argument Lagrangian function).

        Args:
            n_samples: The number of random samples to be drawn when estimating the
                value of functions of the RVs.
            maximum_problem: If passed as `True`, assemble the Lagrangian for the
                maximisation problem. Otherwise assemble that for the minimisation
                problem (default behaviour).

        Returns:
            The Lagrangian, as a function of the model parameters, Lagrange multipliers,
                and PRNG key.

        """
        maximisation_prefactor = -1.0 if maximum_problem else 1.0

        # Build association between self.constraints and the model-samples that each
        # one needs to use. We do this here, since once it is constructed, it is
        # fixed, and doesn't need to be done each time we call the Lagrangian.
        models, component_to_index_mapping = self._associate_models_to_components(
            n_samples
        )

        def _inner(
            parameter_values: dict[str, npt.ArrayLike],
            l_mult: jax.Array,
            rng_key: jax.Array,
        ) -> npt.ArrayLike:
            # Draw samples from all models
            all_samples = tuple(
                sample_model(model, rng_key, parameter_values) for model in models
            )

            value = maximisation_prefactor * self.causal_estimand(all_samples[-1])
            # TODO: https://github.com/UCL/causalprog/issues/87
            value += sum(
                l_mult[i] * c(all_samples[component_to_index_mapping[i]])
                for i, c in enumerate(self.constraints)
            )
            return value

        return _inner
__init__(graph, *constraints, causal_estimand)

Create a new causal problem.

Source code in src/causalprog/causal_problem/causal_problem.py
52
53
54
55
56
57
58
59
60
61
def __init__(
    self,
    graph: Graph,
    *constraints: Constraint,
    causal_estimand: CausalEstimand,
) -> None:
    """Create a new causal problem."""
    self._underlying_graph = graph
    self.causal_estimand = causal_estimand
    self.constraints = list(constraints)
lagrangian(n_samples=1000, *, maximum_problem=False)

Return a function that evaluates the Lagrangian of this CausalProblem.

Following the KKT theorem, given the causal estimand and the constraints we can assemble a Lagrangian and seek its stationary points, to in turn identify minimisers of the constrained optimisation problem that we started with.

The Lagrangian returned is a mathematical function of its first two arguments. The first argument is the same dictionary of parameters that is passed to models like Graph.model, and is the values the parameters (represented by the ParameterNodes) are taking. The second argument is a 1D vector of Lagrange multipliers, whose length is equal to the number of constraints.

The remaining argument of the Lagrangian is the PRNG Key that should be used when drawing samples.

Note that our current implementation assumes there are no equality constraints being imposed (in which case, we would need a 3-argument Lagrangian function).

Parameters:

Name Type Description Default
n_samples int

The number of random samples to be drawn when estimating the value of functions of the RVs.

1000
maximum_problem bool

If passed as True, assemble the Lagrangian for the maximisation problem. Otherwise assemble that for the minimisation problem (default behaviour).

False

Returns:

Type Description
Callable[[dict[str, ArrayLike], ArrayLike, Array], ArrayLike]

The Lagrangian, as a function of the model parameters, Lagrange multipliers, and PRNG key.

Source code in src/causalprog/causal_problem/causal_problem.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def lagrangian(
    self, n_samples: int = 1000, *, maximum_problem: bool = False
) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]:
    """
    Return a function that evaluates the Lagrangian of this `CausalProblem`.

    Following the
    [KKT theorem](https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions),
    given the causal estimand and the constraints we can assemble a Lagrangian and
    seek its stationary points, to in turn identify minimisers of the constrained
    optimisation problem that we started with.

    The Lagrangian returned is a mathematical function of its first two arguments.
    The first argument is the same dictionary of parameters that is passed to models
    like `Graph.model`, and is the values the parameters (represented by the
    `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange
    multipliers, whose length is equal to the number of constraints.

    The remaining argument of the Lagrangian is the PRNG Key that should be used
    when drawing samples.

    Note that our current implementation assumes there are no equality constraints
    being imposed (in which case, we would need a 3-argument Lagrangian function).

    Args:
        n_samples: The number of random samples to be drawn when estimating the
            value of functions of the RVs.
        maximum_problem: If passed as `True`, assemble the Lagrangian for the
            maximisation problem. Otherwise assemble that for the minimisation
            problem (default behaviour).

    Returns:
        The Lagrangian, as a function of the model parameters, Lagrange multipliers,
            and PRNG key.

    """
    maximisation_prefactor = -1.0 if maximum_problem else 1.0

    # Build association between self.constraints and the model-samples that each
    # one needs to use. We do this here, since once it is constructed, it is
    # fixed, and doesn't need to be done each time we call the Lagrangian.
    models, component_to_index_mapping = self._associate_models_to_components(
        n_samples
    )

    def _inner(
        parameter_values: dict[str, npt.ArrayLike],
        l_mult: jax.Array,
        rng_key: jax.Array,
    ) -> npt.ArrayLike:
        # Draw samples from all models
        all_samples = tuple(
            sample_model(model, rng_key, parameter_values) for model in models
        )

        value = maximisation_prefactor * self.causal_estimand(all_samples[-1])
        # TODO: https://github.com/UCL/causalprog/issues/87
        value += sum(
            l_mult[i] * c(all_samples[component_to_index_mapping[i]])
            for i, c in enumerate(self.constraints)
        )
        return value

    return _inner

sample_model(model, rng_key, parameter_values)

Draw samples from the predictive model.

Parameters:

Name Type Description Default
model Predictive

Predictive model to draw samples from.

required
rng_key Array

PRNG Key to use in pseudorandom number generation.

required
parameter_values dict[str, ArrayLike]

Model parameter values to substitute.

required

Returns:

Type Description
dict[str, ArrayLike]

dict of samples, with RV labels as keys and sample values (jax.Arrays) as values.

Source code in src/causalprog/causal_problem/causal_problem.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def sample_model(
    model: Predictive, rng_key: jax.Array, parameter_values: dict[str, npt.ArrayLike]
) -> dict[str, npt.ArrayLike]:
    """
    Draw samples from the predictive model.

    Args:
        model: Predictive model to draw samples from.
        rng_key: PRNG Key to use in pseudorandom number generation.
        parameter_values: Model parameter values to substitute.

    Returns:
        `dict` of samples, with RV labels as keys and sample values (`jax.Array`s) as
            values.

    """
    return jax.vmap(lambda pv, key: model(key, **pv), in_axes=(None, 0))(
        parameter_values,
        jax.random.split(rng_key, model.num_samples),
    )

components

Classes for defining causal estimands and constraints of causal problems.

CausalEstimand

Bases: _CPComponent

A Causal Estimand.

The causal estimand is the function that we want to minimise (and maximise) as part of a causal problem. It should be a scalar-valued function of the random variables appearing in a graph.

Source code in src/causalprog/causal_problem/components.py
16
17
18
19
20
21
22
23
class CausalEstimand(_CPComponent):
    """
    A Causal Estimand.

    The causal estimand is the function that we want to minimise (and maximise)
    as part of a causal problem. It should be a scalar-valued function of the
    random variables appearing in a graph.
    """

Constraint

Bases: _CPComponent

A Constraint that forms part of a causal problem.

Constraints of a causal problem are derived properties of RVs for which we have observed data. The causal estimand is minimised (or maximised) subject to the predicted values of the constraints being close to their observed values in the data.

Adding a constraint \(g(\theta)\) to a causal problem (where \(\theta\) are the parameters of the causal problem) essentially imposes an additional constraint on the minimisation problem;

\[ g(\theta) - g_{\text{data}} \leq \epsilon, \]

where \(g_{\text{data}}\) is the observed data value for the quantity \(g\), and \(\epsilon\) is some tolerance.

Source code in src/causalprog/causal_problem/components.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class Constraint(_CPComponent):
    r"""
    A Constraint that forms part of a causal problem.

    Constraints of a causal problem are derived properties of RVs for which we
    have observed data. The causal estimand is minimised (or maximised) subject
    to the predicted values of the constraints being close to their observed
    values in the data.

    Adding a constraint $g(\theta)$ to a causal problem (where $\theta$ are the
    parameters of the causal problem) essentially imposes an additional
    constraint on the minimisation problem;

    $$ g(\theta) - g_{\text{data}} \leq \epsilon, $$

    where $g_{\text{data}}$ is the observed data value for the quantity $g$,
    and $\epsilon$ is some tolerance.
    """

    data: npt.ArrayLike
    tolerance: npt.ArrayLike
    _outer_norm: Callable[[npt.ArrayLike], float]

    def __init__(
        self,
        *effect_handlers: ModelMask,
        model_quantity: Callable[..., npt.ArrayLike],
        outer_norm: Callable[[npt.ArrayLike], float] | None = None,
        data: npt.ArrayLike = 0.0,
        tolerance: float = 1.0e-6,
    ) -> None:
        r"""
        Create a new constraint.

        Constraints have the form

        $$ c(\theta) :=
        \mathrm{norm}\left( g(\theta)
        - g_{\mathrm{data}} \right)
        - \epsilon $$

        where;
        - $\mathrm{norm}$ is the outer norm of the constraint (`outer_norm`),
        - $g(\theta)$ is the model quantity involved in the constraint
            (`model_quantity`),
        - $g_{\mathrm{data}}$ is the observed data (`data`),
        - $\epsilon$ is the tolerance in the data (`tolerance`).

        In a causal problem, each constraint appears as the condition $c(\theta)\leq 0$
        in the minimisation / maximisation (hence the inclusion of the $-\epsilon$
        term within $c(\theta)$ itself).

        $g$ should be a (possibly vector-valued) function that acts on (a subset of)
        samples from the random variables of the causal problem. It must accept
        variable keyword-arguments only, and should access the samples for each random
        variable by indexing via the RV names (node labels). It should return the
        model quantity as computed from the samples, that $g_{\mathrm{data}}$ observed.

        $g_{\mathrm{data}}$ should be a fixed value whose shape is broadcast-able with
        the return shape of $g$. It defaults to $0$ if not explicitly set.

        $\mathrm{norm}$ should be a suitable norm to take on the difference between the
        model quantity as predicted by the samples ($g$) and the observed data
        ($g_{\mathrm{data}}$). It must return a scalar value. The default is the 2-norm.
        """
        super().__init__(*effect_handlers, do_with_samples=model_quantity)

        if outer_norm is None:
            self._outer_norm = jnp.linalg.vector_norm
        else:
            self._outer_norm = outer_norm

        self.data = data
        self.tolerance = tolerance

    def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
        """
        Evaluate the constraint, given RV samples.

        Args:
            samples: Mapping of RV (node) labels to drawn samples.

        Returns:
            Value of the constraint.

        """
        return (
            self._outer_norm(self._do_with_samples(**samples) - self.data)
            - self.tolerance
        )
__call__(samples)

Evaluate the constraint, given RV samples.

Parameters:

Name Type Description Default
samples dict[str, ArrayLike]

Mapping of RV (node) labels to drawn samples.

required

Returns:

Type Description
ArrayLike

Value of the constraint.

Source code in src/causalprog/causal_problem/components.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
    """
    Evaluate the constraint, given RV samples.

    Args:
        samples: Mapping of RV (node) labels to drawn samples.

    Returns:
        Value of the constraint.

    """
    return (
        self._outer_norm(self._do_with_samples(**samples) - self.data)
        - self.tolerance
    )
__init__(*effect_handlers, model_quantity, outer_norm=None, data=0.0, tolerance=1e-06)

Create a new constraint.

Constraints have the form

\[ c(\theta) := \mathrm{norm}\left( g(\theta) - g_{\mathrm{data}} \right) - \epsilon \]

where; - \(\mathrm{norm}\) is the outer norm of the constraint (outer_norm), - \(g(\theta)\) is the model quantity involved in the constraint (model_quantity), - \(g_{\mathrm{data}}\) is the observed data (data), - \(\epsilon\) is the tolerance in the data (tolerance).

In a causal problem, each constraint appears as the condition \(c(\theta)\leq 0\) in the minimisation / maximisation (hence the inclusion of the \(-\epsilon\) term within \(c(\theta)\) itself).

\(g\) should be a (possibly vector-valued) function that acts on (a subset of) samples from the random variables of the causal problem. It must accept variable keyword-arguments only, and should access the samples for each random variable by indexing via the RV names (node labels). It should return the model quantity as computed from the samples, that \(g_{\mathrm{data}}\) observed.

\(g_{\mathrm{data}}\) should be a fixed value whose shape is broadcast-able with the return shape of \(g\). It defaults to \(0\) if not explicitly set.

\(\mathrm{norm}\) should be a suitable norm to take on the difference between the model quantity as predicted by the samples (\(g\)) and the observed data (\(g_{\mathrm{data}}\)). It must return a scalar value. The default is the 2-norm.

Source code in src/causalprog/causal_problem/components.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def __init__(
    self,
    *effect_handlers: ModelMask,
    model_quantity: Callable[..., npt.ArrayLike],
    outer_norm: Callable[[npt.ArrayLike], float] | None = None,
    data: npt.ArrayLike = 0.0,
    tolerance: float = 1.0e-6,
) -> None:
    r"""
    Create a new constraint.

    Constraints have the form

    $$ c(\theta) :=
    \mathrm{norm}\left( g(\theta)
    - g_{\mathrm{data}} \right)
    - \epsilon $$

    where;
    - $\mathrm{norm}$ is the outer norm of the constraint (`outer_norm`),
    - $g(\theta)$ is the model quantity involved in the constraint
        (`model_quantity`),
    - $g_{\mathrm{data}}$ is the observed data (`data`),
    - $\epsilon$ is the tolerance in the data (`tolerance`).

    In a causal problem, each constraint appears as the condition $c(\theta)\leq 0$
    in the minimisation / maximisation (hence the inclusion of the $-\epsilon$
    term within $c(\theta)$ itself).

    $g$ should be a (possibly vector-valued) function that acts on (a subset of)
    samples from the random variables of the causal problem. It must accept
    variable keyword-arguments only, and should access the samples for each random
    variable by indexing via the RV names (node labels). It should return the
    model quantity as computed from the samples, that $g_{\mathrm{data}}$ observed.

    $g_{\mathrm{data}}$ should be a fixed value whose shape is broadcast-able with
    the return shape of $g$. It defaults to $0$ if not explicitly set.

    $\mathrm{norm}$ should be a suitable norm to take on the difference between the
    model quantity as predicted by the samples ($g$) and the observed data
    ($g_{\mathrm{data}}$). It must return a scalar value. The default is the 2-norm.
    """
    super().__init__(*effect_handlers, do_with_samples=model_quantity)

    if outer_norm is None:
        self._outer_norm = jnp.linalg.vector_norm
    else:
        self._outer_norm = outer_norm

    self.data = data
    self.tolerance = tolerance

handlers

Container class for specifying effect handlers that need to be applied at runtime.

HandlerToApply dataclass

Specifies a handler that needs to be applied to a model at runtime.

Source code in src/causalprog/causal_problem/handlers.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@dataclass
class HandlerToApply:
    """Specifies a handler that needs to be applied to a model at runtime."""

    handler: EffectHandler
    options: dict[str, Any] = field(default_factory=dict)

    @classmethod
    def from_pair(cls, pair: Sequence) -> "HandlerToApply":
        """
        Create an instance from an effect handler and its options.

        The two objects should be passed in as the elements of a container of length
        2. They can be passed in any order;
        - One element must be a dictionary, which will be interpreted as the `options`
            for the effect handler.
        - The other element must be callable, and will be interpreted as the `handler`
            itself.

        Args:
            pair: Container of two elements, one being the effect handler callable and
                the other being the options to pass to it (as a dictionary).

        Returns:
            Class instance corresponding to the effect handler and options passed.

        """
        if len(pair) != 2:  # noqa: PLR2004
            msg = (
                f"{cls.__name__} can only be constructed from a container of 2 elements"
            )
            raise ValueError(msg)

        # __post_init__ will catch cases when the incorrect types for one or both items
        # is passed, so we can just naively if-else here.
        handler: EffectHandler
        options: dict
        if callable(pair[0]):
            handler = pair[0]
            options = pair[1]
        else:
            handler = pair[1]
            options = pair[0]

        return cls(handler=handler, options=options)

    def __post_init__(self) -> None:
        """
        Validate set attributes.

        - The handler is a callable object.
        - The options have been passed as a dictionary of keyword-value pairs.
        """
        if not callable(self.handler):
            msg = f"{type(self.handler).__name__} is not callable."
            raise TypeError(msg)
        if not isinstance(self.options, dict):
            msg = (
                "Options should be dictionary mapping option arguments to values "
                f"(got {type(self.options).__name__})."
            )
            raise TypeError(msg)

    def __eq__(self, other: object) -> bool:
        """
        Equality operation.

        `HandlerToApply`s are considered equal if they use the same handler function and
        provide the same options to this function.

        Comparison to other types returns `False`.
        """
        return (
            isinstance(other, HandlerToApply)
            and self.handler is other.handler
            and self.options == other.options
        )
__eq__(other)

Equality operation.

HandlerToApplys are considered equal if they use the same handler function and provide the same options to this function.

Comparison to other types returns False.

Source code in src/causalprog/causal_problem/handlers.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __eq__(self, other: object) -> bool:
    """
    Equality operation.

    `HandlerToApply`s are considered equal if they use the same handler function and
    provide the same options to this function.

    Comparison to other types returns `False`.
    """
    return (
        isinstance(other, HandlerToApply)
        and self.handler is other.handler
        and self.options == other.options
    )
__post_init__()

Validate set attributes.

  • The handler is a callable object.
  • The options have been passed as a dictionary of keyword-value pairs.
Source code in src/causalprog/causal_problem/handlers.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __post_init__(self) -> None:
    """
    Validate set attributes.

    - The handler is a callable object.
    - The options have been passed as a dictionary of keyword-value pairs.
    """
    if not callable(self.handler):
        msg = f"{type(self.handler).__name__} is not callable."
        raise TypeError(msg)
    if not isinstance(self.options, dict):
        msg = (
            "Options should be dictionary mapping option arguments to values "
            f"(got {type(self.options).__name__})."
        )
        raise TypeError(msg)
from_pair(pair) classmethod

Create an instance from an effect handler and its options.

The two objects should be passed in as the elements of a container of length 2. They can be passed in any order; - One element must be a dictionary, which will be interpreted as the options for the effect handler. - The other element must be callable, and will be interpreted as the handler itself.

Parameters:

Name Type Description Default
pair Sequence

Container of two elements, one being the effect handler callable and the other being the options to pass to it (as a dictionary).

required

Returns:

Type Description
HandlerToApply

Class instance corresponding to the effect handler and options passed.

Source code in src/causalprog/causal_problem/handlers.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@classmethod
def from_pair(cls, pair: Sequence) -> "HandlerToApply":
    """
    Create an instance from an effect handler and its options.

    The two objects should be passed in as the elements of a container of length
    2. They can be passed in any order;
    - One element must be a dictionary, which will be interpreted as the `options`
        for the effect handler.
    - The other element must be callable, and will be interpreted as the `handler`
        itself.

    Args:
        pair: Container of two elements, one being the effect handler callable and
            the other being the options to pass to it (as a dictionary).

    Returns:
        Class instance corresponding to the effect handler and options passed.

    """
    if len(pair) != 2:  # noqa: PLR2004
        msg = (
            f"{cls.__name__} can only be constructed from a container of 2 elements"
        )
        raise ValueError(msg)

    # __post_init__ will catch cases when the incorrect types for one or both items
    # is passed, so we can just naively if-else here.
    handler: EffectHandler
    options: dict
    if callable(pair[0]):
        handler = pair[0]
        options = pair[1]
    else:
        handler = pair[1]
        options = pair[0]

    return cls(handler=handler, options=options)

graph

Creation and storage of graphs.

graph

Graph storage.

Graph

Bases: Labelled

A directed acyclic graph that represents a causality tree.

Source code in src/causalprog/graph/graph.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
class Graph(Labelled):
    """A directed acyclic graph that represents a causality tree."""

    _nodes_by_label: dict[str, Node]

    def __init__(self, *, label: str, graph: nx.DiGraph | None = None) -> None:
        """
        Create a graph.

        Args:
            label: A label to identify the graph
            graph: A networkx graph to base this graph on

        """
        super().__init__(label=label)
        self._nodes_by_label = {}
        if graph is None:
            graph = nx.DiGraph()

        self._graph = graph
        for node in graph.nodes:
            self._nodes_by_label[node.label] = node

    def get_node(self, label: str) -> Node:
        """
        Get a node from its label.

        Args:
            label: The label

        Returns:
            The node

        """
        node = self._nodes_by_label.get(label, None)
        if not node:
            msg = f'Node not found with label "{label}"'
            raise KeyError(msg)
        return node

    def add_node(self, node: Node) -> None:
        """
        Add a node to the graph.

        Args:
            node: The node to add

        """
        if node.label in self._nodes_by_label:
            msg = f"Duplicate node label: {node.label}"
            raise ValueError(msg)
        self._nodes_by_label[node.label] = node
        self._graph.add_node(node)
        if isinstance(node, ComponentNode):
            if len(node.parents) != 1:
                msg = "ComponentNode should have exactly one parent."
                raise ValueError(msg)
            self.add_edge(node.parents[0], node.label)

    def add_edge(self, start_node: Node | str, end_node: Node | str) -> None:
        """
        Add a directed edge to the graph.

        Adding an edge between nodes not currently in the graph,
        will cause said nodes to be added to the graph along with
        the edge.

        Args:
            start_node: The node that the edge points from
            end_node: The node that the edge points to

        """
        if isinstance(start_node, str):
            start_node = self.get_node(start_node)
        if isinstance(end_node, str):
            end_node = self.get_node(end_node)
        if start_node.label not in self._nodes_by_label:
            self.add_node(start_node)
        if end_node.label not in self._nodes_by_label:
            self.add_node(end_node)
        for node_to_check in (start_node, end_node):
            if node_to_check != self._nodes_by_label[node_to_check.label]:
                msg = "Invalid node: {node_to_check}"
                raise ValueError(msg)
        self._graph.add_edge(start_node, end_node)

    @property
    def root_nodes(self) -> tuple[Node, ...]:
        """
        Returns all root nodes in the graph.

        Root nodes are nodes with no parents.

        The returned tuple uses the `ordered_nodes` property to obtain the root
        nodes so that a natural "fixed order" is given to the roots. When root
        values are given as inputs to the causal estimand and / or constraint functions,
        they will ideally be given as a single vector of root values, in which case
        a fixed ordering for the leaves is necessary to make an association to the
        components of the given input vector.

        Returns:
            Root nodes

        """
        return tuple(node for node in self.ordered_nodes if len(node.parents) == 0)

    @property
    def leaf_nodes(self) -> tuple[Node, ...]:
        """
        Returns all leaf nodes in the graph.

        Root nodes are nodes with no children.

        The returned tuple uses the `ordered_nodes` property to obtain the leaf
        nodes so that a natural "fixed order" is given to the leaves.

        Returns:
            Leaf nodes

        """
        labels = [node.label for node in self.ordered_nodes]
        for node in self.nodes:
            for p in node.parents:
                if p in labels:
                    labels.remove(p)
        return tuple(self.get_node(label) for label in labels)

    @property
    def predecessors(self) -> dict[Node, tuple[Node, ...]]:
        """
        Get predecessors of every node.

        Returns:
            Mapping of each Node to its predecessor Nodes

        """
        return {node: tuple(self._graph.predecessors(node)) for node in self.nodes}

    @property
    def successors(self) -> dict[Node, tuple[Node, ...]]:
        """
        Get successors of every node.

        Returns:
            Mapping of each Node to its successor Nodes.

        """
        return {node: tuple(self._graph.successors(node)) for node in self.nodes}

    @property
    def nodes(self) -> tuple[Node, ...]:
        """
        Get the nodes of the graph, with no enforced ordering.

        Returns:
            A list of all the nodes in the graph.

        See Also:
            ordered_nodes: Fetch an ordered list of the nodes in the graph.

        """
        return tuple(self._graph.nodes())

    @property
    def edges(self) -> tuple[tuple[Node, Node], ...]:
        """
        Get the edges of the graph.

        Returns:
            A tuple of all the edges in the graph.

        """
        return tuple(self._graph.edges())

    @property
    def ordered_nodes(self) -> tuple[Node, ...]:
        """
        Nodes ordered so that each node appears after its dependencies.

        Returns:
            A list of all the nodes, ordered such that each node
                appears after all its dependencies.

        """
        if not nx.is_directed_acyclic_graph(self._graph):
            msg = "Graph is not acyclic."
            raise RuntimeError(msg)
        return tuple(nx.topological_sort(self._graph))

    def roots_down_to_outcome(
        self,
        outcome_node_label: str,
    ) -> tuple[Node, ...]:
        """
        Get ordered list of nodes that outcome depends on.

        Nodes are ordered so that each node appears after its dependencies.

        Args:
            outcome_node_label: The label of the outcome node

        Returns:
            A list of the nodes, ordered from root nodes to the outcome Node.

        """
        outcome = self.get_node(outcome_node_label)
        ancestors = nx.ancestors(self._graph, outcome)
        return tuple(
            node for node in self.ordered_nodes if node == outcome or node in ancestors
        )

    def model(self, **parameter_values: npt.ArrayLike) -> dict[str, npt.ArrayLike]:
        """
        Model corresponding to the `Graph`'s structure.

        The model created takes values of the nodes that are parameter as keyword
        arguments. Names of the keyword arguments should match the labels of the
        `DataNode`s, and their values should be the values of those parameters.

        The method returns a dictionary recording the mode sites that are created.
        This means that the model can be 'extended' further by defining additional
        sites in a wrapper around this method.

        Args:
            parameter_values: Names of the keyword arguments should match the labels
                of the `DataNode`s, and their values should be the values of those
                parameters.

        Returns:
            Mapping of non-`DataNode` `Node` labels to the site objects created
                for these nodes.

        """
        # Confirm that all `DataNode`s have been assigned a value.
        for node in self.root_nodes:
            if node.label not in parameter_values:
                msg = f"DataNode '{node.label}' not assigned"
                raise KeyError(msg)

        # Build model sequentially, using the node_order to inform the
        # construction process.
        node_record: dict[str, npt.ArrayLike] = {}
        for node in self.ordered_nodes:
            if isinstance(node, DistributionNode):
                node_record[node.label] = node.create_model_site(
                    **parameter_values,  # All nodes require knowledge of the parameters
                    **node_record,  # and any dependent nodes we have already visited
                )

        return node_record
edges property

Get the edges of the graph.

Returns:

Type Description
tuple[tuple[Node, Node], ...]

A tuple of all the edges in the graph.

leaf_nodes property

Returns all leaf nodes in the graph.

Root nodes are nodes with no children.

The returned tuple uses the ordered_nodes property to obtain the leaf nodes so that a natural "fixed order" is given to the leaves.

Returns:

Type Description
tuple[Node, ...]

Leaf nodes

nodes property

Get the nodes of the graph, with no enforced ordering.

Returns:

Type Description
tuple[Node, ...]

A list of all the nodes in the graph.

See Also

ordered_nodes: Fetch an ordered list of the nodes in the graph.

ordered_nodes property

Nodes ordered so that each node appears after its dependencies.

Returns:

Type Description
tuple[Node, ...]

A list of all the nodes, ordered such that each node appears after all its dependencies.

predecessors property

Get predecessors of every node.

Returns:

Type Description
dict[Node, tuple[Node, ...]]

Mapping of each Node to its predecessor Nodes

root_nodes property

Returns all root nodes in the graph.

Root nodes are nodes with no parents.

The returned tuple uses the ordered_nodes property to obtain the root nodes so that a natural "fixed order" is given to the roots. When root values are given as inputs to the causal estimand and / or constraint functions, they will ideally be given as a single vector of root values, in which case a fixed ordering for the leaves is necessary to make an association to the components of the given input vector.

Returns:

Type Description
tuple[Node, ...]

Root nodes

successors property

Get successors of every node.

Returns:

Type Description
dict[Node, tuple[Node, ...]]

Mapping of each Node to its successor Nodes.

__init__(*, label, graph=None)

Create a graph.

Parameters:

Name Type Description Default
label str

A label to identify the graph

required
graph DiGraph | None

A networkx graph to base this graph on

None
Source code in src/causalprog/graph/graph.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def __init__(self, *, label: str, graph: nx.DiGraph | None = None) -> None:
    """
    Create a graph.

    Args:
        label: A label to identify the graph
        graph: A networkx graph to base this graph on

    """
    super().__init__(label=label)
    self._nodes_by_label = {}
    if graph is None:
        graph = nx.DiGraph()

    self._graph = graph
    for node in graph.nodes:
        self._nodes_by_label[node.label] = node
add_edge(start_node, end_node)

Add a directed edge to the graph.

Adding an edge between nodes not currently in the graph, will cause said nodes to be added to the graph along with the edge.

Parameters:

Name Type Description Default
start_node Node | str

The node that the edge points from

required
end_node Node | str

The node that the edge points to

required
Source code in src/causalprog/graph/graph.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def add_edge(self, start_node: Node | str, end_node: Node | str) -> None:
    """
    Add a directed edge to the graph.

    Adding an edge between nodes not currently in the graph,
    will cause said nodes to be added to the graph along with
    the edge.

    Args:
        start_node: The node that the edge points from
        end_node: The node that the edge points to

    """
    if isinstance(start_node, str):
        start_node = self.get_node(start_node)
    if isinstance(end_node, str):
        end_node = self.get_node(end_node)
    if start_node.label not in self._nodes_by_label:
        self.add_node(start_node)
    if end_node.label not in self._nodes_by_label:
        self.add_node(end_node)
    for node_to_check in (start_node, end_node):
        if node_to_check != self._nodes_by_label[node_to_check.label]:
            msg = "Invalid node: {node_to_check}"
            raise ValueError(msg)
    self._graph.add_edge(start_node, end_node)
add_node(node)

Add a node to the graph.

Parameters:

Name Type Description Default
node Node

The node to add

required
Source code in src/causalprog/graph/graph.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def add_node(self, node: Node) -> None:
    """
    Add a node to the graph.

    Args:
        node: The node to add

    """
    if node.label in self._nodes_by_label:
        msg = f"Duplicate node label: {node.label}"
        raise ValueError(msg)
    self._nodes_by_label[node.label] = node
    self._graph.add_node(node)
    if isinstance(node, ComponentNode):
        if len(node.parents) != 1:
            msg = "ComponentNode should have exactly one parent."
            raise ValueError(msg)
        self.add_edge(node.parents[0], node.label)
get_node(label)

Get a node from its label.

Parameters:

Name Type Description Default
label str

The label

required

Returns:

Type Description
Node

The node

Source code in src/causalprog/graph/graph.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def get_node(self, label: str) -> Node:
    """
    Get a node from its label.

    Args:
        label: The label

    Returns:
        The node

    """
    node = self._nodes_by_label.get(label, None)
    if not node:
        msg = f'Node not found with label "{label}"'
        raise KeyError(msg)
    return node
model(**parameter_values)

Model corresponding to the Graph's structure.

The model created takes values of the nodes that are parameter as keyword arguments. Names of the keyword arguments should match the labels of the DataNodes, and their values should be the values of those parameters.

The method returns a dictionary recording the mode sites that are created. This means that the model can be 'extended' further by defining additional sites in a wrapper around this method.

Parameters:

Name Type Description Default
parameter_values ArrayLike

Names of the keyword arguments should match the labels of the DataNodes, and their values should be the values of those parameters.

{}

Returns:

Type Description
dict[str, ArrayLike]

Mapping of non-DataNode Node labels to the site objects created for these nodes.

Source code in src/causalprog/graph/graph.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def model(self, **parameter_values: npt.ArrayLike) -> dict[str, npt.ArrayLike]:
    """
    Model corresponding to the `Graph`'s structure.

    The model created takes values of the nodes that are parameter as keyword
    arguments. Names of the keyword arguments should match the labels of the
    `DataNode`s, and their values should be the values of those parameters.

    The method returns a dictionary recording the mode sites that are created.
    This means that the model can be 'extended' further by defining additional
    sites in a wrapper around this method.

    Args:
        parameter_values: Names of the keyword arguments should match the labels
            of the `DataNode`s, and their values should be the values of those
            parameters.

    Returns:
        Mapping of non-`DataNode` `Node` labels to the site objects created
            for these nodes.

    """
    # Confirm that all `DataNode`s have been assigned a value.
    for node in self.root_nodes:
        if node.label not in parameter_values:
            msg = f"DataNode '{node.label}' not assigned"
            raise KeyError(msg)

    # Build model sequentially, using the node_order to inform the
    # construction process.
    node_record: dict[str, npt.ArrayLike] = {}
    for node in self.ordered_nodes:
        if isinstance(node, DistributionNode):
            node_record[node.label] = node.create_model_site(
                **parameter_values,  # All nodes require knowledge of the parameters
                **node_record,  # and any dependent nodes we have already visited
            )

    return node_record
roots_down_to_outcome(outcome_node_label)

Get ordered list of nodes that outcome depends on.

Nodes are ordered so that each node appears after its dependencies.

Parameters:

Name Type Description Default
outcome_node_label str

The label of the outcome node

required

Returns:

Type Description
tuple[Node, ...]

A list of the nodes, ordered from root nodes to the outcome Node.

Source code in src/causalprog/graph/graph.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def roots_down_to_outcome(
    self,
    outcome_node_label: str,
) -> tuple[Node, ...]:
    """
    Get ordered list of nodes that outcome depends on.

    Nodes are ordered so that each node appears after its dependencies.

    Args:
        outcome_node_label: The label of the outcome node

    Returns:
        A list of the nodes, ordered from root nodes to the outcome Node.

    """
    outcome = self.get_node(outcome_node_label)
    ancestors = nx.ancestors(self._graph, outcome)
    return tuple(
        node for node in self.ordered_nodes if node == outcome or node in ancestors
    )

node

Graph nodes.

base

Base graph node.

Node

Bases: Labelled

An abstract node in a graph.

Source code in src/causalprog/graph/node/base.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class Node(Labelled):
    """An abstract node in a graph."""

    def __init__(
        self,
        *,
        label: str,
        shape: tuple[int, ...] = (),
    ) -> None:
        """
        Initialise.

        Parameters (equivalently `ParameterNode`s) represent Nodes that do not have
        random variables attached. Instead, these nodes represent values that are passed
        to nodes that _do_ have distributions attached, and the value of the "parameter"
        node is used as a fixed value when constructing the dependent node's
        distribution. The set of parameter nodes is the collection of "parameter"s over
        which one should want to optimise the causal estimand (subject to any
        constraints), and as such the value that a "parameter node" passes to its
        dependent nodes will vary as the optimiser runs and explores the solution space.

        Distributions (equivalently `DistributionNode`s) are Nodes that represent
        random variables described by probability distributions.

        Args:
            label: A unique label to identify the node
            shape: The shape of the node's value for each sample

        """
        super().__init__(label=label)
        self._shape = shape

    def __getitem__(self, indices: int | slice | tuple[int | slice, ...]) -> Node:
        """Get a component of this node."""
        if isinstance(indices, int | slice):
            indices = (indices,)
        if not isinstance(indices, tuple):
            e = f"Invalid index: {indices}"
            raise TypeError(e)
        if len(indices) > len(self._shape):
            e = "list index out of range"
            raise IndexError(e)
        for i, j in zip(indices, self._shape, strict=False):
            if isinstance(i, int) and i >= j:
                e = "list index out of range"
                raise IndexError(e)

        from causalprog.graph import ComponentNode

        shape: tuple[int, ...] = ()
        for i, s in zip(indices, self._shape, strict=False):
            if isinstance(i, slice):
                shape += (len(range(*i.indices(s))),)
        shape += self._shape[len(indices) :]

        return ComponentNode(
            self.label,
            indices,
            shape=shape,
            label=f"{self.label}_{_to_string(indices)}",
        )

    @abstractmethod
    def sample(
        self,
        parameter_values: dict[str, float],
        sampled_dependencies: dict[str, npt.NDArray[float]],
        samples: int,
        *,
        rng_key: jax.Array,
    ) -> float:
        """
        Sample a value from the node.

        Args:
            parameter_values: Values to be taken by parameters
            sampled_dependencies: Values taken by dependencies of this node
            samples: Number of samples
            rng_key: Random key

        Returns:
            Sample value of this node

        """

    @abstractmethod
    def evaluate(
        self,
        **given_values: float | npt.NDArray[float],
    ) -> float | npt.NDArray[float]:
        """
        Evaluate the node, given evaluations of its precursor nodes.

        Args:
            given_values: Values for data nodes and values of parents

        Returns:
            Value of this node given `given_values`.

        """

    @abstractmethod
    def copy(self) -> Node:
        """
        Make a copy of a node.

        Some inner objects stored inside the node may not be copied when this is called.
        Modifying some inner objects of a copy made using this may affect the original
        node.

        Returns:
            A copy of the node

        """

    @property
    def shape(self) -> tuple[int, ...]:
        """
        The shape of the node's value for each sample.

        Returns:
            The shape

        """
        return self._shape

    @property
    @abstractmethod
    def parents(self) -> list[str]:
        """
        Nodes that this node depends on the value of.

        Returns:
            List of labels of parent nodes

        """
parents abstractmethod property

Nodes that this node depends on the value of.

Returns:

Type Description
list[str]

List of labels of parent nodes

shape property

The shape of the node's value for each sample.

Returns:

Type Description
tuple[int, ...]

The shape

__getitem__(indices)

Get a component of this node.

Source code in src/causalprog/graph/node/base.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def __getitem__(self, indices: int | slice | tuple[int | slice, ...]) -> Node:
    """Get a component of this node."""
    if isinstance(indices, int | slice):
        indices = (indices,)
    if not isinstance(indices, tuple):
        e = f"Invalid index: {indices}"
        raise TypeError(e)
    if len(indices) > len(self._shape):
        e = "list index out of range"
        raise IndexError(e)
    for i, j in zip(indices, self._shape, strict=False):
        if isinstance(i, int) and i >= j:
            e = "list index out of range"
            raise IndexError(e)

    from causalprog.graph import ComponentNode

    shape: tuple[int, ...] = ()
    for i, s in zip(indices, self._shape, strict=False):
        if isinstance(i, slice):
            shape += (len(range(*i.indices(s))),)
    shape += self._shape[len(indices) :]

    return ComponentNode(
        self.label,
        indices,
        shape=shape,
        label=f"{self.label}_{_to_string(indices)}",
    )
__init__(*, label, shape=())

Initialise.

Parameters (equivalently ParameterNodes) represent Nodes that do not have random variables attached. Instead, these nodes represent values that are passed to nodes that do have distributions attached, and the value of the "parameter" node is used as a fixed value when constructing the dependent node's distribution. The set of parameter nodes is the collection of "parameter"s over which one should want to optimise the causal estimand (subject to any constraints), and as such the value that a "parameter node" passes to its dependent nodes will vary as the optimiser runs and explores the solution space.

Distributions (equivalently DistributionNodes) are Nodes that represent random variables described by probability distributions.

Parameters:

Name Type Description Default
label str

A unique label to identify the node

required
shape tuple[int, ...]

The shape of the node's value for each sample

()
Source code in src/causalprog/graph/node/base.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(
    self,
    *,
    label: str,
    shape: tuple[int, ...] = (),
) -> None:
    """
    Initialise.

    Parameters (equivalently `ParameterNode`s) represent Nodes that do not have
    random variables attached. Instead, these nodes represent values that are passed
    to nodes that _do_ have distributions attached, and the value of the "parameter"
    node is used as a fixed value when constructing the dependent node's
    distribution. The set of parameter nodes is the collection of "parameter"s over
    which one should want to optimise the causal estimand (subject to any
    constraints), and as such the value that a "parameter node" passes to its
    dependent nodes will vary as the optimiser runs and explores the solution space.

    Distributions (equivalently `DistributionNode`s) are Nodes that represent
    random variables described by probability distributions.

    Args:
        label: A unique label to identify the node
        shape: The shape of the node's value for each sample

    """
    super().__init__(label=label)
    self._shape = shape
copy() abstractmethod

Make a copy of a node.

Some inner objects stored inside the node may not be copied when this is called. Modifying some inner objects of a copy made using this may affect the original node.

Returns:

Type Description
Node

A copy of the node

Source code in src/causalprog/graph/node/base.py
136
137
138
139
140
141
142
143
144
145
146
147
148
@abstractmethod
def copy(self) -> Node:
    """
    Make a copy of a node.

    Some inner objects stored inside the node may not be copied when this is called.
    Modifying some inner objects of a copy made using this may affect the original
    node.

    Returns:
        A copy of the node

    """
evaluate(**given_values) abstractmethod

Evaluate the node, given evaluations of its precursor nodes.

Parameters:

Name Type Description Default
given_values float | NDArray[float]

Values for data nodes and values of parents

{}

Returns:

Type Description
float | NDArray[float]

Value of this node given given_values.

Source code in src/causalprog/graph/node/base.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
@abstractmethod
def evaluate(
    self,
    **given_values: float | npt.NDArray[float],
) -> float | npt.NDArray[float]:
    """
    Evaluate the node, given evaluations of its precursor nodes.

    Args:
        given_values: Values for data nodes and values of parents

    Returns:
        Value of this node given `given_values`.

    """
sample(parameter_values, sampled_dependencies, samples, *, rng_key) abstractmethod

Sample a value from the node.

Parameters:

Name Type Description Default
parameter_values dict[str, float]

Values to be taken by parameters

required
sampled_dependencies dict[str, NDArray[float]]

Values taken by dependencies of this node

required
samples int

Number of samples

required
rng_key Array

Random key

required

Returns:

Type Description
float

Sample value of this node

Source code in src/causalprog/graph/node/base.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
@abstractmethod
def sample(
    self,
    parameter_values: dict[str, float],
    sampled_dependencies: dict[str, npt.NDArray[float]],
    samples: int,
    *,
    rng_key: jax.Array,
) -> float:
    """
    Sample a value from the node.

    Args:
        parameter_values: Values to be taken by parameters
        sampled_dependencies: Values taken by dependencies of this node
        samples: Number of samples
        rng_key: Random key

    Returns:
        Sample value of this node

    """

component

Graph nodes representing distributions.

ComponentNode

Bases: Node

A node representing a component of another node.

Source code in src/causalprog/graph/node/component.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class ComponentNode(Node):
    """A node representing a component of another node."""

    def __init__(
        self,
        parent_node_label: str,
        component: int | tuple[int, ...],
        *,
        shape: tuple[int, ...] = (),
        label: str,
    ) -> None:
        """
        Initialise.

        Args:
            parent_node_label: The node to take a component of
            component: The index/indices of the component
            label: A unique label to identify the node

        """
        self._component = (
            (component,) if isinstance(component, int) else tuple(component)
        )
        self._parent_node_label = parent_node_label
        super().__init__(shape=shape, label=label)

    @override
    def sample(
        self,
        parameter_values: dict[str, float],
        sampled_dependencies: dict[str, npt.NDArray[float]],
        samples: int,
        *,
        rng_key: jax.Array,
    ) -> npt.NDArray[float]:
        return sampled_dependencies[self._parent_node_label][:, *self._component]

    @override
    def evaluate(
        self,
        **given_values: float | npt.NDArray[float],
    ) -> float | npt.NDArray[float]:
        parent_value = given_values[self._parent_node_label]
        return parent_value[*self._component]  # type: ignore[index]

    @override
    def copy(self) -> Node:
        return ComponentNode(
            self._parent_node_label,
            self._component,
            label=self.label,
            shape=self.shape,
        )

    @override
    def __repr__(self) -> str:
        r = f'ComponentNode("{self._parent_node_label}", component={self._component}'
        if len(self.shape) > 0:
            r += f", shape={self.shape}"
        if len(self._parameters) > 0:
            r += f", parameters={self._parameters}"
        if len(self._constant_parameters) > 0:
            r += f", constant_parameters={self._constant_parameters}"
        return r

    @override
    @property
    def parents(self) -> list[str]:
        return [self._parent_node_label]
__init__(parent_node_label, component, *, shape=(), label)

Initialise.

Parameters:

Name Type Description Default
parent_node_label str

The node to take a component of

required
component int | tuple[int, ...]

The index/indices of the component

required
label str

A unique label to identify the node

required
Source code in src/causalprog/graph/node/component.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def __init__(
    self,
    parent_node_label: str,
    component: int | tuple[int, ...],
    *,
    shape: tuple[int, ...] = (),
    label: str,
) -> None:
    """
    Initialise.

    Args:
        parent_node_label: The node to take a component of
        component: The index/indices of the component
        label: A unique label to identify the node

    """
    self._component = (
        (component,) if isinstance(component, int) else tuple(component)
    )
    self._parent_node_label = parent_node_label
    super().__init__(shape=shape, label=label)

constant

Graph nodes representing distributions.

ConstantNode

Bases: Node

A node representing a constant.

Source code in src/causalprog/graph/node/constant.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class ConstantNode(Node):
    """A node representing a constant."""

    def __init__(self, *, label: str, value: float | npt.NDArray[float]) -> None:
        """
        Initialise.

        Args:
            label: A unique label to identify the node
            value: The value of this constant

        """
        self._value = value
        super().__init__(
            shape=() if isinstance(value, float) else value.shape, label=label
        )

    @override
    def sample(
        self,
        parameter_values: dict[str, float],
        sampled_dependencies: dict[str, npt.NDArray[float]],
        samples: int,
        *,
        rng_key: jax.Array,
    ) -> npt.NDArray[float]:
        return jnp.full(samples, self._value)

    @override
    def evaluate(
        self,
        **given_values: float | npt.NDArray[float],
    ) -> float | npt.NDArray[float]:
        return self._value

    @override
    def copy(self) -> Node:
        return ConstantNode(label=self.label, value=self._value)

    @override
    def __repr__(self) -> str:
        return f"ConstantNode({self._value})"

    @override
    @property
    def parents(self) -> list[str]:
        return []
__init__(*, label, value)

Initialise.

Parameters:

Name Type Description Default
label str

A unique label to identify the node

required
value float | NDArray[float]

The value of this constant

required
Source code in src/causalprog/graph/node/constant.py
20
21
22
23
24
25
26
27
28
29
30
31
32
def __init__(self, *, label: str, value: float | npt.NDArray[float]) -> None:
    """
    Initialise.

    Args:
        label: A unique label to identify the node
        value: The value of this constant

    """
    self._value = value
    super().__init__(
        shape=() if isinstance(value, float) else value.shape, label=label
    )

data

Graph nodes representing known of unknown data.

DataNode

Bases: Node

A node containing non-stochastic data.

DataNodes should not be used to encode constant values used by DistributionNodes. Such constant values should be given to the necessary DistributionNodes directly as constant_parameters.

Source code in src/causalprog/graph/node/data.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class DataNode(Node):
    """
    A node containing non-stochastic data.

    `DataNode`s should not be used to encode constant values used by
    `DistributionNode`s. Such constant values should be given to the necessary
    `DistributionNode`s directly as `constant_parameters`.
    """

    def __init__(self, *, shape: tuple[int, ...] = (), label: str) -> None:
        """
        Initialise.

        Args:
            label: A unique label to identify the node

        """
        super().__init__(label=label, shape=shape)

    @override
    def sample(
        self,
        parameter_values: dict[str, float],
        sampled_dependencies: dict[str, npt.ArrayLike],
        samples: int,
        *,
        rng_key: jax.Array,
    ) -> npt.ArrayLike:
        if self.label not in parameter_values:
            msg = f"Missing input for node: {self.label}."
            raise ValueError(msg)
        return jnp.full(samples, parameter_values[self.label])

    @override
    def evaluate(
        self,
        **given_values: float | npt.NDArray[float],
    ) -> float | npt.NDArray[float]:
        if self.label not in given_values:
            msg = f"Missing input for node: {self.label}."
            raise ValueError(msg)
        value = given_values[self.label]
        if self.shape != (value.shape if hasattr(value, "shape") else ()):
            msg = f"Invalid value for node: {self.label}"
            raise ValueError(msg)
        return value

    @override
    def copy(self) -> Node:
        return DataNode(label=self.label)

    @override
    def __repr__(self) -> str:
        return f'DataNode(label="{self.label}")'

    @override
    @property
    def parents(self) -> list[str]:
        return []
__init__(*, shape=(), label)

Initialise.

Parameters:

Name Type Description Default
label str

A unique label to identify the node

required
Source code in src/causalprog/graph/node/data.py
20
21
22
23
24
25
26
27
28
def __init__(self, *, shape: tuple[int, ...] = (), label: str) -> None:
    """
    Initialise.

    Args:
        label: A unique label to identify the node

    """
    super().__init__(label=label, shape=shape)

distribution

Graph nodes representing distributions.

DistributionNode

Bases: Node

A node containing a distribution.

Source code in src/causalprog/graph/node/distribution.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class DistributionNode(Node):
    """A node containing a distribution."""

    def __init__(
        self,
        distribution: type,
        *,
        label: str,
        shape: tuple[int, ...] = (),
        parameters: dict[str, str] | None = None,
        constant_parameters: dict[str, float] | None = None,
    ) -> None:
        """
        Initialise.

        Args:
            distribution: The distribution
            label: A unique label to identify the node
            shape: The shape of the value for each sample
            parameters: A dictionary of parameters
            constant_parameters: A dictionary of constant parameters

        """
        self._dist = distribution
        self._constant_parameters = constant_parameters or {}
        self._parameters = parameters or {}
        super().__init__(label=label, shape=shape)

    @override
    def sample(
        self,
        parameter_values: dict[str, float],
        sampled_dependencies: dict[str, npt.NDArray[float]],
        samples: int,
        *,
        rng_key: jax.Array,
    ) -> npt.NDArray[float]:
        d = self._dist(
            # Pass in node values derived from construction so far
            **{
                native_name: sampled_dependencies[node_name]
                for native_name, node_name in self._parameters.items()
            },
            # Pass in any constant parameters this node sets
            **self._constant_parameters,
        )

        return numpyro.sample(
            self.label,
            d,
            rng_key=rng_key,
            sample_shape=(samples, *self.shape)
            if d.batch_shape == () and samples > 1
            else self.shape,
        )

    @override
    def evaluate(
        self,
        **given_values: float | npt.NDArray[float],
    ) -> float | npt.NDArray[float]:
        msg = "Cannot evaluate a DistributionNode"
        raise RuntimeError(msg)

    @override
    def copy(self) -> Node:
        return DistributionNode(
            self._dist,
            label=self.label,
            shape=self.shape,
            parameters=dict(self._parameters),
            constant_parameters=dict(self._constant_parameters.items()),
        )

    @override
    def __repr__(self) -> str:
        r = f'DistributionNode({self._dist.__name__}, label="{self.label}"'
        if len(self._parameters) > 0:
            r += f", parameters={self._parameters}"
        if len(self.shape) > 0:
            r += f", shape={self.shape}"
        if len(self._constant_parameters) > 0:
            r += f", constant_parameters={self._constant_parameters}"
        return r

    @override
    @property
    def parents(self) -> list[str]:
        return [*self._parameters.keys(), *self._constant_parameters.keys()]

    def create_model_site(self, **dependent_nodes: jax.Array) -> npt.ArrayLike:
        """
        Create a model site for the (conditional) distribution attached to this node.

        `dependent_nodes` should contain keyword arguments mapping dependent node names
        to the values that those nodes are taking (`ParameterNode`s), or the sampling
        object for those nodes (`DistributionNode`s). These are passed to
        `self._dist` as keyword arguments to construct the sample-able object
        representing this node.
        """
        return numpyro.sample(
            self.label,
            self._dist(
                # Pass in node values derived from construction so far
                **{
                    native_name: dependent_nodes[node_name]
                    for native_name, node_name in self._parameters.items()
                },
                # Pass in any constant parameters this node sets
                **self._constant_parameters,
            ),
        )
__init__(distribution, *, label, shape=(), parameters=None, constant_parameters=None)

Initialise.

Parameters:

Name Type Description Default
distribution type

The distribution

required
label str

A unique label to identify the node

required
shape tuple[int, ...]

The shape of the value for each sample

()
parameters dict[str, str] | None

A dictionary of parameters

None
constant_parameters dict[str, float] | None

A dictionary of constant parameters

None
Source code in src/causalprog/graph/node/distribution.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def __init__(
    self,
    distribution: type,
    *,
    label: str,
    shape: tuple[int, ...] = (),
    parameters: dict[str, str] | None = None,
    constant_parameters: dict[str, float] | None = None,
) -> None:
    """
    Initialise.

    Args:
        distribution: The distribution
        label: A unique label to identify the node
        shape: The shape of the value for each sample
        parameters: A dictionary of parameters
        constant_parameters: A dictionary of constant parameters

    """
    self._dist = distribution
    self._constant_parameters = constant_parameters or {}
    self._parameters = parameters or {}
    super().__init__(label=label, shape=shape)
create_model_site(**dependent_nodes)

Create a model site for the (conditional) distribution attached to this node.

dependent_nodes should contain keyword arguments mapping dependent node names to the values that those nodes are taking (ParameterNodes), or the sampling object for those nodes (DistributionNodes). These are passed to self._dist as keyword arguments to construct the sample-able object representing this node.

Source code in src/causalprog/graph/node/distribution.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def create_model_site(self, **dependent_nodes: jax.Array) -> npt.ArrayLike:
    """
    Create a model site for the (conditional) distribution attached to this node.

    `dependent_nodes` should contain keyword arguments mapping dependent node names
    to the values that those nodes are taking (`ParameterNode`s), or the sampling
    object for those nodes (`DistributionNode`s). These are passed to
    `self._dist` as keyword arguments to construct the sample-able object
    representing this node.
    """
    return numpyro.sample(
        self.label,
        self._dist(
            # Pass in node values derived from construction so far
            **{
                native_name: dependent_nodes[node_name]
                for native_name, node_name in self._parameters.items()
            },
            # Pass in any constant parameters this node sets
            **self._constant_parameters,
        ),
    )

random_variables

Graph nodes representing random variables.

ContinuousRandomVariableNode

Bases: RandomVariableNode

A node containing a continuous random variable (RV).

Source code in src/causalprog/graph/node/random_variables.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class ContinuousRandomVariableNode(RandomVariableNode):
    """A node containing a continuous random variable (RV)."""

    @override
    def __repr__(self) -> str:
        return f'ContinuousRandomVariableNode(label="{self.label}")'

    @override
    def is_valid_value(self, value: float | npt.NDArray[float]) -> bool:
        return True

    @override
    def copy(self) -> Node:
        return ContinuousRandomVariableNode(
            shape=self.shape, label=self.label, compute=self._compute
        )
DiscreteRandomVariableNode

Bases: RandomVariableNode

A node containing a discrete random variable (RV).

Source code in src/causalprog/graph/node/random_variables.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class DiscreteRandomVariableNode(RandomVariableNode):
    """A node containing a discrete random variable (RV)."""

    def __init__(
        self,
        *,
        values: list[float] | list[npt.NDArray[float]],
        shape: tuple[int, ...] = (),
        label: str,
        compute: typing.Callable | None = None,
    ) -> None:
        """
        Initialise.

        Args:
            shape: The shape of the output of the RV
            label: A unique label to identify the node
            compute: A function to compute node's value from given values of parents

        """
        super().__init__(label=label, shape=shape, compute=compute)
        self._values = values

    @property
    def possible_values(self) -> list[float] | list[npt.NDArray[float]]:
        """The values that this RV can take."""
        return self._values

    @override
    def __repr__(self) -> str:
        return f'DiscreteRandomVariableNode(label="{self.label}")'

    @override
    def is_valid_value(self, value: float | npt.NDArray[float]) -> bool:
        return any(np.allclose(v, value) for v in self._values)

    @override
    def copy(self) -> Node:
        return DiscreteRandomVariableNode(
            values=self._values,
            shape=self.shape,
            label=self.label,
            compute=self._compute,
        )
possible_values property

The values that this RV can take.

__init__(*, values, shape=(), label, compute=None)

Initialise.

Parameters:

Name Type Description Default
shape tuple[int, ...]

The shape of the output of the RV

()
label str

A unique label to identify the node

required
compute Callable | None

A function to compute node's value from given values of parents

None
Source code in src/causalprog/graph/node/random_variables.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def __init__(
    self,
    *,
    values: list[float] | list[npt.NDArray[float]],
    shape: tuple[int, ...] = (),
    label: str,
    compute: typing.Callable | None = None,
) -> None:
    """
    Initialise.

    Args:
        shape: The shape of the output of the RV
        label: A unique label to identify the node
        compute: A function to compute node's value from given values of parents

    """
    super().__init__(label=label, shape=shape, compute=compute)
    self._values = values
RandomVariableNode

Bases: Node

A node containing a random variable (RV).

Source code in src/causalprog/graph/node/random_variables.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class RandomVariableNode(Node):
    """A node containing a random variable (RV)."""

    def __init__(
        self,
        *,
        shape: tuple[int, ...] = (),
        label: str,
        compute: typing.Callable | None = None,
    ) -> None:
        """
        Initialise.

        Args:
            shape: The shape of the output of the RV
            label: A unique label to identify the node
            compute: A function to compute node's value from given values of parents

        """
        super().__init__(label=label, shape=shape)
        if compute is None:
            self._parents = []
        else:
            self._parents = list(inspect.signature(compute).parameters.keys())
        self._compute = compute

    @override
    def sample(
        self,
        parameter_values: dict[str, float],
        sampled_dependencies: dict[str, npt.ArrayLike],
        samples: int,
        *,
        rng_key: jax.Array,
    ) -> npt.ArrayLike:
        raise NotImplementedError

    @override
    def evaluate(
        self,
        **given_values: float | npt.NDArray[float],
    ) -> float | npt.NDArray[float]:
        if self.label in given_values:
            value = given_values[self.label]
            self.assert_is_valid_value(value)
            return value

        if self._compute is None:
            msg = f"Missing input for node: {self.label}."
            raise ValueError(msg)
        return self._compute(**{p: given_values[p] for p in self._parents})

    @override
    @property
    def parents(self) -> list[str]:
        return self._parents

    @abstractmethod
    def is_valid_value(self, value: float | npt.NDArray[float]) -> bool:
        """Check if a value is valid for this node."""

    def assert_is_valid_value(self, value: float | npt.NDArray[float]) -> None:
        """Check if a value is valid for this node."""
        if not self.is_valid_value(value):
            msg = (
                f"Invalid value for {self.__class__.__name__}: "
                f"{self.label} cannot be {value}"
            )
            raise ValueError(msg)
        if self.shape != (value.shape if hasattr(value, "shape") else ()):
            msg = f"Invalid value for node: {self.label}"
            raise ValueError(msg)
__init__(*, shape=(), label, compute=None)

Initialise.

Parameters:

Name Type Description Default
shape tuple[int, ...]

The shape of the output of the RV

()
label str

A unique label to identify the node

required
compute Callable | None

A function to compute node's value from given values of parents

None
Source code in src/causalprog/graph/node/random_variables.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    *,
    shape: tuple[int, ...] = (),
    label: str,
    compute: typing.Callable | None = None,
) -> None:
    """
    Initialise.

    Args:
        shape: The shape of the output of the RV
        label: A unique label to identify the node
        compute: A function to compute node's value from given values of parents

    """
    super().__init__(label=label, shape=shape)
    if compute is None:
        self._parents = []
    else:
        self._parents = list(inspect.signature(compute).parameters.keys())
    self._compute = compute
assert_is_valid_value(value)

Check if a value is valid for this node.

Source code in src/causalprog/graph/node/random_variables.py
76
77
78
79
80
81
82
83
84
85
86
def assert_is_valid_value(self, value: float | npt.NDArray[float]) -> None:
    """Check if a value is valid for this node."""
    if not self.is_valid_value(value):
        msg = (
            f"Invalid value for {self.__class__.__name__}: "
            f"{self.label} cannot be {value}"
        )
        raise ValueError(msg)
    if self.shape != (value.shape if hasattr(value, "shape") else ()):
        msg = f"Invalid value for node: {self.label}"
        raise ValueError(msg)
is_valid_value(value) abstractmethod

Check if a value is valid for this node.

Source code in src/causalprog/graph/node/random_variables.py
72
73
74
@abstractmethod
def is_valid_value(self, value: float | npt.NDArray[float]) -> bool:
    """Check if a value is valid for this node."""

quadrature

Quadrature rules.

GaussianQuadrature

Bases: QuadratureMethod

A Gaussian quadrature rule.

The domain of integration for the points \(p_i\) and weights \(w_i\) is \([-1,1]\). This means that to integrate an integrand \(f\) over the interval \([a,b]\), the approximation

\[ \int_a^b f(x) dx \approx \frac{b - a}{2}\sum_{w_i}f\left( \frac{b - a}{2} p_i + \frac{b + a}{2}\right) \]

is used.

Source code in src/causalprog/quadrature/gaussian.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class GaussianQuadrature(QuadratureMethod):
    r"""
    A Gaussian quadrature rule.

    The domain of integration for the points $p_i$ and weights $w_i$ is $[-1,1]$.
    This means that to integrate an integrand $f$ over the interval $[a,b]$, the
    approximation

    $$
    \int_a^b f(x) dx \approx
    \frac{b - a}{2}\sum_{w_i}f\left( \frac{b - a}{2} p_i + \frac{b + a}{2}\right)
    $$

    is used.
    """

    _pts: npt.NDArray
    _wts: npt.NDArray

    @override
    def __init__(self, n_points: int) -> None:
        super().__init__(n_points)
        pts, wts = quadraturerules.single_integral_quadrature(
            quadraturerules.QuadratureRule.GaussLegendre,
            quadraturerules.Domain.Interval,
            n_points,
        )
        self._pts = pts[:, 1] - pts[:, 0]
        self._wts = wts * 2.0

    def integrate(
        self,
        integrand: Integrand,
        a: float = -1,
        b: float = 1,
        *integrand_args: IntegrandArgs.args,
        **integrand_kwargs: IntegrandArgs.kwargs,
    ) -> float:
        """Integrate the `integrand` over $[a,b]$ via Gaussian quadrature."""
        result = 0.0
        for p_i, w_i in self.pts_wts_tuples(a=a, b=b):
            result += w_i * integrand(
                p_i,
                *integrand_args,
                **integrand_kwargs,
            )

        return result * (b - a) / 2

    def points_and_weights(
        self, a: float = -1.0, b: float = 1.0
    ) -> tuple[npt.NDArray, npt.NDArray]:
        """Get quadrature points and weights for performing integration on $[a,b]$."""
        change_of_vars_derivative = (b - a) / 2.0
        interval_midpoint = (b + a) / 2.0
        return self._pts * change_of_vars_derivative + interval_midpoint, self._wts

integrate(integrand, a=-1, b=1, *integrand_args, **integrand_kwargs)

Integrate the integrand over \([a,b]\) via Gaussian quadrature.

Source code in src/causalprog/quadrature/gaussian.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def integrate(
    self,
    integrand: Integrand,
    a: float = -1,
    b: float = 1,
    *integrand_args: IntegrandArgs.args,
    **integrand_kwargs: IntegrandArgs.kwargs,
) -> float:
    """Integrate the `integrand` over $[a,b]$ via Gaussian quadrature."""
    result = 0.0
    for p_i, w_i in self.pts_wts_tuples(a=a, b=b):
        result += w_i * integrand(
            p_i,
            *integrand_args,
            **integrand_kwargs,
        )

    return result * (b - a) / 2

points_and_weights(a=-1.0, b=1.0)

Get quadrature points and weights for performing integration on \([a,b]\).

Source code in src/causalprog/quadrature/gaussian.py
59
60
61
62
63
64
65
def points_and_weights(
    self, a: float = -1.0, b: float = 1.0
) -> tuple[npt.NDArray, npt.NDArray]:
    """Get quadrature points and weights for performing integration on $[a,b]$."""
    change_of_vars_derivative = (b - a) / 2.0
    interval_midpoint = (b + a) / 2.0
    return self._pts * change_of_vars_derivative + interval_midpoint, self._wts

MonteCarloGaussianQuadrature

Bases: RNGQuadratureMethod

Monte Carlo quadrature, sampled from a standard Gaussian.

Let \(N\) be the number of sample points to be used by the scheme. The quadrature method approximates the integral

\[ \int_a^b f(x) dx \approx \frac{1}{N}\sum_{p_i} \frac{f(p_i)}{\mathcal{P}(p_i)}, \]

where \(p_i\in[a,b]\) are \(N\) samples drawn from a standard Gaussian.

Source code in src/causalprog/quadrature/monte_carlo.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class MonteCarloGaussianQuadrature(RNGQuadratureMethod):
    r"""
    Monte Carlo quadrature, sampled from a standard Gaussian.

    Let $N$ be the number of sample points to be used by the scheme.
    The quadrature method approximates the integral

    $$
    \int_a^b f(x) dx
    \approx \frac{1}{N}\sum_{p_i} \frac{f(p_i)}{\mathcal{P}(p_i)},
    $$

    where $p_i\in[a,b]$ are $N$ samples drawn from a standard Gaussian.
    """

    def integrate(
        self,
        integrand: Integrand,
        a: float = -1.0,
        b: float = 1.0,
        *integrand_args: IntegrandArgs.args,
        **integrand_kwargs: IntegrandArgs.kwargs,
    ) -> float:
        """Perform Monte-Carlo integration of the `integrand` over $[a,b]$."""
        result = 0.0

        for p_i, w_i in self.pts_wts_tuples(a=a, b=b):
            result += integrand(p_i, *integrand_args, **integrand_kwargs) / w_i

        return result / self.n_points

    @override
    def points_and_weights(
        self, a: float = -1.0, b: float = 1.0
    ) -> tuple[npt.NDArray, npt.NDArray]:
        pts = jax.random.truncated_normal(
            self.rng_key, lower=a, upper=b, shape=(self.n_points,)
        )
        wts = jax.scipy.stats.truncnorm.pdf(pts, a, b)
        return pts, wts

integrate(integrand, a=-1.0, b=1.0, *integrand_args, **integrand_kwargs)

Perform Monte-Carlo integration of the integrand over \([a,b]\).

Source code in src/causalprog/quadrature/monte_carlo.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def integrate(
    self,
    integrand: Integrand,
    a: float = -1.0,
    b: float = 1.0,
    *integrand_args: IntegrandArgs.args,
    **integrand_kwargs: IntegrandArgs.kwargs,
) -> float:
    """Perform Monte-Carlo integration of the `integrand` over $[a,b]$."""
    result = 0.0

    for p_i, w_i in self.pts_wts_tuples(a=a, b=b):
        result += integrand(p_i, *integrand_args, **integrand_kwargs) / w_i

    return result / self.n_points

UniformWeightMonteCarloGaussianQuadrature

Bases: RNGQuadratureMethod

Monte Carlo quadrature, sampled from a standard Gaussian, but using uniform weights.

Let \(N\) be the number of sample points to be used by the scheme. The quadrature method approximates the integral

\[ \int_a^b f(x) dx \approx \frac{1}{N}\sum_{i} f(x_i), \]

where \(x_i\in[a,b]\) are \(N\) samples drawn from a standard Gaussian.

This is effectively computing \(\mathbb{E}[f(X)]\) when \(X\) is distributed according to a truncated normal on \([a, b]\) with mean 0 and standard deviation 1. As one would expect, the above rule for integrating \(f\) is identical to conducting standard Monte-Carlo integration (with Gaussian importance sampling), but on the integrand \(F(x) = f(x)\mathcal{P}(x)\), where \(\mathcal{P}\) is the PDF of a (truncated to \([a, b]\)) normal distribution.

Source code in src/causalprog/quadrature/monte_carlo.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
class UniformWeightMonteCarloGaussianQuadrature(RNGQuadratureMethod):
    r"""
    Monte Carlo quadrature, sampled from a standard Gaussian, but using uniform weights.

    Let $N$ be the number of sample points to be used by the scheme.
    The quadrature method approximates the integral

    $$
    \int_a^b f(x) dx
    \approx \frac{1}{N}\sum_{i} f(x_i),
    $$

    where $x_i\in[a,b]$ are $N$ samples drawn from a standard Gaussian.

    This is effectively computing $\mathbb{E}[f(X)]$ when $X$ is distributed according
    to a truncated normal on $[a, b]$ with mean 0 and standard deviation 1. As one
    would expect, the above rule for integrating $f$ is identical to conducting standard
    Monte-Carlo integration (with Gaussian importance sampling), but on the integrand
    $F(x) = f(x)\mathcal{P}(x)$, where $\mathcal{P}$ is the PDF of a
    (truncated to $[a, b]$) normal distribution.
    """

    def integrate(
        self,
        integrand: Integrand,
        a: float = -1.0,
        b: float = 1.0,
        *integrand_args: IntegrandArgs.args,
        **integrand_kwargs: IntegrandArgs.kwargs,
    ) -> float:
        r"""
        Perform Monte-Carlo integration of the `integrand` over $[a,b]$.

        In terms of the concrete classes in the codebase; if `P`
        again represents the PDF of a truncated normal distribution, the following are
        identical:
        - `UniformWeightGaussianSamplesMonteCarloQuadrature.integrate(f, ...)`
        - `MonteCarloGaussianQuadrature.integrate(f/P, ...)`.
        """
        pts, _ = self.points_and_weights(a=a, b=b)
        ptwise_evaluation: jax.Array = jax.vmap(
            lambda x: integrand(x, *integrand_args, **integrand_kwargs)
        )(pts)

        return ptwise_evaluation.sum() / self.n_points

    @override
    def points_and_weights(
        self, a: float = -1.0, b: float = 1.0
    ) -> tuple[npt.NDArray, npt.NDArray]:
        pts = jax.random.truncated_normal(
            self.rng_key, lower=a, upper=b, shape=(self.n_points,)
        )
        wts = jax.numpy.full((self.n_points,), 1.0 / self.n_points)
        return pts, wts

integrate(integrand, a=-1.0, b=1.0, *integrand_args, **integrand_kwargs)

Perform Monte-Carlo integration of the integrand over \([a,b]\).

In terms of the concrete classes in the codebase; if P again represents the PDF of a truncated normal distribution, the following are identical: - UniformWeightGaussianSamplesMonteCarloQuadrature.integrate(f, ...) - MonteCarloGaussianQuadrature.integrate(f/P, ...).

Source code in src/causalprog/quadrature/monte_carlo.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def integrate(
    self,
    integrand: Integrand,
    a: float = -1.0,
    b: float = 1.0,
    *integrand_args: IntegrandArgs.args,
    **integrand_kwargs: IntegrandArgs.kwargs,
) -> float:
    r"""
    Perform Monte-Carlo integration of the `integrand` over $[a,b]$.

    In terms of the concrete classes in the codebase; if `P`
    again represents the PDF of a truncated normal distribution, the following are
    identical:
    - `UniformWeightGaussianSamplesMonteCarloQuadrature.integrate(f, ...)`
    - `MonteCarloGaussianQuadrature.integrate(f/P, ...)`.
    """
    pts, _ = self.points_and_weights(a=a, b=b)
    ptwise_evaluation: jax.Array = jax.vmap(
        lambda x: integrand(x, *integrand_args, **integrand_kwargs)
    )(pts)

    return ptwise_evaluation.sum() / self.n_points

base

Base quadrature class.

QuadratureMethod

Bases: ABC

An abstract quadrature method.

All QuadratureMethods are required to provide a means of obtaining the points and weights that they use, accessible through the points_and_weights method of an instance.

Instances also provide an integrate method, to perform numerical integration of an integrand.

Source code in src/causalprog/quadrature/base.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class QuadratureMethod(ABC):
    """
    An abstract quadrature method.

    All `QuadratureMethod`s are required to provide a means of obtaining the
    points and weights that they use, accessible through the `points_and_weights`
    method of an instance.

    Instances also provide an `integrate` method, to perform
    numerical integration of an integrand.
    """

    def __init__(self, n_points: int) -> None:
        """
        Initialise.

        Args:
            n_points: The number of quadrature points

        """
        self._npts = n_points

    @property
    def n_points(self) -> int:
        """Number of quadrature points."""
        return self._npts

    @abstractmethod
    def integrate(
        self,
        integrand: Integrand,
        a: float = -1.0,
        b: float = 1.0,
        *integrand_args: IntegrandArgs.args,
        **integrand_kwargs: IntegrandArgs.kwargs,
    ) -> float:
        """
        Integrate the `integrand` over `[a,b]` using the `QuadratureMethod`.

        Subclasses should implement specific details.

        Ideally, we would be able to assume that the integrand is vectorised
        in it's first argument (Callable[[ArrayLike, ...], ArrayLike]).
        Then we could do without the for-loop in each of the subclass implementations.
        """

    @abstractmethod
    def points_and_weights(
        self, a: float = -1.0, b: float = 1.0
    ) -> tuple[npt.NDArray, npt.NDArray]:
        """Get quadrature points and weights for performing integration on $[a,b]$."""

    def pts_wts_tuples(
        self, a: float = -1.0, b: float = 1.0
    ) -> list[tuple[float, float]]:
        """Get `(point, weight)` pairs as a list of tuples."""
        return list(zip(*self.points_and_weights(a=a, b=b), strict=True))
n_points property

Number of quadrature points.

__init__(n_points)

Initialise.

Parameters:

Name Type Description Default
n_points int

The number of quadrature points

required
Source code in src/causalprog/quadrature/base.py
26
27
28
29
30
31
32
33
34
def __init__(self, n_points: int) -> None:
    """
    Initialise.

    Args:
        n_points: The number of quadrature points

    """
    self._npts = n_points
integrate(integrand, a=-1.0, b=1.0, *integrand_args, **integrand_kwargs) abstractmethod

Integrate the integrand over [a,b] using the QuadratureMethod.

Subclasses should implement specific details.

Ideally, we would be able to assume that the integrand is vectorised in it's first argument (Callable[[ArrayLike, ...], ArrayLike]). Then we could do without the for-loop in each of the subclass implementations.

Source code in src/causalprog/quadrature/base.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@abstractmethod
def integrate(
    self,
    integrand: Integrand,
    a: float = -1.0,
    b: float = 1.0,
    *integrand_args: IntegrandArgs.args,
    **integrand_kwargs: IntegrandArgs.kwargs,
) -> float:
    """
    Integrate the `integrand` over `[a,b]` using the `QuadratureMethod`.

    Subclasses should implement specific details.

    Ideally, we would be able to assume that the integrand is vectorised
    in it's first argument (Callable[[ArrayLike, ...], ArrayLike]).
    Then we could do without the for-loop in each of the subclass implementations.
    """
points_and_weights(a=-1.0, b=1.0) abstractmethod

Get quadrature points and weights for performing integration on \([a,b]\).

Source code in src/causalprog/quadrature/base.py
60
61
62
63
64
@abstractmethod
def points_and_weights(
    self, a: float = -1.0, b: float = 1.0
) -> tuple[npt.NDArray, npt.NDArray]:
    """Get quadrature points and weights for performing integration on $[a,b]$."""
pts_wts_tuples(a=-1.0, b=1.0)

Get (point, weight) pairs as a list of tuples.

Source code in src/causalprog/quadrature/base.py
66
67
68
69
70
def pts_wts_tuples(
    self, a: float = -1.0, b: float = 1.0
) -> list[tuple[float, float]]:
    """Get `(point, weight)` pairs as a list of tuples."""
    return list(zip(*self.points_and_weights(a=a, b=b), strict=True))

RNGQuadratureMethod

Bases: QuadratureMethod

An abstract quadrature method, that relies on RNG.

The only difference from the base QuadratureMethod class is the requirement that an rng_key be provided to the instance at creation.

Source code in src/causalprog/quadrature/base.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class RNGQuadratureMethod(QuadratureMethod):
    """
    An abstract quadrature method, that relies on RNG.

    The only difference from the base `QuadratureMethod` class is the requirement
    that an `rng_key` be provided to the instance at creation.
    """

    rng_key: Array

    def __init__(self, n_points: int, *, rng_key: Array) -> None:
        """
        Initialise.

        Args:
            n_points: The number of quadrature points
            rng_key: PRNG key used for sample generation

        """
        super().__init__(n_points)

        self.rng_key = rng_key
__init__(n_points, *, rng_key)

Initialise.

Parameters:

Name Type Description Default
n_points int

The number of quadrature points

required
rng_key Array

PRNG key used for sample generation

required
Source code in src/causalprog/quadrature/base.py
83
84
85
86
87
88
89
90
91
92
93
94
def __init__(self, n_points: int, *, rng_key: Array) -> None:
    """
    Initialise.

    Args:
        n_points: The number of quadrature points
        rng_key: PRNG key used for sample generation

    """
    super().__init__(n_points)

    self.rng_key = rng_key

gaussian

Gaussian quadrature.

GaussianQuadrature

Bases: QuadratureMethod

A Gaussian quadrature rule.

The domain of integration for the points \(p_i\) and weights \(w_i\) is \([-1,1]\). This means that to integrate an integrand \(f\) over the interval \([a,b]\), the approximation

\[ \int_a^b f(x) dx \approx \frac{b - a}{2}\sum_{w_i}f\left( \frac{b - a}{2} p_i + \frac{b + a}{2}\right) \]

is used.

Source code in src/causalprog/quadrature/gaussian.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class GaussianQuadrature(QuadratureMethod):
    r"""
    A Gaussian quadrature rule.

    The domain of integration for the points $p_i$ and weights $w_i$ is $[-1,1]$.
    This means that to integrate an integrand $f$ over the interval $[a,b]$, the
    approximation

    $$
    \int_a^b f(x) dx \approx
    \frac{b - a}{2}\sum_{w_i}f\left( \frac{b - a}{2} p_i + \frac{b + a}{2}\right)
    $$

    is used.
    """

    _pts: npt.NDArray
    _wts: npt.NDArray

    @override
    def __init__(self, n_points: int) -> None:
        super().__init__(n_points)
        pts, wts = quadraturerules.single_integral_quadrature(
            quadraturerules.QuadratureRule.GaussLegendre,
            quadraturerules.Domain.Interval,
            n_points,
        )
        self._pts = pts[:, 1] - pts[:, 0]
        self._wts = wts * 2.0

    def integrate(
        self,
        integrand: Integrand,
        a: float = -1,
        b: float = 1,
        *integrand_args: IntegrandArgs.args,
        **integrand_kwargs: IntegrandArgs.kwargs,
    ) -> float:
        """Integrate the `integrand` over $[a,b]$ via Gaussian quadrature."""
        result = 0.0
        for p_i, w_i in self.pts_wts_tuples(a=a, b=b):
            result += w_i * integrand(
                p_i,
                *integrand_args,
                **integrand_kwargs,
            )

        return result * (b - a) / 2

    def points_and_weights(
        self, a: float = -1.0, b: float = 1.0
    ) -> tuple[npt.NDArray, npt.NDArray]:
        """Get quadrature points and weights for performing integration on $[a,b]$."""
        change_of_vars_derivative = (b - a) / 2.0
        interval_midpoint = (b + a) / 2.0
        return self._pts * change_of_vars_derivative + interval_midpoint, self._wts
integrate(integrand, a=-1, b=1, *integrand_args, **integrand_kwargs)

Integrate the integrand over \([a,b]\) via Gaussian quadrature.

Source code in src/causalprog/quadrature/gaussian.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def integrate(
    self,
    integrand: Integrand,
    a: float = -1,
    b: float = 1,
    *integrand_args: IntegrandArgs.args,
    **integrand_kwargs: IntegrandArgs.kwargs,
) -> float:
    """Integrate the `integrand` over $[a,b]$ via Gaussian quadrature."""
    result = 0.0
    for p_i, w_i in self.pts_wts_tuples(a=a, b=b):
        result += w_i * integrand(
            p_i,
            *integrand_args,
            **integrand_kwargs,
        )

    return result * (b - a) / 2
points_and_weights(a=-1.0, b=1.0)

Get quadrature points and weights for performing integration on \([a,b]\).

Source code in src/causalprog/quadrature/gaussian.py
59
60
61
62
63
64
65
def points_and_weights(
    self, a: float = -1.0, b: float = 1.0
) -> tuple[npt.NDArray, npt.NDArray]:
    """Get quadrature points and weights for performing integration on $[a,b]$."""
    change_of_vars_derivative = (b - a) / 2.0
    interval_midpoint = (b + a) / 2.0
    return self._pts * change_of_vars_derivative + interval_midpoint, self._wts

monte_carlo

Monte Carlo quadrature.

MonteCarloGaussianQuadrature

Bases: RNGQuadratureMethod

Monte Carlo quadrature, sampled from a standard Gaussian.

Let \(N\) be the number of sample points to be used by the scheme. The quadrature method approximates the integral

\[ \int_a^b f(x) dx \approx \frac{1}{N}\sum_{p_i} \frac{f(p_i)}{\mathcal{P}(p_i)}, \]

where \(p_i\in[a,b]\) are \(N\) samples drawn from a standard Gaussian.

Source code in src/causalprog/quadrature/monte_carlo.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class MonteCarloGaussianQuadrature(RNGQuadratureMethod):
    r"""
    Monte Carlo quadrature, sampled from a standard Gaussian.

    Let $N$ be the number of sample points to be used by the scheme.
    The quadrature method approximates the integral

    $$
    \int_a^b f(x) dx
    \approx \frac{1}{N}\sum_{p_i} \frac{f(p_i)}{\mathcal{P}(p_i)},
    $$

    where $p_i\in[a,b]$ are $N$ samples drawn from a standard Gaussian.
    """

    def integrate(
        self,
        integrand: Integrand,
        a: float = -1.0,
        b: float = 1.0,
        *integrand_args: IntegrandArgs.args,
        **integrand_kwargs: IntegrandArgs.kwargs,
    ) -> float:
        """Perform Monte-Carlo integration of the `integrand` over $[a,b]$."""
        result = 0.0

        for p_i, w_i in self.pts_wts_tuples(a=a, b=b):
            result += integrand(p_i, *integrand_args, **integrand_kwargs) / w_i

        return result / self.n_points

    @override
    def points_and_weights(
        self, a: float = -1.0, b: float = 1.0
    ) -> tuple[npt.NDArray, npt.NDArray]:
        pts = jax.random.truncated_normal(
            self.rng_key, lower=a, upper=b, shape=(self.n_points,)
        )
        wts = jax.scipy.stats.truncnorm.pdf(pts, a, b)
        return pts, wts
integrate(integrand, a=-1.0, b=1.0, *integrand_args, **integrand_kwargs)

Perform Monte-Carlo integration of the integrand over \([a,b]\).

Source code in src/causalprog/quadrature/monte_carlo.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def integrate(
    self,
    integrand: Integrand,
    a: float = -1.0,
    b: float = 1.0,
    *integrand_args: IntegrandArgs.args,
    **integrand_kwargs: IntegrandArgs.kwargs,
) -> float:
    """Perform Monte-Carlo integration of the `integrand` over $[a,b]$."""
    result = 0.0

    for p_i, w_i in self.pts_wts_tuples(a=a, b=b):
        result += integrand(p_i, *integrand_args, **integrand_kwargs) / w_i

    return result / self.n_points

UniformWeightMonteCarloGaussianQuadrature

Bases: RNGQuadratureMethod

Monte Carlo quadrature, sampled from a standard Gaussian, but using uniform weights.

Let \(N\) be the number of sample points to be used by the scheme. The quadrature method approximates the integral

\[ \int_a^b f(x) dx \approx \frac{1}{N}\sum_{i} f(x_i), \]

where \(x_i\in[a,b]\) are \(N\) samples drawn from a standard Gaussian.

This is effectively computing \(\mathbb{E}[f(X)]\) when \(X\) is distributed according to a truncated normal on \([a, b]\) with mean 0 and standard deviation 1. As one would expect, the above rule for integrating \(f\) is identical to conducting standard Monte-Carlo integration (with Gaussian importance sampling), but on the integrand \(F(x) = f(x)\mathcal{P}(x)\), where \(\mathcal{P}\) is the PDF of a (truncated to \([a, b]\)) normal distribution.

Source code in src/causalprog/quadrature/monte_carlo.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
class UniformWeightMonteCarloGaussianQuadrature(RNGQuadratureMethod):
    r"""
    Monte Carlo quadrature, sampled from a standard Gaussian, but using uniform weights.

    Let $N$ be the number of sample points to be used by the scheme.
    The quadrature method approximates the integral

    $$
    \int_a^b f(x) dx
    \approx \frac{1}{N}\sum_{i} f(x_i),
    $$

    where $x_i\in[a,b]$ are $N$ samples drawn from a standard Gaussian.

    This is effectively computing $\mathbb{E}[f(X)]$ when $X$ is distributed according
    to a truncated normal on $[a, b]$ with mean 0 and standard deviation 1. As one
    would expect, the above rule for integrating $f$ is identical to conducting standard
    Monte-Carlo integration (with Gaussian importance sampling), but on the integrand
    $F(x) = f(x)\mathcal{P}(x)$, where $\mathcal{P}$ is the PDF of a
    (truncated to $[a, b]$) normal distribution.
    """

    def integrate(
        self,
        integrand: Integrand,
        a: float = -1.0,
        b: float = 1.0,
        *integrand_args: IntegrandArgs.args,
        **integrand_kwargs: IntegrandArgs.kwargs,
    ) -> float:
        r"""
        Perform Monte-Carlo integration of the `integrand` over $[a,b]$.

        In terms of the concrete classes in the codebase; if `P`
        again represents the PDF of a truncated normal distribution, the following are
        identical:
        - `UniformWeightGaussianSamplesMonteCarloQuadrature.integrate(f, ...)`
        - `MonteCarloGaussianQuadrature.integrate(f/P, ...)`.
        """
        pts, _ = self.points_and_weights(a=a, b=b)
        ptwise_evaluation: jax.Array = jax.vmap(
            lambda x: integrand(x, *integrand_args, **integrand_kwargs)
        )(pts)

        return ptwise_evaluation.sum() / self.n_points

    @override
    def points_and_weights(
        self, a: float = -1.0, b: float = 1.0
    ) -> tuple[npt.NDArray, npt.NDArray]:
        pts = jax.random.truncated_normal(
            self.rng_key, lower=a, upper=b, shape=(self.n_points,)
        )
        wts = jax.numpy.full((self.n_points,), 1.0 / self.n_points)
        return pts, wts
integrate(integrand, a=-1.0, b=1.0, *integrand_args, **integrand_kwargs)

Perform Monte-Carlo integration of the integrand over \([a,b]\).

In terms of the concrete classes in the codebase; if P again represents the PDF of a truncated normal distribution, the following are identical: - UniformWeightGaussianSamplesMonteCarloQuadrature.integrate(f, ...) - MonteCarloGaussianQuadrature.integrate(f/P, ...).

Source code in src/causalprog/quadrature/monte_carlo.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def integrate(
    self,
    integrand: Integrand,
    a: float = -1.0,
    b: float = 1.0,
    *integrand_args: IntegrandArgs.args,
    **integrand_kwargs: IntegrandArgs.kwargs,
) -> float:
    r"""
    Perform Monte-Carlo integration of the `integrand` over $[a,b]$.

    In terms of the concrete classes in the codebase; if `P`
    again represents the PDF of a truncated normal distribution, the following are
    identical:
    - `UniformWeightGaussianSamplesMonteCarloQuadrature.integrate(f, ...)`
    - `MonteCarloGaussianQuadrature.integrate(f/P, ...)`.
    """
    pts, _ = self.points_and_weights(a=a, b=b)
    ptwise_evaluation: jax.Array = jax.vmap(
        lambda x: integrand(x, *integrand_args, **integrand_kwargs)
    )(pts)

    return ptwise_evaluation.sum() / self.n_points

solvers

Solvers for Causal Problems.

iteration_result

Container classes for outputs from each iteration of solver methods.

IterationResult dataclass

Result container for iterative solvers with optional history logging.

Stores the latest iterate and if history_logging_interval > 0, update appends snapshots of the iterate to corresponding history lists each time the iteration number is a multiple of history_logging_interval. Instances are mutable but do not allow dynamic attribute creation.

Parameters:

Name Type Description Default
fn_args PyTree

Argument to the objective function at final iteration (the solution, if successful isTrue`).

required
grad_val PyTree

Value of the gradient of the objective function at the fn_args.

required
iters int

Number of iterations performed.

required
obj_val ArrayLike

Value of the objective function at fn_args.

required
iter_history list[int]

List of iteration numbers at which history was logged.

list()
fn_args_history list[PyTree]

List of fn_args at each logged iteration.

list()
grad_val_history list[PyTree]

List of grad_val at each logged iteration.

list()
obj_val_history list[ArrayLike]

List of obj_val at each logged iteration.

list()
history_logging_interval int

Interval at which to log history. If history_logging_interval <= 0, then no history is logged.

0
Source code in src/causalprog/solvers/iteration_result.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@dataclass(frozen=False, slots=True)
class IterationResult:
    """
    Result container for iterative solvers with optional history logging.

    Stores the latest iterate and if `history_logging_interval > 0`, `update` appends
    snapshots of the iterate to corresponding history lists each time the iteration
    number is a multiple of `history_logging_interval`.
    Instances are mutable but do not allow dynamic attribute creation.

    Args:
        fn_args: Argument to the objective function at final iteration (the solution,
            if `successful is `True`).
        grad_val: Value of the gradient of the objective function at the `fn_args`.
        iters: Number of iterations performed.
        obj_val: Value of the objective function at `fn_args`.
        iter_history: List of iteration numbers at which history was logged.
        fn_args_history: List of `fn_args` at each logged iteration.
        grad_val_history: List of `grad_val` at each logged iteration.
        obj_val_history: List of `obj_val` at each logged iteration.
        history_logging_interval: Interval at which to log history. If
            `history_logging_interval <= 0`, then no history is logged.

    """

    fn_args: PyTree
    grad_val: PyTree
    iters: int
    obj_val: npt.ArrayLike
    history_logging_interval: int = 0

    iter_history: list[int] = field(default_factory=list)
    fn_args_history: list[PyTree] = field(default_factory=list)
    grad_val_history: list[PyTree] = field(default_factory=list)
    obj_val_history: list[npt.ArrayLike] = field(default_factory=list)

    _log_enabled: bool = field(init=False, repr=False)

    def __post_init__(self) -> None:
        self._log_enabled = self.history_logging_interval > 0

    def update(
        self,
        current_params: PyTree,
        gradient_value: PyTree,
        iters: int,
        objective_value: npt.ArrayLike,
    ) -> None:
        """
        Update the `IterationResult` object with current iteration data.

        Only updates the history if `history_logging_interval` is positive and
        the current iteration is a multiple of `history_logging_interval`.

        """
        self.fn_args = current_params
        self.grad_val = gradient_value
        self.iters = iters
        self.obj_val = objective_value

        if self._log_enabled and iters % self.history_logging_interval == 0:
            self.iter_history.append(iters)
            self.fn_args_history.append(current_params)
            self.grad_val_history.append(gradient_value)
            self.obj_val_history.append(objective_value)
update(current_params, gradient_value, iters, objective_value)

Update the IterationResult object with current iteration data.

Only updates the history if history_logging_interval is positive and the current iteration is a multiple of history_logging_interval.

Source code in src/causalprog/solvers/iteration_result.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def update(
    self,
    current_params: PyTree,
    gradient_value: PyTree,
    iters: int,
    objective_value: npt.ArrayLike,
) -> None:
    """
    Update the `IterationResult` object with current iteration data.

    Only updates the history if `history_logging_interval` is positive and
    the current iteration is a multiple of `history_logging_interval`.

    """
    self.fn_args = current_params
    self.grad_val = gradient_value
    self.iters = iters
    self.obj_val = objective_value

    if self._log_enabled and iters % self.history_logging_interval == 0:
        self.iter_history.append(iters)
        self.fn_args_history.append(current_params)
        self.grad_val_history.append(gradient_value)
        self.obj_val_history.append(objective_value)

sgd

Minimisation via Stochastic Gradient Descent.

stochastic_gradient_descent(obj_fn, initial_guess, *, convergence_criteria=None, fn_args=None, fn_kwargs=None, learning_rate=0.1, maxiter=1000, optimiser=None, tolerance=1e-08, history_logging_interval=-1, callbacks=None)

Minimise a function of one argument using Stochastic Gradient Descent (SGD).

The obj_fn provided will be minimised over its first argument. If you wish to minimise a function over a different argument, or multiple arguments, wrap it in a suitable lambda expression that has the correct call signature. For example, to minimise a function f(x, y, z) over y and z, use g = lambda yz, x: f(x, yz[0], yz[1]), and pass g in as obj_fn. Note that you will also need to provide a constant value for x via fn_args or fn_kwargs.

The fn_args and fn_kwargs keys can be used to supply additional parameters that need to be passed to obj_fn, but which should be held constant.

SGD terminates when the convergence_criteria is found to be smaller than the tolerance. That is, when convergence_criteria(objective_value, gradient_value) <= tolerance is found to be True, the algorithm considers a minimum to have been found. The default condition under which the algorithm terminates is when the norm of the gradient at the current argument value is smaller than the provided tolerance.

The optimiser to use can be selected by passing in a suitable optax optimiser via the optimiser command. By default, optax.adams is used with the supplied learning_rate. Providing an explicit value for optimiser will result in the learning_rate argument being ignored.

Parameters:

Name Type Description Default
obj_fn Callable[[PyTree], ArrayLike]

Function to be minimised over its first argument.

required
initial_guess PyTree

Initial guess for the minimising argument.

required
convergence_criteria Callable[[PyTree, PyTree], ArrayLike] | None

The quantity that will be tested against tolerance, to determine whether the method has converged to a minimum. It should be a callable that takes the current value of obj_fn as its 1st argument, and the current value of the gradient of obj_fn as its 2nd argument. The default criteria is the l2-norm of the gradient.

None
fn_args tuple | None

Positional arguments to be passed to obj_fn, and held constant.

None
fn_kwargs dict | None

Keyword arguments to be passed to obj_fn, and held constant.

None
learning_rate float

Default learning rate (or step size) to use when using the default optimiser. No effect if optimiser is provided explicitly.

0.1
maxiter int

Maximum number of iterations to perform. An error will be reported if this number of iterations is exceeded.

1000
optimiser GradientTransformationExtraArgs | None

The optax optimiser to use during the update step.

None
tolerance float

tolerance used when determining if a minimum has been found.

1e-08
history_logging_interval int

Interval (in number of iterations) at which to log the history of optimisation. If history_logging_interval <= 0, no history is logged.

-1
callbacks Callable[[IterationResult], None] | list[Callable[[IterationResult], None]] | None

A callable or list of callables that take an IterationResult as their only argument, and return None. These will be called at the end of each iteration of the optimisation procedure.

None

Returns:

Name Type Description
SolverResult SolverResult

Result of the optimisation procedure.

Source code in src/causalprog/solvers/sgd.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def stochastic_gradient_descent(
    obj_fn: Callable[[PyTree], npt.ArrayLike],
    initial_guess: PyTree,
    *,
    convergence_criteria: Callable[[PyTree, PyTree], npt.ArrayLike] | None = None,
    fn_args: tuple | None = None,
    fn_kwargs: dict | None = None,
    learning_rate: float = 1.0e-1,
    maxiter: int = 1000,
    optimiser: optax.GradientTransformationExtraArgs | None = None,
    tolerance: float = 1.0e-8,
    history_logging_interval: int = -1,
    callbacks: Callable[[IterationResult], None]
    | list[Callable[[IterationResult], None]]
    | None = None,
) -> SolverResult:
    """
    Minimise a function of one argument using Stochastic Gradient Descent (SGD).

    The `obj_fn` provided will be minimised over its first argument. If you wish to
    minimise a function over a different argument, or multiple arguments, wrap it in a
    suitable `lambda` expression that has the correct call signature. For example, to
    minimise a function `f(x, y, z)` over `y` and `z`, use
    `g = lambda yz, x: f(x, yz[0], yz[1])`, and pass `g` in as `obj_fn`. Note that
    you will also need to provide a constant value for `x` via `fn_args` or `fn_kwargs`.

    The `fn_args` and `fn_kwargs` keys can be used to supply additional parameters that
    need to be passed to `obj_fn`, but which should be held constant.

    SGD terminates when the `convergence_criteria` is found to be smaller than the
    `tolerance`. That is, when
    `convergence_criteria(objective_value, gradient_value) <= tolerance` is found to
    be `True`, the algorithm considers a minimum to have been found. The default
    condition under which the algorithm terminates is when the norm of the gradient
    at the current argument value is smaller than the provided `tolerance`.

    The optimiser to use can be selected by passing in a suitable `optax` optimiser
    via the `optimiser` command. By default, `optax.adams` is used with the supplied
    `learning_rate`. Providing an explicit value for `optimiser` will result in the
    `learning_rate` argument being ignored.

    Args:
        obj_fn: Function to be minimised over its first argument.
        initial_guess: Initial guess for the minimising argument.
        convergence_criteria: The quantity that will be tested against `tolerance`, to
            determine whether the method has converged to a minimum. It should be a
            `callable` that takes the current value of `obj_fn` as its 1st argument, and
            the current value of the gradient of `obj_fn` as its 2nd argument. The
            default criteria is the l2-norm of the gradient.
        fn_args: Positional arguments to be passed to `obj_fn`, and held constant.
        fn_kwargs: Keyword arguments to be passed to `obj_fn`, and held constant.
        learning_rate: Default learning rate (or step size) to use when using the
            default `optimiser`. No effect if `optimiser` is provided explicitly.
        maxiter: Maximum number of iterations to perform. An error will be reported if
            this number of iterations is exceeded.
        optimiser: The `optax` optimiser to use during the update step.
        tolerance: `tolerance` used when determining if a minimum has been found.
        history_logging_interval: Interval (in number of iterations) at which to log
            the history of optimisation. If history_logging_interval <= 0, no
            history is logged.
        callbacks: A `callable` or list of `callables` that take an
            `IterationResult` as their only argument, and return `None`.
            These will be called at the end of each iteration of the optimisation
            procedure.


    Returns:
        SolverResult: Result of the optimisation procedure.

    """
    if not fn_args:
        fn_args = ()
    if not fn_kwargs:
        fn_kwargs = {}
    if not convergence_criteria:
        convergence_criteria = lambda _, dx: jnp.sqrt(l2_normsq(dx))  # noqa: E731
    if not optimiser:
        optimiser = optax.adam(learning_rate)

    callbacks = _normalise_callbacks(callbacks)

    def objective(x: npt.ArrayLike) -> npt.ArrayLike:
        return obj_fn(x, *fn_args, **fn_kwargs)

    def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool:
        return convergence_criteria(x, dx) < tolerance

    value_and_grad_fn = jax.jit(jax.value_and_grad(objective))

    # init state
    opt_state = optimiser.init(initial_guess)
    current_params = deepcopy(initial_guess)
    converged = False
    objective_value, gradient_value = value_and_grad_fn(current_params)

    iter_result = IterationResult(
        fn_args=current_params,
        grad_val=gradient_value,
        iters=0,
        obj_val=objective_value,
        history_logging_interval=history_logging_interval,
    )

    for current_iter in range(maxiter + 1):
        iter_result.update(
            current_params=current_params,
            gradient_value=gradient_value,
            iters=current_iter,
            objective_value=objective_value,
        )

        _run_callbacks(iter_result, callbacks)

        if converged := is_converged(objective_value, gradient_value):
            break

        updates, opt_state = optimiser.update(gradient_value, opt_state)
        current_params = optax.apply_updates(current_params, updates)

        objective_value, gradient_value = value_and_grad_fn(current_params)

    iters_used = current_iter
    reason_msg = (
        f"Did not converge after {iters_used} iterations" if not converged else ""
    )

    return SolverResult(
        fn_args=current_params,
        grad_val=gradient_value,
        iters=iters_used,
        maxiter=maxiter,
        obj_val=objective_value,
        reason=reason_msg,
        successful=converged,
        iter_history=iter_result.iter_history,
        fn_args_history=iter_result.fn_args_history,
        grad_val_history=iter_result.grad_val_history,
        obj_val_history=iter_result.obj_val_history,
    )

solver_callbacks

Module for callback functions for solvers.

tqdm_callback(total)

Progress bar callback using tqdm.

Creates a callback function that can be passed to solvers to display a progress bar during optimization. The progress bar updates based on the number of iterations and also displays the current objective value.

Parameters:

Name Type Description Default
total int

Total number of iterations for the progress bar.

required

Returns:

Type Description
Callable[[IterationResult], None]

Callback function that updates the progress bar.

Source code in src/causalprog/solvers/solver_callbacks.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def tqdm_callback(total: int) -> Callable[[IterationResult], None]:
    """
    Progress bar callback using `tqdm`.

    Creates a callback function that can be passed to solvers to display a progress bar
    during optimization. The progress bar updates based on the number of iterations and
    also displays the current objective value.

    Args:
        total: Total number of iterations for the progress bar.

    Returns:
        Callback function that updates the progress bar.

    """
    bar = tqdm(total=total)
    last_it = {"i": 0}

    def cb(ir: IterationResult) -> None:
        step = ir.iters - last_it["i"]
        if step > 0:
            bar.update(step)

            # Show objective and grad norm
            bar.set_postfix(obj=float(ir.obj_val))
            last_it["i"] = ir.iters

    return cb

solver_result

Container class for outputs from solver methods.

SolverResult dataclass

Container class for outputs from solver methods.

Instances of this class provide a container for useful information that comes out of running one of the solver methods on a causal problem.

Attributes:

Name Type Description
fn_args PyTree

Argument to the objective function at final iteration (the solution, if successful isTrue`).

grad_val PyTree

Value of the gradient of the objective function at the fn_args.

iters int

Number of iterations performed.

maxiter int

Maximum number of iterations the solver was permitted to perform.

obj_val ArrayLike

Value of the objective function at fn_args.

reason str

Human-readable string explaining success or reasons for solver failure.

successful bool

True if solver converged, in which case fn_args is the argument to the objective function at the solution of the problem being solved. False otherwise.

iter_history list[int]

List of iteration numbers at which history was logged.

fn_args_history list[PyTree]

List of fn_args at each logged iteration.

grad_val_history list[PyTree]

List of grad_val at each logged iteration.

obj_val_history list[ArrayLike]

List of obj_val at each logged iteration.

Source code in src/causalprog/solvers/solver_result.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
@dataclass(frozen=True)
class SolverResult:
    """
    Container class for outputs from solver methods.

    Instances of this class provide a container for useful information that
    comes out of running one of the solver methods on a causal problem.

    Attributes:
        fn_args: Argument to the objective function at final iteration (the solution,
            if `successful is `True`).
        grad_val: Value of the gradient of the objective function at the `fn_args`.
        iters: Number of iterations performed.
        maxiter: Maximum number of iterations the solver was permitted to perform.
        obj_val: Value of the objective function at `fn_args`.
        reason: Human-readable string explaining success or reasons for solver failure.
        successful: `True` if solver converged, in which case `fn_args` is the
            argument to the objective function at the solution of the problem being
            solved. `False` otherwise.
        iter_history: List of iteration numbers at which history was logged.
        fn_args_history: List of `fn_args` at each logged iteration.
        grad_val_history: List of `grad_val` at each logged iteration.
        obj_val_history: List of `obj_val` at each logged iteration.

    """

    fn_args: PyTree
    grad_val: PyTree
    iters: int
    maxiter: int
    obj_val: npt.ArrayLike
    reason: str
    successful: bool

    iter_history: list[int] = field(default_factory=list)
    fn_args_history: list[PyTree] = field(default_factory=list)
    grad_val_history: list[PyTree] = field(default_factory=list)
    obj_val_history: list[npt.ArrayLike] = field(default_factory=list)

utils

Utility classes and methods.

norms

Misc collection of norm-like functions for PyTree structures.

l2_normsq(x)

Square of the l2-norm of a PyTree.

This is effectively "sum(elements**2 in leaf for leaf in x)".

Source code in src/causalprog/utils/norms.py
11
12
13
14
15
16
17
18
def l2_normsq(x: PyTree) -> npt.ArrayLike:
    """
    Square of the l2-norm of a PyTree.

    This is effectively "sum(elements**2 in leaf for leaf in x)".
    """
    leaves, _ = jax.tree_util.tree_flatten(x)
    return sum(jax.numpy.sum(leaf**2) for leaf in leaves)

translator

Helper class to keep the codebase backend-agnostic.

Our frontend (or user-facing) classes each use a syntax that applies across the package codebase. By contrast, the various backends that we want to support will have different syntaxes and call signatures for the functions that we want to support. As such, we need a helper class that can store this "translation" information, allowing the user to interact with the package in a standard way but also allowing them to choose their own backend if desired.

Translator

Bases: ABC

Maps syntax of a backend function to our frontend syntax.

Different backends have different syntax for drawing samples from the distributions they support. In order to map these different syntaxes to our backend-agnostic framework, we need a container class to map the names we have chosen for our frontend methods to those used by their corresponding backend method.

A Translator allows us to identify whether a user-provided backend object is compatible with one of our frontend wrapper classes (and thus, call signatures). It also allows users to write their own translators for any custom backends that we do not explicitly support.

The use case for a Translator is as follows. Suppose that we have a frontend class C that needs to provide a method do_something. C stores a reference to a backend object obj that can provide the functionality of do_something via one of its methods, obj.backend_method. However, there is no guarantee that the signature of do_something maps identically to that of obj.backend_method. A Translator allows us to encode a mapping of obj.backend_methods arguments to those of do_something.

Source code in src/causalprog/utils/translator.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class Translator(ABC):
    """
    Maps syntax of a backend function to our frontend syntax.

    Different backends have different syntax for drawing samples from the distributions
    they support. In order to map these different syntaxes to our backend-agnostic
    framework, we need a container class to map the names we have chosen for our
    frontend methods to those used by their corresponding backend method.

    A ``Translator`` allows us to identify whether a user-provided backend object is
    compatible with one of our frontend wrapper classes (and thus, call signatures). It
    also allows users to write their own translators for any custom backends that we do
    not explicitly support.

    The use case for a ``Translator`` is as follows. Suppose that we have a frontend
    class ``C`` that needs to provide a method ``do_something``. ``C`` stores a
    reference to a backend object ``obj`` that can provide the functionality of
    ``do_something`` via one of its methods, ``obj.backend_method``. However, there is
    no guarantee that the signature of ``do_something`` maps identically to that of
    ``obj.backend_method``. A ``Translator`` allows us to encode a mapping of
    ``obj.backend_method``s arguments to those of ``do_something``.
    """

    backend_method: str
    corresponding_backend_arg: dict[str, str]

    @property
    @abstractmethod
    def _frontend_method(self) -> str:
        """Name of the frontend method that the backend is to be translated into."""

    @property
    @abstractmethod
    def compulsory_frontend_args(self) -> set[str]:
        """Arguments that are required by the frontend function."""

    @property
    def compulsory_backend_args(self) -> set[str]:
        """Arguments that are required to be taken by the backend function."""
        return {
            self.corresponding_backend_arg[arg_name]
            for arg_name in self.compulsory_frontend_args
        }

    def __init__(
        self, backend_method: str | None = None, **front_args_to_back_args: str
    ) -> None:
        """
        Create a new Translator.

        Args:
            backend_method (str): Name of the backend method that the instance
                translates.
            **front_args_to_back_args (str): Mapping of frontend argument names to the
                corresponding backend argument names.

        """
        # Assume backend name is identical to frontend name if not provided explicitly
        self.backend_method = backend_method or self._frontend_method

        # This should really be immutable after we fill defaults!
        self.corresponding_backend_arg = dict(front_args_to_back_args)
        # Assume compulsory frontend args that are not given translations
        # retain their name in the backend.
        for arg in self.compulsory_frontend_args:
            if arg not in self.corresponding_backend_arg:
                self.corresponding_backend_arg[arg] = arg

    def translate_args(self, **kwargs: Any) -> dict[str, Any]:  # noqa: ANN401
        """
        Translate frontend arguments (with values) to backend arguments.

        Essentially transforms frontend keyword arguments into their backend keyword
        arguments, preserving the value assigned to each argument.
        """
        return {
            self.corresponding_backend_arg[arg_name]: arg_value
            for arg_name, arg_value in kwargs.items()
        }

    def validate_compatible(self, obj: object) -> None:
        """
        Determine if ``obj`` provides a compatible backend method.

        ``obj`` must provide a callable whose name matches ``self.backend_method``,
        and the callable referenced must take arguments matching the names specified in
        ``self.compulsory_backend_args``.

        Args:
            obj (object): Object to check possesses a method that can be translated into
                frontend syntax.

        """
        # Check that obj does provide a method of matching name
        if not hasattr(obj, self.backend_method):
            msg = f"{obj} has no method '{self.backend_method}'."
            raise AttributeError(msg)
        if not callable(getattr(obj, self.backend_method)):
            msg = f"'{self.backend_method}' attribute of {obj} is not callable."
            raise TypeError(msg)

        # Check that this method will be callable with the information given.
        method_params = inspect.signature(getattr(obj, self.backend_method)).parameters
        # The arguments that will be passed are actually taken by the method.
        for compulsory_arg in self.compulsory_backend_args:
            if compulsory_arg not in method_params:
                msg = (
                    f"'{self.backend_method}' does not "
                    f"take argument '{compulsory_arg}'."
                )
                raise TypeError(msg)
        # The method does not _require_ any additional arguments
        method_requires = {
            name for name, p in method_params.items() if p.default is p.empty
        }
        if not method_requires.issubset(self.compulsory_backend_args):
            args_not_accounted_for = method_requires - self.compulsory_backend_args
            raise TypeError(
                f"'{self.backend_method}' not provided compulsory arguments "
                "(missing " + ", ".join(args_not_accounted_for) + ")"
            )
compulsory_backend_args property

Arguments that are required to be taken by the backend function.

compulsory_frontend_args abstractmethod property

Arguments that are required by the frontend function.

__init__(backend_method=None, **front_args_to_back_args)

Create a new Translator.

Parameters:

Name Type Description Default
backend_method str

Name of the backend method that the instance translates.

None
**front_args_to_back_args str

Mapping of frontend argument names to the corresponding backend argument names.

{}
Source code in src/causalprog/utils/translator.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(
    self, backend_method: str | None = None, **front_args_to_back_args: str
) -> None:
    """
    Create a new Translator.

    Args:
        backend_method (str): Name of the backend method that the instance
            translates.
        **front_args_to_back_args (str): Mapping of frontend argument names to the
            corresponding backend argument names.

    """
    # Assume backend name is identical to frontend name if not provided explicitly
    self.backend_method = backend_method or self._frontend_method

    # This should really be immutable after we fill defaults!
    self.corresponding_backend_arg = dict(front_args_to_back_args)
    # Assume compulsory frontend args that are not given translations
    # retain their name in the backend.
    for arg in self.compulsory_frontend_args:
        if arg not in self.corresponding_backend_arg:
            self.corresponding_backend_arg[arg] = arg
translate_args(**kwargs)

Translate frontend arguments (with values) to backend arguments.

Essentially transforms frontend keyword arguments into their backend keyword arguments, preserving the value assigned to each argument.

Source code in src/causalprog/utils/translator.py
85
86
87
88
89
90
91
92
93
94
95
def translate_args(self, **kwargs: Any) -> dict[str, Any]:  # noqa: ANN401
    """
    Translate frontend arguments (with values) to backend arguments.

    Essentially transforms frontend keyword arguments into their backend keyword
    arguments, preserving the value assigned to each argument.
    """
    return {
        self.corresponding_backend_arg[arg_name]: arg_value
        for arg_name, arg_value in kwargs.items()
    }
validate_compatible(obj)

Determine if obj provides a compatible backend method.

obj must provide a callable whose name matches self.backend_method, and the callable referenced must take arguments matching the names specified in self.compulsory_backend_args.

Parameters:

Name Type Description Default
obj object

Object to check possesses a method that can be translated into frontend syntax.

required
Source code in src/causalprog/utils/translator.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def validate_compatible(self, obj: object) -> None:
    """
    Determine if ``obj`` provides a compatible backend method.

    ``obj`` must provide a callable whose name matches ``self.backend_method``,
    and the callable referenced must take arguments matching the names specified in
    ``self.compulsory_backend_args``.

    Args:
        obj (object): Object to check possesses a method that can be translated into
            frontend syntax.

    """
    # Check that obj does provide a method of matching name
    if not hasattr(obj, self.backend_method):
        msg = f"{obj} has no method '{self.backend_method}'."
        raise AttributeError(msg)
    if not callable(getattr(obj, self.backend_method)):
        msg = f"'{self.backend_method}' attribute of {obj} is not callable."
        raise TypeError(msg)

    # Check that this method will be callable with the information given.
    method_params = inspect.signature(getattr(obj, self.backend_method)).parameters
    # The arguments that will be passed are actually taken by the method.
    for compulsory_arg in self.compulsory_backend_args:
        if compulsory_arg not in method_params:
            msg = (
                f"'{self.backend_method}' does not "
                f"take argument '{compulsory_arg}'."
            )
            raise TypeError(msg)
    # The method does not _require_ any additional arguments
    method_requires = {
        name for name, p in method_params.items() if p.default is p.empty
    }
    if not method_requires.issubset(self.compulsory_backend_args):
        args_not_accounted_for = method_requires - self.compulsory_backend_args
        raise TypeError(
            f"'{self.backend_method}' not provided compulsory arguments "
            "(missing " + ", ".join(args_not_accounted_for) + ")"
        )