Skip to content

Commit

Permalink
Merge pull request #2 from dhruvbaldawa/improve-podcast-quality
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruvbaldawa authored Dec 5, 2024
2 parents 26560fe + beb646a commit cc34582
Show file tree
Hide file tree
Showing 13 changed files with 650 additions and 503 deletions.
20 changes: 10 additions & 10 deletions gyandex/cli/genpod.py → gyandex/cli/podgen.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import asyncio
import hashlib
import os
from collections import namedtuple

from dotenv import load_dotenv
from rich.console import Console
Expand All @@ -11,9 +11,9 @@
from gyandex.podgen.engine.publisher import PodcastPublisher, PodcastMetadata
from gyandex.podgen.feed.models import PodcastDB
from gyandex.podgen.config.loader import load_config
from gyandex.podgen.engine.synthesizer import TTSEngine
from gyandex.podgen.engine.workflows import create_script
from gyandex.podgen.speech.factory import get_text_to_speech_engine
from gyandex.podgen.storage.factory import get_storage
from gyandex.podgen.workflows.factory import get_workflow


def main():
Expand All @@ -29,7 +29,6 @@ def main():
return
console = Console()
config = load_config(args.config_path)
model = get_model(config.llm)

# Load the content
with console.status('[bold green] Loading content...[/bold green]'):
Expand All @@ -38,13 +37,14 @@ def main():

# Analyze the content
with console.status('[bold green] Crafting the script...[/bold green]'):
script = create_script(model, document) # attach callback to see the progress
console.log(f'Script completed for "{script.title}". Script contains {len(script.segments)} segments...')
workflow = get_workflow(config)
script = asyncio.run(workflow.generate_script(document))
console.log(f'Script completed for "{script.title}". Script contains {len(script.dialogues)} segments...')

# Generate the podcast audio
with console.status('[bold green] Generating audio...[/bold green]'):
tts_engine = TTSEngine()
audio_segments = [tts_engine.process_segment(segment) for segment in script.segments]
tts_engine = get_text_to_speech_engine(config.tts)
audio_segments = [tts_engine.process_segment(dialogue) for dialogue in script.dialogues]

# Create output directory
output_dir = f"generated_podcasts/{config.feed.slug}"
Expand Down Expand Up @@ -82,5 +82,5 @@ def main():
description=script.description,
)
)
console.log(f"Feed published at {urls['feed_url']}")
console.log(f"Episode published at {urls['episode_url']}")
console.print(f"Feed published at {urls['feed_url']}")
console.print(f"Episode published at {urls['episode_url']}")
35 changes: 24 additions & 11 deletions gyandex/podgen/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,30 @@ class GoogleGenerativeAILLMConfig(BaseModel):
google_api_key: str


class VoiceProfile(BaseModel):
voice_id: str
speaking_rate: float
pitch: int
class AlexandriaWorkflowConfig(BaseModel):
name: Literal["alexandria"]
outline: Union[GoogleGenerativeAILLMConfig]
script: Union[GoogleGenerativeAILLMConfig]
verbose: Optional[bool] = False


class TTSConfig(BaseModel):
provider: str
default_voice: str
voices: Dict[str, VoiceProfile]
class Gender(Enum):
MALE = "male"
FEMALE = "female"
NON_BINARY = "non-binary"


class Participant(BaseModel):
name: str
voice: str
gender: Gender
personality: Optional[str] = ''
language_code: Optional[str] = "en-US"


class GoogleCloudTTSConfig(BaseModel):
provider: Literal["google-cloud"]
participants: List[Participant]


class S3StorageConfig(BaseModel):
Expand Down Expand Up @@ -76,8 +90,7 @@ class ContentStructure(BaseModel):
class PodcastConfig(BaseModel):
version: str
content: ContentConfig
# @TODO: Rethink this because I would like to use multiple LLMs for optimizing costs
llm: Union[GoogleGenerativeAILLMConfig] = Field(discriminator="provider")
tts: TTSConfig
workflow: Union[AlexandriaWorkflowConfig] = Field(discriminator="name")
tts: Union[GoogleCloudTTSConfig] = Field(discriminator="provider")
storage: Union[S3StorageConfig] = Field(discriminator="provider")
feed: FeedConfig
131 changes: 0 additions & 131 deletions gyandex/podgen/engine/workflows.py

This file was deleted.

