Skip to content

Commit 479c620

Browse files
Implemented a basic CLI
commit-id:b2136058
1 parent a4337d0 commit 479c620

File tree

12 files changed

+487
-4
lines changed

12 files changed

+487
-4
lines changed

packages/child-lab-cli/hello.py

-1
This file was deleted.

packages/child-lab-cli/pyproject.toml

+21-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,24 @@ version = "0.1.0"
44
description = "Add your description here"
55
readme = "README.md"
66
requires-python = ">=3.12.7"
7-
dependencies = []
7+
dependencies = [
8+
"click>=8.1.8",
9+
"pyserde>=0.23.0",
10+
"torch>=2.6.0",
11+
"tqdm>=4.67.1",
12+
"transformation-buffer",
13+
"video-io",
14+
"depth-pro>=0.1.0",
15+
]
16+
17+
[tool.uv.sources]
18+
depth-pro = { git = "https://github.com/child-lab-uj/depth-pro.git" }
19+
transformation-buffer = { workspace = true }
20+
video-io = { workspace = true }
21+
22+
[build-system]
23+
requires = ["hatchling"]
24+
build-backend = "hatchling.build"
25+
26+
[project.scripts]
27+
child-lab = "child_lab_cli:cli"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import click
2+
3+
from .commands import (
4+
calibrate,
5+
estimate_transformations,
6+
generate_pointcloud,
7+
process,
8+
visualize,
9+
)
10+
11+
12+
@click.group('child-lab')
13+
def cli() -> None: ...
14+
15+
16+
cli.add_command(calibrate)
17+
cli.add_command(estimate_transformations)
18+
cli.add_command(generate_pointcloud)
19+
cli.add_command(process)
20+
cli.add_command(visualize)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .calibrate import calibrate
2+
from .estimate_transformations import estimate_transformations
3+
from .generate_pointcloud import generate_pointcloud
4+
from .process import process
5+
from .visualize import visualize
6+
7+
__all__ = [
8+
'calibrate',
9+
'estimate_transformations',
10+
'generate_pointcloud',
11+
'process',
12+
'visualize',
13+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from pathlib import Path
2+
3+
import click
4+
5+
6+
@click.command('calibrate')
7+
@click.argument('workspace', type=Path)
8+
@click.argument('videos', type=Path, nargs=-1)
9+
@click.option('--square-size', type=float, help='Board square size in centimeters')
10+
@click.option(
11+
'--inner-board-corners',
12+
nargs=2,
13+
type=int,
14+
help="Number of chessboard's inner corners in rows and columns that calibration algorithm should locate",
15+
)
16+
@click.option(
17+
'--max-samples',
18+
type=int,
19+
required=False,
20+
help='Maximal number of board samples to collect',
21+
)
22+
@click.option(
23+
'--max-speed',
24+
type=float,
25+
required=False,
26+
help='Maximal speed the board can move with to be captured, in pixels per second',
27+
)
28+
@click.option(
29+
'--min-distance',
30+
type=float,
31+
required=False,
32+
help='Minimal distance between new observation and the previous observations to be captured',
33+
)
34+
def calibrate(
35+
workspace: Path,
36+
videos: list[Path],
37+
square_size: float,
38+
inner_board_corners: tuple[int, int],
39+
max_samples: int | None,
40+
max_speed: float | None,
41+
min_distance: float | None,
42+
) -> None:
43+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from pathlib import Path
2+
3+
import click
4+
5+
6+
@click.command('estimate-transformations')
7+
@click.argument('workspace', type=Path)
8+
@click.argument('videos', type=Path, nargs=-1)
9+
@click.option('--marker-dictionary', type=str, help='Dictionary to detect markers from')
10+
@click.option('--marker-size', type=float, help='Marker size in centimeters')
11+
@click.option(
12+
'--device',
13+
type=str,
14+
required=False,
15+
help='Torch device to use for tensor computations',
16+
)
17+
@click.option(
18+
'--checkpoint',
19+
type=Path,
20+
required=False,
21+
help='File containing serialized Buffer to load and place new transformations in',
22+
)
23+
@click.option(
24+
'--skip',
25+
type=int,
26+
required=False,
27+
help='Seconds of videos to skip at the beginning',
28+
)
29+
def estimate_transformations(
30+
workspace: Path,
31+
videos: list[Path],
32+
marker_dictionary: str,
33+
marker_size: float,
34+
device: str | None,
35+
checkpoint: Path | None,
36+
skip: int | None,
37+
) -> None:
38+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from collections.abc import Generator
2+
from pathlib import Path
3+
4+
import click
5+
import torch
6+
import tqdm
7+
from depth_pro import Config, DepthPro
8+
from jaxtyping import UInt8
9+
from serde.yaml import from_yaml
10+
from torchvision.transforms import Compose, ConvertImageDtype, Normalize, Resize
11+
from video_io.calibration import Calibration
12+
from video_io.reader import Reader
13+
14+
from ..workspace import Workspace
15+
16+
17+
@click.command('generate-pointcloud')
18+
@click.argument('workspace_root', type=Path)
19+
@click.argument('video_name', type=str)
20+
@click.option(
21+
'--checkpoint',
22+
type=Path,
23+
required=True,
24+
help='DepthPro checkpoint to build the depth estimator from',
25+
)
26+
@click.option(
27+
'--batch-size',
28+
type=int,
29+
default=16,
30+
required=False,
31+
help='Number of frames to process as a single batch',
32+
)
33+
@click.option(
34+
'--device',
35+
type=str,
36+
required=False,
37+
help='Torch device to use for tensor computations',
38+
)
39+
def generate_pointcloud(
40+
workspace_root: Path,
41+
video_name: str,
42+
checkpoint: Path,
43+
batch_size: int,
44+
device: str | None,
45+
) -> None:
46+
workspace = Workspace.in_directory(workspace_root)
47+
output = workspace.output / 'points'
48+
49+
for video in workspace.calibrated_videos():
50+
if video.name == video_name:
51+
break
52+
else:
53+
raise FileNotFoundError(f'Video {video_name} not found in {workspace.input}.')
54+
55+
calibration = from_yaml(Calibration, video.calibration.read_text())
56+
57+
main_device = torch.device(device or 'cpu')
58+
59+
reader = Reader(video.location, torch.device('cpu'))
60+
depth_estimator = DepthEstimator(checkpoint, main_device)
61+
62+
click.echo('Depth estimator created!')
63+
64+
def batched_frames(
65+
reader: Reader,
66+
batch_size: int,
67+
) -> Generator[torch.Tensor, None, None]:
68+
while (frames := reader.read_batch(batch_size)) is not None:
69+
yield frames
70+
71+
progress_bar = tqdm.tqdm(
72+
range(0, reader.metadata.frames, batch_size),
73+
desc='Processing batches of frames',
74+
)
75+
76+
for i, frames in enumerate(batched_frames(reader, batch_size)):
77+
depths = depth_estimator.predict(frames.to(main_device))
78+
perspective_points = calibration.unproject_depth(depths)
79+
80+
torch.save(perspective_points, output / f'points_{i}.pt')
81+
82+
progress_bar.update()
83+
84+
click.echo('Done!')
85+
86+
87+
# TODO: Delete this from here
88+
class DepthEstimator:
89+
device: torch.device
90+
model: DepthPro
91+
model_config: Config
92+
to_model: Compose
93+
94+
def __init__(self, checkpoint: Path, device: torch.device) -> None:
95+
self.device = device
96+
97+
config = Config(checkpoint=checkpoint)
98+
self.model_config = config
99+
self.model = DepthPro(config, device, torch.half)
100+
101+
self.to_model = Compose( # type: ignore[no-untyped-call]
102+
[
103+
ConvertImageDtype(torch.half),
104+
Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), # type: ignore[no-untyped-call]
105+
]
106+
)
107+
108+
def predict(
109+
self,
110+
frame_batch: UInt8[torch.Tensor, 'batch 3 height width'],
111+
) -> torch.Tensor:
112+
*_, height, width = frame_batch.shape
113+
114+
result = self.model.predict(self.to_model(frame_batch))
115+
116+
return (
117+
Resize((height, width)) # type: ignore[no-untyped-call]
118+
.forward(result.depth)
119+
.to(torch.float32)
120+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from pathlib import Path
2+
3+
import click
4+
5+
6+
@click.command('process')
7+
@click.argument('workspace', type=Path)
8+
@click.argument('videos', type=Path, nargs=-1)
9+
@click.option(
10+
'--device',
11+
type=str,
12+
required=False,
13+
help='Torch device to use for tensor computations',
14+
)
15+
@click.option(
16+
'--skip',
17+
type=int,
18+
required=False,
19+
help='Seconds of videos to skip at the beginning',
20+
)
21+
@click.option(
22+
'--dynamic-transformations',
23+
type=bool,
24+
is_flag=True,
25+
default=False,
26+
help='Compute camera transformations on the fly, using heuristic algorithms',
27+
)
28+
# @click_trap()
29+
def process(
30+
workspace: Path,
31+
videos: list[Path],
32+
device: str | None,
33+
skip: int | None,
34+
dynamic_transformations: bool,
35+
) -> None:
36+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from pathlib import Path
2+
3+
import click
4+
5+
from ..workspace import Workspace
6+
7+
# 1. Load a workspace
8+
# 3. Load a single video
9+
# 2. Estimate or load the depth
10+
# 3. Launch a `viser` server
11+
# 4. Draw the point clouds
12+
# 5. Load the transformation buffer
13+
# 6. Draw those frames of reference which are reachable
14+
# from the video's frame of reference
15+
16+
17+
@click.command('visualize')
18+
@click.argument('workspace_root', type=Path)
19+
@click.argument('video_name', type=str)
20+
@click.option(
21+
'--device',
22+
type=str,
23+
required=False,
24+
help='Torch device to use for tensor computations',
25+
)
26+
def visualize(
27+
workspace_root: Path,
28+
video_name: str,
29+
device: str | None,
30+
) -> None:
31+
workspace = Workspace.in_directory(workspace_root)
32+
33+
for video in workspace.calibrated_videos():
34+
if video.name == video_name:
35+
break
36+
else:
37+
raise FileNotFoundError(f'Video {video_name} not found in {workspace.input}.')
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .model import CalibratedVideo, NonCalibratedVideo, Workspace, WorkspaceModelError
2+
3+
__all__ = ['CalibratedVideo', 'NonCalibratedVideo', 'Workspace', 'WorkspaceModelError']

0 commit comments

Comments
 (0)