Skip to content

Commit 16715d4

Browse files
committedDec 14, 2024
[Add] data pipeline w/ llm rephrase & geocalib support
1 parent b10a202 commit 16715d4

File tree

9 files changed

+909
-20
lines changed

9 files changed

+909
-20
lines changed
 

‎dataset_pipeline/README.md

+51-10
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55
#### Environment
66

77
```sh
8-
conda create -n osd_pipeline anaconda python=3.10
8+
conda create -n osd_pipeline python=3.10 -y
99
conda activate osd_pipeline
1010

1111
##### Install Pytorch according to your own setup #####
12-
# For example, if you have a GPU with CUDA 12.1
1312
pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121
14-
# This is optional if you prefer to system built-in nvcc.
15-
conda install -c nvidia cuda-toolkit -y
1613

1714
# We use mmengine for config management
1815
pip install -U openmim
@@ -25,19 +22,35 @@ pip install https://github.com/zju3dv/Wis3D/releases/download/2.0.0/wis3d-2.0.0-
2522
pip install 'git+https://github.com/facebookresearch/detectron2.git'
2623

2724
# Some other libraries
28-
pip install iopath pyequilib==0.3.0 albumentations einops
25+
pip install iopath pyequilib==0.3.0 albumentations einops open3d imageio
2926
pip install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.13/index.html
30-
3127
```
3228

3329
#### Install Grounded-SAM package
3430

35-
```
36-
mkdir external && cd osdsynth/external
31+
```sh
32+
mkdir osdsynth/external && cd osdsynth/external
3733
git clone https://github.com/IDEA-Research/Grounded-Segment-Anything.git
3834
```
3935

40-
Follow the instructions on the original [repo](https://github.com/IDEA-Research/Grounded-Segment-Anything#install-without-docker). Our pipeline has been tested with the codebase at this [commit](https://github.com/open-mmlab/mmengine/commit/85c83ba61689907fb1775713622b1b146d82277b). Grounded-SAM codebase at later commits may require some adaptations. If you encounter problems installing the RAM package, try upgrade your `setuptools` version to the latest version.
36+
Follow the instructions on the original [repo](https://github.com/IDEA-Research/Grounded-Segment-Anything#install-without-docker) to build Segment Anything, Grounding DINO, and RAM, respectively. Our pipeline has been tested with the codebase at this [commit](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/126abe633ffe333e16e4a0a4e946bc1003caf757).
37+
38+
```sh
39+
cd Grounded-Segment-Anything/
40+
41+
# Install Segment Anything
42+
python -m pip install -e segment_anything
43+
44+
# Install Grounding DINO
45+
pip install --no-build-isolation -e GroundingDINO
46+
47+
# Install RAM
48+
git clone https://github.com/xinyu1205/recognize-anything.git
49+
pip install -r ./recognize-anything/requirements.txt
50+
pip install setuptools --upgrade
51+
pip install -e ./recognize-anything/
52+
```
53+
4154

4255
#### Install Perspective Fields package
4356

@@ -56,10 +69,11 @@ sh ./scripts/download_all_weights.sh
5669

5770
### Inference
5871

72+
#### Template-based QA
5973
To specify the folder containing the images for testing, use the `--input` argument. You can also adjust the settings in `configs/v2.py` to better suit your images, like modifying the SAM thresholds or tweaking the DBSCAN hyperparameters.
6074

6175
```sh
62-
python run.py --config configs/v2.py --input PATH_TO_INPUT --vis
76+
python run_template_qa.py --config configs/v2.py --input PATH_TO_INPUT --vis True
6377
```
6478

6579
The results are saved in two formats. One is in JSON, where the Open3D bounding boxes are serialized. If you'd like to recreate the Open3D bounding box object for each detection, you can use the following code:
@@ -73,6 +87,33 @@ bbox = o3d.geometry.AxisAlignedBoundingBox(
7387

7488
The other format is compatible with Wis3D point clouds. You can use the instructions below to visualize these results.
7589

90+
91+
#### LLM-rephrased QA
92+
93+
**Step1:** Generate template-based descriptions with the following command, this will save a `llm_prompts.json` in the output json folder.
94+
95+
```sh
96+
python run_template_facts.py --config configs/v2.py --input PATH_TO_INPUT --vis True
97+
```
98+
99+
**Step2:** Prepare a clean environment and install sglang
100+
```sh
101+
conda create -n sglang python=3.10 -y
102+
conda activate sglang
103+
104+
pip install --upgrade pip
105+
pip install "sglang[all]"
106+
107+
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
108+
```
109+
110+
**Step3:** Run llm rephrase, currently the script is using Llama-3.1-70B
111+
```sh
112+
export HF_TOKEN=<key>
113+
python run_llm.py --llm-prompts-path /PATH/SAMPLE_llm_prompts.json --port 3000 --gpus 8
114+
115+
```
116+
76117
### Wis3D Visualization
77118

78119
```sh

‎dataset_pipeline/configs/v2.py

+2
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,5 @@
9494
wid3d_interval = 1
9595

