-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdrifts.py
112 lines (91 loc) · 2.71 KB
/
drifts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Nicholas M. Boffi
7/29/22
Drift terms for score-based transport modeling.
"""
from jax import vmap
from jax.lax import stop_gradient
import jax.numpy as np
from typing import Callable, Tuple
compute_particle_diffs = vmap(
vmap(lambda x, y: x - y, in_axes=(0, None), out_axes=0),
in_axes=(None, 0), out_axes=1
)
def active_swimmer(
xv: np.ndarray,
t: float,
gamma: float
) -> np.ndarray:
"""Active swimmer example."""
del t
x, v = xv
return np.array([-x**3 + v, -gamma*v])
def harmonic_trap(
x: np.ndarray,
t: float,
compute_mut: Callable[[float], np.ndarray],
N: int,
d: int
) -> np.ndarray:
"""Forcing for particles in a harmonic trap with harmonic repulsion."""
mut = compute_mut(t)
particle_pos = x.reshape((N, d))
particle_forces = -0.5*particle_pos + mut[None, :] \
- 0.5*np.mean(particle_pos, axis=0)[None, :]
return particle_forces.ravel()
def gaussian_interaction(
xs: np.ndarray,
A: float,
r: float,
) -> np.ndarray:
particle_diffs = compute_particle_diffs(xs, xs)
gauss_facs = np.exp(-np.sum(particle_diffs**2, axis=2) / (2*r**2))
interaction = A/(r**2)*np.mean(particle_diffs * gauss_facs[:, :, None], axis=1)
return interaction
def anharmonic_gaussian(
x: np.ndarray,
t: float,
A: float,
r: float,
B: float,
N: int,
d: int,
compute_mut: Callable[[float], np.ndarray],
print_info=False
) -> np.ndarray:
"""Gaussian short-range force with anharmonic attraction to the origin."""
# trap force
particle_pos = x.reshape((N, d))
diff = particle_pos - compute_mut(t)[None, :]
diff_norms = np.linalg.norm(diff, axis=1)**2
# repulsive interaction
interaction = gaussian_interaction(particle_pos, A, r)
particle_forces = -B*diff*diff_norms[:, None] + interaction
if print_info:
print('interaction:', interaction)
print('potential:', particle_forces - interaction)
print()
return particle_forces.ravel()
def anharmonic(
x: np.ndarray,
t: float,
compute_mut: Callable[[float], np.ndarray]
) -> np.ndarray:
"""Single particle in an anharmonic trap."""
diff = x - compute_mut(t)
return -diff * (diff @ diff)
def anharmonic_harmonic(
x: np.ndarray,
t: float,
A: float,
B: float,
N: int,
d: int,
compute_mut: Callable[[float], np.ndarray],
) -> np.ndarray:
"""Harmonically-interacting particles in an anharmonic trap."""
particle_pos = x.reshape((N, d))
diff = particle_pos - compute_mut(t)[None, :]
diff_norms = np.sum(diff**2, axis=1)
xbar = np.mean(particle_pos, axis=0)
return np.ravel(-B*diff*diff_norms[:, None] + A*(particle_pos - xbar[None, :]))