diff --git a/hubconf.py b/hubconf.py index c65c768..49393c0 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,6 +1,7 @@ # import the necessary packages import torch from pyimagesearch import mlp +import os # define entry point/callable function # to initialize and return model @@ -13,6 +14,10 @@ def custom_model(): # initialize the model # load weights from path # returns model + repo_dir = os.path.dirname(__file__) model = mlp.get_training_model() - model.load_state_dict(torch.load("model_wt.pth")) - return model \ No newline at end of file + model_path = os.path.join(repo_dir, "output", "model_wt.pth") + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model file not found at {model_path}") + model.load_state_dict(torch.load(model_path)) + return model