9696
use_clip = False
97+
98+
perspective_model_variant = "perspective_fields"
+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Class related params
2+
class_set = "ram"
3+
add_bg_classes = False
4+
accumu_classes = False
5+
exp_suffix = None
6+
rm_bg_classes = True
7+
8+
add_classes = []
9+
remove_classes = [
10+
"room",
11+
"kitchen",
12+
"office",
13+
"house",
14+
"home",
15+
"building",
16+
"corner",
17+
"shadow",
18+
"carpet",
19+
"photo",
20+
"sea",
21+
"shade",
22+
"stall",
23+
"space",
24+
"aquarium",
25+
"apartment",
26+
"image",
27+
"city",
28+
"blue",
29+
"skylight",
30+
"hallway",
31+
"bureau",
32+
"modern",
33+
"salon",
34+
"doorway",
35+
"wall lamp",
36+
"scene",
37+
"sun",
38+
"sky",
39+
"smile",
40+
"cloudy",
41+
"comfort",
42+
"white",
43+
"black",
44+
"red",
45+
"green",
46+
"blue",
47+
"yellow",
48+
"purple",
49+
"pink",
50+
"stand",
51+
"wear",
52+
"area",
53+
"shine",
54+
"lay",
55+
"walk",
56+
"lead",
57+
"bite",
58+
"sing",
59+
]
60+
bg_classes = ["wall", "floor", "ceiling"]
61+
62+
# Sam related params
63+
sam_variant = "sam-hq"
64+
65+
# Tag2text related params
66+
specified_tags = "None"
67+
68+
# Grounding Dino related params
69+
box_threshold = 0.25
70+
text_threshold = 0.2
71+
nms_threshold = 0.5
72+
73+
# LLaVa related params
74+
masking_option = "none"
75+
76+
# Selection criteria on the 2D masks
77+
mask_area_threshold = 25 # mask with pixel area less than this will be skipped
78+
mask_conf_threshold = 0.3 # mask with lower confidence score will be skipped default 0.2
79+
max_bbox_area_ratio = 0.75 # boxes with larger areas than this will be skipped
80+
skip_bg = False
81+
min_points_threshold = 16 # projected and sampled pcd with less points will be skipped
82+
min_points_threshold_after_denoise = 10
83+
84+
# point cloud processing
85+
downsample_voxel_size = 0.025
86+
dbscan_remove_noise = True
87+
dbscan_eps = 0.2 # v1 use 0.2
88+
dbscan_min_points = 10
89+
90+
# bounding-box related
91+
spatial_sim_type = "overlap" # "iou", "giou", "overlap"
92+
93+
save_interval = 1
94+
wid3d_interval = 1
95+
96+
use_clip = False
97+
98+
perspective_model_variant = "geo_calib"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
import random
2+
import numpy as np
3+
from itertools import combinations
4+
import json
5+
6+
from osdsynth.processor.instruction_template import *
7+
from osdsynth.processor.prompt_utils import *
8+
from osdsynth.processor.pointcloud import human_like_distance, calculate_distances_between_point_clouds
9+
10+
11+
def left_predicate(A, B):
12+
true_responses = left_true_responses
13+
false_responses = left_false_responses
14+
15+
A_desc, A_cloud = A["caption"], A["pcd"]
16+
B_desc, B_cloud = B["caption"], B["pcd"]
17+
A_desc, B_desc = A_desc.lower(), B_desc.lower()
18+
19+
A_pos = A_cloud.get_center()
20+
B_pos = B_cloud.get_center()
21+
22+
is_left = A_pos[0] > B_pos[0] # Compare X coordinates
23+
24+
response_template = random.choice(true_responses if is_left else false_responses)
25+
answer = response_template.replace("[A]", A_desc).replace("[B]", B_desc)
26+
27+
return answer
28+
29+
30+
def below_predicate(A, B):
31+
true_responses = below_true_responses
32+
false_responses = below_false_responses
33+
34+
A_desc, A_cloud = A["caption"], A["pcd"]
35+
B_desc, B_cloud = B["caption"], B["pcd"]
36+
A_desc, B_desc = A_desc.lower(), B_desc.lower()
37+
38+
A_pos = A_cloud.get_center()
39+
B_pos = B_cloud.get_center()
40+
41+
is_below = A_pos[1] < B_pos[1]
42+
43+
response_template = random.choice(true_responses if is_below else false_responses)
44+
45+
answer = response_template.replace("[A]", A_desc).replace("[B]", B_desc)
46+
47+
return answer
48+
49+
50+
def short_predicate(A, B):
51+
true_responses = short_true_responses
52+
false_responses = short_false_responses
53+
54+
A_desc, A_cloud = A["caption"], A["pcd"]
55+
B_desc, B_cloud = B["caption"], B["pcd"]
56+
A_desc, B_desc = A_desc.lower(), B_desc.lower()
57+
58+
height_A = A_cloud.get_axis_aligned_bounding_box().get_extent()[1]
59+
height_B = B_cloud.get_axis_aligned_bounding_box().get_extent()[1]
60+
61+
is_shorter = height_A < height_B
62+
63+
response_template = random.choice(true_responses if is_shorter else false_responses)
64+
65+
answer = response_template.replace("[A]", A_desc).replace("[B]", B_desc)
66+
67+
return answer
68+
69+
70+
def thin_predicate(A, B):
71+
true_responses = thin_true_responses
72+
false_responses = thin_false_responses
73+
74+
A_desc, A_cloud = A["caption"], A["pcd"]
75+
B_desc, B_cloud = B["caption"], B["pcd"]
76+
A_desc, B_desc = A_desc.lower(), B_desc.lower()
77+
78+
width_A = A_cloud.get_axis_aligned_bounding_box().get_extent()[0]
79+
width_B = B_cloud.get_axis_aligned_bounding_box().get_extent()[0]
80+
81+
is_thinner = width_A < width_B
82+
83+
response_template = random.choice(true_responses if is_thinner else false_responses)
84+
85+
answer = response_template.replace("[A]", A_desc).replace("[B]", B_desc)
86+
87+
return answer
88+
89+
90+
def small_predicate(A, B):
91+
true_responses = small_true_responses
92+
false_responses = small_false_responses
93+
94+
A_desc, A_cloud = A["caption"], A["pcd"]
95+
B_desc, B_cloud = B["caption"], B["pcd"]
96+
A_desc, B_desc = A_desc.lower(), B_desc.lower()
97+
98+
extent_A = A_cloud.get_axis_aligned_bounding_box().get_extent()
99+
volume_A = extent_A[0] * extent_A[1] * extent_A[2]
100+
101+
extent_B = B_cloud.get_axis_aligned_bounding_box().get_extent()
102+
volume_B = extent_B[0] * extent_B[1] * extent_B[2]
103+
104+
is_smaller = volume_A < volume_B
105+
106+
response_template = random.choice(true_responses if is_smaller else false_responses)
107+
108+
answer = response_template.replace("[A]", A_desc).replace("[B]", B_desc)
109+
110+
return answer
111+
112+
113+
def front_predicate(A, B):
114+
true_responses = front_true
115+
false_responses = front_false
116+
117+
A_desc, A_cloud = A["caption"], A["pcd"]
118+
B_desc, B_cloud = B["caption"], B["pcd"]
119+
A_desc, B_desc = A_desc.lower(), B_desc.lower()
120+
121+
# Calculate the minimum z-value for both A and B
122+
A_min_z = A_cloud.get_min_bound()[2]
123+
B_min_z = B_cloud.get_min_bound()[2]
124+
# Determine if A is behind B based on the minimum z-value
125+
is_in_front = A_min_z < B_min_z
126+
127+
response_template = random.choice(true_responses if is_in_front else false_responses)
128+
129+
answer = response_template.replace("[A]", A_desc).replace("[B]", B_desc)
130+
131+
return answer
132+
133+
134+
# Distance prompts
135+
136+
137+
def generate_spatial_reasoning_data(A, B, human_readable_dist, template_answers):
138+
A_desc, B_desc = A["caption"].lower(), B["caption"].lower()
139+
140+
answer_template = random.choice(template_answers)
141+
142+
# Replace placeholders with actual values
143+
answer = answer_template.replace("[A]", A_desc).replace("[B]", B_desc).replace("[X]", human_readable_dist)
144+
145+
# Add to the dataset
146+
return answer
147+
148+
149+
def vertical_distance_data(A, B, use_center=True):
150+
template_answers = vertical_distance_answers
151+
152+
# Get the bounding boxes for both A and B
153+
A_box = A["pcd"].get_axis_aligned_bounding_box()
154+
B_box = B["pcd"].get_axis_aligned_bounding_box()
155+
156+
if use_center:
157+
A_center = A_box.get_axis_aligned_bounding_box().get_center()
158+
B_center = B_box.get_axis_aligned_bounding_box().get_center()
159+
vertical_distance = abs(A_center[1] - B_center[1])
160+
else:
161+
# Determine the highest and lowest points (in terms of y-value) of each object
162+
A_min_y, A_max_y = A_box.get_min_bound()[1], A_box.get_max_bound()[1]
163+
B_min_y, B_max_y = B_box.get_min_bound()[1], B_box.get_max_bound()[1]
164+
165+
# Assuming A is above B, adjust if it's the other way around
166+
if A_min_y < B_min_y:
167+
# This means B is above A, swap the values
168+
A_min_y, A_max_y, B_min_y, B_max_y = B_min_y, B_max_y, A_min_y, A_max_y
169+
170+
# The vertical distance is now the difference between the lowest point of the higher object (B_max_y)
171+
# and the highest point of the lower object (A_min_y), considering A is below B after the possible swap.
172+
vertical_distance = A_min_y - B_max_y if A_min_y > B_max_y else 0
173+
174+
human_readable_dist = human_like_distance(vertical_distance)
175+
176+
return generate_spatial_reasoning_data(A, B, human_readable_dist, template_answers)
177+
178+
179+
def distance(A, B):
180+
template_answers = distance_template_answers
181+
distance = calculate_distances_between_point_clouds(A["pcd"], B["pcd"])
182+
return generate_spatial_reasoning_data(
183+
A,
184+
B,
185+
distance,
186+
template_answers,
187+
)
188+
189+
190+
def horizontal_distance_data(A, B, use_center=True):
191+
template_answers = horizontal_distance_answers
192+
193+
# Extract bounding boxes for A and B
194+
A_box = A["pcd"].get_axis_aligned_bounding_box()
195+
B_box = B["pcd"].get_axis_aligned_bounding_box()
196+
197+
if use_center:
198+
A_center = A_box.get_center()
199+
B_center = B_box.get_center()
200+
horizontal_distance = np.sqrt((A_center[0] - B_center[0]) ** 2)
201+
else:
202+
# Extract min and max bounds for A and B on x and z axes
203+
A_min, A_max = A_box.get_min_bound(), A_box.get_max_bound()
204+
B_min, B_max = B_box.get_min_bound(), B_box.get_max_bound()
205+
206+
# Calculate the shortest horizontal (x, z plane) distance between the two boxes
207+
horizontal_distance = max(A_min[0] - B_max[0], B_min[0] - A_max[0], 0)
208+
209+
human_readable_dist = human_like_distance(horizontal_distance)
210+
211+
return generate_spatial_reasoning_data(A, B, human_readable_dist, template_answers)
212+
213+
214+
def width_data(A, B=None):
215+
A_desc = A["caption"].lower()
216+
217+
template_answers = width_answers
218+
219+
width = A["pcd"].get_axis_aligned_bounding_box().get_extent()[0]
220+
221+
human_readable_width = human_like_distance(width)
222+
answer_template = random.choice(template_answers)
223+
224+
answer = answer_template.replace("[A]", A_desc).replace("[X]", human_readable_width)
225+
226+
return answer
227+
228+
229+
def height_data(A, B=None):
230+
A_desc = A["caption"].lower()
231+
232+
template_answers = height_answers
233+
234+
height = A["pcd"].get_axis_aligned_bounding_box().get_extent()[1]
235+
236+
human_readable_height = human_like_distance(height)
237+
answer_template = random.choice(template_answers)
238+
239+
answer = answer_template.replace("[A]", A_desc).replace("[X]", human_readable_height)
240+
241+
return answer
242+
243+
244+
def direction(A, B):
245+
template_responses = direction_responses
246+
247+
A_desc, A_cloud = A["caption"], A["pcd"]
248+
B_desc, B_cloud = B["caption"], B["pcd"]
249+
A_desc, B_desc = A_desc.lower(), B_desc.lower()
250+
251+
A_pos = (A_cloud.get_center()[0], A_cloud.get_center()[2]) # Only x, z
252+
B_pos = (B_cloud.get_center()[0], B_cloud.get_center()[2]) # Only x, z
253+
254+
clock_position = calculate_angle_clockwise(A_pos, B_pos)
255+
256+
answer_template = random.choice(template_responses)
257+
258+
answer = answer_template.replace("[X]", str(int(clock_position))).replace("[A]", A_desc).replace("[B]", B_desc)
259+
260+
return answer
261+
262+
263+
class PromptGenerator:
264+
def __init__(self, cfg, logger, device):
265+
"""Initialize the class."""
266+
self.cfg = cfg
267+
self.logger = logger
268+
self.device = device
269+
self.vis = True
270+
271+
def evaluate_predicates_on_pairs(self, detections):
272+
273+
all_combinations = list(combinations(range(len(detections)), 2))
274+
random.shuffle(all_combinations)
275+
selected_combinations = all_combinations[:3]
276+
object_pairs = [(detections[i], detections[j]) for i, j in selected_combinations]
277+
278+
all_prompt_variants = [
279+
# direction,
280+
left_predicate,
281+
thin_predicate,
282+
small_predicate,
283+
front_predicate,
284+
below_predicate,
285+
short_predicate,
286+
vertical_distance_data,
287+
horizontal_distance_data,
288+
width_data,
289+
height_data,
290+
distance,
291+
]
292+
293+
results = []
294+
295+
for A, B in object_pairs:
296+
297+
to_remove = set() # A set to hold items to remove
298+
299+
# Remove all items in `to_remove` from `all_prompt_variants`, if present
300+
all_prompt_variants = [item for item in all_prompt_variants if item not in to_remove]
301+
302+
# selected_predicates_choices = all_prompt_variants
303+
selected_predicates_choices = random.sample(all_prompt_variants, 3)
304+
305+
for prompt_func in selected_predicates_choices:
306+
results.append(prompt_func(A, B))
307+
308+
return results
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
direction_responses = [
2+
"[B] is roughly at [X] o'clock from [A].",
3+
"[A] find [B] around the [X] o'clock direction.",
4+
]
5+
6+
height_answers = [
7+
"The height of [A] is [X].",
8+
"[A] is [X] tall.",
9+
"[A] is [X] in height.",
10+
]
11+
12+
width_answers = [
13+
"The width of [A] is [X].",
14+
"[A] is [X] wide.",
15+
"[A] is [X] in width.",
16+
]
17+
18+
horizontal_distance_answers = [
19+
"[A] and [B] are [X] apart horizontally.",
20+
"[A] is [X] away from [B] horizontally.",
21+
"A horizontal distance of [X] exists between [A] and [B].",
22+
"[A] is [X] from [B] horizontally.",
23+
"Horizontally, [A] and [B] are [X] apart.",
24+
"[A] and [B] are [X] apart horizontally from each other.",
25+
"The horizontal distance of [A] from [B] is [X].",
26+
]
27+
28+
vertical_distance_answers = [
29+
"[A] and [B] are [X] apart vertically.",
30+
"[A] is [X] away from [B] vertically.",
31+
"A vertical distance of [X] exists between [A] and [B].",
32+
"[A] is [X] from [B] vertically.",
33+
"[A] and [B] are [X] apart vertically from each other.",
34+
"Vertically, [A] and [B] are [X] apart.",
35+
"The vertical distance of [A] from [B] is [X].",
36+
]
37+
38+
front_true = [
39+
"[A] is closer to the viewer than [B].",
40+
"[A] is in front of [B].",
41+
]
42+
43+
front_false = [
44+
"[A] is further to the viewer than [B].",
45+
"[B] is behind [A].",
46+
]
47+
48+
small_true_responses = [
49+
"[A] is smaller than [B].",
50+
"[A] has a smaller size compared to [B].",
51+
"[A] occupies less space than [B].",
52+
]
53+
54+
small_false_responses = [
55+
"[A] is bigger than [B].",
56+
"[A] has a larger size compared to [B].",
57+
"[A] is larger in size than [B].",
58+
]
59+
60+
thin_true_responses = [
61+
"[A] is thinner than [B].",
62+
"[A] has a lesser width compared to [B].",
63+
"[A]'s width is less than [B]'s.",
64+
]
65+
66+
thin_false_responses = [
67+
"[A] might be wider than [B]",
68+
"[A]'s width surpass [B]'s width.",
69+
"[A]'s width is larger than [B]'s.",
70+
]
71+
72+
short_true_responses = [
73+
"[A] is shorter than [B].",
74+
"[A] has a lesser height compared to [B].",
75+
"[A] is not as tall as [B].",
76+
]
77+
78+
short_false_responses = [
79+
"[A] is taller than [B].",
80+
"[A] has a greater height compared to [B].",
81+
"[A] is much taller as [B].",
82+
]
83+
84+
below_true_responses = [
85+
"[A] is below [B].",
86+
"[A] is positioned under [B].",
87+
"[A] is located below [B].",
88+
]
89+
90+
below_false_responses = [
91+
"[A] is above [B].",
92+
"[A] is positioned over [B].",
93+
"[A] is located above [B].",
94+
]
95+
96+
left_true_responses = [
97+
"[A] is to the left of [B].",
98+
"[A] is positioned on the left side of [B].",
99+
"You'll find [A] to the left of [B].",
100+
]
101+
102+
left_false_responses = [
103+
"[A] is to the right of [B].",
104+
"[A] is positioned on the right side of [B].",
105+
"You'll find [A] to the right of [B].",
106+
]
107+
108+
distance_template_answers = [
109+
"[A] and [B] are [X] apart.",
110+
"[A] is [X] away from [B].",
111+
"A distance of [X] exists between [A] and [B].",
112+
"[A] is [X] from [B].",
113+
"[A] and [B] are [X] apart from each other.",
114+
"The distance of [A] from [B] is [X].",
115+
]

