Skip to content

API reference

causalprog package.

algorithms

Algorithms.

do

Algorithms for applying do to a graph.

do(graph, node, value, label=None)

Apply do to a graph.

Parameters:

Name Type Description Default
graph Graph

The graph to apply do to. This will be copied.

required
node str

The label of the node to apply do to.

required
value float

The value to set the node to.

required
label str | None

The label of the new graph

None

Returns:

Type Description
Graph

A copy of the graph with do applied

Source code in src/causalprog/algorithms/do.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph:
    """
    Apply do to a graph.

    Args:
        graph: The graph to apply do to. This will be copied.
        node: The label of the node to apply do to.
        value: The value to set the node to.
        label: The label of the new graph

    Returns:
        A copy of the graph with do applied

    """
    if label is None:
        label = f"{graph.label}|do({node}={value})"

    nodes = {n.label: deepcopy(n) for n in graph.nodes if n.label != node}

    # Search through the old graph, identifying nodes that had parameters which were
    # defined by the node being fixed in the DO operation.
    # We recreate these nodes, but replace each such parameter we encounter with
    # a constant parameter equal that takes the fixed value given as an input.
    for n in nodes.values():
        params = tuple(n.parameters.keys())
        for parameter_name in params:
            if n.parameters[parameter_name] == node:
                # Swap the parameter to a constant parameter, giving it the fixed value
                n.constant_parameters[parameter_name] = value
                # Remove the parameter from the node's record of non-constant parameters
                n.parameters.pop(parameter_name)

    # Recursively remove nodes that are predecessors of removed nodes
    nodes_to_remove: tuple[str, ...] = (node,)
    while len(nodes_to_remove) > 0:
        nodes_to_remove = removable_nodes(graph, nodes)
        for n in removable_nodes(graph, nodes):
            nodes.pop(n)

    # Check for nodes that are predecessors of both a removed node and a remaining node
    # and throw an error if one of these is found
    for n in nodes:
        _, excluded = get_included_excluded_successors(graph, nodes, n)
        if len(excluded) > 0:
            msg = (
                "Node that is predecessor of node set by do and "
                f'nodes that are not removed found ("{n}")'
            )
            raise ValueError(msg)

    g = Graph(label=f"{label}|do[{node}={value}]")
    for n in nodes.values():
        g.add_node(n)

    # Any nodes whose counterparts connect to other nodes in the network need
    # to mimic these links.
    for edge in graph.edges:
        if edge[0].label in nodes and edge[1].label in nodes:
            g.add_edge(edge[0].label, edge[1].label)

    return g

get_included_excluded_successors(graph, node_list, successors_of)

Split successors of a node into nodes included and not included in a list.

Split the successorts of a node into a list of nodes that are included in the input node list and a list of nodes that are not in the list.

Parameters:

Name Type Description Default
graph Graph

The graph

required
node_list dict[str, Node]

A dictionary of nodes, indexed by label

required
successors_of str

The node to check the successors of

required

Returns:

Type Description
tuple[tuple[str, ...], tuple[str, ...]]

Lists of included and excluded nodes

Source code in src/causalprog/algorithms/do.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def get_included_excluded_successors(
    graph: Graph, node_list: dict[str, Node], successors_of: str
) -> tuple[tuple[str, ...], tuple[str, ...]]:
    """
    Split successors of a node into nodes included and not included in a list.

    Split the successorts of a node into a list of nodes that are included in
    the input node list and a list of nodes that are not in the list.

    Args:
        graph: The graph
        node_list: A dictionary of nodes, indexed by label
        successors_of: The node to check the successors of

    Returns:
        Lists of included and excluded nodes

    """
    included = []
    excluded = []
    for n in graph.successors[graph.get_node(successors_of)]:
        if n.label in node_list:
            included.append(n)
        else:
            excluded.append(n)
    return tuple(included), tuple(excluded)

removable_nodes(graph, nodes)

Generate list of nodes that can be removed from the graph.

Parameters:

Name Type Description Default
graph Graph

The graph

required
nodes dict[str, Node]

A dictionary of nodes, indexed by label

required

Returns:

Type Description
tuple[str, ...]

List of labels of removable nodes

Source code in src/causalprog/algorithms/do.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def removable_nodes(graph: Graph, nodes: dict[str, Node]) -> tuple[str, ...]:
    """
    Generate list of nodes that can be removed from the graph.

    Args:
        graph: The graph
        nodes: A dictionary of nodes, indexed by label

    Returns:
        List of labels of removable nodes

    """
    removable: list[str] = []
    for n in nodes:
        included, excluded = get_included_excluded_successors(graph, nodes, n)
        if len(excluded) > 0 and len(included) == 0:
            removable.append(n)
    return tuple(removable)

moments

Algorithms for estimating the expectation and standard deviation.

expectation(graph, outcome_node_label, samples, *, parameter_values=None, rng_key)

Estimate the expectation of (a random variable attached to) a node in a graph.

Source code in src/causalprog/algorithms/moments.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def expectation(
    graph: Graph,
    outcome_node_label: str,
    samples: int,
    *,
    parameter_values: dict[str, float] | None = None,
    rng_key: jax.Array,
) -> float:
    """Estimate the expectation of (a random variable attached to) a node in a graph."""
    return moment(
        1,
        graph,
        outcome_node_label,
        samples,
        rng_key=rng_key,
        parameter_values=parameter_values,
    )

moment(order, graph, outcome_node_label, samples, *, parameter_values=None, rng_key)

Estimate a moment of (a random variable attached to) a node in a graph.

Source code in src/causalprog/algorithms/moments.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def moment(
    order: int,
    graph: Graph,
    outcome_node_label: str,
    samples: int,
    *,
    parameter_values: dict[str, float] | None = None,
    rng_key: jax.Array,
) -> float:
    """Estimate a moment of (a random variable attached to) a node in a graph."""
    return (
        sum(
            sample(
                graph,
                outcome_node_label,
                samples,
                rng_key=rng_key,
                parameter_values=parameter_values,
            )
            ** order
        )
        / samples
    )

sample(graph, outcome_node_label, samples, *, parameter_values=None, rng_key)

Sample data from (a random variable attached to) a node in a graph.

Source code in src/causalprog/algorithms/moments.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def sample(
    graph: Graph,
    outcome_node_label: str,
    samples: int,
    *,
    parameter_values: dict[str, float] | None = None,
    rng_key: jax.Array,
) -> npt.NDArray[float]:
    """Sample data from (a random variable attached to) a node in a graph."""
    nodes = graph.roots_down_to_outcome(outcome_node_label)

    values: dict[str, npt.NDArray[float]] = {}
    keys = jax.random.split(rng_key, len(nodes))

    for node, key in zip(nodes, keys, strict=False):
        values[node.label] = node.sample(
            parameter_values if parameter_values else {},
            values,
            samples,
            rng_key=key,
        )
    return values[outcome_node_label]

standard_deviation(graph, outcome_node_label, samples, *, parameter_values=None, rng_key, rng_key_first_moment=None)

Estimate the standard deviation of (a RV attached to) a node in a graph.

Source code in src/causalprog/algorithms/moments.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def standard_deviation(
    graph: Graph,
    outcome_node_label: str,
    samples: int,
    *,
    parameter_values: dict[str, float] | None = None,
    rng_key: jax.Array,
    rng_key_first_moment: jax.Array | None = None,
) -> float:
    """Estimate the standard deviation of (a RV attached to) a node in a graph."""
    return (
        moment(
            2,
            graph,
            outcome_node_label,
            samples,
            rng_key=rng_key,
            parameter_values=parameter_values,
        )
        - moment(
            1,
            graph,
            outcome_node_label,
            samples,
            rng_key=rng_key if rng_key_first_moment is None else rng_key_first_moment,
            parameter_values=parameter_values,
        )
        ** 2
    ) ** 0.5

backend

Helper functionality for incorporating different backends.

causal_problem

Classes for defining causal problems.

CausalEstimand

Bases: _CPComponent

A Causal Estimand.

The causal estimand is the function that we want to minimise (and maximise) as part of a causal problem. It should be a scalar-valued function of the random variables appearing in a graph.

Source code in src/causalprog/causal_problem/components.py
16
17
18
19
20
21
22
23
class CausalEstimand(_CPComponent):
    """
    A Causal Estimand.

    The causal estimand is the function that we want to minimise (and maximise)
    as part of a causal problem. It should be a scalar-valued function of the
    random variables appearing in a graph.
    """

CausalProblem

Defines a causal problem.

Source code in src/causalprog/causal_problem/causal_problem.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class CausalProblem:
    """Defines a causal problem."""

    _underlying_graph: Graph
    causal_estimand: CausalEstimand
    constraints: list[Constraint]

    @property
    def _ordered_components(self) -> list[_CPComponent]:
        """Internal ordering for components of the causal problem."""
        return [*self.constraints, self.causal_estimand]

    def __init__(
        self,
        graph: Graph,
        *constraints: Constraint,
        causal_estimand: CausalEstimand,
    ) -> None:
        """Create a new causal problem."""
        self._underlying_graph = graph
        self.causal_estimand = causal_estimand
        self.constraints = list(constraints)

    def _associate_models_to_components(
        self, n_samples: int
    ) -> tuple[list[Predictive], list[int]]:
        """
        Create models to be used by components of the problem.

        Depending on how many constraints (and the causal estimand) require effect
        handlers to wrap `self._underlying_graph.model`, we will need to create several
        predictive models to sample from. However, we also want to minimise the number
        of such models we have to make, in order to minimise the time we spend
        actually computing samples.

        As such, in this method we determine:
        - How many models we will need to build, by grouping the constraints and the
          causal estimand by the handlers they use.
        - Build these models, returning them in a list called `models`.
        - Build another list that maps the index of components in
          `self._ordered_components` to the index of the model in `models` that they
          use. The causal estimand is by convention the component at index -1 of this
          returned list.

        Args:
            n_samples: Value to be passed to `numpyro.Predictive`'s `num_samples`
                argument for each of the models that are constructed from the underlying
                graph.

        Returns:
            list[Predictive]: List of Predictive models, whose elements contain all the
                models needed by the components.
            list[int]: Mapping of component indexes (as per `self_ordered_components`)
                to the index of the model in the first return argument that the
                component uses.

        """
        models: list[Predictive] = []
        grouped_component_indexes: list[list[int]] = []
        for index, component in enumerate(self._ordered_components):
            # Determine if this constraint uses the same handlers as those of any of
            # the other sets.
            belongs_to_existing_group = False
            for group in grouped_component_indexes:
                # Pull any element from the group to compare models to.
                # Items in a group are known to have the same model, so we can just
                # pull out the first one.
                group_element = self._ordered_components[group[0]]
                # Check if the current constraint can also use this model.
                if component.can_use_same_model_as(group_element):
                    group.append(index)
                    belongs_to_existing_group = True
                    break

            # If the component does not fit into any existing group, create a new
            # group for it. And add the model corresponding to the group to the
            # list of models.
            if not belongs_to_existing_group:
                grouped_component_indexes.append([index])

                models.append(
                    Predictive(
                        component.apply_effects(self._underlying_graph.model),
                        num_samples=n_samples,
                    )
                )

        # Now "invert" the grouping, creating a mapping that maps the index of a
        # component to the (index of the) model it uses.
        component_index_to_model_index = []
        for index in range(len(self._ordered_components)):
            for group_index, group in enumerate(grouped_component_indexes):
                if index in group:
                    component_index_to_model_index.append(group_index)
                    break
        # All indexes should belong to at least one group (worst case scenario,
        # their own individual group). Thus, it is safe to do the above to create
        # the mapping from component index -> model (group) index.
        return models, component_index_to_model_index

    def lagrangian(
        self, n_samples: int = 1000, *, maximum_problem: bool = False
    ) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]:
        """
        Return a function that evaluates the Lagrangian of this `CausalProblem`.

        Following the
        [KKT theorem](https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions),
        given the causal estimand and the constraints we can assemble a Lagrangian and
        seek its stationary points, to in turn identify minimisers of the constrained
        optimisation problem that we started with.

        The Lagrangian returned is a mathematical function of its first two arguments.
        The first argument is the same dictionary of parameters that is passed to models
        like `Graph.model`, and is the values the parameters (represented by the
        `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange
        multipliers, whose length is equal to the number of constraints.

        The remaining argument of the Lagrangian is the PRNG Key that should be used
        when drawing samples.

        Note that our current implementation assumes there are no equality constraints
        being imposed (in which case, we would need a 3-argument Lagrangian function).

        Args:
            n_samples: The number of random samples to be drawn when estimating the
                value of functions of the RVs.
            maximum_problem: If passed as `True`, assemble the Lagrangian for the
                maximisation problem. Otherwise assemble that for the minimisation
                problem (default behaviour).

        Returns:
            The Lagrangian, as a function of the model parameters, Lagrange multipliers,
                and PRNG key.

        """
        maximisation_prefactor = -1.0 if maximum_problem else 1.0

        # Build association between self.constraints and the model-samples that each
        # one needs to use. We do this here, since once it is constructed, it is
        # fixed, and doesn't need to be done each time we call the Lagrangian.
        models, component_to_index_mapping = self._associate_models_to_components(
            n_samples
        )

        def _inner(
            parameter_values: dict[str, npt.ArrayLike],
            l_mult: jax.Array,
            rng_key: jax.Array,
        ) -> npt.ArrayLike:
            # Draw samples from all models
            all_samples = tuple(
                sample_model(model, rng_key, parameter_values) for model in models
            )

            value = maximisation_prefactor * self.causal_estimand(all_samples[-1])
            # TODO: https://github.com/UCL/causalprog/issues/87
            value += sum(
                l_mult[i] * c(all_samples[component_to_index_mapping[i]])
                for i, c in enumerate(self.constraints)
            )
            return value

        return _inner

