Skip to content

Commit a3ec0aa

Browse files
authored
Merge pull request #156 from dolfin-adjoint/dham/abstract_reduced_functional
2 parents 7d03a3c + c5aa4e8 commit a3ec0aa

9 files changed

Lines changed: 417 additions & 206 deletions

File tree

pyadjoint/adjfloat.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,12 @@ def __rsub__(self, other):
9393
def __pow__(self, power):
9494
return PowBlock(self, power)
9595

96-
def _ad_convert_type(self, value, options={}):
96+
def _ad_init_zero(self, dual=False):
97+
return type(self)(0.)
98+
99+
def _ad_convert_riesz(self, value, riesz_map=None):
100+
if riesz_map is not None:
101+
raise ValueError(f"Unexpected Riesz map for Adjfloat: {riesz_map}")
97102
return AdjFloat(value)
98103

99104
def _ad_create_checkpoint(self):
@@ -343,7 +348,7 @@ def __init__(self, *args):
343348

344349
def recompute_component(self, inputs, block_variable, idx, prepared):
345350
output = self.operator(*(term.saved_output for term in self.terms))
346-
return self._outputs[0].saved_output._ad_convert_type(output)
351+
return type(self._outputs[0].saved_output)(output)
347352

348353
def __str__(self):
349354
return f"{self.terms[0]} {self.symbol} {self.terms[1]}"

pyadjoint/control.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
from typing import Any
12
from .overloaded_type import OverloadedType, create_overloaded_object
23
import logging
34

45

56
class Control(object):
67
"""Defines a control variable from an OverloadedType.
78
8-
The control object references a specific node on the Tape.
9-
For mutable OverloadedType instances the Control only represents
10-
the value at the time of initialization.
9+
The control object references a specific node on the Tape. For mutable
10+
OverloadedType instances the Control only represents the value at the time
11+
of initialization.
1112
1213
Example:
1314
Given a mutable OverloadedType instance u.
@@ -25,18 +26,22 @@ class Control(object):
2526
>>> c2.data()
2627
3.0
2728
28-
Now c1 represents the node prior to the add_in_place Block,
29-
while c2 represents the node after the add_in_place Block.
30-
Creating a `ReducedFunctional` with c2 as Control results in
31-
a reduced problem without the add_in_place Block, while a ReducedFunctional
32-
with c1 as Control results in a forward model including the add_in_place.
29+
Now c1 represents the node prior to the add_in_place Block, while c2
30+
represents the node after the add_in_place Block. Creating a
31+
`ReducedFunctional` with c2 as Control results in a reduced problem
32+
without the add_in_place Block, while a ReducedFunctional with c1 as
33+
Control results in a forward model including the add_in_place.
3334
3435
Args:
35-
control (OverloadedType): The OverloadedType instance to define this control from.
36+
control: The OverloadedType instance to define this control from.
37+
riesz_map: Parameters controlling how to find the Riesz representer of
38+
a dual (adjoint) variable to this control. The permitted values are
39+
type-dependent.
3640
3741
"""
38-
def __init__(self, control):
42+
def __init__(self, control: OverloadedType, riesz_map: Any = None):
3943
self.control = control
44+
self.riesz_map = riesz_map
4045
self.block_variable = control.block_variable
4146

4247
def data(self):
@@ -45,17 +50,27 @@ def data(self):
4550
def tape_value(self):
4651
return create_overloaded_object(self.block_variable.saved_output)
4752

48-
def get_derivative(self, options={}):
53+
def get_derivative(self, apply_riesz=False):
4954
if self.block_variable.adj_value is None:
5055
logging.warning("Adjoint value is None, is the functional independent of the control variable?")
51-
return self.control._ad_convert_type(0., options=options)
52-
return self.control._ad_convert_type(self.block_variable.adj_value, options=options)
56+
return self.control._ad_init_zero(dual=not apply_riesz)
57+
elif apply_riesz:
58+
return self.control._ad_convert_riesz(
59+
self.block_variable.adj_value, riesz_map=self.riesz_map)
60+
else:
61+
return self.control._ad_init_object(self.block_variable.adj_value)
5362

