From 094f0b9dc3c4906fe7506481f4d37475bf84050d Mon Sep 17 00:00:00 2001 From: ngshya Date: Wed, 15 Jan 2020 12:04:04 +0100 Subject: [PATCH] added map_location=str(device) in torch.load() --- caption.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caption.py b/caption.py index 145499014..8ca071905 100644 --- a/caption.py +++ b/caption.py @@ -197,7 +197,7 @@ def visualize_att(image_path, seq, alphas, rev_word_map, smooth=True): args = parser.parse_args() # Load model - checkpoint = torch.load(args.model) + checkpoint = torch.load(args.model, map_location=str(device)) decoder = checkpoint['decoder'] decoder = decoder.to(device) decoder.eval()