__init__(graph, *constraints, causal_estimand)

Create a new causal problem.

Source code in src/causalprog/causal_problem/causal_problem.py
52
53
54
55
56
57
58
59
60
61
def __init__(
    self,
    graph: Graph,
    *constraints: Constraint,
    causal_estimand: CausalEstimand,
) -> None:
    """Create a new causal problem."""
    self._underlying_graph = graph
    self.causal_estimand = causal_estimand
    self.constraints = list(constraints)

lagrangian(n_samples=1000, *, maximum_problem=False)

Return a function that evaluates the Lagrangian of this CausalProblem.

Following the KKT theorem, given the causal estimand and the constraints we can assemble a Lagrangian and seek its stationary points, to in turn identify minimisers of the constrained optimisation problem that we started with.

The Lagrangian returned is a mathematical function of its first two arguments. The first argument is the same dictionary of parameters that is passed to models like Graph.model, and is the values the parameters (represented by the ParameterNodes) are taking. The second argument is a 1D vector of Lagrange multipliers, whose length is equal to the number of constraints.

The remaining argument of the Lagrangian is the PRNG Key that should be used when drawing samples.

Note that our current implementation assumes there are no equality constraints being imposed (in which case, we would need a 3-argument Lagrangian function).

Parameters:

Name Type Description Default
n_samples int

The number of random samples to be drawn when estimating the value of functions of the RVs.

1000
maximum_problem bool

If passed as True, assemble the Lagrangian for the maximisation problem. Otherwise assemble that for the minimisation problem (default behaviour).

False

Returns:

Type Description
Callable[[dict[str, ArrayLike], ArrayLike, Array], ArrayLike]

The Lagrangian, as a function of the model parameters, Lagrange multipliers, and PRNG key.

Source code in src/causalprog/causal_problem/causal_problem.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def lagrangian(
    self, n_samples: int = 1000, *, maximum_problem: bool = False
) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]:
    """
    Return a function that evaluates the Lagrangian of this `CausalProblem`.

    Following the
    [KKT theorem](https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions),
    given the causal estimand and the constraints we can assemble a Lagrangian and
    seek its stationary points, to in turn identify minimisers of the constrained
    optimisation problem that we started with.

    The Lagrangian returned is a mathematical function of its first two arguments.
    The first argument is the same dictionary of parameters that is passed to models
    like `Graph.model`, and is the values the parameters (represented by the
    `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange
    multipliers, whose length is equal to the number of constraints.

    The remaining argument of the Lagrangian is the PRNG Key that should be used
    when drawing samples.

    Note that our current implementation assumes there are no equality constraints
    being imposed (in which case, we would need a 3-argument Lagrangian function).

    Args:
        n_samples: The number of random samples to be drawn when estimating the
            value of functions of the RVs.
        maximum_problem: If passed as `True`, assemble the Lagrangian for the
            maximisation problem. Otherwise assemble that for the minimisation
            problem (default behaviour).

    Returns:
        The Lagrangian, as a function of the model parameters, Lagrange multipliers,
            and PRNG key.

    """
    maximisation_prefactor = -1.0 if maximum_problem else 1.0

    # Build association between self.constraints and the model-samples that each
    # one needs to use. We do this here, since once it is constructed, it is
    # fixed, and doesn't need to be done each time we call the Lagrangian.
    models, component_to_index_mapping = self._associate_models_to_components(
        n_samples
    )

    def _inner(
        parameter_values: dict[str, npt.ArrayLike],
        l_mult: jax.Array,
        rng_key: jax.Array,
    ) -> npt.ArrayLike:
        # Draw samples from all models
        all_samples = tuple(
            sample_model(model, rng_key, parameter_values) for model in models
        )

        value = maximisation_prefactor * self.causal_estimand(all_samples[-1])
        # TODO: https://github.com/UCL/causalprog/issues/87
        value += sum(
            l_mult[i] * c(all_samples[component_to_index_mapping[i]])
            for i, c in enumerate(self.constraints)
        )
        return value

    return _inner

Constraint

Bases: _CPComponent

A Constraint that forms part of a causal problem.

Constraints of a causal problem are derived properties of RVs for which we have observed data. The causal estimand is minimised (or maximised) subject to the predicted values of the constraints being close to their observed values in the data.

Adding a constraint \(g(\theta)\) to a causal problem (where \(\theta\) are the parameters of the causal problem) essentially imposes an additional constraint on the minimisation problem;

\[ g(\theta) - g_{\text{data}} \leq \epsilon, \]

where \(g_{\text{data}}\) is the observed data value for the quantity \(g\), and \(\epsilon\) is some tolerance.

Source code in src/causalprog/causal_problem/components.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class Constraint(_CPComponent):
    r"""
    A Constraint that forms part of a causal problem.

    Constraints of a causal problem are derived properties of RVs for which we
    have observed data. The causal estimand is minimised (or maximised) subject
    to the predicted values of the constraints being close to their observed
    values in the data.

    Adding a constraint $g(\theta)$ to a causal problem (where $\theta$ are the
    parameters of the causal problem) essentially imposes an additional
    constraint on the minimisation problem;

    $$ g(\theta) - g_{\text{data}} \leq \epsilon, $$

    where $g_{\text{data}}$ is the observed data value for the quantity $g$,
    and $\epsilon$ is some tolerance.
    """

    data: npt.ArrayLike
    tolerance: npt.ArrayLike
    _outer_norm: Callable[[npt.ArrayLike], float]

    def __init__(
        self,
        *effect_handlers: ModelMask,
        model_quantity: Callable[..., npt.ArrayLike],
        outer_norm: Callable[[npt.ArrayLike], float] | None = None,
        data: npt.ArrayLike = 0.0,
        tolerance: float = 1.0e-6,
    ) -> None:
        r"""
        Create a new constraint.

        Constraints have the form

        $$ c(\theta) :=
        \mathrm{norm}\left( g(\theta)
        - g_{\mathrm{data}} \right)
        - \epsilon $$

        where;
        - $\mathrm{norm}$ is the outer norm of the constraint (`outer_norm`),
        - $g(\theta)$ is the model quantity involved in the constraint
            (`model_quantity`),
        - $g_{\mathrm{data}}$ is the observed data (`data`),
        - $\epsilon$ is the tolerance in the data (`tolerance`).

        In a causal problem, each constraint appears as the condition $c(\theta)\leq 0$
        in the minimisation / maximisation (hence the inclusion of the $-\epsilon$
        term within $c(\theta)$ itself).

        $g$ should be a (possibly vector-valued) function that acts on (a subset of)
        samples from the random variables of the causal problem. It must accept
        variable keyword-arguments only, and should access the samples for each random
        variable by indexing via the RV names (node labels). It should return the
        model quantity as computed from the samples, that $g_{\mathrm{data}}$ observed.

        $g_{\mathrm{data}}$ should be a fixed value whose shape is broadcast-able with
        the return shape of $g$. It defaults to $0$ if not explicitly set.

        $\mathrm{norm}$ should be a suitable norm to take on the difference between the
        model quantity as predicted by the samples ($g$) and the observed data
        ($g_{\mathrm{data}}$). It must return a scalar value. The default is the 2-norm.
        """
        super().__init__(*effect_handlers, do_with_samples=model_quantity)

        if outer_norm is None:
            self._outer_norm = jnp.linalg.vector_norm
        else:
            self._outer_norm = outer_norm

        self.data = data
        self.tolerance = tolerance

    def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
        """
        Evaluate the constraint, given RV samples.

        Args:
            samples: Mapping of RV (node) labels to drawn samples.

        Returns:
            Value of the constraint.

        """
        return (
            self._outer_norm(self._do_with_samples(**samples) - self.data)
            - self.tolerance
        )

__call__(samples)

Evaluate the constraint, given RV samples.

Parameters:

Name Type Description Default
samples dict[str, ArrayLike]

Mapping of RV (node) labels to drawn samples.

required

Returns:

Type Description
ArrayLike

Value of the constraint.

Source code in src/causalprog/causal_problem/components.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
    """
    Evaluate the constraint, given RV samples.

    Args:
        samples: Mapping of RV (node) labels to drawn samples.

    Returns:
        Value of the constraint.

    """
    return (
        self._outer_norm(self._do_with_samples(**samples) - self.data)
        - self.tolerance
    )

__init__(*effect_handlers, model_quantity, outer_norm=None, data=0.0, tolerance=1e-06)

Create a new constraint.

Constraints have the form

\[ c(\theta) := \mathrm{norm}\left( g(\theta) - g_{\mathrm{data}} \right) - \epsilon \]

where; - \(\mathrm{norm}\) is the outer norm of the constraint (outer_norm), - \(g(\theta)\) is the model quantity involved in the constraint (model_quantity), - \(g_{\mathrm{data}}\) is the observed data (data), - \(\epsilon\) is the tolerance in the data (tolerance).

In a causal problem, each constraint appears as the condition \(c(\theta)\leq 0\) in the minimisation / maximisation (hence the inclusion of the \(-\epsilon\) term within \(c(\theta)\) itself).

\(g\) should be a (possibly vector-valued) function that acts on (a subset of) samples from the random variables of the causal problem. It must accept variable keyword-arguments only, and should access the samples for each random variable by indexing via the RV names (node labels). It should return the model quantity as computed from the samples, that \(g_{\mathrm{data}}\) observed.

\(g_{\mathrm{data}}\) should be a fixed value whose shape is broadcast-able with the return shape of \(g\). It defaults to \(0\) if not explicitly set.

\(\mathrm{norm}\) should be a suitable norm to take on the difference between the model quantity as predicted by the samples (\(g\)) and the observed data (\(g_{\mathrm{data}}\)). It must return a scalar value. The default is the 2-norm.

Source code in src/causalprog/causal_problem/components.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def __init__(
    self,
    *effect_handlers: ModelMask,
    model_quantity: Callable[..., npt.ArrayLike],
    outer_norm: Callable[[npt.ArrayLike], float] | None = None,
    data: npt.ArrayLike = 0.0,
    tolerance: float = 1.0e-6,
) -> None:
    r"""
    Create a new constraint.

    Constraints have the form

    $$ c(\theta) :=
    \mathrm{norm}\left( g(\theta)
    - g_{\mathrm{data}} \right)
    - \epsilon $$

    where;
    - $\mathrm{norm}$ is the outer norm of the constraint (`outer_norm`),
    - $g(\theta)$ is the model quantity involved in the constraint
        (`model_quantity`),
    - $g_{\mathrm{data}}$ is the observed data (`data`),
    - $\epsilon$ is the tolerance in the data (`tolerance`).

    In a causal problem, each constraint appears as the condition $c(\theta)\leq 0$
    in the minimisation / maximisation (hence the inclusion of the $-\epsilon$
    term within $c(\theta)$ itself).

    $g$ should be a (possibly vector-valued) function that acts on (a subset of)
    samples from the random variables of the causal problem. It must accept
    variable keyword-arguments only, and should access the samples for each random
    variable by indexing via the RV names (node labels). It should return the
    model quantity as computed from the samples, that $g_{\mathrm{data}}$ observed.

    $g_{\mathrm{data}}$ should be a fixed value whose shape is broadcast-able with
    the return shape of $g$. It defaults to $0$ if not explicitly set.

    $\mathrm{norm}$ should be a suitable norm to take on the difference between the
    model quantity as predicted by the samples ($g$) and the observed data
    ($g_{\mathrm{data}}$). It must return a scalar value. The default is the 2-norm.
    """
    super().__init__(*effect_handlers, do_with_samples=model_quantity)

    if outer_norm is None:
        self._outer_norm = jnp.linalg.vector_norm
    else:
        self._outer_norm = outer_norm

    self.data = data
    self.tolerance = tolerance

