Skip to content

Commit dca4e7d

Browse files
committed
reduce code duplication
1 parent d3f4cb0 commit dca4e7d

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

LoopStructural/interpolators/_constant_norm.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class ConstantNormInterpolator:
1616
_type_
1717
_description_
1818
"""
19-
def __init__(self, interpolator: DiscreteInterpolator):
19+
def __init__(self, interpolator: DiscreteInterpolator,basetype):
2020
"""Initialise the constant norm inteprolator
2121
with a discrete interpolator.
2222
@@ -25,10 +25,12 @@ def __init__(self, interpolator: DiscreteInterpolator):
2525
interpolator : DiscreteInterpolator
2626
The discrete interpolator to add constant norm to.
2727
"""
28+
self.basetype = basetype
2829
self.interpolator = interpolator
2930
self.support = interpolator.support
3031
self.random_subset = False
3132
self.norm_length = 1.0
33+
self.n_iterations = 20
3234
def add_constant_norm(self, w:float):
3335
"""Add a constraint to the interpolator to constrain the norm of the gradient
3436
to be a set value
@@ -74,27 +76,33 @@ def solve_system(
7476
tol: Optional[float] = None,
7577
solver_kwargs: dict = {},
7678
) -> bool:
77-
"""
79+
"""Solve the system of equations iteratively for the constant norm interpolator.
7880
7981
Parameters
8082
----------
8183
solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
82-
_description_, by default None
84+
Solver function or name, by default None
8385
tol : Optional[float], optional
84-
_description_, by default None
86+
Tolerance for the solver, by default None
8587
solver_kwargs : dict, optional
86-
_description_, by default {}
88+
Additional arguments for the solver, by default {}
8789
8890
Returns
8991
-------
9092
bool
91-
_description_
93+
Success status of the solver
9294
"""
9395
success = True
94-
for i in range(20):
96+
for i in range(self.n_iterations):
9597
if i > 0:
9698
self.add_constant_norm(w=(0.1 * i) ** 2 + 0.01)
97-
success = self.interpolator.solve_system(solver=solver, tol=tol, solver_kwargs=solver_kwargs)
99+
# Ensure the interpolator is cast to P1Interpolator before calling solve_system
100+
if isinstance(self.interpolator, self.basetype):
101+
success = self.basetype.solve_system(self.interpolator, solver=solver, tol=tol, solver_kwargs=solver_kwargs)
102+
else:
103+
raise TypeError("self.interpolator is not an instance of P1Interpolator")
104+
if not success:
105+
break
98106
return success
99107

100108
class ConstantNormP1Interpolator(P1Interpolator, ConstantNormInterpolator):
@@ -116,7 +124,7 @@ def __init__(self, support):
116124
_description_
117125
"""
118126
P1Interpolator.__init__(self, support)
119-
ConstantNormInterpolator.__init__(self, self)
127+
ConstantNormInterpolator.__init__(self, self, P1Interpolator)
120128

121129
def solve_system(
122130
self,
@@ -129,24 +137,19 @@ def solve_system(
129137
Parameters
130138
----------
131139
solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
132-
_description_, by default None
140+
Solver function or name, by default None
133141
tol : Optional[float], optional
134-
_description_, by default None
142+
Tolerance for the solver, by default None
135143
solver_kwargs : dict, optional
136-
_description_, by default {}
144+
Additional arguments for the solver, by default {}
137145
138146
Returns
139147
-------
140148
bool
141-
_description_
149+
Success status of the solver
142150
"""
143-
success = True
144-
for i in range(20):
151+
return ConstantNormInterpolator.solve_system(self, solver=solver, tol=tol, solver_kwargs=solver_kwargs)
145152

146-
if i > 0:
147-
self.add_constant_norm(w=(0.1 * i) ** 2 + 0.01)
148-
success = P1Interpolator.solve_system(self, solver, tol, solver_kwargs)
149-
return success
150153
class ConstantNormFDIInterpolator(FiniteDifferenceInterpolator, ConstantNormInterpolator):
151154
"""Constant norm interpolator using finite difference base interpolator
152155
@@ -166,7 +169,7 @@ def __init__(self, support):
166169
_description_
167170
"""
168171
FiniteDifferenceInterpolator.__init__(self, support)
169-
ConstantNormInterpolator.__init__(self, self)
172+
ConstantNormInterpolator.__init__(self, self, FiniteDifferenceInterpolator)
170173
def solve_system(
171174
self,
172175
solver: Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]] = None,
@@ -178,20 +181,15 @@ def solve_system(
178181
Parameters
179182
----------
180183
solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
181-
_description_, by default None
184+
Solver function or name, by default None
182185
tol : Optional[float], optional
183-
_description_, by default None
186+
Tolerance for the solver, by default None
184187
solver_kwargs : dict, optional
185-
_description_, by default {}
188+
Additional arguments for the solver, by default {}
186189
187190
Returns
188191
-------
189192
bool
190-
_description_
193+
Success status of the solver
191194
"""
192-
success=True
193-
for i in range(20):
194-
if i > 0:
195-
self.add_constant_norm(w=(0.1 * i) ** 2 + 0.01)
196-
success = FiniteDifferenceInterpolator.solve_system(self, solver, tol, solver_kwargs)
197-
return success
195+
return ConstantNormInterpolator.solve_system(self, solver=solver, tol=tol, solver_kwargs=solver_kwargs)

0 commit comments

Comments
 (0)