diff --git a/maxperf.py b/maxperf.py index b40e841..8124d13 100644 --- a/maxperf.py +++ b/maxperf.py @@ -1,3 +1,4 @@ +import os import sys import PIL from PyQt5 import QtWidgets, QtCore @@ -19,7 +20,19 @@ mw = None batchSize = 10 -prompts = ['Evil space kitty', 'Cute dog in hat, H.R. Giger style', 'Horse wearing a tie', 'Cartoon pig', 'Donkey on Mars', 'Cute kitties baked in a cake', 'Boxing chickens on farm, Maxfield Parish style', 'Future spaceship', 'A city of the past', 'Jabba the Hut wearing jewelery'] + +custom_prompts_path = "prompts.txt" +if os.path.exists(custom_prompts_path): + with open(custom_prompts_path, "r") as file: + lines = file.readlines() + prompts = [line.strip() for line in lines] + +else: + prompts = ['Evil space kitty', 'Cute dog in hat, H.R. Giger style', 'Horse wearing a tie', 'Cartoon pig', 'Donkey on Mars', 'Cute kitties baked in a cake', 'Boxing chickens on farm, Maxfield Parish style', 'Future spaceship', 'A city of the past', 'Jabba the Hut wearing jewelery'] + +print(f"Using the following prompts:", *prompts, sep='\n') + +prompts_len = len(prompts) def dwencode(pipe, prompts, batchSize: int, nTokens: int): tokenizer = pipe.tokenizer @@ -244,14 +257,21 @@ def genit(mode, prompts, batchSize, nSteps): return images if __name__ == '__main__': + + if len(sys.argv) == 2: batchSize = int(sys.argv[1]) - if batchSize > 10: - print('Batchsize must not be greater than 10.') - prompts = prompts[:batchSize] + + if batchSize > prompts_len: + prompts=prompts * (1 + batchSize // prompts_len) + print(prompts_len, prompts) + + else: - batchSize = 10 - prompts = ['Evil space kitty', 'Cute dog in hat, H.R. Giger style', 'Horse wearing a tie', 'Cartoon pig', 'Donkey on Mars', 'Cute kitties baked in a cake', 'Boxing chickens on farm, Maxfield Parish style', 'Future spaceship', 'A city of the past', 'Jabba the Hut wearing jewelery'] + batchSize = prompts_len + + prompts = prompts[:batchSize] + app = QApplication(sys.argv) mw = MainWindow() mw.show() diff --git a/prompts.txt b/prompts.txt new file mode 100644 index 0000000..e4609c6 --- /dev/null +++ b/prompts.txt @@ -0,0 +1 @@ +a fluffy cat meme \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 63e5bfb..26ce93e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,4 +21,5 @@ regex accelerate omegaconf piexif +olefile pathvalidate \ No newline at end of file diff --git a/src/stable_diffusion_base.py b/src/stable_diffusion_base.py index 56b144b..12297b1 100644 --- a/src/stable_diffusion_base.py +++ b/src/stable_diffusion_base.py @@ -87,7 +87,7 @@ def setup_torch_compilation(self): self.pipe.text_encoder = torch.compile(self.pipe.text_encoder, mode='max-autotune') self.pipe.unet = torch.compile(self.pipe.unet, mode='max-autotune') self.pipe.vae = torch.compile(self.pipe.vae, mode='max-autotune') - self.perform_warmup() + # self.perform_warmup() def perform_warmup(self): self._logger.info(