-
Notifications
You must be signed in to change notification settings - Fork 299
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Move functions and class definitions to library
- Loading branch information
Showing
13 changed files
with
298 additions
and
262 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
.env | ||
venv | ||
.venv | ||
.venv | ||
*.pyc |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .section_writer import generate_section | ||
from .structure_writer import generate_book_structure | ||
from .title_writer import generate_book_title |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
""" | ||
Agent to generate book section content | ||
""" | ||
|
||
from ..inference import GenerationStatistics | ||
|
||
def generate_section(prompt: str, additional_instructions: str, model: str, groq_provider): | ||
stream = groq_provider.chat.completions.create( | ||
model=model, | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": "You are an expert writer. Generate a long, comprehensive, structured chapter for the section provided. If additional instructions are provided, consider them very important. Only output the content.", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": f"Generate a long, comprehensive, structured chapter. Use the following section and important instructions:\n\n<section_title>{prompt}</section_title>\n\n<additional_instructions>{additional_instructions}</additional_instructions>", | ||
}, | ||
], | ||
temperature=0.3, | ||
max_tokens=8000, | ||
top_p=1, | ||
stream=True, | ||
stop=None, | ||
) | ||
|
||
for chunk in stream: | ||
tokens = chunk.choices[0].delta.content | ||
if tokens: | ||
yield tokens | ||
if x_groq := chunk.x_groq: | ||
if not x_groq.usage: | ||
continue | ||
usage = x_groq.usage | ||
statistics_to_return = GenerationStatistics( | ||
input_time=usage.prompt_time, | ||
output_time=usage.completion_time, | ||
input_tokens=usage.prompt_tokens, | ||
output_tokens=usage.completion_tokens, | ||
total_time=usage.total_time, | ||
model_name=model, | ||
) | ||
yield statistics_to_return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
""" | ||
Agent to generate book structure | ||
""" | ||
|
||
from ..inference import GenerationStatistics | ||
|
||
def generate_book_structure(prompt: str, additional_instructions: str, model: str, groq_provider): | ||
""" | ||
Returns book structure content as well as total tokens and total time for generation. | ||
""" | ||
completion = groq_provider.chat.completions.create( | ||
model=model, | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": 'Write in JSON format:\n\n{"Title of section goes here":"Description of section goes here",\n"Title of section goes here":{"Title of section goes here":"Description of section goes here","Title of section goes here":"Description of section goes here","Title of section goes here":"Description of section goes here"}}', | ||
}, | ||
{ | ||
"role": "user", | ||
"content": f"Write a comprehensive structure, omiting introduction and conclusion sections (forward, author's note, summary), for a long (>300 page) book. It is very important that use the following subject and additional instructions to write the book. \n\n<subject>{prompt}</subject>\n\n<additional_instructions>{additional_instructions}</additional_instructions>", | ||
}, | ||
], | ||
temperature=0.3, | ||
max_tokens=8000, | ||
top_p=1, | ||
stream=False, | ||
response_format={"type": "json_object"}, | ||
stop=None, | ||
) | ||
|
||
usage = completion.usage | ||
statistics_to_return = GenerationStatistics( | ||
input_time=usage.prompt_time, | ||
output_time=usage.completion_time, | ||
input_tokens=usage.prompt_tokens, | ||
output_tokens=usage.completion_tokens, | ||
total_time=usage.total_time, | ||
model_name=model, | ||
) | ||
|
||
return statistics_to_return, completion.choices[0].message.content |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
""" | ||
Agent to generate book title | ||
""" | ||
|
||
from ..inference import GenerationStatistics | ||
|
||
def generate_book_title(prompt: str, model: str, groq_provider): | ||
""" | ||
Generate a book title using AI. | ||
""" | ||
completion = groq_provider.chat.completions.create( | ||
model="llama3-70b-8192", | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": "Generate suitable book titles for the provided topics. There is only one generated book title! Don't give any explanation or add any symbols, just write the title of the book. The requirement for this title is that it must be between 7 and 25 words long, and it must be attractive enough!" | ||
}, | ||
{ | ||
"role": "user", | ||
"content": f"Generate a book title for the following topic. There is only one generated book title! Don't give any explanation or add any symbols, just write the title of the book. The requirement for this title is that it must be at least 7 words and 25 words long, and it must be attractive enough:\n\n{prompt}" | ||
} | ||
], | ||
temperature=0.7, | ||
max_tokens=100, | ||
top_p=1, | ||
stream=False, | ||
stop=None, | ||
) | ||
|
||
return completion.choices[0].message.content.strip() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .stats import GenerationStatistics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
""" | ||
Class for tracking and displaying inference statistics | ||
""" | ||
|
||
class GenerationStatistics: | ||
def __init__( | ||
self, | ||
model_name, | ||
input_time=0, | ||
output_time=0, | ||
input_tokens=0, | ||
output_tokens=0, | ||
total_time=0, | ||
): | ||
self.model_name = model_name | ||
self.input_time = input_time | ||
self.output_time = output_time | ||
self.input_tokens = input_tokens | ||
self.output_tokens = output_tokens | ||
self.total_time = ( | ||
total_time # Sum of queue, prompt (input), and completion (output) times | ||
) | ||
|
||
def get_input_speed(self): | ||
""" | ||
Tokens per second calculation for input | ||
""" | ||
if self.input_time != 0: | ||
return self.input_tokens / self.input_time | ||
else: | ||
return 0 | ||
|
||
def get_output_speed(self): | ||
""" | ||
Tokens per second calculation for output | ||
""" | ||
if self.output_time != 0: | ||
return self.output_tokens / self.output_time | ||
else: | ||
return 0 | ||
|
||
def add(self, other): | ||
""" | ||
Add statistics from another GenerationStatistics object to this one. | ||
""" | ||
if not isinstance(other, GenerationStatistics): | ||
raise TypeError("Can only add GenerationStatistics objects") | ||
|
||
self.input_time += other.input_time | ||
self.output_time += other.output_time | ||
self.input_tokens += other.input_tokens | ||
self.output_tokens += other.output_tokens | ||
self.total_time += other.total_time | ||
|
||
def __str__(self): | ||
return ( | ||
f"\n## {self.get_output_speed():.2f} T/s ⚡\nRound trip time: {self.total_time:.2f}s Model: {self.model_name}\n\n" | ||
f"| Metric | Input | Output | Total |\n" | ||
f"|-----------------|----------------|-----------------|----------------|\n" | ||
f"| Speed (T/s) | {self.get_input_speed():.2f} | {self.get_output_speed():.2f} | {(self.input_tokens + self.output_tokens) / self.total_time if self.total_time != 0 else 0:.2f} |\n" | ||
f"| Tokens | {self.input_tokens} | {self.output_tokens} | {self.input_tokens + self.output_tokens} |\n" | ||
f"| Inference Time (s) | {self.input_time:.2f} | {self.output_time:.2f} | {self.total_time:.2f} |" | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .markdown import create_markdown_file | ||
from .pdf import create_pdf_file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
""" | ||
Functions to manage markdown content | ||
""" | ||
|
||
from io import BytesIO | ||
|
||
def create_markdown_file(content: str) -> BytesIO: | ||
""" | ||
Create a Markdown file from the provided content. | ||
""" | ||
markdown_file = BytesIO() | ||
markdown_file.write(content.encode("utf-8")) | ||
markdown_file.seek(0) | ||
return markdown_file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
""" | ||
Functions to manage pdf content | ||
""" | ||
|
||
from io import BytesIO | ||
from markdown import markdown | ||
|
||
def create_pdf_file(content: str) -> BytesIO: | ||
""" | ||
Create a PDF file from the provided Markdown content. | ||
Converts Markdown to styled HTML, then HTML to PDF. | ||
""" | ||
|
||
html_content = markdown(content, extensions=["extra", "codehilite"]) | ||
|
||
styled_html = f""" | ||
<html> | ||
<head> | ||
<style> | ||
@page {{ | ||
margin: 2cm; | ||
}} | ||
body {{ | ||
font-family: Arial, sans-serif; | ||
line-height: 1.6; | ||
font-size: 12pt; | ||
}} | ||
h1, h2, h3, h4, h5, h6 {{ | ||
color: #333366; | ||
margin-top: 1em; | ||
margin-bottom: 0.5em; | ||
}} | ||
p {{ | ||
margin-bottom: 0.5em; | ||
}} | ||
code {{ | ||
background-color: #f4f4f4; | ||
padding: 2px 4px; | ||
border-radius: 4px; | ||
font-family: monospace; | ||
font-size: 0.9em; | ||
}} | ||
pre {{ | ||
background-color: #f4f4f4; | ||
padding: 1em; | ||
border-radius: 4px; | ||
white-space: pre-wrap; | ||
word-wrap: break-word; | ||
}} | ||
blockquote {{ | ||
border-left: 4px solid #ccc; | ||
padding-left: 1em; | ||
margin-left: 0; | ||
font-style: italic; | ||
}} | ||
table {{ | ||
border-collapse: collapse; | ||
width: 100%; | ||
margin-bottom: 1em; | ||
}} | ||
th, td {{ | ||
border: 1px solid #ddd; | ||
padding: 8px; | ||
text-align: left; | ||
}} | ||
th {{ | ||
background-color: #f2f2f2; | ||
}} | ||
input, textarea {{ | ||
border-color: #4A90E2 !important; | ||
}} | ||
</style> | ||
</head> | ||
<body> | ||
{html_content} | ||
</body> | ||
</html> | ||
""" | ||
|
||
pdf_buffer = BytesIO() | ||
HTML(string=styled_html).write_pdf(pdf_buffer) | ||
pdf_buffer.seek(0) | ||
|
||
return pdf_buffer |
Oops, something went wrong.