54-
def get_hessian(self, options={}):
63+
def get_hessian(self, apply_riesz=False):
5564
if self.block_variable.hessian_value is None:
5665
logging.warning("Hessian value is None, is the functional independent of the control variable?")
57-
return self.control._ad_convert_type(0., options=options)
58-
return self.control._ad_convert_type(self.block_variable.hessian_value, options=options)
66+
return self.control._ad_init_zero(dual=not apply_riesz)
67+
elif apply_riesz:
68+
return self.control._ad_convert_riesz(
69+
self.block_variable.hessian_value, riesz_map=self.riesz_map)
70+
else:
71+
return self.control._ad_init_object(
72+
self.block_variable.hessian_value
73+
)
5974

6075
def update(self, value):
6176
# In the future we might want to call a static method

pyadjoint/drivers.py

Lines changed: 114 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,35 @@
1+
try:
2+
from warnings import deprecated
3+
except ImportError:
4+
from warnings import warn
5+
deprecated = None
6+
17
from .enlisting import Enlist
28
from .tape import get_working_tape, stop_annotating
39

410

5-
def compute_gradient(J, m, options=None, tape=None, adj_value=1.0):
11+
def compute_derivative(J, m, tape=None, adj_value=1.0, apply_riesz=False):
612
"""
7-
Compute the gradient of J with respect to the initialisation value of m,
13+
Compute the derivative of J with respect to the initialisation value of m,
814
that is the value of m at its creation.
915
1016
Args:
11-
J (AdjFloat): The objective functional.
17+
J (OverloadedType): The objective functional.
1218
m (list or instance of Control): The (list of) controls.
13-
options (dict): A dictionary of options. To find a list of available options
14-
have a look at the specific control type.
1519
tape: The tape to use. Default is the current tape.
20+
adj_value: The adjoint value to the result. Required if the functional
21+
is not scalar-valued, or if the functional is not the final stage
22+
in the computation of an outer functional.
23+
apply_riesz: If True, apply the Riesz map of each control in order
24+
to return a primal gradient rather than a derivative in the
25+
dual space.
1626
1727
Returns:
18-
OverloadedType: The derivative with respect to the control. Should be an instance of the same type as
19-
the control.
28+
OverloadedType: The derivative with respect to the control.
29+
If apply_riesz is False, should be an instance of the type dual
30+
to that of the control. If apply_riesz is True should have the
31+
same type as the control.
2032
"""
21-
options = options or {}
2233
tape = tape or get_working_tape()
2334
tape.reset_variables()
2435
J.block_variable.adj_value = adj_value
@@ -30,51 +41,126 @@ def compute_gradient(J, m, options=None, tape=None, adj_value=1.0):
3041
with marked_controls(m):
3142
tape.evaluate_adj(markings=True)
3243

33-
grads = [i.get_derivative(options=options) for i in m]
44+
grads = [i.get_derivative(apply_riesz=apply_riesz) for i in m]
3445
return m.delist(grads)
3546

3647

