diff --git a/pyadjoint/__init__.py b/pyadjoint/__init__.py index 0baeb05b..b923bfa7 100644 --- a/pyadjoint/__init__.py +++ b/pyadjoint/__init__.py @@ -10,7 +10,10 @@ from .block import Block from .tape import (Tape, set_working_tape, get_working_tape, no_annotations, - annotate_tape, stop_annotating, pause_annotation, continue_annotation) + annotate_tape, stop_annotating, pause_annotation, continue_annotation, + no_reverse_over_forward, reverse_over_forward_enabled, + stop_reverse_over_forward, pause_reverse_over_forward, + continue_reverse_over_forward) from .adjfloat import AdjFloat, exp, log from .reduced_functional import ReducedFunctional from .drivers import compute_gradient, compute_hessian, solve_adjoint diff --git a/pyadjoint/adjfloat.py b/pyadjoint/adjfloat.py index e92e2a05..8f303348 100644 --- a/pyadjoint/adjfloat.py +++ b/pyadjoint/adjfloat.py @@ -64,6 +64,10 @@ def __div__(self, other): def __truediv__(self, other): return DivBlock(self, other) + @annotate_operator + def __pos__(self): + return PosBlock(self) + @annotate_operator def __neg__(self): return NegBlock(self) @@ -188,6 +192,14 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar input0 = inputs[0] return _exp(input0) * tlm_input + def solve_tlm(self): + x, = self.get_outputs() + a, = self.get_dependencies() + if a.tlm_value is None: + x.tlm_value = None + else: + x.tlm_value = exp(a.output) * a.tlm_value + def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): input0 = inputs[0] @@ -213,6 +225,14 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar input0 = inputs[0] return tlm_input / input0 + def solve_tlm(self): + x, = self.get_outputs() + a, = self.get_dependencies() + if a.tlm_value is None: + x.tlm_value = None + else: + x.tlm_value = a.tlm_value / a.output + def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): input0 = inputs[0] @@ -285,6 +305,14 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar idx = 0 if inputs[0] <= inputs[1] else 1 return tlm_inputs[idx] + def solve_tlm(self): + x, = self.get_outputs() + a, b = self.get_dependencies() + if a.output <= b.output: + x.tlm_value = +a.tlm_value + else: + x.tlm_value = +b.tlm_value + def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): return self.evaluate_adj_component(inputs, hessian_inputs, block_variable, idx, prepared) @@ -307,6 +335,14 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepar else: return 0. + def solve_tlm(self): + x, = self.get_outputs() + a, b = self.get_dependencies() + if a.output >= b.output: + x.tlm_value = +a.tlm_value + else: + x.tlm_value = +b.tlm_value + def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): idx = 0 if inputs[0] >= inputs[1] else 1 return tlm_inputs[idx] @@ -379,6 +415,16 @@ def evaluate_tlm(self, markings=False): float.__pow__(base_value, exponent_value)) output.add_tlm_output(exponent_adj) + def solve_tlm(self): + x, = self.get_outputs() + a, b = self.get_dependencies() + terms = [] + if a.tlm_value is not None: + terms.append(b.output * (a.output ** (b.output - 1)) * a.tlm_value) + if b.tlm_value is not None: + terms.append(log(a.output) * (a.output ** b.output) * b.tlm_value) + x.tlm_value = None if len(terms) == 0 else sum(terms[1:], start=terms[0]) + def evaluate_hessian(self, markings=False): output = self.get_outputs()[0] hessian_input = output.hessian_value @@ -442,6 +488,17 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar tlm_output += tlm_input return tlm_output + def solve_tlm(self): + x, = self.get_outputs() + terms = tuple(dep.tlm_value for dep in self.get_dependencies() + if dep.tlm_value is not None) + if len(terms) == 0: + x.tlm_value = None + elif len(terms) == 1: + x.tlm_value = +terms[0] + else: + x.tlm_value = sum(terms[1:], start=terms[0]) + def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): return hessian_inputs[0] @@ -466,6 +523,19 @@ def evaluate_tlm(self, markings=False): if tlm_input_1 is not None: output.add_tlm_output(float.__neg__(tlm_input_1)) + def solve_tlm(self): + x, = self.get_outputs() + a, b = self.get_dependencies() + if a.tlm_value is None: + if b.tlm_value is None: + x.tlm_value = None + else: + x.tlm_value = -b.tlm_value + elif b.tlm_value is None: + x.tlm_value = +a.tlm_value + else: + x.tlm_value = a.tlm_value - b.tlm_value + def evaluate_hessian(self, markings=False): hessian_input = self.get_outputs()[0].hessian_value if hessian_input is None: @@ -494,6 +564,16 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar tlm_output += float.__mul__(tlm_input, self.terms[j].saved_output) return tlm_output + def solve_tlm(self): + x, = self.get_outputs() + a, b = self.get_dependencies() + terms = [] + if a.tlm_value is not None: + terms.append(b.output * a.tlm_value) + if b.tlm_value is not None: + terms.append(a.output * b.tlm_value) + x.tlm_value = None if len(terms) == 0 else sum(terms[1:], start=terms[0]) + def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): adj_input = adj_inputs[0] @@ -542,6 +622,16 @@ def evaluate_tlm(self, markings=False): )) )) + def solve_tlm(self): + x, = self.get_outputs() + a, b = self.get_dependencies() + terms = [] + if a.tlm_value is not None: + terms.append(a.tlm_value / b.output) + if b.tlm_value is not None: + terms.append((-a.output / (b.output ** 2)) * b.tlm_value) + x.tlm_value = None if len(terms) == 0 else sum(terms[1:], start=terms[0]) + def evaluate_hessian(self, markings=False): output = self.get_outputs()[0] hessian_input = output.hessian_value @@ -588,6 +678,35 @@ def evaluate_hessian(self, markings=False): denominator.add_hessian_output(float.__mul__(numerator.tlm_value, mixed)) +class PosBlock(FloatOperatorBlock): + operator = staticmethod(float.__pos__) + symbol = "+" + + def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): + return float.__pos__(adj_inputs[0]) + + def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): + return float.__pos__(tlm_inputs[0]) + + def solve_tlm(self): + x, = self.get_outputs() + a, = self.get_dependencies() + if a.tlm_value is None: + x.tlm_value = None + else: + x.tlm_value = +a.tlm_value + + def evaluate_hessian(self, markings=False): + hessian_input = self.get_outputs()[0].hessian_value + if hessian_input is None: + return + + self.terms[0].add_hessian_output(float.__pos__(hessian_input)) + + def __str__(self): + return f"{self.symbol} {self.terms[0]}" + + class NegBlock(FloatOperatorBlock): operator = staticmethod(float.__neg__) symbol = "-" @@ -598,6 +717,14 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepar def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): return float.__neg__(tlm_inputs[0]) + def solve_tlm(self): + x, = self.get_outputs() + a, = self.get_dependencies() + if a.tlm_value is None: + x.tlm_value = None + else: + x.tlm_value = -a.tlm_value + def evaluate_hessian(self, markings=False): hessian_input = self.get_outputs()[0].hessian_value if hessian_input is None: diff --git a/pyadjoint/block.py b/pyadjoint/block.py index 29dbc70b..9d8ebdc2 100644 --- a/pyadjoint/block.py +++ b/pyadjoint/block.py @@ -1,6 +1,8 @@ -from .tape import no_annotations +from contextlib import ExitStack from html import escape +from .tape import no_annotations, reverse_over_forward_enabled, stop_reverse_over_forward + class Block(object): """Base class for all Tape Block types. @@ -11,15 +13,19 @@ class Block(object): Abstract methods :func:`evaluate_adj` + Args: + n_outputs (int): The number of outputs. Required for + reverse-over-forward AD. """ __slots__ = ['_dependencies', '_outputs', 'block_helper'] pop_kwargs_keys = [] - def __init__(self, ad_block_tag=None): + def __init__(self, ad_block_tag=None, *, n_outputs=1): self._dependencies = [] self._outputs = [] self.block_helper = None self.tag = ad_block_tag + self._n_outputs = n_outputs @classmethod def pop_kwargs(cls, kwargs): @@ -71,9 +77,28 @@ def add_output(self, obj): obj (:class:`BlockVariable`): The object to be added. """ - obj.will_add_as_output() + + if reverse_over_forward_enabled() and len(self._outputs) >= self._n_outputs: + raise RuntimeError("Unexpected output") + self._outputs.append(obj) + if reverse_over_forward_enabled(): + if len(self._outputs) == self._n_outputs: + if any(dep.tlm_value is not None for dep in self.get_dependencies()): + with ExitStack() as stack: + for dep in self.get_dependencies(): + stack.enter_context(dep.restore_output()) + with stop_reverse_over_forward(): + self.solve_tlm() + else: + for x in self.get_outputs(): + x.tlm_value = None + elif len(self._outputs) > self._n_outputs: + raise RuntimeError("Unexpected output") + + obj.will_add_as_output() + def get_outputs(self): """Returns the list of block outputs. @@ -255,6 +280,15 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar """ raise NotImplementedError("evaluate_tlm_component is not implemented for Block-type: {}".format(type(self))) + def solve_tlm(self): + """This method should be overridden if using reverse-over-forward AD. + + Perform a tangent-linear operation, storing results in the `tlm_value` + attributes of relevant `BlockVariable` objects. + """ + + raise NotImplementedError(f"solve_tlm is not implemented for Block-type: {type(self)}") + @no_annotations def evaluate_hessian(self, markings=False): outputs = self.get_outputs() diff --git a/pyadjoint/block_variable.py b/pyadjoint/block_variable.py index e252b0ea..bffd9c2e 100644 --- a/pyadjoint/block_variable.py +++ b/pyadjoint/block_variable.py @@ -1,4 +1,6 @@ -from .tape import no_annotations, get_working_tape +from contextlib import contextmanager + +from .tape import no_annotations, get_working_tape, stop_annotating class BlockVariable(object): @@ -93,3 +95,24 @@ def checkpoint(self, value): if self.is_control: return self._checkpoint = value + + @contextmanager + def restore_output(self): + """Return a context manager which can be used to temporarily restore + the value of `self.output` to `self.block_variable.saved_output`. + + Returns: + The context manager. + """ + + if self.output is self.saved_output: + yield + else: + with stop_annotating(): + old_value = self.output._ad_copy() + self.output._ad_assign(self.saved_output) + try: + yield + finally: + with stop_annotating(): + self.output._ad_assign(old_value) diff --git a/pyadjoint/overloaded_type.py b/pyadjoint/overloaded_type.py index 0f96667e..d2f9b183 100644 --- a/pyadjoint/overloaded_type.py +++ b/pyadjoint/overloaded_type.py @@ -285,6 +285,18 @@ def _ad_to_list(m): """ raise NotImplementedError + def _ad_assign(self, other): + """This method must be overridden for mutable types. + + In-place assignment. + + Args: + other (object): The object assign to `self`, with the same type as + `self`. + """ + + raise NotImplementedError + def _ad_copy(self): """This method must be overridden. diff --git a/pyadjoint/tape.py b/pyadjoint/tape.py index 63853d44..f7944dcb 100644 --- a/pyadjoint/tape.py +++ b/pyadjoint/tape.py @@ -11,6 +11,7 @@ _working_tape = None _annotation_enabled = False +_reverse_over_forward_enabled = False def get_working_tape(): @@ -136,6 +137,68 @@ def annotate_tape(kwargs=None): return annotate +def pause_reverse_over_forward(): + """Disable reverse-over-forward AD. + """ + + global _reverse_over_forward_enabled + _reverse_over_forward_enabled = False + + +def continue_reverse_over_forward(): + """Enable reverse-over-forward AD. + + Returns: + bool: True + """ + + global _reverse_over_forward_enabled + _reverse_over_forward_enabled = True + # Following continue_annotation behavior + return _reverse_over_forward_enabled + + +@contextmanager +def stop_reverse_over_forward(): + """Return a context manager used to temporarily disable + reverse-over-forward AD. + + Returns: + The context manager. + """ + + global _reverse_over_forward_enabled + reverse_over_forward_enabled = _reverse_over_forward_enabled + _reverse_over_forward_enabled = False + try: + yield + finally: + _reverse_over_forward_enabled = reverse_over_forward_enabled + + +def no_reverse_over_forward(function): + """Decorator to temporarily disable reverse-over-forward AD for the + decorated callable. + + Args: + function (callable): The callable. + Returns: + callable: Callable for which reverse-over-forward AD is disabled. + """ + + return stop_reverse_over_forward()(function) + + +def reverse_over_forward_enabled(): + """Return whether reverse-over-forward AD is enabled. + + Returns: + bool: Whether reverse-over-forward AD is enabled. + """ + + return _reverse_over_forward_enabled + + def _find_relevant_nodes(tape, controls): # This function is just a stripped down Block.optimize_for_controls blocks = tape.get_blocks() diff --git a/tests/pyadjoint/test_reverse_over_forward.py b/tests/pyadjoint/test_reverse_over_forward.py new file mode 100644 index 00000000..9d086a52 --- /dev/null +++ b/tests/pyadjoint/test_reverse_over_forward.py @@ -0,0 +1,290 @@ +from contextlib import contextmanager + +import numpy as np +import pytest + +from pyadjoint import * + + +@pytest.fixture(autouse=True) +def _(): + get_working_tape().clear_tape() + continue_annotation() + continue_reverse_over_forward() + yield + get_working_tape().clear_tape() + pause_annotation() + pause_reverse_over_forward() + + +@pytest.mark.parametrize("a_val", [2.0, -2.0]) +@pytest.mark.parametrize("tlm_a_val", [3.5, -3.5]) +def test_exp(a_val, tlm_a_val): + a = AdjFloat(a_val) + a.block_variable.tlm_value = AdjFloat(tlm_a_val) + x = exp(a) + stop_annotating() + _ = compute_gradient(x.block_variable.tlm_value, Control(a)) + adj_value = a.block_variable.adj_value + assert np.allclose(adj_value, exp(a_val) * tlm_a_val) + + +@pytest.mark.parametrize("a_val", [2.0, 3.0]) +@pytest.mark.parametrize("tlm_a_val", [3.5, -3.5]) +def test_log(a_val, tlm_a_val): + a = AdjFloat(a_val) + a.block_variable.tlm_value = AdjFloat(tlm_a_val) + x = log(a) + stop_annotating() + _ = compute_gradient(x.block_variable.tlm_value, Control(a)) + adj_value = a.block_variable.adj_value + assert np.allclose(adj_value, -tlm_a_val / (a_val ** 2)) + + +@pytest.mark.parametrize("a_val", [2.0, 3.0]) +@pytest.mark.parametrize("tlm_a_val", [3.5, -3.5]) +@pytest.mark.parametrize("c", [0, 1]) +def test_min_left(a_val, tlm_a_val, c): + a = AdjFloat(a_val) + a.block_variable.tlm_value = AdjFloat(tlm_a_val) + b = AdjFloat(a_val + c) + x = min(a, b) + assert x.block_variable.tlm_value == tlm_a_val + y = x ** 3 + stop_annotating() + _ = compute_gradient(y.block_variable.tlm_value, Control(a)) + adj_value = a.block_variable.adj_value + assert np.allclose(adj_value, 6 * a_val * tlm_a_val) + + + +@pytest.mark.parametrize("b_val", [2.0, 3.0]) +@pytest.mark.parametrize("tlm_b_val", [3.5, -3.5]) +def test_min_right(b_val, tlm_b_val): + a = AdjFloat(b_val + 1) + b = AdjFloat(b_val) + b.block_variable.tlm_value = AdjFloat(tlm_b_val) + x = min(a, b) + assert x.block_variable.tlm_value == tlm_b_val + y = x ** 3 + stop_annotating() + _ = compute_gradient(y.block_variable.tlm_value, Control(b)) + adj_value = b.block_variable.adj_value + assert np.allclose(adj_value, 6 * b_val * tlm_b_val) + + +@pytest.mark.parametrize("a_val", [2.0, 3.0]) +@pytest.mark.parametrize("tlm_a_val", [3.5, -3.5]) +@pytest.mark.parametrize("c", [0, -1]) +def test_max_left(a_val, tlm_a_val, c): + a = AdjFloat(a_val) + a.block_variable.tlm_value = AdjFloat(tlm_a_val) + b = AdjFloat(a_val + c) + x = max(a, b) + assert x.block_variable.tlm_value == tlm_a_val + y = x ** 3 + stop_annotating() + _ = compute_gradient(y.block_variable.tlm_value, Control(a)) + adj_value = a.block_variable.adj_value + assert np.allclose(adj_value, 6 * a_val * tlm_a_val) + + +@pytest.mark.parametrize("b_val", [2.0, 3.0]) +@pytest.mark.parametrize("tlm_b_val", [3.5, -3.5]) +def test_max_right(b_val, tlm_b_val): + a = AdjFloat(b_val - 1) + b = AdjFloat(b_val) + b.block_variable.tlm_value = AdjFloat(tlm_b_val) + x = max(a, b) + assert x.block_variable.tlm_value == tlm_b_val + y = x ** 3 + stop_annotating() + _ = compute_gradient(y.block_variable.tlm_value, Control(b)) + adj_value = b.block_variable.adj_value + assert np.allclose(adj_value, 6 * b_val * tlm_b_val) + + +@pytest.mark.parametrize("a_val", [2.0, 3.0]) +@pytest.mark.parametrize("tlm_a_val", [3.5, -3.5, None]) +@pytest.mark.parametrize("b_val", [4.25, 5.25]) +@pytest.mark.parametrize("tlm_b_val", [5.8125, -5.8125, None]) +def test_pow(a_val, tlm_a_val, b_val, tlm_b_val): + a = AdjFloat(a_val) + if tlm_a_val is not None: + a.block_variable.tlm_value = AdjFloat(tlm_a_val) + b = AdjFloat(b_val) + if tlm_b_val is not None: + b.block_variable.tlm_value = AdjFloat(tlm_b_val) + x = a ** b + if tlm_a_val is None and tlm_b_val is None: + assert x.block_variable.tlm_value is None + else: + assert (x.block_variable.tlm_value == + b_val * (a_val ** (b_val - 1)) * (0.0 if tlm_a_val is None else tlm_a_val) + + log(a_val) * (a_val ** b_val) * (0.0 if tlm_b_val is None else tlm_b_val)) + if tlm_a_val is not None or tlm_b_val is not None: + _ = compute_gradient(x.block_variable.tlm_value, (Control(a), Control(b))) + assert np.allclose( + a.block_variable.adj_value, + b_val * (b_val - 1) * (a_val ** (b_val - 2)) * (0.0 if tlm_a_val is None else tlm_a_val) + + (1 + b_val * log(a_val)) * (a_val ** (b_val - 1)) * (0.0 if tlm_b_val is None else tlm_b_val)) + assert np.allclose( + b.block_variable.adj_value, + (1 + b_val * log(a_val)) * (a_val ** (b_val - 1)) * (0.0 if tlm_a_val is None else tlm_a_val) + + (log(a_val) ** 2) * (a_val ** b_val) * (0.0 if tlm_b_val is None else tlm_b_val)) + + +@pytest.mark.parametrize("a_val", [2.0, -2.0]) +@pytest.mark.parametrize("tlm_a_val", [3.5, -3.5, None]) +@pytest.mark.parametrize("b_val", [4.25, -4.25]) +@pytest.mark.parametrize("tlm_b_val", [5.8125, -5.8125, None]) +def test_add(a_val, tlm_a_val, b_val, tlm_b_val): + a = AdjFloat(a_val) + if tlm_a_val is not None: + a.block_variable.tlm_value = AdjFloat(tlm_a_val) + b = AdjFloat(b_val) + if tlm_b_val is not None: + b.block_variable.tlm_value = AdjFloat(tlm_b_val) + x = a + b + if tlm_a_val is None and tlm_b_val is None: + assert x.block_variable.tlm_value is None + else: + assert (x.block_variable.tlm_value == + (0.0 if tlm_a_val is None else tlm_a_val) + + (0.0 if tlm_b_val is None else tlm_b_val)) + y = x ** 3 + stop_annotating() + if tlm_a_val is not None or tlm_b_val is not None: + _ = compute_gradient(y.block_variable.tlm_value, (Control(a), Control(b))) + assert np.allclose( + a.block_variable.adj_value, + 6 * (a_val + b_val) * (0.0 if tlm_a_val is None else tlm_a_val) + + 6 * (a_val + b_val) * (0.0 if tlm_b_val is None else tlm_b_val)) + assert np.allclose( + b.block_variable.adj_value, + 6 * (a_val + b_val) * (0.0 if tlm_a_val is None else tlm_a_val) + + 6 * (a_val + b_val) * (0.0 if tlm_b_val is None else tlm_b_val)) + + +@pytest.mark.parametrize("a_val", [2.0, -2.0]) +@pytest.mark.parametrize("tlm_a_val", [3.5, -3.5, None]) +@pytest.mark.parametrize("b_val", [4.25, -4.25]) +@pytest.mark.parametrize("tlm_b_val", [5.8125, -5.8125, None]) +def test_sub(a_val, tlm_a_val, b_val, tlm_b_val): + a = AdjFloat(a_val) + if tlm_a_val is not None: + a.block_variable.tlm_value = AdjFloat(tlm_a_val) + b = AdjFloat(b_val) + if tlm_b_val is not None: + b.block_variable.tlm_value = AdjFloat(tlm_b_val) + x = a - b + if tlm_a_val is None and tlm_b_val is None: + assert x.block_variable.tlm_value is None + else: + assert (x.block_variable.tlm_value == + (0.0 if tlm_a_val is None else tlm_a_val) + - (0.0 if tlm_b_val is None else tlm_b_val)) + y = x ** 3 + stop_annotating() + if tlm_a_val is not None or tlm_b_val is not None: + _ = compute_gradient(y.block_variable.tlm_value, (Control(a), Control(b))) + assert np.allclose( + a.block_variable.adj_value, + 6 * (a_val - b_val) * (0.0 if tlm_a_val is None else tlm_a_val) + - 6 * (a_val - b_val) * (0.0 if tlm_b_val is None else tlm_b_val)) + assert np.allclose( + b.block_variable.adj_value, + - 6 * (a_val - b_val) * (0.0 if tlm_a_val is None else tlm_a_val) + + 6 * (a_val - b_val) * (0.0 if tlm_b_val is None else tlm_b_val)) + + +@pytest.mark.parametrize("a_val", [2.0, -2.0]) +@pytest.mark.parametrize("tlm_a_val", [3.5, -3.5, None]) +@pytest.mark.parametrize("b_val", [4.25, -4.25]) +@pytest.mark.parametrize("tlm_b_val", [5.8125, -5.8125, None]) +def test_mul(a_val, tlm_a_val, b_val, tlm_b_val): + a = AdjFloat(a_val) + if tlm_a_val is not None: + a.block_variable.tlm_value = AdjFloat(tlm_a_val) + b = AdjFloat(b_val) + if tlm_b_val is not None: + b.block_variable.tlm_value = AdjFloat(tlm_b_val) + x = a * b + if tlm_a_val is None and tlm_b_val is None: + assert x.block_variable.tlm_value is None + else: + assert (x.block_variable.tlm_value == + b_val * (0.0 if tlm_a_val is None else tlm_a_val) + + a_val * (0.0 if tlm_b_val is None else tlm_b_val)) + stop_annotating() + if tlm_a_val is not None or tlm_b_val is not None: + _ = compute_gradient(x.block_variable.tlm_value, (Control(a), Control(b))) + if tlm_b_val is None: + assert a.block_variable.adj_value is None + else: + assert np.allclose( + a.block_variable.adj_value, tlm_b_val) + if tlm_a_val is None: + assert b.block_variable.adj_value is None + else: + assert np.allclose( + b.block_variable.adj_value, tlm_a_val) + + +@pytest.mark.parametrize("a_val", [2.0, -2.0]) +@pytest.mark.parametrize("tlm_a_val", [3.5, -3.5, None]) +@pytest.mark.parametrize("b_val", [4.25, -4.25]) +@pytest.mark.parametrize("tlm_b_val", [5.8125, -5.8125, None]) +def test_div(a_val, tlm_a_val, b_val, tlm_b_val): + a = AdjFloat(a_val) + if tlm_a_val is not None: + a.block_variable.tlm_value = AdjFloat(tlm_a_val) + b = AdjFloat(b_val) + if tlm_b_val is not None: + b.block_variable.tlm_value = AdjFloat(tlm_b_val) + x = (a ** 2) / b + if tlm_a_val is None and tlm_b_val is None: + assert x.block_variable.tlm_value is None + else: + assert (x.block_variable.tlm_value == + (2 * a_val / b_val) * (0.0 if tlm_a_val is None else tlm_a_val) + - ((a_val ** 2) / (b_val ** 2)) * (0.0 if tlm_b_val is None else tlm_b_val)) + stop_annotating() + if tlm_a_val is not None or tlm_b_val is not None: + _ = compute_gradient(x.block_variable.tlm_value, (Control(a), Control(b))) + assert np.allclose( + a.block_variable.adj_value, + (2 / b_val) * (0.0 if tlm_a_val is None else tlm_a_val) + - 2 * a_val / (b_val ** 2) * (0.0 if tlm_b_val is None else tlm_b_val)) + assert np.allclose( + b.block_variable.adj_value, + - 2 * a_val / (b_val ** 2) * (0.0 if tlm_a_val is None else tlm_a_val) + + 2 * (a_val ** 2) / (b_val ** 3) * (0.0 if tlm_b_val is None else tlm_b_val)) + + +@pytest.mark.parametrize("a_val", [2.0, -2.0]) +@pytest.mark.parametrize("tlm_a_val", [3.5, -3.5]) +def test_pos(a_val, tlm_a_val): + a = AdjFloat(a_val) + a.block_variable.tlm_value = AdjFloat(tlm_a_val) + x = +a + assert x.block_variable.tlm_value == tlm_a_val + y = x ** 3 + stop_annotating() + _ = compute_gradient(y.block_variable.tlm_value, Control(a)) + adj_value = a.block_variable.adj_value + assert np.allclose(adj_value, 6 * a_val * tlm_a_val) + + +@pytest.mark.parametrize("a_val", [2.0, -2.0]) +@pytest.mark.parametrize("tlm_a_val", [3.5, -3.5]) +def test_neg(a_val, tlm_a_val): + a = AdjFloat(a_val) + a.block_variable.tlm_value = AdjFloat(tlm_a_val) + x = -a + assert x.block_variable.tlm_value == -tlm_a_val + y = x ** 3 + stop_annotating() + _ = compute_gradient(y.block_variable.tlm_value, Control(a)) + adj_value = a.block_variable.adj_value + assert np.allclose(adj_value, -6 * a_val * tlm_a_val) diff --git a/tests/pyadjoint/test_tape.py b/tests/pyadjoint/test_tape.py new file mode 100644 index 00000000..20b88189 --- /dev/null +++ b/tests/pyadjoint/test_tape.py @@ -0,0 +1,38 @@ +import pytest + +from pyadjoint import * # noqa: F403 + + +@pytest.fixture(autouse=True) +def _(): + pause_reverse_over_forward() + yield + pause_reverse_over_forward() + + +def test_reverse_over_forward_configuration(): + assert not reverse_over_forward_enabled() + + continue_reverse_over_forward() + assert reverse_over_forward_enabled() + pause_reverse_over_forward() + assert not reverse_over_forward_enabled() + + continue_reverse_over_forward() + assert reverse_over_forward_enabled() + with stop_reverse_over_forward(): + assert not reverse_over_forward_enabled() + assert reverse_over_forward_enabled() + pause_reverse_over_forward() + assert not reverse_over_forward_enabled() + + @no_reverse_over_forward + def test(): + assert not reverse_over_forward_enabled() + + continue_reverse_over_forward() + assert reverse_over_forward_enabled() + test() + assert reverse_over_forward_enabled() + pause_reverse_over_forward() + assert not reverse_over_forward_enabled()