|
| 1 | +import random |
| 2 | +import numpy as np |
| 3 | +from itertools import combinations |
| 4 | +import json |
| 5 | + |
| 6 | +from osdsynth.processor.instruction_template import * |
| 7 | +from osdsynth.processor.prompt_utils import * |
| 8 | +from osdsynth.processor.pointcloud import human_like_distance, calculate_distances_between_point_clouds |
| 9 | + |
| 10 | + |
| 11 | +def left_predicate(A, B): |
| 12 | + true_responses = left_true_responses |
| 13 | + false_responses = left_false_responses |
| 14 | + |
| 15 | + A_desc, A_cloud = A["caption"], A["pcd"] |
| 16 | + B_desc, B_cloud = B["caption"], B["pcd"] |
| 17 | + A_desc, B_desc = A_desc.lower(), B_desc.lower() |
| 18 | + |
| 19 | + A_pos = A_cloud.get_center() |
| 20 | + B_pos = B_cloud.get_center() |
| 21 | + |
| 22 | + is_left = A_pos[0] > B_pos[0] # Compare X coordinates |
| 23 | + |
| 24 | + response_template = random.choice(true_responses if is_left else false_responses) |
| 25 | + answer = response_template.replace("[A]", A_desc).replace("[B]", B_desc) |
| 26 | + |
| 27 | + return answer |
| 28 | + |
| 29 | + |
| 30 | +def below_predicate(A, B): |
| 31 | + true_responses = below_true_responses |
| 32 | + false_responses = below_false_responses |
| 33 | + |
| 34 | + A_desc, A_cloud = A["caption"], A["pcd"] |
| 35 | + B_desc, B_cloud = B["caption"], B["pcd"] |
| 36 | + A_desc, B_desc = A_desc.lower(), B_desc.lower() |
| 37 | + |
| 38 | + A_pos = A_cloud.get_center() |
| 39 | + B_pos = B_cloud.get_center() |
| 40 | + |
| 41 | + is_below = A_pos[1] < B_pos[1] |
| 42 | + |
| 43 | + response_template = random.choice(true_responses if is_below else false_responses) |
| 44 | + |
| 45 | + answer = response_template.replace("[A]", A_desc).replace("[B]", B_desc) |
| 46 | + |
| 47 | + return answer |
| 48 | + |
| 49 | + |
| 50 | +def short_predicate(A, B): |
| 51 | + true_responses = short_true_responses |
| 52 | + false_responses = short_false_responses |
| 53 | + |
| 54 | + A_desc, A_cloud = A["caption"], A["pcd"] |
| 55 | + B_desc, B_cloud = B["caption"], B["pcd"] |
| 56 | + A_desc, B_desc = A_desc.lower(), B_desc.lower() |
| 57 | + |
| 58 | + height_A = A_cloud.get_axis_aligned_bounding_box().get_extent()[1] |
| 59 | + height_B = B_cloud.get_axis_aligned_bounding_box().get_extent()[1] |
| 60 | + |
| 61 | + is_shorter = height_A < height_B |
| 62 | + |
| 63 | + response_template = random.choice(true_responses if is_shorter else false_responses) |
| 64 | + |
| 65 | + answer = response_template.replace("[A]", A_desc).replace("[B]", B_desc) |
| 66 | + |
| 67 | + return answer |
| 68 | + |
| 69 | + |
| 70 | +def thin_predicate(A, B): |
| 71 | + true_responses = thin_true_responses |
| 72 | + false_responses = thin_false_responses |
| 73 | + |
| 74 | + A_desc, A_cloud = A["caption"], A["pcd"] |
| 75 | + B_desc, B_cloud = B["caption"], B["pcd"] |
| 76 | + A_desc, B_desc = A_desc.lower(), B_desc.lower() |
| 77 | + |
| 78 | + width_A = A_cloud.get_axis_aligned_bounding_box().get_extent()[0] |
| 79 | + width_B = B_cloud.get_axis_aligned_bounding_box().get_extent()[0] |
| 80 | + |
| 81 | + is_thinner = width_A < width_B |
| 82 | + |
| 83 | + response_template = random.choice(true_responses if is_thinner else false_responses) |
| 84 | + |
| 85 | + answer = response_template.replace("[A]", A_desc).replace("[B]", B_desc) |
| 86 | + |
| 87 | + return answer |
| 88 | + |
| 89 | + |
| 90 | +def small_predicate(A, B): |
| 91 | + true_responses = small_true_responses |
| 92 | + false_responses = small_false_responses |
| 93 | + |
| 94 | + A_desc, A_cloud = A["caption"], A["pcd"] |
| 95 | + B_desc, B_cloud = B["caption"], B["pcd"] |
| 96 | + A_desc, B_desc = A_desc.lower(), B_desc.lower() |
| 97 | + |
| 98 | + extent_A = A_cloud.get_axis_aligned_bounding_box().get_extent() |
| 99 | + volume_A = extent_A[0] * extent_A[1] * extent_A[2] |
| 100 | + |
| 101 | + extent_B = B_cloud.get_axis_aligned_bounding_box().get_extent() |
| 102 | + volume_B = extent_B[0] * extent_B[1] * extent_B[2] |
| 103 | + |
| 104 | + is_smaller = volume_A < volume_B |
| 105 | + |
| 106 | + response_template = random.choice(true_responses if is_smaller else false_responses) |
| 107 | + |
| 108 | + answer = response_template.replace("[A]", A_desc).replace("[B]", B_desc) |
| 109 | + |
| 110 | + return answer |
| 111 | + |
| 112 | + |
| 113 | +def front_predicate(A, B): |
| 114 | + true_responses = front_true |
| 115 | + false_responses = front_false |
| 116 | + |
| 117 | + A_desc, A_cloud = A["caption"], A["pcd"] |
| 118 | + B_desc, B_cloud = B["caption"], B["pcd"] |
| 119 | + A_desc, B_desc = A_desc.lower(), B_desc.lower() |
| 120 | + |
| 121 | + # Calculate the minimum z-value for both A and B |
| 122 | + A_min_z = A_cloud.get_min_bound()[2] |
| 123 | + B_min_z = B_cloud.get_min_bound()[2] |
| 124 | + # Determine if A is behind B based on the minimum z-value |
| 125 | + is_in_front = A_min_z < B_min_z |
| 126 | + |
| 127 | + response_template = random.choice(true_responses if is_in_front else false_responses) |
| 128 | + |
| 129 | + answer = response_template.replace("[A]", A_desc).replace("[B]", B_desc) |
| 130 | + |
| 131 | + return answer |
| 132 | + |
| 133 | + |
| 134 | +# Distance prompts |
| 135 | + |
| 136 | + |
| 137 | +def generate_spatial_reasoning_data(A, B, human_readable_dist, template_answers): |
| 138 | + A_desc, B_desc = A["caption"].lower(), B["caption"].lower() |
| 139 | + |
| 140 | + answer_template = random.choice(template_answers) |
| 141 | + |
| 142 | + # Replace placeholders with actual values |
| 143 | + answer = answer_template.replace("[A]", A_desc).replace("[B]", B_desc).replace("[X]", human_readable_dist) |
| 144 | + |
| 145 | + # Add to the dataset |
| 146 | + return answer |
| 147 | + |
| 148 | + |
| 149 | +def vertical_distance_data(A, B, use_center=True): |
| 150 | + template_answers = vertical_distance_answers |
| 151 | + |
| 152 | + # Get the bounding boxes for both A and B |
| 153 | + A_box = A["pcd"].get_axis_aligned_bounding_box() |
| 154 | + B_box = B["pcd"].get_axis_aligned_bounding_box() |
| 155 | + |
| 156 | + if use_center: |
| 157 | + A_center = A_box.get_axis_aligned_bounding_box().get_center() |
| 158 | + B_center = B_box.get_axis_aligned_bounding_box().get_center() |
| 159 | + vertical_distance = abs(A_center[1] - B_center[1]) |
| 160 | + else: |
| 161 | + # Determine the highest and lowest points (in terms of y-value) of each object |
| 162 | + A_min_y, A_max_y = A_box.get_min_bound()[1], A_box.get_max_bound()[1] |
| 163 | + B_min_y, B_max_y = B_box.get_min_bound()[1], B_box.get_max_bound()[1] |
| 164 | + |
| 165 | + # Assuming A is above B, adjust if it's the other way around |
| 166 | + if A_min_y < B_min_y: |
| 167 | + # This means B is above A, swap the values |
| 168 | + A_min_y, A_max_y, B_min_y, B_max_y = B_min_y, B_max_y, A_min_y, A_max_y |
| 169 | + |
| 170 | + # The vertical distance is now the difference between the lowest point of the higher object (B_max_y) |
| 171 | + # and the highest point of the lower object (A_min_y), considering A is below B after the possible swap. |
| 172 | + vertical_distance = A_min_y - B_max_y if A_min_y > B_max_y else 0 |
| 173 | + |
| 174 | + human_readable_dist = human_like_distance(vertical_distance) |
| 175 | + |
| 176 | + return generate_spatial_reasoning_data(A, B, human_readable_dist, template_answers) |
| 177 | + |
| 178 | + |
| 179 | +def distance(A, B): |
| 180 | + template_answers = distance_template_answers |
| 181 | + distance = calculate_distances_between_point_clouds(A["pcd"], B["pcd"]) |
| 182 | + return generate_spatial_reasoning_data( |
| 183 | + A, |
| 184 | + B, |
| 185 | + distance, |
| 186 | + template_answers, |
| 187 | + ) |
| 188 | + |
| 189 | + |
| 190 | +def horizontal_distance_data(A, B, use_center=True): |
| 191 | + template_answers = horizontal_distance_answers |
| 192 | + |
| 193 | + # Extract bounding boxes for A and B |
| 194 | + A_box = A["pcd"].get_axis_aligned_bounding_box() |
| 195 | + B_box = B["pcd"].get_axis_aligned_bounding_box() |
| 196 | + |
| 197 | + if use_center: |
| 198 | + A_center = A_box.get_center() |
| 199 | + B_center = B_box.get_center() |
| 200 | + horizontal_distance = np.sqrt((A_center[0] - B_center[0]) ** 2) |
| 201 | + else: |
| 202 | + # Extract min and max bounds for A and B on x and z axes |
| 203 | + A_min, A_max = A_box.get_min_bound(), A_box.get_max_bound() |
| 204 | + B_min, B_max = B_box.get_min_bound(), B_box.get_max_bound() |
| 205 | + |
| 206 | + # Calculate the shortest horizontal (x, z plane) distance between the two boxes |
| 207 | + horizontal_distance = max(A_min[0] - B_max[0], B_min[0] - A_max[0], 0) |
| 208 | + |
| 209 | + human_readable_dist = human_like_distance(horizontal_distance) |
| 210 | + |
| 211 | + return generate_spatial_reasoning_data(A, B, human_readable_dist, template_answers) |
| 212 | + |
| 213 | + |
| 214 | +def width_data(A, B=None): |
| 215 | + A_desc = A["caption"].lower() |
| 216 | + |
| 217 | + template_answers = width_answers |
| 218 | + |
| 219 | + width = A["pcd"].get_axis_aligned_bounding_box().get_extent()[0] |
| 220 | + |
| 221 | + human_readable_width = human_like_distance(width) |
| 222 | + answer_template = random.choice(template_answers) |
| 223 | + |
| 224 | + answer = answer_template.replace("[A]", A_desc).replace("[X]", human_readable_width) |
| 225 | + |
| 226 | + return answer |
| 227 | + |
| 228 | + |
| 229 | +def height_data(A, B=None): |
| 230 | + A_desc = A["caption"].lower() |
| 231 | + |
| 232 | + template_answers = height_answers |
| 233 | + |
| 234 | + height = A["pcd"].get_axis_aligned_bounding_box().get_extent()[1] |
| 235 | + |
| 236 | + human_readable_height = human_like_distance(height) |
| 237 | + answer_template = random.choice(template_answers) |
| 238 | + |
| 239 | + answer = answer_template.replace("[A]", A_desc).replace("[X]", human_readable_height) |
| 240 | + |
| 241 | + return answer |
| 242 | + |
| 243 | + |
| 244 | +def direction(A, B): |
| 245 | + template_responses = direction_responses |
| 246 | + |
| 247 | + A_desc, A_cloud = A["caption"], A["pcd"] |
| 248 | + B_desc, B_cloud = B["caption"], B["pcd"] |
| 249 | + A_desc, B_desc = A_desc.lower(), B_desc.lower() |
| 250 | + |
| 251 | + A_pos = (A_cloud.get_center()[0], A_cloud.get_center()[2]) # Only x, z |
| 252 | + B_pos = (B_cloud.get_center()[0], B_cloud.get_center()[2]) # Only x, z |
| 253 | + |
| 254 | + clock_position = calculate_angle_clockwise(A_pos, B_pos) |
| 255 | + |
| 256 | + answer_template = random.choice(template_responses) |
| 257 | + |
| 258 | + answer = answer_template.replace("[X]", str(int(clock_position))).replace("[A]", A_desc).replace("[B]", B_desc) |
| 259 | + |
| 260 | + return answer |
| 261 | + |
| 262 | + |
| 263 | +class PromptGenerator: |
| 264 | + def __init__(self, cfg, logger, device): |
| 265 | + """Initialize the class.""" |
| 266 | + self.cfg = cfg |
| 267 | + self.logger = logger |
| 268 | + self.device = device |
| 269 | + self.vis = True |
| 270 | + |
| 271 | + def evaluate_predicates_on_pairs(self, detections): |
| 272 | + |
| 273 | + all_combinations = list(combinations(range(len(detections)), 2)) |
| 274 | + random.shuffle(all_combinations) |
| 275 | + selected_combinations = all_combinations[:3] |
| 276 | + object_pairs = [(detections[i], detections[j]) for i, j in selected_combinations] |
| 277 | + |
| 278 | + all_prompt_variants = [ |
| 279 | + # direction, |
| 280 | + left_predicate, |
| 281 | + thin_predicate, |
| 282 | + small_predicate, |
| 283 | + front_predicate, |
| 284 | + below_predicate, |
| 285 | + short_predicate, |
| 286 | + vertical_distance_data, |
| 287 | + horizontal_distance_data, |
| 288 | + width_data, |
| 289 | + height_data, |
| 290 | + distance, |
| 291 | + ] |
| 292 | + |
| 293 | + results = [] |
| 294 | + |
| 295 | + for A, B in object_pairs: |
| 296 | + |
| 297 | + to_remove = set() # A set to hold items to remove |
| 298 | + |
| 299 | + # Remove all items in `to_remove` from `all_prompt_variants`, if present |
| 300 | + all_prompt_variants = [item for item in all_prompt_variants if item not in to_remove] |
| 301 | + |
| 302 | + # selected_predicates_choices = all_prompt_variants |
| 303 | + selected_predicates_choices = random.sample(all_prompt_variants, 3) |
| 304 | + |
| 305 | + for prompt_func in selected_predicates_choices: |
| 306 | + results.append(prompt_func(A, B)) |
| 307 | + |
| 308 | + return results |
0 commit comments