diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 10a638e..719e8f6 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -282,7 +282,7 @@ def generate( normalize_key: bool = True, use_cache: bool = True, lstrip: bool = True, - special_tokens: List[int] = None, + nonempty_output: bool = True, **kwargs, ) -> Optional[str]: """ @@ -329,102 +329,120 @@ def generate( gen_max_length = model_max_length(self.model.config) max_length = min(gen_max_length, max_length) - outputs = self.model.generate( - input_ids=input_ids, - min_length=min_length, - max_length=max_length, - temperature=temperature, - do_sample=do_sample, - num_return_sequences=n, - pad_token_id=pad_token_id, - use_cache=use_cache, - **kwargs, - ) + while True: + outputs = self.model.generate( + input_ids=input_ids, + min_length=min_length, + max_length=max_length, + temperature=temperature, + do_sample=do_sample, + num_return_sequences=n, + pad_token_id=pad_token_id, + use_cache=use_cache, + **kwargs, + ) - # Reset seed if used - if seed: - reset_seed() + # Schema token handling + if schema: + schema_tokens = getattr(self.model.config, "schema_tokens") + schema_return = getattr(self.model.config, "schema_return") + schema_tokens_enc = self.tokenizer(text=schema_tokens)["input_ids"] - if special_tokens is None: - special_tokens = [self.tokenizer.bos_token_id, self.tokenizer.eos_token_id] - - # Schema token handling - if schema: - schema_tokens = getattr(self.model.config, "schema_tokens") - schema_return = getattr(self.model.config, "schema_return") - schema_tokens_enc = self.tokenizer(text=schema_tokens)["input_ids"] - - nonalphanum_pattern = re.compile(r"[\W_]+", re.UNICODE) - - outputs = outputs.tolist() - gen_texts = [] - for output in outputs: - gen_text_dict = {} - - # Get indices of each schema token within the text - schema_token_indices = [ - (schema_tokens[i], find_index_of_subset(output, token_enc)) - for i, token_enc in enumerate(schema_tokens_enc) - ] - schema_token_indices.sort(key=lambda x: x[1]) - - for i, token_tuple in enumerate(schema_token_indices): - start_index = token_tuple[1] - end_index = ( - schema_token_indices[i + 1][1] - 1 - if i + 1 < len(schema_token_indices) - else None - ) - key = ( - nonalphanum_pattern.sub("", token_tuple[0]) - if normalize_key - else token_tuple[0] - ) + nonalphanum_pattern = re.compile(r"[\W_]+", re.UNICODE) - gen_text_dict[key] = self.tokenizer.decode( - output[start_index:end_index], skip_special_tokens=True - ) + outputs = outputs.tolist() + gen_texts = [] + for output in outputs: + gen_text_dict = {} - # remove fields not in schema_return - if schema_return: - if len(schema_return) == 1: - gen_text_dict = gen_text_dict[schema_return[0]] - for key in gen_text_dict.keys(): - if key not in schema_return: - gen_text_dict.pop(key, None) + # Get indices of each schema token within the text + schema_token_indices = [ + (schema_tokens[i], find_index_of_subset(output, token_enc)) + for i, token_enc in enumerate(schema_tokens_enc) + ] + schema_token_indices.sort(key=lambda x: x[1]) + + for i, token_tuple in enumerate(schema_token_indices): + start_index = token_tuple[1] + end_index = ( + schema_token_indices[i + 1][1] - 1 + if i + 1 < len(schema_token_indices) + else None + ) + key = ( + nonalphanum_pattern.sub("", token_tuple[0]) + if normalize_key + else token_tuple[0] + ) + + gen_text_dict[key] = self.tokenizer.decode( + output[start_index:end_index], skip_special_tokens=True + ) + + # remove fields not in schema_return + if schema_return: + if len(schema_return) == 1: + gen_text_dict = gen_text_dict[schema_return[0]] + for key in gen_text_dict.keys(): + if key not in schema_return: + gen_text_dict.pop(key, None) + + gen_texts.append(gen_text_dict) + + # Reset seed if used + if seed: + reset_seed() + + if not return_as_list: + print(*gen_texts, sep="\n" + "=" * 10 + "\n") + else: - gen_texts.append(gen_text_dict) + if n > 1: + return gen_texts + else: + return gen_texts[0] - if not return_as_list: - print(*gen_texts, sep="\n" + "=" * 10 + "\n") + # Typical use case else: - if n > 1: - return gen_texts - else: - return gen_texts[0] - - # Typical use case - else: - gen_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) - - # Handle stripping tokenization spaces w/ regex - if lstrip: - gen_texts = [re.sub(r"^\s+", "", text) for text in gen_texts] - - if not return_as_list: - if prompt: - # Bold the prompt if printing to console - gen_texts = [ - text.replace(prompt_text, f"\033[1m{prompt_text}\033[0m", 1) - for text in gen_texts - ] + gen_texts = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True + ) - if n > 1: - print(*gen_texts, sep="\n" + "=" * 10 + "\n") + # Handle stripping tokenization spaces w/ regex + if lstrip: + gen_texts = [re.sub(r"^\s+", "", text) for text in gen_texts] + + if nonempty_output: + if min_length: + gen_texts = list( + filter(lambda x: len(x) > min_length, gen_texts) + ) + else: + gen_texts = list(filter(lambda x: len(x) > 0, gen_texts)) + + # if there is no generated text after cleanup, try again. + if len(gen_texts) == 0: + continue + + # Reset seed if used + if seed: + reset_seed() + + if not return_as_list: + if prompt: + # Bold the prompt if printing to console + gen_texts = [ + text.replace(prompt_text, f"\033[1m{prompt_text}\033[0m", 1) + for text in gen_texts + ] + + if n > 1: + print(*gen_texts, sep="\n" + "=" * 10 + "\n") + else: + print(gen_texts[0]) + break else: - print(gen_texts[0]) - else: - return gen_texts + return gen_texts def generate_one(self, **kwargs) -> None: """