37-
def compute_hessian(J, m, m_dot, options=None, tape=None):
48+
def compute_gradient(J, m, tape=None, adj_value=1.0, apply_riesz=True):
49+
"""
50+
Compute the gradient of J with respect to the initialisation value of m,
51+
that is the value of m at its creation.
52+
53+
This function is deprecated in favour of :compute_derivative
54+
55+
Args:
56+
J (OverloadedType): The objective functional.
57+
m (list or instance of Control): The (list of) controls.
58+
tape: The tape to use. Default is the current tape.
59+
adj_value: The adjoint value to the result. Required if the functional
60+
is not scalar-valued, or if the functional is not the final stage
61+
in the computation of an outer functional.
62+
apply_riesz: If True, apply the Riesz map of each control in order
63+
to return a primal gradient rather than a derivative in the
64+
dual space.
65+
66+
Returns:
67+
OverloadedType: The gradient with respect to the control.
68+
If apply_riesz is False, should be an instance of the type dual
69+
to that of the control. If apply_riesz is True should have the
70+
same type as the control.
71+
"""
72+
if deprecated is None:
73+
warn("compute_gradient is deprecated in favour of compute_derivative.",
74+
FutureWarning)
75+
76+
return compute_derivative(J, m, tape, adj_value, apply_riesz)
77+
78+
79+
if deprecated is not None:
80+
compute_gradient = deprecated(
81+
"compute_gradient is deprecated in favour of compute_derivative."
82+
)(compute_gradient)
83+
84+
85+
def compute_hessian(J, m, m_dot, hessian_input=None, tape=None, evaluate_tlm=True, apply_riesz=False):
3886
"""
3987
Compute the Hessian of J in a direction m_dot at the current value of m
4088
4189
Args:
4290
J (AdjFloat): The objective functional.
4391
m (list or instance of Control): The (list of) controls.
44-
m_dot (list or instance of the control type): The direction in which to compute the Hessian.
45-
options (dict): A dictionary of options. To find a list of available options
46-
have a look at the specific control type.
92+
m_dot (list or instance of the control type): The direction in which to
93+
compute the Hessian.
94+
hessian_input (OverloadedType): The value to start the Hessian accumulation
95+
from after the TLM calculation. Uses zero initialised value if None.
4796
tape: The tape to use. Default is the current tape.
97+
apply_riesz: If True, apply the Riesz map of each control in order
98+
to return the (primal) Riesz representer of the Hessian
99+
action.
100+
evaluate_tlm (bool): Whether or not to compute the forward (TLM) part of
101+
the Hessian calculation. If False, assumes that the tape has already
102+
been populated with the required TLM values.
48103
49104
Returns:
50-
OverloadedType: The second derivative with respect to the control in direction m_dot. Should be an instance of
51-
the same type as the control.
105+
OverloadedType: The action of the Hessian in the direction m_dot.
106+
If apply_riesz is False, should be an instance of the type dual
107+
to that of the control. If apply_riesz is true should have the
108+
same type as the control.
52109
"""
53110
tape = tape or get_working_tape()
54-
options = options or {}
55111

56-
tape.reset_tlm_values()
112+
# fill the relevant tlm values on the tape
113+
if evaluate_tlm:
114+
compute_tlm(J, m, m_dot, tape)
115+
57116
tape.reset_hessian_values()
58117

59-
m = Enlist(m)
60-
m_dot = Enlist(m_dot)
61-
for i, value in enumerate(m_dot):
62-
m[i].tlm_value = m_dot[i]
118+
if hessian_input is None:
119+
J.block_variable.hessian_value = (
120+
J.block_variable.output._ad_init_zero(dual=True))
121+
else:
122+
J.block_variable.hessian_value = (
123+
J.block_variable.output._ad_init_object(hessian_input))
63124

125+
m = Enlist(m)
64126
with stop_annotating():
65127
with tape.marked_control_dependents(m):
66-
tape.evaluate_tlm(markings=True)
128+
with tape.marked_functional_dependencies(J):
129+
tape.evaluate_hessian(markings=True)
130+
131+
r = [v.get_hessian(apply_riesz=apply_riesz) for v in m]
132+
return m.delist(r)
133+
134+
135+
def compute_tlm(J, m, m_dot, tape=None):
136+
"""
137+
Compute the tangent linear model of J in a direction m_dot at the current value of m
138+
139+
Args:
140+
J (OverloadedType): The objective functional.
141+
m (list or instance of Control): The (list of) controls.
142+
m_dot (list or instance of the control type): The direction in which to
143+
compute the tangent linear model.
144+
tape: The tape to use. Default is the current tape.
67145
68-
J.block_variable.hessian_value = J.block_variable.output._ad_convert_type(
69-
0., options={'riesz_representation': None})
146+
Returns:
147+
OverloadedType: The action of the tangent linear model with respect to the control
148+
in direction m_dot. Should be an instance of the same type as the functional.
149+
"""
150+
tape = tape or get_working_tape()
151+
tape.reset_tlm_values()
152+
153+
m = Enlist(m)
154+
m_dot = Enlist(m_dot)
155+
156+
for mi, mdi in zip(m, m_dot):
157+
mi.tlm_value = mdi
70158

