Skip to content

Commit 1042705

Browse files
authored
Merge pull request #51 from HERA-Team/avoid-build-solver
Add option to save memory when using LinProductSolver infrastructure but not its solvers
2 parents 7255e7a + 18c9ab0 commit 1042705

3 files changed

Lines changed: 17 additions & 3 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ dist/*
88
coverage.xml
99
*.DS_Store
1010
*/_version.py
11+
src/linsolve/_version.py

src/linsolve/linsolve.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ def taylor_expand(terms, consts=None, prepend="d"):
856856
# XXX make a version of linproductsolver that taylor expands in e^{a+bi} form
857857
# see https://github.com/HERA-Team/linsolve/issues/15
858858
class LinProductSolver:
859-
def __init__(self, data, sol0, wgts={}, sparse=False, **kwargs):
859+
def __init__(self, data, sol0, wgts={}, sparse=False, build_solver=True, **kwargs):
860860
"""Set up a nonlinear system of equations of the form a*b + c*d = 1.0.
861861
862862
Linearize via Taylor expansion and solve iteratively using the Gauss-Newton
@@ -882,6 +882,9 @@ def __init__(self, data, sol0, wgts={}, sparse=False, **kwargs):
882882
sparse : bool
883883
If True, represents A matrix sparsely (though AtA, Aty end up dense)
884884
May be faster for certain systems of equations.
885+
build_solver : bool
886+
Advanced users can turn this off to save memory when using only LinProductSolver
887+
infrastructure but not solve() or solve_iteratively(), as in omnical.
885888
**kwargs: keyword arguments of constants (python variables in keys of data that
886889
are not to be solved for)
887890
"""
@@ -894,8 +897,14 @@ def __init__(self, data, sol0, wgts={}, sparse=False, **kwargs):
894897
self.init_kwargs, self.sols_kwargs = constants, deepcopy(constants)
895898
self.sols_kwargs.update(sol0)
896899
self.all_terms, self.taylors, self.taylor_keys = self.gen_taylors()
897-
self.build_solver(sol0)
898-
self.dtype = self.ls.dtype
900+
if build_solver:
901+
self.build_solver(sol0)
902+
self.dtype = self.ls.dtype
903+
else:
904+
self.sol0 = sol0
905+
self.dtype = infer_dtype(list(self.data.values())
906+
+ list(self.sol0.values())
907+
+ list(self.wgts.values()))
899908

900909
def gen_taylors(self, keys=None):
901910
"""Perform Taylor expansion, and map eq. keys to taylor expansion keys."""

tests/test_linsolve.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,10 @@ def test_init(self):
490490
np.testing.assert_almost_equal(eval(k), 0.002)
491491
assert len(ls.ls.prms) == 3
492492

493+
ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse, build_solver=False)
494+
assert not hasattr(ls, "ls")
495+
assert ls.dtype == np.complex64
496+
493497
def test_real_solve(self):
494498
x, y, z = 1.0, 2.0, 3.0
495499
keys = ["x*y", "x*z", "y*z"]

0 commit comments

Comments
 (0)