diff --git a/inference.py b/inference.py index bcb3762..d157e3a 100644 --- a/inference.py +++ b/inference.py @@ -1,5 +1,7 @@ import torch +from utils import create_input_tensors +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class ViterbiDecoder(): """ @@ -61,3 +63,64 @@ def decode(self, scores, lengths): dim=1) return decoded + + +def main(): + checkpoint = torch.load('BEST_checkpoint_lm_lstm_crf.pth.tar') + model = checkpoint['model'] + optimizer = checkpoint['optimizer'] + word_map = checkpoint['word_map'] + lm_vocab_size = checkpoint['lm_vocab_size'] + tag_map = checkpoint['tag_map'] + char_map = checkpoint['char_map'] + start_epoch = checkpoint['epoch'] + 1 + best_f1 = checkpoint['f1'] + + model.eval() + + sentence = "does this thing work" + parsed = sentence.split() + wmaps, cmaps_f, cmaps_b, cmarkers_f, cmarkers_b, tmaps, wmap_lengths, cmap_lengths = create_input_tensors([parsed], [[next(iter(tag_map))] * len(parsed)], word_map, char_map, tag_map) + + max_word_len = max(wmap_lengths.tolist()) + max_char_len = max(cmap_lengths.tolist()) + + rev_tag_map = {v: k for k, v in tag_map.items()} + + # Reduce batch's padded length to maximum in-batch sequence + # This saves some compute on nn.Linear layers (RNNs are unaffected, since they don't compute over the pads) + wmaps = wmaps[:, :max_word_len].to(device) + cmaps_f = cmaps_f[:, :max_char_len].to(device) + cmaps_b = cmaps_b[:, :max_char_len].to(device) + cmarkers_f = cmarkers_f[:, :max_word_len].to(device) + cmarkers_b = cmarkers_b[:, :max_word_len].to(device) + tmaps = tmaps[:, :max_word_len].to(device) + wmap_lengths = wmap_lengths.to(device) + cmap_lengths = cmap_lengths.to(device) + + # Forward prop. + crf_scores, wmaps_sorted, tmaps_sorted, wmap_lengths_sorted, _, __ = model(cmaps_f, + cmaps_b, + cmarkers_f, + cmarkers_b, + wmaps, + tmaps, + wmap_lengths, + cmap_lengths) + + crf_scores = crf_scores.to('cpu') + wmap_lengths_sorted = wmap_lengths_sorted.to('cpu') + + decoder = ViterbiDecoder(tag_map) + output = decoder.decode(crf_scores, wmap_lengths_sorted) + + for idx, i in enumerate(output.data[0]): + if (idx == len(parsed)): + break + + print(parsed[idx] + " [" + rev_tag_map[i.item()] + "]") + + +if __name__ == '__main__': + main() +