71159
with stop_annotating():
72160
with tape.marked_control_dependents(m):
73-
with tape.marked_functional_dependencies(J):
74-
tape.evaluate_hessian(markings=True)
161+
tape.evaluate_tlm(markings=True)
75162

76-
r = [v.get_hessian(options=options) for v in m]
77-
return m.delist(r)
163+
return J._ad_init_object(J.block_variable.tlm_value)
78164

79165

80166
def solve_adjoint(J, tape=None, adj_value=1.0):

pyadjoint/optimization/optimization.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def serialise_bounds(rf_np, bounds):
3838
return np.array(bounds_arr).T
3939

4040

41-
def minimize_scipy_generic(rf_np, method, bounds=None, derivative_options=None, **kwargs):
41+
def minimize_scipy_generic(rf_np, method, bounds=None, **kwargs):
4242
"""Interface to the generic minimize method in scipy
4343
4444
"""
@@ -56,18 +56,11 @@ def minimize_scipy_generic(rf_np, method, bounds=None, derivative_options=None,
5656

5757
raise
5858

59-
if method in ["Newton-CG"]:
60-
forget = None
61-
else:
62-
forget = False
63-
64-
project = kwargs.pop("project", False)
65-
6659
m = [p.tape_value() for p in rf_np.controls]
6760
m_global = rf_np.obj_to_array(m)
6861
J = rf_np.__call__
69-
dJ = lambda m: rf_np.derivative(m, forget=forget, project=project, options=derivative_options)
70-
H = rf_np.hessian
62+
dJ = lambda m: rf_np.derivative(apply_riesz=True)
63+
H = lambda x, p: rf_np.hessian(p)
7164

7265
if "options" not in kwargs:
7366
kwargs["options"] = {}
@@ -144,7 +137,7 @@ def jac(x):
144137
return m
145138

146139

147-
def minimize_custom(rf_np, bounds=None, derivative_options=None, **kwargs):
140+
def minimize_custom(rf_np, bounds=None, **kwargs):
148141
""" Interface to the user-provided minimisation method """
149142

150143
try:
@@ -160,7 +153,7 @@ def minimize_custom(rf_np, bounds=None, derivative_options=None, **kwargs):
160153
m_global = rf_np.obj_to_array(m)
161154
J = rf_np.__call__
162155

163-
dJ = lambda m: rf_np.derivative(m, forget=None, options=derivative_options)
156+
dJ = lambda m: rf_np.derivative(m, apply_riesz=True)
164157
H = rf_np.hessian
165158

166159
if bounds is not None:
@@ -263,7 +256,7 @@ def minimize(rf, method='L-BFGS-B', scale=1.0, **kwargs):
263256
return opt
264257

265258

266-
def maximize(rf, method='L-BFGS-B', scale=1.0, derivative_options=None, **kwargs):
259+
def maximize(rf, method='L-BFGS-B', scale=1.0, **kwargs):
267260
""" Solves the maximisation problem with PDE constraint:
268261
269262
max_m func(u, m)
@@ -282,7 +275,6 @@ def maximize(rf, method='L-BFGS-B', scale=1.0, derivative_options=None, **kwargs
282275
* 'method' specifies the optimization method to be used to solve the problem.
283276
The available methods can be listed with the print_optimization_methods function.
284277
* 'scale' is a factor to scale to problem (default: 1.0).
285-
* 'derivative_options' is a dictionary of options that will be passed to the `rf.derivative`.
286278
* 'bounds' is an optional keyword parameter to support control constraints: bounds = (lb, ub).
287279
lb and ub must be of the same type than the parameters m.
288280
@@ -291,7 +283,7 @@ def maximize(rf, method='L-BFGS-B', scale=1.0, derivative_options=None, **kwargs
291283
For detailed information about which arguments are supported for each optimization method,
292284
please refer to the documentaton of the optimization algorithm.
293285
"""
294-
return minimize(rf, method, scale=-scale, derivative_options=derivative_options, **kwargs)
286+
return minimize(rf, method, scale=-scale, **kwargs)
295287

296288

297289
minimise = minimize

0 commit comments

Comments
 (0)