-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisual.py
92 lines (78 loc) · 3.25 KB
/
visual.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from concurrent.futures import ProcessPoolExecutor, as_completed
import json
import pathlib
import imageio.v2 as imageio
import numpy as np
import pyvista as pv
import scipy.spatial
from tqdm import tqdm
CELL_COLOR = (172,188,189)
RECEPTOR_COLOR = (255, 155, 133)
ckpt_folder = pathlib.Path('outputs/checkpoints')
if not ckpt_folder.exists():
raise FileNotFoundError("No checkpoints found")
ckpts = list(ckpt_folder.glob("*.npz"))
ckpts = sorted(ckpts)
selected_ckpts = [ckpts[0], ckpts[-1]]
with open(ckpt_folder.parent/'params.json', 'r') as f:
params = json.load(f)
harm_order = params["harm_order"]
harm_degree = params["harm_degree"]
frames_folder = ckpt_folder.parent / "frames"
frames_folder.mkdir(exist_ok=True)
def to_cartesian(theta, phi, alpha=1):
Y = np.real(scipy.special.sph_harm(harm_order, harm_degree, phi, theta))
r = 1 + alpha * Y
x = r * np.sin(theta) * np.cos(phi)
y = r * np.sin(theta) * np.sin(phi)
z = r * np.cos(theta)
return np.stack([x, y, z], axis=-1)
def plot_cell(file):
data = np.load(file)
pl = pv.Plotter(window_size=[800, 800], off_screen=True)
thetas = np.linspace(0, np.pi, 400)
phis = np.linspace(0, 2*np.pi, 400)
thetas, phis = np.meshgrid(thetas, phis)
surface_points = to_cartesian(thetas, phis)
grid = pv.StructuredGrid(*surface_points.T)
surface_mesh = grid.extract_geometry()
pl.add_mesh(surface_mesh, opacity=0.7)
points = to_cartesian(data["thetas"], data["phis"])
receptors = pv.PolyData(points)
pl.add_mesh(receptors, point_size=10, render_points_as_spheres=True, color=RECEPTOR_COLOR)
pl.add_text(f"Error: {data['error']:.5g}", position="upper_edge", font_size=12)
pl.camera.zoom(1.2)
pl.screenshot(frames_folder/f'{file.stem}.png')
pl.close()
def main():
thetas = np.linspace(0, np.pi, 500)
phis = np.linspace(0, 2*np.pi, 500)
thetas, phis = np.meshgrid(thetas, phis)
labels = ['initial', 'final']
pl = pv.Plotter(shape=(1, 2), window_size=[1600, 800], border=False)
for i, file in enumerate(selected_ckpts):
pl.subplot(0, i)
pl.add_text(labels[i], position="upper_left", font_size=18, color='black')
data = np.load(file)
surface_points = to_cartesian(thetas, phis)
grid = pv.StructuredGrid(*surface_points.T)
surface_mesh = grid.extract_geometry()
pl.add_mesh(surface_mesh, opacity=1.0, color=CELL_COLOR)
points = to_cartesian(data["thetas"], data["phis"])
receptors = pv.PolyData(points)
pl.add_mesh(receptors, point_size=10, render_points_as_spheres=True, color=RECEPTOR_COLOR)
pl.view_isometric()
pl.camera.zoom(1.28)
pl.link_views()
pl.show()
with ProcessPoolExecutor() as executor:
tasks = [executor.submit(plot_cell, file) for file in ckpts]
for _ in tqdm(as_completed(tasks), total=len(tasks), desc="Rendering frames", ncols=80):
pass
images = list(frames_folder.glob("*.png"))
images = sorted(images, key=lambda x: int(x.stem.split("_")[1]))
images = [imageio.imread(str(image)) for image in images]
imageio.mimsave(ckpt_folder.parent / "animation.mp4", images, fps=30)
print('Animation saved to', ckpt_folder.parent / "animation.mp4")
if __name__ == "__main__":
main()