@@ -53,36 +53,43 @@ def forward(self, x):
5353
5454 return x
5555
56+ @staticmethod
57+ def add_model_specific_args (parent_parser ):
5658
57- if __name__ == "__main__" :
58- parser = ArgumentParser (description = "PyTorch Autolog Mnist Example" )
59+ parser = ArgumentParser (parents = [parent_parser ], add_help = False )
5960
60- # Add trainer specific arguments
61+ # Add trainer specific arguments
62+ parser .add_argument (
63+ "--tracking_uri" , type = str , default = "http://localhost:5000/" , help = "mlflow tracking uri"
64+ )
65+ parser .add_argument (
66+ "--max_epochs" , type = int , default = 20 , help = "number of epochs to run (default: 20)"
67+ )
68+ parser .add_argument (
69+ "--gpus" , type = int , default = 0 , help = "Number of gpus - by default runs on CPU"
70+ )
71+ parser .add_argument (
72+ "--accelerator" ,
73+ type = str ,
74+ default = None ,
75+ help = "accelerator - (default: None)" ,
76+ )
77+ return parser
6178
62- parser .add_argument (
63- "--tracking_uri" , type = str , default = "http://localhost:5000/" , help = "mlflow tracking uri"
64- )
65- parser .add_argument (
66- "--max_epochs" , type = int , default = 20 , help = "number of epochs to run (default: 20)"
67- )
68- parser .add_argument (
69- "--gpus" , type = int , default = 0 , help = "Number of gpus - by default runs on CPU"
70- )
71- parser .add_argument (
72- "--accelerator" ,
73- type = str ,
74- default = None ,
75- help = "accelerator - (default: None)" ,
76- )
77- parser = LightningMNISTClassifier .add_model_specific_args (parent_parser = parser )
79+
80+ if __name__ == "__main__" :
81+ parent_parser = ArgumentParser (description = "PyTorch Autolog Mnist Example" )
82+
83+ parser = LightningMNISTClassifier .add_model_specific_args (parent_parser = parent_parser )
7884
7985 mlflow .pytorch .autolog () # just add this line and your Autologging should work!
8086
8187 args = parser .parse_args ()
8288
8389 args = parser .parse_args ()
8490 dict_args = vars (args )
85- #mlflow.set_tracking_uri(dict_args['tracking_uri'])
91+
92+ # mlflow.set_tracking_uri(dict_args['tracking_uri'])
8693
8794 model = LightningMNISTClassifier (** dict_args )
8895 early_stopping = EarlyStopping (monitor = "val_loss" , mode = "min" , verbose = True )
@@ -93,15 +100,15 @@ def forward(self, x):
93100 verbose = True ,
94101 monitor = "val_loss" ,
95102 mode = "min" ,
96- prefix = "" ,
103+ prefix = ""
97104 )
98105 lr_logger = LearningRateMonitor ()
99106
100107 trainer = pl .Trainer .from_argparse_args (
101108 args ,
102109 callbacks = [lr_logger , early_stopping ],
103110 checkpoint_callback = checkpoint_callback ,
104- train_percent_check = 0.1 ,
111+ # train_percent_check=0.1,
105112 )
106113 trainer .fit (model )
107114 trainer .test ()
0 commit comments