11import numpy as np
22import warnings
3- from typing import List , Union
3+ from typing import List , Union , Literal
44
55from ..utilities import TypeChangeWarning
6- from ..base import NoiseCovLinearOperator , InvNoiseCovLinearOperator
7- from ..math import DTypeFloat
6+ from ..base import LinearOperator , NoiseCovLinearOperator , InvNoiseCovLinearOperator
7+ from ..math import DTypeFloat , cg
88from ..mpi import MPI_RAISE_EXCEPTION
9-
10- import scipy .sparse .linalg
9+ from ..core import InvNoiseCovLO_Circulant
1110
1211from brahmap import MPI_UTILS
1312
@@ -19,7 +18,7 @@ def __init__(
1918 self ,
2019 size : int ,
2120 input : Union [np .ndarray , List ],
22- input_type : str = "power_spectrum" ,
21+ input_type : Literal [ "covariance" , "power_spectrum" ] = "power_spectrum" ,
2322 dtype : DTypeFloat = np .float64 ,
2423 ):
2524 input = np .asarray (a = input , dtype = dtype )
@@ -108,9 +107,12 @@ def __init__(
108107 self ,
109108 size : int ,
110109 input : Union [np .ndarray , List ],
111- input_type : str = "power_spectrum" ,
112- precond_op = None ,
110+ input_type : Literal ["covariance" , "power_spectrum" ] = "power_spectrum" ,
111+ precond_op : Union [
112+ LinearOperator , Literal [None , "Strang" , "TChan" , "RChan" , "KK2" ]
113+ ] = None ,
113114 precond_maxiter = 50 ,
115+ precond_rtol = 1.0e-10 ,
114116 precond_atol = 1.0e-10 ,
115117 precond_callback = None ,
116118 dtype : DTypeFloat = np .float64 ,
@@ -122,11 +124,61 @@ def __init__(
122124 dtype = dtype ,
123125 )
124126
127+ self .__precond_rtol = precond_rtol
125128 self .__precond_atol = precond_atol
126129 self .__precond_maxiter = precond_maxiter
127- self .__precond_op = precond_op
128130 self .__precond_callback = precond_callback
129131
132+ if precond_op is None :
133+ self .__precond_op = None
134+ elif isinstance (precond_op , LinearOperator ) or isinstance (
135+ precond_op , np .ndarray
136+ ):
137+ self .__precond_op = precond_op
138+ elif precond_op in ["Strang" , "TChan" , "RChan" , "KK2" ]:
139+ if input_type == "power_spectrum" :
140+ cov = np .fft .ifft (input ).real [:size ]
141+ else :
142+ cov = input [:size ]
143+
144+ if precond_op == "Strang" :
145+ temp_size = int (np .floor (cov .size / 2 ))
146+ if cov .size % 2 == 0 :
147+ new_cov = np .concatenate (
148+ [cov [:temp_size ], cov [1 : temp_size + 1 ][::- 1 ]]
149+ )
150+ else :
151+ new_cov = np .concatenate (
152+ [cov [: temp_size + 1 ], cov [1 : temp_size + 1 ][::- 1 ]]
153+ )
154+ elif precond_op == "TChan" :
155+ new_cov = np .empty_like (cov )
156+ new_cov [0 ] = cov [0 ]
157+ n = cov .size
158+ for idx in range (1 , n ):
159+ new_cov [idx ] = ((n - idx ) * cov [idx ] + idx * cov [n - idx ]) / n
160+ elif precond_op == "RChan" :
161+ new_cov = np .roll (np .flip (cov ), 1 )
162+ new_cov += cov
163+ new_cov [0 ] = cov [0 ]
164+ elif precond_op == "KK2" : # Circulant but not symmetric
165+ new_cov = np .roll (np .flip (cov ), 1 )
166+ new_cov [0 ] = 0
167+ new_cov = cov - new_cov
168+
169+ self .__precond_op = InvNoiseCovLO_Circulant (
170+ size = size ,
171+ input = new_cov ,
172+ input_type = "covariance" ,
173+ dtype = dtype ,
174+ )
175+ else :
176+ MPI_RAISE_EXCEPTION (
177+ condition = True ,
178+ exception = ValueError ,
179+ message = "Invalid preconditioner operator provided!" ,
180+ )
181+
130182 super (InvNoiseCovLO_Toeplitz01 , self ).__init__ (
131183 nargin = size ,
132184 matvec = self ._mult ,
@@ -157,13 +209,15 @@ def _mult(self, vec: np.ndarray):
157209 )
158210 vec = vec .astype (dtype = self .dtype , copy = False )
159211
160- prod , _ = scipy . sparse . linalg . gmres (
212+ prod , _ = cg (
161213 A = self .__toeplitz_op ,
162214 b = vec ,
215+ rtol = self .__precond_rtol ,
163216 atol = self .__precond_atol ,
164217 maxiter = self .__precond_maxiter ,
165218 M = self .__precond_op ,
166219 callback = self .__precond_callback ,
220+ parallel = False ,
167221 )
168222
169223 return prod
0 commit comments