diff --git a/Cargo.toml b/Cargo.toml index 658182d..f2dcaf2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,15 +31,16 @@ thiserror = "2.0.11" tokio = { version = "1.43", features = ["full"] } futures = "0.3" parking_lot = "0.12.3" -tracing = "0.1" +maplit = "1.0.2" # CLI and UI - +structopt = "0.3.26" colored = "3.0.0" console = { version = "0.15.10", default-features = false } indicatif = { version = "0.17.11", default-features = false } log = "0.4.25" env_logger = { version = "0.11.6", default-features = false } +tracing = "0.1" # Git integration git2 = { version = "0.20.0", default-features = false } @@ -74,9 +75,7 @@ syntect = { version = "5.2", default-features = false, features = [ pulldown-cmark = "0.12" comrak = "0.35" textwrap = "0.16" -structopt = "0.3.26" mustache = "0.9.0" -maplit = "1.0.2" [dev-dependencies] tempfile = "3.16.0" diff --git a/finetune.md b/finetune.md new file mode 100644 index 0000000..ca4be33 --- /dev/null +++ b/finetune.md @@ -0,0 +1,85 @@ +# Finetune.rs Workflow + +Here's a summary of the workflow in `finetune.rs`: + +- Uses GPT4o-mini model for OpenAI +- Generates training data in JSONL format for fine-tuning +- Splits data into training and verification sets + +1. **Initialize and Setup** + + - Creates empty train and verify files + - Sets up thread pool for parallel processing + - Initializes progress bars and counters + - Loads system prompt from `resources/prompt.md` + +2. **Collect Commit History** + + - Opens local git repository + - Walks through commit history + - Filters commits based on: + - Message length (20-500 chars) + - Non-merge commits only + - Diff size within limits (default 5000 chars) + - Collects valid commits up to 3x target number + - Shuffles commits for randomization + +3. **Process Commits in Parallel** + + - Spawns worker threads based on CPU count or user setting + - Each worker processes a subset of commits + - For each commit: + - Checks for duplicate messages + - Rates commit quality (0.0-1.0) + - Cleans up commit message + - Tracks approved commits with progress bar + - Stops when target number reached + +4. **Clean and Rate Commit Messages** + + - Cleanup process: + - Takes first line only + - Removes ticket references and tags + - Ensures proper capitalization + - Drops type prefixes + - Keeps messages short and meaningful + - Quality rating based on: + - Message format and clarity + - Diff alignment + - Present tense and active voice + - Description accuracy + +5. **Generate Training Data** + + - Creates JSONL entries with: + - System prompt + - Diff as user input + - Cleaned message as assistant output + - Splits data: + - 50% for training + - 50% for verification + - Prevents duplicate messages + - Validates cleaned messages + +6. **Track Progress and Results** + - Shows real-time progress: + - Commit collection progress + - Message cleaning progress + - Approval status + - Reports final statistics: + - Total commits processed + - Training examples count + - Verification examples count + - Distribution between files + +Key Features: + +- Parallel processing for better performance +- Double quality check (original and cleaned messages) +- Duplicate prevention at multiple stages +- Progress visualization with spinners and bars +- Verbose mode for detailed logging + +The key difference from optimize.rs is that finetune.rs focuses on generating high-quality training data for fine-tuning, while optimize.rs focuses on improving the system prompt itself. + +Note: Run sync, not async diff --git a/resources/prompt.md b/resources/prompt.md index 14b9685..085cd7a 100644 --- a/resources/prompt.md +++ b/resources/prompt.md @@ -6,17 +6,24 @@ The character limit for the commit message is: {{max_length}} -Please follow these guidelines when generating the commit message: - -1. Analyze the diff carefully, focusing on lines marked with + or -. -2. Identify the files changed and the nature of the changes (added, modified, or deleted). -3. Determine the most significant change if multiple changes are present. -4. Create a clear, present-tense summary of the change in the imperative mood. -5. Ensure the commit message is within the specified character limit. -6. For binary files or unreadable diffs: - - Use the format "Add/Update/Delete binary file " - - Include file size in parentheses if available - - For multiple binary files, list them separated by commas +Please adhere to the following enhanced guidelines: + +- **Structure**: Begin with a clear, present-tense summary of the change in the non-conventional commit format. Use a single-line summary for the change, followed by a blank line. As a best practice, consider including only one bullet point detailing context if essential, but refrain from excessive elaboration. + +- **Content**: Commit messages must strictly describe the lines marked with + or - in the diff. Avoid including surrounding context, unmarked lines, or irrelevant details. Explicitly refrain from mentioning implications, reasoning, motivations, or any external context not explicitly reflected in the diff. Make sure to avoid any interpretations or assumptions beyond what is clearly stated. + +- **Changes**: Clearly articulate what was added, removed, or modified based solely on what is visible in the diff. Use phrases such as "Based only on the changes visible in the diff, this commit..." to emphasize an evidence-based approach while outlining changes directly. + +- **Consistency**: Ensure uniformity in tense, punctuation, and capitalization throughout the message. Use present tense and imperative form, such as "Add x to y" instead of "Added x to y". + +- **Clarity & Brevity**: Craft messages that are clear and easy to understand, succinctly capturing the essence of the changes. Limit the message to the specified character limit while ensuring enough detail is provided on the primary action taken. Avoid jargon; provide plain definitions for any necessary technical terms. + +- **Binary Files**: For binary files or unreadable diffs: + - Use the format "Add/Update/Delete binary file " + - Include file size in parentheses if available + - For multiple binary files, list them separated by commas + +- **Accuracy & Hallucination Prevention**: Rigorously reflect only the changes visible in the diff. Avoid any speculation or inclusion of content not substantiated by the diff. Restate the necessity for messages to focus exclusively on aspects evident in the diff and to completely avoid extrapolation or assumptions about motivations or implications. Before generating the final commit message, please analyze the diff and but keep your thought process to your self: @@ -25,11 +32,11 @@ Before generating the final commit message, please analyze the diff and but keep 3. Identify any binary files or unreadable diffs separately. 4. Determine the most significant change if multiple changes are present. 5. Consider the impact of each change and its relevance to the overall commit message. -6. Brainstorm keywords that could be used in the commit message. -7. Propose three potential single-line summaries based on the breakdown. -8. Count the characters in each proposed summary, ensuring they meet the specified character limit. -9. Select the best summary that accurately reflects the most significant change and meets the character limit. -10. Prefixes such as `refactor:`, `fix` should be removed +6. Review the message to ensure it: + - Accurately reflects only the changes in the diff + - Follows the structure and formatting guidelines + - Contains no external context or assumptions + - Is clear and understandable to other developers After your analysis, provide only the final commit message as output. Ensure it is clear, concise, and accurately reflects the content of the diff while adhering to the character limit. Do not include any additional text or explanations in your final output. diff --git a/src/bin/hook.rs b/src/bin/hook.rs index 818e247..e5c686c 100644 --- a/src/bin/hook.rs +++ b/src/bin/hook.rs @@ -120,7 +120,7 @@ impl Args { .context("Failed to get patch")?; let response = commit::generate(patch.to_string(), remaining_tokens, model).await?; - std::fs::write(&self.commit_msg_file, response.response.trim())?; + std::fs::write(&self.commit_msg_file, response.trim())?; pb.finish_and_clear(); @@ -193,7 +193,7 @@ impl Args { .context("Failed to get patch")?; let response = commit::generate(patch.to_string(), remaining_tokens, model).await?; - std::fs::write(&self.commit_msg_file, response.response.trim())?; + std::fs::write(&self.commit_msg_file, response.trim())?; pb.finish_and_clear(); @@ -205,22 +205,12 @@ impl Args { #[tokio::main] async fn main() -> Result<()> { - if std::env::var("RUST_LOG").is_ok() { - env_logger::init(); - } + env_logger::init(); - let time = std::time::Instant::now(); let args = Args::from_args(); - - if log::log_enabled!(log::Level::Debug) { - log::debug!("Arguments: {:?}", args); - } - - if let Err(err) = args.execute().await { - eprintln!("{} ({:?})", err, time.elapsed()); + if let Err(e) = args.execute().await { + eprintln!("Error: {}", e); exit(1); - } else if log::log_enabled!(log::Level::Debug) { - log::debug!("Completed in {:?}", time.elapsed()); } Ok(()) diff --git a/src/commit.rs b/src/commit.rs index c8aa540..bbceefc 100644 --- a/src/commit.rs +++ b/src/commit.rs @@ -31,16 +31,15 @@ fn get_instruction_template() -> Result { /// * `Result` - The number of tokens used or an error pub fn get_instruction_token_count(model: &Model) -> Result { profile!("Calculate instruction tokens"); - let template = get_instruction_template()?; - model.count_tokens(&template) + model.count_tokens(&get_instruction_template()?) } -/// Creates an OpenAI request for commit message generation. +/// Creates a commit request for the OpenAI API. /// /// # Arguments -/// * `diff` - The git diff to generate a commit message for -/// * `max_tokens` - Maximum number of tokens allowed for the response -/// * `model` - The AI model to use for generation +/// * `diff` - The diff to generate a commit message for +/// * `max_tokens` - The maximum number of tokens to use +/// * `model` - The model to use for generation /// /// # Returns /// * `Result` - The prepared request @@ -55,25 +54,20 @@ fn create_commit_request(diff: String, max_tokens: usize, model: Model) -> Resul }) } -/// Generates a commit message using the AI model. +/// Generates a commit message for the given patch. /// /// # Arguments -/// * `diff` - The git diff to generate a commit message for -/// * `max_tokens` - Maximum number of tokens allowed for the response -/// * `model` - The AI model to use for generation +/// * `patch` - The patch to generate a commit message for +/// * `remaining_tokens` - The maximum number of tokens to use +/// * `model` - The model to use for generation /// /// # Returns -/// * `Result` - The generated commit message or an error -/// -/// # Errors -/// Returns an error if: -/// - max_tokens is 0 -/// - OpenAI API call fails -pub async fn generate(patch: String, remaining_tokens: usize, model: Model) -> Result { +/// * `Result` - The generated commit message or an error +pub async fn generate(patch: String, remaining_tokens: usize, model: Model) -> Result { profile!("Generate commit message"); - if remaining_tokens == 0 { - bail!("Maximum token count must be greater than zero") + if patch.is_empty() { + bail!("No changes to commit"); } let request = create_commit_request(patch, remaining_tokens, model)?; diff --git a/src/filesystem.rs b/src/filesystem.rs index 980bb67..113e6a5 100644 --- a/src/filesystem.rs +++ b/src/filesystem.rs @@ -173,14 +173,16 @@ impl Filesystem { /// * `Result` - The initialized filesystem or an error pub fn new() -> Result { // Get current directory - let current_dir = env::current_dir().context(ERR_CURRENT_DIR)?; + let current_dir = { env::current_dir().context(ERR_CURRENT_DIR)? }; // Get executable path - let git_ai_bin_path = env::current_exe().context("Failed to get current executable")?; + let git_ai_bin_path = { env::current_exe().context("Failed to get current executable")? }; // Open git repository - let repo = Repository::open_ext(¤t_dir, Flags::empty(), Vec::<&Path>::new()) - .with_context(|| format!("Failed to open repository at {}", current_dir.display()))?; + let repo = { + Repository::open_ext(¤t_dir, Flags::empty(), Vec::<&Path>::new()) + .with_context(|| format!("Failed to open repository at {}", current_dir.display()))? + }; // Get git path and ensure it's absolute let git_path = { diff --git a/src/finetune.rs b/src/finetune.rs new file mode 100644 index 0000000..25b74fc --- /dev/null +++ b/src/finetune.rs @@ -0,0 +1,445 @@ +use std::fs; +use std::io::Write; +use std::sync::Arc; +use std::collections::HashSet; + +use anyhow::{Context, Result}; +use colored::*; +use git2::{DiffOptions, Repository}; +use indicatif::{ProgressBar, ProgressStyle}; +use rand::prelude::*; +use serde::{Deserialize, Serialize}; +use structopt::StructOpt; +use tokio::sync::{mpsc, Mutex}; +use tokio::task; +use num_cpus; + +use crate::model::Model; +use crate::openai; + +/// Represents command-line arguments for fine-tuning +#[derive(Debug, Clone, Deserialize, Serialize, StructOpt)] +pub struct FinetuneArgs { + #[structopt(long, default_value = "resources/prompt.md")] + pub prompt_file: String, + + #[structopt(long, default_value = "finetune_train.jsonl")] + pub train_file: String, + + #[structopt(long, default_value = "finetune_verify.jsonl")] + pub verify_file: String, + + #[structopt(long, default_value = "50")] + pub num_commits: u32, + + #[structopt(long)] + pub parallel_requests: Option, + + #[structopt(long, default_value = "0.8")] + pub quality_threshold: f32, + + #[structopt(long)] + pub verbose: bool, + + #[structopt(long, default_value = "5000")] + pub max_diff_size: usize +} + +#[derive(Debug, Serialize, Deserialize)] +struct Message { + role: String, + content: String +} + +#[derive(Debug, Serialize, Deserialize)] +struct TrainingExample { + messages: Vec +} + +/// Track the types of changes in a commit +#[derive(Debug)] +struct CommitChangeTypes { + #[allow(dead_code)] + has_additions: bool, + #[allow(dead_code)] + has_deletions: bool, + #[allow(dead_code)] + has_modifications: bool, + #[allow(dead_code)] + has_renames: bool, + #[allow(dead_code)] + has_file_mode_changes: bool +} + +/// Simple container for commit info +#[derive(Debug)] +struct CommitInfo { + message: String, + diff: String, + #[allow(dead_code)] + change_types: CommitChangeTypes +} + +pub async fn run(args: FinetuneArgs) -> Result<()> { + println!("šŸ”„ Starting fine-tuning data export..."); + + // Reset (truncate) the output files + fs::write(&args.train_file, "")?; + fs::write(&args.verify_file, "")?; + + // Track seen messages to prevent duplicates + let seen_messages = Arc::new(Mutex::new(HashSet::new())); + + // 1. Load system prompt + let prompt_content = + fs::read_to_string(&args.prompt_file).with_context(|| format!("Failed to read prompt file: {}", args.prompt_file))?; + + // 2. Open local repository and setup commit processing + println!("šŸ“š Collecting commit history..."); + let repo = Repository::open(".")?; + let mut revwalk = repo.revwalk()?; + revwalk.push_head()?; + + let mut total_checked = 0; + let mut valid_commits = 0; + let mut commit_data = Vec::new(); + + let collect_pb = ProgressBar::new_spinner(); + collect_pb.set_style( + ProgressStyle::default_spinner() + .template("{spinner:.green} Processing commits: {pos} found ({msg})") + .unwrap() + ); + + // Process commits as we find them + for oid in revwalk { + total_checked += 1; + if let Ok(id) = oid { + if let Ok(commit) = repo.find_commit(id) { + let message = commit.message().unwrap_or(""); + if (20..500).contains(&message.len()) && commit.parent_count() == 1 { + let parent = commit.parent(0)?; + let parent_tree = parent.tree()?; + let commit_tree = commit.tree()?; + let mut diff_opts = DiffOptions::new(); + let diff = repo.diff_tree_to_tree(Some(&parent_tree), Some(&commit_tree), Some(&mut diff_opts))?; + + let mut diff_text = String::new(); + let mut total_diff_size = 0; + let mut should_skip = false; + + diff.print(git2::DiffFormat::Patch, |_, _, line| { + if let Ok(content) = std::str::from_utf8(line.content()) { + total_diff_size += content.len(); + if total_diff_size <= args.max_diff_size { + diff_text.push(line.origin()); + diff_text.push_str(content); + } else { + should_skip = true; + } + } + true + })?; + + if !should_skip { + commit_data.push((message.to_string(), diff_text)); + valid_commits += 1; + collect_pb.set_position(valid_commits as u64); + collect_pb.set_message(format!("latest: {:.40}...", message)); + } + } + } + } + if valid_commits >= args.num_commits as usize * 3 { + break; + } + } + + if args.verbose { + println!(" Checked {} commits, found {} valid ones", total_checked, valid_commits); + } + collect_pb.finish_with_message(format!("Found {} commits to process", valid_commits)); + + // Shuffle the collected commits for randomization + let mut rng = rand::rngs::ThreadRng::default(); + commit_data.shuffle(&mut rng); + let commit_data = Arc::new(commit_data); + + // Setup processing channel + let num_workers = args.parallel_requests.unwrap_or_else(num_cpus::get); + let (tx, mut rx) = mpsc::channel(num_workers * 2); + let approved_commits = Arc::new(Mutex::new(0usize)); + let threshold = args.quality_threshold; + + // Create progress bar for approved commits + let process_pb = ProgressBar::new(args.num_commits as u64); + process_pb.set_style( + ProgressStyle::default_bar() + .template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} approved ({eta})") + .unwrap() + .progress_chars("#>-") + ); + + // Spawn workers for quality checking + let mut workers = Vec::new(); + for worker_id in 0..num_workers { + let tx = tx.clone(); + let approved = Arc::clone(&approved_commits); + let seen = Arc::clone(&seen_messages); + let pb = process_pb.clone(); + let verbose = args.verbose; + let target_commits = args.num_commits; + let commit_data = Arc::clone(&commit_data); + let start_idx = worker_id * commit_data.len() / num_workers; + let end_idx = ((worker_id + 1) * commit_data.len() / num_workers).min(commit_data.len()); + + let worker = task::spawn(async move { + for (message, diff) in commit_data[start_idx..end_idx].iter() { + let current_approved = { + let count = approved.lock().await; + *count + }; + if current_approved >= target_commits as usize { + break; + } + let is_duplicate = { + let mut seen = seen.lock().await; + if seen.contains(message) { + true + } else { + seen.insert(message.clone()); + false + } + }; + if !is_duplicate { + if let Ok(score) = rate_commit_quality(&CommitInfo { + message: message.clone(), + diff: diff.clone(), + change_types: CommitChangeTypes { + has_additions: false, + has_deletions: false, + has_modifications: false, + has_renames: false, + has_file_mode_changes: false + } + }) + .await + { + if score >= threshold { + if let Ok(cleaned_message) = cleanup_commit_message(message).await { + let mut count = approved.lock().await; + *count += 1; + pb.set_position(*count as u64); + if verbose { + println!("āœ“ {} (score: {:.2})", cleaned_message.bright_green(), score); + } + if tx.send((message.clone(), diff.clone())).await.is_err() { + break; + } + } + } + } + } + } + }); + workers.push(worker); + } + drop(tx); + + // Process approved commits + let mut approved_count = 0; + let train_size = args.num_commits / 2; + let mut train_file = fs::OpenOptions::new() + .create(true) + .append(true) + .open(&args.train_file)?; + let mut verify_file = fs::OpenOptions::new() + .create(true) + .append(true) + .open(&args.verify_file)?; + + while let Some((message, diff)) = rx.recv().await { + if approved_count >= args.num_commits as usize { + break; + } + let cleaned_message = cleanup_commit_message(&message).await?; + if cleaned_message.trim().is_empty() { + continue; + } + let is_duplicate = { + let mut seen = seen_messages.lock().await; + if seen.contains(&cleaned_message) { + true + } else { + seen.insert(cleaned_message.clone()); + false + } + }; + if is_duplicate { + continue; + } + // Run scoring on the cleaned output + let cleaned_score = rate_cleaned_commit_message(&cleaned_message).await?; + if args.verbose { + println!("Cleaned: {} (score: {:.2})", cleaned_message, cleaned_score); + } + let example = TrainingExample { + messages: vec![ + Message { + role: "system".to_string(), + content: prompt_content.clone() + }, + Message { role: "user".to_string(), content: diff }, + Message { + role: "assistant".to_string(), + content: cleaned_message + }, + ] + }; + let json = serde_json::to_string(&example)?; + if approved_count < train_size as usize { + writeln!(train_file, "{}", json)?; + } else { + writeln!(verify_file, "{}", json)?; + } + approved_count += 1; + } + + for worker in workers { + worker.await?; + } + process_pb.finish(); + + println!("\n✨ Successfully exported {} training examples:", approved_count); + println!(" - {} training examples in {}", train_size, args.train_file); + println!(" - {} verification examples in {}", args.num_commits - train_size, args.verify_file); + + Ok(()) +} + +/// Cleanup commit message using GPT4oMini +async fn cleanup_commit_message(original_msg: &str) -> Result { + if original_msg.trim().is_empty() { + return Ok(String::new()); + } + let first_line = original_msg + .lines() + .next() + .unwrap_or("") + .trim() + .trim_start_matches("```") + .trim_end_matches("```") + .trim_start_matches("plaintext") + .trim_start_matches("git") + .trim(); + let system_prompt = "\ +You are an expert at cleaning up git commit messages. \ +Your task is to:\n\ +1. Remove any ticket references or extraneous tags\n\ +2. Keep it short, focusing on meaningful description\n\ +3. Do not end the message with a period\n\ +4. Always start with a capitalized verb (Add, Fix, Update, etc)\n\ +5. Drop the type prefix if it is present\n\ +6. Return ONLY the cleaned message without any formatting or backticks"; + let req = openai::Request { + system: system_prompt.to_string(), + prompt: first_line.to_string(), + max_tokens: 100, + model: Model::GPT4oMini + }; + let response = openai::call(req).await?; + let cleaned = response + .trim() + .trim_start_matches("```") + .trim_end_matches("```") + .trim_start_matches("plaintext") + .trim_start_matches("git") + .trim() + .to_string(); + if cleaned.is_empty() + || cleaned.to_lowercase().contains("please") + || cleaned.to_lowercase().contains("provide") + || cleaned.to_lowercase().contains("didn't") + || cleaned.to_lowercase().contains("error") + || cleaned.to_lowercase().contains("missing") + || cleaned.to_lowercase().contains("sorry") + || cleaned.to_lowercase().contains("unable") + || cleaned.to_lowercase().contains("could not") + || cleaned.to_lowercase().contains("cannot") + || cleaned.to_lowercase().contains("failed") + || cleaned.len() > 100 + { + return Ok(String::new()); + } + let message = if cleaned.contains(": ") { + let parts: Vec<&str> = cleaned.splitn(2, ": ").collect(); + parts.get(1).unwrap_or(&cleaned.as_str()).trim().to_string() + } else { + cleaned + }; + let mut chars = message.chars(); + Ok(if let Some(first_char) = chars.next() { + if first_char.is_lowercase() { + first_char.to_uppercase().collect::() + chars.as_str() + } else { + message + } + } else { + message + }) +} + +/// Rate commit quality using GPT4oMini +async fn rate_commit_quality(commit_info: &CommitInfo) -> Result { + let system_prompt = "\ +You are an expert at evaluating git commit quality. Your task is to rate this commit from 0.0 to 1.0 based on: + +1. Commit Message Quality (50% of score): + - Is the first line concise (under 72 chars)? + - If present, is the body descriptive and separated by blank line? + - Is the message present tense? + - Is the message written in the active voice? + - Is the message clear and concise? + +2. Diff Alignment (50% of score): + - Does the message accurately describe the changes in the diff? + - Are all significant changes reflected in the message? + - Is the scope of changes consistent with the message? + +Scoring Guide: +- 0.0-0.3: Poor quality (wrong format, unclear or misleading, conventional commit format) +- 0.4-0.6: Mediocre quality (basic description) +- 0.7-0.8: Good quality (follows format, clear message, mostly aligned with changes) +- 0.9-1.0: Excellent (perfect format and description of changes) + +Return ONLY a number between 0.0 and 1.0"; + let prompt = format!( + "Evaluate this commit:\n\nCommit Message:\n{}\n\nCode Changes:\n{}\n\nScore (0.0-1.0):", + commit_info.message, commit_info.diff + ); + let req = openai::Request { + system: system_prompt.to_string(), + prompt, + max_tokens: 10, + model: Model::GPT4oMini + }; + let response = openai::call(req).await?; + let score = response.trim().parse::().unwrap_or(0.0); + Ok(score.clamp(0.0, 1.0)) +} + +/// Rate cleaned commit message quality using GPT4oMini +async fn rate_cleaned_commit_message(cleaned_message: &str) -> Result { + let system_prompt = "\ +You are an expert at evaluating cleaned git commit messages. Rate the quality of this commit message on a scale from 0.0 to 1.0, based solely on clarity, conciseness, and adherence to conventional commit style guidelines. Return ONLY a number between 0.0 and 1.0."; + let prompt = format!("Cleaned Commit Message:\n{}\nScore (0.0-1.0):", cleaned_message); + let req = openai::Request { + system: system_prompt.to_string(), + prompt, + max_tokens: 10, + model: Model::GPT4oMini + }; + let response = openai::call(req).await?; + let score = response.trim().parse::().unwrap_or(0.0); + Ok(score.clamp(0.0, 1.0)) +} diff --git a/src/install.rs b/src/install.rs new file mode 100644 index 0000000..dc32602 --- /dev/null +++ b/src/install.rs @@ -0,0 +1,18 @@ +use anyhow::{bail, Result}; +use ai::filesystem::Filesystem; + +#[allow(dead_code)] +pub fn run() -> Result<()> { + let fs = Filesystem::new()?; + let hook_bin = fs.git_ai_hook_bin_path()?; + let hook_file = fs.prepare_commit_msg_path()?; + + if hook_file.exists() { + bail!("Hook already exists at {}, please run 'git ai hook reinstall'", hook_file); + } + + hook_file.symlink(&hook_bin)?; + println!("šŸ”— Hook symlinked successfully to {}", hook_file); + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index ce41b1f..001f84f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,20 @@ +#[macro_export] +macro_rules! profile { + ($name:expr) => {{ + let _span = tracing::span!(tracing::Level::DEBUG, $name); + let _enter = _span.enter(); + }}; +} + pub mod commit; pub mod config; +pub mod filesystem; pub mod hook; -pub mod style; pub mod model; -pub mod filesystem; pub mod openai; pub mod profiling; +pub mod style; +pub mod finetune; // Re-exports pub use profiling::Profile; diff --git a/src/main.rs b/src/main.rs index b0f978e..5a2908f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ mod filesystem; use structopt::StructOpt; use anyhow::Result; use dotenv::dotenv; +use ai::finetune::{self, FinetuneArgs}; use crate::config::App; use crate::filesystem::Filesystem; @@ -14,7 +15,9 @@ enum Cli { #[structopt(about = "Installs the git-ai hook")] Hook(HookSubcommand), #[structopt(about = "Sets or gets configuration values")] - Config(ConfigSubcommand) + Config(ConfigSubcommand), + #[structopt(about = "Exports training data for fine-tuning")] + Finetune(FinetuneArgs) } #[derive(StructOpt)] @@ -222,6 +225,9 @@ async fn main() -> Result<()> { } }, }, + Cli::Finetune(args) => { + finetune::run(args).await?; + } } Ok(()) diff --git a/src/openai.rs b/src/openai.rs index 3dab00e..c4611f8 100644 --- a/src/openai.rs +++ b/src/openai.rs @@ -1,13 +1,13 @@ use async_openai::types::{ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs}; use async_openai::config::OpenAIConfig; use async_openai::Client; -use async_openai::error::OpenAIError; use anyhow::{anyhow, Context, Result}; -use colored::*; -use crate::{commit, config, profile}; +use crate::{config, profile}; use crate::model::Model; +const MAX_CONTEXT_LENGTH: usize = 128000; +const BUFFER_TOKENS: usize = 30000; // Large buffer for safety const MAX_ATTEMPTS: usize = 3; #[derive(Debug, Clone, PartialEq)] @@ -24,10 +24,81 @@ pub struct Request { } /// Generates an improved commit message using the provided prompt and diff -pub async fn generate_commit_message(diff: &str) -> Result { +pub async fn generate_commit_message(diff: &str, prompt: &str, file_context: &str, author: &str, date: &str) -> Result { profile!("Generate commit message"); - let response = commit::generate(diff.into(), 256, Model::GPT4oMini).await?; - Ok(response.response.trim().to_string()) + let system_prompt = format!( + "You are an expert at writing clear, concise git commit messages. \ + Your task is to generate a commit message for the following code changes.\n\n\ + {}\n\n\ + Consider:\n\ + - Author: {}\n\ + - Date: {}\n\ + - Files changed: {}\n", + prompt, author, date, file_context + ); + + let response = call(Request { + system: system_prompt, + prompt: format!("Generate a commit message for this diff:\n\n{}", diff), + max_tokens: 256, + model: Model::GPT4oMini + }) + .await?; + + Ok(response.trim().to_string()) +} + +/// Scores a commit message against the original using AI evaluation +pub async fn score_commit_message(message: &str, original: &str) -> Result { + profile!("Score commit message"); + let system_prompt = "You are an expert at evaluating git commit messages. Score the following commit message on these criteria: + - Accuracy (0-1): How well does it describe the actual changes? + - Clarity (0-1): How clear and understandable is the message? + - Brevity (0-1): Is it concise while being informative? + - Categorization (0-1): Does it properly categorize the type of change? + + Return ONLY a JSON object containing these scores and brief feedback."; + + let response = call(Request { + system: system_prompt.to_string(), + prompt: format!("Original commit message:\n{}\n\nGenerated commit message:\n{}", original, message), + max_tokens: 512, + model: Model::GPT4oMini + }) + .await?; + + // Parse the JSON response to get the overall score + let parsed: serde_json::Value = serde_json::from_str(&response).context("Failed to parse scoring response as JSON")?; + + let accuracy = parsed["accuracy"].as_f64().unwrap_or(0.0) as f32; + let clarity = parsed["clarity"].as_f64().unwrap_or(0.0) as f32; + let brevity = parsed["brevity"].as_f64().unwrap_or(0.0) as f32; + let categorization = parsed["categorization"].as_f64().unwrap_or(0.0) as f32; + + Ok((accuracy + clarity + brevity + categorization) / 4.0) +} + +/// Optimizes a prompt based on performance metrics +pub async fn optimize_prompt(current_prompt: &str, performance_metrics: &str) -> Result { + profile!("Optimize prompt"); + let system_prompt = "You are an expert at optimizing prompts for AI systems. \ + Your task is to improve a prompt used for generating git commit messages \ + based on performance metrics. Return ONLY the improved prompt text."; + + let response = call(Request { + system: system_prompt.to_string(), + prompt: format!( + "Current prompt:\n{}\n\nPerformance metrics:\n{}\n\n\ + Suggest an improved version of this prompt that addresses any weaknesses \ + shown in the metrics while maintaining its strengths.", + current_prompt, performance_metrics + ), + max_tokens: 1024, + model: Model::GPT4oMini + }) + .await?; + + Ok(response.trim().to_string()) } fn truncate_to_fit(text: &str, max_tokens: usize, model: &Model) -> Result { @@ -88,21 +159,44 @@ fn truncate_to_fit(text: &str, max_tokens: usize, model: &Model) -> Result Result { +/// Calls the OpenAI API with the given request +/// +/// # Arguments +/// * `request` - The request to send to the API +/// +/// # Returns +/// * `Result` - The response from the API or an error +pub async fn call(request: Request) -> Result { profile!("OpenAI API call"); - let api_key = config::APP.openai_api_key.clone().context(format!( - "{} OpenAI API key not found.\n Run: {}", - "ERROR:".bold().bright_red(), - "git-ai config set openai-api-key ".yellow() - ))?; + + let api_key = config::APP + .openai_api_key + .as_ref() + .ok_or_else(|| anyhow!("OpenAI API key not found. Please set OPENAI_API_KEY environment variable."))?; let config = OpenAIConfig::new().with_api_key(api_key); let client = Client::with_config(config); - // Calculate available tokens using model's context size + for attempt in 1..=MAX_ATTEMPTS { + match try_call(&client, &request).await { + Ok(response) => return Ok(response), + Err(e) if attempt < MAX_ATTEMPTS => { + log::warn!("Attempt {} failed: {}. Retrying...", attempt, e); + continue; + } + Err(e) => return Err(e) + } + } + + Err(anyhow!("Failed after {} attempts", MAX_ATTEMPTS)) +} + +pub async fn try_call(client: &Client, request: &Request) -> Result { + profile!("OpenAI request/response"); + + // Calculate available tokens for content let system_tokens = request.model.count_tokens(&request.system)?; - let model_context_size = request.model.context_size(); - let available_tokens = model_context_size.saturating_sub(system_tokens + request.max_tokens as usize); + let available_tokens = MAX_CONTEXT_LENGTH.saturating_sub(system_tokens + BUFFER_TOKENS + request.max_tokens as usize); // Truncate prompt if needed let truncated_prompt = truncate_to_fit(&request.prompt, available_tokens, &request.model)?; @@ -112,7 +206,7 @@ pub async fn call(request: Request) -> Result { .model(request.model.to_string()) .messages([ ChatCompletionRequestSystemMessageArgs::default() - .content(request.system) + .content(request.system.clone()) .build()? .into(), ChatCompletionRequestUserMessageArgs::default() @@ -122,56 +216,16 @@ pub async fn call(request: Request) -> Result { ]) .build()?; - { - profile!("OpenAI request/response"); - let response = match client.chat().create(request).await { - Ok(response) => response, - Err(err) => { - let error_msg = match err { - OpenAIError::ApiError(e) => - format!( - "{} {}\n {}\n\nDetails:\n {}\n\nSuggested Actions:\n 1. {}\n 2. {}\n 3. {}", - "ERROR:".bold().bright_red(), - "OpenAI API error:".bright_white(), - e.message.dimmed(), - "Failed to create chat completion.".dimmed(), - "Ensure your OpenAI API key is valid".yellow(), - "Check your account credits".yellow(), - "Verify OpenAI service availability".yellow() - ), - OpenAIError::Reqwest(e) => - format!( - "{} {}\n {}\n\nDetails:\n {}\n\nSuggested Actions:\n 1. {}\n 2. {}", - "ERROR:".bold().bright_red(), - "Network error:".bright_white(), - e.to_string().dimmed(), - "Failed to connect to OpenAI API.".dimmed(), - "Check your internet connection".yellow(), - "Verify OpenAI service availability".yellow() - ), - _ => - format!( - "{} {}\n {}\n\nDetails:\n {}\n\nSuggested Actions:\n 1. {}", - "ERROR:".bold().bright_red(), - "Unexpected error:".bright_white(), - err.to_string().dimmed(), - "An unexpected error occurred while calling OpenAI API.".dimmed(), - "Please report this issue on GitHub".yellow() - ), - }; - return Err(anyhow!(error_msg)); - } - }; + let response = client.chat().create(request).await?; - let content = response - .choices - .first() - .context("No response choices available")? - .message - .content - .clone() - .context("Response content is empty")?; + let content = response + .choices + .first() + .context("No response choices available")? + .message + .content + .clone() + .context("Response content is empty")?; - Ok(Response { response: content }) - } + Ok(content) } diff --git a/src/profiling.rs b/src/profiling.rs index 37e4482..e52e661 100644 --- a/src/profiling.rs +++ b/src/profiling.rs @@ -1,34 +1,26 @@ -use std::time::{Duration, Instant}; +use std::time::Instant; -use colored::Colorize; +use tracing::debug; pub struct Profile { - start: Instant, - name: String + name: String, + start: Instant } impl Profile { - pub fn new(name: impl Into) -> Self { - Self { start: Instant::now(), name: name.into() } - } - - pub fn elapsed(&self) -> Duration { - self.start.elapsed() + pub fn new(name: &str) -> Self { + Self { name: name.to_string(), start: Instant::now() } } } impl Drop for Profile { fn drop(&mut self) { - if log::log_enabled!(log::Level::Debug) { - let duration = self.elapsed(); - eprintln!("{}: {:.2?}", self.name.blue(), duration); - } + let elapsed = self.start.elapsed(); + debug!("{} took {:?}", self.name, elapsed); } } -#[macro_export] -macro_rules! profile { - ($name:expr) => { - let _profile = $crate::Profile::new($name); - }; +pub fn span(name: &str) -> Profile { + debug!("Starting {}", name); + Profile::new(name) } diff --git a/src/reinstall.rs b/src/reinstall.rs new file mode 100644 index 0000000..f572d3e --- /dev/null +++ b/src/reinstall.rs @@ -0,0 +1,32 @@ +use console::Emoji; +use anyhow::Result; +use ai::filesystem::Filesystem; +use colored::*; + +#[allow(dead_code)] +const EMOJI: Emoji<'_, '_> = Emoji("šŸ”—", ""); + +#[allow(dead_code)] +pub fn run() -> Result<()> { + let fs = Filesystem::new()?; + let hook_bin = fs.git_ai_hook_bin_path()?; + let hook_file = fs.prepare_commit_msg_path()?; + + if !fs.git_hooks_path().exists() { + fs.git_hooks_path().create_dir_all()?; + } + + if hook_file.exists() { + log::debug!("Removing existing hook file: {}", hook_file); + hook_file.delete()?; + } + + hook_file.symlink(&hook_bin)?; + + println!( + "{EMOJI} Hook symlinked successfully to {}", + hook_file.relative_path()?.to_string().italic() + ); + + Ok(()) +} diff --git a/src/uninstall.rs b/src/uninstall.rs new file mode 100644 index 0000000..09f6b3b --- /dev/null +++ b/src/uninstall.rs @@ -0,0 +1,45 @@ +use std::path::{Path, PathBuf}; +use std::{env, fs}; + +use anyhow::{bail, Context, Result}; +use ai::style::Styled; +use colored::Colorize; +use console::Emoji; +use git2::{Repository, RepositoryOpenFlags as Flags}; +use thiserror::Error; + +#[derive(Error, Debug)] +#[allow(dead_code)] +pub enum InstallError { + #[error("Failed to get current directory")] + CurrentDir, + #[error("Failed to open repository")] + OpenRepo, + #[error("Hook already exists: {0:?}")] + HookExists(PathBuf) +} + +#[allow(dead_code)] +const EMOJI: Emoji<'_, '_> = Emoji("šŸ”—", ""); + +#[allow(dead_code)] +pub fn run() -> Result<()> { + let current_dir = env::current_dir().context(InstallError::CurrentDir)?; + let repo = Repository::open_ext(current_dir, Flags::empty(), Vec::<&Path>::new()).context(InstallError::OpenRepo)?; + + let hook_dir = PathBuf::from(repo.path()).join("hooks"); + let hook_file = hook_dir.join("prepare-commit-msg"); + + if !hook_file.exists() { + bail!(InstallError::HookExists(hook_file)); + } + + fs::remove_file(&hook_file).context("Failed to remove hook file")?; + + println!( + "{EMOJI} Hook uninstall successfully from {}", + hook_file.relative_path().display().to_string().italic() + ); + + Ok(()) +}