Skip to content

Commit 1e9b681

Browse files
Merge pull request #10 from anand-avinash/toeplitz_precond
Adding preconditioners for the inverse Toeplitz operators
2 parents 1d1dfa4 + 3890738 commit 1e9b681

13 files changed

Lines changed: 400 additions & 270 deletions

brahmap/base/linop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
# Default (null) logger.
4141
null_log = logging.getLogger("linop")
42-
null_log.setLevel(logging.INFO)
42+
null_log.setLevel(logging.WARNING)
4343
null_log.addHandler(logging.NullHandler())
4444

4545

brahmap/base/noise_ops.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from typing import Literal
23

34
from ..base import LinearOperator, BlockDiagonalLinearOperator
45

@@ -12,7 +13,7 @@ def __init__(
1213
self,
1314
nargin: int,
1415
matvec: int,
15-
input_type: str = "covariance",
16+
input_type: Literal["covariance", "power_spectrum"] = "covariance",
1617
dtype: DTypeFloat = np.float64,
1718
**kwargs,
1819
):
@@ -54,7 +55,7 @@ def __init__(
5455
self,
5556
nargin: int,
5657
matvec: int,
57-
input_type: str = "covariance",
58+
input_type: Literal["covariance", "power_spectrum"] = "covariance",
5859
dtype: DTypeFloat = np.float64,
5960
**kwargs,
6061
):

brahmap/core/noise_ops_block_diag.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from typing import List, Union
2+
from typing import List, Union, Literal, Dict, Any
33

