Skip to content

Commit 306e425

Browse files
Fix #309 for the factorial2 function
Handling of int float and arrays
1 parent aba83f1 commit 306e425

File tree

1 file changed

+55
-38
lines changed

1 file changed

+55
-38
lines changed

iodata/overlap.py

+55-38
Original file line numberDiff line numberDiff line change
@@ -32,32 +32,52 @@
3232

3333

3434
def factorial2(n, exact=False):
35-
"""Wrap scipy.special.factorial2 to return 1.0 when the input is -1.
35+
"""Wrap scipy.special.factorial2 to return 1.0 when the input is -1 and handle float arrays.
3636
3737
This is a temporary workaround while we wait for Scipy's update.
3838
To learn more, see https://github.com/scipy/scipy/issues/18409.
3939
4040
Parameters
4141
----------
42-
n : int or np.ndarray
42+
n : int, float, or np.ndarray
4343
Values to calculate n!! for. If n={0, -1}, the return value is 1.
4444
For n < -1, the return value is 0.
4545
"""
46-
# Scipy 1.11.x returns an integer when n is an integer, but 1.10.x returns an array,
47-
# so np.array(n) is passed to make sure the output is always an array.
48-
out = scipy.special.factorial2(np.array(n), exact=exact)
49-
out[out <= 0] = 1.0
50-
out[out <= -2] = 0.0
51-
return out
46+
# Handle integer inputs
47+
if isinstance(n, (int, np.integer)):
48+
if n == -1 or n == 0:
49+
return 1.0
50+
elif n < -1:
51+
return 0.0
52+
else:
53+
return scipy.special.factorial2(n, exact=exact)
54+
55+
# Handle float inputs
56+
elif isinstance(n, float):
57+
if n == -1.0 or n == 0.0:
58+
return 1.0
59+
elif n < -1.0:
60+
return 0.0
61+
else:
62+
return scipy.special.factorial2(int(n), exact=exact)
63+
64+
# Handle array inputs
65+
elif isinstance(n, np.ndarray):
66+
result = np.zeros_like(n, dtype=float)
67+
for i, val in np.ndenumerate(n):
68+
if val == -1.0 or val == 0.0:
69+
result[i] = 1.0
70+
elif val < -1.0:
71+
result[i] = 0.0
72+
else:
73+
result[i] = scipy.special.factorial2(int(val), exact=exact)
74+
return result
5275

5376

5477
# pylint: disable=too-many-nested-blocks,too-many-statements,too-many-branches
55-
def compute_overlap(
56-
obasis0: MolecularBasis,
57-
atcoords0: np.ndarray,
58-
obasis1: Optional[MolecularBasis] = None,
59-
atcoords1: Optional[np.ndarray] = None,
60-
) -> np.ndarray:
78+
def compute_overlap(obasis0: MolecularBasis, atcoords0: np.ndarray,
79+
obasis1: Optional[MolecularBasis] = None,
80+
atcoords1: Optional[np.ndarray] = None,) -> np.ndarray:
6181
r"""Compute overlap matrix for the given molecular basis set(s).
6282
6383
.. math::
@@ -89,30 +109,28 @@ def compute_overlap(
89109
The matrix with overlap integrals, ``shape=(obasis0.nbasis, obasis1.nbasis)``.
90110
91111
"""
92-
if obasis0.primitive_normalization != "L2":
93-
raise ValueError("The overlap integrals are only implemented for L2 " "normalization.")
112+
if obasis0.primitive_normalization != 'L2':
113+
raise ValueError('The overlap integrals are only implemented for L2 '
114+
'normalization.')
94115

95116
# Get a segmented basis, for simplicity
96117
obasis0 = obasis0.get_segmented()
97118

98119
# Handle optional arguments
99120
if obasis1 is None:
100121
if atcoords1 is not None:
101-
raise TypeError(
102-
"When no second basis is given, no second second "
103-
"array of atomic coordinates is expected."
104-
)
122+
raise TypeError("When no second basis is given, no second second "
123+
"array of atomic coordinates is expected.")
105124
obasis1 = obasis0
106125
atcoords1 = atcoords0
107126
identical = True
108127
else:
109-
if obasis1.primitive_normalization != "L2":
110-
raise ValueError("The overlap integrals are only implemented for L2 " "normalization.")
128+
if obasis1.primitive_normalization != 'L2':
129+
raise ValueError('The overlap integrals are only implemented for L2 '
130+
'normalization.')
111131
if atcoords1 is None:
112-
raise TypeError(
113-
"When a second basis is given, a second second "
114-
"array of atomic coordinates is expected."
115-
)
132+
raise TypeError("When a second basis is given, a second second "
133+
"array of atomic coordinates is expected.")
116134
# Get a segmented basis, for simplicity
117135
obasis1 = obasis1.get_segmented()
118136
identical = False
@@ -122,13 +140,13 @@ def compute_overlap(
122140

123141
# Compute the normalization constants of the Cartesian primitives, with the
124142
# contraction coefficients multiplied in.
125-
scales0 = [_compute_cart_shell_normalizations(shell) * shell.coeffs for shell in obasis0.shells]
143+
scales0 = [_compute_cart_shell_normalizations(shell) * shell.coeffs
144+
for shell in obasis0.shells]
126145
if identical:
127146
scales1 = scales0
128147
else:
129-
scales1 = [
130-
_compute_cart_shell_normalizations(shell) * shell.coeffs for shell in obasis1.shells
131-
]
148+
scales1 = [_compute_cart_shell_normalizations(shell) * shell.coeffs
149+
for shell in obasis1.shells]
132150

133151
n_max = max(np.max(shell.angmoms) for shell in obasis0.shells)
134152
if not identical:
@@ -203,9 +221,9 @@ def compute_overlap(
203221
# END of Cartesian coordinate system (if going to pure coordinates)
204222

205223
# cart to pure
206-
if shell0.kinds[0] == "p":
224+
if shell0.kinds[0] == 'p':
207225
shell_overlap = np.dot(tfs[shell0.angmoms[0]], shell_overlap)
208-
if shell1.kinds[0] == "p":
226+
if shell1.kinds[0] == 'p':
209227
shell_overlap = np.dot(shell_overlap, tfs[shell1.angmoms[0]].T)
210228

211229
# store lower triangular result
@@ -254,12 +272,12 @@ def compute_overlap_gaussian_1d(self, x1, x2, n1, n2, two_at):
254272
pf_i = self.binomials[n1][i] * x1 ** (n1 - i)
255273
for j in range(i % 2, n2 + 1, 2):
256274
m = i + j
257-
integ = self.facts[m] / two_at ** (m / 2) # TODO // 2
275+
integ = self.facts[m] / two_at ** (m / 2) # TODO // 2
258276
value += pf_i * self.binomials[n2][j] * x2 ** (n2 - j) * integ
259277
return value
260278

261279

262-
def _compute_cart_shell_normalizations(shell: "Shell") -> np.ndarray:
280+
def _compute_cart_shell_normalizations(shell: 'Shell') -> np.ndarray:
263281
"""Return normalization constants for the primitives in a given shell.
264282
265283
Parameters
@@ -274,7 +292,7 @@ def _compute_cart_shell_normalizations(shell: "Shell") -> np.ndarray:
274292
shell is pure.
275293
276294
"""
277-
shell = attr.evolve(shell, kinds=["c"] * shell.ncon)
295+
shell = attr.evolve(shell, kinds=['c'] * shell.ncon)
278296
result = []
279297
for angmom in shell.angmoms:
280298
for exponent in shell.exponents:
@@ -301,6 +319,5 @@ def gob_cart_normalization(alpha: np.ndarray, n: np.ndarray) -> np.ndarray:
301319
The normalization constant for the gaussian cartesian basis.
302320
303321
"""
304-
return np.sqrt(
305-
(4 * alpha) ** sum(n) * (2 * alpha / np.pi) ** 1.5 / np.prod(factorial2(2 * n - 1))
306-
)
322+
return np.sqrt((4 * alpha) ** sum(n) * (2 * alpha / np.pi) ** 1.5
323+
/ np.prod(factorial2(2 * n - 1)))

0 commit comments

Comments
 (0)