-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtemp.py
executable file
·94 lines (76 loc) · 3.02 KB
/
temp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# binary mask prompt
mask_to_logits = Segmentix()
# new_mask_input = mask_to_logits.reference_to_sam_mask(bin_mask_input)
bin_masks = {}
center_points = {}
for bin_mask_id in np.unique(label):
bin_mask_image = np.where(label==bin_mask_id,1,0).astype(np.uint8)
bin_masks[bin_mask_id] = mask_to_logits.reference_to_sam_mask(np.where(label==bin_mask_id,1,0))
# 进行连通区域提取
connectivity = 8 # 连通性,4代表4连通,8代表8连通
output = cv2.connectedComponentsWithStats(bin_mask_image, connectivity, cv2.CV_32S)
# 获取连通区域的数量
num_labels = output[0]
# 获取连通区域的属性
labels = output[1]
stats = output[2]
cps = []
# 循环遍历每个连通区域
for i in range(1, num_labels):
# 获取连通区域的左上角坐标和宽高
x = stats[i, cv2.CC_STAT_LEFT]
y = stats[i, cv2.CC_STAT_TOP]
width = stats[i, cv2.CC_STAT_WIDTH]
height = stats[i, cv2.CC_STAT_HEIGHT]
if width * height < 200:
continue
contours, _ = cv2.findContours(np.uint8(labels == i), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 计算区域的质心
M = cv2.moments(contours[0])
if M["m00"] == 0:
continue
center_x = int(M["m10"] / M["m00"])
center_y = int(M["m01"] / M["m00"])
# 绘制连通区域的外接矩形
center_point = (center_x, center_y)
if bin_mask_image[center_point[1], center_point[0]]:
cps.append(center_point)
else:
points = np.where(labels == i)
idx = np.random.choice(list(range(len(points[0]))))
cps.append([points[1][idx], points[0][idx]])
center_points[bin_mask_id] = cps
print(center_points)
import os
if not os.path.isdir("./mask_prompt/"):
os.makedirs("./mask_prompt/")
print("{} made".format("./mask_prompt/"))
for bin_mask_id in bin_masks:
print(class_names[bin_mask_id])
pos = center_points[bin_mask_id]
neg = []
for mask_id in center_points:
if mask_id != bin_mask_id:
neg += center_points[mask_id]
plt.imshow((np.exp(bin_masks[bin_mask_id])/(np.exp(bin_masks[bin_mask_id])+1))[0], cmap="gray")
plt.savefig(f"./mask_prompt/type_{class_names[bin_mask_id]}_mask_prompt.png")
plt.show()
input_points = np.array(pos + neg)
input_labels = np.array([1]*len(pos) + [0]*len(neg)),
mask, _, logit = predictor.predict(
point_coords= input_points,
point_labels= input_labes,
mask_input = bin_masks[bin_mask_id],
multimask_output=False,
)
plt.figure(figsize=(16, 8))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_points, input_labels, plt.gca())
plt.axis('off')
plt.savefig(f"./mask_prompt/type_{class_names[bin_mask_id]}_image_with_mask.png")
plt.show()
plt.close()
plt.imshow((np.exp(logit)/(np.exp(logit)+1))[0], cmap="gray")
plt.savefig(f"./mask_prompt/type_{class_names[bin_mask_id]}_logit.png")
plt.show()