Skip to content

Commit 3a3ece3

Browse files
committed
Fix for nms
1 parent 344f0c0 commit 3a3ece3

File tree

2 files changed

+10
-13
lines changed

2 files changed

+10
-13
lines changed

examples/models/run_yolo.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import torch
44
from PIL import Image
55
from PIL import ImageDraw
6+
import sys
67

8+
if len(sys.argv) < 2:
9+
print("Specify an image as the command line argument")
10+
sys.exit(0)
711

812
labels20 = [
913
"aeroplane", # 0
@@ -40,18 +44,11 @@
4044
r.execute_command('AI.SCRIPTSET', 'yolo-post', 'CPU', script)
4145

4246

43-
44-
45-
46-
# filename = "../img/sample_dog_416.jpg"
47-
filename = "../img/sample_office_416.jpg"
48-
49-
img_jpg = Image.open(filename)
47+
img_jpg = Image.open(sys.argv[1])
5048

5149
# normalize
5250
img = np.array(img_jpg).astype(np.float32)
53-
img -= 128.0
54-
img /= 128.0
51+
img /= 256.0
5552

5653
r.execute_command('AI.TENSORSET', 'in', 'FLOAT', 1, 416, 416, 3, 'BLOB', img.tobytes())
5754

examples/models/yolo_boxes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ def nms(boxes):
4242
for j in range(boxes.shape[1] - (i+1)):
4343
j_idx = sort_ids[j + i+1]
4444
box_j = boxes[b, j_idx]
45-
if bbox_iou(box_i, box_j) > nms_thresh:
46-
boxes[b, j_idx].zero_()
45+
if float(box_j[4]) > 0. and bbox_iou(box_i, box_j) > nms_thresh:
46+
boxes[b, j_idx] = torch.zeros(7)
4747

4848
return boxes
4949

5050

5151
def get_region_boxes(output):
52-
conf_thresh = 0.25
52+
conf_thresh = 0.2
5353
num_classes = 20
5454
num_anchors = 5
5555
anchor_step = 2
@@ -95,7 +95,7 @@ def get_region_boxes(output):
9595
for cx in range(w):
9696
for i in range(num_anchors):
9797
ind = b*sz_hwa + i*sz_hw + cy*w + cx
98-
det_conf = det_confs[ind]
98+
det_conf = det_confs[ind]
9999
conf = det_confs[ind] * cls_max_confs[ind]
100100

101101
if bool(conf > conf_thresh):

0 commit comments

Comments
 (0)