HandlerToApply dataclass

Specifies a handler that needs to be applied to a model at runtime.

Source code in src/causalprog/causal_problem/handlers.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@dataclass
class HandlerToApply:
    """Specifies a handler that needs to be applied to a model at runtime."""

    handler: EffectHandler
    options: dict[str, Any] = field(default_factory=dict)

    @classmethod
    def from_pair(cls, pair: Sequence) -> "HandlerToApply":
        """
        Create an instance from an effect handler and its options.

        The two objects should be passed in as the elements of a container of length
        2. They can be passed in any order;
        - One element must be a dictionary, which will be interpreted as the `options`
            for the effect handler.
        - The other element must be callable, and will be interpreted as the `handler`
            itself.

        Args:
            pair: Container of two elements, one being the effect handler callable and
                the other being the options to pass to it (as a dictionary).

        Returns:
            Class instance corresponding to the effect handler and options passed.

        """
        if len(pair) != 2:  # noqa: PLR2004
            msg = (
                f"{cls.__name__} can only be constructed from a container of 2 elements"
            )
            raise ValueError(msg)

        # __post_init__ will catch cases when the incorrect types for one or both items
        # is passed, so we can just naively if-else here.
        handler: EffectHandler
        options: dict
        if callable(pair[0]):
            handler = pair[0]
            options = pair[1]
        else:
            handler = pair[1]
            options = pair[0]

        return cls(handler=handler, options=options)

    def __post_init__(self) -> None:
        """
        Validate set attributes.

        - The handler is a callable object.
        - The options have been passed as a dictionary of keyword-value pairs.
        """
        if not callable(self.handler):
            msg = f"{type(self.handler).__name__} is not callable."
            raise TypeError(msg)
        if not isinstance(self.options, dict):
            msg = (
                "Options should be dictionary mapping option arguments to values "
                f"(got {type(self.options).__name__})."
            )
            raise TypeError(msg)

    def __eq__(self, other: object) -> bool:
        """
        Equality operation.

        `HandlerToApply`s are considered equal if they use the same handler function and
        provide the same options to this function.

        Comparison to other types returns `False`.
        """
        return (
            isinstance(other, HandlerToApply)
            and self.handler is other.handler
            and self.options == other.options
        )

__eq__(other)

Equality operation.

HandlerToApplys are considered equal if they use the same handler function and provide the same options to this function.

Comparison to other types returns False.

Source code in src/causalprog/causal_problem/handlers.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __eq__(self, other: object) -> bool:
    """
    Equality operation.

    `HandlerToApply`s are considered equal if they use the same handler function and
    provide the same options to this function.

    Comparison to other types returns `False`.
    """
    return (
        isinstance(other, HandlerToApply)
        and self.handler is other.handler
        and self.options == other.options
    )

__post_init__()

Validate set attributes.

  • The handler is a callable object.
  • The options have been passed as a dictionary of keyword-value pairs.
Source code in src/causalprog/causal_problem/handlers.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __post_init__(self) -> None:
    """
    Validate set attributes.

    - The handler is a callable object.
    - The options have been passed as a dictionary of keyword-value pairs.
    """
    if not callable(self.handler):
        msg = f"{type(self.handler).__name__} is not callable."
        raise TypeError(msg)
    if not isinstance(self.options, dict):
        msg = (
            "Options should be dictionary mapping option arguments to values "
            f"(got {type(self.options).__name__})."
        )
        raise TypeError(msg)

from_pair(pair) classmethod

Create an instance from an effect handler and its options.

The two objects should be passed in as the elements of a container of length 2. They can be passed in any order; - One element must be a dictionary, which will be interpreted as the options for the effect handler. - The other element must be callable, and will be interpreted as the handler itself.

Parameters:

Name Type Description Default
pair Sequence

Container of two elements, one being the effect handler callable and the other being the options to pass to it (as a dictionary).

required

Returns:

Type Description
HandlerToApply

Class instance corresponding to the effect handler and options passed.

Source code in src/causalprog/causal_problem/handlers.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@classmethod
def from_pair(cls, pair: Sequence) -> "HandlerToApply":
    """
    Create an instance from an effect handler and its options.

    The two objects should be passed in as the elements of a container of length
    2. They can be passed in any order;
    - One element must be a dictionary, which will be interpreted as the `options`
        for the effect handler.
    - The other element must be callable, and will be interpreted as the `handler`
        itself.

    Args:
        pair: Container of two elements, one being the effect handler callable and
            the other being the options to pass to it (as a dictionary).

    Returns:
        Class instance corresponding to the effect handler and options passed.

    """
    if len(pair) != 2:  # noqa: PLR2004
        msg = (
            f"{cls.__name__} can only be constructed from a container of 2 elements"
        )
        raise ValueError(msg)

    # __post_init__ will catch cases when the incorrect types for one or both items
    # is passed, so we can just naively if-else here.
    handler: EffectHandler
    options: dict
    if callable(pair[0]):
        handler = pair[0]
        options = pair[1]
    else:
        handler = pair[1]
        options = pair[0]

    return cls(handler=handler, options=options)

causal_problem

Classes for representing causal problems.

CausalProblem

Defines a causal problem.

Source code in src/causalprog/causal_problem/causal_problem.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class CausalProblem:
    """Defines a causal problem."""

    _underlying_graph: Graph
    causal_estimand: CausalEstimand
    constraints: list[Constraint]

    @property
    def _ordered_components(self) -> list[_CPComponent]:
        """Internal ordering for components of the causal problem."""
        return [*self.constraints, self.causal_estimand]

    def __init__(
        self,
        graph: Graph,
        *constraints: Constraint,
        causal_estimand: CausalEstimand,
    ) -> None:
        """Create a new causal problem."""
        self._underlying_graph = graph
        self.causal_estimand = causal_estimand
        self.constraints = list(constraints)

    def _associate_models_to_components(
        self, n_samples: int
    ) -> tuple[list[Predictive], list[int]]:
        """
        Create models to be used by components of the problem.

        Depending on how many constraints (and the causal estimand) require effect
        handlers to wrap `self._underlying_graph.model`, we will need to create several
        predictive models to sample from. However, we also want to minimise the number
        of such models we have to make, in order to minimise the time we spend
        actually computing samples.

        As such, in this method we determine:
        - How many models we will need to build, by grouping the constraints and the
          causal estimand by the handlers they use.
        - Build these models, returning them in a list called `models`.
        - Build another list that maps the index of components in
          `self._ordered_components` to the index of the model in `models` that they
          use. The causal estimand is by convention the component at index -1 of this
          returned list.

        Args:
            n_samples: Value to be passed to `numpyro.Predictive`'s `num_samples`
                argument for each of the models that are constructed from the underlying
                graph.

        Returns:
            list[Predictive]: List of Predictive models, whose elements contain all the
                models needed by the components.
            list[int]: Mapping of component indexes (as per `self_ordered_components`)
                to the index of the model in the first return argument that the
                component uses.

        """
        models: list[Predictive] = []
        grouped_component_indexes: list[list[int]] = []
        for index, component in enumerate(self._ordered_components):
            # Determine if this constraint uses the same handlers as those of any of
            # the other sets.
            belongs_to_existing_group = False
            for group in grouped_component_indexes:
                # Pull any element from the group to compare models to.
                # Items in a group are known to have the same model, so we can just
                # pull out the first one.
                group_element = self._ordered_components[group[0]]
                # Check if the current constraint can also use this model.
                if component.can_use_same_model_as(group_element):
                    group.append(index)
                    belongs_to_existing_group = True
                    break

            # If the component does not fit into any existing group, create a new
            # group for it. And add the model corresponding to the group to the
            # list of models.
            if not belongs_to_existing_group:
                grouped_component_indexes.append([index])

                models.append(
                    Predictive(
                        component.apply_effects(self._underlying_graph.model),
                        num_samples=n_samples,
                    )
                )

        # Now "invert" the grouping, creating a mapping that maps the index of a
        # component to the (index of the) model it uses.
        component_index_to_model_index = []
        for index in range(len(self._ordered_components)):
            for group_index, group in enumerate(grouped_component_indexes):
                if index in group:
                    component_index_to_model_index.append(group_index)
                    break
        # All indexes should belong to at least one group (worst case scenario,
        # their own individual group). Thus, it is safe to do the above to create
        # the mapping from component index -> model (group) index.
        return models, component_index_to_model_index

    def lagrangian(
        self, n_samples: int = 1000, *, maximum_problem: bool = False
    ) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]:
        """
        Return a function that evaluates the Lagrangian of this `CausalProblem`.

        Following the
        [KKT theorem](https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions),
        given the causal estimand and the constraints we can assemble a Lagrangian and
        seek its stationary points, to in turn identify minimisers of the constrained
        optimisation problem that we started with.

        The Lagrangian returned is a mathematical function of its first two arguments.
        The first argument is the same dictionary of parameters that is passed to models
        like `Graph.model`, and is the values the parameters (represented by the
        `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange
        multipliers, whose length is equal to the number of constraints.

        The remaining argument of the Lagrangian is the PRNG Key that should be used
        when drawing samples.

        Note that our current implementation assumes there are no equality constraints
        being imposed (in which case, we would need a 3-argument Lagrangian function).

        Args:
            n_samples: The number of random samples to be drawn when estimating the
                value of functions of the RVs.
            maximum_problem: If passed as `True`, assemble the Lagrangian for the
                maximisation problem. Otherwise assemble that for the minimisation
                problem (default behaviour).

        Returns:
            The Lagrangian, as a function of the model parameters, Lagrange multipliers,
                and PRNG key.

        """
        maximisation_prefactor = -1.0 if maximum_problem else 1.0

        # Build association between self.constraints and the model-samples that each
        # one needs to use. We do this here, since once it is constructed, it is
        # fixed, and doesn't need to be done each time we call the Lagrangian.
        models, component_to_index_mapping = self._associate_models_to_components(
            n_samples
        )

        def _inner(
            parameter_values: dict[str, npt.ArrayLike],
            l_mult: jax.Array,
            rng_key: jax.Array,
        ) -> npt.ArrayLike:
            # Draw samples from all models
            all_samples = tuple(
                sample_model(model, rng_key, parameter_values) for model in models
            )

            value = maximisation_prefactor * self.causal_estimand(all_samples[-1])
            # TODO: https://github.com/UCL/causalprog/issues/87
            value += sum(
                l_mult[i] * c(all_samples[component_to_index_mapping[i]])
                for i, c in enumerate(self.constraints)
            )
            return value

        return _inner
__init__(graph, *constraints, causal_estimand)

Create a new causal problem.

Source code in src/causalprog/causal_problem/causal_problem.py
52
53
54
55
56
57
58
59
60
61
def __init__(
    self,
    graph: Graph,
    *constraints: Constraint,
    causal_estimand: CausalEstimand,
) -> None:
    """Create a new causal problem."""
    self._underlying_graph = graph
    self.causal_estimand = causal_estimand
    self.constraints = list(constraints)
lagrangian(n_samples=1000, *, maximum_problem=False)

Return a function that evaluates the Lagrangian of this CausalProblem.

Following the KKT theorem, given the causal estimand and the constraints we can assemble a Lagrangian and seek its stationary points, to in turn identify minimisers of the constrained optimisation problem that we started with.

The Lagrangian returned is a mathematical function of its first two arguments. The first argument is the same dictionary of parameters that is passed to models like Graph.model, and is the values the parameters (represented by the ParameterNodes) are taking. The second argument is a 1D vector of Lagrange multipliers, whose length is equal to the number of constraints.

The remaining argument of the Lagrangian is the PRNG Key that should be used when drawing samples.

Note that our current implementation assumes there are no equality constraints being imposed (in which case, we would need a 3-argument Lagrangian function).

Parameters:

Name Type Description Default
n_samples int

The number of random samples to be drawn when estimating the value of functions of the RVs.

1000
maximum_problem bool

If passed as True, assemble the Lagrangian for the maximisation problem. Otherwise assemble that for the minimisation problem (default behaviour).

False

Returns:

Type Description
Callable[[dict[str, ArrayLike], ArrayLike, Array], ArrayLike]

The Lagrangian, as a function of the model parameters, Lagrange multipliers, and PRNG key.

