-
Notifications
You must be signed in to change notification settings - Fork 46
Open
Description
In order to deal with CUDA error of out of bounds for the bounding box I ended up using this code in utils.py:
def evaluate_heatmaps_at_location(
heatmaps: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"],
locs: TensorType["batch", "num_keypoints", 2],
sigma: float = 1.25, # sigma used for generating heatmaps
num_stds: int = 2, # num standard deviations of pixels to compute confidence
) -> TensorType["batch", "num_keypoints"]:
"""Evaluate 4D heatmaps using a 3D location tensor (last dim is x, y coords). Since
the model outputs heatmaps with a standard deviation of sigma, confidence will be
spread across neighboring pixels. To account for this, confidence is computed by
taking all pixels within two standard deviations of the predicted pixel."""
pix_to_consider = int(np.floor(sigma * num_stds)) # get all pixels within num_stds.
num_pad = pix_to_consider
heatmaps_padded = torch.zeros(
(
heatmaps.shape[0],
heatmaps.shape[1],
heatmaps.shape[2] + num_pad * 2,
heatmaps.shape[3] + num_pad * 2,
),
device=heatmaps.device,
)
heatmaps_padded[:, :, num_pad:-num_pad, num_pad:-num_pad] = heatmaps
i = torch.arange(heatmaps_padded.shape[0], device=heatmaps_padded.device).reshape(
-1, 1, 1, 1
)
j = torch.arange(heatmaps_padded.shape[1], device=heatmaps_padded.device).reshape(
1, -1, 1, 1
)
# Handle NaN values: replace with 0 and clamp to prevent index out of bounds
nan_mask = torch.isnan(locs).any(dim=-1)
locs_safe = torch.where(torch.isnan(locs), torch.zeros_like(locs), locs)
k = torch.clamp(locs_safe[:, :, None, 1, None].type(torch.int64) + num_pad, 0, heatmaps_padded.shape[2] - 1)
m = torch.clamp(locs_safe[:, :, 0, None, None].type(torch.int64) + num_pad, 0, heatmaps_padded.shape[3] - 1)
offsets = list(np.arange(-pix_to_consider, pix_to_consider + 1))
vals_all = []
for offset in offsets:
for offset_2 in offsets:
k_offset = torch.clamp(k + offset, 0, heatmaps_padded.shape[2] - 1)
m_offset = torch.clamp(m + offset_2, 0, heatmaps_padded.shape[3] - 1)
vals_all.append(heatmaps_padded[i, j, k_offset, m_offset].squeeze(-1).squeeze(-1))
vals = torch.stack(vals_all, 0).sum(0)
vals[nan_mask] = 0.0 # Set confidence to zero for NaN locations
return vals
def normalized_to_bbox(
keypoints: TensorType["batch", "num_keypoints", "xy":2],
bbox: TensorType["batch", "xyhw":4]
) -> TensorType["batch", "num_keypoints", "xy":2]:
if keypoints.shape[0] == bbox.shape[0]:
# normal batch
bbox_width = bbox[:, 3].unsqueeze(1)
bbox_x = bbox[:, 0].unsqueeze(1)
bbox_height = bbox[:, 2].unsqueeze(1)
bbox_y = bbox[:, 1].unsqueeze(1)
keypoints[:, :, 0] *= bbox_width # scale x by box width
keypoints[:, :, 0] += bbox_x # add bbox x offset
keypoints[:, :, 1] *= bbox_height # scale y by box height
keypoints[:, :, 1] += bbox_y # add bbox y offset
else:
# context batch; we don't have predictions for first/last two frames
bbox_slice = bbox[2:-2]
keypoints[:, :, 0] *= bbox_slice[:, 3].unsqueeze(1) # scale x by box width
keypoints[:, :, 0] += bbox_slice[:, 0].unsqueeze(1) # add bbox x offset
keypoints[:, :, 1] *= bbox_slice[:, 2].unsqueeze(1) # scale y by box height
keypoints[:, :, 1] += bbox_slice[:, 1].unsqueeze(1) # add bbox y offset
return keypoints
def convert_bbox_coords(
batch_dict: (
HeatmapLabeledBatchDict
| MultiviewHeatmapLabeledBatchDict
| MultiviewUnlabeledBatchDict
| UnlabeledBatchDict
),
predicted_keypoints: TensorType["batch", "num_targets"],
) -> TensorType["batch", "num_targets"]:
"""Transform keypoints from bbox coordinates to absolute frame coordinates."""
num_targets = predicted_keypoints.shape[1]
num_keypoints = num_targets // 2
# reshape from (batch, n_targets) back to (batch, n_key, 2), in x,y order
predicted_keypoints = predicted_keypoints.reshape((-1, num_keypoints, 2))
# divide by image dims to get 0-1 normalized coordinates
if "images" in batch_dict.keys():
img_shape = batch_dict["images"].shape
predicted_keypoints[:, :, 0] /= img_shape[-1] # -1 dim is width "x"
predicted_keypoints[:, :, 1] /= img_shape[-2] # -2 dim is height "y"
else: # we have unlabeled dict, 'frames' instead of 'images'
frames_shape = batch_dict["frames"].shape
predicted_keypoints[:, :, 0] /= frames_shape[-1] # -1 dim is width "x"
predicted_keypoints[:, :, 1] /= frames_shape[-2] # -2 dim is height "y"
# multiply and add by bbox dims (x,y,h,w)
has_num_views = "num_views" in batch_dict.keys()
is_multiview_flag = batch_dict.get("is_multiview", False)
if (
(has_num_views and int(batch_dict["num_views"].max()) > 1)
or is_multiview_flag
):
# the first check is for labeled batches while is_multiview is for unlabeled batches
# For MultiviewUnlabeledBatchDict, we need to infer num_views from bbox shape
if has_num_views:
unique = batch_dict["num_views"].unique()
if len(unique) != 1:
raise ValueError(
f"each batch element must contain the same number of views; "
f"found elements with {unique} views"
)
num_views = int(unique)
else:
# Infer from bbox shape: bbox has shape [seq_len, num_views * 4]
num_views = batch_dict["bbox"].shape[1] // 4
num_keypoints_per_view = num_keypoints // num_views
if batch_dict["bbox"].shape[1] < num_views * 4:
raise ValueError(
f"bbox shape mismatch: expected at least {num_views * 4} columns "
f"(num_views={num_views} * 4), but got {batch_dict['bbox'].shape[1]}"
)
for v in range(num_views):
idx_beg = num_keypoints_per_view * v
idx_end = idx_beg + num_keypoints_per_view
bbox_start = 4 * v
bbox_end = 4 * (v + 1)
bbox_slice = batch_dict["bbox"][:, bbox_start:bbox_end]
kp_slice = predicted_keypoints[:, idx_beg:idx_end, :]
predicted_keypoints[:, idx_beg:idx_end, :] = normalized_to_bbox(
kp_slice,
bbox_slice,
)
else:
predicted_keypoints = normalized_to_bbox(predicted_keypoints, batch_dict["bbox"])
# return new keypoints, reshaped to (batch, num_targets)
result = predicted_keypoints.reshape((-1, num_targets))
return result
Metadata
Metadata
Assignees
Labels
No labels