-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Rethink vocabulary API #3097
Description
The way we do vocab stuff with BERT / other pretrained wordpiece tokenizers is a little crazy and leads to subtle bugs. We need to get the BERT wordpiece vocab into our Vocabulary object so that we can pass it to the model and use it there if necessary. But if you save a vocab object with this vocabulary, it will give this error on load:
File "tmp/tmp.py", line 30, in predict
predictor = Predictor.from_path(path, 'masked_lm_predictor')
File "/Users/mattg/clone/allennlp/allennlp/predictors/predictor.py", line 253, in from_path
return Predictor.from_archive(load_archive(archive_path, cuda_device=cuda_device), predictor_name)
File "/Users/mattg/clone/allennlp/allennlp/models/archival.py", line 230, in load_archive
cuda_device=cuda_device)
File "/Users/mattg/clone/allennlp/allennlp/models/model.py", line 327, in load
return cls.by_name(model_type)._load(config, serialization_dir, weights_file, cuda_device)
File "/Users/mattg/clone/allennlp/allennlp/models/model.py", line 256, in _load
vocab = Vocabulary.by_name(vocab_choice).from_files(vocab_dir)
File "/Users/mattg/clone/allennlp/allennlp/data/vocabulary.py", line 324, in from_files
vocab.set_from_file(filename, is_padded, namespace=namespace)
File "/Users/mattg/clone/allennlp/allennlp/data/vocabulary.py", line 378, in set_from_file
assert self._oov_token in self._token_to_index[namespace], "OOV token not found!"
AssertionError: OOV token not found!
Somehow our existing BERT models are not actually saving the BERT vocabulary, even though it appears in the vocab object after indexing data. If I'm understanding things right, this is because we only count vocab items on our first pass through the data, and don't actually index anything when constructing the vocab. It's only after we've saved the vocab that we index stuff:
allennlp/allennlp/training/trainer_pieces.py
Lines 67 to 77 in 417a757
| # Initializing the model can have side effect of expanding the vocabulary | |
| vocab.save_to_files(os.path.join(serialization_dir, "vocabulary")) | |
| iterator = DataIterator.from_params(params.pop("iterator")) | |
| iterator.index_with(model.vocab) | |
| validation_iterator_params = params.pop("validation_iterator", None) | |
| if validation_iterator_params: | |
| validation_iterator = DataIterator.from_params(validation_iterator_params) | |
| validation_iterator.index_with(model.vocab) | |
| else: | |
| validation_iterator = None |
This means that we are basically only avoiding the above bug because of an accident in when we are saving the vocab. We should probably figure out a better solution for this.