diff --git a/caption.py b/caption.py index 145499014..8ca071905 100644 --- a/caption.py +++ b/caption.py @@ -197,7 +197,7 @@ def visualize_att(image_path, seq, alphas, rev_word_map, smooth=True): args = parser.parse_args() # Load model - checkpoint = torch.load(args.model) + checkpoint = torch.load(args.model, map_location=str(device)) decoder = checkpoint['decoder'] decoder = decoder.to(device) decoder.eval()