From 2ef4334689cf63a0e302b0c7f0f56fb6a7b21a44 Mon Sep 17 00:00:00 2001 From: sy3394 Date: Thu, 8 Jun 2023 18:39:34 +0300 Subject: [PATCH 1/6] first working version --- lyncs_quda/__init__.py | 1 + lyncs_quda/dirac.py | 43 ++++++++++++- lyncs_quda/lib.py | 23 ++++++- lyncs_quda/multigrid.py | 123 +++++++++++++++++++++++++++++++++++++ lyncs_quda/solver.py | 34 +++++++--- lyncs_quda/spinor_field.py | 10 +++ lyncs_quda/struct.py | 14 +++-- lyncs_quda/structs.py | 1 - test/test_lib.py | 1 + test/test_multigrid.py | 39 ++++++++++++ test/test_structs.py | 2 + 11 files changed, 272 insertions(+), 19 deletions(-) create mode 100644 lyncs_quda/multigrid.py create mode 100644 test/test_multigrid.py diff --git a/lyncs_quda/__init__.py b/lyncs_quda/__init__.py index 5986219..0873f3a 100644 --- a/lyncs_quda/__init__.py +++ b/lyncs_quda/__init__.py @@ -14,3 +14,4 @@ from .dirac import * from .solver import * from .evenodd import * +from .multigrid import * diff --git a/lyncs_quda/dirac.py b/lyncs_quda/dirac.py index a6c723d..2aa8038 100644 --- a/lyncs_quda/dirac.py +++ b/lyncs_quda/dirac.py @@ -14,11 +14,13 @@ from .clover_field import CloverField from .spinor_field import spinor from .lib import lib +from .structs import QudaGaugeParam from .enums import ( QudaDiracType, QudaMatPCType, QudaDagType, QudaParity, + QudaDslashType, ) @@ -66,6 +68,18 @@ def type(self): return "CLOVER" + PC return "TWISTED_CLOVER" + PC + @property + @QudaDslashType + def dslash_type(self): + if "coarse" in self.type: + return "INVALID" + dslash_type = str(self.type).replace("pc","") + dslash_type = dslash_type.replace("gauge_","") + if "clover" == dslash_type: dslash_type += "_wilson" + if "mobius" in dslash_type: dslash_type = dslash_type.replace("domain_wall","dwf") + + return dslash_type + @property @QudaMatPCType def matPCtype(self): @@ -345,7 +359,34 @@ def force(self, *phis, out=None, mult=2, coeffs=None, **params): return out - + #TODO: needs to be more automatic + def setGaugeParam(self, **gauge_options): + g_param = QudaGaugeParam() + + print(self.type, self.dslash_type) + #TODO: prepare default params for other type of dirac op + if "wilson" in self.dslash_type or "clover" in self.dslash_type or "twisted" in self.dslash_type: + lib.setWilsonGaugeParam(g_param.quda) + elif "staggered" in self.type: + lib.setStaggeredGaugeParam(g_param.quda) + else: + lib.setGaugeParam(g_param.quda) + + g_param.location = int(self.gauge.location) + g_param.X = self.gauge.local_lattice + g_param.anisotropy = self.gauge.quda_field.Anisotropy() + g_param.tadpole_coeff = self.gauge.quda_field.Tadpole() + g_param.type = int(self.gauge.link_type) + g_param.gauge_order = int(self.gauge.order) + g_param.t_boundary = int(self.gauge.t_boundary) + g_param.cpu_prec = int(self.gauge.precision) + g_param.cuda_prec = int(self.gauge.precision) + g_param.update(gauge_options) + print(getattr(g_param._quda_params,"gauge_order"), g_param.type) + print(self.gauge.quda_field.Gauge_p(), int(self.gauge.order))#g_param.gauge_order) + lib.loadGaugeQuda(self.gauge.quda_field.Gauge_p(), g_param.quda) + + GaugeField.Dirac = wraps(Dirac)(lambda *args, **kwargs: Dirac(*args, **kwargs)) diff --git a/lyncs_quda/lib.py b/lyncs_quda/lib.py index c9d550d..b9cd09a 100644 --- a/lyncs_quda/lib.py +++ b/lyncs_quda/lib.py @@ -223,6 +223,24 @@ def copy_struct(self): ) return self.lyncs_quda_copy_struct + @property + def set_mg_eig_param(self): + try: + return self.lyncs_quda_set_mg_eig_param + except AttributeError: + cppdef( + """ + template + void lyncs_quda_set_mg_eig_param(T** ptr_array, T param, int i, bool is_null=false) { + if ( i < n) { + if (is_null) ptr_array[i] = nullptr; + else ptr_array[i] = ¶m; + } + } + """ + ) + return self.lyncs_quda_set_mg_eig_param + def save_tuning(self): if self.tune_enabled: self.saveTuneCache() @@ -276,13 +294,14 @@ def __del__(self): "array.h", "momentum.h", "tune_quda.h", + "host_utils.h", + "command_line_params.h", ] - lib = QudaLib( path=PATHS, header=headers, - library=["libquda.so"] + libs, + library=["libquda.so", "libquda_test.so"] + libs, namespace=["quda", "lyncs_quda"], defined={"QUDA_PRECISION": QUDA_PRECISION, "QUDA_RECONSTRUCT": QUDA_RECONSTRUCT}, ) diff --git a/lyncs_quda/multigrid.py b/lyncs_quda/multigrid.py new file mode 100644 index 0000000..74a2d1f --- /dev/null +++ b/lyncs_quda/multigrid.py @@ -0,0 +1,123 @@ +""" +Interface to multigrid_solver +""" + +__all__ = ["MultigridPreconditioner"] + +from cppyy import bind_object +from lyncs_cppyy import nullptr +from lyncs_utils import isiterable +from .lib import lib +from .enums import QudaInverterType, QudaPrecision, QudaSolveType +from .structs import QudaInvertParam, QudaMultigridParam, QudaEigParam + +class MultigridPreconditioner: + __slots__ = ["_mg_solver", "mg_param", "inv_param"] + + def __init__(self, D, inv_options={}, mg_options={}, eig_options={}, is_eig=False): + self._mg_solver = None + self.mg_param, self.inv_param = self.prepareParams(D, inv_options=inv_options, mg_options=mg_options, eig_options=eig_options, is_eig=is_eig) + self.setMG_solver(self.mg_param) + + @property + @QudaInverterType + def inv_type_precondition(self): + return "MG_INVERTER" + + @property + def preconditioner(self): + return self._mg_solver + + def prepareParams(self, D, g_options={}, inv_options={}, mg_options={}, eig_options={}, is_eig=False): + # INPUT: D is a Dirac instance + # is_eig is a list of bools indicating whether eigsolver is used to generate + # near null-vectors at each level + inv_param = QudaInvertParam() + mg_param = QudaMultigridParam() + mg_param.invert_param = inv_param.quda + + # set* are defined in set_params.cpp, setting params to vals according to the ones defined globally + # <- command_line_params.cpp: contains some default values for those global vars, some set to invalid + # <- host_utils.h provides funcs to set global vars to some meaningful vals, according to vals in command_line... + # <- misc.h implemented in misc.cpp + + # sets fields to default values + #app = lib.make_app() + #lib.add_multigrid_option_group(app) + #app.parse(1,"--solve_type 2") + # Set internal global vars to their default vals + dslash_type = D.dslash_type #inv_param.dslash_type.upper() + solve_type = QudaSolveType["direct"] if D.full else QudaSolveType["direct_pc"] + lib.dslash_type = int(dslash_type) + lib.solve_type = int(solve_type) + lib.setQudaPrecisions() + lib.setQudaDefaultMgTestParams() + lib.setQudaMgSolveTypes() + + + + # Set param vals to the default vals and update according to the user's specification + D.setGaugeParam(gauge_options=g_options) + lib.setMultigridParam(mg_param.quda) + if not D.full: inv_param.matpc_type = int(D.matPCtype) + inv_param.dagger = int(D.dagger) + inv_param.cpu_prec = int(D.precision) # quda.h says this is supposed to be the prec of input fermion field + inv_param.cuda_prec = int(D.precision) + if "clover" in D.type: + inv_param.compute_clover = False + inv_param.clover_cpu_prec = int(D.clover.precision) + inv_param.clover_cuda_prec = int(D.clover.precision) + inv_param.clover_order = int(D.clover.order) + inv_param.clover_location = int(D.clover.location) + inv_param.clover_csw = D.clover.csw + inv_param.clover_coeff = D.clover.coeff + inv_param.clover_rho = D.clover.rho + inv_param.compute_clover = False + inv_param.compute_clover_inverse = False + inv_param.return_clover = False + inv_param.return_clover_inverse = False + inv_param.update(inv_options) + mg_param.update(mg_options) + if "clover" in D.type: + print("mult init clover") + D.clover.clover_field + D.clover.inverse_field + lib.loadCloverQuda(D.clover.quda_field.V(), D.clover.quda_field.V(True), inv_param.quda) + mg_param.invert_param = inv_param.quda #not sure if this is necessary? + + # Only these fermions are supported with MG + print(dslash_type, type(dslash_type)) + if dslash_type != "WILSON" and dslash_type != "CLOVER_WILSON" and dslash_type != "TWISTED_MASS" and dslash_type != "TWISTED_CLOVER": + raise ValueError(f"dslash_type {dslash_type} not supported for MG") + # Only these solve types are supported with MG + if solve_type != "DIRECT" and solve_type != "DIRECT_PC": + raise ValueError(f"Solve_type {solve_type} not supported with MG. Please use QUDA_DIRECT_SOLVE or QUDA_DIRECT_PC_SOLVE") + print(type(mg_param)) + if not isiterable(is_eig): + is_eig = [is_eig]*mg_param.n_level + for i, eig in enumerate(is_eig): + eig_param = QudaEigParam() + if eig: + lib.setMultigridEigParam(eig_param.quda) + eig_param.update(eig_options) + lib.set_mg_eig_param["QudaEigParam", lib.QUDA_MAX_MG_LEVEL](mg_param.eig_param, eig_param.quda, i) + else: + print(mg_param.eig_param) + lib.set_mg_eig_param["QudaEigParam", lib.QUDA_MAX_MG_LEVEL](mg_param.eig_param, eig_param.quda, i, is_null=True) #?to_pointer + addressof? + + return mg_param, inv_param + + def setMG_solver(self, mg_param): + if self._mg_solver is None: + self._mg_solver = lib.newMultigridQuda(mg_param.quda) + else: + self.updateMG_solver(mg_param) + + def updateMG_solver(self, mg_param): + lib.updateMultigridQuda(self._mg_solver, mg_param.quda) + + def destroyMG_solver(self): + lib.destroyMultigridQuda(self._mg_solver) + self._mg_solver = None + + diff --git a/lyncs_quda/solver.py b/lyncs_quda/solver.py index c3eba68..2b5605a 100644 --- a/lyncs_quda/solver.py +++ b/lyncs_quda/solver.py @@ -9,16 +9,17 @@ from functools import wraps from warnings import warn +from cppyy import bind_object from lyncs_cppyy import nullptr, make_shared from .dirac import Dirac, DiracMatrix -from .enums import QudaInverterType, QudaPrecision, QudaResidualType, QudaBoolean +from .enums import QudaInverterType, QudaPrecision, QudaResidualType, QudaBoolean, QudaSolutionType from .lib import lib from .spinor_field import spinor from .time_profile import default_profiler, TimeProfile -def solve(mat, rhs, out=None, **kwargs): - return Solver(mat)(rhs, out, **kwargs) +def solve(mat, rhs, out=None, precon=None, **kwargs): + return Solver(mat, precon=precon)(rhs, out, **kwargs) class Solver: @@ -79,11 +80,11 @@ class Solver: def _init_params(): return lib.SolverParam() - def __init__(self, mat, **kwargs): + def __init__(self, mat, precon=None, **kwargs): self._params = self._init_params() self._solver = None self._profiler = None - self._precon = None + self.preconditioner = precon self.mat = mat params = type(self).default_params.copy() @@ -180,7 +181,10 @@ def preconditioner(self, value): self._params.inv_type_precondition = int(QudaInverterType["INVALID"]) self._params.preconditioner = nullptr else: - raise NotImplementedError + self._precon = value + self._params.inv_type_precondition = int(self._precon.inv_type_precondition) + self._params.preconditioner = self._precon.preconditioner + def _update_return_residual(self, old, new): assert self._params.return_residual == new @@ -223,17 +227,27 @@ def swap(self, **params): del params[key] return params - def __call__(self, rhs, out=None, warning=True, **kwargs): + def __call__(self, rhs, out=None, warning=True, solution_typ=None, **kwargs): rhs = spinor(rhs) out = rhs.prepare_out(out) kwargs = self.swap(**kwargs) + print("solver!!!", rhs.gamma_basis, out.gamma_basis) # ASSUME: QUDA_FULL_SITE_SUBSET if self.mat.dirac.full: self.quda(out.quda_field, rhs.quda_field) - elif self.mat.dirac.even: - self.quda(out.quda_field.Even(), rhs.quda_field.Even()) + elif solution_typ is not None: + # Computes the full inverse based on the e-o preconditioned matrix + in_, out_ = bind_object(nullptr, "quda::ColorSpinorField"), bind_object(nullptr, "quda::ColorSpinorField") + styp = int(QudaSolutionType[solution_typ]) + self.mat.dirac.quda_dirac.prepare(in_, out_, out.quda_field, rhs.quda_field, styp) + self.quda(out_, in_) + self.mat.dirac.quda_dirac.reconstruct(out.quda_field, rhs.quda_field, styp) else: - self.quda(out.quda_field.Odd(), rhs.quda_field.Odd()) + # Computes the inverse of the Schur complement of the matpc type + if self.mat.dirac.even: + self.quda(out.quda_field.Even(), rhs.quda_field.Even()) + else: + self.quda(out.quda_field.Odd(), rhs.quda_field.Odd()) self.swap(**kwargs) if self.true_res > self.tol: diff --git a/lyncs_quda/spinor_field.py b/lyncs_quda/spinor_field.py index befd590..cc14ec1 100644 --- a/lyncs_quda/spinor_field.py +++ b/lyncs_quda/spinor_field.py @@ -14,6 +14,7 @@ from lyncs_cppyy.ll import to_pointer from .lib import lib from .lattice_field import LatticeField +from .enum import EnumValue from .enums import ( QudaGammaBasis, QudaFieldOrder, @@ -64,6 +65,11 @@ def __init__(self, *args, gamma_basis=None, site_order="EO", **kwargs): self.gamma_basis = gamma_basis self.site_order = site_order + def _prepare(self, field, **kwargs): + kwargs.setdefault("gamma_basis", self.gamma_basis) + kwargs.setdefault("site_order", self.site_order) + return super()._prepare(field, **kwargs) + @property def ncolor(self): "Number of colors of the field" @@ -93,6 +99,8 @@ def gamma_basis(self, value): if value is None: value = "UKQCD" values = f"Possible values are {SpinorField.gammas}" + if isinstance(value, EnumValue): + value = str(value) if not isinstance(value, str): raise TypeError("Expected a string. " + values) if not value.upper() in values: @@ -125,6 +133,8 @@ def site_order(self, value): if value is None: value = "NONE" values = "Possible values are NONE, EVEN_ODD, ODD_EVEN" + if isinstance(value, EnumValue): + value = str(value) if not isinstance(value, str): raise TypeError("Expected a string. " + values) value = value.upper() diff --git a/lyncs_quda/struct.py b/lyncs_quda/struct.py index 046d238..92ca604 100644 --- a/lyncs_quda/struct.py +++ b/lyncs_quda/struct.py @@ -108,7 +108,7 @@ def __init__(self, *args, **kwargs): # temporal fix: newQudaMultigridParam does not assign a default value to n_level if "Multigrid" in type(self).__name__: n = getattr(self._quda_params, "n_level") - n = lib.QUDA_MAX_MG_LEVEL if n < 0 or n > lib.QUDA_MAX_MG_LEVEL else n + n = 2 if n < 0 or n > lib.QUDA_MAX_MG_LEVEL else n setattr(self._quda_params, "n_level", n) for arg in args: @@ -135,8 +135,9 @@ def _assign(self, key, val): val = to_code(val, typ) cur = getattr(self._quda_params, key) - if "[" in self._types[key] and not hasattr(cur, "shape"):# not sure if this is needed for cppyy3.0.0 + if "[" in self._types[key] and not hasattr(cur, "shape"): # safeguard against hectic behavior of cppyy + # QudaEigParam *eig_param[QUDA_MAX_MG_LEVEL] is somehow turned into QudaEigParam ** raise RuntimeError("cppyy is not happy for now. Try again!") @@ -166,9 +167,11 @@ def _assign(self, key, val): assert hasattr(cur, "shape") shape = tuple([getattr(lib, macro) for macro in typ.split(" ") if "QUDA_" in macro or macro.isnumeric()]) #not necessary for cppyy3.0.0? cur.reshape(shape) #? not necessary for cppyy3.0.0? - if "*" in typ: - for i in range(shape[0]): - val = to_pointer(addressof(val), ctype = typ[:-typ.index("[")].strip()) + if "*" in typ: + if hasattr(val, "__len__"): + val = [to_pointer(addressof(v), ctype = typ[:-typ.index("[")].strip()) for v in val] + else: + val = to_pointer(addressof(v), ctype = typ[:-typ.index("[")].strip()) is_string = True if "char" in typ else False if is_string: setitems(cur, b"\0") # for printing @@ -199,6 +202,7 @@ def __setattr__(self, key, val): super().__setattr__(key, val) def __str__(self): + #TODO: cannot print param vals of type char* unless it is set to something as some bytes go beyond ascii range return str(dict(self.items())) @property diff --git a/lyncs_quda/structs.py b/lyncs_quda/structs.py index 09fe839..89dcdb9 100644 --- a/lyncs_quda/structs.py +++ b/lyncs_quda/structs.py @@ -1,4 +1,3 @@ - "List of QUDA parameter structures" # NOTE: This file is automathically generated by setup.py diff --git a/test/test_lib.py b/test/test_lib.py index 11b39ee..4e1e6db 100644 --- a/test/test_lib.py +++ b/test/test_lib.py @@ -8,6 +8,7 @@ def test_device_count(lib): def test_init(lib): assert lib.initialized + lib.setMultigridParam() lib.end_quda() assert not lib.initialized lib.init_quda() diff --git a/test/test_multigrid.py b/test/test_multigrid.py new file mode 100644 index 0000000..c854e57 --- /dev/null +++ b/test/test_multigrid.py @@ -0,0 +1,39 @@ +from lyncs_quda import gauge, spinor, MultigridPreconditioner +from lyncs_quda.testing import ( + fixlib as lib, + lattice_loop, + device_loop, + dtype_loop, + gamma_loop, +) + + +@device_loop # enables device +@lattice_loop # enables lattice +@gamma_loop # enables gamma +def test_solve_mg_random(lib, lattice, device, gamma, dtype=None): + gf = gauge(lattice, dtype=dtype, device=device) + gf.gaussian() + dirac = gf.Dirac(kappa=0.01, csw=1, computeTrLog=True, full=True) + rhs = spinor(lattice, dtype=dtype, device=device, gamma_basis=gamma) + rhs.uniform() + mat = dirac.M + prec = MultigridPreconditioner(mat.dirac) + out = mat.solve(rhs, precon=prec, delta=1e-4) # this value allowed convergence for all cases + res = mat(out) + res -= rhs + res = res.norm() / rhs.norm() + prec.destroyMG_solver() + assert res < 1e-9 + + if gamma == "UKQCD": # precond Dirac op works only with + dirac = gf.Dirac(kappa=0.01, csw=1, computeTrLog=True, full=False) + mat = dirac.M + prec = MultigridPreconditioner(mat.dirac) + print(prec.inv_param.solution_type) + pout = mat.solve(rhs, precon=prec, solution_typ=prec.inv_param.solution_type, delta=1e-4) + res = out-pout + res = res.norm() / pout.norm() + prec.destroyMG_solver() + assert res < 1e-9 + diff --git a/test/test_structs.py b/test/test_structs.py index 38fda97..6579e85 100644 --- a/test/test_structs.py +++ b/test/test_structs.py @@ -29,6 +29,8 @@ def test_assign_something(lib): # ptr to strct class works mp.n_level = 3 # This is supposed to be set explicitly mp.invert_param = ip.quda + print(ep.quda) + lib.set_mg_eig_param["QudaEigParam", lib.QUDA_MAX_MG_LEVEL](mp.eig_param, ep.quda, 0) ip.split_grid = list(range(lib.QUDA_MAX_DIM)) ip.madwf_param_infile = "hi I'm here!" mp.geo_block_size = [[i+j+1 for j in range(lib.QUDA_MAX_DIM)] for i in range(lib.QUDA_MAX_MG_LEVEL)] From 52dc68910b14025177c64429ad7f766948640cab Mon Sep 17 00:00:00 2001 From: sy3394 Date: Mon, 12 Jun 2023 20:02:45 +0300 Subject: [PATCH 2/6] cleanup --- lyncs_quda/dirac.py | 3 --- lyncs_quda/enum.py | 1 - lyncs_quda/lib.py | 4 ++-- lyncs_quda/multigrid.py | 22 ++++++++-------------- lyncs_quda/solver.py | 1 - lyncs_quda/spinor_field.py | 12 ++++-------- lyncs_quda/struct.py | 5 +++-- patches/copy_util_files.patch | 11 +++++++++++ post_build.py | 14 ++++++++++++++ test/test_enums.py | 1 - test/test_multigrid.py | 3 +-- test/test_structs.py | 1 + 12 files changed, 44 insertions(+), 34 deletions(-) create mode 100644 patches/copy_util_files.patch diff --git a/lyncs_quda/dirac.py b/lyncs_quda/dirac.py index 2aa8038..30819ca 100644 --- a/lyncs_quda/dirac.py +++ b/lyncs_quda/dirac.py @@ -363,7 +363,6 @@ def force(self, *phis, out=None, mult=2, coeffs=None, **params): def setGaugeParam(self, **gauge_options): g_param = QudaGaugeParam() - print(self.type, self.dslash_type) #TODO: prepare default params for other type of dirac op if "wilson" in self.dslash_type or "clover" in self.dslash_type or "twisted" in self.dslash_type: lib.setWilsonGaugeParam(g_param.quda) @@ -382,8 +381,6 @@ def setGaugeParam(self, **gauge_options): g_param.cpu_prec = int(self.gauge.precision) g_param.cuda_prec = int(self.gauge.precision) g_param.update(gauge_options) - print(getattr(g_param._quda_params,"gauge_order"), g_param.type) - print(self.gauge.quda_field.Gauge_p(), int(self.gauge.order))#g_param.gauge_order) lib.loadGaugeQuda(self.gauge.quda_field.Gauge_p(), g_param.quda) diff --git a/lyncs_quda/enum.py b/lyncs_quda/enum.py index 29b3fc1..3d41bf9 100644 --- a/lyncs_quda/enum.py +++ b/lyncs_quda/enum.py @@ -55,7 +55,6 @@ def items(cls): return cls._values.items() def clean(cls, rep): - # should turn everything into upper for consistency "Strips away prefix and suffix from key" "See enums.py to find what is prefix and suffix for a given enum value" if isinstance(rep, EnumValue): diff --git a/lyncs_quda/lib.py b/lyncs_quda/lib.py index b9cd09a..7f1963b 100644 --- a/lyncs_quda/lib.py +++ b/lyncs_quda/lib.py @@ -294,8 +294,8 @@ def __del__(self): "array.h", "momentum.h", "tune_quda.h", - "host_utils.h", - "command_line_params.h", + "utils/host_utils.h", + "utils/command_line_params.h", ] lib = QudaLib( diff --git a/lyncs_quda/multigrid.py b/lyncs_quda/multigrid.py index 74a2d1f..de7a221 100644 --- a/lyncs_quda/multigrid.py +++ b/lyncs_quda/multigrid.py @@ -41,20 +41,14 @@ def prepareParams(self, D, g_options={}, inv_options={}, mg_options={}, eig_opti # <- host_utils.h provides funcs to set global vars to some meaningful vals, according to vals in command_line... # <- misc.h implemented in misc.cpp - # sets fields to default values - #app = lib.make_app() - #lib.add_multigrid_option_group(app) - #app.parse(1,"--solve_type 2") # Set internal global vars to their default vals - dslash_type = D.dslash_type #inv_param.dslash_type.upper() + dslash_type = D.dslash_type solve_type = QudaSolveType["direct"] if D.full else QudaSolveType["direct_pc"] lib.dslash_type = int(dslash_type) lib.solve_type = int(solve_type) lib.setQudaPrecisions() lib.setQudaDefaultMgTestParams() lib.setQudaMgSolveTypes() - - # Set param vals to the default vals and update according to the user's specification D.setGaugeParam(gauge_options=g_options) @@ -79,20 +73,17 @@ def prepareParams(self, D, g_options={}, inv_options={}, mg_options={}, eig_opti inv_param.update(inv_options) mg_param.update(mg_options) if "clover" in D.type: - print("mult init clover") D.clover.clover_field D.clover.inverse_field lib.loadCloverQuda(D.clover.quda_field.V(), D.clover.quda_field.V(True), inv_param.quda) mg_param.invert_param = inv_param.quda #not sure if this is necessary? # Only these fermions are supported with MG - print(dslash_type, type(dslash_type)) if dslash_type != "WILSON" and dslash_type != "CLOVER_WILSON" and dslash_type != "TWISTED_MASS" and dslash_type != "TWISTED_CLOVER": raise ValueError(f"dslash_type {dslash_type} not supported for MG") # Only these solve types are supported with MG if solve_type != "DIRECT" and solve_type != "DIRECT_PC": raise ValueError(f"Solve_type {solve_type} not supported with MG. Please use QUDA_DIRECT_SOLVE or QUDA_DIRECT_PC_SOLVE") - print(type(mg_param)) if not isiterable(is_eig): is_eig = [is_eig]*mg_param.n_level for i, eig in enumerate(is_eig): @@ -102,18 +93,21 @@ def prepareParams(self, D, g_options={}, inv_options={}, mg_options={}, eig_opti eig_param.update(eig_options) lib.set_mg_eig_param["QudaEigParam", lib.QUDA_MAX_MG_LEVEL](mg_param.eig_param, eig_param.quda, i) else: - print(mg_param.eig_param) - lib.set_mg_eig_param["QudaEigParam", lib.QUDA_MAX_MG_LEVEL](mg_param.eig_param, eig_param.quda, i, is_null=True) #?to_pointer + addressof? + lib.set_mg_eig_param["QudaEigParam", lib.QUDA_MAX_MG_LEVEL](mg_param.eig_param, eig_param.quda, i, is_null=True) return mg_param, inv_param - def setMG_solver(self, mg_param): + def setMG_solver(self, mg_param=None): + if mg_param is None: + mg_param = self.mg_param if self._mg_solver is None: self._mg_solver = lib.newMultigridQuda(mg_param.quda) else: self.updateMG_solver(mg_param) - def updateMG_solver(self, mg_param): + def updateMG_solver(self, mg_param=None): + if mg_param is None: + mg_param = self.mg_param lib.updateMultigridQuda(self._mg_solver, mg_param.quda) def destroyMG_solver(self): diff --git a/lyncs_quda/solver.py b/lyncs_quda/solver.py index 2b5605a..6f88576 100644 --- a/lyncs_quda/solver.py +++ b/lyncs_quda/solver.py @@ -231,7 +231,6 @@ def __call__(self, rhs, out=None, warning=True, solution_typ=None, **kwargs): rhs = spinor(rhs) out = rhs.prepare_out(out) kwargs = self.swap(**kwargs) - print("solver!!!", rhs.gamma_basis, out.gamma_basis) # ASSUME: QUDA_FULL_SITE_SUBSET if self.mat.dirac.full: self.quda(out.quda_field, rhs.quda_field) diff --git a/lyncs_quda/spinor_field.py b/lyncs_quda/spinor_field.py index cc14ec1..c5ec20c 100644 --- a/lyncs_quda/spinor_field.py +++ b/lyncs_quda/spinor_field.py @@ -99,13 +99,10 @@ def gamma_basis(self, value): if value is None: value = "UKQCD" values = f"Possible values are {SpinorField.gammas}" - if isinstance(value, EnumValue): - value = str(value) - if not isinstance(value, str): - raise TypeError("Expected a string. " + values) - if not value.upper() in values: + value = str(QudaGammaBasis[value]).upper() + if not value in values: raise ValueError("Invalid gamma. " + values) - self._gamma_basis = value.upper() + self._gamma_basis = value @property @QudaFieldOrder @@ -134,10 +131,9 @@ def site_order(self, value): value = "NONE" values = "Possible values are NONE, EVEN_ODD, ODD_EVEN" if isinstance(value, EnumValue): - value = str(value) + value = str(value).upper() if not isinstance(value, str): raise TypeError("Expected a string. " + values) - value = value.upper() if value in ["NONE", "LEX", "LEXICOGRAPHIC"]: value = "LEXICOGRAPHIC" elif value in ["EO", "EVEN_ODD"]: diff --git a/lyncs_quda/struct.py b/lyncs_quda/struct.py index 92ca604..415d2fb 100644 --- a/lyncs_quda/struct.py +++ b/lyncs_quda/struct.py @@ -104,7 +104,9 @@ def __init__(self, *args, **kwargs): if not getattr(self._quda_params, key) in enm.values(): val = list(enm.values())[-1] self._assign(key, val) - + if "char" in self._types[key]: + self._assign(key, b"\0") + # temporal fix: newQudaMultigridParam does not assign a default value to n_level if "Multigrid" in type(self).__name__: n = getattr(self._quda_params, "n_level") @@ -202,7 +204,6 @@ def __setattr__(self, key, val): super().__setattr__(key, val) def __str__(self): - #TODO: cannot print param vals of type char* unless it is set to something as some bytes go beyond ascii range return str(dict(self.items())) @property diff --git a/patches/copy_util_files.patch b/patches/copy_util_files.patch new file mode 100644 index 0000000..186ae1b --- /dev/null +++ b/patches/copy_util_files.patch @@ -0,0 +1,11 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 181353712..7639d17ab 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -697,3 +697,6 @@ include(CTest) + add_subdirectory(lib) + add_subdirectory(tests) + add_subdirectory(doc) ++ ++# move files in tests/utils for lyncs_quda ++install(DIRECTORY ${QUDA_SOURCE_DIR}/tests/utils FILES_MATCHING PATTERN "*.h" TYPE INCLUDE) diff --git a/post_build.py b/post_build.py index caa1f11..fe04557 100644 --- a/post_build.py +++ b/post_build.py @@ -2,6 +2,7 @@ import json import fileinput import subprocess +import shutil from pathlib import Path from tempfile import TemporaryDirectory from os.path import commonprefix @@ -32,7 +33,20 @@ def patch_include(builder, ext): continue print(line, end="") +def patch_utils(builder, ext): + install_dir = builder.get_install_dir(ext) + "/include/utils" + path = install_dir + "/command_line_params.h" + with fileinput.FileInput(str(path), inplace=True, backup=".bak") as fp: + for fline in fp: + line = str(fline) + #TODO: better way to remove QUDAApp related lines + if (line.strip().startswith("#include") and "CLI11" in line.strip()) or (fp.filelineno() > 14 and fp.filelineno() < 154): + print("", end="") + continue + print(line, end="") + + # PATCH 2: generates enums.py ENUM_OUTPUT = """ diff --git a/test/test_enums.py b/test/test_enums.py index 52fd609..66cd3d0 100644 --- a/test/test_enums.py +++ b/test/test_enums.py @@ -8,7 +8,6 @@ def test_enums(): enum = getattr(enums, enum) assert issubclass(enum, Enum) - for key, val in enum.items(): assert key in enum assert val in enum diff --git a/test/test_multigrid.py b/test/test_multigrid.py index c854e57..b248b1e 100644 --- a/test/test_multigrid.py +++ b/test/test_multigrid.py @@ -26,11 +26,10 @@ def test_solve_mg_random(lib, lattice, device, gamma, dtype=None): prec.destroyMG_solver() assert res < 1e-9 - if gamma == "UKQCD": # precond Dirac op works only with + if gamma == "UKQCD": # precond Dirac op works only with UKQCD basis dirac = gf.Dirac(kappa=0.01, csw=1, computeTrLog=True, full=False) mat = dirac.M prec = MultigridPreconditioner(mat.dirac) - print(prec.inv_param.solution_type) pout = mat.solve(rhs, precon=prec, solution_typ=prec.inv_param.solution_type, delta=1e-4) res = out-pout res = res.norm() / pout.norm() diff --git a/test/test_structs.py b/test/test_structs.py index 6579e85..70babeb 100644 --- a/test/test_structs.py +++ b/test/test_structs.py @@ -26,6 +26,7 @@ def test_assign_something(lib): ip = structs.QudaInvertParam() ep = structs.QudaEigParam() + print(mp) # ptr to strct class works mp.n_level = 3 # This is supposed to be set explicitly mp.invert_param = ip.quda From 378617a79289eddfaa8c8f88029853c0084bccac Mon Sep 17 00:00:00 2001 From: sy3394 Date: Tue, 13 Jun 2023 16:23:46 +0300 Subject: [PATCH 3/6] CLI11 related fix --- lyncs_quda/lib.py | 5 +++++ patches/copy_util_files.patch | 5 +++-- post_build.py | 13 ------------- test/test_lib.py | 1 - 4 files changed, 8 insertions(+), 16 deletions(-) diff --git a/lyncs_quda/lib.py b/lyncs_quda/lib.py index 7f1963b..acdfbb1 100644 --- a/lyncs_quda/lib.py +++ b/lyncs_quda/lib.py @@ -10,6 +10,7 @@ ] import atexit +import cppyy from os import environ from pathlib import Path from array import array @@ -306,6 +307,10 @@ def __del__(self): defined={"QUDA_PRECISION": QUDA_PRECISION, "QUDA_RECONSTRUCT": QUDA_RECONSTRUCT}, ) lib.MPI = MPI +# TODO: need to change "load" function of Lib from lyncs_cppyy to avoid the line below +# NOTE: This assumes: Python3.x & not runned from Jupyter notebook +# alternative: os.path.dirname(os.path.abspath(__file__)) +cppyy.add_include_path(str(Path(__file__).parent.absolute()) + "/include/externals") # used? try: diff --git a/patches/copy_util_files.patch b/patches/copy_util_files.patch index 186ae1b..a290de3 100644 --- a/patches/copy_util_files.patch +++ b/patches/copy_util_files.patch @@ -1,11 +1,12 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index 181353712..7639d17ab 100644 +index 181353712..7681e949b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -697,3 +697,6 @@ include(CTest) +@@ -697,3 +697,7 @@ include(CTest) add_subdirectory(lib) add_subdirectory(tests) add_subdirectory(doc) + +# move files in tests/utils for lyncs_quda +install(DIRECTORY ${QUDA_SOURCE_DIR}/tests/utils FILES_MATCHING PATTERN "*.h" TYPE INCLUDE) ++install(DIRECTORY ${QUDA_SOURCE_DIR}/include/externals TYPE INCLUDE) diff --git a/post_build.py b/post_build.py index fe04557..720e198 100644 --- a/post_build.py +++ b/post_build.py @@ -33,19 +33,6 @@ def patch_include(builder, ext): continue print(line, end="") -def patch_utils(builder, ext): - install_dir = builder.get_install_dir(ext) + "/include/utils" - path = install_dir + "/command_line_params.h" - with fileinput.FileInput(str(path), inplace=True, backup=".bak") as fp: - for fline in fp: - line = str(fline) - #TODO: better way to remove QUDAApp related lines - if (line.strip().startswith("#include") and "CLI11" in line.strip()) or (fp.filelineno() > 14 and fp.filelineno() < 154): - print("", end="") - continue - print(line, end="") - - # PATCH 2: generates enums.py diff --git a/test/test_lib.py b/test/test_lib.py index 4e1e6db..11b39ee 100644 --- a/test/test_lib.py +++ b/test/test_lib.py @@ -8,7 +8,6 @@ def test_device_count(lib): def test_init(lib): assert lib.initialized - lib.setMultigridParam() lib.end_quda() assert not lib.initialized lib.init_quda() From 8a285804e2d09f82cb8648df5fe24fcc070d80ad Mon Sep 17 00:00:00 2001 From: sy3394 Date: Tue, 13 Jun 2023 17:52:04 +0300 Subject: [PATCH 4/6] slightly better automatization --- lyncs_quda/multigrid.py | 36 ++++++++++++++++-------------------- lyncs_quda/solver.py | 2 +- test/test_multigrid.py | 2 -- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/lyncs_quda/multigrid.py b/lyncs_quda/multigrid.py index de7a221..afc83b8 100644 --- a/lyncs_quda/multigrid.py +++ b/lyncs_quda/multigrid.py @@ -12,22 +12,27 @@ from .structs import QudaInvertParam, QudaMultigridParam, QudaEigParam class MultigridPreconditioner: - __slots__ = ["_mg_solver", "mg_param", "inv_param"] + __slots__ = ["_quda", "mg_param", "inv_param"] def __init__(self, D, inv_options={}, mg_options={}, eig_options={}, is_eig=False): - self._mg_solver = None + self._quda = None self.mg_param, self.inv_param = self.prepareParams(D, inv_options=inv_options, mg_options=mg_options, eig_options=eig_options, is_eig=is_eig) - self.setMG_solver(self.mg_param) @property @QudaInverterType def inv_type_precondition(self): return "MG_INVERTER" + #TODO: absorb updateMG_solver into this property and delete the function + # This will reqiure detecting the change of mg and inv param structs from + # the last update or creation of QUDA multigrid_solver object @property - def preconditioner(self): - return self._mg_solver + def quda(self): + if self._quda is None: + self._quda = lib.newMultigridQuda(self.mg_param.quda) + return self._quda + # TODO: can also accept structs? def prepareParams(self, D, g_options={}, inv_options={}, mg_options={}, eig_options={}, is_eig=False): # INPUT: D is a Dirac instance # is_eig is a list of bools indicating whether eigsolver is used to generate @@ -97,21 +102,12 @@ def prepareParams(self, D, g_options={}, inv_options={}, mg_options={}, eig_opti return mg_param, inv_param - def setMG_solver(self, mg_param=None): - if mg_param is None: - mg_param = self.mg_param - if self._mg_solver is None: - self._mg_solver = lib.newMultigridQuda(mg_param.quda) - else: - self.updateMG_solver(mg_param) - - def updateMG_solver(self, mg_param=None): - if mg_param is None: - mg_param = self.mg_param - lib.updateMultigridQuda(self._mg_solver, mg_param.quda) + def updateMG_solver(self): + lib.updateMultigridQuda(self._quda, self.mg_param.quda) - def destroyMG_solver(self): - lib.destroyMultigridQuda(self._mg_solver) - self._mg_solver = None + def __del__(self): + if self._quda is not None: + lib.destroyMultigridQuda(self._quda) + self._quda = None diff --git a/lyncs_quda/solver.py b/lyncs_quda/solver.py index 6f88576..057c2de 100644 --- a/lyncs_quda/solver.py +++ b/lyncs_quda/solver.py @@ -183,7 +183,7 @@ def preconditioner(self, value): else: self._precon = value self._params.inv_type_precondition = int(self._precon.inv_type_precondition) - self._params.preconditioner = self._precon.preconditioner + self._params.preconditioner = self._precon.quda def _update_return_residual(self, old, new): diff --git a/test/test_multigrid.py b/test/test_multigrid.py index b248b1e..575d57c 100644 --- a/test/test_multigrid.py +++ b/test/test_multigrid.py @@ -23,7 +23,6 @@ def test_solve_mg_random(lib, lattice, device, gamma, dtype=None): res = mat(out) res -= rhs res = res.norm() / rhs.norm() - prec.destroyMG_solver() assert res < 1e-9 if gamma == "UKQCD": # precond Dirac op works only with UKQCD basis @@ -33,6 +32,5 @@ def test_solve_mg_random(lib, lattice, device, gamma, dtype=None): pout = mat.solve(rhs, precon=prec, solution_typ=prec.inv_param.solution_type, delta=1e-4) res = out-pout res = res.norm() / pout.norm() - prec.destroyMG_solver() assert res < 1e-9 From 1b9a05ae08905b0844c0577eee92104b06d4200e Mon Sep 17 00:00:00 2001 From: sy3394 Date: Wed, 14 Jun 2023 15:52:42 +0300 Subject: [PATCH 5/6] auto-update of multigrid_solver --- lyncs_quda/multigrid.py | 6 +++--- lyncs_quda/struct.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/lyncs_quda/multigrid.py b/lyncs_quda/multigrid.py index afc83b8..8f5ebfc 100644 --- a/lyncs_quda/multigrid.py +++ b/lyncs_quda/multigrid.py @@ -30,6 +30,9 @@ def inv_type_precondition(self): def quda(self): if self._quda is None: self._quda = lib.newMultigridQuda(self.mg_param.quda) + elif self.mg_param.updated or self.inv_param.updated: + print("UPDATEDDD") + lib.updateMultigridQuda(self._quda, self.mg_param.quda) return self._quda # TODO: can also accept structs? @@ -102,9 +105,6 @@ def prepareParams(self, D, g_options={}, inv_options={}, mg_options={}, eig_opti return mg_param, inv_param - def updateMG_solver(self): - lib.updateMultigridQuda(self._quda, self.mg_param.quda) - def __del__(self): if self._quda is not None: lib.destroyMultigridQuda(self._quda) diff --git a/lyncs_quda/struct.py b/lyncs_quda/struct.py index 415d2fb..c03b21c 100644 --- a/lyncs_quda/struct.py +++ b/lyncs_quda/struct.py @@ -116,7 +116,8 @@ def __init__(self, *args, **kwargs): for arg in args: self.update(arg) self.update(kwargs) - + self.updated = False + def keys(self): "List of keys in the structure" return self._types.keys() @@ -202,7 +203,8 @@ def __setattr__(self, key, val): ) else: #should we allow this? super().__setattr__(key, val) - + super().__setattr__("updated", True) + def __str__(self): return str(dict(self.items())) From ca9d5eb2c1d7461a12ab8fe0a7b26cab9b97ff30 Mon Sep 17 00:00:00 2001 From: sy3394 Date: Wed, 14 Jun 2023 22:33:21 +0300 Subject: [PATCH 6/6] cleanup --- lyncs_quda/multigrid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lyncs_quda/multigrid.py b/lyncs_quda/multigrid.py index 8f5ebfc..09f9768 100644 --- a/lyncs_quda/multigrid.py +++ b/lyncs_quda/multigrid.py @@ -31,8 +31,9 @@ def quda(self): if self._quda is None: self._quda = lib.newMultigridQuda(self.mg_param.quda) elif self.mg_param.updated or self.inv_param.updated: - print("UPDATEDDD") lib.updateMultigridQuda(self._quda, self.mg_param.quda) + self.mg_param.updated = False + self.inv_param.updated = False return self._quda # TODO: can also accept structs?