Skip to content

Commit

Permalink
Merge pull request #544 from kentaroy47/pytorch-1.0
Browse files Browse the repository at this point in the history
fix for pytorch1.1
  • Loading branch information
jwyang authored May 9, 2019
2 parents 7589307 + 96a4037 commit 0a1a74b
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
9 changes: 5 additions & 4 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,11 @@ def _get_image_blob(im):
im_data_pt = im_data_pt.permute(0, 3, 1, 2)
im_info_pt = torch.from_numpy(im_info_np)

im_data.data.resize_(im_data_pt.size()).copy_(im_data_pt)
im_info.data.resize_(im_info_pt.size()).copy_(im_info_pt)
gt_boxes.data.resize_(1, 1, 5).zero_()
num_boxes.data.resize_(1).zero_()
with torch.no_grad():
im_data.resize_(im_data_pt.size()).copy_(im_data_pt)
im_info.resize_(im_info_pt.size()).copy_(im_info_pt)
gt_boxes.resize_(1, 1, 5).zero_()
num_boxes.resize_(1).zero_()

# pdb.set_trace()
det_tic = time.time()
Expand Down
9 changes: 5 additions & 4 deletions test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,11 @@ def parse_args():
for i in range(num_images):

data = next(data_iter)
im_data.data.resize_(data[0].size()).copy_(data[0])
im_info.data.resize_(data[1].size()).copy_(data[1])
gt_boxes.data.resize_(data[2].size()).copy_(data[2])
num_boxes.data.resize_(data[3].size()).copy_(data[3])
with torch.no_grad():
im_data.resize_(data[0].size()).copy_(data[0])
im_info.resize_(data[1].size()).copy_(data[1])
gt_boxes.resize_(data[2].size()).copy_(data[2])
num_boxes.resize_(data[3].size()).copy_(data[3])

det_tic = time.time()
rois, cls_prob, bbox_pred, \
Expand Down
9 changes: 5 additions & 4 deletions trainval_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,11 @@ def __len__(self):
data_iter = iter(dataloader)
for step in range(iters_per_epoch):
data = next(data_iter)
im_data.data.resize_(data[0].size()).copy_(data[0])
im_info.data.resize_(data[1].size()).copy_(data[1])
gt_boxes.data.resize_(data[2].size()).copy_(data[2])
num_boxes.data.resize_(data[3].size()).copy_(data[3])
with torch.no_grad():
im_data.resize_(data[0].size()).copy_(data[0])
im_info.resize_(data[1].size()).copy_(data[1])
gt_boxes.resize_(data[2].size()).copy_(data[2])
num_boxes.resize_(data[3].size()).copy_(data[3])

fasterRCNN.zero_grad()
rois, cls_prob, bbox_pred, \
Expand Down

0 comments on commit 0a1a74b

Please sign in to comment.