|
| 1 | +# SPDX-FileCopyrightText: 2024 Alexandru Fikl <[email protected]> |
| 2 | +# SPDX-License-Identifier: MIT |
| 3 | + |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +from collections.abc import Iterator |
| 7 | +from dataclasses import dataclass |
| 8 | +from functools import cached_property |
| 9 | +from typing import Any, NamedTuple |
| 10 | + |
| 11 | +import numpy as np |
| 12 | + |
| 13 | +from pycaputo.derivatives import CaputoDerivative |
| 14 | +from pycaputo.events import Event, StepCompleted |
| 15 | +from pycaputo.history import History, ProductIntegrationHistory |
| 16 | +from pycaputo.logging import get_logger |
| 17 | +from pycaputo.stepping import ( |
| 18 | + FractionalDifferentialEquationMethod, |
| 19 | + evolve, |
| 20 | + make_initial_condition, |
| 21 | +) |
| 22 | +from pycaputo.typing import Array, StateFunctionT |
| 23 | + |
| 24 | +logger = get_logger(__name__) |
| 25 | + |
| 26 | + |
| 27 | +class AdvanceResult(NamedTuple): |
| 28 | + """Result of :func:`~pycaputo.stepping.advance` for |
| 29 | + :class:`ProductIntegrationMethod` subclasses.""" |
| 30 | + |
| 31 | + y: Array |
| 32 | + """Estimated solution at the next time step.""" |
| 33 | + trunc: Array |
| 34 | + """Estimated truncation error at the next time step.""" |
| 35 | + storage: Array |
| 36 | + """Array to add to the history storage.""" |
| 37 | + |
| 38 | + |
| 39 | +@dataclass(frozen=True) |
| 40 | +class SplineCollocationMethod( |
| 41 | + FractionalDifferentialEquationMethod[CaputoDerivative, StateFunctionT] |
| 42 | +): |
| 43 | + """A spline collocation method for the Caputo fractional derivative.""" |
| 44 | + |
| 45 | + if __debug__: |
| 46 | + |
| 47 | + def __post_init__(self) -> None: |
| 48 | + super().__post_init__() |
| 49 | + |
| 50 | + if not all(isinstance(d, CaputoDerivative) for d in self.ds): |
| 51 | + raise TypeError(f"Expected 'CaputoDerivative' operators: {self.ds}") |
| 52 | + |
| 53 | + @cached_property |
| 54 | + def derivative_order(self) -> tuple[float, ...]: |
| 55 | + return tuple([d.alpha for d in self.ds]) |
| 56 | + |
| 57 | + @cached_property |
| 58 | + def alpha(self) -> Array: |
| 59 | + return np.array([d.alpha for d in self.ds]) |
| 60 | + |
| 61 | + |
| 62 | +@make_initial_condition.register(SplineCollocationMethod) |
| 63 | +def _make_initial_condition_caputo_spline_collocation( # type: ignore[misc] |
| 64 | + m: SplineCollocationMethod[StateFunctionT], |
| 65 | +) -> Array: |
| 66 | + return m.y0[0] |
| 67 | + |
| 68 | + |
| 69 | +@evolve.register(SplineCollocationMethod) |
| 70 | +def _evolve_caputo_spline_collocation( # type: ignore[misc] |
| 71 | + m: SplineCollocationMethod[StateFunctionT], |
| 72 | + *, |
| 73 | + history: History[Any] | None = None, |
| 74 | + dtinit: float | None = None, |
| 75 | +) -> Iterator[Event]: |
| 76 | + from pycaputo.controller import estimate_initial_time_step |
| 77 | + |
| 78 | + if history is None: |
| 79 | + history = ProductIntegrationHistory.empty_like(m.y0[0]) |
| 80 | + |
| 81 | + # initialize |
| 82 | + c = m.control |
| 83 | + n = 0 |
| 84 | + t = c.tstart |
| 85 | + |
| 86 | + # determine the initial condition |
| 87 | + yprev = make_initial_condition(m) |
| 88 | + history.append(t, m.source(t, yprev)) |
| 89 | + |
| 90 | + # determine initial time step |
| 91 | + if dtinit is None: |
| 92 | + dt = estimate_initial_time_step( |
| 93 | + t, yprev, m.source, m.smallest_derivative_order, trunc=m.order + 1 |
| 94 | + ) |
| 95 | + else: |
| 96 | + dt = dtinit |
| 97 | + |
| 98 | + yield StepCompleted( |
| 99 | + t=t, |
| 100 | + iteration=n, |
| 101 | + dt=dt, |
| 102 | + y=yprev, |
| 103 | + eest=0.0, |
| 104 | + q=1.0, |
| 105 | + trunc=np.zeros_like(yprev), |
| 106 | + ) |
0 commit comments