Source code in src/causalprog/causal_problem/causal_problem.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def lagrangian(
    self, n_samples: int = 1000, *, maximum_problem: bool = False
) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]:
    """
    Return a function that evaluates the Lagrangian of this `CausalProblem`.

    Following the
    [KKT theorem](https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions),
    given the causal estimand and the constraints we can assemble a Lagrangian and
    seek its stationary points, to in turn identify minimisers of the constrained
    optimisation problem that we started with.

    The Lagrangian returned is a mathematical function of its first two arguments.
    The first argument is the same dictionary of parameters that is passed to models
    like `Graph.model`, and is the values the parameters (represented by the
    `ParameterNode`s) are taking. The second argument is a 1D vector of Lagrange
    multipliers, whose length is equal to the number of constraints.

    The remaining argument of the Lagrangian is the PRNG Key that should be used
    when drawing samples.

    Note that our current implementation assumes there are no equality constraints
    being imposed (in which case, we would need a 3-argument Lagrangian function).

    Args:
        n_samples: The number of random samples to be drawn when estimating the
            value of functions of the RVs.
        maximum_problem: If passed as `True`, assemble the Lagrangian for the
            maximisation problem. Otherwise assemble that for the minimisation
            problem (default behaviour).

    Returns:
        The Lagrangian, as a function of the model parameters, Lagrange multipliers,
            and PRNG key.

    """
    maximisation_prefactor = -1.0 if maximum_problem else 1.0

    # Build association between self.constraints and the model-samples that each
    # one needs to use. We do this here, since once it is constructed, it is
    # fixed, and doesn't need to be done each time we call the Lagrangian.
    models, component_to_index_mapping = self._associate_models_to_components(
        n_samples
    )

    def _inner(
        parameter_values: dict[str, npt.ArrayLike],
        l_mult: jax.Array,
        rng_key: jax.Array,
    ) -> npt.ArrayLike:
        # Draw samples from all models
        all_samples = tuple(
            sample_model(model, rng_key, parameter_values) for model in models
        )

        value = maximisation_prefactor * self.causal_estimand(all_samples[-1])
        # TODO: https://github.com/UCL/causalprog/issues/87
        value += sum(
            l_mult[i] * c(all_samples[component_to_index_mapping[i]])
            for i, c in enumerate(self.constraints)
        )
        return value

    return _inner

sample_model(model, rng_key, parameter_values)

Draw samples from the predictive model.

Parameters:

Name Type Description Default
model Predictive

Predictive model to draw samples from.

required
rng_key Array

PRNG Key to use in pseudorandom number generation.

required
parameter_values dict[str, ArrayLike]

Model parameter values to substitute.

required

Returns:

Type Description
dict[str, ArrayLike]

dict of samples, with RV labels as keys and sample values (jax.Arrays) as values.

Source code in src/causalprog/causal_problem/causal_problem.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def sample_model(
    model: Predictive, rng_key: jax.Array, parameter_values: dict[str, npt.ArrayLike]
) -> dict[str, npt.ArrayLike]:
    """
    Draw samples from the predictive model.

    Args:
        model: Predictive model to draw samples from.
        rng_key: PRNG Key to use in pseudorandom number generation.
        parameter_values: Model parameter values to substitute.

    Returns:
        `dict` of samples, with RV labels as keys and sample values (`jax.Array`s) as
            values.

    """
    return jax.vmap(lambda pv, key: model(key, **pv), in_axes=(None, 0))(
        parameter_values,
        jax.random.split(rng_key, model.num_samples),
    )

components

Classes for defining causal estimands and constraints of causal problems.

CausalEstimand

Bases: _CPComponent

A Causal Estimand.

The causal estimand is the function that we want to minimise (and maximise) as part of a causal problem. It should be a scalar-valued function of the random variables appearing in a graph.

Source code in src/causalprog/causal_problem/components.py
16
17
18
19
20
21
22
23
class CausalEstimand(_CPComponent):
    """
    A Causal Estimand.

    The causal estimand is the function that we want to minimise (and maximise)
    as part of a causal problem. It should be a scalar-valued function of the
    random variables appearing in a graph.
    """

Constraint

Bases: _CPComponent

A Constraint that forms part of a causal problem.

Constraints of a causal problem are derived properties of RVs for which we have observed data. The causal estimand is minimised (or maximised) subject to the predicted values of the constraints being close to their observed values in the data.

Adding a constraint \(g(\theta)\) to a causal problem (where \(\theta\) are the parameters of the causal problem) essentially imposes an additional constraint on the minimisation problem;

\[ g(\theta) - g_{\text{data}} \leq \epsilon, \]

where \(g_{\text{data}}\) is the observed data value for the quantity \(g\), and \(\epsilon\) is some tolerance.

Source code in src/causalprog/causal_problem/components.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class Constraint(_CPComponent):
    r"""
    A Constraint that forms part of a causal problem.

    Constraints of a causal problem are derived properties of RVs for which we
    have observed data. The causal estimand is minimised (or maximised) subject
    to the predicted values of the constraints being close to their observed
    values in the data.

    Adding a constraint $g(\theta)$ to a causal problem (where $\theta$ are the
    parameters of the causal problem) essentially imposes an additional
    constraint on the minimisation problem;

    $$ g(\theta) - g_{\text{data}} \leq \epsilon, $$

    where $g_{\text{data}}$ is the observed data value for the quantity $g$,
    and $\epsilon$ is some tolerance.
    """

    data: npt.ArrayLike
    tolerance: npt.ArrayLike
    _outer_norm: Callable[[npt.ArrayLike], float]

    def __init__(
        self,
        *effect_handlers: ModelMask,
        model_quantity: Callable[..., npt.ArrayLike],
        outer_norm: Callable[[npt.ArrayLike], float] | None = None,
        data: npt.ArrayLike = 0.0,
        tolerance: float = 1.0e-6,
    ) -> None:
        r"""
        Create a new constraint.

        Constraints have the form

        $$ c(\theta) :=
        \mathrm{norm}\left( g(\theta)
        - g_{\mathrm{data}} \right)
        - \epsilon $$

        where;
        - $\mathrm{norm}$ is the outer norm of the constraint (`outer_norm`),
        - $g(\theta)$ is the model quantity involved in the constraint
            (`model_quantity`),
        - $g_{\mathrm{data}}$ is the observed data (`data`),
        - $\epsilon$ is the tolerance in the data (`tolerance`).

        In a causal problem, each constraint appears as the condition $c(\theta)\leq 0$
        in the minimisation / maximisation (hence the inclusion of the $-\epsilon$
        term within $c(\theta)$ itself).

        $g$ should be a (possibly vector-valued) function that acts on (a subset of)
        samples from the random variables of the causal problem. It must accept
        variable keyword-arguments only, and should access the samples for each random
        variable by indexing via the RV names (node labels). It should return the
        model quantity as computed from the samples, that $g_{\mathrm{data}}$ observed.

        $g_{\mathrm{data}}$ should be a fixed value whose shape is broadcast-able with
        the return shape of $g$. It defaults to $0$ if not explicitly set.

        $\mathrm{norm}$ should be a suitable norm to take on the difference between the
        model quantity as predicted by the samples ($g$) and the observed data
        ($g_{\mathrm{data}}$). It must return a scalar value. The default is the 2-norm.
        """
        super().__init__(*effect_handlers, do_with_samples=model_quantity)

        if outer_norm is None:
            self._outer_norm = jnp.linalg.vector_norm
        else:
            self._outer_norm = outer_norm

        self.data = data
        self.tolerance = tolerance

    def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
        """
        Evaluate the constraint, given RV samples.

        Args:
            samples: Mapping of RV (node) labels to drawn samples.

        Returns:
            Value of the constraint.

        """
        return (
            self._outer_norm(self._do_with_samples(**samples) - self.data)
            - self.tolerance
        )
__call__(samples)

Evaluate the constraint, given RV samples.

Parameters:

Name Type Description Default
samples dict[str, ArrayLike]

Mapping of RV (node) labels to drawn samples.

required

Returns:

Type Description
ArrayLike

Value of the constraint.

Source code in src/causalprog/causal_problem/components.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
    """
    Evaluate the constraint, given RV samples.

    Args:
        samples: Mapping of RV (node) labels to drawn samples.

    Returns:
        Value of the constraint.

    """
    return (
        self._outer_norm(self._do_with_samples(**samples) - self.data)
        - self.tolerance
    )
__init__(*effect_handlers, model_quantity, outer_norm=None, data=0.0, tolerance=1e-06)

Create a new constraint.

Constraints have the form

\[ c(\theta) := \mathrm{norm}\left( g(\theta) - g_{\mathrm{data}} \right) - \epsilon \]

where; - \(\mathrm{norm}\) is the outer norm of the constraint (outer_norm), - \(g(\theta)\) is the model quantity involved in the constraint (model_quantity), - \(g_{\mathrm{data}}\) is the observed data (data), - \(\epsilon\) is the tolerance in the data (tolerance).

In a causal problem, each constraint appears as the condition \(c(\theta)\leq 0\) in the minimisation / maximisation (hence the inclusion of the \(-\epsilon\) term within \(c(\theta)\) itself).

\(g\) should be a (possibly vector-valued) function that acts on (a subset of) samples from the random variables of the causal problem. It must accept variable keyword-arguments only, and should access the samples for each random variable by indexing via the RV names (node labels). It should return the model quantity as computed from the samples, that \(g_{\mathrm{data}}\) observed.

\(g_{\mathrm{data}}\) should be a fixed value whose shape is broadcast-able with the return shape of \(g\). It defaults to \(0\) if not explicitly set.

\(\mathrm{norm}\) should be a suitable norm to take on the difference between the model quantity as predicted by the samples (\(g\)) and the observed data (\(g_{\mathrm{data}}\)). It must return a scalar value. The default is the 2-norm.

Source code in src/causalprog/causal_problem/components.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def __init__(
    self,
    *effect_handlers: ModelMask,
    model_quantity: Callable[..., npt.ArrayLike],
    outer_norm: Callable[[npt.ArrayLike], float] | None = None,
    data: npt.ArrayLike = 0.0,
    tolerance: float = 1.0e-6,
) -> None:
    r"""
    Create a new constraint.

    Constraints have the form

    $$ c(\theta) :=
    \mathrm{norm}\left( g(\theta)
    - g_{\mathrm{data}} \right)
    - \epsilon $$

    where;
    - $\mathrm{norm}$ is the outer norm of the constraint (`outer_norm`),
    - $g(\theta)$ is the model quantity involved in the constraint
        (`model_quantity`),
    - $g_{\mathrm{data}}$ is the observed data (`data`),
    - $\epsilon$ is the tolerance in the data (`tolerance`).

    In a causal problem, each constraint appears as the condition $c(\theta)\leq 0$
    in the minimisation / maximisation (hence the inclusion of the $-\epsilon$
    term within $c(\theta)$ itself).

    $g$ should be a (possibly vector-valued) function that acts on (a subset of)
    samples from the random variables of the causal problem. It must accept
    variable keyword-arguments only, and should access the samples for each random
    variable by indexing via the RV names (node labels). It should return the
    model quantity as computed from the samples, that $g_{\mathrm{data}}$ observed.

    $g_{\mathrm{data}}$ should be a fixed value whose shape is broadcast-able with
    the return shape of $g$. It defaults to $0$ if not explicitly set.

    $\mathrm{norm}$ should be a suitable norm to take on the difference between the
    model quantity as predicted by the samples ($g$) and the observed data
    ($g_{\mathrm{data}}$). It must return a scalar value. The default is the 2-norm.
    """
    super().__init__(*effect_handlers, do_with_samples=model_quantity)

    if outer_norm is None:
        self._outer_norm = jnp.linalg.vector_norm
    else:
        self._outer_norm = outer_norm

    self.data = data
    self.tolerance = tolerance

handlers

Container class for specifying effect handlers that need to be applied at runtime.

HandlerToApply dataclass

Specifies a handler that needs to be applied to a model at runtime.

Source code in src/causalprog/causal_problem/handlers.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@dataclass
class HandlerToApply:
    """Specifies a handler that needs to be applied to a model at runtime."""

    handler: EffectHandler
    options: dict[str, Any] = field(default_factory=dict)

    @classmethod
    def from_pair(cls, pair: Sequence) -> "HandlerToApply":
        """
        Create an instance from an effect handler and its options.

        The two objects should be passed in as the elements of a container of length
        2. They can be passed in any order;
        - One element must be a dictionary, which will be interpreted as the `options`
            for the effect handler.
        - The other element must be callable, and will be interpreted as the `handler`
            itself.

        Args:
            pair: Container of two elements, one being the effect handler callable and
                the other being the options to pass to it (as a dictionary).

        Returns:
            Class instance corresponding to the effect handler and options passed.

        """
        if len(pair) != 2:  # noqa: PLR2004
            msg = (
                f"{cls.__name__} can only be constructed from a container of 2 elements"
            )
            raise ValueError(msg)

        # __post_init__ will catch cases when the incorrect types for one or both items
        # is passed, so we can just naively if-else here.
        handler: EffectHandler
        options: dict
        if callable(pair[0]):
            handler = pair[0]
            options = pair[1]
        else:
            handler = pair[1]
            options = pair[0]

        return cls(handler=handler, options=options)

    def __post_init__(self) -> None:
        """
        Validate set attributes.

        - The handler is a callable object.
        - The options have been passed as a dictionary of keyword-value pairs.
        """
        if not callable(self.handler):
            msg = f"{type(self.handler).__name__} is not callable."
            raise TypeError(msg)
        if not isinstance(self.options, dict):
            msg = (
                "Options should be dictionary mapping option arguments to values "
                f"(got {type(self.options).__name__})."
            )
            raise TypeError(msg)

    def __eq__(self, other: object) -> bool:
        """
        Equality operation.

        `HandlerToApply`s are considered equal if they use the same handler function and
        provide the same options to this function.

        Comparison to other types returns `False`.
        """
        return (
            isinstance(other, HandlerToApply)
            and self.handler is other.handler
            and self.options == other.options
        )
