Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Rethink vocabulary API #3097

@matt-gardner

Description

@matt-gardner

The way we do vocab stuff with BERT / other pretrained wordpiece tokenizers is a little crazy and leads to subtle bugs. We need to get the BERT wordpiece vocab into our Vocabulary object so that we can pass it to the model and use it there if necessary. But if you save a vocab object with this vocabulary, it will give this error on load:

  File "tmp/tmp.py", line 30, in predict
    predictor = Predictor.from_path(path, 'masked_lm_predictor')
  File "/Users/mattg/clone/allennlp/allennlp/predictors/predictor.py", line 253, in from_path
    return Predictor.from_archive(load_archive(archive_path, cuda_device=cuda_device), predictor_name)
  File "/Users/mattg/clone/allennlp/allennlp/models/archival.py", line 230, in load_archive
    cuda_device=cuda_device)
  File "/Users/mattg/clone/allennlp/allennlp/models/model.py", line 327, in load
    return cls.by_name(model_type)._load(config, serialization_dir, weights_file, cuda_device)
  File "/Users/mattg/clone/allennlp/allennlp/models/model.py", line 256, in _load
    vocab = Vocabulary.by_name(vocab_choice).from_files(vocab_dir)
  File "/Users/mattg/clone/allennlp/allennlp/data/vocabulary.py", line 324, in from_files
    vocab.set_from_file(filename, is_padded, namespace=namespace)
  File "/Users/mattg/clone/allennlp/allennlp/data/vocabulary.py", line 378, in set_from_file
    assert self._oov_token in self._token_to_index[namespace], "OOV token not found!"
AssertionError: OOV token not found!

Somehow our existing BERT models are not actually saving the BERT vocabulary, even though it appears in the vocab object after indexing data. If I'm understanding things right, this is because we only count vocab items on our first pass through the data, and don't actually index anything when constructing the vocab. It's only after we've saved the vocab that we index stuff:

# Initializing the model can have side effect of expanding the vocabulary
vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))
iterator = DataIterator.from_params(params.pop("iterator"))
iterator.index_with(model.vocab)
validation_iterator_params = params.pop("validation_iterator", None)
if validation_iterator_params:
validation_iterator = DataIterator.from_params(validation_iterator_params)
validation_iterator.index_with(model.vocab)
else:
validation_iterator = None

This means that we are basically only avoiding the above bug because of an accident in when we are saving the vocab. We should probably figure out a better solution for this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions