Skip to content

Commit

Permalink
Merge pull request #4 from dhruvbaldawa/code-quality
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruvbaldawa authored Dec 14, 2024
2 parents cc34582 + 9af0aa3 commit 19947c7
Show file tree
Hide file tree
Showing 39 changed files with 760 additions and 933 deletions.
51 changes: 51 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name: Quality Checks

on:
pull_request:

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: 'google-github-actions/auth@v2'
with:
credentials_json: '${{ secrets.GOOGLE_CREDENTIALS }}'

- name: Set up Cloud SDK
uses: google-github-actions/setup-gcloud@v2

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'

- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: latest

- name: Setup Poetry cache
uses: actions/cache@v4
with:
path: ./.venv
key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}

- name: Install just
uses: extractions/setup-just@v1

- name: Configure Poetry
run: |
poetry config virtualenvs.in-project true
poetry config virtualenvs.create true
- name: Install dependencies
run: poetry install --no-interaction

- name: Run checks
run: poetry run just check

- name: Run tests
run: poetry run just test
16 changes: 16 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
repos:
- repo: local
hooks:
- id: ruff-check
name: Run ruff check
entry: poetry run just check
language: system
pass_filenames: false
stages: [pre-commit]

- id: pytest
name: Run pytest
entry: poetry run just test
language: system
pass_filenames: false
stages: [pre-commit]
6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@ poetry install
- Add `GOOGLE_API_KEY` in `.env` by generating the [API key from Google AI Studio](https://aistudio.google.com/app/apikey)
- Login to Google Cloud using [the following instructions](https://cloud.google.com/text-to-speech/docs/create-audio-text-client-libraries)

## Running the Application
```shell
jupyter notebook
```
Run the `main.ipynb` notebook

## License

This project is licensed under the **AGPL v3** for open-source use. For those wishing to use the software in proprietary applications without disclosing source code, a **commercial license** is available.
Expand Down
5 changes: 5 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pytest_plugins = [
"gyandex.podgen.storage.test_fixtures",
"gyandex.podgen.feed.test_fixtures",
"gyandex.podgen.engine.test_fixtures",
]
28 changes: 14 additions & 14 deletions gyandex/cli/podgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from dotenv import load_dotenv
from rich.console import Console

from gyandex.llms.factory import get_model
from gyandex.loaders.factory import load_content
from gyandex.podgen.engine.publisher import PodcastPublisher, PodcastMetadata
from gyandex.podgen.feed.models import PodcastDB
from gyandex.podgen.config.loader import load_config
from gyandex.podgen.engine.publisher import PodcastMetadata, PodcastPublisher
from gyandex.podgen.feed.models import PodcastDB
from gyandex.podgen.speech.factory import get_text_to_speech_engine
from gyandex.podgen.storage.factory import get_storage
from gyandex.podgen.workflows.factory import get_workflow
Expand All @@ -31,18 +30,18 @@ def main():
config = load_config(args.config_path)

# Load the content
with console.status('[bold green] Loading content...[/bold green]'):
with console.status("[bold green] Loading content...[/bold green]"):
document = load_content(config.content)
console.log('Content loaded...')
console.log("Content loaded...")

# Analyze the content
with console.status('[bold green] Crafting the script...[/bold green]'):
with console.status("[bold green] Crafting the script...[/bold green]"):
workflow = get_workflow(config)
script = asyncio.run(workflow.generate_script(document))
console.log(f'Script completed for "{script.title}". Script contains {len(script.dialogues)} segments...')

# Generate the podcast audio
with console.status('[bold green] Generating audio...[/bold green]'):
with console.status("[bold green] Generating audio...[/bold green]"):
tts_engine = get_text_to_speech_engine(config.tts)
audio_segments = [tts_engine.process_segment(dialogue) for dialogue in script.dialogues]

Expand All @@ -52,17 +51,18 @@ def main():

podcast_path = f"{output_dir}/podcast_{hashlib.md5(config.content.source.encode()).hexdigest()}.mp3"
tts_engine.generate_audio_file(audio_segments, podcast_path)
console.log(f'Podcast file {podcast_path} generated...')
console.log(f"Podcast file {podcast_path} generated...")

with console.status('[bold green] Publishing podcast...[/bold green]'):
with console.status("[bold green] Publishing podcast...[/bold green]"):
storage = get_storage(config.storage)
db = PodcastDB(db_path='assets/podcasts.db')
db = PodcastDB(db_path="assets/podcasts.db")
publisher = PodcastPublisher(
storage=storage,
db=db,
base_url=f"https://{storage.custom_domain}", # @FIXME: we need to fallback when custom domain is not available
# @FIXME: we need to fallback when custom domain is not available
base_url=f"https://{storage.custom_domain}",
)
feed_url = publisher.create_feed(
publisher.create_feed(
slug=config.feed.slug,
title=config.feed.title,
email=config.feed.email,
Expand All @@ -73,14 +73,14 @@ def main():
language=config.feed.language,
categories=",".join(config.feed.categories),
)
console.log('Uploading episode...')
console.log("Uploading episode...")
urls = publisher.add_episode(
feed_slug=config.feed.slug,
audio_file_path=podcast_path,
metadata=PodcastMetadata(
title=script.title,
description=script.description,
)
),
)
console.print(f"Feed published at {urls['feed_url']}")
console.print(f"Episode published at {urls['episode_url']}")
31 changes: 31 additions & 0 deletions gyandex/cli/podgen_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from unittest.mock import Mock, patch

import pytest

from gyandex.cli.podgen import main


def test_cli_help_command():
"""Tests that help command prints help message and exits"""
# When
with (
patch("argparse.ArgumentParser.parse_args", return_value=Mock(config_path="--help")),
patch("argparse.ArgumentParser.print_help") as mock_help,
):
main()

# Then
mock_help.assert_called_once()


def test_invalid_config_path():
"""Tests handling of invalid configuration file path"""
# Given
invalid_path = "nonexistent.yaml"

# When/Then
with (
pytest.raises(FileNotFoundError),
patch("argparse.ArgumentParser.parse_args", return_value=Mock(config_path=invalid_path)),
):
main()
15 changes: 8 additions & 7 deletions gyandex/llms/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@

class LLMLoggingCallback(BaseCallbackHandler):
def __init__(self, log_dir="assets"):
logger = logging.getLogger('llm_logger')
logger = logging.getLogger("llm_logger")
logger.setLevel(logging.INFO)

# Create file handler with timestamp in filename
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
fh = logging.FileHandler(f'{log_dir}/llm_logs_{timestamp}.log')
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
fh = logging.FileHandler(f"{log_dir}/llm_logs_{timestamp}.log")
fh.setLevel(logging.INFO)

# Create formatter
formatter = logging.Formatter('%(asctime)s - %(message)s')
formatter = logging.Formatter("%(asctime)s - %(message)s")
fh.setFormatter(formatter)

logger.addHandler(fh)
Expand All @@ -38,14 +38,15 @@ def on_llm_end(self, response, **kwargs):
def on_llm_error(self, error, **kwargs):
self.logger.error(f"\n=== ERROR ===\n{str(error)}\n")


# @TODO: Centralize this argument type in a single place
def get_model(config: Union[GoogleGenerativeAILLMConfig], log_dir="assets"):
def get_model(config: Union[GoogleGenerativeAILLMConfig], log_dir="assets"): # pyright: ignore [reportInvalidTypeArguments]
if config.provider == "google-generative-ai":
return GoogleGenerativeAI(
model=config.model,
temperature=config.temperature,
google_api_key=config.google_api_key,
max_output_tokens=8192, # @TODO: Move this to config params
google_api_key=config.google_api_key, # pyright: ignore [reportCallIssue]
max_output_tokens=8192, # @TODO: Move this to config params # pyright: ignore [reportCallIssue]
callbacks=[LLMLoggingCallback(log_dir)],
)
else:
Expand Down
20 changes: 7 additions & 13 deletions gyandex/llms/factory_test.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import pytest
from unittest.mock import Mock, patch
from datetime import datetime
from langchain_google_genai import GoogleGenerativeAI
from pydantic import ValidationError

from gyandex.llms.factory import get_model, LLMLoggingCallback
from gyandex.podgen.config.schema import GoogleGenerativeAILLMConfig
from ..podgen.config.schema import GoogleGenerativeAILLMConfig
from .factory import get_model


def test_get_model_returns_google_generative_ai():
"""Tests that get_model creates a GoogleGenerativeAI instance with correct config"""
# Given
config = GoogleGenerativeAILLMConfig(
provider="google-generative-ai",
model="gemini-pro",
temperature=0.7,
google_api_key="test-key"
provider="google-generative-ai", model="gemini-pro", temperature=0.7, google_api_key="test-key"
)

# When
Expand All @@ -25,13 +21,11 @@ def test_get_model_returns_google_generative_ai():
assert model.model == "gemini-pro"
assert model.temperature == 0.7


def test_get_model_raises_for_unsupported_provider():
"""Tests that get_model raises NotImplementedError for unsupported providers"""
# When/Then
with pytest.raises(ValidationError):
config = GoogleGenerativeAILLMConfig(
provider="unsupported",
model="test",
temperature=0.5,
google_api_key="test-key"
_ = GoogleGenerativeAILLMConfig(
provider="unsupported", model="test", temperature=0.5, google_api_key="test-key"
)
16 changes: 10 additions & 6 deletions gyandex/loaders/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Dict, Any
from typing import Any, Dict, Optional

import requests
from pydantic import BaseModel
Expand All @@ -20,11 +20,15 @@ def load_content(content_config: ContentConfig) -> Document:


def fetch_url(url) -> Document:
headers = { "Accept": "application/json" }
headers = {"Accept": "application/json"}
response = requests.get(f"https://r.jina.ai/{url}", headers=headers)
# @TODO: Add error handling
content = response.json()
return Document(title=content['data']['title'], content=content['data']['content'], metadata={
'url': content['data']['url'],
'description': content['data']['description'],
})
return Document(
title=content["data"]["title"],
content=content["data"]["content"],
metadata={
"url": content["data"]["url"],
"description": content["data"]["description"],
},
)
22 changes: 8 additions & 14 deletions gyandex/loaders/factory_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import json

import pytest
import responses
from gyandex.loaders.factory import fetch_url

from .factory import fetch_url


@responses.activate
Expand All @@ -11,33 +9,28 @@ def test_fetch_url_returns_json_response():
# Given
test_url = "test123"
actual = {"data": {"title": "title", "content": "test content", "url": "url", "description": "description"}}
responses.add(
responses.GET,
f"https://r.jina.ai/{test_url}",
json=actual,
status=200
)
responses.add(responses.GET, f"https://r.jina.ai/{test_url}", json=actual, status=200)

# When
result = fetch_url(test_url)

# Then
assert result.content == "test content"
assert result.title == "title"
assert result.metadata == { "url": "url", "description": "description" }
assert result.metadata == {"url": "url", "description": "description"}


@responses.activate
def test_fetch_url_sends_correct_headers():
"""Tests that fetch_url sends the correct Accept header"""
# Given
test_url = "test123"
expected_headers = {"Accept": "application/json"}
_ = {"Accept": "application/json"}
responses.add(
responses.GET,
f"https://r.jina.ai/{test_url}",
json={"data": {"title": "title", "content": "test content", "url": "url", "description": "description"}},
status=200
status=200,
)

# When
Expand All @@ -46,6 +39,7 @@ def test_fetch_url_sends_correct_headers():
# Then
assert responses.calls[0].request.headers["Accept"] == "application/json"


@responses.activate
def test_fetch_url_constructs_correct_url():
"""Tests that fetch_url constructs the correct URL with the base and provided path"""
Expand All @@ -56,7 +50,7 @@ def test_fetch_url_constructs_correct_url():
responses.GET,
expected_url,
json={"data": {"title": "title", "content": "test content", "url": "url", "description": "description"}},
status=200
status=200,
)

# When
Expand Down
Loading

0 comments on commit 19947c7

Please sign in to comment.