-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy pathexport_reconfusion_example.py
137 lines (112 loc) · 4.18 KB
/
export_reconfusion_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import argparse
import json
import os
import numpy as np
from PIL import Image
try:
from sklearn.cluster import KMeans # type: ignore[import]
except ImportError:
print("Please install sklearn to use this script.")
exit(1)
# Define the folder containing the image and JSON files
subfolder = "/path/to/your/dataset"
output_file = os.path.join(subfolder, "transforms.json")
# List to hold the frames
frames = []
# Iterate over the files in the folder
for file in sorted(os.listdir(subfolder)):
if file.endswith(".json"):
# Read the JSON file containing camera extrinsics and intrinsics
json_path = os.path.join(subfolder, file)
with open(json_path, "r") as f:
data = json.load(f)
# Read the corresponding image file
image_file = file.replace(".json", ".png")
image_path = os.path.join(subfolder, image_file)
if not os.path.exists(image_path):
print(f"Image file not found for {file}, skipping...")
continue
with Image.open(image_path) as img:
w, h = img.size
# Extract and normalize intrinsic matrix K
K = data["K"]
fx = K[0][0] * w
fy = K[1][1] * h
cx = K[0][2] * w
cy = K[1][2] * h
# Extract the transformation matrix
transform_matrix = np.array(data["c2w"])
# Adjust for OpenGL convention
transform_matrix[..., [1, 2]] *= -1
# Add the frame data
frames.append(
{
"fl_x": fx,
"fl_y": fy,
"cx": cx,
"cy": cy,
"w": w,
"h": h,
"file_path": f"./{os.path.relpath(image_path, subfolder)}",
"transform_matrix": transform_matrix.tolist(),
}
)
# Create the output dictionary
transforms_data = {"orientation_override": "none", "frames": frames}
# Write to the transforms.json file
with open(output_file, "w") as f:
json.dump(transforms_data, f, indent=4)
print(f"transforms.json generated at {output_file}")
# Train-test split function using K-means clustering with stride
def create_train_test_split(frames, n, output_path, stride):
# Prepare the data for K-means
positions = []
for frame in frames:
transform_matrix = np.array(frame["transform_matrix"])
position = transform_matrix[:3, 3] # 3D camera position
direction = transform_matrix[:3, 2] / np.linalg.norm(
transform_matrix[:3, 2]
) # Normalized 3D direction
positions.append(np.concatenate([position, direction]))
positions = np.array(positions)
# Apply K-means clustering
kmeans = KMeans(n_clusters=n, random_state=42)
kmeans.fit(positions)
centers = kmeans.cluster_centers_
# Find the index closest to each cluster center
train_ids = []
for center in centers:
distances = np.linalg.norm(positions - center, axis=1)
train_ids.append(int(np.argmin(distances))) # Convert to Python int
# Remaining indices as test_ids, applying stride
all_indices = set(range(len(frames)))
remaining_indices = sorted(all_indices - set(train_ids))
test_ids = [
int(idx) for idx in remaining_indices[::stride]
] # Convert to Python int
# Create the split data
split_data = {"train_ids": sorted(train_ids), "test_ids": test_ids}
with open(output_path, "w") as f:
json.dump(split_data, f, indent=4)
print(f"Train-test split file generated at {output_path}")
# Parse arguments
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate train-test split JSON file using K-means clustering."
)
parser.add_argument(
"--n",
type=int,
required=True,
help="Number of frames to include in the training set.",
)
parser.add_argument(
"--stride",
type=int,
default=1,
help="Stride for selecting test frames (not used with K-means).",
)
args = parser.parse_args()
# Create train-test split
train_test_split_path = os.path.join(subfolder, f"train_test_split_{args.n}.json")
create_train_test_split(frames, args.n, train_test_split_path, args.stride)