‎dataset_pipeline/osdsynth/processor/pointcloud.py

+30-10
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,16 @@ def __init__(self, cfg, logger, device, init_models=True):
2929

3030
if init_models:
3131
# Initialize the perspective_fields_model
32-
self.perspective_fields_model = get_perspective_fields_model(cfg, device)
32+
if self.cfg.perspective_model_variant == "perspective_fields":
33+
print(f"Using Perspective Fields")
34+
self.perspective_fields_model = get_perspective_fields_model(cfg, device)
35+
elif self.cfg.perspective_model_variant == "geo_calib":
36+
from geocalib import GeoCalib
37+
38+
print(f"Using Geo Calib")
39+
self.perspective_fields_model = GeoCalib(weights="distorted").to(device)
40+
else:
41+
raise ValueError(f"perspective_model_variant: {self.cfg.perspective_model_variant} not implemented")
3342

3443
# Initialize the Camera Intrinsics Model
3544
self.wilde_camera_model = torch.hub.load("ShngJZ/WildCamera", "WildCamera", pretrained=True).to(device)
@@ -45,16 +54,27 @@ def process(self, filename, image_bgr, detections_list):
4554
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
4655
image_rgb_pil = Image.fromarray(image_rgb)
4756

48-
# Run Perspective Fields, this returns the pitch, roll
49-
(
50-
vis_perspective_fields,
51-
perspective_fields,
52-
) = run_perspective_fields_model(self.perspective_fields_model, image_bgr)
57+
if self.cfg.perspective_model_variant == "perspective_fields":
58+
# Run Perspective Fields, this returns the pitch, roll
59+
(
60+
vis_perspective_fields,
61+
perspective_fields,
62+
) = run_perspective_fields_model(self.perspective_fields_model, image_bgr)
63+
roll, pitch = perspective_fields["roll"], perspective_fields["pitch"]
64+
65+
elif self.cfg.perspective_model_variant == "geo_calib":
66+
from geocalib.utils import rad2deg
67+
68+
# load image as tensor in range [0, 1] with shape [C, H, W]
69+
image_geo = torch.tensor((image_rgb.transpose((2, 0, 1))) / 255.0, dtype=torch.float).to(self.device)
70+
geo_results = self.perspective_fields_model.calibrate(image_geo, camera_model="simple_radial")
71+
roll, pitch = rad2deg(geo_results["gravity"].rp).unbind(-1)
72+
roll, pitch = roll.item(), pitch.item()
5373

