Skip to content

Commit

Permalink
nonempty_output generation param
Browse files Browse the repository at this point in the history
  • Loading branch information
minimaxir committed Apr 17, 2021
1 parent 429b39f commit f2ef5fd
Showing 1 changed file with 106 additions and 88 deletions.
194 changes: 106 additions & 88 deletions aitextgen/aitextgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def generate(
normalize_key: bool = True,
use_cache: bool = True,
lstrip: bool = True,
special_tokens: List[int] = None,
nonempty_output: bool = True,
**kwargs,
) -> Optional[str]:
"""
Expand Down Expand Up @@ -329,102 +329,120 @@ def generate(
gen_max_length = model_max_length(self.model.config)
max_length = min(gen_max_length, max_length)

outputs = self.model.generate(
input_ids=input_ids,
min_length=min_length,
max_length=max_length,
temperature=temperature,
do_sample=do_sample,
num_return_sequences=n,
pad_token_id=pad_token_id,
use_cache=use_cache,
**kwargs,
)
while True:
outputs = self.model.generate(
input_ids=input_ids,
min_length=min_length,
max_length=max_length,
temperature=temperature,
do_sample=do_sample,
num_return_sequences=n,
pad_token_id=pad_token_id,
use_cache=use_cache,
**kwargs,
)

# Reset seed if used
if seed:
reset_seed()
# Schema token handling
if schema:
schema_tokens = getattr(self.model.config, "schema_tokens")
schema_return = getattr(self.model.config, "schema_return")
schema_tokens_enc = self.tokenizer(text=schema_tokens)["input_ids"]

if special_tokens is None:
special_tokens = [self.tokenizer.bos_token_id, self.tokenizer.eos_token_id]

# Schema token handling
if schema:
schema_tokens = getattr(self.model.config, "schema_tokens")
schema_return = getattr(self.model.config, "schema_return")
schema_tokens_enc = self.tokenizer(text=schema_tokens)["input_ids"]

nonalphanum_pattern = re.compile(r"[\W_]+", re.UNICODE)

outputs = outputs.tolist()
gen_texts = []
for output in outputs:
gen_text_dict = {}

# Get indices of each schema token within the text
schema_token_indices = [
(schema_tokens[i], find_index_of_subset(output, token_enc))
for i, token_enc in enumerate(schema_tokens_enc)
]
schema_token_indices.sort(key=lambda x: x[1])

for i, token_tuple in enumerate(schema_token_indices):
start_index = token_tuple[1]
end_index = (
schema_token_indices[i + 1][1] - 1
if i + 1 < len(schema_token_indices)
else None
)
key = (
nonalphanum_pattern.sub("", token_tuple[0])
if normalize_key
else token_tuple[0]
)
nonalphanum_pattern = re.compile(r"[\W_]+", re.UNICODE)

gen_text_dict[key] = self.tokenizer.decode(
output[start_index:end_index], skip_special_tokens=True
)
outputs = outputs.tolist()
gen_texts = []
for output in outputs:
gen_text_dict = {}

# remove fields not in schema_return
if schema_return:
if len(schema_return) == 1:
gen_text_dict = gen_text_dict[schema_return[0]]
for key in gen_text_dict.keys():
if key not in schema_return:
gen_text_dict.pop(key, None)
# Get indices of each schema token within the text
schema_token_indices = [
(schema_tokens[i], find_index_of_subset(output, token_enc))
for i, token_enc in enumerate(schema_tokens_enc)
]
schema_token_indices.sort(key=lambda x: x[1])

for i, token_tuple in enumerate(schema_token_indices):
start_index = token_tuple[1]
end_index = (
schema_token_indices[i + 1][1] - 1
if i + 1 < len(schema_token_indices)
else None
)
key = (
nonalphanum_pattern.sub("", token_tuple[0])
if normalize_key
else token_tuple[0]
)

gen_text_dict[key] = self.tokenizer.decode(
output[start_index:end_index], skip_special_tokens=True
)

# remove fields not in schema_return
if schema_return:
if len(schema_return) == 1:
gen_text_dict = gen_text_dict[schema_return[0]]
for key in gen_text_dict.keys():
if key not in schema_return:
gen_text_dict.pop(key, None)

gen_texts.append(gen_text_dict)

# Reset seed if used
if seed:
reset_seed()

if not return_as_list:
print(*gen_texts, sep="\n" + "=" * 10 + "\n")
else:

gen_texts.append(gen_text_dict)
if n > 1:
return gen_texts
else:
return gen_texts[0]

if not return_as_list:
print(*gen_texts, sep="\n" + "=" * 10 + "\n")
# Typical use case
else:
if n > 1:
return gen_texts
else:
return gen_texts[0]

# Typical use case
else:
gen_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

# Handle stripping tokenization spaces w/ regex
if lstrip:
gen_texts = [re.sub(r"^\s+", "", text) for text in gen_texts]

if not return_as_list:
if prompt:
# Bold the prompt if printing to console
gen_texts = [
text.replace(prompt_text, f"\033[1m{prompt_text}\033[0m", 1)
for text in gen_texts
]
gen_texts = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True
)

if n > 1:
print(*gen_texts, sep="\n" + "=" * 10 + "\n")
# Handle stripping tokenization spaces w/ regex
if lstrip:
gen_texts = [re.sub(r"^\s+", "", text) for text in gen_texts]

if nonempty_output:
if min_length:
gen_texts = list(
filter(lambda x: len(x) > min_length, gen_texts)
)
else:
gen_texts = list(filter(lambda x: len(x) > 0, gen_texts))

# if there is no generated text after cleanup, try again.
if len(gen_texts) == 0:
continue

# Reset seed if used
if seed:
reset_seed()

if not return_as_list:
if prompt:
# Bold the prompt if printing to console
gen_texts = [
text.replace(prompt_text, f"\033[1m{prompt_text}\033[0m", 1)
for text in gen_texts
]

if n > 1:
print(*gen_texts, sep="\n" + "=" * 10 + "\n")
else:
print(gen_texts[0])
break
else:
print(gen_texts[0])
else:
return gen_texts
return gen_texts

def generate_one(self, **kwargs) -> None:
"""
Expand Down

0 comments on commit f2ef5fd

Please sign in to comment.