diff --git a/predict.py b/predict.py index 4c280c5..3a8f0a1 100644 --- a/predict.py +++ b/predict.py @@ -514,7 +514,8 @@ def bf16_predict( ) timesteps = get_schedule( num_inference_steps, - x.shape[1], + # equivalent to inp["img"].shape[1], needs to be here for prompt strength in img2img + (x.shape[-1] * x.shape[-2]) // 4, shift=self.shift, )