forked from McCoyGroup/PyDVR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathWavefunctions.py
144 lines (117 loc) · 4.57 KB
/
Wavefunctions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Provides a DVRWavefunction class that inherits from the base Psience wavefunction
"""
from Psience.Wavefun import Wavefunction, Wavefunctions
class DVRWavefunction(Wavefunction):
def plot(self, figure=None, grid=None, **opts):
import numpy as np
if grid is None:
grid = self.opts['grid']
dim = len(grid.shape)
if dim > 1 and grid.shape[-1] == dim-1: # check whether we have a mesh of points that we need to reshape
unroll = np.roll(np.arange(len(grid.shape)), 1)
grid = grid.transpose(unroll)
if dim == 1:
if figure is None:
from McUtils.Plots import Plot
return Plot(grid, self.data, **opts)
else:
return figure.plot(grid, self.data, **opts)
else:
if figure is None:
from McUtils.Plots import Plot3D
return Plot3D(*grid, self.data.reshape(grid[0].shape), **opts)
else:
return figure.plot(*grid, self.data.reshape(grid[0].shape), **opts)
def expectation(self, op, other):
"""Computes the expectation value of operator op over the wavefunction other and self
:param other:
:type other: Wavefunction | np.ndarray
:param op:
:type op:
:return:
:rtype:
"""
import numpy as np
wf = op(self.data)
if not isinstance(other, np.ndarray):
other = other.data
return np.dot(wf, other)
def probability_density(self):
"""Computes the probability density of the current wavefunction
:return:
:rtype:
"""
import numpy as np
return np.power(self.data, 2)
class DVRWavefunctions(Wavefunctions):
# most evaluations are most efficient done in batch for DVR wavefunctions so we focus on the batch object
def __init__(self, energies=None, wavefunctions=None,
wavefunction_class=None,
rephase=True,
**opts
):
import numpy as np
if rephase:
phase_gs = np.sign(wavefunctions[:, 0])
wavefunctions = wavefunctions*phase_gs[:, np.newaxis]
super().__init__(wavefunctions=wavefunctions, energies=energies, wavefunction_class=DVRWavefunction, **opts)
def __getitem__(self, item):
"""Returns a single Wavefunction object"""
# iter comes for free with this
if isinstance(item, slice):
return type(self)(
energies=self.energies[item],
wavefunctions=self.wavefunctions[:, item],
wavefunction_class=self.wavefunction_class,
**self.opts
)
else:
return self.wavefunction_class(self.energies[item], self.wavefunctions[:, item], parent=self, **self.opts)
def plot(self, figure=None, graphics_class=None, plot_class=None, plot_style=None, **opts):
import numpy as np
from McUtils.Plots import Plot, Plot3D
grid = self.opts['grid']
dim = len(grid.shape)
if dim > 1 and grid.shape[-1] == dim-1: # check whether we have a mesh of points that we need to reshape
unroll = np.roll(np.arange(len(grid.shape)), 1)
grid = grid.transpose(unroll)
# if plot_class is None:
# if dim == 1:
# plot_class = Plot
# elif dim == 2:
# plot_class = Plot3D
# else:
# print("I don't know how to plot that")
# # raise DVRException("{}.{}: don't know how to plot {} dimensional potential".format(
# # type(self).__name__,
# # 'plot',
# # dim
# # ))
super().plot(
figure=figure,
graphics_class=graphics_class,
plot_style=plot_style,
**opts
)
def expectation(self, op, other):
"""Computes the expectation value of operator op over the wavefunction other and self
:param other:
:type other: DVRWavefunctions | np.ndarray
:param op:
:type op:
:return:
:rtype:
"""
import numpy as np
wfs = op(self.wavefunctions)
if not isinstance(other, np.ndarray):
other = other.wavefunctions
return np.dot(other.T, wfs)
def probability_density(self):
"""Computes the probability density of the set of wavefunctions
:return:
:rtype:
"""
import numpy as np
return np.power(self.wavefunctions, 2)