@@ -622,6 +622,22 @@ def _process_audio_sample(
622622
623623 num_channels , num_samples = waveform .shape [0 ], waveform .shape [1 ]
624624 duration_seconds = float (num_samples ) / float (sample_rate ) if sample_rate else None
625+
626+ # For source_from_video, compute target duration based on video num_frames/framerate
627+ # to match what _align_audio_waveform_to_video will do during VAE caching
628+ bucket_duration = duration_seconds
629+ if source_from_video :
630+ source_dataset_id = self .dataset_config .get ("source_dataset_id" ) or self .audio_config .get (
631+ "source_dataset_id"
632+ )
633+ if source_dataset_id :
634+ source_config = StateTracker .get_data_backend_config (data_backend_id = source_dataset_id ) or {}
635+ video_config = source_config .get ("video" ) or {}
636+ target_num_frames = video_config .get ("num_frames" )
637+ fps = getattr (StateTracker .get_args (), "framerate" , None )
638+ if target_num_frames and fps :
639+ bucket_duration = float (target_num_frames ) / float (fps )
640+
625641 audio_metadata = self ._build_audio_metadata_entry (
626642 sample_path = image_path_str ,
627643 sample_rate = sample_rate ,
@@ -636,16 +652,16 @@ def _process_audio_sample(
636652 audio_metadata [key ] = value
637653
638654 max_duration = self .audio_max_duration_seconds
639- if max_duration is not None and duration_seconds and duration_seconds > max_duration :
655+ if max_duration is not None and bucket_duration and bucket_duration > max_duration :
640656 logger .debug (
641- f"Audio sample { image_path_str } duration { duration_seconds :.2f} s exceeds "
657+ f"Audio sample { image_path_str } duration { bucket_duration :.2f} s exceeds "
642658 f"limit { max_duration :.2f} s. Skipping."
643659 )
644660 skipped = statistics .setdefault ("skipped" , {})
645661 skipped ["too_long" ] = skipped .get ("too_long" , 0 ) + 1
646662 return aspect_ratio_bucket_indices
647663
648- bucket_key , truncated_duration = self ._compute_audio_bucket (duration_seconds )
664+ bucket_key , truncated_duration = self ._compute_audio_bucket (bucket_duration )
649665 audio_metadata ["original_duration_seconds" ] = duration_seconds
650666 if truncated_duration is not None :
651667 audio_metadata ["duration_seconds" ] = truncated_duration
0 commit comments