diff --git a/mnist_visualize.py b/mnist_visualize.py index db0c39e..0b45cc9 100644 --- a/mnist_visualize.py +++ b/mnist_visualize.py @@ -44,7 +44,7 @@ N = 10 for data_batch, target_batch in test_loader: for data, target in zip(data_batch, target_batch): - data_list = target2data_list[target] + data_list = target2data_list[int(target)] if len(data_list) < N: data_list.append(data) total += 1