Skip to content

Reverse-over-forward AD #162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2880f2b
Reverse-over-forward configuration controls
jrmaddison Jul 10, 2024
19e7a1d
Reverse-over-forward AD
jrmaddison Jul 10, 2024
857438f
Test setup
jrmaddison Jul 10, 2024
04c48d6
In-place assignment, restore old values in reverse-over-forward AD
jrmaddison Jul 10, 2024
27d6355
Test setup fix
jrmaddison Jul 10, 2024
d123af3
BlockVariable.restore_output fix
jrmaddison Jul 10, 2024
ad76a17
Move will_add_as_output call until after tangent-linear operations
jrmaddison Jul 10, 2024
8b0d18c
Handle zero case to avoid unnecessary higher-order processing
jrmaddison Jul 10, 2024
111fb9f
restore_output fixes
jrmaddison Jul 11, 2024
74e45fe
Limit reverse-over-forward to second order
jrmaddison Jul 11, 2024
adcad2d
Reverse-over-forward AD: ExpBlock
jrmaddison Jul 11, 2024
3724a74
Reverse-over-forward AD: LogBlock
jrmaddison Jul 11, 2024
7e77ccd
Reverse-over-forward AD: AddBlock
jrmaddison Jul 11, 2024
c3ad8ea
Reverse-over-forward AD: NegBlock
jrmaddison Jul 11, 2024
f585e18
Reverse-over-forward AD: SubBlock
jrmaddison Jul 11, 2024
f537007
Expand AddBlock and SubBlock reverse-over-forward tests
jrmaddison Jul 11, 2024
ec60d22
Reverse-over-forward AD: MulBlock
jrmaddison Jul 11, 2024
664b320
Reverse-over-forward AD: PowBlock
jrmaddison Jul 11, 2024
1ad6323
Add PosBlock, use to fix a bug in AdjFloat reverse-over-forward AD
jrmaddison Jul 11, 2024
42b1d60
Reverse-over-forward AD: DivBlock
jrmaddison Jul 11, 2024
b535ae0
Reverse-over-forward AD: MinBlock and MaxBlock
jrmaddison Jul 11, 2024
7ff113d
== -> np.allclose
jrmaddison Jul 11, 2024
7deeb81
More reverse-over-forward testing
jrmaddison Jul 12, 2024
8a9c334
Bugfix
jrmaddison Jul 12, 2024
1d84ccf
Extra AdjFloat.__truediv__ testing
jrmaddison Jul 12, 2024
8fa8a0a
Minor __pos__ fixes
jrmaddison Jul 12, 2024
05a8365
Documentation fixes
jrmaddison Jul 12, 2024
ec6e2d5
Test updates
jrmaddison Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyadjoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from .block import Block
from .tape import (Tape,
set_working_tape, get_working_tape, no_annotations,
annotate_tape, stop_annotating, pause_annotation, continue_annotation)
annotate_tape, stop_annotating, pause_annotation, continue_annotation,
no_reverse_over_forward, reverse_over_forward_enabled,
stop_reverse_over_forward, pause_reverse_over_forward,
continue_reverse_over_forward)
from .adjfloat import AdjFloat, exp, log
from .reduced_functional import ReducedFunctional
from .drivers import compute_gradient, compute_hessian, solve_adjoint
Expand Down
127 changes: 127 additions & 0 deletions pyadjoint/adjfloat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def __div__(self, other):
def __truediv__(self, other):
return DivBlock(self, other)

@annotate_operator
def __pos__(self):
return PosBlock(self)

@annotate_operator
def __neg__(self):
return NegBlock(self)
Expand Down Expand Up @@ -188,6 +192,14 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar
input0 = inputs[0]
return _exp(input0) * tlm_input

def solve_tlm(self):
x, = self.get_outputs()
a, = self.get_dependencies()
if a.tlm_value is None:
x.tlm_value = None
else:
x.tlm_value = exp(a.output) * a.tlm_value

def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx,
relevant_dependencies, prepared=None):
input0 = inputs[0]
Expand All @@ -213,6 +225,14 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar
input0 = inputs[0]
return tlm_input / input0

def solve_tlm(self):
x, = self.get_outputs()
a, = self.get_dependencies()
if a.tlm_value is None:
x.tlm_value = None
else:
x.tlm_value = a.tlm_value / a.output

def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx,
relevant_dependencies, prepared=None):
input0 = inputs[0]
Expand Down Expand Up @@ -285,6 +305,14 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar
idx = 0 if inputs[0] <= inputs[1] else 1
return tlm_inputs[idx]

