Skip to content

bbox CUDA error out of bounds #353

@LennyAharon

Description

@LennyAharon

In order to deal with CUDA error of out of bounds for the bounding box I ended up using this code in utils.py:

def evaluate_heatmaps_at_location(
    heatmaps: TensorType["batch", "num_keypoints", "heatmap_height", "heatmap_width"],
    locs: TensorType["batch", "num_keypoints", 2],
    sigma: float = 1.25,  # sigma used for generating heatmaps
    num_stds: int = 2,  # num standard deviations of pixels to compute confidence
) -> TensorType["batch", "num_keypoints"]:
    """Evaluate 4D heatmaps using a 3D location tensor (last dim is x, y coords). Since
    the model outputs heatmaps with a standard deviation of sigma, confidence will be
    spread across neighboring pixels. To account for this, confidence is computed by
    taking all pixels within two standard deviations of the predicted pixel."""
    pix_to_consider = int(np.floor(sigma * num_stds))  # get all pixels within num_stds.
    num_pad = pix_to_consider
    heatmaps_padded = torch.zeros(
        (
            heatmaps.shape[0],
            heatmaps.shape[1],
            heatmaps.shape[2] + num_pad * 2,
            heatmaps.shape[3] + num_pad * 2,
        ),
        device=heatmaps.device,
    )
    heatmaps_padded[:, :, num_pad:-num_pad, num_pad:-num_pad] = heatmaps
    i = torch.arange(heatmaps_padded.shape[0], device=heatmaps_padded.device).reshape(
        -1, 1, 1, 1
    )
    j = torch.arange(heatmaps_padded.shape[1], device=heatmaps_padded.device).reshape(
        1, -1, 1, 1
    )
    # Handle NaN values: replace with 0 and clamp to prevent index out of bounds
    nan_mask = torch.isnan(locs).any(dim=-1)
    locs_safe = torch.where(torch.isnan(locs), torch.zeros_like(locs), locs)
    k = torch.clamp(locs_safe[:, :, None, 1, None].type(torch.int64) + num_pad, 0, heatmaps_padded.shape[2] - 1)
    m = torch.clamp(locs_safe[:, :, 0, None, None].type(torch.int64) + num_pad, 0, heatmaps_padded.shape[3] - 1)
    offsets = list(np.arange(-pix_to_consider, pix_to_consider + 1))
    vals_all = []
    for offset in offsets:
        for offset_2 in offsets:
            k_offset = torch.clamp(k + offset, 0, heatmaps_padded.shape[2] - 1)
            m_offset = torch.clamp(m + offset_2, 0, heatmaps_padded.shape[3] - 1)
            vals_all.append(heatmaps_padded[i, j, k_offset, m_offset].squeeze(-1).squeeze(-1))
    vals = torch.stack(vals_all, 0).sum(0)
    vals[nan_mask] = 0.0  # Set confidence to zero for NaN locations
    return vals
def normalized_to_bbox(
    keypoints: TensorType["batch", "num_keypoints", "xy":2],
    bbox: TensorType["batch", "xyhw":4]
) -> TensorType["batch", "num_keypoints", "xy":2]:
    if keypoints.shape[0] == bbox.shape[0]:
        # normal batch
        bbox_width = bbox[:, 3].unsqueeze(1)
        bbox_x = bbox[:, 0].unsqueeze(1)
        bbox_height = bbox[:, 2].unsqueeze(1)
        bbox_y = bbox[:, 1].unsqueeze(1)
        
        keypoints[:, :, 0] *= bbox_width  # scale x by box width
        keypoints[:, :, 0] += bbox_x  # add bbox x offset
        keypoints[:, :, 1] *= bbox_height  # scale y by box height
        keypoints[:, :, 1] += bbox_y  # add bbox y offset
    else:
        # context batch; we don't have predictions for first/last two frames
        bbox_slice = bbox[2:-2]
        keypoints[:, :, 0] *= bbox_slice[:, 3].unsqueeze(1)  # scale x by box width
        keypoints[:, :, 0] += bbox_slice[:, 0].unsqueeze(1)  # add bbox x offset
        keypoints[:, :, 1] *= bbox_slice[:, 2].unsqueeze(1)  # scale y by box height
        keypoints[:, :, 1] += bbox_slice[:, 1].unsqueeze(1)  # add bbox y offset
    
    return keypoints
def convert_bbox_coords(
    batch_dict: (
        HeatmapLabeledBatchDict
        | MultiviewHeatmapLabeledBatchDict
        | MultiviewUnlabeledBatchDict
        | UnlabeledBatchDict
    ),
    predicted_keypoints: TensorType["batch", "num_targets"],
) -> TensorType["batch", "num_targets"]:
    """Transform keypoints from bbox coordinates to absolute frame coordinates."""
    num_targets = predicted_keypoints.shape[1]
    num_keypoints = num_targets // 2
    # reshape from (batch, n_targets) back to (batch, n_key, 2), in x,y order
    predicted_keypoints = predicted_keypoints.reshape((-1, num_keypoints, 2))
    # divide by image dims to get 0-1 normalized coordinates
    if "images" in batch_dict.keys():
        img_shape = batch_dict["images"].shape
        predicted_keypoints[:, :, 0] /= img_shape[-1]  # -1 dim is width "x"
        predicted_keypoints[:, :, 1] /= img_shape[-2]  # -2 dim is height "y"
    else:  # we have unlabeled dict, 'frames' instead of 'images'
        frames_shape = batch_dict["frames"].shape
        predicted_keypoints[:, :, 0] /= frames_shape[-1]  # -1 dim is width "x"
        predicted_keypoints[:, :, 1] /= frames_shape[-2]  # -2 dim is height "y"
    
    # multiply and add by bbox dims (x,y,h,w)
    has_num_views = "num_views" in batch_dict.keys()
    is_multiview_flag = batch_dict.get("is_multiview", False)
    
    if (
        (has_num_views and int(batch_dict["num_views"].max()) > 1)
        or is_multiview_flag
    ):
        # the first check is for labeled batches while is_multiview is for unlabeled batches
        # For MultiviewUnlabeledBatchDict, we need to infer num_views from bbox shape
        if has_num_views:
            unique = batch_dict["num_views"].unique()
            if len(unique) != 1:
                raise ValueError(
                    f"each batch element must contain the same number of views; "
                    f"found elements with {unique} views"
                )
            num_views = int(unique)
        else:
            # Infer from bbox shape: bbox has shape [seq_len, num_views * 4]
            num_views = batch_dict["bbox"].shape[1] // 4

        num_keypoints_per_view = num_keypoints // num_views

        if batch_dict["bbox"].shape[1] < num_views * 4:
            raise ValueError(
                f"bbox shape mismatch: expected at least {num_views * 4} columns "
                f"(num_views={num_views} * 4), but got {batch_dict['bbox'].shape[1]}"
            )

        for v in range(num_views):
            idx_beg = num_keypoints_per_view * v
            idx_end = idx_beg + num_keypoints_per_view
            bbox_start = 4 * v
            bbox_end = 4 * (v + 1)
            
            bbox_slice = batch_dict["bbox"][:, bbox_start:bbox_end]
            kp_slice = predicted_keypoints[:, idx_beg:idx_end, :]

            predicted_keypoints[:, idx_beg:idx_end, :] = normalized_to_bbox(
                kp_slice,
                bbox_slice,
            )
    else:
        predicted_keypoints = normalized_to_bbox(predicted_keypoints, batch_dict["bbox"])    
    # return new keypoints, reshaped to (batch, num_targets)
    result = predicted_keypoints.reshape((-1, num_targets))
    return result

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions