Skip to content

Commit 29e7df7

Browse files
committed
feat: implement spline collocation methods
1 parent 4f2a415 commit 29e7df7

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# SPDX-FileCopyrightText: 2024 Alexandru Fikl <[email protected]>
2+
# SPDX-License-Identifier: MIT
3+
4+
from __future__ import annotations
5+
6+
from dataclasses import dataclass
7+
from functools import cached_property
8+
from typing import Any, Iterator, NamedTuple
9+
10+
import numpy as np
11+
12+
from pycaputo.derivatives import CaputoDerivative, Side
13+
from pycaputo.events import Event, StepCompleted
14+
from pycaputo.history import History, ProductIntegrationHistory
15+
from pycaputo.logging import get_logger
16+
from pycaputo.stepping import (
17+
FractionalDifferentialEquationMethod,
18+
evolve,
19+
make_initial_condition,
20+
)
21+
from pycaputo.utils import Array, StateFunctionT
22+
23+
logger = get_logger(__name__)
24+
25+
26+
class AdvanceResult(NamedTuple):
27+
"""Result of :func:`~pycaputo.stepping.advance` for
28+
:class:`ProductIntegrationMethod` subclasses."""
29+
30+
y: Array
31+
"""Estimated solution at the next time step."""
32+
trunc: Array
33+
"""Estimated truncation error at the next time step."""
34+
storage: Array
35+
"""Array to add to the history storage."""
36+
37+
38+
@dataclass(frozen=True)
39+
class SplineCollocationMethod(FractionalDifferentialEquationMethod[StateFunctionT]):
40+
"""A spline collocation method for"""
41+
42+
@cached_property
43+
def d(self) -> tuple[CaputoDerivative, ...]:
44+
return tuple([
45+
CaputoDerivative(alpha=alpha, side=Side.Left)
46+
for alpha in self.derivative_order
47+
])
48+
49+
50+
@make_initial_condition.register(SplineCollocationMethod)
51+
def _make_initial_condition_caputo(
52+
m: SplineCollocationMethod[StateFunctionT],
53+
) -> Array:
54+
return m.y0[0]
55+
56+
57+
@evolve.register(SplineCollocationMethod)
58+
def _evolve_pi(
59+
m: SplineCollocationMethod[StateFunctionT],
60+
*,
61+
history: History[Any] | None = None,
62+
dtinit: float | None = None,
63+
) -> Iterator[Event]:
64+
from pycaputo.controller import estimate_initial_time_step
65+
66+
if history is None:
67+
history = ProductIntegrationHistory.empty_like(m.y0[0])
68+
69+
# initialize
70+
c = m.control
71+
n = 0
72+
t = c.tstart
73+
74+
# determine the initial condition
75+
yprev = make_initial_condition(m)
76+
history.append(t, m.source(t, yprev))
77+
78+
# determine initial time step
79+
if dtinit is None:
80+
dt = estimate_initial_time_step(
81+
t, yprev, m.source, m.smallest_derivative_order, trunc=m.order + 1
82+
)
83+
else:
84+
dt = dtinit
85+
86+
yield StepCompleted(
87+
t=t,
88+
iteration=n,
89+
dt=dt,
90+
y=yprev,
91+
eest=0.0,
92+
q=1.0,
93+
trunc=np.zeros_like(yprev),
94+
)

0 commit comments

Comments
 (0)