From 1df12112c93c3efc78e8101a436557409fd9cc77 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 17 May 2020 13:52:39 -0700 Subject: [PATCH] More gen param tweaks --- aitextgen/aitextgen.py | 2 +- aitextgen/train.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index af956de..4e76eb7 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -234,7 +234,7 @@ def generate( self, n: int = 1, prompt: str = None, - max_length: int = 200, + max_length: int = 1024, temperature: float = 0.7, do_sample: bool = True, return_as_list: bool = False, diff --git a/aitextgen/train.py b/aitextgen/train.py index 3f66842..5ca1d1d 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -191,7 +191,10 @@ def generate_sample_text(self, trainer, pl_module): gen_length = min(pl_module.model.config.n_positions, 256) outputs = pl_module.model.generate( - max_length=gen_length, do_sample=True, num_return_sequences=self.n_generate + max_length=gen_length, + do_sample=True, + num_return_sequences=self.n_generate, + temperature=0.7, ) gen_texts = [ pl_module.tokenizer.decode(output, skip_special_tokens=True)