-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathsegmentation.py
38 lines (32 loc) · 1.12 KB
/
segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from torchmetrics import Metric
class IoU(Metric):
def __init__(
self,
n_class: int,
dist_sync_on_step: bool = False,
dist_reduce_fx: str = "sum",
) -> None:
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.n_class = n_class
self.add_state(
"intersection",
default=torch.zeros((n_class,)),
dist_reduce_fx=dist_reduce_fx,
)
self.add_state(
"union",
default=torch.zeros((n_class,)),
dist_reduce_fx=dist_reduce_fx,
)
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
res = preds.argmax(dim=1)
for index in range(self.n_class):
gt = target.cpu() == index
preds = res == index
intersection = gt.logical_and(preds.cpu())
union = gt.logical_or(preds.cpu())
self.intersection[index] += intersection.float().sum()
self.union[index] += union.float().sum()
def compute(self) -> torch.Tensor:
return self.intersection.sum() / self.union.sum()