Skip to content

Commit

Permalink
Correct default GPT-2 tokenizer behavior (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
minimaxir committed Apr 17, 2021
1 parent 6cff5f4 commit 429b39f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 37 deletions.
5 changes: 5 additions & 0 deletions aitextgen/TokenDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def __init__(
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
verbose=False,
)
# https://github.com/huggingface/transformers/issues/10202
tokenizer.add_special_tokens(
{"additional_special_tokens": ["<|endoftext|>"]}
)

# If a cache path is provided, load it.
Expand Down
29 changes: 9 additions & 20 deletions aitextgen/aitextgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
model_max_length,
reset_seed,
set_seed,
skip_special_tokens,
)

logger = logging.getLogger("aitextgen")
Expand Down Expand Up @@ -249,7 +248,13 @@ def __init__(
eos_token=self.eos_token,
unk_token=self.unk_token,
pad_token=self.pad_token,
verbose=False,
)
if not custom_tokenizer:
# https://github.com/huggingface/transformers/issues/10202
self.tokenizer.add_special_tokens(
{"additional_special_tokens": ["<|endoftext|>"]}
)

self.tokenizer.padding_side = "left"

Expand Down Expand Up @@ -376,14 +381,10 @@ def generate(
else token_tuple[0]
)

gen_text = skip_special_tokens(
output[start_index:end_index],
self.get_device(),
special_tokens,
gen_text_dict[key] = self.tokenizer.decode(
output[start_index:end_index], skip_special_tokens=True
)

gen_text_dict[key] = self.tokenizer.decode(gen_text)

# remove fields not in schema_return
if schema_return:
if len(schema_return) == 1:
Expand All @@ -404,19 +405,7 @@ def generate(

# Typical use case
else:
# Handle special token stripping at the PyTorch level
gen_texts = [
skip_special_tokens(
text,
self.get_device(),
special_tokens,
)
for text in outputs
]
if n > 1:
gen_texts = self.tokenizer.batch_decode(gen_texts)
else:
gen_texts = [self.tokenizer.decode(gen_texts[0])]
gen_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

# Handle stripping tokenization spaces w/ regex
if lstrip:
Expand Down
18 changes: 1 addition & 17 deletions aitextgen/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,23 +211,7 @@ def generate_sample_text(self, trainer, pl_module):
pad_token_id=pad_token_id,
)

special_token_id_tensor = torch.unique(
torch.as_tensor(
[pl_module.tokenizer.bos_token_id, pl_module.tokenizer.eos_token_id]
)
).to(pl_module.model.device.type)

outputs = [
output[
~output.unsqueeze(1).eq(special_token_id_tensor.unsqueeze(1)).any(1)
].tolist()
for output in outputs
]

if self.n_generate > 1:
gen_texts = pl_module.tokenizer.batch_decode(outputs)
else:
gen_texts = [pl_module.tokenizer.decode(outputs[0])]
gen_texts = pl_module.tokenizer.batch_decode(outputs, skip_special_tokens=True)

for text in gen_texts:
self.main_progress_bar.write("=" * 10)
Expand Down

0 comments on commit 429b39f

Please sign in to comment.