File renamed without changes.
12 changes: 12 additions & 0 deletions gyandex/podgen/speech/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Union

from .google_cloud import GoogleTTSEngine
from ..config.schema import GoogleCloudTTSConfig


# @TODO: Centralize this type and move this to a common place
def get_text_to_speech_engine(tts_config: Union[GoogleCloudTTSConfig]):
if tts_config.provider == "google-cloud":
return GoogleTTSEngine(tts_config.participants)
else:
raise NotImplementedError(f"Unsupported TTS provider: {tts_config.provider}")
Original file line number Diff line number Diff line change
@@ -1,35 +1,40 @@
from io import BytesIO
from typing import List, Optional, Dict, Any
from typing import List, Optional, Dict, Any, Union

from google.cloud import texttospeech
from pydub import AudioSegment

from gyandex.podgen.engine.workflows import PodcastSegment # @TODO: Pull this out of workflows
from ..config.schema import Participant, Gender
from ..workflows.types import ScriptSegment # @TODO: Pull this out of workflows


class TTSEngine:
# @TODO: Accept configuration to tweak the voice
def __init__(self):
class GoogleTTSEngine:
def __init__(self, participants: List[Participant]):
self.client = texttospeech.TextToSpeechClient()
self.voices = {
'HOST1': texttospeech.VoiceSelectionParams(
language_code='en-US',
name='en-US-Journey-D',
ssml_gender=texttospeech.SsmlVoiceGender.MALE
),
'HOST2': texttospeech.VoiceSelectionParams(
language_code='en-US',
name='en-US-Journey-O',
ssml_gender=texttospeech.SsmlVoiceGender.FEMALE
)
}
self.voices = self.generate_voice_profile(participants)
self.audio_config = texttospeech.AudioConfig(
audio_encoding=texttospeech.AudioEncoding.MP3,
effects_profile_id=['headphone-class-device']
)

def process_segment(self, segment: PodcastSegment) -> bytes:
# ssml = self.generate_ssml(segment)
def generate_voice_profile(self, participants: List[Participant]) -> Dict[str, Any]:
def resolve_gender(gender: Gender):
if gender == Gender.FEMALE:
return texttospeech.SsmlVoiceGender.FEMALE
elif gender == Gender.MALE:
return texttospeech.SsmlVoiceGender.MALE
return texttospeech.SsmlVoiceGender.NEUTRAL

return {
participant.name: texttospeech.VoiceSelectionParams(
language_code=participant.language_code,
name=participant.voice,
ssml_gender=resolve_gender(participant.gender),
)
for participant in participants
}

def process_segment(self, segment: ScriptSegment) -> bytes:
return self.synthesize_speech(segment.text, segment.speaker)

def synthesize_speech(self, text: str, speaker: str) -> bytes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from io import BytesIO
from pydub import AudioSegment
from google.cloud import texttospeech
from gyandex.podgen.engine.synthesizer import TTSEngine
from gyandex.podgen.engine.workflows import PodcastSegment
from gyandex.podgen.processors.tts import GoogleTTSEngine
from gyandex.podgen.engine.workflows import ScriptSegment

def test_tts_engine_initialization():
"""Tests that TTSEngine initializes with correct voice configurations"""
# Given/When
engine = TTSEngine()
engine = GoogleTTSEngine()

# Then
assert 'HOST1' in engine.voices
Expand All @@ -20,7 +20,7 @@ def test_tts_engine_initialization():
def test_synthesize_speech_for_host1(mock_client):
"""Tests speech synthesis for HOST1 voice"""
# Given
engine = TTSEngine()
engine = GoogleTTSEngine()
mock_response = Mock()
mock_response.audio_content = b"test_audio_content"
mock_client.return_value.synthesize_speech.return_value = mock_response
Expand All @@ -36,8 +36,8 @@ def test_synthesize_speech_for_host1(mock_client):
def test_process_segment(mock_client):
"""Tests processing of a complete podcast segment"""
# Given
engine = TTSEngine()
segment = PodcastSegment(text="Test segment", speaker="HOST1")
engine = GoogleTTSEngine()
segment = ScriptSegment(dialogue="Test segment", speaker="HOST1")
mock_response = Mock()
mock_response.audio_content = b"test_audio_content"
mock_client.return_value.synthesize_speech.return_value = mock_response
Expand Down
Empty file.
Loading

0 comments on commit cc34582

Please sign in to comment.