diff --git a/tracker/model/network.py b/tracker/model/network.py index c5f179d..70b7e92 100644 --- a/tracker/model/network.py +++ b/tracker/model/network.py @@ -142,7 +142,7 @@ def init_hyperparameters(self, config, model_path=None, map_location=None): if model_path is not None: # load the model and key/value/hidden dimensions with some hacks # config is updated with the loaded parameters - model_weights = torch.load(model_path, map_location=map_location) + model_weights = torch.load(model_path, map_location="cpu") self.key_dim = model_weights['key_proj.key_proj.weight'].shape[0] self.value_dim = model_weights['value_encoder.fuser.block2.conv2.weight'].shape[0] self.disable_hidden = 'decoder.hidden_update.transform.weight' not in model_weights