22# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
33import argparse
44import os
5+ import onnx
56import torch
67
78from detectron2 .checkpoint import DetectionCheckpointer
89from detectron2 .config import get_cfg
910from detectron2 .data import build_detection_test_loader
1011from detectron2 .evaluation import COCOEvaluator , inference_on_dataset , print_csv_format
11- from detectron2 .export import add_export_config , export_caffe2_model
12+ from detectron2 .export import Caffe2Tracer , add_export_config
1213from detectron2 .modeling import build_model
1314from detectron2 .utils .logger import setup_logger
1415
@@ -28,10 +29,16 @@ def setup_cfg(args):
2829
2930
3031if __name__ == "__main__" :
31- parser = argparse .ArgumentParser (description = "Convert a model to Caffe2" )
32+ parser = argparse .ArgumentParser (description = "Convert a model using caffe2 tracing." )
33+ parser .add_argument (
34+ "--format" ,
35+ choices = ["caffe2" , "onnx" , "torchscript" ],
36+ help = "output format" ,
37+ default = "caffe2" ,
38+ )
3239 parser .add_argument ("--config-file" , default = "" , metavar = "FILE" , help = "path to config file" )
3340 parser .add_argument ("--run-eval" , action = "store_true" )
34- parser .add_argument ("--output" , help = "output directory for the converted caffe2 model" )
41+ parser .add_argument ("--output" , help = "output directory for the converted model" )
3542 parser .add_argument (
3643 "opts" ,
3744 help = "Modify config options using the command-line" ,
@@ -41,6 +48,7 @@ def setup_cfg(args):
4148 args = parser .parse_args ()
4249 logger = setup_logger ()
4350 logger .info ("Command line arguments: " + str (args ))
51+ os .makedirs (args .output , exist_ok = True )
4452
4553 cfg = setup_cfg (args )
4654
@@ -53,13 +61,35 @@ def setup_cfg(args):
5361 first_batch = next (iter (data_loader ))
5462
5563 # convert and save caffe2 model
56- caffe2_model = export_caffe2_model (cfg , torch_model , first_batch )
57- caffe2_model .save_protobuf (args .output )
58- # draw the caffe2 graph
59- caffe2_model .save_graph (os .path .join (args .output , "model.svg" ), inputs = first_batch )
64+ tracer = Caffe2Tracer (cfg , torch_model , first_batch )
65+ if args .format == "caffe2" :
66+ caffe2_model = tracer .export_caffe2 ()
67+ caffe2_model .save_protobuf (args .output )
68+ # draw the caffe2 graph
69+ caffe2_model .save_graph (os .path .join (args .output , "model.svg" ), inputs = first_batch )
70+ elif args .format == "onnx" :
71+ onnx_model = tracer .export_onnx ()
72+ onnx .save (onnx_model , os .path .join (args .output , "model.onnx" ))
73+ elif args .format == "torchscript" :
74+ script_model = tracer .export_torchscript ()
75+ script_model .save (os .path .join (args .output , "model.ts" ))
76+
77+ # Recursively print IR of all modules
78+ with open (os .path .join (args .output , "model_ts_IR.txt" ), "w" ) as f :
79+ try :
80+ f .write (script_model ._actual_script_module ._c .dump_to_str (True , False , False ))
81+ except AttributeError :
82+ pass
83+ # Print IR of the entire graph (all submodules inlined)
84+ with open (os .path .join (args .output , "model_ts_IR_inlined.txt" ), "w" ) as f :
85+ f .write (str (script_model .inlined_graph ))
86+ # Print the model structure in pytorch style
87+ with open (os .path .join (args .output , "model.txt" ), "w" ) as f :
88+ f .write (str (script_model ))
6089
6190 # run evaluation with the converted model
6291 if args .run_eval :
92+ assert args .format == "caffe2" , "Python inference in other format is not yet supported."
6393 dataset = cfg .DATASETS .TEST [0 ]
6494 data_loader = build_detection_test_loader (cfg , dataset )
6595 # NOTE: hard-coded evaluator. change to the evaluator for your dataset
0 commit comments