__eq__(other)

Equality operation.

HandlerToApplys are considered equal if they use the same handler function and provide the same options to this function.

Comparison to other types returns False.

Source code in src/causalprog/causal_problem/handlers.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __eq__(self, other: object) -> bool:
    """
    Equality operation.

    `HandlerToApply`s are considered equal if they use the same handler function and
    provide the same options to this function.

    Comparison to other types returns `False`.
    """
    return (
        isinstance(other, HandlerToApply)
        and self.handler is other.handler
        and self.options == other.options
    )
__post_init__()

Validate set attributes.

  • The handler is a callable object.
  • The options have been passed as a dictionary of keyword-value pairs.
Source code in src/causalprog/causal_problem/handlers.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __post_init__(self) -> None:
    """
    Validate set attributes.

    - The handler is a callable object.
    - The options have been passed as a dictionary of keyword-value pairs.
    """
    if not callable(self.handler):
        msg = f"{type(self.handler).__name__} is not callable."
        raise TypeError(msg)
    if not isinstance(self.options, dict):
        msg = (
            "Options should be dictionary mapping option arguments to values "
            f"(got {type(self.options).__name__})."
        )
        raise TypeError(msg)
from_pair(pair) classmethod

Create an instance from an effect handler and its options.

The two objects should be passed in as the elements of a container of length 2. They can be passed in any order; - One element must be a dictionary, which will be interpreted as the options for the effect handler. - The other element must be callable, and will be interpreted as the handler itself.

Parameters:

Name Type Description Default
pair Sequence

Container of two elements, one being the effect handler callable and the other being the options to pass to it (as a dictionary).

required

Returns:

Type Description
HandlerToApply

Class instance corresponding to the effect handler and options passed.

Source code in src/causalprog/causal_problem/handlers.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@classmethod
def from_pair(cls, pair: Sequence) -> "HandlerToApply":
    """
    Create an instance from an effect handler and its options.

    The two objects should be passed in as the elements of a container of length
    2. They can be passed in any order;
    - One element must be a dictionary, which will be interpreted as the `options`
        for the effect handler.
    - The other element must be callable, and will be interpreted as the `handler`
        itself.

    Args:
        pair: Container of two elements, one being the effect handler callable and
            the other being the options to pass to it (as a dictionary).

    Returns:
        Class instance corresponding to the effect handler and options passed.

    """
    if len(pair) != 2:  # noqa: PLR2004
        msg = (
            f"{cls.__name__} can only be constructed from a container of 2 elements"
        )
        raise ValueError(msg)

    # __post_init__ will catch cases when the incorrect types for one or both items
    # is passed, so we can just naively if-else here.
    handler: EffectHandler
    options: dict
    if callable(pair[0]):
        handler = pair[0]
        options = pair[1]
    else:
        handler = pair[1]
        options = pair[0]

    return cls(handler=handler, options=options)

graph

Creation and storage of graphs.

graph

Graph storage.

Graph

Bases: Labelled

A directed acyclic graph that represents a causality tree.

Source code in src/causalprog/graph/graph.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
class Graph(Labelled):
    """A directed acyclic graph that represents a causality tree."""

    _nodes_by_label: dict[str, Node]

    def __init__(self, *, label: str, graph: nx.DiGraph | None = None) -> None:
        """
        Create a graph.

        Args:
            label: A label to identify the graph
            graph: A networkx graph to base this graph on

        """
        super().__init__(label=label)
        self._nodes_by_label = {}
        if graph is None:
            graph = nx.DiGraph()

        self._graph = graph
        for node in graph.nodes:
            self._nodes_by_label[node.label] = node

    def get_node(self, label: str) -> Node:
        """
        Get a node from its label.

        Args:
            label: The label

        Returns:
            The node

        """
        node = self._nodes_by_label.get(label, None)
        if not node:
            msg = f'Node not found with label "{label}"'
            raise KeyError(msg)
        return node

    def add_node(self, node: Node) -> None:
        """
        Add a node to the graph.

        Args:
            node: The node to add

        """
        if node.label in self._nodes_by_label:
            msg = f"Duplicate node label: {node.label}"
            raise ValueError(msg)
        self._nodes_by_label[node.label] = node
        self._graph.add_node(node)

    def add_edge(self, start_node: Node | str, end_node: Node | str) -> None:
        """
        Add a directed edge to the graph.

        Adding an edge between nodes not currently in the graph,
        will cause said nodes to be added to the graph along with
        the edge.

        Args:
            start_node: The node that the edge points from
            end_node: The node that the edge points to

        """
        if isinstance(start_node, str):
            start_node = self.get_node(start_node)
        if isinstance(end_node, str):
            end_node = self.get_node(end_node)
        if start_node.label not in self._nodes_by_label:
            self.add_node(start_node)
        if end_node.label not in self._nodes_by_label:
            self.add_node(end_node)
        for node_to_check in (start_node, end_node):
            if node_to_check != self._nodes_by_label[node_to_check.label]:
                msg = "Invalid node: {node_to_check}"
                raise ValueError(msg)
        self._graph.add_edge(start_node, end_node)

    @property
    def parameter_nodes(self) -> tuple[Node, ...]:
        """
        Returns all parameter nodes in the graph.

        The returned tuple uses the `ordered_nodes` property to obtain the parameter
        nodes so that a natural "fixed order" is given to the parameters. When parameter
        values are given as inputs to the causal estimand and / or constraint functions,
        they will ideally be given as a single vector of parameter values, in which case
        a fixed ordering for the parameters is necessary to make an association to the
        components of the given input vector.

        Returns:
            Parameter nodes

        """
        return tuple(node for node in self.ordered_nodes if node.is_parameter)

    @property
    def predecessors(self) -> dict[Node, tuple[Node, ...]]:
        """
        Get predecessors of every node.

        Returns:
            Mapping of each Node to its predecessor Nodes

        """
        return {node: tuple(self._graph.predecessors(node)) for node in self.nodes}

    @property
    def successors(self) -> dict[Node, tuple[Node, ...]]:
        """
        Get successors of every node.

        Returns:
            Mapping of each Node to its successor Nodes.

        """
        return {node: tuple(self._graph.successors(node)) for node in self.nodes}

    @property
    def nodes(self) -> tuple[Node, ...]:
        """
        Get the nodes of the graph, with no enforced ordering.

        Returns:
            A list of all the nodes in the graph.

        See Also:
            ordered_nodes: Fetch an ordered list of the nodes in the graph.

        """
        return tuple(self._graph.nodes())

    @property
    def edges(self) -> tuple[tuple[Node, Node], ...]:
        """
        Get the edges of the graph.

        Returns:
            A tuple of all the edges in the graph.

        """
        return tuple(self._graph.edges())

    @property
    def ordered_nodes(self) -> tuple[Node, ...]:
        """
        Nodes ordered so that each node appears after its dependencies.

        Returns:
            A list of all the nodes, ordered such that each node
                appears after all its dependencies.

        """
        if not nx.is_directed_acyclic_graph(self._graph):
            msg = "Graph is not acyclic."
            raise RuntimeError(msg)
        return tuple(nx.topological_sort(self._graph))

    @property
    def ordered_dist_nodes(self) -> tuple[Node, ...]:
        """
        `DistributionNode`s in dependency order.

        Each `DistributionNode` in the returned list appears after all its
        dependencies. Order is derived from `self.ordered_nodes`, selecting
        only those nodes where `is_distribution` is `True`.
        """
        return tuple(node for node in self.ordered_nodes if node.is_distribution)

    def roots_down_to_outcome(
        self,
        outcome_node_label: str,
    ) -> tuple[Node, ...]:
        """
        Get ordered list of nodes that outcome depends on.

        Nodes are ordered so that each node appears after its dependencies.

        Args:
            outcome_node_label: The label of the outcome node

        Returns:
            A list of the nodes, ordered from root nodes to the outcome Node.

        """
        outcome = self.get_node(outcome_node_label)
        ancestors = nx.ancestors(self._graph, outcome)
        return tuple(
            node for node in self.ordered_nodes if node == outcome or node in ancestors
        )

    def model(self, **parameter_values: npt.ArrayLike) -> dict[str, npt.ArrayLike]:
        """
        Model corresponding to the `Graph`'s structure.

        The model created takes values of the nodes that are parameter as keyword
        arguments. Names of the keyword arguments should match the labels of the
        `ParameterNode`s, and their values should be the values of those parameters.

        The method returns a dictionary recording the mode sites that are created.
        This means that the model can be 'extended' further by defining additional
        sites in a wrapper around this method.

        Args:
            parameter_values: Names of the keyword arguments should match the labels
                of the `ParameterNode`s, and their values should be the values of those
                parameters.

        Returns:
            Mapping of non-`ParameterNode` `Node` labels to the site objects created
                for these nodes.

        """
        # Confirm that all `ParameterNode`s have been assigned a value.
        for node in self.parameter_nodes:
            if node.label not in parameter_values:
                msg = f"ParameterNode '{node.label}' not assigned"
                raise KeyError(msg)

        # Build model sequentially, using the node_order to inform the
        # construction process.
        node_record: dict[str, npt.ArrayLike] = {}
        for node in self.ordered_dist_nodes:
            node_record[node.label] = node.create_model_site(
                **parameter_values,  # All nodes require knowledge of the parameters
                **node_record,  # and any dependent nodes we have already visited
            )

        return node_record
edges property

Get the edges of the graph.

Returns:

Type Description
tuple[tuple[Node, Node], ...]

A tuple of all the edges in the graph.

nodes property

Get the nodes of the graph, with no enforced ordering.

Returns:

Type Description
tuple[Node, ...]

A list of all the nodes in the graph.

See Also

ordered_nodes: Fetch an ordered list of the nodes in the graph.

ordered_dist_nodes property

DistributionNodes in dependency order.

Each DistributionNode in the returned list appears after all its dependencies. Order is derived from self.ordered_nodes, selecting only those nodes where is_distribution is True.

ordered_nodes property

Nodes ordered so that each node appears after its dependencies.

Returns:

Type Description
tuple[Node, ...]

A list of all the nodes, ordered such that each node appears after all its dependencies.

parameter_nodes property

Returns all parameter nodes in the graph.

The returned tuple uses the ordered_nodes property to obtain the parameter nodes so that a natural "fixed order" is given to the parameters. When parameter values are given as inputs to the causal estimand and / or constraint functions, they will ideally be given as a single vector of parameter values, in which case a fixed ordering for the parameters is necessary to make an association to the components of the given input vector.

Returns:

Type Description
tuple[Node, ...]

Parameter nodes

predecessors property

Get predecessors of every node.

Returns:

Type Description
dict[Node, tuple[Node, ...]]

Mapping of each Node to its predecessor Nodes

successors property

Get successors of every node.

Returns:

Type Description
dict[Node, tuple[Node, ...]]

Mapping of each Node to its successor Nodes.

__init__(*, label, graph=None)

Create a graph.

Parameters:

Name Type Description Default
label str

A label to identify the graph

required
graph DiGraph | None

A networkx graph to base this graph on

None
Source code in src/causalprog/graph/graph.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def __init__(self, *, label: str, graph: nx.DiGraph | None = None) -> None:
    """
    Create a graph.

    Args:
        label: A label to identify the graph
        graph: A networkx graph to base this graph on

    """
    super().__init__(label=label)
    self._nodes_by_label = {}
    if graph is None:
        graph = nx.DiGraph()

    self._graph = graph
    for node in graph.nodes:
        self._nodes_by_label[node.label] = node
add_edge(start_node, end_node)

Add a directed edge to the graph.

Adding an edge between nodes not currently in the graph, will cause said nodes to be added to the graph along with the edge.

Parameters:

Name Type Description Default
start_node Node | str

The node that the edge points from

required
end_node Node | str

The node that the edge points to

