Skip to content

Commit 1d728a7

Browse files
authored
Merge pull request #2558 from bghira/bugfix/ltx2-audio-alignment
LTX-2 audio fps should come from --framerate, not the dataset
2 parents 42ac3bb + df195f6 commit 1d728a7

2 files changed

Lines changed: 20 additions & 4 deletions

File tree

simpletuner/helpers/caching/vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1167,7 +1167,7 @@ def _align_audio_waveform_to_video(self, waveform, sample_rate, metadata: dict,
11671167
video_config = source_config.get("video") or {}
11681168

11691169
target_num_frames = video_config.get("num_frames") or video_meta.get("num_frames")
1170-
fps = video_meta.get("fps") or getattr(StateTracker.get_args(), "framerate", None) or 25
1170+
fps = getattr(StateTracker.get_args(), "framerate", None)
11711171
if not target_num_frames or not fps:
11721172
return waveform, metadata
11731173

simpletuner/helpers/metadata/backends/discovery.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,22 @@ def _process_audio_sample(
622622

623623
num_channels, num_samples = waveform.shape[0], waveform.shape[1]
624624
duration_seconds = float(num_samples) / float(sample_rate) if sample_rate else None
625+
626+
# For source_from_video, compute target duration based on video num_frames/framerate
627+
# to match what _align_audio_waveform_to_video will do during VAE caching
628+
bucket_duration = duration_seconds
629+
if source_from_video:
630+
source_dataset_id = self.dataset_config.get("source_dataset_id") or self.audio_config.get(
631+
"source_dataset_id"
632+
)
633+
if source_dataset_id:
634+
source_config = StateTracker.get_data_backend_config(data_backend_id=source_dataset_id) or {}
635+
video_config = source_config.get("video") or {}
636+
target_num_frames = video_config.get("num_frames")
637+
fps = getattr(StateTracker.get_args(), "framerate", None)
638+
if target_num_frames and fps:
639+
bucket_duration = float(target_num_frames) / float(fps)
640+
625641
audio_metadata = self._build_audio_metadata_entry(
626642
sample_path=image_path_str,
627643
sample_rate=sample_rate,
@@ -636,16 +652,16 @@ def _process_audio_sample(
636652
audio_metadata[key] = value
637653

638654
max_duration = self.audio_max_duration_seconds
639-
if max_duration is not None and duration_seconds and duration_seconds > max_duration:
655+
if max_duration is not None and bucket_duration and bucket_duration > max_duration:
640656
logger.debug(
641-
f"Audio sample {image_path_str} duration {duration_seconds:.2f}s exceeds "
657+
f"Audio sample {image_path_str} duration {bucket_duration:.2f}s exceeds "
642658
f"limit {max_duration:.2f}s. Skipping."
643659
)
644660
skipped = statistics.setdefault("skipped", {})
645661
skipped["too_long"] = skipped.get("too_long", 0) + 1
646662
return aspect_ratio_bucket_indices
647663

648-
bucket_key, truncated_duration = self._compute_audio_bucket(duration_seconds)
664+
bucket_key, truncated_duration = self._compute_audio_bucket(bucket_duration)
649665
audio_metadata["original_duration_seconds"] = duration_seconds
650666
if truncated_duration is not None:
651667
audio_metadata["duration_seconds"] = truncated_duration

0 commit comments

Comments
 (0)