@@ -16,7 +16,7 @@ class ConstantNormInterpolator:
16
16
_type_
17
17
_description_
18
18
"""
19
- def __init__ (self , interpolator : DiscreteInterpolator ):
19
+ def __init__ (self , interpolator : DiscreteInterpolator , basetype ):
20
20
"""Initialise the constant norm inteprolator
21
21
with a discrete interpolator.
22
22
@@ -25,10 +25,12 @@ def __init__(self, interpolator: DiscreteInterpolator):
25
25
interpolator : DiscreteInterpolator
26
26
The discrete interpolator to add constant norm to.
27
27
"""
28
+ self .basetype = basetype
28
29
self .interpolator = interpolator
29
30
self .support = interpolator .support
30
31
self .random_subset = False
31
32
self .norm_length = 1.0
33
+ self .n_iterations = 20
32
34
def add_constant_norm (self , w :float ):
33
35
"""Add a constraint to the interpolator to constrain the norm of the gradient
34
36
to be a set value
@@ -74,27 +76,33 @@ def solve_system(
74
76
tol : Optional [float ] = None ,
75
77
solver_kwargs : dict = {},
76
78
) -> bool :
77
- """
79
+ """Solve the system of equations iteratively for the constant norm interpolator.
78
80
79
81
Parameters
80
82
----------
81
83
solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
82
- _description_ , by default None
84
+ Solver function or name , by default None
83
85
tol : Optional[float], optional
84
- _description_ , by default None
86
+ Tolerance for the solver , by default None
85
87
solver_kwargs : dict, optional
86
- _description_ , by default {}
88
+ Additional arguments for the solver , by default {}
87
89
88
90
Returns
89
91
-------
90
92
bool
91
- _description_
93
+ Success status of the solver
92
94
"""
93
95
success = True
94
- for i in range (20 ):
96
+ for i in range (self . n_iterations ):
95
97
if i > 0 :
96
98
self .add_constant_norm (w = (0.1 * i ) ** 2 + 0.01 )
97
- success = self .interpolator .solve_system (solver = solver , tol = tol , solver_kwargs = solver_kwargs )
99
+ # Ensure the interpolator is cast to P1Interpolator before calling solve_system
100
+ if isinstance (self .interpolator , self .basetype ):
101
+ success = self .basetype .solve_system (self .interpolator , solver = solver , tol = tol , solver_kwargs = solver_kwargs )
102
+ else :
103
+ raise TypeError ("self.interpolator is not an instance of P1Interpolator" )
104
+ if not success :
105
+ break
98
106
return success
99
107
100
108
class ConstantNormP1Interpolator (P1Interpolator , ConstantNormInterpolator ):
@@ -116,7 +124,7 @@ def __init__(self, support):
116
124
_description_
117
125
"""
118
126
P1Interpolator .__init__ (self , support )
119
- ConstantNormInterpolator .__init__ (self , self )
127
+ ConstantNormInterpolator .__init__ (self , self , P1Interpolator )
120
128
121
129
def solve_system (
122
130
self ,
@@ -129,24 +137,19 @@ def solve_system(
129
137
Parameters
130
138
----------
131
139
solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
132
- _description_ , by default None
140
+ Solver function or name , by default None
133
141
tol : Optional[float], optional
134
- _description_ , by default None
142
+ Tolerance for the solver , by default None
135
143
solver_kwargs : dict, optional
136
- _description_ , by default {}
144
+ Additional arguments for the solver , by default {}
137
145
138
146
Returns
139
147
-------
140
148
bool
141
- _description_
149
+ Success status of the solver
142
150
"""
143
- success = True
144
- for i in range (20 ):
151
+ return ConstantNormInterpolator .solve_system (self , solver = solver , tol = tol , solver_kwargs = solver_kwargs )
145
152
146
- if i > 0 :
147
- self .add_constant_norm (w = (0.1 * i ) ** 2 + 0.01 )
148
- success = P1Interpolator .solve_system (self , solver , tol , solver_kwargs )
149
- return success
150
153
class ConstantNormFDIInterpolator (FiniteDifferenceInterpolator , ConstantNormInterpolator ):
151
154
"""Constant norm interpolator using finite difference base interpolator
152
155
@@ -166,7 +169,7 @@ def __init__(self, support):
166
169
_description_
167
170
"""
168
171
FiniteDifferenceInterpolator .__init__ (self , support )
169
- ConstantNormInterpolator .__init__ (self , self )
172
+ ConstantNormInterpolator .__init__ (self , self , FiniteDifferenceInterpolator )
170
173
def solve_system (
171
174
self ,
172
175
solver : Optional [Union [Callable [[sparse .csr_matrix , np .ndarray ], np .ndarray ], str ]] = None ,
@@ -178,20 +181,15 @@ def solve_system(
178
181
Parameters
179
182
----------
180
183
solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
181
- _description_ , by default None
184
+ Solver function or name , by default None
182
185
tol : Optional[float], optional
183
- _description_ , by default None
186
+ Tolerance for the solver , by default None
184
187
solver_kwargs : dict, optional
185
- _description_ , by default {}
188
+ Additional arguments for the solver , by default {}
186
189
187
190
Returns
188
191
-------
189
192
bool
190
- _description_
193
+ Success status of the solver
191
194
"""
192
- success = True
193
- for i in range (20 ):
194
- if i > 0 :
195
- self .add_constant_norm (w = (0.1 * i ) ** 2 + 0.01 )
196
- success = FiniteDifferenceInterpolator .solve_system (self , solver , tol , solver_kwargs )
197
- return success
195
+ return ConstantNormInterpolator .solve_system (self , solver = solver , tol = tol , solver_kwargs = solver_kwargs )
0 commit comments