required
Source code in src/causalprog/graph/graph.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def add_edge(self, start_node: Node | str, end_node: Node | str) -> None:
    """
    Add a directed edge to the graph.

    Adding an edge between nodes not currently in the graph,
    will cause said nodes to be added to the graph along with
    the edge.

    Args:
        start_node: The node that the edge points from
        end_node: The node that the edge points to

    """
    if isinstance(start_node, str):
        start_node = self.get_node(start_node)
    if isinstance(end_node, str):
        end_node = self.get_node(end_node)
    if start_node.label not in self._nodes_by_label:
        self.add_node(start_node)
    if end_node.label not in self._nodes_by_label:
        self.add_node(end_node)
    for node_to_check in (start_node, end_node):
        if node_to_check != self._nodes_by_label[node_to_check.label]:
            msg = "Invalid node: {node_to_check}"
            raise ValueError(msg)
    self._graph.add_edge(start_node, end_node)
add_node(node)

Add a node to the graph.

Parameters:

Name Type Description Default
node Node

The node to add

required
Source code in src/causalprog/graph/graph.py
50
51
52
53
54
55
56
57
58
59
60
61
62
def add_node(self, node: Node) -> None:
    """
    Add a node to the graph.

    Args:
        node: The node to add

    """
    if node.label in self._nodes_by_label:
        msg = f"Duplicate node label: {node.label}"
        raise ValueError(msg)
    self._nodes_by_label[node.label] = node
    self._graph.add_node(node)
get_node(label)

Get a node from its label.

Parameters:

Name Type Description Default
label str

The label

required

Returns:

Type Description
Node

The node

Source code in src/causalprog/graph/graph.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def get_node(self, label: str) -> Node:
    """
    Get a node from its label.

    Args:
        label: The label

    Returns:
        The node

    """
    node = self._nodes_by_label.get(label, None)
    if not node:
        msg = f'Node not found with label "{label}"'
        raise KeyError(msg)
    return node
model(**parameter_values)

Model corresponding to the Graph's structure.

The model created takes values of the nodes that are parameter as keyword arguments. Names of the keyword arguments should match the labels of the ParameterNodes, and their values should be the values of those parameters.

The method returns a dictionary recording the mode sites that are created. This means that the model can be 'extended' further by defining additional sites in a wrapper around this method.

Parameters:

Name Type Description Default
parameter_values ArrayLike

Names of the keyword arguments should match the labels of the ParameterNodes, and their values should be the values of those parameters.

{}

Returns:

Type Description
dict[str, ArrayLike]

Mapping of non-ParameterNode Node labels to the site objects created for these nodes.

Source code in src/causalprog/graph/graph.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def model(self, **parameter_values: npt.ArrayLike) -> dict[str, npt.ArrayLike]:
    """
    Model corresponding to the `Graph`'s structure.

    The model created takes values of the nodes that are parameter as keyword
    arguments. Names of the keyword arguments should match the labels of the
    `ParameterNode`s, and their values should be the values of those parameters.

    The method returns a dictionary recording the mode sites that are created.
    This means that the model can be 'extended' further by defining additional
    sites in a wrapper around this method.

    Args:
        parameter_values: Names of the keyword arguments should match the labels
            of the `ParameterNode`s, and their values should be the values of those
            parameters.

    Returns:
        Mapping of non-`ParameterNode` `Node` labels to the site objects created
            for these nodes.

    """
    # Confirm that all `ParameterNode`s have been assigned a value.
    for node in self.parameter_nodes:
        if node.label not in parameter_values:
            msg = f"ParameterNode '{node.label}' not assigned"
            raise KeyError(msg)

    # Build model sequentially, using the node_order to inform the
    # construction process.
    node_record: dict[str, npt.ArrayLike] = {}
    for node in self.ordered_dist_nodes:
        node_record[node.label] = node.create_model_site(
            **parameter_values,  # All nodes require knowledge of the parameters
            **node_record,  # and any dependent nodes we have already visited
        )

    return node_record
roots_down_to_outcome(outcome_node_label)

Get ordered list of nodes that outcome depends on.

Nodes are ordered so that each node appears after its dependencies.

Parameters:

Name Type Description Default
outcome_node_label str

The label of the outcome node

required

Returns:

Type Description
tuple[Node, ...]

A list of the nodes, ordered from root nodes to the outcome Node.

Source code in src/causalprog/graph/graph.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def roots_down_to_outcome(
    self,
    outcome_node_label: str,
) -> tuple[Node, ...]:
    """
    Get ordered list of nodes that outcome depends on.

    Nodes are ordered so that each node appears after its dependencies.

    Args:
        outcome_node_label: The label of the outcome node

    Returns:
        A list of the nodes, ordered from root nodes to the outcome Node.

    """
    outcome = self.get_node(outcome_node_label)
    ancestors = nx.ancestors(self._graph, outcome)
    return tuple(
        node for node in self.ordered_nodes if node == outcome or node in ancestors
    )

node

Graph nodes.

base

Base graph node.

Node

Bases: Labelled

An abstract node in a graph.

Source code in src/causalprog/graph/node/base.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class Node(Labelled):
    """An abstract node in a graph."""

    def __init__(
        self,
        *,
        label: str,
        is_parameter: bool = False,
        is_distribution: bool = False,
    ) -> None:
        """
        Initialise.

        Parameters (equivalently `ParameterNode`s) represent Nodes that do not have
        random variables attached. Instead, these nodes represent values that are passed
        to nodes that _do_ have distributions attached, and the value of the "parameter"
        node is used as a fixed value when constructing the dependent node's
        distribution. The set of parameter nodes is the collection of "parameter"s over
        which one should want to optimise the causal estimand (subject to any
        constraints), and as such the value that a "parameter node" passes to its
        dependent nodes will vary as the optimiser runs and explores the solution space.

        Note that a "constant parameter" is distinct from a "parameter" in the sense
        that a constant parameter is _not_ added to the collection of parameters over
        which we will want to optimise (it is a hard-coded, fixed value).

        Distributions (equivalently `DistributionNode`s) are Nodes that represent
        random variables described by probability distributions.

        Args:
            label: A unique label to identify the node
            is_parameter: Is the node a parameter?
            is_distribution: Is the node a distribution?

        """
        super().__init__(label=label)
        self._is_parameter = is_parameter
        self._is_distribution = is_distribution

    @abstractmethod
    def sample(
        self,
        parameter_values: dict[str, float],
        sampled_dependencies: dict[str, npt.NDArray[float]],
        samples: int,
        *,
        rng_key: jax.Array,
    ) -> float:
        """
        Sample a value from the node.

        Args:
            parameter_values: Values to be taken by parameters
            sampled_dependencies: Values taken by dependencies of this node
            samples: Number of samples
            rng_key: Random key

        Returns:
            Sample value of this node

        """

    @abstractmethod
    def copy(self) -> Node:
        """
        Make a copy of a node.

        Some inner objects stored inside the node may not be copied when this is called.
        Modifying some inner objects of a copy made using this may affect the original
        node.

        Returns:
            A copy of the node

        """

    @property
    def is_parameter(self) -> bool:
        """
        Identify if the node is an parameter.

        Returns:
            True if the node is an parameter

        """
        return self._is_parameter

    @property
    def is_distribution(self) -> bool:
        """
        Identify if the node is an distribution.

        Returns:
            True if the node is an distribution

        """
        return self._is_distribution

    @property
    @abstractmethod
    def constant_parameters(self) -> dict[str, float]:
        """
        Named constants that this node depends on.

        Returns:
            A dictionary of the constant parameter names (keys) and their corresponding
            values

        """

    @property
    @abstractmethod
    def parameters(self) -> dict[str, str]:
        """
        Mapping of distribution parameter names to the nodes they are represented by.

        Returns:
            Mapping of distribution parameters (keys) to the corresponding label of the
            node that represents this parameter (value).

        """
constant_parameters abstractmethod property

Named constants that this node depends on.

Returns:

Type Description
dict[str, float]

A dictionary of the constant parameter names (keys) and their corresponding

dict[str, float]

values

is_distribution property

Identify if the node is an distribution.

Returns:

Type Description
bool

True if the node is an distribution

is_parameter property

Identify if the node is an parameter.

Returns:

Type Description
bool

True if the node is an parameter

parameters abstractmethod property

Mapping of distribution parameter names to the nodes they are represented by.

Returns:

Type Description
dict[str, str]

Mapping of distribution parameters (keys) to the corresponding label of the

dict[str, str]

node that represents this parameter (value).

__init__(*, label, is_parameter=False, is_distribution=False)

Initialise.

Parameters (equivalently ParameterNodes) represent Nodes that do not have random variables attached. Instead, these nodes represent values that are passed to nodes that do have distributions attached, and the value of the "parameter" node is used as a fixed value when constructing the dependent node's distribution. The set of parameter nodes is the collection of "parameter"s over which one should want to optimise the causal estimand (subject to any constraints), and as such the value that a "parameter node" passes to its dependent nodes will vary as the optimiser runs and explores the solution space.

Note that a "constant parameter" is distinct from a "parameter" in the sense that a constant parameter is not added to the collection of parameters over which we will want to optimise (it is a hard-coded, fixed value).

Distributions (equivalently DistributionNodes) are Nodes that represent random variables described by probability distributions.

Parameters:

Name Type Description Default
label str

A unique label to identify the node

required
is_parameter bool

Is the node a parameter?

False
is_distribution bool

Is the node a distribution?

False
Source code in src/causalprog/graph/node/base.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    *,
    label: str,
    is_parameter: bool = False,
    is_distribution: bool = False,
) -> None:
    """
    Initialise.

    Parameters (equivalently `ParameterNode`s) represent Nodes that do not have
    random variables attached. Instead, these nodes represent values that are passed
    to nodes that _do_ have distributions attached, and the value of the "parameter"
    node is used as a fixed value when constructing the dependent node's
    distribution. The set of parameter nodes is the collection of "parameter"s over
    which one should want to optimise the causal estimand (subject to any
    constraints), and as such the value that a "parameter node" passes to its
    dependent nodes will vary as the optimiser runs and explores the solution space.

    Note that a "constant parameter" is distinct from a "parameter" in the sense
    that a constant parameter is _not_ added to the collection of parameters over
    which we will want to optimise (it is a hard-coded, fixed value).

    Distributions (equivalently `DistributionNode`s) are Nodes that represent
    random variables described by probability distributions.

    Args:
        label: A unique label to identify the node
        is_parameter: Is the node a parameter?
        is_distribution: Is the node a distribution?

    """
    super().__init__(label=label)
    self._is_parameter = is_parameter
    self._is_distribution = is_distribution
copy() abstractmethod

Make a copy of a node.

Some inner objects stored inside the node may not be copied when this is called. Modifying some inner objects of a copy made using this may affect the original node.

Returns:

Type Description
Node

A copy of the node

Source code in src/causalprog/graph/node/base.py
77
78
79
80
81
82
83
84
85
86
87
88
89
@abstractmethod
def copy(self) -> Node:
    """
    Make a copy of a node.

    Some inner objects stored inside the node may not be copied when this is called.
    Modifying some inner objects of a copy made using this may affect the original
    node.

    Returns:
        A copy of the node

    """
sample(parameter_values, sampled_dependencies, samples, *, rng_key) abstractmethod

Sample a value from the node.

Parameters:

Name Type Description Default
parameter_values dict[str, float]

Values to be taken by parameters

required
sampled_dependencies dict[str, NDArray[float]]

Values taken by dependencies of this node

required
samples int

Number of samples

required
rng_key Array

Random key

required

Returns:

Type Description
float

Sample value of this node

Source code in src/causalprog/graph/node/base.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@abstractmethod
def sample(
    self,
    parameter_values: dict[str, float],
    sampled_dependencies: dict[str, npt.NDArray[float]],
    samples: int,
    *,
    rng_key: jax.Array,
) -> float:
    """
    Sample a value from the node.

    Args:
        parameter_values: Values to be taken by parameters
        sampled_dependencies: Values taken by dependencies of this node
        samples: Number of samples
        rng_key: Random key

    Returns:
        Sample value of this node

    """

distribution

Graph nodes representing distributions.

DistributionNode

Bases: Node

A node containing a distribution.

