Skip to content

Commit ce19f54

Browse files
committed
feat: implement spline collocation methods
1 parent 881a301 commit ce19f54

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# SPDX-FileCopyrightText: 2024 Alexandru Fikl <[email protected]>
2+
# SPDX-License-Identifier: MIT
3+
4+
from __future__ import annotations
5+
6+
from collections.abc import Iterator
7+
from dataclasses import dataclass
8+
from functools import cached_property
9+
from typing import Any, NamedTuple
10+
11+
import numpy as np
12+
13+
from pycaputo.derivatives import CaputoDerivative
14+
from pycaputo.events import Event, StepCompleted
15+
from pycaputo.history import History, ProductIntegrationHistory
16+
from pycaputo.logging import get_logger
17+
from pycaputo.stepping import (
18+
FractionalDifferentialEquationMethod,
19+
evolve,
20+
make_initial_condition,
21+
)
22+
from pycaputo.typing import Array, StateFunctionT
23+
24+
logger = get_logger(__name__)
25+
26+
27+
class AdvanceResult(NamedTuple):
28+
"""Result of :func:`~pycaputo.stepping.advance` for
29+
:class:`ProductIntegrationMethod` subclasses."""
30+
31+
y: Array
32+
"""Estimated solution at the next time step."""
33+
trunc: Array
34+
"""Estimated truncation error at the next time step."""
35+
storage: Array
36+
"""Array to add to the history storage."""
37+
38+
39+
@dataclass(frozen=True)
40+
class SplineCollocationMethod(
41+
FractionalDifferentialEquationMethod[CaputoDerivative, StateFunctionT]
42+
):
43+
"""A spline collocation method for the Caputo fractional derivative."""
44+
45+
if __debug__:
46+
47+
def __post_init__(self) -> None:
48+
super().__post_init__()
49+
50+
if not all(isinstance(d, CaputoDerivative) for d in self.ds):
51+
raise TypeError(f"Expected 'CaputoDerivative' operators: {self.ds}")
52+
53+
@cached_property
54+
def derivative_order(self) -> tuple[float, ...]:
55+
return tuple([d.alpha for d in self.ds])
56+
57+
@cached_property
58+
def alpha(self) -> Array:
59+
return np.array([d.alpha for d in self.ds])
60+
61+
62+
@make_initial_condition.register(SplineCollocationMethod)
63+
def _make_initial_condition_caputo_spline_collocation( # type: ignore[misc]
64+
m: SplineCollocationMethod[StateFunctionT],
65+
) -> Array:
66+
return m.y0[0]
67+
68+
69+
@evolve.register(SplineCollocationMethod)
70+
def _evolve_caputo_spline_collocation( # type: ignore[misc]
71+
m: SplineCollocationMethod[StateFunctionT],
72+
*,
73+
history: History[Any] | None = None,
74+
dtinit: float | None = None,
75+
) -> Iterator[Event]:
76+
from pycaputo.controller import estimate_initial_time_step
77+
78+
if history is None:
79+
history = ProductIntegrationHistory.empty_like(m.y0[0])
80+
81+
# initialize
82+
c = m.control
83+
n = 0
84+
t = c.tstart
85+
86+
# determine the initial condition
87+
yprev = make_initial_condition(m)
88+
history.append(t, m.source(t, yprev))
89+
90+
# determine initial time step
91+
if dtinit is None:
92+
dt = estimate_initial_time_step(
93+
t, yprev, m.source, m.smallest_derivative_order, trunc=m.order + 1
94+
)
95+
else:
96+
dt = dtinit
97+
98+
yield StepCompleted(
99+
t=t,
100+
iteration=n,
101+
dt=dt,
102+
y=yprev,
103+
eest=0.0,
104+
q=1.0,
105+
trunc=np.zeros_like(yprev),
106+
)

0 commit comments

Comments
 (0)