1
+ import numpy as np
2
+
3
+ from LoopStructural .interpolators ._discrete_interpolator import DiscreteInterpolator
4
+ from LoopStructural .interpolators ._finite_difference_interpolator import FiniteDifferenceInterpolator
5
+ from ._p1interpolator import P1Interpolator
6
+ from typing import Optional , Union , Callable
7
+ from scipy import sparse
8
+ from LoopStructural .utils import rng
9
+
10
+ class ConstantNormInterpolator :
11
+ """Adds a non linear constraint to an interpolator to constrain
12
+ the norm of the gradient to be a set value.
13
+
14
+ Returns
15
+ -------
16
+ _type_
17
+ _description_
18
+ """
19
+ def __init__ (self , interpolator : DiscreteInterpolator ):
20
+ """Initialise the constant norm inteprolator
21
+ with a discrete interpolator.
22
+
23
+ Parameters
24
+ ----------
25
+ interpolator : DiscreteInterpolator
26
+ The discrete interpolator to add constant norm to.
27
+ """
28
+ self .interpolator = interpolator
29
+ self .support = interpolator .support
30
+ self .random_subset = False
31
+ self .norm_length = 1.0
32
+ def add_constant_norm (self , w :float ):
33
+ """Add a constraint to the interpolator to constrain the norm of the gradient
34
+ to be a set value
35
+
36
+ Parameters
37
+ ----------
38
+ w : float
39
+ weighting of the constraint
40
+ """
41
+ if "constant norm" in self .interpolator .constraints :
42
+ _ = self .interpolator .constraints .pop ("constant norm" )
43
+
44
+ element_indices = np .arange (self .support .elements .shape [0 ])
45
+ if self .random_subset :
46
+ rng .shuffle (element_indices )
47
+ element_indices = element_indices [: int (0.1 * self .support .elements .shape [0 ])]
48
+ vertices , gradient , elements , inside = self .support .get_element_gradient_for_location (
49
+ self .support .barycentre [element_indices ]
50
+ )
51
+
52
+ t_g = gradient [:, :, :]
53
+ # t_n = gradient[self.support.shared_element_relationships[:, 1], :, :]
54
+ v_t = np .einsum (
55
+ "ijk,ik->ij" ,
56
+ t_g ,
57
+ self .interpolator .c [self .support .elements [elements ]],
58
+ )
59
+
60
+ v_t = v_t / np .linalg .norm (v_t , axis = 1 )[:, np .newaxis ]
61
+ A1 = np .einsum ("ij,ijk->ik" , v_t , t_g )
62
+
63
+ b = np .zeros (A1 .shape [0 ]) + self .norm_length
64
+ idc = np .hstack (
65
+ [
66
+ self .support .elements [elements ],
67
+ ]
68
+ )
69
+ self .interpolator .add_constraints_to_least_squares (A1 , b , idc , w = w , name = "constant norm" )
70
+
71
+ def solve_system (
72
+ self ,
73
+ solver : Optional [Union [Callable [[sparse .csr_matrix , np .ndarray ], np .ndarray ], str ]] = None ,
74
+ tol : Optional [float ] = None ,
75
+ solver_kwargs : dict = {},
76
+ ) -> bool :
77
+ """
78
+
79
+ Parameters
80
+ ----------
81
+ solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
82
+ _description_, by default None
83
+ tol : Optional[float], optional
84
+ _description_, by default None
85
+ solver_kwargs : dict, optional
86
+ _description_, by default {}
87
+
88
+ Returns
89
+ -------
90
+ bool
91
+ _description_
92
+ """
93
+ for i in range (20 ):
94
+ if i > 0 :
95
+ self .add_constant_norm (w = (0.1 * i ) ** 2 + 0.01 )
96
+ success = self .interpolator .solve_system (solver = solver , tol = tol , solver_kwargs = solver_kwargs )
97
+ return True
98
+
99
+ class ConstantNormP1Interpolator (P1Interpolator , ConstantNormInterpolator ):
100
+ """Constant norm interpolator using P1 base interpolator
101
+
102
+ Parameters
103
+ ----------
104
+ P1Interpolator : class
105
+ The P1Interpolator class.
106
+ ConstantNormInterpolator : class
107
+ The ConstantNormInterpolator class.
108
+ """
109
+ def __init__ (self , support ):
110
+ """Initialise the constant norm P1 interpolator.
111
+
112
+ Parameters
113
+ ----------
114
+ support : _type_
115
+ _description_
116
+ """
117
+ P1Interpolator .__init__ (self , support )
118
+ ConstantNormInterpolator .__init__ (self , self )
119
+
120
+ def solve_system (
121
+ self ,
122
+ solver : Optional [Union [Callable [[sparse .csr_matrix , np .ndarray ], np .ndarray ], str ]] = None ,
123
+ tol : Optional [float ] = None ,
124
+ solver_kwargs : dict = {},
125
+ ) -> bool :
126
+ """Solve the system of equations for the constant norm P1 interpolator.
127
+
128
+ Parameters
129
+ ----------
130
+ solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
131
+ _description_, by default None
132
+ tol : Optional[float], optional
133
+ _description_, by default None
134
+ solver_kwargs : dict, optional
135
+ _description_, by default {}
136
+
137
+ Returns
138
+ -------
139
+ bool
140
+ _description_
141
+ """
142
+ success = True
143
+ for i in range (20 ):
144
+
145
+ if i > 0 :
146
+ self .add_constant_norm (w = (0.1 * i ) ** 2 + 0.01 )
147
+ success = P1Interpolator .solve_system (self , solver , tol , solver_kwargs )
148
+ return success
149
+ class ConstantNormFDIInterpolator (FiniteDifferenceInterpolator , ConstantNormInterpolator ):
150
+ """Constant norm interpolator using finite difference base interpolator
151
+
152
+ Parameters
153
+ ----------
154
+ FiniteDifferenceInterpolator : class
155
+ The FiniteDifferenceInterpolator class.
156
+ ConstantNormInterpolator : class
157
+ The ConstantNormInterpolator class.
158
+ """
159
+ def __init__ (self , support ):
160
+ """Initialise the constant norm finite difference interpolator.
161
+
162
+ Parameters
163
+ ----------
164
+ support : _type_
165
+ _description_
166
+ """
167
+ FiniteDifferenceInterpolator .__init__ (self , support )
168
+ ConstantNormInterpolator .__init__ (self , self )
169
+ def solve_system (
170
+ self ,
171
+ solver : Optional [Union [Callable [[sparse .csr_matrix , np .ndarray ], np .ndarray ], str ]] = None ,
172
+ tol : Optional [float ] = None ,
173
+ solver_kwargs : dict = {},
174
+ ) -> bool :
175
+ """Solve the system of equations for the constant norm finite difference interpolator.
176
+
177
+ Parameters
178
+ ----------
179
+ solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
180
+ _description_, by default None
181
+ tol : Optional[float], optional
182
+ _description_, by default None
183
+ solver_kwargs : dict, optional
184
+ _description_, by default {}
185
+
186
+ Returns
187
+ -------
188
+ bool
189
+ _description_
190
+ """
191
+ success = True
192
+ for i in range (20 ):
193
+ if i > 0 :
194
+ self .add_constant_norm (w = (0.1 * i ) ** 2 + 0.01 )
195
+ success = FiniteDifferenceInterpolator .solve_system (self , solver , tol , solver_kwargs )
196
+ return success
0 commit comments