Skip to content

Commit

Permalink
Add special_tokens param to generate()
Browse files Browse the repository at this point in the history
  • Loading branch information
minimaxir committed Mar 5, 2021
1 parent 869e06a commit 3c76f0a
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions aitextgen/aitextgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def generate(
normalize_key: bool = True,
use_cache: bool = True,
lstrip: bool = True,
special_tokens: List[str] = None,
**kwargs,
) -> Optional[str]:
"""
Expand Down Expand Up @@ -324,6 +325,9 @@ def generate(
if seed:
reset_seed()

if special_tokens is None:
special_tokens = [self.tokenizer.bos_token_id, self.tokenizer.eos_token_id]

# Schema token handling
if schema:
schema_tokens = getattr(self.model.config, "schema_tokens")
Expand Down Expand Up @@ -360,7 +364,7 @@ def generate(
gen_text = skip_special_tokens(
output[start_index:end_index],
self.get_device(),
[self.tokenizer.bos_token_id, self.tokenizer.eos_token_id],
special_tokens,
)

gen_text_dict[key] = self.tokenizer.decode(gen_text)
Expand Down Expand Up @@ -390,7 +394,7 @@ def generate(
skip_special_tokens(
text,
self.get_device(),
[self.tokenizer.bos_token_id, self.tokenizer.eos_token_id],
special_tokens,
)
for text in outputs
]
Expand Down

0 comments on commit 3c76f0a

Please sign in to comment.