Source code in src/causalprog/graph/node/distribution.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class DistributionNode(Node):
    """A node containing a distribution."""

    def __init__(
        self,
        distribution: type,
        *,
        label: str,
        parameters: dict[str, str] | None = None,
        constant_parameters: dict[str, float] | None = None,
    ) -> None:
        """
        Initialise.

        Args:
            distribution: The distribution
            label: A unique label to identify the node
            parameters: A dictionary of parameters
            constant_parameters: A dictionary of constant parameters

        """
        self._dist = distribution
        self._constant_parameters = constant_parameters if constant_parameters else {}
        self._parameters = parameters if parameters else {}
        super().__init__(label=label, is_distribution=True)

    @override
    def sample(
        self,
        parameter_values: dict[str, float],
        sampled_dependencies: dict[str, npt.NDArray[float]],
        samples: int,
        *,
        rng_key: jax.Array,
    ) -> npt.NDArray[float]:
        d = self._dist(
            # Pass in node values derived from construction so far
            **{
                native_name: sampled_dependencies[node_name]
                for native_name, node_name in self.parameters.items()
            },
            # Pass in any constant parameters this node sets
            **self.constant_parameters,
        )
        return numpyro.sample(
            self.label,
            d,
            rng_key=rng_key,
            sample_shape=(samples,) if d.batch_shape == () and samples > 1 else (),
        )

    @override
    def copy(self) -> Node:
        return DistributionNode(
            self._dist,
            label=self.label,
            parameters=dict(self._parameters),
            constant_parameters=dict(self._constant_parameters.items()),
        )

    @override
    def __repr__(self) -> str:
        r = f'DistributionNode({self._dist.__name__}, label="{self.label}"'
        if len(self._parameters) > 0:
            r += f", parameters={self._parameters}"
        if len(self._constant_parameters) > 0:
            r += f", constant_parameters={self._constant_parameters}"
        return r

    @override
    @property
    def constant_parameters(self) -> dict[str, float]:
        return self._constant_parameters

    @override
    @property
    def parameters(self) -> dict[str, str]:
        return self._parameters

    def create_model_site(self, **dependent_nodes: jax.Array) -> npt.ArrayLike:
        """
        Create a model site for the (conditional) distribution attached to this node.

        `dependent_nodes` should contain keyword arguments mapping dependent node names
        to the values that those nodes are taking (`ParameterNode`s), or the sampling
        object for those nodes (`DistributionNode`s). These are passed to
        `self._dist` as keyword arguments to construct the sample-able object
        representing this node.
        """
        return numpyro.sample(
            self.label,
            self._dist(
                # Pass in node values derived from construction so far
                **{
                    native_name: dependent_nodes[node_name]
                    for native_name, node_name in self.parameters.items()
                },
                # Pass in any constant parameters this node sets
                **self.constant_parameters,
            ),
        )
__init__(distribution, *, label, parameters=None, constant_parameters=None)

Initialise.

Parameters:

Name Type Description Default
distribution type

The distribution

required
label str

A unique label to identify the node

required
parameters dict[str, str] | None

A dictionary of parameters

None
constant_parameters dict[str, float] | None

A dictionary of constant parameters

None
Source code in src/causalprog/graph/node/distribution.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(
    self,
    distribution: type,
    *,
    label: str,
    parameters: dict[str, str] | None = None,
    constant_parameters: dict[str, float] | None = None,
) -> None:
    """
    Initialise.

    Args:
        distribution: The distribution
        label: A unique label to identify the node
        parameters: A dictionary of parameters
        constant_parameters: A dictionary of constant parameters

    """
    self._dist = distribution
    self._constant_parameters = constant_parameters if constant_parameters else {}
    self._parameters = parameters if parameters else {}
    super().__init__(label=label, is_distribution=True)
create_model_site(**dependent_nodes)

Create a model site for the (conditional) distribution attached to this node.

dependent_nodes should contain keyword arguments mapping dependent node names to the values that those nodes are taking (ParameterNodes), or the sampling object for those nodes (DistributionNodes). These are passed to self._dist as keyword arguments to construct the sample-able object representing this node.

Source code in src/causalprog/graph/node/distribution.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def create_model_site(self, **dependent_nodes: jax.Array) -> npt.ArrayLike:
    """
    Create a model site for the (conditional) distribution attached to this node.

    `dependent_nodes` should contain keyword arguments mapping dependent node names
    to the values that those nodes are taking (`ParameterNode`s), or the sampling
    object for those nodes (`DistributionNode`s). These are passed to
    `self._dist` as keyword arguments to construct the sample-able object
    representing this node.
    """
    return numpyro.sample(
        self.label,
        self._dist(
            # Pass in node values derived from construction so far
            **{
                native_name: dependent_nodes[node_name]
                for native_name, node_name in self.parameters.items()
            },
            # Pass in any constant parameters this node sets
            **self.constant_parameters,
        ),
    )

parameter

Graph nodes representing parameters.

ParameterNode

Bases: Node

A node containing a parameter.

ParameterNodes differ from DistributionNodes in that they do not have an attached distribution, but rather represent a parameter that contributes to the shape of one (or more) DistributionNodes.

The collection of parameters described by ParameterNodes forms the set of variables that will be optimised over in the corresponding CausalProblem.

ParameterNodes should not be used to encode constant values used by DistributionNodes. Such constant values should be given to the necessary DistributionNodes directly as constant_parameters.

Source code in src/causalprog/graph/node/parameter.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class ParameterNode(Node):
    """
    A node containing a parameter.

    `ParameterNode`s differ from `DistributionNode`s in that they do not have an
    attached distribution, but rather represent a parameter that contributes
    to the shape of one (or more) `DistributionNode`s.

    The collection of parameters described by `ParameterNode`s forms the set of
    variables that will be optimised over in the corresponding `CausalProblem`.

    `ParameterNode`s should not be used to encode constant values used by
    `DistributionNode`s. Such constant values should be given to the necessary
    `DistributionNode`s directly as `constant_parameters`.
    """

    def __init__(self, *, label: str) -> None:
        """
        Initialise.

        Args:
            label: A unique label to identify the node

        """
        super().__init__(label=label, is_parameter=True)

    @override
    def sample(
        self,
        parameter_values: dict[str, float],
        sampled_dependencies: dict[str, npt.ArrayLike],
        samples: int,
        *,
        rng_key: jax.Array,
    ) -> npt.ArrayLike:
        if self.label not in parameter_values:
            msg = f"Missing input for parameter node: {self.label}."
            raise ValueError(msg)
        return jnp.full(samples, parameter_values[self.label])

    @override
    def copy(self) -> Node:
        return ParameterNode(label=self.label)

    @override
    def __repr__(self) -> str:
        return f'ParameterNode(label="{self.label}")'

    @override
    @property
    def constant_parameters(self) -> dict[str, float]:
        return {}

    @override
    @property
    def parameters(self) -> dict[str, str]:
        return {}
__init__(*, label)

Initialise.

Parameters:

Name Type Description Default
label str

A unique label to identify the node

required
Source code in src/causalprog/graph/node/parameter.py
27
28
29
30
31
32
33
34
35
def __init__(self, *, label: str) -> None:
    """
    Initialise.

    Args:
        label: A unique label to identify the node

    """
    super().__init__(label=label, is_parameter=True)

solvers

Solvers for Causal Problems.

sgd

Minimisation via Stochastic Gradient Descent.

stochastic_gradient_descent(obj_fn, initial_guess, *, convergence_criteria=None, fn_args=None, fn_kwargs=None, learning_rate=0.1, maxiter=1000, optimiser=None, tolerance=1e-08)

Minimise a function of one argument using Stochastic Gradient Descent (SGD).

The obj_fn provided will be minimised over its first argument. If you wish to minimise a function over a different argument, or multiple arguments, wrap it in a suitable lambda expression that has the correct call signature. For example, to minimise a function f(x, y, z) over y and z, use g = lambda yz, x: f(x, yz[0], yz[1]), and pass g in as obj_fn. Note that you will also need to provide a constant value for x via fn_args or fn_kwargs.

The fn_args and fn_kwargs keys can be used to supply additional parameters that need to be passed to obj_fn, but which should be held constant.

SGD terminates when the convergence_criteria is found to be smaller than the tolerance. That is, when convergence_criteria(objective_value, gradient_value) <= tolerance is found to be True, the algorithm considers a minimum to have been found. The default condition under which the algorithm terminates is when the norm of the gradient at the current argument value is smaller than the provided tolerance.

The optimiser to use can be selected by passing in a suitable optax optimiser via the optimiser command. By default, optax.adams is used with the supplied learning_rate. Providing an explicit value for optimiser will result in the learning_rate argument being ignored.

Parameters:

Name Type Description Default
obj_fn Callable[[PyTree], ArrayLike]

Function to be minimised over its first argument.

required
initial_guess PyTree

Initial guess for the minimising argument.

required
convergence_criteria Callable[[PyTree, PyTree], ArrayLike] | None

The quantity that will be tested against tolerance, to determine whether the method has converged to a minimum. It should be a callable that takes the current value of obj_fn as its 1st argument, and the current value of the gradient of obj_fn as its 2nd argument. The default criteria is the l2-norm of the gradient.

None
fn_args tuple | None

Positional arguments to be passed to obj_fn, and held constant.

None
fn_kwargs dict | None

Keyword arguments to be passed to obj_fn, and held constant.

None
learning_rate float

Default learning rate (or step size) to use when using the default optimiser. No effect if optimiser is provided explicitly.

0.1
maxiter int

Maximum number of iterations to perform. An error will be reported if this number of iterations is exceeded.

1000
optimiser GradientTransformationExtraArgs | None

The optax optimiser to use during the update step.

None
tolerance float

tolerance used when determining if a minimum has been found.

1e-08

Returns:

Type Description
SolverResult

Minimising argument of obj_fn.

SolverResult

Value of obj_fn at the minimum.

SolverResult

Gradient of obj_fn at the minimum.

SolverResult

Number of iterations performed.

Source code in src/causalprog/solvers/sgd.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def stochastic_gradient_descent(
    obj_fn: Callable[[PyTree], npt.ArrayLike],
    initial_guess: PyTree,
    *,
    convergence_criteria: Callable[[PyTree, PyTree], npt.ArrayLike] | None = None,
    fn_args: tuple | None = None,
    fn_kwargs: dict | None = None,
    learning_rate: float = 1.0e-1,
    maxiter: int = 1000,
    optimiser: optax.GradientTransformationExtraArgs | None = None,
    tolerance: float = 1.0e-8,
) -> SolverResult:
    """
    Minimise a function of one argument using Stochastic Gradient Descent (SGD).

    The `obj_fn` provided will be minimised over its first argument. If you wish to
    minimise a function over a different argument, or multiple arguments, wrap it in a
    suitable `lambda` expression that has the correct call signature. For example, to
    minimise a function `f(x, y, z)` over `y` and `z`, use
    `g = lambda yz, x: f(x, yz[0], yz[1])`, and pass `g` in as `obj_fn`. Note that
    you will also need to provide a constant value for `x` via `fn_args` or `fn_kwargs`.

    The `fn_args` and `fn_kwargs` keys can be used to supply additional parameters that
    need to be passed to `obj_fn`, but which should be held constant.

    SGD terminates when the `convergence_criteria` is found to be smaller than the
    `tolerance`. That is, when
    `convergence_criteria(objective_value, gradient_value) <= tolerance` is found to
    be `True`, the algorithm considers a minimum to have been found. The default
    condition under which the algorithm terminates is when the norm of the gradient
    at the current argument value is smaller than the provided `tolerance`.

    The optimiser to use can be selected by passing in a suitable `optax` optimiser
    via the `optimiser` command. By default, `optax.adams` is used with the supplied
    `learning_rate`. Providing an explicit value for `optimiser` will result in the
    `learning_rate` argument being ignored.

    Args:
        obj_fn: Function to be minimised over its first argument.
        initial_guess: Initial guess for the minimising argument.
        convergence_criteria: The quantity that will be tested against `tolerance`, to
            determine whether the method has converged to a minimum. It should be a
            `callable` that takes the current value of `obj_fn` as its 1st argument, and
            the current value of the gradient of `obj_fn` as its 2nd argument. The
            default criteria is the l2-norm of the gradient.
        fn_args: Positional arguments to be passed to `obj_fn`, and held constant.
        fn_kwargs: Keyword arguments to be passed to `obj_fn`, and held constant.
        learning_rate: Default learning rate (or step size) to use when using the
            default `optimiser`. No effect if `optimiser` is provided explicitly.
        maxiter: Maximum number of iterations to perform. An error will be reported if
            this number of iterations is exceeded.
        optimiser: The `optax` optimiser to use during the update step.
        tolerance: `tolerance` used when determining if a minimum has been found.

    Returns:
        Minimising argument of `obj_fn`.
        Value of `obj_fn` at the minimum.
        Gradient of `obj_fn` at the minimum.
        Number of iterations performed.

    """
    if not fn_args:
        fn_args = ()
    if not fn_kwargs:
        fn_kwargs = {}
    if not convergence_criteria:
        convergence_criteria = lambda _, dx: jnp.sqrt(l2_normsq(dx))  # noqa: E731
    if not optimiser:
        optimiser = optax.adam(learning_rate)

    def objective(x: npt.ArrayLike) -> npt.ArrayLike:
        return obj_fn(x, *fn_args, **fn_kwargs)

    def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool:
        return convergence_criteria(x, dx) < tolerance

    converged = False

    opt_state = optimiser.init(initial_guess)
    current_params = deepcopy(initial_guess)
    gradient = jax.grad(objective)

    for _ in range(maxiter + 1):
        objective_value = objective(current_params)
        gradient_value = gradient(current_params)

        if converged := is_converged(objective_value, gradient_value):
            break

        updates, opt_state = optimiser.update(gradient_value, opt_state)
        current_params = optax.apply_updates(current_params, updates)

    iters_used = _
    reason_msg = (
        f"Did not converge after {iters_used} iterations" if not converged else ""
    )

    return SolverResult(
        fn_args=current_params,
        grad_val=gradient_value,
        iters=iters_used,
        maxiter=maxiter,
        obj_val=objective_value,
        reason=reason_msg,
        successful=converged,
    )

