Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions crates/cloud_llm_client/src/predict_edits_v3.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use chrono::Duration;
use serde::{Deserialize, Serialize};
use std::{
ops::Range,
ops::{Add, Range, Sub},
path::{Path, PathBuf},
sync::Arc,
};
Expand All @@ -18,8 +18,8 @@ pub struct PredictEditsRequest {
pub excerpt_path: Arc<Path>,
/// Within file
pub excerpt_range: Range<usize>,
/// Within `excerpt`
pub cursor_offset: usize,
pub excerpt_line_range: Range<Line>,
pub cursor_point: Point,
/// Within `signatures`
pub excerpt_parent: Option<usize>,
pub signatures: Vec<Signature>,
Expand Down Expand Up @@ -47,12 +47,13 @@ pub struct PredictEditsRequest {
pub enum PromptFormat {
MarkedExcerpt,
LabeledSections,
NumberedLines,
/// Prompt format intended for use via zeta_cli
OnlySnippets,
}

impl PromptFormat {
pub const DEFAULT: PromptFormat = PromptFormat::LabeledSections;
pub const DEFAULT: PromptFormat = PromptFormat::NumberedLines;
}

impl Default for PromptFormat {
Expand All @@ -73,6 +74,7 @@ impl std::fmt::Display for PromptFormat {
PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
PromptFormat::NumberedLines => write!(f, "Numbered Lines"),
}
}
}
Expand All @@ -97,7 +99,7 @@ pub struct Signature {
pub parent_index: Option<usize>,
/// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
/// file is implicitly the file that contains the descendant declaration or excerpt.
pub range: Range<usize>,
pub range: Range<Line>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -106,7 +108,7 @@ pub struct ReferencedDeclaration {
pub text: String,
pub text_is_truncated: bool,
/// Range of `text` within file, possibly truncated according to `text_is_truncated`
pub range: Range<usize>,
pub range: Range<Line>,
/// Range within `text`
pub signature_range: Range<usize>,
/// Index within `signatures`.
Expand Down Expand Up @@ -169,10 +171,36 @@ pub struct DebugInfo {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Edit {
pub path: Arc<Path>,
pub range: Range<usize>,
pub range: Range<Line>,
pub content: String,
}

fn is_default<T: Default + PartialEq>(value: &T) -> bool {
*value == T::default()
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
pub struct Point {
pub line: Line,
pub column: u32,
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
#[serde(transparent)]
pub struct Line(pub u32);

impl Add for Line {
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
Self(self.0 + rhs.0)
}
}

impl Sub for Line {
type Output = Self;

fn sub(self, rhs: Self) -> Self::Output {
Self(self.0 - rhs.0)
}
}
Loading
Loading