-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
68 lines (51 loc) · 1.97 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Utility functions."""
import numpy as np
import torch
from PIL import Image
from torch import Tensor
IMAGE_SIZE = (252, 378)
def dice_coeff(
pred: Tensor,
target: Tensor,
reduce_batch_first: bool = False,
epsilon: float = 1e-6,
):
"""Average of Dice coefficient for all batches, or for a single mask"""
assert pred.size() == target.size()
assert pred.dim() == 3 or not reduce_batch_first
sum_dim = (-1, -2) if pred.dim() == 2 or not reduce_batch_first else (-1, -2, -3)
inter = 2 * (pred * target).sum(dim=sum_dim)
sets_sum = pred.sum(dim=sum_dim) + target.sum(dim=sum_dim)
sets_sum = torch.where(sets_sum == 0, inter, sets_sum)
dice = (inter + epsilon) / (sets_sum + epsilon)
return dice.mean()
def load_mask(mask_path):
"""Loads the segmentation mask from the specified path.
Inputs:
mask_path (str): the path from which the segmentation mask will be read.
It should have the format "/PATH/TO/LOAD/DIR/XXXX_mask.png".
Outputs:
mask (np.array): segmentation mask as a numpy array.
"""
mask = np.asarray(Image.open(mask_path)).astype(int)
if mask.max() > 1:
mask = mask // 255
return mask
def compute_iou(pred_mask, gt_mask, eps=1e-6):
"""Computes the IoU between two numpy arrays: pred_mask and gt_mask.
Inputs:
pred_mask (np.array): dtype:int, shape:(image_height, image_width), values are 0 or 1.
gt_mask (np.array): dtype:int, shape:(image_height, image_width), values are 0 or 1.
eps (float): epsilon to smooth the division in order to avoid 0/0.
Outputs:
iou_score (float)
"""
intersection = (
(pred_mask.long() & gt_mask.long()).sum()
) # will be zero if gt=0 or pred=0
union = (pred_mask.long() | gt_mask.long()).sum() # will be zero if both are 0
iou = (intersection + eps) / (
union + eps
) # we smooth our division by epsilon to avoid 0/0
iou_score = iou.mean()
return iou_score