Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 24, 2024
1 parent 306e425 commit a3c4d9f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 32 deletions.
58 changes: 32 additions & 26 deletions iodata/overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def factorial2(n, exact=False):
return 0.0
else:
return scipy.special.factorial2(n, exact=exact)

# Handle float inputs
elif isinstance(n, float):
if n == -1.0 or n == 0.0:
Expand All @@ -60,7 +60,7 @@ def factorial2(n, exact=False):
return 0.0
else:
return scipy.special.factorial2(int(n), exact=exact)

# Handle array inputs
elif isinstance(n, np.ndarray):
result = np.zeros_like(n, dtype=float)
Expand All @@ -75,9 +75,12 @@ def factorial2(n, exact=False):


# pylint: disable=too-many-nested-blocks,too-many-statements,too-many-branches
def compute_overlap(obasis0: MolecularBasis, atcoords0: np.ndarray,
obasis1: Optional[MolecularBasis] = None,
atcoords1: Optional[np.ndarray] = None,) -> np.ndarray:
def compute_overlap(
obasis0: MolecularBasis,
atcoords0: np.ndarray,
obasis1: Optional[MolecularBasis] = None,
atcoords1: Optional[np.ndarray] = None,
) -> np.ndarray:
r"""Compute overlap matrix for the given molecular basis set(s).
.. math::
Expand Down Expand Up @@ -109,28 +112,30 @@ def compute_overlap(obasis0: MolecularBasis, atcoords0: np.ndarray,
The matrix with overlap integrals, ``shape=(obasis0.nbasis, obasis1.nbasis)``.
"""
if obasis0.primitive_normalization != 'L2':
raise ValueError('The overlap integrals are only implemented for L2 '
'normalization.')
if obasis0.primitive_normalization != "L2":
raise ValueError("The overlap integrals are only implemented for L2 " "normalization.")

# Get a segmented basis, for simplicity
obasis0 = obasis0.get_segmented()

# Handle optional arguments
if obasis1 is None:
if atcoords1 is not None:
raise TypeError("When no second basis is given, no second second "
"array of atomic coordinates is expected.")
raise TypeError(
"When no second basis is given, no second second "
"array of atomic coordinates is expected."
)
obasis1 = obasis0
atcoords1 = atcoords0
identical = True
else:
if obasis1.primitive_normalization != 'L2':
raise ValueError('The overlap integrals are only implemented for L2 '
'normalization.')
if obasis1.primitive_normalization != "L2":
raise ValueError("The overlap integrals are only implemented for L2 " "normalization.")
if atcoords1 is None:
raise TypeError("When a second basis is given, a second second "
"array of atomic coordinates is expected.")
raise TypeError(
"When a second basis is given, a second second "
"array of atomic coordinates is expected."
)
# Get a segmented basis, for simplicity
obasis1 = obasis1.get_segmented()
identical = False
Expand All @@ -140,13 +145,13 @@ def compute_overlap(obasis0: MolecularBasis, atcoords0: np.ndarray,

# Compute the normalization constants of the Cartesian primitives, with the
# contraction coefficients multiplied in.
scales0 = [_compute_cart_shell_normalizations(shell) * shell.coeffs
for shell in obasis0.shells]
scales0 = [_compute_cart_shell_normalizations(shell) * shell.coeffs for shell in obasis0.shells]
if identical:
scales1 = scales0
else:
scales1 = [_compute_cart_shell_normalizations(shell) * shell.coeffs
for shell in obasis1.shells]
scales1 = [
_compute_cart_shell_normalizations(shell) * shell.coeffs for shell in obasis1.shells
]

n_max = max(np.max(shell.angmoms) for shell in obasis0.shells)
if not identical:
Expand Down Expand Up @@ -221,9 +226,9 @@ def compute_overlap(obasis0: MolecularBasis, atcoords0: np.ndarray,
# END of Cartesian coordinate system (if going to pure coordinates)

# cart to pure
if shell0.kinds[0] == 'p':
if shell0.kinds[0] == "p":
shell_overlap = np.dot(tfs[shell0.angmoms[0]], shell_overlap)
if shell1.kinds[0] == 'p':
if shell1.kinds[0] == "p":
shell_overlap = np.dot(shell_overlap, tfs[shell1.angmoms[0]].T)

# store lower triangular result
Expand Down Expand Up @@ -272,12 +277,12 @@ def compute_overlap_gaussian_1d(self, x1, x2, n1, n2, two_at):
pf_i = self.binomials[n1][i] * x1 ** (n1 - i)
for j in range(i % 2, n2 + 1, 2):
m = i + j
integ = self.facts[m] / two_at ** (m / 2) # TODO // 2
integ = self.facts[m] / two_at ** (m / 2) # TODO // 2
value += pf_i * self.binomials[n2][j] * x2 ** (n2 - j) * integ
return value


def _compute_cart_shell_normalizations(shell: 'Shell') -> np.ndarray:
def _compute_cart_shell_normalizations(shell: "Shell") -> np.ndarray:
"""Return normalization constants for the primitives in a given shell.
Parameters
Expand All @@ -292,7 +297,7 @@ def _compute_cart_shell_normalizations(shell: 'Shell') -> np.ndarray:
shell is pure.
"""
shell = attr.evolve(shell, kinds=['c'] * shell.ncon)
shell = attr.evolve(shell, kinds=["c"] * shell.ncon)
result = []
for angmom in shell.angmoms:
for exponent in shell.exponents:
Expand All @@ -319,5 +324,6 @@ def gob_cart_normalization(alpha: np.ndarray, n: np.ndarray) -> np.ndarray:
The normalization constant for the gaussian cartesian basis.
"""
return np.sqrt((4 * alpha) ** sum(n) * (2 * alpha / np.pi) ** 1.5
/ np.prod(factorial2(2 * n - 1)))
return np.sqrt(
(4 * alpha) ** sum(n) * (2 * alpha / np.pi) ** 1.5 / np.prod(factorial2(2 * n - 1))
)
11 changes: 5 additions & 6 deletions iodata/test/test_factorial2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

factorial2 = overlap.factorial2

class TestFactorial2(unittest.TestCase):

class TestFactorial2(unittest.TestCase):
def test_integer_arguments(self):
self.assertEqual(factorial2(0, exact=True), 1)
self.assertEqual(factorial2(1, exact=True), 1)
Expand All @@ -22,14 +22,12 @@ def test_float_arguments(self):

def test_integer_array_argument(self):
np.testing.assert_array_equal(
factorial2(np.array([0, 1, 2, 3]), exact=True),
np.array([1, 1, 2, 3])
factorial2(np.array([0, 1, 2, 3]), exact=True), np.array([1, 1, 2, 3])
)

def test_float_array_argument(self):
np.testing.assert_array_almost_equal(
factorial2(np.array([0.0, 1.0, 2.0, 3.0]), exact=False),
np.array([1.0, 1.0, 2.0, 3.0])
factorial2(np.array([0.0, 1.0, 2.0, 3.0]), exact=False), np.array([1.0, 1.0, 2.0, 3.0])
)

def test_special_cases_exact(self):
Expand All @@ -40,5 +38,6 @@ def test_special_cases_not_exact(self):
np.testing.assert_almost_equal(factorial2(-1.0, exact=False), 1.0)
np.testing.assert_almost_equal(factorial2(-2.0, exact=False), 0.0)

if __name__ == '__main__':

if __name__ == "__main__":
unittest.main()

0 comments on commit a3c4d9f

Please sign in to comment.