diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 90dd06d33c5e..a2ea8f17dd74 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -171,7 +171,13 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight autocast_ctx = torch.autocast(accelerator.device.type) with autocast_ctx: - image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] + image = pipeline( + args.validation_prompts[i], + height=args.resolution, + width=args.resolution, + num_inference_steps=20, + generator=generator, + ).images[0] images.append(image) @@ -1150,7 +1156,13 @@ def unwrap_model(model): for i in range(len(args.validation_prompts)): with torch.autocast("cuda"): - image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] + image = pipeline( + args.validation_prompts[i], + height=args.resolution, + width=args.resolution, + num_inference_steps=20, + generator=generator, + ).images[0] images.append(image) if args.push_to_hub: diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 310a50ac4e9a..9f10ba0586eb 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -131,7 +131,15 @@ def log_validation( with autocast_ctx: for _ in range(args.num_validation_images): - images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + images.append( + pipeline( + args.validation_prompt, + height=args.resolution, + width=args.resolution, + num_inference_steps=30, + generator=generator, + ).images[0] + ) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 88f5c3cede6e..900e6ea3d611 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -138,7 +138,11 @@ def log_validation( # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None - pipeline_args = {"prompt": args.validation_prompt} + pipeline_args = { + "prompt": args.validation_prompt, + "height": args.resolution, + "width": args.resolution, + } if torch.backends.mps.is_available(): autocast_ctx = nullcontext() else: diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 4eafa8f28a19..5fc0d323a7a1 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1258,7 +1258,11 @@ def compute_time_ids(original_size, crops_coords_top_left): if args.seed is not None else None ) - pipeline_args = {"prompt": args.validation_prompt} + pipeline_args = { + "prompt": args.validation_prompt, + "height": args.resolution, + "width": args.resolution, + } with autocast_ctx: images = [ @@ -1327,7 +1331,13 @@ def compute_time_ids(original_size, crops_coords_top_left): with autocast_ctx: images = [ - pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + pipeline( + args.validation_prompt, + height=args.resolution, + width=args.resolution, + num_inference_steps=25, + generator=generator, + ).images[0] for _ in range(args.num_validation_images) ]