1+ try :
2+ from warnings import deprecated
3+ except ImportError :
4+ from warnings import warn
5+ deprecated = None
6+
17from .enlisting import Enlist
28from .tape import get_working_tape , stop_annotating
39
410
5- def compute_gradient (J , m , options = None , tape = None , adj_value = 1.0 ):
11+ def compute_derivative (J , m , tape = None , adj_value = 1.0 , apply_riesz = False ):
612 """
7- Compute the gradient of J with respect to the initialisation value of m,
13+ Compute the derivative of J with respect to the initialisation value of m,
814 that is the value of m at its creation.
915
1016 Args:
11- J (AdjFloat ): The objective functional.
17+ J (OverloadedType ): The objective functional.
1218 m (list or instance of Control): The (list of) controls.
13- options (dict): A dictionary of options. To find a list of available options
14- have a look at the specific control type.
1519 tape: The tape to use. Default is the current tape.
20+ adj_value: The adjoint value to the result. Required if the functional
21+ is not scalar-valued, or if the functional is not the final stage
22+ in the computation of an outer functional.
23+ apply_riesz: If True, apply the Riesz map of each control in order
24+ to return a primal gradient rather than a derivative in the
25+ dual space.
1626
1727 Returns:
18- OverloadedType: The derivative with respect to the control. Should be an instance of the same type as
19- the control.
28+ OverloadedType: The derivative with respect to the control.
29+ If apply_riesz is False, should be an instance of the type dual
30+ to that of the control. If apply_riesz is True should have the
31+ same type as the control.
2032 """
21- options = options or {}
2233 tape = tape or get_working_tape ()
2334 tape .reset_variables ()
2435 J .block_variable .adj_value = adj_value
@@ -30,51 +41,126 @@ def compute_gradient(J, m, options=None, tape=None, adj_value=1.0):
3041 with marked_controls (m ):
3142 tape .evaluate_adj (markings = True )
3243
33- grads = [i .get_derivative (options = options ) for i in m ]
44+ grads = [i .get_derivative (apply_riesz = apply_riesz ) for i in m ]
3445 return m .delist (grads )
3546
3647
37- def compute_hessian (J , m , m_dot , options = None , tape = None ):
48+ def compute_gradient (J , m , tape = None , adj_value = 1.0 , apply_riesz = True ):
49+ """
50+ Compute the gradient of J with respect to the initialisation value of m,
51+ that is the value of m at its creation.
52+
53+ This function is deprecated in favour of :compute_derivative
54+
55+ Args:
56+ J (OverloadedType): The objective functional.
57+ m (list or instance of Control): The (list of) controls.
58+ tape: The tape to use. Default is the current tape.
59+ adj_value: The adjoint value to the result. Required if the functional
60+ is not scalar-valued, or if the functional is not the final stage
61+ in the computation of an outer functional.
62+ apply_riesz: If True, apply the Riesz map of each control in order
63+ to return a primal gradient rather than a derivative in the
64+ dual space.
65+
66+ Returns:
67+ OverloadedType: The gradient with respect to the control.
68+ If apply_riesz is False, should be an instance of the type dual
69+ to that of the control. If apply_riesz is True should have the
70+ same type as the control.
71+ """
72+ if deprecated is None :
73+ warn ("compute_gradient is deprecated in favour of compute_derivative." ,
74+ FutureWarning )
75+
76+ return compute_derivative (J , m , tape , adj_value , apply_riesz )
77+
78+
79+ if deprecated is not None :
80+ compute_gradient = deprecated (
81+ "compute_gradient is deprecated in favour of compute_derivative."
82+ )(compute_gradient )
83+
84+
85+ def compute_hessian (J , m , m_dot , hessian_input = None , tape = None , evaluate_tlm = True , apply_riesz = False ):
3886 """
3987 Compute the Hessian of J in a direction m_dot at the current value of m
4088
4189 Args:
4290 J (AdjFloat): The objective functional.
4391 m (list or instance of Control): The (list of) controls.
44- m_dot (list or instance of the control type): The direction in which to compute the Hessian.
45- options (dict): A dictionary of options. To find a list of available options
46- have a look at the specific control type.
92+ m_dot (list or instance of the control type): The direction in which to
93+ compute the Hessian.
94+ hessian_input (OverloadedType): The value to start the Hessian accumulation
95+ from after the TLM calculation. Uses zero initialised value if None.
4796 tape: The tape to use. Default is the current tape.
97+ apply_riesz: If True, apply the Riesz map of each control in order
98+ to return the (primal) Riesz representer of the Hessian
99+ action.
100+ evaluate_tlm (bool): Whether or not to compute the forward (TLM) part of
101+ the Hessian calculation. If False, assumes that the tape has already
102+ been populated with the required TLM values.
48103
49104 Returns:
50- OverloadedType: The second derivative with respect to the control in direction m_dot. Should be an instance of
51- the same type as the control.
105+ OverloadedType: The action of the Hessian in the direction m_dot.
106+ If apply_riesz is False, should be an instance of the type dual
107+ to that of the control. If apply_riesz is true should have the
108+ same type as the control.
52109 """
53110 tape = tape or get_working_tape ()
54- options = options or {}
55111
56- tape .reset_tlm_values ()
112+ # fill the relevant tlm values on the tape
113+ if evaluate_tlm :
114+ compute_tlm (J , m , m_dot , tape )
115+
57116 tape .reset_hessian_values ()
58117
59- m = Enlist (m )
60- m_dot = Enlist (m_dot )
61- for i , value in enumerate (m_dot ):
62- m [i ].tlm_value = m_dot [i ]
118+ if hessian_input is None :
119+ J .block_variable .hessian_value = (
120+ J .block_variable .output ._ad_init_zero (dual = True ))
121+ else :
122+ J .block_variable .hessian_value = (
123+ J .block_variable .output ._ad_init_object (hessian_input ))
63124
125+ m = Enlist (m )
64126 with stop_annotating ():
65127 with tape .marked_control_dependents (m ):
66- tape .evaluate_tlm (markings = True )
128+ with tape .marked_functional_dependencies (J ):
129+ tape .evaluate_hessian (markings = True )
130+
131+ r = [v .get_hessian (apply_riesz = apply_riesz ) for v in m ]
132+ return m .delist (r )
133+
134+
135+ def compute_tlm (J , m , m_dot , tape = None ):
136+ """
137+ Compute the tangent linear model of J in a direction m_dot at the current value of m
138+
139+ Args:
140+ J (OverloadedType): The objective functional.
141+ m (list or instance of Control): The (list of) controls.
142+ m_dot (list or instance of the control type): The direction in which to
143+ compute the tangent linear model.
144+ tape: The tape to use. Default is the current tape.
67145
68- J .block_variable .hessian_value = J .block_variable .output ._ad_convert_type (
69- 0. , options = {'riesz_representation' : None })
146+ Returns:
147+ OverloadedType: The action of the tangent linear model with respect to the control
148+ in direction m_dot. Should be an instance of the same type as the functional.
149+ """
150+ tape = tape or get_working_tape ()
151+ tape .reset_tlm_values ()
152+
153+ m = Enlist (m )
154+ m_dot = Enlist (m_dot )
155+
156+ for mi , mdi in zip (m , m_dot ):
157+ mi .tlm_value = mdi
70158
71159 with stop_annotating ():
72160 with tape .marked_control_dependents (m ):
73- with tape .marked_functional_dependencies (J ):
74- tape .evaluate_hessian (markings = True )
161+ tape .evaluate_tlm (markings = True )
75162
76- r = [v .get_hessian (options = options ) for v in m ]
77- return m .delist (r )
163+ return J ._ad_init_object (J .block_variable .tlm_value )
78164
79165
80166def solve_adjoint (J , tape = None , adj_value = 1.0 ):
0 commit comments