solver_result

Container class for outputs from solver methods.

SolverResult dataclass

Container class for outputs from solver methods.

Instances of this class provide a container for useful information that comes out of running one of the solver methods on a causal problem.

Attributes:

Name Type Description
fn_args PyTree

Argument to the objective function at final iteration (the solution, if successful isTrue`).

grad_val PyTree

Value of the gradient of the objective function at the fn_args.

iters int

Number of iterations performed.

maxiter int

Maximum number of iterations the solver was permitted to perform.

obj_val ArrayLike

Value of the objective function at fn_args.

reason str

Human-readable string explaining success or reasons for solver failure.

successful bool

True if solver converged, in which case fn_args is the argument to the objective function at the solution of the problem being solved. False otherwise.

Source code in src/causalprog/solvers/solver_result.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@dataclass(frozen=True)
class SolverResult:
    """
    Container class for outputs from solver methods.

    Instances of this class provide a container for useful information that
    comes out of running one of the solver methods on a causal problem.

    Attributes:
        fn_args: Argument to the objective function at final iteration (the solution,
            if `successful is `True`).
        grad_val: Value of the gradient of the objective function at the `fn_args`.
        iters: Number of iterations performed.
        maxiter: Maximum number of iterations the solver was permitted to perform.
        obj_val: Value of the objective function at `fn_args`.
        reason: Human-readable string explaining success or reasons for solver failure.
        successful: `True` if solver converged, in which case `fn_args` is the
            argument to the objective function at the solution of the problem being
            solved. `False` otherwise.

    """

    fn_args: PyTree
    grad_val: PyTree
    iters: int
    maxiter: int
    obj_val: npt.ArrayLike
    reason: str
    successful: bool

utils

Utility classes and methods.

norms

Misc collection of norm-like functions for PyTree structures.

l2_normsq(x)

Square of the l2-norm of a PyTree.

This is effectively "sum(elements**2 in leaf for leaf in x)".

Source code in src/causalprog/utils/norms.py
11
12
13
14
15
16
17
18
def l2_normsq(x: PyTree) -> npt.ArrayLike:
    """
    Square of the l2-norm of a PyTree.

    This is effectively "sum(elements**2 in leaf for leaf in x)".
    """
    leaves, _ = jax.tree_util.tree_flatten(x)
    return sum(jax.numpy.sum(leaf**2) for leaf in leaves)

translator

Helper class to keep the codebase backend-agnostic.

Our frontend (or user-facing) classes each use a syntax that applies across the package codebase. By contrast, the various backends that we want to support will have different syntaxes and call signatures for the functions that we want to support. As such, we need a helper class that can store this "translation" information, allowing the user to interact with the package in a standard way but also allowing them to choose their own backend if desired.

Translator

Bases: ABC

Maps syntax of a backend function to our frontend syntax.

Different backends have different syntax for drawing samples from the distributions they support. In order to map these different syntaxes to our backend-agnostic framework, we need a container class to map the names we have chosen for our frontend methods to those used by their corresponding backend method.

A Translator allows us to identify whether a user-provided backend object is compatible with one of our frontend wrapper classes (and thus, call signatures). It also allows users to write their own translators for any custom backends that we do not explicitly support.

The use case for a Translator is as follows. Suppose that we have a frontend class C that needs to provide a method do_something. C stores a reference to a backend object obj that can provide the functionality of do_something via one of its methods, obj.backend_method. However, there is no guarantee that the signature of do_something maps identically to that of obj.backend_method. A Translator allows us to encode a mapping of obj.backend_methods arguments to those of do_something.

Source code in src/causalprog/utils/translator.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class Translator(ABC):
    """
    Maps syntax of a backend function to our frontend syntax.

    Different backends have different syntax for drawing samples from the distributions
    they support. In order to map these different syntaxes to our backend-agnostic
    framework, we need a container class to map the names we have chosen for our
    frontend methods to those used by their corresponding backend method.

    A ``Translator`` allows us to identify whether a user-provided backend object is
    compatible with one of our frontend wrapper classes (and thus, call signatures). It
    also allows users to write their own translators for any custom backends that we do
    not explicitly support.

    The use case for a ``Translator`` is as follows. Suppose that we have a frontend
    class ``C`` that needs to provide a method ``do_something``. ``C`` stores a
    reference to a backend object ``obj`` that can provide the functionality of
    ``do_something`` via one of its methods, ``obj.backend_method``. However, there is
    no guarantee that the signature of ``do_something`` maps identically to that of
    ``obj.backend_method``. A ``Translator`` allows us to encode a mapping of
    ``obj.backend_method``s arguments to those of ``do_something``.
    """

    backend_method: str
    corresponding_backend_arg: dict[str, str]

    @property
    @abstractmethod
    def _frontend_method(self) -> str:
        """Name of the frontend method that the backend is to be translated into."""

    @property
    @abstractmethod
    def compulsory_frontend_args(self) -> set[str]:
        """Arguments that are required by the frontend function."""

    @property
    def compulsory_backend_args(self) -> set[str]:
        """Arguments that are required to be taken by the backend function."""
        return {
            self.corresponding_backend_arg[arg_name]
            for arg_name in self.compulsory_frontend_args
        }

    def __init__(
        self, backend_method: str | None = None, **front_args_to_back_args: str
    ) -> None:
        """
        Create a new Translator.

        Args:
            backend_method (str): Name of the backend method that the instance
                translates.
            **front_args_to_back_args (str): Mapping of frontend argument names to the
                corresponding backend argument names.

        """
        # Assume backend name is identical to frontend name if not provided explicitly
        self.backend_method = (
            backend_method if backend_method else self._frontend_method
        )

        # This should really be immutable after we fill defaults!
        self.corresponding_backend_arg = dict(front_args_to_back_args)
        # Assume compulsory frontend args that are not given translations
        # retain their name in the backend.
        for arg in self.compulsory_frontend_args:
            if arg not in self.corresponding_backend_arg:
                self.corresponding_backend_arg[arg] = arg

    def translate_args(self, **kwargs: Any) -> dict[str, Any]:  # noqa: ANN401
        """
        Translate frontend arguments (with values) to backend arguments.

        Essentially transforms frontend keyword arguments into their backend keyword
        arguments, preserving the value assigned to each argument.
        """
        return {
            self.corresponding_backend_arg[arg_name]: arg_value
            for arg_name, arg_value in kwargs.items()
        }

    def validate_compatible(self, obj: object) -> None:
        """
        Determine if ``obj`` provides a compatible backend method.

        ``obj`` must provide a callable whose name matches ``self.backend_method``,
        and the callable referenced must take arguments matching the names specified in
        ``self.compulsory_backend_args``.

        Args:
            obj (object): Object to check possesses a method that can be translated into
                frontend syntax.

        """
        # Check that obj does provide a method of matching name
        if not hasattr(obj, self.backend_method):
            msg = f"{obj} has no method '{self.backend_method}'."
            raise AttributeError(msg)
        if not callable(getattr(obj, self.backend_method)):
            msg = f"'{self.backend_method}' attribute of {obj} is not callable."
            raise TypeError(msg)

        # Check that this method will be callable with the information given.
        method_params = inspect.signature(getattr(obj, self.backend_method)).parameters
        # The arguments that will be passed are actually taken by the method.
        for compulsory_arg in self.compulsory_backend_args:
            if compulsory_arg not in method_params:
                msg = (
                    f"'{self.backend_method}' does not "
                    f"take argument '{compulsory_arg}'."
                )
                raise TypeError(msg)
        # The method does not _require_ any additional arguments
        method_requires = {
            name for name, p in method_params.items() if p.default is p.empty
        }
        if not method_requires.issubset(self.compulsory_backend_args):
            args_not_accounted_for = method_requires - self.compulsory_backend_args
            raise TypeError(
                f"'{self.backend_method}' not provided compulsory arguments "
                "(missing " + ", ".join(args_not_accounted_for) + ")"
            )
compulsory_backend_args property

Arguments that are required to be taken by the backend function.

compulsory_frontend_args abstractmethod property

Arguments that are required by the frontend function.

__init__(backend_method=None, **front_args_to_back_args)

Create a new Translator.

Parameters:

Name Type Description Default
backend_method str

Name of the backend method that the instance translates.

None
**front_args_to_back_args str

Mapping of frontend argument names to the corresponding backend argument names.

{}
Source code in src/causalprog/utils/translator.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __init__(
    self, backend_method: str | None = None, **front_args_to_back_args: str
) -> None:
    """
    Create a new Translator.

    Args:
        backend_method (str): Name of the backend method that the instance
            translates.
        **front_args_to_back_args (str): Mapping of frontend argument names to the
            corresponding backend argument names.

    """
    # Assume backend name is identical to frontend name if not provided explicitly
    self.backend_method = (
        backend_method if backend_method else self._frontend_method
    )

    # This should really be immutable after we fill defaults!
    self.corresponding_backend_arg = dict(front_args_to_back_args)
    # Assume compulsory frontend args that are not given translations
    # retain their name in the backend.
    for arg in self.compulsory_frontend_args:
        if arg not in self.corresponding_backend_arg:
            self.corresponding_backend_arg[arg] = arg
translate_args(**kwargs)

Translate frontend arguments (with values) to backend arguments.

Essentially transforms frontend keyword arguments into their backend keyword arguments, preserving the value assigned to each argument.

Source code in src/causalprog/utils/translator.py
87
88
89
90
91
92
93
94
95
96
97
def translate_args(self, **kwargs: Any) -> dict[str, Any]:  # noqa: ANN401
    """
    Translate frontend arguments (with values) to backend arguments.

    Essentially transforms frontend keyword arguments into their backend keyword
    arguments, preserving the value assigned to each argument.
    """
    return {
        self.corresponding_backend_arg[arg_name]: arg_value
        for arg_name, arg_value in kwargs.items()
    }
validate_compatible(obj)

Determine if obj provides a compatible backend method.

obj must provide a callable whose name matches self.backend_method, and the callable referenced must take arguments matching the names specified in self.compulsory_backend_args.

Parameters:

Name Type Description Default
obj object

Object to check possesses a method that can be translated into frontend syntax.

required
Source code in src/causalprog/utils/translator.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def validate_compatible(self, obj: object) -> None:
    """
    Determine if ``obj`` provides a compatible backend method.

    ``obj`` must provide a callable whose name matches ``self.backend_method``,
    and the callable referenced must take arguments matching the names specified in
    ``self.compulsory_backend_args``.

    Args:
        obj (object): Object to check possesses a method that can be translated into
            frontend syntax.

    """
    # Check that obj does provide a method of matching name
    if not hasattr(obj, self.backend_method):
        msg = f"{obj} has no method '{self.backend_method}'."
        raise AttributeError(msg)
    if not callable(getattr(obj, self.backend_method)):
        msg = f"'{self.backend_method}' attribute of {obj} is not callable."
        raise TypeError(msg)

    # Check that this method will be callable with the information given.
    method_params = inspect.signature(getattr(obj, self.backend_method)).parameters
    # The arguments that will be passed are actually taken by the method.
    for compulsory_arg in self.compulsory_backend_args:
        if compulsory_arg not in method_params:
            msg = (
                f"'{self.backend_method}' does not "
                f"take argument '{compulsory_arg}'."
            )
            raise TypeError(msg)
    # The method does not _require_ any additional arguments
    method_requires = {
        name for name, p in method_params.items() if p.default is p.empty
    }
    if not method_requires.issubset(self.compulsory_backend_args):
        args_not_accounted_for = method_requires - self.compulsory_backend_args
        raise TypeError(
            f"'{self.backend_method}' not provided compulsory arguments "
            "(missing " + ", ".join(args_not_accounted_for) + ")"
        )