Skip to content

Cleanup2 #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: cleanup
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/black_n_pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install black lyncs_setuptools[pylint]
pip install black lyncs_setuptools[all]

- name: Applying black formatting
run: |
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[![build & test](https://img.shields.io/github/workflow/status/Lyncs-API/lyncs.quda/build%20&%20test?logo=github&logoColor=white)](https://github.com/Lyncs-API/lyncs.quda/actions)
-->
[![license](https://img.shields.io/github/license/Lyncs-API/lyncs.quda?logo=github&logoColor=white)](https://github.com/Lyncs-API/lyncs.quda/blob/master/LICENSE)
[![pylint](https://img.shields.io/badge/pylint%20score-8.5%2F10-yellowgreen?logo=python&logoColor=white)](http://pylint.pycqa.org/)
[![pylint](https://img.shields.io/badge/pylint%20score-7.9%2F10-yellow?logo=python&logoColor=white)](http://pylint.pycqa.org/)
[![black](https://img.shields.io/badge/code%20style-black-000000.svg?logo=codefactor&logoColor=white)](https://github.com/ambv/black)

[QUDA](http://lattice.github.io/quda/) is a library for performing calculations in lattice QCD on GPUs.
Expand Down
2 changes: 1 addition & 1 deletion lyncs_quda/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Array:
def __init__(self, typename, size, elems=None):
self._qarray = lib.quda.array[typename, size]()

if elems != None:
if elems is not None:
if isiterable(elems):
if len(elems) > size:
raise ValueError()
Expand Down
10 changes: 7 additions & 3 deletions lyncs_quda/clover_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@
"CloverField",
]

import numpy
from cppyy.gbl.std import vector
from functools import cache
import numpy

from lyncs_cppyy import make_shared, to_pointer
from .lib import lib, cupy
from .lattice_field import LatticeField
from .gauge_field import GaugeField
from .enums import QudaParity, QudaTwistFlavorType, QudaCloverFieldOrder, QudaFieldCreate
from .enums import (
QudaParity,
QudaTwistFlavorType,
QudaCloverFieldOrder,
QudaFieldCreate,
)

# TODO list
# We want dimension of (cu/num)py array to reflect parity and order
Expand Down
11 changes: 5 additions & 6 deletions lyncs_quda/dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

from functools import wraps
from dataclasses import dataclass, field
from numpy import sqrt
from typing import Union
from lyncs_cppyy import make_shared, nullptr
from .gauge_field import gauge, GaugeField
from .gauge_field import GaugeField
from .clover_field import CloverField
from .spinor_field import spinor
from .lib import lib
Expand All @@ -20,7 +19,7 @@
QudaMatPCType,
QudaDagType,
QudaParity,
)
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -222,13 +221,13 @@ def action(self, phi, **params):
"""

if not self.full:
if "CLOVER" in self.type and self.symm == True:
if "CLOVER" in self.type and self.symm:
raise ValueError("Preconditioned matrix should be asymmetric")
if "CLOVER" not in self.type and self.symm != True:
if "CLOVER" not in self.type and not self.symm:
raise ValueError(
"Preconditioned matrix should be symmetric for non-clover type Dirac matrix"
)
if "CLOVER" in self.type and self.computeTrLog != True:
if "CLOVER" in self.type and not self.computeTrLog:
raise ValueError(
"computeTrLog should be set True in the preconditioned case"
)
Expand Down
20 changes: 9 additions & 11 deletions lyncs_quda/gauge_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
from collections import defaultdict
from functools import cache
import numpy
from lyncs_cppyy import make_shared, lib as tmp, to_pointer, array_to_pointers
from lyncs_cppyy import make_shared, to_pointer, array_to_pointers
from lyncs_utils import prod, isiterable
from .lib import lib, cupy
from .array import Array
from .lattice_field import LatticeField, backend
from .spinor_field import spinor
from .lattice_field import LatticeField
from .time_profile import default_profiler
from .enums import (
QudaReconstructType,
Expand Down Expand Up @@ -121,8 +120,8 @@ def new(self, reconstruct=None, geometry=None, **kwargs):
kwargs["dofs"][0],
val // 2 if self.iscomplex else val,
)
except ValueError:
raise ValueError(f"Invalid reconstruct {reconstruct}")
except ValueError as VE:
raise VE(f"Invalid reconstruct {reconstruct}")
out = super().new(**kwargs)
is_momentum = kwargs.get("is_momentum", self.is_momentum)
out.is_momentum = is_momentum
Expand Down Expand Up @@ -197,9 +196,8 @@ def ncol(self):
def order(self):
"Data order of the field"
dofs = self.dofs_per_link
if self.precision != "double" and (
dofs == 8 or dofs == 12
): # if FLOAT8 defined, if prec=half/quarter and recon=8, FLOAT8
if self.precision != "double" and dofs in (8, 12):
# if FLOAT8 defined, if prec=half/quarter and recon=8, FLOAT8
return "FLOAT4"
return "FLOAT2"

Expand Down Expand Up @@ -346,7 +344,7 @@ def extended_field(self, sites=1):
self.ptr, self.quda_params, numpy.array(sites, dtype="int32")
)
)
elif self.location == "CUDA":
if self.location == "CUDA":
"Returns cudaGaugeField"
return make_shared(
lib.createExtendedGauge(
Expand Down Expand Up @@ -531,7 +529,7 @@ def gaussian(self, epsilon=1, seed=None):
seed = seed or int(time() * 1e9)
lib.gaugeGauss(self.quda_field, seed, epsilon)

def uniform(self, epsilon=1, seed=None):
def uniform(self, seed=None):
"""
Generates Uniform distributed SU(N) field.
"""
Expand Down Expand Up @@ -756,7 +754,7 @@ def compute_paths(
if not len(paths) == len(coeffs):
raise ValueError("Paths and coeffs must have the same length")
else:
assert coeffs == None, "coeffs not used in case of not sum_paths"
assert coeffs is None, "coeffs not used in case of not sum_paths"

# Preparing fnc
if insertion is not None:
Expand Down
5 changes: 0 additions & 5 deletions lyncs_quda/lattice_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"LatticeField",
]

from array import array
from contextlib import contextmanager
from functools import cache
import numpy
Expand All @@ -16,10 +15,6 @@
from .lib import lib, cupy
from .array import lat_dims

from lyncs_cppyy import to_pointer
import ctypes
import traceback


def get_precision(dtype):
if dtype in ["float64", "complex128"]:
Expand Down
10 changes: 5 additions & 5 deletions lyncs_quda/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
from os import environ
from pathlib import Path
from array import array
from appdirs import user_data_dir
from math import prod
from lyncs_cppyy import Lib, nullptr, cppdef
from lyncs_cppyy.ll import addressof, to_pointer
from lyncs_utils import static_property, lazy_import
from appdirs import user_data_dir
from lyncs_cppyy import Lib, cppdef
from lyncs_utils import lazy_import
from . import __path__
from .config import QUDA_MPI, QUDA_GITVERSION, QUDA_PRECISION, QUDA_RECONSTRUCT

Expand Down Expand Up @@ -86,7 +85,7 @@ def device_id(self, value):
f"device_id cannot be changed: current={self.device_id}, given={value}"
)
if not isinstance(value, int):
raise TypeError(f"Unsupported type for device: {type(device)}")
raise TypeError(f"Unsupported type for device ID: {type(value)}")
self._device_id = value

def get_current_device(self):
Expand Down Expand Up @@ -262,6 +261,7 @@ def __del__(self):
PATHS = list(__path__)

headers = [
"comm_quda.h",
"quda.h",
"gauge_field.h",
"gauge_tools.h",
Expand Down
2 changes: 1 addition & 1 deletion lyncs_quda/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"Solver",
]

from functools import wraps, cache
from functools import wraps
from warnings import warn
from lyncs_cppyy import nullptr, make_shared
from .dirac import Dirac, DiracMatrix
Expand Down
12 changes: 6 additions & 6 deletions lyncs_quda/spinor_field.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Interface to gauge_field.h
Interface to color_spinor_field.h
"""

__all__ = [
Expand All @@ -8,7 +8,7 @@
"SpinorField",
]

from functools import reduce, cache
from functools import cache
from time import time
from lyncs_cppyy import make_shared
from lyncs_cppyy.ll import to_pointer
Expand All @@ -22,7 +22,7 @@
QudaPCType,
QudaFieldCreate,
QudaNoiseType,
)
)

"""
NOTE:
Expand Down Expand Up @@ -124,7 +124,7 @@ def site_order(self):
def site_order(self, value):
if value is None:
value = "NONE"
values = f"Possible values are NONE, EVEN_ODD, ODD_EVEN"
values = "Possible values are NONE, EVEN_ODD, ODD_EVEN"
if not isinstance(value, str):
raise TypeError("Expected a string. " + values)
value = value.upper()
Expand Down Expand Up @@ -215,15 +215,15 @@ def norm1(self, parity=None):
"L1 norm of the field"
if parity == "EVEN":
return lib.blas.norm1(self.quda_field.Even())
elif parity == "ODD":
if parity == "ODD":
return lib.blas.norm1(self.quda_field.Odd())
return lib.blas.norm1(self.quda_field)

def norm2(self, parity=None):
"L2 norm of the field"
if parity == "EVEN":
return lib.blas.norm2(self.quda_field.Even())
elif parity == "ODD":
if parity == "ODD":
return lib.blas.norm2(self.quda_field.Odd())
return lib.blas.norm2(self.quda_field)

Expand Down
63 changes: 63 additions & 0 deletions patches/error.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
diff --git a/include/comm_quda.h b/include/comm_quda.h
index e99d412..f221eab 100644
--- a/include/comm_quda.h
+++ b/include/comm_quda.h
@@ -4,6 +4,7 @@
#include <quda_constants.h>
#include <quda_api.h>
#include <array.h>
+#include <stdexcept>

#ifdef __cplusplus
extern "C" {
@@ -420,4 +421,7 @@ namespace quda
bool commAsyncReduction();
void commAsyncReductionSet(bool global_reduce);

+ class QudaException : public std::runtime_error {
+ using std::runtime_error::runtime_error;
+ };
} // namespace quda
diff --git a/lib/comm_common.cpp b/lib/comm_common.cpp
index 2c58188..330b0e0 100644
--- a/lib/comm_common.cpp
+++ b/lib/comm_common.cpp
@@ -1,6 +1,9 @@
#include <unistd.h> // for gethostname()
#include <assert.h>
#include <limits>
+#include <cstdlib>
+#include <exception>
+#include <string>

#include <quda_internal.h>
#include <communicator_quda.h>
@@ -130,7 +133,12 @@ namespace quda

return topo;
}
-
+
+ static int current_status;
+ void abort(){
+ comm_abort_(current_status);
+ }
+
void comm_abort(int status)
{
#ifdef HOST_DEBUG
@@ -142,7 +150,13 @@ namespace quda
backward::Printer p;
p.print(st, getOutputFile());
#endif
- comm_abort_(status);
+ static bool called = false;
+ if (not called) {
+ std::set_terminate(abort);
+ called = true;
+ }
+ current_status = status;
+ throw QudaException("QUDA Error: exit code " + std::to_string(status));
}

} // namespace quda
19 changes: 19 additions & 0 deletions test/test_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from lyncs_quda import gauge, gauge_scalar

import pytest
from lyncs_quda.testing import (
fixlib as lib,
lattice_loop,
)


@lattice_loop
def test_error(lib, lattice):
gf = gauge(lattice)
gs = gauge_scalar(lattice)

gf.quda_field.copy(gs.quda_field)
with pytest.raises(lib.std.runtime_error):
gf.quda_field.copy(gs.quda_field)
gf.zero()
assert gf == 0
6 changes: 2 additions & 4 deletions test/test_evenodd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from lyncs_quda import evenodd, continous, to_quda, from_quda

# from lyncs_quda.lib import fixlib as lib
from lyncs_quda.lib import fixlib as lib
import numpy as np


Expand All @@ -20,17 +20,15 @@ def outer(request):
return request.param


def test_evenodd(shape, inner, outer):
def test_evenodd(lib, shape, inner, outer):
tile = np.array([1, -1])
for i in range(1, len(shape)):
tile = np.array([tile, tile * -1])

arr = np.tile(tile, shape)
shape = arr.shape
out = evenodd(arr)

assert (continous(out) == arr).all()

out = out.flatten()
n = out.shape[0] // 2

Expand Down