Skip to content

Commit 04e9b4c

Browse files
authored
Implement model wrapper for AI content generation (#58)
### Issue <!-- Please link the GitHub issues related to this PR, if available --> ### Description This PR introduces the new ModelWrapper feature, which encapsulates model logic using the pydantic-ai Agent. The ModelWrapper enables flexible integration with multiple LLM providers and centralizes AI-powered content generation. Key changes include: - Addition of the `ModelWrapper` class in `struct_module/model_wrapper.py` - Integration of ModelWrapper into `FileItem` and related modules - Improved test and workflow configuration to support AI model usage - Ensured environment variable handling for API keys in CI and local development These changes improve modularity, testability, and future extensibility for AI-powered features. ### Checklist - [x] I have read the [contributing guidelines](https://github.com/httpdss/struct/blob/main/README.md#contributing). - [x] My code follows the code style of this project. - [x] I have performed a self-review of my own code. - [x] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [x] I have updated the documentation accordingly. ### Screenshots (if applicable) <!-- Add screenshots to illustrate the changes made in the pull request --> ### Additional Comments - The main feature is the ModelWrapper, which standardizes model usage across the codebase. - The test workflow now sets a dummy `OPENAI_API_KEY` for CI compatibility.
1 parent ee6537a commit 04e9b4c

File tree

6 files changed

+72
-39
lines changed

6 files changed

+72
-39
lines changed

.devcontainer/devcontainer.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"image": "mcr.microsoft.com/devcontainers/python:3",
44
"features": {
55
"ghcr.io/devcontainers/features/python:1": {},
6-
"ghcr.io/gvatsal60/dev-container-features/pre-commit": {},
6+
"ghcr.io/gvatsal60/dev-container-features/pre-commit:1": {}
77
},
88
"postCreateCommand": "bash ./scripts/devcontainer_start.sh",
99
"customizations": {

.env.example

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,29 @@
1+
# The AI model to use, e.g., "openai:gpt-4.1"
2+
# check for available models at https://ai.pydantic.dev/models/
3+
AI_MODEL=your-model-name-here
4+
5+
# To use openai's API, you need to set your API key in the environment variable OPENAI_API_KEY.
16
OPENAI_API_KEY=your-api-key-here
7+
8+
# To use Anthropy's API, you need to set your API key in the environment variable ANTHROPY_API_KEY.
9+
ANTHROPY_API_KEY=your-anthropy-api-key-here
10+
11+
# To use Gemini's API, you need to set your API key in the environment variable GEMINI_API_KEY.
12+
GEMINI_API_KEY=your-gemini-api-key-here
13+
14+
# To use Google's API, you need to set your API key in the environment variable GOOGLE_API_KEY.
15+
GOOGLE_API_KEY=your-google-api-key-here
16+
17+
# Bedrock API key
18+
AWS_ACCESS_KEY_ID=your-aws-access-key-id-here
19+
AWS_SECRET_ACCESS_KEY=your-aws-secret-access-key-here
20+
AWS_DEFAULT_REGION=your-aws-region-here
21+
22+
# To use Cohere's API, you need to set your API key in the environment variable CO_API_KEY.
23+
CO_API_KEY=your-cohere-api-key-here
24+
25+
# To use Groq's API, you need to set your API key in the environment variable GROQ_API_KEY.
26+
GROQ_API_KEY=your-groq-api-key-here
27+
28+
# To use Mistral's API, you need to set your API key in the environment variable MISTRAL_API_KEY.
29+
MISTRAL_API_KEY=your-mistral-api-key-here

.github/workflows/test-script.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ on:
88
branches:
99
- main
1010

11+
env:
12+
OPENAI_API_KEY: "my-test-key"
13+
1114
jobs:
1215
build:
1316

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ boto3
1010
google-cloud
1111
google-api-core
1212
cachetools
13+
pydantic-ai

struct_module/file_item.py

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44
import shutil
55
import logging
66
import time
7-
from openai import OpenAI
87
from dotenv import load_dotenv
98
from struct_module.template_renderer import TemplateRenderer
109
from struct_module.content_fetcher import ContentFetcher
10+
from struct_module.model_wrapper import ModelWrapper
1111

1212
load_dotenv()
1313

14-
openai_api_key = os.getenv("OPENAI_API_KEY")
15-
openai_model = os.getenv("OPENAI_MODEL")
16-
1714
class FileItem:
1815
def __init__(self, properties):
1916
self.logger = logging.getLogger(__name__)
@@ -32,11 +29,9 @@ def __init__(self, properties):
3229

3330
self.system_prompt = properties.get("system_prompt") or properties.get("global_system_prompt")
3431
self.user_prompt = properties.get("user_prompt")
35-
self.openai_client = None
3632
self.mappings = properties.get("mappings", {})
3733

38-
if openai_api_key:
39-
self._configure_openai()
34+
self.model_wrapper = ModelWrapper(self.logger)
4035

4136
self.template_renderer = TemplateRenderer(
4237
self.config_variables,
@@ -45,55 +40,29 @@ def __init__(self, properties):
4540
self.mappings
4641
)
4742

48-
def _configure_openai(self):
49-
self.openai_client = OpenAI(api_key=openai_api_key)
50-
if not openai_model:
51-
self.logger.debug("OpenAI model not found. Using default model.")
52-
self.openai_model = "gpt-4.1"
53-
else:
54-
self.logger.debug(f"Using OpenAI model: {openai_model}")
55-
self.openai_model = openai_model
56-
5743
def _get_file_directory(self):
5844
return os.path.dirname(self.name)
5945

6046
def process_prompt(self, dry_run=False, existing_content=None):
6147
if self.user_prompt:
62-
if not self.openai_client or not openai_api_key:
63-
self.logger.warning("Skipping processing prompt as OpenAI API key is not set.")
64-
return
6548

6649
if not self.system_prompt:
6750
system_prompt = "You are a software developer working on a project. You need to create a file with the following content:"
6851
else:
6952
system_prompt = self.system_prompt
7053

71-
# If existing_content is provided, append it to the user prompt
7254
user_prompt = self.user_prompt
7355
if existing_content:
7456
user_prompt += f"\n\nCurrent file content (if any):\n```\n{existing_content}\n```\n\nPlease modify existing content so that it meets the new requirements. Your output should be plain text, without any code blocks or formatting. Do not include any explanations or comments. Just provide the final content of the file."
7557

7658
self.logger.debug(f"Using system prompt: {system_prompt}")
7759
self.logger.debug(f"Using user prompt: {user_prompt}")
7860

79-
if dry_run:
80-
self.logger.info("[DRY RUN] Would generate content using OpenAI API.")
81-
self.content = "[DRY RUN] Generating content using OpenAI"
82-
return
83-
84-
if self.openai_client and openai_api_key:
85-
completion = self.openai_client.chat.completions.create(
86-
model=self.openai_model,
87-
messages=[
88-
{"role": "system", "content": system_prompt},
89-
{"role": "user", "content": user_prompt}
90-
]
91-
)
92-
93-
self.content = completion.choices[0].message.content
94-
else:
95-
self.content = "OpenAI API key not found. Skipping content generation."
96-
self.logger.warning("Skipping processing prompt as OpenAI API key is not set.")
61+
self.content = self.model_wrapper.generate_content(
62+
system_prompt,
63+
user_prompt,
64+
dry_run=dry_run
65+
)
9766
self.logger.debug(f"Generated content: \n\n{self.content}")
9867

9968
def fetch_content(self):

struct_module/model_wrapper.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import os
2+
import logging
3+
from dotenv import load_dotenv
4+
from pydantic_ai import Agent
5+
6+
load_dotenv()
7+
8+
class ModelWrapper:
9+
"""
10+
Wraps model logic using pydantic-ai Agent, allowing use of multiple LLM providers.
11+
"""
12+
def __init__(self, logger=None):
13+
self.logger = logger or logging.getLogger(__name__)
14+
self.model_name = os.getenv("AI_MODEL") or "openai:gpt-4.1"
15+
self.agent = Agent(model=self.model_name)
16+
self.logger.debug(f"Configured Agent with model: {self.model_name}")
17+
18+
def generate_content(self, system_prompt, user_prompt, dry_run=False):
19+
if not self.agent:
20+
self.logger.warning("No agent configured. Skipping content generation.")
21+
return "No agent configured. Skipping content generation."
22+
if dry_run:
23+
self.logger.info("[DRY RUN] Would generate content using AI agent.")
24+
return "[DRY RUN] Generating content using AI agent"
25+
prompt = f"{user_prompt}"
26+
try:
27+
self.agent.system_prompt = system_prompt
28+
result = self.agent.run_sync(prompt)
29+
return result.output
30+
except Exception as e:
31+
self.logger.error(f"AI agent generation failed: {e}")
32+
return f"AI agent generation failed: {e}"

0 commit comments

Comments
 (0)