Skip to content

Commit 29570f1

Browse files
committed
fix: adding constant norm interpolators
1 parent 1c9a91a commit 29570f1

File tree

3 files changed

+252
-30
lines changed

3 files changed

+252
-30
lines changed

LoopStructural/interpolators/__init__.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
44
"""
55

6+
67
__all__ = [
78
"InterpolatorType",
89
"GeologicalInterpolator",
@@ -21,39 +22,13 @@
2122
"StructuredGrid2D",
2223
"P2UnstructuredTetMesh",
2324
]
24-
from enum import IntEnum
25+
from ._interpolatortype import InterpolatorType, add_interpolator_type
26+
2527

2628
from ..utils import getLogger
2729

2830
logger = getLogger(__name__)
2931

30-
31-
class InterpolatorType(IntEnum):
32-
"""
33-
Enum for the different interpolator types
34-
35-
1-9 should cover interpolators with supports
36-
9+ are data supported
37-
"""
38-
39-
BASE = 0
40-
BASE_DISCRETE = 1
41-
FINITE_DIFFERENCE = 2
42-
DISCRETE_FOLD = 3
43-
PIECEWISE_LINEAR = 4
44-
PIECEWISE_QUADRATIC = 5
45-
BASE_DATA_SUPPORTED = 10
46-
SURFE = 11
47-
48-
49-
interpolator_string_map = {
50-
"FDI": InterpolatorType.FINITE_DIFFERENCE,
51-
"PLI": InterpolatorType.PIECEWISE_LINEAR,
52-
"P2": InterpolatorType.PIECEWISE_QUADRATIC,
53-
"P1": InterpolatorType.PIECEWISE_LINEAR,
54-
"DFI": InterpolatorType.DISCRETE_FOLD,
55-
'surfe': InterpolatorType.SURFE,
56-
}
5732
from ..interpolators._geological_interpolator import GeologicalInterpolator
5833
from ..interpolators._discrete_interpolator import DiscreteInterpolator
5934
from ..interpolators.supports import (
@@ -79,7 +54,7 @@ class InterpolatorType(IntEnum):
7954
)
8055
from ..interpolators._p2interpolator import P2Interpolator
8156
from ..interpolators._p1interpolator import P1Interpolator
82-
57+
from ..interpolators._constant_norm import ConstantNormP1Interpolator, ConstantNormFDIInterpolator
8358
try:
8459
from ..interpolators._surfe_wrapper import SurfeRBFInterpolator
8560
except ImportError:
@@ -93,6 +68,24 @@ def __init__(self, *args, **kwargs):
9368
raise ImportError(
9469
"Surfe cannot be imported. Please install Surfe. pip install surfe/ conda install -c loop3d surfe"
9570
)
71+
72+
# Ensure compatibility between the fallback and imported class
73+
SurfeRBFInterpolator = SurfeRBFInterpolator
74+
75+
76+
interpolator_string_map = {
77+
"FDI": InterpolatorType.FINITE_DIFFERENCE,
78+
"PLI": InterpolatorType.PIECEWISE_LINEAR,
79+
"P2": InterpolatorType.PIECEWISE_QUADRATIC,
80+
"P1": InterpolatorType.PIECEWISE_LINEAR,
81+
"DFI": InterpolatorType.DISCRETE_FOLD,
82+
'surfe': InterpolatorType.SURFE,
83+
"FDI_CN": InterpolatorType.FINITE_DIFFERENCE_CONSTANT_NORM,
84+
"P1_CN": InterpolatorType.PIECEWISE_LINEAR_CONSTANT_NORM,
85+
86+
}
87+
88+
# Define the mapping after all imports
9689
interpolator_map = {
9790
InterpolatorType.BASE: GeologicalInterpolator,
9891
InterpolatorType.BASE_DISCRETE: DiscreteInterpolator,
@@ -102,6 +95,8 @@ def __init__(self, *args, **kwargs):
10295
InterpolatorType.PIECEWISE_QUADRATIC: P2Interpolator,
10396
InterpolatorType.BASE_DATA_SUPPORTED: GeologicalInterpolator,
10497
InterpolatorType.SURFE: SurfeRBFInterpolator,
98+
InterpolatorType.PIECEWISE_LINEAR_CONSTANT_NORM: ConstantNormP1Interpolator,
99+
InterpolatorType.FINITE_DIFFERENCE_CONSTANT_NORM: ConstantNormFDIInterpolator,
105100
}
106101

107102
support_interpolator_map = {
@@ -119,9 +114,18 @@ def __init__(self, *args, **kwargs):
119114
3: SupportType.DataSupported,
120115
2: SupportType.DataSupported,
121116
},
117+
InterpolatorType.PIECEWISE_LINEAR_CONSTANT_NORM:{
118+
3: SupportType.TetMesh,
119+
2: SupportType.P1Unstructured2d,
120+
},
121+
InterpolatorType.FINITE_DIFFERENCE_CONSTANT_NORM: {
122+
3: SupportType.StructuredGrid,
123+
2: SupportType.StructuredGrid2D,
124+
}
122125
}
123126

124127
from ._interpolator_factory import InterpolatorFactory
125128
from ._interpolator_builder import InterpolatorBuilder
126129

127-
# from ._api import LoopInterpolator
130+
131+
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import numpy as np
2+
3+
from LoopStructural.interpolators._discrete_interpolator import DiscreteInterpolator
4+
from LoopStructural.interpolators._finite_difference_interpolator import FiniteDifferenceInterpolator
5+
from ._p1interpolator import P1Interpolator
6+
from typing import Optional, Union, Callable
7+
from scipy import sparse
8+
from LoopStructural.utils import rng
9+
10+
class ConstantNormInterpolator:
11+
"""Adds a non linear constraint to an interpolator to constrain
12+
the norm of the gradient to be a set value.
13+
14+
Returns
15+
-------
16+
_type_
17+
_description_
18+
"""
19+
def __init__(self, interpolator: DiscreteInterpolator):
20+
"""Initialise the constant norm inteprolator
21+
with a discrete interpolator.
22+
23+
Parameters
24+
----------
25+
interpolator : DiscreteInterpolator
26+
The discrete interpolator to add constant norm to.
27+
"""
28+
self.interpolator = interpolator
29+
self.support = interpolator.support
30+
self.random_subset = False
31+
self.norm_length = 1.0
32+
def add_constant_norm(self, w:float):
33+
"""Add a constraint to the interpolator to constrain the norm of the gradient
34+
to be a set value
35+
36+
Parameters
37+
----------
38+
w : float
39+
weighting of the constraint
40+
"""
41+
if "constant norm" in self.interpolator.constraints:
42+
_ = self.interpolator.constraints.pop("constant norm")
43+
44+
element_indices = np.arange(self.support.elements.shape[0])
45+
if self.random_subset:
46+
rng.shuffle(element_indices)
47+
element_indices = element_indices[: int(0.1 * self.support.elements.shape[0])]
48+
vertices, gradient, elements, inside = self.support.get_element_gradient_for_location(
49+
self.support.barycentre[element_indices]
50+
)
51+
52+
t_g = gradient[:, :, :]
53+
# t_n = gradient[self.support.shared_element_relationships[:, 1], :, :]
54+
v_t = np.einsum(
55+
"ijk,ik->ij",
56+
t_g,
57+
self.interpolator.c[self.support.elements[elements]],
58+
)
59+
60+
v_t = v_t / np.linalg.norm(v_t, axis=1)[:, np.newaxis]
61+
A1 = np.einsum("ij,ijk->ik", v_t, t_g)
62+
63+
b = np.zeros(A1.shape[0]) + self.norm_length
64+
idc = np.hstack(
65+
[
66+
self.support.elements[elements],
67+
]
68+
)
69+
self.interpolator.add_constraints_to_least_squares(A1, b, idc, w=w, name="constant norm")
70+
71+
def solve_system(
72+
self,
73+
solver: Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]] = None,
74+
tol: Optional[float] = None,
75+
solver_kwargs: dict = {},
76+
) -> bool:
77+
"""
78+
79+
Parameters
80+
----------
81+
solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
82+
_description_, by default None
83+
tol : Optional[float], optional
84+
_description_, by default None
85+
solver_kwargs : dict, optional
86+
_description_, by default {}
87+
88+
Returns
89+
-------
90+
bool
91+
_description_
92+
"""
93+
for i in range(20):
94+
if i > 0:
95+
self.add_constant_norm(w=(0.1 * i) ** 2 + 0.01)
96+
success = self.interpolator.solve_system(solver=solver, tol=tol, solver_kwargs=solver_kwargs)
97+
return True
98+
99+
class ConstantNormP1Interpolator(P1Interpolator, ConstantNormInterpolator):
100+
"""Constant norm interpolator using P1 base interpolator
101+
102+
Parameters
103+
----------
104+
P1Interpolator : class
105+
The P1Interpolator class.
106+
ConstantNormInterpolator : class
107+
The ConstantNormInterpolator class.
108+
"""
109+
def __init__(self, support):
110+
"""Initialise the constant norm P1 interpolator.
111+
112+
Parameters
113+
----------
114+
support : _type_
115+
_description_
116+
"""
117+
P1Interpolator.__init__(self, support)
118+
ConstantNormInterpolator.__init__(self, self)
119+
120+
def solve_system(
121+
self,
122+
solver: Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]] = None,
123+
tol: Optional[float] = None,
124+
solver_kwargs: dict = {},
125+
) -> bool:
126+
"""Solve the system of equations for the constant norm P1 interpolator.
127+
128+
Parameters
129+
----------
130+
solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
131+
_description_, by default None
132+
tol : Optional[float], optional
133+
_description_, by default None
134+
solver_kwargs : dict, optional
135+
_description_, by default {}
136+
137+
Returns
138+
-------
139+
bool
140+
_description_
141+
"""
142+
success = True
143+
for i in range(20):
144+
145+
if i > 0:
146+
self.add_constant_norm(w=(0.1 * i) ** 2 + 0.01)
147+
success = P1Interpolator.solve_system(self, solver, tol, solver_kwargs)
148+
return success
149+
class ConstantNormFDIInterpolator(FiniteDifferenceInterpolator, ConstantNormInterpolator):
150+
"""Constant norm interpolator using finite difference base interpolator
151+
152+
Parameters
153+
----------
154+
FiniteDifferenceInterpolator : class
155+
The FiniteDifferenceInterpolator class.
156+
ConstantNormInterpolator : class
157+
The ConstantNormInterpolator class.
158+
"""
159+
def __init__(self, support):
160+
"""Initialise the constant norm finite difference interpolator.
161+
162+
Parameters
163+
----------
164+
support : _type_
165+
_description_
166+
"""
167+
FiniteDifferenceInterpolator.__init__(self, support)
168+
ConstantNormInterpolator.__init__(self, self)
169+
def solve_system(
170+
self,
171+
solver: Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]] = None,
172+
tol: Optional[float] = None,
173+
solver_kwargs: dict = {},
174+
) -> bool:
175+
"""Solve the system of equations for the constant norm finite difference interpolator.
176+
177+
Parameters
178+
----------
179+
solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
180+
_description_, by default None
181+
tol : Optional[float], optional
182+
_description_, by default None
183+
solver_kwargs : dict, optional
184+
_description_, by default {}
185+
186+
Returns
187+
-------
188+
bool
189+
_description_
190+
"""
191+
success=True
192+
for i in range(20):
193+
if i > 0:
194+
self.add_constant_norm(w=(0.1 * i) ** 2 + 0.01)
195+
success = FiniteDifferenceInterpolator.solve_system(self, solver, tol, solver_kwargs)
196+
return success
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from enum import Enum
2+
3+
class InterpolatorType(Enum):
4+
"""
5+
Enum for the different interpolator types
6+
7+
Each value is a unique identifier.
8+
"""
9+
10+
BASE = "BASE"
11+
BASE_DISCRETE = "BASE_DISCRETE"
12+
FINITE_DIFFERENCE = "FINITE_DIFFERENCE"
13+
DISCRETE_FOLD = "DISCRETE_FOLD"
14+
PIECEWISE_LINEAR = "PIECEWISE_LINEAR"
15+
PIECEWISE_QUADRATIC = "PIECEWISE_QUADRATIC"
16+
BASE_DATA_SUPPORTED = "BASE_DATA_SUPPORTED"
17+
SURFE = "SURFE"
18+
PIECEWISE_LINEAR_CONSTANT_NORM = "PIECEWISE_LINEAR_CONSTANT_NORM"
19+
FINITE_DIFFERENCE_CONSTANT_NORM = "FINITE_DIFFERENCE_CONSTANT_NORM"
20+
21+
22+

0 commit comments

Comments
 (0)