54-
# Perspective Fields to Rotation Matrix
74+
# Perspective to Rotation Matrix
5575
perspective_R = create_rotation_matrix(
56-
roll=perspective_fields["roll"],
57-
pitch=perspective_fields["pitch"],
76+
roll=roll,
77+
pitch=pitch,
5878
yaw=0,
5979
degrees=True,
6080
)
@@ -79,7 +99,7 @@ def process(self, filename, image_bgr, detections_list):
7999

80100
if self.vis:
81101
wis3d = Wis3D(self.cfg.wis3d_folder, filename)
82-
# wis3d.add_point_cloud(vertices=pts3d.reshape((-1, 3)), colors=image_rgb.reshape(-1, 3), name="pts3d")
102+
wis3d.add_point_cloud(vertices=pts3d.reshape((-1, 3)), colors=image_rgb.reshape(-1, 3), name="pts3d")
83103
wis3d.add_point_cloud(
84104
vertices=cano_pts3d.reshape((-1, 3)), colors=image_rgb.reshape(-1, 3), name="cano_pts3d"
85105
)

‎dataset_pipeline/run_llm.py

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import argparse
2+
import time
3+
import warnings
4+
import re
5+
import json
6+
7+
from sglang import function, system, gen, set_default_backend, RuntimeEndpoint
8+
from sglang.utils import (
9+
execute_shell_command,
10+
wait_for_server,
11+
)
12+
13+
# Suppressing all warnings
14+
warnings.filterwarnings("ignore")
15+
16+
17+
response_regex = r"\{" + r' "Question": "[\w\d\s<>?,.!]{1,512}",' + r' "Answer": "[\w\d\s<>?,.!]{1,512}"' + r"\}"
18+
19+
20+
@function
21+
def rephrase_qa(s, question_1):
22+
s += system(
23+
r"""
24+
You are a helpful assistant tasked with generating spatial reasoning-based questions and answers from provided descriptions of scenes.
25+
Always craft a question without directly revealing specific details from the description.
26+
Always generate questions related to the description using <regionX>.
27+
The description should always be used to answer and not leak into the question.
28+
When mentioning the objects or regions, use <regionX> instead of the objects or regions.
29+
Speak like you are the observer's perspective.
30+
Always make sure all the description objects or regions are mentioned with <regionX> in the question.
31+
Only mention each <regionX> once.
32+
33+
Here's several examples:
34+
35+
[Objects]: <region4> sofa, <region1> chair. [Description]: The path between the <region4> and <region1> is 1.5 meters.
36+
"Question": You are a cleaning robot that is 1 meter wide. Now you are standing in a living room and see the image; you want to move from here to the door that leads to the backyard. Do you think I can go through the path between the <region4> and <region1>?
37+
"Answer": The path between <region4> and <region1> is 1.5 meters, so yes, the robot can go through the path between <region4> and <region1> since it is wider than the robot's width.
38+
39+
[Objects]: <region2> apple, <region3> orange. [Description]: <region2> is positioned on the left side of <region3>.
40+
"Question": You see two fruits, an apple in <region2> and an orange in <region3>. Which one is more on the left side?
41+
"Answer": The apple in <region2> is more on the left.
42+
43+
[Objects]: <region3> desk, <region6> bed. [Description]: <region3> is further to the viewer than <region6>.
44+
"Question": You are exploring a bedroom and walking towards <region3> and <region6>. Which one will you reach first?
45+
"Answer": You will reach the bed first because it is closer to you than the desk, which is further away.
46+
47+
[Objects]: <region0> book. [Description]: <region0> is 50 cm in width.
48+
"Question": You are a librarian currently standing in front of a 40 cm width bookshelf, and you see <region0> that you want to place on the shelf. Can you determine if <region0> will fit on the shelf?
49+
"Answer": Answer: <region0> is 50 cm in width, so the shelf is not wide enough to hold a book of that size. Please find a larger shelf.
50+
51+
Now its your turn!
52+
53+
"""
54+
)
55+
s += question_1
56+
s += gen("json_output", max_tokens=1024, regex=response_regex)
57+
58+
59+
def process_prompt(prompt, rephrase_qa, max_retries=5):
60+
for attempt in range(max_retries):
61+
try:
62+
llama_response = rephrase_qa.run(prompt, temperature=0.2)
63+
response_string = llama_response["json_output"]
64+
65+
# Clean and parse the response
66+
cleaned_string = response_string.strip()
67+
cleaned_string = "".join(char for char in cleaned_string if ord(char) >= 32 or char == "\n")
68+
cleaned_string = re.sub(r"\s+", " ", cleaned_string)
69+
cleaned_string = cleaned_string.replace("'", '"')
70+
json_response = json.loads(cleaned_string)
71+
72+
question, answer = json_response["Question"], json_response["Answer"]
73+
74+
# Cleanup question/answer
75+
question = question[2:] if question and question[:2] == ". " else question
76+
answer = answer[2:] if answer and answer[:2] == ". " else answer
77+
78+
# Validate region tags
79+
prompt_tags = {x for x in prompt.split() if x.startswith("<region") and x.endswith(">")}
80+
question_tags = {x for x in question.split() if x.startswith("<region") and x.endswith(">")}
81+
answer_tags = {x for x in answer.split() if x.startswith("<region") and x.endswith(">")}
82+
83+
# Check if all validations pass
84+
if prompt_tags == question_tags and prompt_tags == answer_tags:
85+
if all(question.count(tag) == 1 for tag in prompt_tags):
86+
print(f"Prompt: {prompt}")
87+
print(f"Question: {question}")
88+
print(f"Answer: {answer}")
89+
print("---------------")
90+
return True, question, answer
91+
else:
92+
print(f"Attempt {attempt + 1}: skipping because <regionX> appeared >1 times in question")
93+
else:
94+
print(f"Attempt {attempt + 1}: skipping because <regionX> miss-matched in question/answer")
95+
96+
except Exception as e:
97+
print(f"Attempt {attempt + 1} failed with error: {str(e)}")
98+
99+
print(f"Failed to get valid output after {max_retries} attempts")
100+
return False, None, None
101+
102+
103+
def main(args):
104+
"""Main function to control the flow of the program."""
105+
106+
# Launch sglang backend
107+
server_process = execute_shell_command(
108+
f"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-70B-Instruct --port {args.port} --host 0.0.0.0 --tp {args.gpus}"
109+
)
110+
wait_for_server(f"http://localhost:{args.port}")
111+
set_default_backend(RuntimeEndpoint(f"http://localhost:{args.port}"))
112+
113+
# Read llm_prompts json
114+
with open(args.llm_prompts_path, "r") as f:
115+
llm_prompts = json.load(f)
116+
117+
conversations = []
118+
for prompt in llm_prompts:
119+
success, question, answer = process_prompt(prompt, rephrase_qa)
120+
if success:
121+
conversations.append((question, answer))
122+
123+
for sample in conversations:
124+
print(f"Q: {sample[0]}")
125+
print(f"A: {sample[1]}")
126+
print("-----------------------")
127+
128+
129+
def parse_args():
130+
"""Command-line argument parser."""
131+
parser = argparse.ArgumentParser(description="Generate 3D SceneGraph for an image.")
132+
parser.add_argument("--config", default="configs/v2.py", help="Annotation config file path.")
133+
parser.add_argument("--port", default=3000, help="Port for Sglang")
134+
parser.add_argument("--gpus", default=8, help="Number of gpus")
135+
parser.add_argument(
136+
"--llm-prompts-path",
137+
default="./demo_out/20241125_175649/json/indoor_llm_prompts.json",
138+
help="Path to llm prompt json.",
139+
)
140+
141+
return parser.parse_args()
142+
143+
144+
if __name__ == "__main__":
145+
args = parse_args()
146+
timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
147+
args.timestamp = timestamp
148+
main(args)
+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import argparse
2+
import glob
3+
import os
4+
import random
5+
import time
6+
import warnings
7+
import re
8+
import json
9+
10+
import cv2
11+
import numpy as np
12+
from mmengine import Config
13+
from osdsynth.processor.captions import CaptionImage
14+
from osdsynth.processor.pointcloud import PointCloudReconstruction
15+
from osdsynth.processor.instruction import PromptGenerator
16+
17+
# from osdsynth.processor.filter import FilterImage
18+
from osdsynth.processor.segment import SegmentImage
19+
from osdsynth.utils.logger import SkipImageException, save_detection_list_to_json, setup_logger
20+
from tqdm import tqdm
21+
22+
# Suppressing all warnings
23+
warnings.filterwarnings("ignore")
24+
25+
26+
def main(args):
27+
"""Main function to control the flow of the program."""
28+
# Parse arguments
29+
cfg = Config.fromfile(args.config)
30+
exp_name = args.name if args.name else args.timestamp
31+
32+
# Create log folder
33+
cfg.log_folder = os.path.join(args.log_dir, exp_name)
34+
os.makedirs(os.path.abspath(cfg.log_folder), exist_ok=True)
35+
36+
# Create Wis3D folder
37+
cfg.vis = args.vis
38+
cfg.wis3d_folder = os.path.join(args.log_dir, "Wis3D")
39+
os.makedirs(os.path.abspath(cfg.wis3d_folder), exist_ok=True)
40+
41+
# Init the logger and log some basic info
42+
cfg.log_file = os.path.join(cfg.log_folder, f"{exp_name}_{args.timestamp}.log")
43+
logger = setup_logger() # cfg.log_file
44+
logger.info(f"Config:\n{cfg.pretty_text}")
45+
46+
# Dump config to log
47+
cfg.dump(os.path.join(cfg.log_folder, os.path.basename(args.config)))
48+
49+
# Create output folder
50+
cfg.exp_dir = os.path.join(args.output_dir, exp_name)
51+
os.makedirs(os.path.abspath(cfg.exp_dir), exist_ok=True)
52+
53+
# Create folder for output json
54+
cfg.json_folder = os.path.join(cfg.exp_dir, "json")
55+
os.makedirs(os.path.abspath(cfg.json_folder), exist_ok=True)
56+
57+
global_data = glob.glob(f"{args.input}/*.jpg") + glob.glob(f"{args.input}/*.png")
58+
device = "cuda"
59+
60+
annotate(cfg, global_data, logger, device)
61+
62+
63+
def annotate(cfg, global_data, logger, device):
64+
65+
random.shuffle(global_data)
66+
67+
segmenter = SegmentImage(cfg, logger, device)
68+
reconstructor = PointCloudReconstruction(cfg, logger, device)
69+
captioner = CaptionImage(cfg, logger, device)
70+
prompter = PromptGenerator(cfg, logger, device)
71+
72+
for i, filepath in tqdm(enumerate(global_data), ncols=25):
73+
filename = filepath.split("/")[-1].split(".")[0]
74+
print(f"Processing file: {filename}")
75+
76+
progress_file_path = os.path.join(cfg.log_folder, f"{filename}.progress")
77+
if os.path.exists(progress_file_path) and cfg.check_exist:
78+
continue
79+
80+
image_bgr = cv2.imread(filepath)
81+
image_bgr = cv2.resize(image_bgr, (int(640 / (image_bgr.shape[0]) * (image_bgr.shape[1])), 640))
82+
83+
try:
84+
85+
# Run tagging model and get openworld detections
86+
vis_som, detection_list = segmenter.process(image_bgr)
87+
88+
# Lift 2D to 3D, 3D bbox informations are included in detection_list
89+
detection_list = reconstructor.process(filename, image_bgr, detection_list)
90+
91+
# Get LLaVA local caption for each region, however, currently just use a <region> placeholder
92+
detection_list = captioner.process_local_caption(detection_list)
93+
94+
# Save detection list to json
95+
detection_list_path = os.path.join(cfg.json_folder, f"{filename}.json")
96+
save_detection_list_to_json(detection_list, detection_list_path)
97+
98+
# Generate instructions (facts) based on templates
99+
facts = prompter.evaluate_predicates_on_pairs(detection_list)
100+
101+
batched_llm_prompts = prepare_llm_prompts(facts, detection_list)
102+
103+
llm_prompts_path = os.path.join(cfg.json_folder, f"{filename}_llm_prompts.json")
104+
with open(llm_prompts_path, "w") as f:
105+
json.dump(batched_llm_prompts, f, indent=2)
106+
107+
for llm_prompt in batched_llm_prompts:
108+
print(f"{llm_prompt}")
109+
print("-----------------------")
110+
111+
except SkipImageException as e:
112+
# Meet skip image condition
113+
logger.info(f"Skipping processing {filename}: {e}.")
114+
continue
115+
116+
117+
def prepare_llm_prompts(facts, detection_list):
118+
region_to_tag_list = []
119+
batched_instructions = []
120+
for qa_idx, instruction in enumerate(facts):
121+
i_regions = re.findall(r"<region(\d+)>", instruction)
122+
region_to_tag = {int(region): detection_list[int(region)]["class_name"] for region in i_regions}
123+
region_to_tag_list.append(region_to_tag)
124+
125+
object_reference = []
126+
for r_id, (region, tag) in enumerate(region_to_tag.items()):
127+
object_reference.append(f"<region{region}> {tag}")
128+
object_reference = ", ".join(object_reference)
129+
130+
new_instruction = f"[Objets]: {object_reference}. [Description]: {instruction}"
131+
batched_instructions.append(new_instruction)
132+
133+
return batched_instructions
134+
135+
136+
def parse_args():
137+
"""Command-line argument parser."""
138+
parser = argparse.ArgumentParser(description="Generate 3D SceneGraph for an image.")
139+
parser.add_argument("--config", default="configs/v2.py", help="Annotation config file path.")
140+
parser.add_argument(
141+
"--input",
142+
default="./demo_images",
143+
help="Path to input, can be json of folder of images.",
144+
)
145+
parser.add_argument("--output-dir", default="./demo_out", help="Path to save the scene-graph JSON files.")
146+
parser.add_argument("--name", required=False, default=None, help="Specify, otherwise use timestamp as nameing.")
147+
parser.add_argument("--log-dir", default="./demo_out/log", help="Path to save logs and visualization results.")
148+
parser.add_argument("--vis", required=False, default=True, help="Wis3D visualization for reconstruted pointclouds.")
149+
parser.add_argument("--overwrite", required=False, action="store_true", help="Overwrite previous.")
150+
return parser.parse_args()
151+
152+
153+
if __name__ == "__main__":
154+
args = parse_args()
155+
timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
156+
args.timestamp = timestamp
157+
main(args)
File renamed without changes.

0 commit comments

Comments
 (0)
Please sign in to comment.