Skip to content

Commit 848e798

Browse files
Added resolution scaling
used code from lllyasviel#158
1 parent 5e14b91 commit 848e798

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

diffusers_helper/bucket_tools.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,36 @@
11
bucket_options = {
2-
640: [
3-
(416, 960),
4-
(448, 864),
5-
(480, 832),
6-
(512, 768),
7-
(544, 704),
8-
(576, 672),
9-
(608, 640),
10-
(640, 608),
11-
(672, 576),
12-
(704, 544),
13-
(768, 512),
14-
(832, 480),
15-
(864, 448),
16-
(960, 416),
17-
],
2+
(416, 960),
3+
(448, 864),
4+
(480, 832),
5+
(512, 768),
6+
(544, 704),
7+
(576, 672),
8+
(608, 640),
9+
(640, 640),# square
10+
(640, 608),
11+
(672, 576),
12+
(704, 544),
13+
(768, 512),
14+
(832, 480),
15+
(864, 448),
16+
(960, 416),
1817
}
1918

2019

2120
def find_nearest_bucket(h, w, resolution=640):
2221
min_metric = float('inf')
2322
best_bucket = None
24-
for (bucket_h, bucket_w) in bucket_options[resolution]:
23+
for (bucket_h, bucket_w) in bucket_options:
2524
metric = abs(h * bucket_w - w * bucket_h)
2625
if metric <= min_metric:
2726
min_metric = metric
2827
best_bucket = (bucket_h, bucket_w)
29-
return best_bucket
3028

29+
if resolution != 640:
30+
scale_factor = resolution / 640.0
31+
scaled_height = round(best_bucket[0] * scale_factor / 16) * 16
32+
scaled_width = round(best_bucket[1] * scale_factor / 16) * 16
33+
best_bucket = (scaled_height, scaled_width)
34+
35+
print(f'--------------> Resolution: {best_bucket[1]} x {best_bucket[0]} for input res {resolution}')
36+
return best_bucket

0 commit comments

Comments
 (0)