Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Oct 4, 2023
1 parent 4a1b482 commit 487247b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
3 changes: 1 addition & 2 deletions 307_YOLOv7/post_process_gen_tools/make_box_gather_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@
selected_indices,
batch_dims=0,
)
gathered_boxes_casted = tf.cast(gathered_boxes, dtype=tf.int64)

model = tf.keras.models.Model(inputs=[boxes, selected_indices], outputs=[gathered_boxes_casted])
model = tf.keras.models.Model(inputs=[boxes, selected_indices], outputs=[gathered_boxes])
model.summary()
output_path = 'saved_model_postprocess'
tf.saved_model.save(model, output_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def __init__(self):
super(Model, self).__init__()

def forward(self, x):
batch_nums = x[:, 0:1] # batch number
class_nums = x[:, 1:2] # class ids
box_nums = x[:, [0,2]] # batch number + box number
batch_nums = x[:, 0:1].to(torch.float32) # batch number
class_nums = x[:, 1:2].to(torch.float32) # class ids
box_nums = x[:, [0,2]].to(torch.float32) # batch number + box number
return batch_nums, class_nums, box_nums


Expand Down
24 changes: 13 additions & 11 deletions 307_YOLOv7/post_process_gen_tools/make_nms_outputs_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, batch, classid, x1y1x2y2):
batchno_classid_x1y1x2y2_cat = torch.cat([batch, classid, x1y1x2y2], dim=1)
return batchno_classid_x1y1x2y2_cat
def forward(self, batch, classid, x1y1x2y2, score):
batchno_classid_x1y1x2y2_score_cat = torch.cat([batch, classid, x1y1x2y2, score], dim=1)
return batchno_classid_x1y1x2y2_score_cat


if __name__ == "__main__":
Expand All @@ -29,27 +29,29 @@ def forward(self, batch, classid, x1y1x2y2):

model = Model()

MODEL = f'24_nms_batchno_classid_x1y1x2y2_cat'
MODEL = f'24_nms_batchno_classid_x1y1x2y2_score_cat'

onnx_file = f"{MODEL}.onnx"
OPSET=args.opset

x1 = torch.ones([1, 1], dtype=torch.int64)
x2 = torch.ones([1, 1], dtype=torch.int64)
x3 = torch.ones([1, 4], dtype=torch.int64)
x1 = torch.ones([1, 1], dtype=torch.float32)
x2 = torch.ones([1, 1], dtype=torch.float32)
x3 = torch.ones([1, 4], dtype=torch.float32)
x4 = torch.ones([1, 1], dtype=torch.float32)

torch.onnx.export(
model,
args=(x1,x2,x3),
args=(x1,x2,x3,x4),
f=onnx_file,
opset_version=OPSET,
input_names=['cat_batch','cat_classid','cat_x1y1x2y2'],
output_names=['batchno_classid_x1y1x2y2'],
input_names=['cat_batch','cat_classid','cat_x1y1x2y2','cat_score'],
output_names=['batchno_classid_x1y1x2y2_score'],
dynamic_axes={
'cat_batch': {0: 'N'},
'cat_classid': {0: 'N'},
'cat_x1y1x2y2': {0: 'N'},
'batchno_classid_x1y1x2y2': {0: 'N'},
'cat_score': {0: 'N'},
'batchno_classid_x1y1x2y2_score': {0: 'N'},
}
)
model_onnx1 = onnx.load(onnx_file)
Expand Down

0 comments on commit 487247b

Please sign in to comment.