44
from ..base import (
55
BaseBlockDiagNoiseCovLinearOperator,
@@ -15,8 +15,9 @@ def __init__(
1515
operator,
1616
block_size: Union[np.ndarray, List],
1717
block_input: List[Union[np.ndarray, List]],
18-
input_type: str = "power_spectrum",
18+
input_type: Literal["covariance", "power_spectrum"] = "power_spectrum",
1919
dtype: DTypeFloat = np.float64,
20+
extra_kwargs: Dict[str, Any] = {},
2021
):
2122
MPI_RAISE_EXCEPTION(
2223
condition=(len(block_size) != len(block_input)),
@@ -31,6 +32,7 @@ def __init__(
3132
block_input=block_input,
3233
input_type=input_type,
3334
dtype=dtype,
35+
extra_kwargs=extra_kwargs,
3436
)
3537

3638
super(BlockDiagNoiseCovLO, self).__init__(
@@ -44,6 +46,7 @@ def __build_blocks(
4446
block_size,
4547
input_type,
4648
dtype,
49+
extra_kwargs,
4750
):
4851
block_list = []
4952
for idx, input in enumerate(block_input):
@@ -52,6 +55,7 @@ def __build_blocks(
5255
input=input,
5356
input_type=input_type,
5457
dtype=dtype,
58+
**extra_kwargs,
5559
)
5660
block_list.append(block_op)
5761
return block_list
@@ -63,13 +67,15 @@ def __init__(
6367
operator,
6468
block_size: Union[np.ndarray, List],
6569
block_input: List[Union[np.ndarray, List]],
66-
input_type: str = "power_spectrum",
70+
input_type: Literal["covariance", "power_spectrum"] = "power_spectrum",
6771
dtype: DTypeFloat = np.float64,
72+
extra_kwargs: Dict[str, Any] = {},
6873
):
6974
super(BlockDiagInvNoiseCovLO, self).__init__(
7075
operator,
7176
block_size,
7277
block_input,
7378
input_type,
7479
dtype,
80+
extra_kwargs,
7581
)

brahmap/core/noise_ops_circulant.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
import warnings
3-
from typing import List, Union
3+
from typing import List, Union, Literal
44

55
from ..utilities import TypeChangeWarning
66
from ..base import NoiseCovLinearOperator, InvNoiseCovLinearOperator
@@ -15,7 +15,7 @@ def __init__(
1515
self,
1616
size: int,
1717
input: Union[np.ndarray, List],
18-
input_type: str = "power_spectrum",
18+
input_type: Literal["covariance", "power_spectrum"] = "power_spectrum",
1919
dtype: DTypeFloat = np.float64,
2020
):
2121
input = np.asarray(a=input, dtype=dtype)
@@ -84,7 +84,7 @@ def __init__(
8484
self,
8585
size: int,
8686
input: Union[np.ndarray, List],
87-
input_type: str = "power_spectrum",
87+
input_type: Literal["covariance", "power_spectrum"] = "power_spectrum",
8888
dtype: DTypeFloat = np.float64,
8989
):
9090
input = np.asarray(a=input, dtype=dtype)

brahmap/core/noise_ops_diagonal.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import warnings
33
from numbers import Number
4-
from typing import List, Union
4+
from typing import List, Union, Literal
55

66

77
from ..utilities import TypeChangeWarning
@@ -20,7 +20,7 @@ def __init__(
2020
self,
2121
size: int,
2222
input: Union[np.ndarray, List, DTypeFloat] = 1.0,
23-
input_type="covariance",
23+
input_type: Literal["covariance", "power_spectrum"] = "covariance",
2424
dtype: DTypeFloat = np.float64,
2525
):
2626
if isinstance(input, Number) and input_type == "covariance":
@@ -95,7 +95,7 @@ def __init__(
9595
self,
9696
size: int,
9797
input: Union[np.ndarray, List, DTypeFloat] = 1.0,
98-
input_type="covariance",
98+
input_type: Literal["covariance", "power_spectrum"] = "covariance",
9999
dtype: DTypeFloat = np.float64,
100100
):
101101
if isinstance(input, Number) and input_type == "covariance":

brahmap/core/noise_ops_toeplitz.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import numpy as np
22
import warnings
3-
from typing import List, Union
3+
from typing import List, Union, Literal
44

55
from ..utilities import TypeChangeWarning
6-
from ..base import NoiseCovLinearOperator, InvNoiseCovLinearOperator
7-
from ..math import DTypeFloat
6+
from ..base import LinearOperator, NoiseCovLinearOperator, InvNoiseCovLinearOperator
7+
from ..math import DTypeFloat, cg
88
from ..mpi import MPI_RAISE_EXCEPTION
9-
10-
import scipy.sparse.linalg
9+
from ..core import InvNoiseCovLO_Circulant
1110

1211
from brahmap import MPI_UTILS
1312

@@ -19,7 +18,7 @@ def __init__(
1918
self,
2019
size: int,
2120
input: Union[np.ndarray, List],
22-
input_type: str = "power_spectrum",
21+
input_type: Literal["covariance", "power_spectrum"] = "power_spectrum",
2322
dtype: DTypeFloat = np.float64,
2423
):
2524
input = np.asarray(a=input, dtype=dtype)
@@ -108,9 +107,12 @@ def __init__(
108107
self,
109108
size: int,
110109
input: Union[np.ndarray, List],
111-
input_type: str = "power_spectrum",
112-
precond_op=None,
110+
input_type: Literal["covariance", "power_spectrum"] = "power_spectrum",
111+
precond_op: Union[
112+
LinearOperator, Literal[None, "Strang", "TChan", "RChan", "KK2"]
113+
] = None,
113114
precond_maxiter=50,
115+
precond_rtol=1.0e-10,
114116
precond_atol=1.0e-10,
115117
precond_callback=None,
116118
dtype: DTypeFloat = np.float64,
@@ -122,11 +124,61 @@ def __init__(
122124
dtype=dtype,
123125
)
124126

127+
self.__precond_rtol = precond_rtol
125128
self.__precond_atol = precond_atol
126129
self.__precond_maxiter = precond_maxiter
127-
self.__precond_op = precond_op
128130
self.__precond_callback = precond_callback
129131

132+
if precond_op is None:
133+
self.__precond_op = None
134+
elif isinstance(precond_op, LinearOperator) or isinstance(
135+
precond_op, np.ndarray
136+
):
137+
self.__precond_op = precond_op
138+
elif precond_op in ["Strang", "TChan", "RChan", "KK2"]:
139+
if input_type == "power_spectrum":
140+
cov = np.fft.ifft(input).real[:size]
141+
else:
142+
cov = input[:size]
143+
144+
if precond_op == "Strang":
145+
temp_size = int(np.floor(cov.size / 2))
146+
if cov.size % 2 == 0:
147+
new_cov = np.concatenate(
148+
[cov[:temp_size], cov[1 : temp_size + 1][::-1]]
149+
)
150+
else:
151+
new_cov = np.concatenate(
152+
[cov[: temp_size + 1], cov[1 : temp_size + 1][::-1]]
153+
)
154+
elif precond_op == "TChan":
155+
new_cov = np.empty_like(cov)
156+
new_cov[0] = cov[0]
157+
n = cov.size
158+
for idx in range(1, n):
159+
new_cov[idx] = ((n - idx) * cov[idx] + idx * cov[n - idx]) / n
160+
elif precond_op == "RChan":
161+
new_cov = np.roll(np.flip(cov), 1)
162+
new_cov += cov
163+
new_cov[0] = cov[0]
164+
elif precond_op == "KK2": # Circulant but not symmetric
165+
new_cov = np.roll(np.flip(cov), 1)
166+
new_cov[0] = 0
167+
new_cov = cov - new_cov
168+
169+
self.__precond_op = InvNoiseCovLO_Circulant(
170+
size=size,
171+
input=new_cov,
172+
input_type="covariance",
173+
dtype=dtype,
174+
)
175+
else:
176+
MPI_RAISE_EXCEPTION(
177+
condition=True,
178+
exception=ValueError,
179+
message="Invalid preconditioner operator provided!",
180+
)
181+
130182
super(InvNoiseCovLO_Toeplitz01, self).__init__(
131183
nargin=size,
132184
matvec=self._mult,
@@ -157,13 +209,15 @@ def _mult(self, vec: np.ndarray):
157209
)
158210
vec = vec.astype(dtype=self.dtype, copy=False)
159211

160-
prod, _ = scipy.sparse.linalg.gmres(
212+
prod, _ = cg(
161213
A=self.__toeplitz_op,
162214
b=vec,
215+
rtol=self.__precond_rtol,
163216
atol=self.__precond_atol,
164217
maxiter=self.__precond_maxiter,
165218
M=self.__precond_op,
166219
callback=self.__precond_callback,
220+
parallel=False,
167221
)
168222

169223
return prod

brahmap/lbsim/lbsim_noise_operators.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Union
1+
from typing import List, Union, Literal, Dict, Any
22

33
import numpy as np
44
import litebird_sim as lbs
@@ -63,7 +63,7 @@ def __init__(
6363
self,
6464
obs: Union[lbs.Observation, List[lbs.Observation]],
6565
input: Union[dict, Union[np.ndarray, List]],
66-
input_type: str = "power_spectrum",
66+
input_type: Literal["covariance", "power_spectrum"] = "power_spectrum",
6767
dtype=np.float64,
6868
):
6969
if isinstance(obs, lbs.Observation):
@@ -142,9 +142,10 @@ def __init__(
142142
self,
143143
obs: Union[lbs.Observation, List[lbs.Observation]],
144144
input: Union[dict, Union[np.ndarray, List]],
145-
input_type: str = "power_spectrum",
145+
input_type: Literal["covariance", "power_spectrum"] = "power_spectrum",
146146
operator=InvNoiseCovLO_Toeplitz01,
147147
dtype=np.float64,
148+
extra_kwargs: Dict[str, Any] = {},
148149
):
149150
if isinstance(obs, lbs.Observation):
150151
obs_list = [obs]
@@ -184,6 +185,7 @@ def __init__(
184185
block_input=block_input,
185186
input_type=input_type,
186187
dtype=dtype,
188+
extra_kwargs=extra_kwargs,
187189
)
188190

189191
def _resize_input(self, new_size, input, input_type, dtype):

brahmap/math/linalg.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Callable
12
import numpy as np
23
import scipy
34
import scipy.sparse
@@ -13,10 +14,37 @@ def parallel_norm(x: np.ndarray):
1314
return ret
1415

1516

16-
def cg(A, b, x0=None, atol=1.0e-12, maxiter=100, M=None, callback=None):
17-
A, M, x, b, postprocess = scipy.sparse.linalg._isolve.utils.make_system(A, M, x0, b)
17+
def cg(
18+
A,
19+
b,
20+
x0=None,
21+
rtol=1.0e-12,
22+
atol=1.0e-12,
23+
maxiter=100,
24+
M=None,
25+
callback=None,
26+
parallel=True,
27+
):
28+
A, M, x, b, postprocess = scipy.sparse.linalg._isolve.utils.make_system(
29+
A,
30+
M,
31+
x0,
32+
b,
33+
)
34+
35+
if parallel:
36+
norm_function: Callable = parallel_norm
37+
else:
38+
norm_function: Callable = np.linalg.norm
39+
40+
b_norm = norm_function(b)
1841

19-
b_norm = parallel_norm(b)
42+
atol, _ = scipy.sparse.linalg._isolve.iterative._get_atol_rtol(
43+
"cg",
44+
b_norm,
45+
atol,
46+
rtol,
47+
)
2048

2149
if b_norm == 0:
2250
return postprocess(b), 0
@@ -51,7 +79,7 @@ def cg(A, b, x0=None, atol=1.0e-12, maxiter=100, M=None, callback=None):
5179
r -= alpha * q
5280
rho_prev = rho_cur
5381

54-
norm_residual = parallel_norm(r) / b_norm
82+
norm_residual = norm_function(r) / b_norm
5583

5684
if callback:
5785
callback(x, r, norm_residual)

0 commit comments

Comments
 (0)