diff --git a/train.py b/train.py index 3c3513a3..77f96ead 100644 --- a/train.py +++ b/train.py @@ -90,10 +90,6 @@ def weights_init(m): crnn = crnn.CRNN(opt.imgH, nc, nclass, opt.nh) crnn.apply(weights_init) -if opt.pretrained != '': - print('loading pretrained model from %s' % opt.pretrained) - crnn.load_state_dict(torch.load(opt.pretrained)) -print(crnn) image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH) text = torch.IntTensor(opt.batchSize * 5) @@ -105,6 +101,13 @@ def weights_init(m): image = image.cuda() criterion = criterion.cuda() +if opt.pretrained != '': + print('loading pretrained model from %s' % opt.pretrained) + crnn.load_state_dict(torch.load(opt.pretrained)) +print(crnn) + + + image = Variable(image) text = Variable(text) length = Variable(length)