Skip to content

Commit

Permalink
Removed NLLBTokenizer.
Browse files Browse the repository at this point in the history
  • Loading branch information
visheratin committed Dec 8, 2023
1 parent 9b6b13f commit ee57769
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 73 deletions.
11 changes: 2 additions & 9 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
from .tokenizer import HFTokenizer, NLLBTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH
from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH

HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
Expand Down Expand Up @@ -110,18 +110,11 @@ def get_tokenizer(
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)

if 'hf_tokenizer_name' in text_config:
if model_name.startswith("nllb"):
tokenizer = NLLBTokenizer(
tokenizer = HFTokenizer(
text_config['hf_tokenizer_name'],
context_length=context_length,
**tokenizer_kwargs,
)
else:
tokenizer = HFTokenizer(
text_config['hf_tokenizer_name'],
context_length=context_length,
**tokenizer_kwargs,
)
else:
tokenizer = SimpleTokenizer(
context_length=context_length,
Expand Down
77 changes: 13 additions & 64 deletions src/open_clip/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import string
from functools import lru_cache, partial
from typing import Callable, List, Optional, Union
import warnings

import ftfy
import numpy as np
Expand Down Expand Up @@ -402,9 +403,17 @@ def __init__(
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
clean: str = 'whitespace',
strip_sep_token: bool = False,
language: Optional[str] = None,
):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if language is not None:
set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)
if callable(set_lang_fn):
set_lang_fn(language)
self.set_lang_fn = set_lang_fn
else:
warnings.warn(f'Cannot set language for tokenizer {tokenizer_name}.')
self.context_length = context_length
self.clean_fn = get_clean_fn(clean)
self.strip_sep_token = strip_sep_token
Expand Down Expand Up @@ -438,6 +447,10 @@ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] =
)

return input_ids

def set_language(self, src_lang):
if hasattr(self, 'set_lang_fn'):
self.set_lang_fn(src_lang)


class SigLipTokenizer:
Expand Down Expand Up @@ -495,67 +508,3 @@ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] =
truncation=True,
)
return output.input_ids


class NLLBTokenizer:
"""HuggingFace tokenizer wrapper for NLLB models"""

def __init__(
self,
tokenizer_name: str,
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
clean: str = "whitespace",
):
from transformers import AutoTokenizer

self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.context_length = context_length
self.clean_fn = get_clean_fn(clean)

def save_pretrained(self, dest):
self.tokenizer.save_pretrained(dest)

def __call__(
self,
texts: Union[str, List[str]],
langs: Union[str, List[str], None],
context_length: Optional[int] = None,
) -> torch.Tensor:
import warnings

if isinstance(texts, str):
texts = [texts]

context_length = context_length or self.context_length
assert (
context_length
), "Please set a valid context length in class init or call."

# same cleaning as for default tokenizer, except lowercasing
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
texts = [self.clean_fn(text) for text in texts]
if langs is None:
warnings.warn("No languages provided, assuming all texts are in English.")
input_ids = self.tokenizer.batch_encode_plus(
texts,
return_tensors="pt",
max_length=context_length,
padding="max_length",
truncation=True,
).input_ids
else:
assert len(texts) == len(langs), "Please provide a language for each text."
text_input_ids = []
for i, text in enumerate(texts):
self.tokenizer.set_src_lang_special_tokens(langs[i])
text_input_ids.append(
self.tokenizer.batch_encode_plus(
[text],
return_tensors="pt",
max_length=context_length,
padding="max_length",
truncation=True,
).input_ids
)
input_ids = torch.stack(text_input_ids).squeeze()
return input_ids

0 comments on commit ee57769

Please sign in to comment.