def solve_tlm(self):
x, = self.get_outputs()
a, b = self.get_dependencies()
if a.output <= b.output:
x.tlm_value = +a.tlm_value
else:
x.tlm_value = +b.tlm_value

def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx,
relevant_dependencies, prepared=None):
return self.evaluate_adj_component(inputs, hessian_inputs, block_variable, idx, prepared)
Expand All @@ -307,6 +335,14 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepar
else:
return 0.

def solve_tlm(self):
x, = self.get_outputs()
a, b = self.get_dependencies()
if a.output >= b.output:
x.tlm_value = +a.tlm_value
else:
x.tlm_value = +b.tlm_value

def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None):
idx = 0 if inputs[0] >= inputs[1] else 1
return tlm_inputs[idx]
Expand Down Expand Up @@ -379,6 +415,16 @@ def evaluate_tlm(self, markings=False):
float.__pow__(base_value, exponent_value))
output.add_tlm_output(exponent_adj)

def solve_tlm(self):
x, = self.get_outputs()
a, b = self.get_dependencies()
terms = []
if a.tlm_value is not None:
terms.append(b.output * (a.output ** (b.output - 1)) * a.tlm_value)
if b.tlm_value is not None:
terms.append(log(a.output) * (a.output ** b.output) * b.tlm_value)
x.tlm_value = None if len(terms) == 0 else sum(terms[1:], start=terms[0])

def evaluate_hessian(self, markings=False):
output = self.get_outputs()[0]
hessian_input = output.hessian_value
Expand Down Expand Up @@ -442,6 +488,17 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar
tlm_output += tlm_input
return tlm_output

def solve_tlm(self):
x, = self.get_outputs()
terms = tuple(dep.tlm_value for dep in self.get_dependencies()
if dep.tlm_value is not None)
if len(terms) == 0:
x.tlm_value = None
elif len(terms) == 1:
x.tlm_value = +terms[0]
else:
x.tlm_value = sum(terms[1:], start=terms[0])

def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx,
relevant_dependencies, prepared=None):
return hessian_inputs[0]
Expand All @@ -466,6 +523,19 @@ def evaluate_tlm(self, markings=False):
if tlm_input_1 is not None:
output.add_tlm_output(float.__neg__(tlm_input_1))

def solve_tlm(self):
x, = self.get_outputs()
a, b = self.get_dependencies()
if a.tlm_value is None:
if b.tlm_value is None:
x.tlm_value = None
else:
x.tlm_value = -b.tlm_value
elif b.tlm_value is None:
x.tlm_value = +a.tlm_value
else:
x.tlm_value = a.tlm_value - b.tlm_value

def evaluate_hessian(self, markings=False):
hessian_input = self.get_outputs()[0].hessian_value
if hessian_input is None:
Expand Down Expand Up @@ -494,6 +564,16 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar
tlm_output += float.__mul__(tlm_input, self.terms[j].saved_output)
return tlm_output

def solve_tlm(self):
x, = self.get_outputs()
a, b = self.get_dependencies()
terms = []
if a.tlm_value is not None:
terms.append(b.output * a.tlm_value)
if b.tlm_value is not None:
terms.append(a.output * b.tlm_value)
x.tlm_value = None if len(terms) == 0 else sum(terms[1:], start=terms[0])

def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx,
relevant_dependencies, prepared=None):
adj_input = adj_inputs[0]
Expand Down Expand Up @@ -542,6 +622,16 @@ def evaluate_tlm(self, markings=False):
))
))

def solve_tlm(self):
x, = self.get_outputs()
a, b = self.get_dependencies()
terms = []
if a.tlm_value is not None:
terms.append(a.tlm_value / b.output)
if b.tlm_value is not None:
terms.append((-a.output / (b.output ** 2)) * b.tlm_value)
x.tlm_value = None if len(terms) == 0 else sum(terms[1:], start=terms[0])

def evaluate_hessian(self, markings=False):
output = self.get_outputs()[0]
hessian_input = output.hessian_value
Expand Down Expand Up @@ -588,6 +678,35 @@ def evaluate_hessian(self, markings=False):
denominator.add_hessian_output(float.__mul__(numerator.tlm_value, mixed))


class PosBlock(FloatOperatorBlock):
operator = staticmethod(float.__pos__)
symbol = "+"

def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None):
return float.__pos__(adj_inputs[0])

def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None):
return float.__pos__(tlm_inputs[0])

def solve_tlm(self):
x, = self.get_outputs()
a, = self.get_dependencies()
if a.tlm_value is None:
x.tlm_value = None
else:
x.tlm_value = +a.tlm_value

def evaluate_hessian(self, markings=False):
hessian_input = self.get_outputs()[0].hessian_value
if hessian_input is None:
return

self.terms[0].add_hessian_output(float.__pos__(hessian_input))

def __str__(self):
return f"{self.symbol} {self.terms[0]}"


class NegBlock(FloatOperatorBlock):
operator = staticmethod(float.__neg__)
symbol = "-"
Expand All @@ -598,6 +717,14 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepar
def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None):
return float.__neg__(tlm_inputs[0])

def solve_tlm(self):
x, = self.get_outputs()
a, = self.get_dependencies()
if a.tlm_value is None:
x.tlm_value = None
else:
x.tlm_value = -a.tlm_value

def evaluate_hessian(self, markings=False):
hessian_input = self.get_outputs()[0].hessian_value
if hessian_input is None:
Expand Down
40 changes: 37 additions & 3 deletions pyadjoint/block.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .tape import no_annotations
from contextlib import ExitStack
from html import escape

from .tape import no_annotations, reverse_over_forward_enabled, stop_reverse_over_forward


class Block(object):
"""Base class for all Tape Block types.
Expand All @@ -11,15 +13,19 @@ class Block(object):
Abstract methods
:func:`evaluate_adj`

Args:
n_outputs (int): The number of outputs. Required for
reverse-over-forward AD.
"""
__slots__ = ['_dependencies', '_outputs', 'block_helper']
pop_kwargs_keys = []

def __init__(self, ad_block_tag=None):
def __init__(self, ad_block_tag=None, *, n_outputs=1):
self._dependencies = []
self._outputs = []
self.block_helper = None
self.tag = ad_block_tag
self._n_outputs = n_outputs

@classmethod
def pop_kwargs(cls, kwargs):
Expand Down Expand Up @@ -71,9 +77,28 @@ def add_output(self, obj):
obj (:class:`BlockVariable`): The object to be added.

"""
obj.will_add_as_output()

if reverse_over_forward_enabled() and len(self._outputs) >= self._n_outputs:
raise RuntimeError("Unexpected output")

self._outputs.append(obj)

if reverse_over_forward_enabled():
if len(self._outputs) == self._n_outputs:
if any(dep.tlm_value is not None for dep in self.get_dependencies()):
with ExitStack() as stack:
for dep in self.get_dependencies():
stack.enter_context(dep.restore_output())
with stop_reverse_over_forward():
self.solve_tlm()
else:
for x in self.get_outputs():
x.tlm_value = None
elif len(self._outputs) > self._n_outputs:
raise RuntimeError("Unexpected output")

obj.will_add_as_output()

def get_outputs(self):
"""Returns the list of block outputs.

Expand Down Expand Up @@ -255,6 +280,15 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar
"""
raise NotImplementedError("evaluate_tlm_component is not implemented for Block-type: {}".format(type(self)))

def solve_tlm(self):
"""This method should be overridden if using reverse-over-forward AD.

Perform a tangent-linear operation, storing results in the `tlm_value`
attributes of relevant `BlockVariable` objects.
"""

raise NotImplementedError(f"solve_tlm is not implemented for Block-type: {type(self)}")

@no_annotations
def evaluate_hessian(self, markings=False):
outputs = self.get_outputs()
Expand Down
25 changes: 24 additions & 1 deletion pyadjoint/block_variable.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .tape import no_annotations, get_working_tape
from contextlib import contextmanager

from .tape import no_annotations, get_working_tape, stop_annotating


class BlockVariable(object):
Expand Down Expand Up @@ -93,3 +95,24 @@ def checkpoint(self, value):
if self.is_control:
return
self._checkpoint = value

@contextmanager
def restore_output(self):
"""Return a context manager which can be used to temporarily restore
the value of `self.output` to `self.block_variable.saved_output`.

Returns:
The context manager.
"""

if self.output is self.saved_output:
yield
else:
with stop_annotating():
old_value = self.output._ad_copy()
self.output._ad_assign(self.saved_output)
try:
yield
finally:
with stop_annotating():
self.output._ad_assign(old_value)
12 changes: 12 additions & 0 deletions pyadjoint/overloaded_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,18 @@ def _ad_to_list(m):
"""
raise NotImplementedError

def _ad_assign(self, other):
"""This method must be overridden for mutable types.

In-place assignment.

Args:
other (object): The object assign to `self`, with the same type as
`self`.
"""

raise NotImplementedError

def _ad_copy(self):
"""This method must be overridden.

Expand Down
Loading