Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions open_dubbing/audio_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import os
import warnings

import pysrt
import re

from typing import Final, Mapping, Sequence

import numpy as np
Expand All @@ -37,22 +40,51 @@ def create_pyannote_timestamps(
audio_file: str,
pipeline: Pipeline,
device: str = "cpu",
input_srt: str | None = None,
) -> Sequence[Mapping[str, float]]:
"""Creates timestamps from a vocals file using Pyannote speaker diarization.

Returns:
A list of dictionaries containing start and end timestamps for each
speaker segment.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
if device == "cuda":
pipeline.to(torch.device("cuda"))
diarization = pipeline(audio_file)
utterance_metadata = [
{"start": segment.start, "end": segment.end, "speaker_id": speaker}
for segment, _, speaker in diarization.itertracks(yield_label=True)
]
if not input_srt:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
if device == "cuda":
pipeline.to(torch.device("cuda"))
diarization = pipeline(audio_file)
utterance_metadata = [
{"start": segment.start, "end": segment.end, "speaker_id": speaker}
for segment, _, speaker in diarization.itertracks(yield_label=True)
]
return utterance_metadata
else:
subs = pysrt.open(input_srt)
utterance_metadata = []

for sub in subs:
match = re.match(r'\[(SPEAKER_\d+)\]:', sub.text.strip())
if match:
speaker_id = match.group(1)
start_seconds = (
sub.start.hours * 3600 +
sub.start.minutes * 60 +
sub.start.seconds +
sub.start.milliseconds / 1000
)
end_seconds = (
sub.end.hours * 3600 +
sub.end.minutes * 60 +
sub.end.seconds +
sub.end.milliseconds / 1000
)
utterance_metadata.append({
"start": round(start_seconds, 3),
"end": round(end_seconds, 3),
"speaker_id": speaker_id
})

return utterance_metadata


Expand Down
7 changes: 7 additions & 0 deletions open_dubbing/command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ def read_parameters():
help="Update the dubbed video produced by a previous execution with the latest changes in utterance_metadata file",
)

parser.add_argument(
"--input_srt",
type=str,
default="",
help=("A .srt file with speaker annotations that overrides audio segmentation, speaker diarization and speech to text"),
)

parser.add_argument(
"--original_subtitles",
action="store_true",
Expand Down
5 changes: 5 additions & 0 deletions open_dubbing/dubbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def __init__(
clean_intermediate_files: bool = False,
original_subtitles: bool = False,
dubbed_subtitles: bool = False,
input_srt: str | None = None,

) -> None:
self._input_file = input_file
self.output_directory = output_directory
Expand All @@ -128,6 +130,7 @@ def __init__(
self.preprocessing_output = None
self.original_subtitles = original_subtitles
self.dubbed_subtitles = dubbed_subtitles
self.input_srt = input_srt

if cpu_threads > 0:
torch.set_num_threads(cpu_threads)
Expand Down Expand Up @@ -210,6 +213,7 @@ def run_preprocessing(self) -> None:
audio_file=audio_file,
pipeline=self.pyannote_pipeline,
device=device_pyannote,
input_srt=self.input_srt,
)
utterance_metadata = audio_processing.run_cut_and_save_audio(
utterance_metadata=utterance_metadata,
Expand Down Expand Up @@ -240,6 +244,7 @@ def run_speech_to_text(self) -> None:
utterance_metadata=self.utterance_metadata,
source_language=self.source_language,
no_dubbing_phrases=[],
input_srt=self.input_srt,
)
speaker_info = self.stt.predict_gender(
file=media_file,
Expand Down
1 change: 1 addition & 0 deletions open_dubbing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def main():
clean_intermediate_files=args.clean_intermediate_files,
original_subtitles=args.original_subtitles,
dubbed_subtitles=args.dubbed_subtitles,
input_srt=args.input_srt,
)

logger().info(
Expand Down
57 changes: 45 additions & 12 deletions open_dubbing/speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import array
import os
import re
import pysrt

from abc import ABC, abstractmethod
from typing import Mapping, Sequence
Expand Down Expand Up @@ -79,35 +80,67 @@ def _make_sure_single_space(self, sentence: str) -> str:
fixed = fixed.strip()
return fixed

def _srt_time_to_seconds(self, t):
return t.hours * 3600 + t.minutes * 60 + t.seconds + t.milliseconds / 1000.0

def transcribe_audio_chunks(
self,
*,
utterance_metadata: Sequence[Mapping[str, float | str]],
source_language: str,
no_dubbing_phrases: Sequence[str],
input_srt: str | None = None,
) -> Sequence[Mapping[str, float | str]]:

logger().debug(f"transcribe_audio_chunks: {source_language}")
iso_639_1 = self._get_iso_639_1(source_language)

if input_srt:
logger().debug(f"transcribe_audio_chunks: read transcripts from from {input_srt}")
subs = pysrt.open(input_srt)

updated_utterance_metadata = []
for item in utterance_metadata:
new_item = item.copy()
path = ""
try:
path = item["path"]
duration = item["end"] - item["start"]
if self._is_short_audio(duration=duration):
transcribed_text = ""
logger().debug(
f"speech_to_text._is_short_audio. Audio is less than {self.MIN_SECS} second, skipping transcription of '{path}'."
)
if not input_srt:
path = item["path"]
duration = item["end"] - item["start"]
if self._is_short_audio(duration=duration):
transcribed_text = ""
logger().debug(
f"speech_to_text._is_short_audio. Audio is less than {self.MIN_SECS} second, skipping transcription of '{path}'."
)
else:
transcribed_text = self._transcribe(
vocals_filepath=path,
source_language_iso_639_1=iso_639_1,
)
transcribed_text = self._make_sure_single_space(transcribed_text)
else:
transcribed_text = self._transcribe(
vocals_filepath=path,
source_language_iso_639_1=iso_639_1,
)
transcribed_text = self._make_sure_single_space(transcribed_text)
match = None
time_tolerance=0.05
for sub in subs:
target_start = item["start"]
target_end = item["end"]
sub_start = self._srt_time_to_seconds(sub.start)
sub_end = self._srt_time_to_seconds(sub.end)

if (abs(sub_start - target_start) <= time_tolerance and
abs(sub_end - target_end) <= time_tolerance):

# Remove the [SPEAKER_XX]: tag
clean_text = re.sub(r'^\[SPEAKER_\d{2}\]:\s*', '', sub.text.strip())
match = clean_text
break

if match is None:
print(f"\n⚠️ WARNING: No subtitle match found for time range {target_start:.3f}–{target_end:.3f} seconds "
f"(speaker {meta.get('speaker_id', 'UNKNOWN')})\n")
transcribed_text = self._make_sure_single_space("")
else:
transcribed_text = self._make_sure_single_space(match)
except Exception as e:
logger().error(
f"speech_to_text.transcribe_audio_chunks. file '{path}', error: '{e}'"
Expand Down