@@ -1182,7 +1182,8 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
1182
1182
1183
1183
frame_batch_size = self .num_sample_frames_batch_size
1184
1184
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1185
- num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
1185
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1186
+ num_batches = max (num_frames // frame_batch_size , 1 )
1186
1187
conv_cache = None
1187
1188
enc = []
1188
1189
@@ -1330,7 +1331,8 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1330
1331
row = []
1331
1332
for j in range (0 , width , overlap_width ):
1332
1333
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1333
- num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
1334
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1335
+ num_batches = max (num_frames // frame_batch_size , 1 )
1334
1336
conv_cache = None
1335
1337
time = []
1336
1338
@@ -1409,7 +1411,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
1409
1411
for i in range (0 , height , overlap_height ):
1410
1412
row = []
1411
1413
for j in range (0 , width , overlap_width ):
1412
- num_batches = num_frames // frame_batch_size
1414
+ num_batches = max ( num_frames // frame_batch_size , 1 )
1413
1415
conv_cache = None
1414
1416
time = []
1415
1417
0 commit comments