diff --git a/.gitignore b/.gitignore index d6e62a6..6154c91 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,6 @@ /example-data **.DS_Store **.env + +# SwiftPM scratch/index dirs (editor indexing creates these; cargo builds use target/) +.build/ diff --git a/Cargo.lock b/Cargo.lock index 2ef14f6..da8f6c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -751,6 +751,20 @@ dependencies = [ "vob", ] +[[package]] +name = "chat-applefm" +version = "0.1.0" +dependencies = [ + "async-stream", + "async-trait", + "chat-core", + "futures", + "schemars 1.2.1", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "chat-cerebras" version = "0.2.5" @@ -943,6 +957,7 @@ name = "chat-rs" version = "0.5.2" dependencies = [ "async-trait", + "chat-applefm", "chat-cerebras", "chat-claude", "chat-completions", diff --git a/Cargo.toml b/Cargo.toml index f018959..4349068 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ members = [ "providers/mistralrs", "providers/openrouter", "providers/router", + "providers/applefm", ] [workspace.package] @@ -73,6 +74,7 @@ chat-deepseek = { path = "providers/deepseek", version = "0.1.3", optional = chat-mistralrs = { path = "providers/mistralrs", version = "0.1.5", optional = true } chat-openrouter = { path = "providers/openrouter", version = "0.1.0", optional = true } chat-router = { path = "providers/router", version = "0.2.4", optional = true } +chat-applefm = { path = "providers/applefm", version = "0.1.0", optional = true } tokio.workspace = true schemars.workspace = true serde.workspace = true @@ -115,8 +117,9 @@ deepseek = ["dep:chat-deepseek", "dep:chat-completions"] mistralrs = ["dep:chat-mistralrs"] openrouter = ["dep:chat-openrouter"] router = ["dep:chat-router"] +applefm = ["dep:chat-applefm"] python = ["tools-rs/python"] -stream = ["chat-core/stream", "chat-completions?/stream", "chat-responses?/stream", "chat-gemini?/stream", "chat-openai?/stream", "chat-claude?/stream", "chat-ollama?/stream", "chat-huggingface?/stream", "chat-cerebras?/stream", "chat-deepseek?/stream", "chat-mistralrs?/stream", "chat-openrouter?/stream", "chat-router?/stream"] +stream = ["chat-core/stream", "chat-completions?/stream", "chat-responses?/stream", "chat-gemini?/stream", "chat-openai?/stream", "chat-claude?/stream", "chat-ollama?/stream", "chat-huggingface?/stream", "chat-cerebras?/stream", "chat-deepseek?/stream", "chat-mistralrs?/stream", "chat-openrouter?/stream", "chat-router?/stream", "chat-applefm?/stream"] tungstenite = ["chat-core/tungstenite"] tokio-tungstenite = ["chat-core/tokio-tungstenite"] @@ -341,6 +344,21 @@ name = "openai-image-generation" path = "./examples/openai/image_generation.rs" required-features = ["openai"] +[[example]] +name = "applefm-availability" +path = "./examples/applefm/availability.rs" +required-features = ["applefm"] + +[[example]] +name = "applefm-completion" +path = "./examples/applefm/completion.rs" +required-features = ["applefm"] + +[[example]] +name = "applefm-stream" +path = "./examples/applefm/stream.rs" +required-features = ["applefm", "stream"] + [[example]] name = "router-embeddings" path = "./examples/router/embeddings.rs" diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..672d14e --- /dev/null +++ b/build.rs @@ -0,0 +1,13 @@ +fn main() { + // chat-applefm embeds a Swift bridge whose concurrency runtime is + // referenced as @rpath/libswift_Concurrency.dylib. Final binaries + // must carry an rpath to the OS Swift runtime; link-args don't + // propagate from dependency build scripts, so the binary-producing + // package emits it. Downstream binary crates using the `applefm` + // feature need this same line in their own build.rs. + if std::env::var_os("CARGO_FEATURE_APPLEFM").is_some() + && std::env::var("CARGO_CFG_TARGET_OS").as_deref() == Ok("macos") + { + println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift"); + } +} diff --git a/examples/applefm/availability.rs b/examples/applefm/availability.rs new file mode 100644 index 0000000..0451028 --- /dev/null +++ b/examples/applefm/availability.rs @@ -0,0 +1,21 @@ +//! Probe whether this machine can run the Apple on-device foundation model. +//! +//! Run with: +//! ```sh +//! cargo run --example applefm-availability --features applefm +//! ``` +//! +//! Exits 0 when the model is usable, 1 otherwise (with the reason). + +fn main() { + let probe = chat_rs::applefm::availability(); + if probe.available { + println!("Apple on-device model: AVAILABLE"); + } else { + println!( + "Apple on-device model: UNAVAILABLE — {}", + probe.reason.as_deref().unwrap_or("no reason given") + ); + std::process::exit(1); + } +} diff --git a/examples/applefm/completion.rs b/examples/applefm/completion.rs new file mode 100644 index 0000000..c30e70f --- /dev/null +++ b/examples/applefm/completion.rs @@ -0,0 +1,53 @@ +//! One transcript correction through the Apple on-device model. +//! +//! Run with: +//! ```sh +//! cargo run --example applefm-completion --features applefm +//! # with a LoRA fine-tune: +//! APPLEFM_LORA=path/to/transcripts.fmadapter \ +//! cargo run --example applefm-completion --features applefm +//! ``` + +use chat_rs::{ChatBuilder, applefm::AppleFMBuilder, parts, types::messages, types::messages::content}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let probe = chat_rs::applefm::availability(); + if !probe.available { + eprintln!( + "Apple on-device model unavailable: {}", + probe.reason.as_deref().unwrap_or("no reason given") + ); + std::process::exit(1); + } + + let mut builder = AppleFMBuilder::new(); + if let Ok(lora) = std::env::var("APPLEFM_LORA") { + println!("Using LoRA adapter: {lora}"); + builder = builder.with_lora(lora); + } + + let client = builder.build().map_err(|err| err.err)?; + let mut chat = ChatBuilder::new().with_model(client).build(); + + let mut messages = messages::Messages::default(); + messages.push(content::from_system(parts![ + "You correct raw speech transcripts: fix disfluencies, casing and \ + punctuation without changing meaning. Reply with the corrected text only." + ])); + messages.push(content::from_user(parts![ + "so um the meeting is uh moved to thrusday at 3 pee em tell uh tell dario" + ])); + + let response = chat + .complete(&mut messages) + .await + .map_err(|err| err.err)? + .expect_complete(); + + if let Some(text) = response.content.parts.text_response() { + println!("Corrected: {text}"); + } + println!("Metadata: {:?}", response.metadata); + Ok(()) +} diff --git a/examples/applefm/stream.rs b/examples/applefm/stream.rs new file mode 100644 index 0000000..bc468c7 --- /dev/null +++ b/examples/applefm/stream.rs @@ -0,0 +1,57 @@ +//! Stream a transcript correction from the Apple on-device model. +//! +//! Run with: +//! ```sh +//! cargo run --example applefm-stream --features applefm,stream +//! ``` + +use std::io::Write; + +use chat_rs::{ + ChatBuilder, StreamEvent, + applefm::AppleFMBuilder, + parts, + types::messages::{self, content}, +}; +use futures::StreamExt; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let probe = chat_rs::applefm::availability(); + if !probe.available { + eprintln!( + "Apple on-device model unavailable: {}", + probe.reason.as_deref().unwrap_or("no reason given") + ); + std::process::exit(1); + } + + let client = AppleFMBuilder::new().build().map_err(|err| err.err)?; + let mut chat = ChatBuilder::new().with_model(client).build(); + + let mut messages = messages::Messages::default(); + messages.push(content::from_system(parts![ + "You correct raw speech transcripts: fix disfluencies, casing and \ + punctuation without changing meaning. Reply with the corrected text only." + ])); + messages.push(content::from_user(parts![ + "so um the meeting is uh moved to thrusday at 3 pee em tell uh tell dario" + ])); + + let mut stream = chat.stream(&mut messages).await.map_err(|err| err.err)?; + print!("Corrected: "); + while let Some(event) = stream.next().await { + match event? { + StreamEvent::TextChunk(text) => { + print!("{text}"); + std::io::stdout().flush()?; + } + StreamEvent::Done(response) => { + println!(); + println!("Metadata: {:?}", response.metadata); + } + _ => {} + } + } + Ok(()) +} diff --git a/providers/AGENTS.md b/providers/AGENTS.md index 3c58796..c665a49 100644 --- a/providers/AGENTS.md +++ b/providers/AGENTS.md @@ -13,6 +13,7 @@ All provider crates live under `providers/`. Each implements the core traits (`C | `chat-ollama` | `providers/ollama` | — (optional) | Wraps `chat-completions`; adds Ollama's native `/api/pull` and `OLLAMA_HOST` defaults | | `chat-huggingface` | `providers/huggingface` | `HF_TOKEN` | Wraps `chat-completions` for HF Inference Providers Router (`https://router.huggingface.co/v1`) | | `chat-cerebras` | `providers/cerebras` | `CEREBRAS_API_KEY` | Wraps `chat-completions` for Cerebras Inference (`https://api.cerebras.ai/v1`) | +| `chat-applefm` | `providers/applefm` | — (on-device) | Apple FoundationModels via embedded Swift bridge (no HTTP); macOS 26+, optional `.fmadapter` LoRA | ### Wire-spec crates vs provider wrappers diff --git a/providers/applefm/Cargo.toml b/providers/applefm/Cargo.toml new file mode 100644 index 0000000..f1e0ba8 --- /dev/null +++ b/providers/applefm/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "chat-applefm" +version = "0.1.0" +description = "Apple on-device foundation model provider for chat-rs, built on the FoundationModels framework." +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +chat-core = { path = "../../core", version = "0.4.0" } +async-trait.workspace = true +schemars.workspace = true +serde.workspace = true +serde_json.workspace = true +tokio = { workspace = true, features = ["rt"] } +async-stream = { workspace = true, optional = true } +futures = { workspace = true, optional = true } + +[features] +default = [] +stream = ["chat-core/stream", "dep:async-stream", "dep:futures"] diff --git a/providers/applefm/bridge/Package.swift b/providers/applefm/bridge/Package.swift new file mode 100644 index 0000000..6a9d46a --- /dev/null +++ b/providers/applefm/bridge/Package.swift @@ -0,0 +1,17 @@ +// swift-tools-version: 6.0 +import PackageDescription + +let package = Package( + name: "AppleFMBridge", + // Minimum macOS 15 so host binaries still run on machines without the + // FoundationModels framework; every use of the framework is guarded + // with `#available(macOS 26.0, *)` and the framework is weak-linked + // by the crate's build script. + platforms: [.macOS(.v15)], + products: [ + .library(name: "AppleFMBridge", type: .static, targets: ["AppleFMBridge"]) + ], + targets: [ + .target(name: "AppleFMBridge") + ] +) diff --git a/providers/applefm/bridge/Sources/AppleFMBridge/Complete.swift b/providers/applefm/bridge/Sources/AppleFMBridge/Complete.swift new file mode 100644 index 0000000..2648fdb --- /dev/null +++ b/providers/applefm/bridge/Sources/AppleFMBridge/Complete.swift @@ -0,0 +1,66 @@ +// One blocking completion against the on-device model. +// +// Stateless by design: every call builds the model (base or LoRA-adapted), +// a fresh session, and renders the full conversation — correctness first; +// session reuse is a later optimization. Callers invoke this from a +// non-main thread (the Rust side uses spawn_blocking), so blocking on a +// semaphore here is safe. + +import Foundation +#if canImport(FoundationModels) +import FoundationModels +#endif + +func completeJSON(_ requestJSON: String) -> String { + #if canImport(FoundationModels) + if #available(macOS 26.0, *) { + guard let data = requestJSON.data(using: .utf8), + let request = try? JSONDecoder().decode(CompleteRequest.self, from: data) + else { + return errorJSON("decode", "malformed request JSON") + } + return runComplete(request) + } else { + return errorJSON("unavailable", "macOS 26 or later is required") + } + #else + return errorJSON("unavailable", "built against an SDK without the FoundationModels framework") + #endif +} + +#if canImport(FoundationModels) +@available(macOS 26.0, *) +private func runComplete(_ request: CompleteRequest) -> String { + final class Box: @unchecked Sendable { var reply = "" } + let box = Box() + let semaphore = DispatchSemaphore(value: 0) + + Task.detached { + box.reply = await respond(to: request) + semaphore.signal() + } + semaphore.wait() + return box.reply +} + +@available(macOS 26.0, *) +private func respond(to request: CompleteRequest) async -> String { + let session: LanguageModelSession + switch makeSession(for: request) { + case .failure(let body): + return encodeJSON(ErrorReply(error: body)) + case .success(let s): + session = s + } + + do { + let response = try await session.respond( + to: renderPrompt(request.messages), + options: generationOptions(request.options) + ) + return encodeJSON(CompleteReply(text: response.content, finish: "stop")) + } catch { + return errorJSON("generation", String(describing: error)) + } +} +#endif diff --git a/providers/applefm/bridge/Sources/AppleFMBridge/Exports.swift b/providers/applefm/bridge/Sources/AppleFMBridge/Exports.swift new file mode 100644 index 0000000..a86831a --- /dev/null +++ b/providers/applefm/bridge/Sources/AppleFMBridge/Exports.swift @@ -0,0 +1,100 @@ +// The C surface of the bridge. +// +// This file is the only place where Rust and Swift meet: each +// `@_cdecl` function below is callable as a plain C function from the +// crate's `src/ffi.rs`. Payloads cross the boundary as JSON strings — +// the same idea as an HTTP provider's wire format, minus the network. +// +// Memory contract: every `char*` returned here is `strdup`-allocated +// and must be released by the caller via `afm_string_free`. + +import Foundation +#if canImport(FoundationModels) +import FoundationModels +#endif + +private func cString(_ s: String) -> UnsafeMutablePointer? { + strdup(s) +} + +private func availabilityJSON() -> String { + #if canImport(FoundationModels) + if #available(macOS 26.0, *) { + switch SystemLanguageModel.default.availability { + case .available: + return #"{"available":true}"# + case .unavailable(let reason): + let why: String + switch reason { + case .deviceNotEligible: + why = "this device is not eligible for Apple Intelligence" + case .appleIntelligenceNotEnabled: + why = "Apple Intelligence is not enabled in System Settings" + case .modelNotReady: + why = "model assets are not downloaded yet; retry once Apple Intelligence finishes setting up" + @unknown default: + why = "the model is unavailable for an unknown reason" + } + return #"{"available":false,"reason":"\#(why)"}"# + @unknown default: + return #"{"available":false,"reason":"unrecognized availability state"}"# + } + } else { + return #"{"available":false,"reason":"macOS 26 or later is required"}"# + } + #else + return #"{"available":false,"reason":"built against an SDK without the FoundationModels framework"}"# + #endif +} + +/// Probe whether the on-device Apple foundation model can be used. +/// Returns `{"available": bool, "reason"?: string}`. +@_cdecl("afm_availability") +public func afm_availability() -> UnsafeMutablePointer? { + cString(availabilityJSON()) +} + +/// Run one completion. Takes a `CompleteRequest` JSON, returns a +/// `CompleteReply` or `ErrorReply` JSON. Blocking — call from a worker +/// thread. +@_cdecl("afm_complete") +public func afm_complete(_ requestJSON: UnsafePointer?) -> UnsafeMutablePointer? { + guard let requestJSON else { + return cString(errorJSON("decode", "null request")) + } + return cString(completeJSON(String(cString: requestJSON))) +} + +/// Run one streaming completion. Takes a `CompleteRequest` JSON and a +/// callback invoked once per event JSON (`delta` / `done` / `error`). +/// Blocks until the stream finishes; the callback may be invoked from a +/// different thread and the event pointer is only valid for the duration +/// of each invocation (copy it out). +@_cdecl("afm_respond_stream") +public func afm_respond_stream( + _ requestJSON: UnsafePointer?, + _ onEvent: (@convention(c) (UnsafeMutableRawPointer?, UnsafePointer?) -> Void)?, + _ context: UnsafeMutableRawPointer? +) { + guard let onEvent else { return } + + struct RawContext: @unchecked Sendable { + let pointer: UnsafeMutableRawPointer? + } + let rawContext = RawContext(pointer: context) + let emit: @Sendable (String) -> Void = { event in + event.withCString { onEvent(rawContext.pointer, $0) } + } + + guard let requestJSON else { + emit(encodeJSON(StreamErrorEvent(error: ErrorBody(kind: "decode", message: "null request")))) + return + } + streamJSON(String(cString: requestJSON), emit: emit) +} + +/// Release a string previously returned by this bridge. +@_cdecl("afm_string_free") +public func afm_string_free(_ ptr: UnsafeMutablePointer?) { + free(ptr) +} diff --git a/providers/applefm/bridge/Sources/AppleFMBridge/Session.swift b/providers/applefm/bridge/Sources/AppleFMBridge/Session.swift new file mode 100644 index 0000000..b4acbc7 --- /dev/null +++ b/providers/applefm/bridge/Sources/AppleFMBridge/Session.swift @@ -0,0 +1,75 @@ +// Shared session plumbing for the complete and stream paths: model +// construction (base or LoRA-adapted), availability gating, prompt +// rendering, and generation options. + +import Foundation +#if canImport(FoundationModels) +import FoundationModels + +/// Build the model (applying a `.fmadapter` LoRA when requested) and a +/// session over it. Errors come back as an `ErrorBody` so each caller +/// can wrap them in its own reply shape. +@available(macOS 26.0, *) +func makeSession(for request: CompleteRequest) -> Result { + let model: SystemLanguageModel + if let loraPath = request.lora { + do { + let adapter = try SystemLanguageModel.Adapter( + fileURL: URL(fileURLWithPath: loraPath)) + model = SystemLanguageModel(adapter: adapter) + } catch { + return .failure( + ErrorBody( + kind: "adapter", + message: "could not load .fmadapter at \(loraPath): \(error)")) + } + } else { + model = SystemLanguageModel.default + } + + guard case .available = model.availability else { + return .failure( + ErrorBody( + kind: "unavailable", + message: "the on-device model is not available on this machine")) + } + + if let instructions = request.instructions { + return .success(LanguageModelSession(model: model, instructions: instructions)) + } + return .success(LanguageModelSession(model: model)) +} + +/// Single user turn → pass the text straight through. Multi-turn → render +/// a role-tagged transcript (v1 flattening; native Transcript +/// reconstruction arrives with tool support). +func renderPrompt(_ messages: [WireMessage]) -> String { + if messages.count == 1, let only = messages.first { + return only.text + } + return messages.map { message in + let tag = message.role == "assistant" ? "Assistant" : "User" + return "\(tag): \(message.text)" + }.joined(separator: "\n\n") +} + +@available(macOS 26.0, *) +func generationOptions(_ options: WireOptions?) -> GenerationOptions { + guard let options else { return GenerationOptions() } + + var sampling: GenerationOptions.SamplingMode? + if options.greedy == true { + sampling = .greedy + } else if let k = options.top_k { + sampling = .random(top: k, seed: options.seed) + } else if let p = options.top_p { + sampling = .random(probabilityThreshold: p, seed: options.seed) + } + + return GenerationOptions( + sampling: sampling, + temperature: options.temperature, + maximumResponseTokens: options.max_tokens + ) +} +#endif diff --git a/providers/applefm/bridge/Sources/AppleFMBridge/Stream.swift b/providers/applefm/bridge/Sources/AppleFMBridge/Stream.swift new file mode 100644 index 0000000..de1bc6f --- /dev/null +++ b/providers/applefm/bridge/Sources/AppleFMBridge/Stream.swift @@ -0,0 +1,93 @@ +// Streaming completion against the on-device model. +// +// FoundationModels streams *cumulative snapshots* ("The", "The meeting", +// …), while the wire protocol carries *deltas*. This file diffs +// consecutive snapshots and emits delta events through the caller's +// callback, ending with a `done` event carrying the authoritative full +// text (so downstream state is correct even if a snapshot revised +// earlier output). +// +// Blocking like the complete path: the function returns only after the +// stream finishes. Events are emitted from the streaming task's thread — +// the callback must be thread-safe (the Rust side feeds a channel). + +import Foundation +#if canImport(FoundationModels) +import FoundationModels +#endif + +func streamJSON(_ requestJSON: String, emit: @escaping @Sendable (String) -> Void) { + #if canImport(FoundationModels) + if #available(macOS 26.0, *) { + guard let data = requestJSON.data(using: .utf8), + let request = try? JSONDecoder().decode(CompleteRequest.self, from: data) + else { + emit(encodeJSON(StreamErrorEvent(error: ErrorBody( + kind: "decode", message: "malformed request JSON")))) + return + } + runStream(request, emit: emit) + } else { + emit(encodeJSON(StreamErrorEvent(error: ErrorBody( + kind: "unavailable", message: "macOS 26 or later is required")))) + } + #else + emit(encodeJSON(StreamErrorEvent(error: ErrorBody( + kind: "unavailable", + message: "built against an SDK without the FoundationModels framework")))) + #endif +} + +#if canImport(FoundationModels) +@available(macOS 26.0, *) +private func runStream(_ request: CompleteRequest, emit: @escaping @Sendable (String) -> Void) { + let semaphore = DispatchSemaphore(value: 0) + + Task.detached { + defer { semaphore.signal() } + + let session: LanguageModelSession + switch makeSession(for: request) { + case .failure(let body): + emit(encodeJSON(StreamErrorEvent(error: body))) + return + case .success(let s): + session = s + } + + do { + let stream = session.streamResponse( + to: renderPrompt(request.messages), + options: generationOptions(request.options) + ) + + var previous = "" + for try await snapshot in stream { + let current: String = snapshot.content + let delta = suffixDelta(previous: previous, current: current) + previous = current + if !delta.isEmpty { + emit(encodeJSON(StreamDeltaEvent(text: delta))) + } + } + emit(encodeJSON(StreamDoneEvent(text: previous, finish: "stop"))) + } catch { + emit(encodeJSON(StreamErrorEvent(error: ErrorBody( + kind: "generation", message: String(describing: error))))) + } + } + semaphore.wait() +} + +/// Delta between cumulative snapshots. Normally `current` extends +/// `previous`; when a snapshot revises earlier text instead, fall back to +/// everything after the longest common prefix (best effort — the `done` +/// event carries the authoritative full text). +private func suffixDelta(previous: String, current: String) -> String { + if current.hasPrefix(previous) { + return String(current.dropFirst(previous.count)) + } + let common = zip(previous, current).prefix { $0 == $1 }.count + return String(current.dropFirst(common)) +} +#endif diff --git a/providers/applefm/bridge/Sources/AppleFMBridge/WireTypes.swift b/providers/applefm/bridge/Sources/AppleFMBridge/WireTypes.swift new file mode 100644 index 0000000..2a3d685 --- /dev/null +++ b/providers/applefm/bridge/Sources/AppleFMBridge/WireTypes.swift @@ -0,0 +1,76 @@ +// Codable mirrors of the JSON wire protocol defined in +// `src/api/types/mod.rs`. Keep the two in sync — this boundary is the +// provider's equivalent of an HTTP provider's wire format. + +import Foundation + +struct WireOptions: Codable, Sendable { + var temperature: Double? + var max_tokens: Int? + // Sampling mode, flattened. Picked as greedy > top_k > top_p. + var greedy: Bool? + var top_k: Int? + var top_p: Double? + var seed: UInt64? +} + +struct WireMessage: Codable, Sendable { + /// "user" | "assistant" + var role: String + var text: String +} + +struct CompleteRequest: Codable, Sendable { + var instructions: String? + /// Filesystem path to a `.fmadapter` package (a trained LoRA for the + /// on-device base model). Absent → the plain base model. + var lora: String? + var messages: [WireMessage] + var options: WireOptions? +} + +struct CompleteReply: Codable, Sendable { + var text: String + /// "stop" | "max_tokens" + var finish: String +} + +struct ErrorBody: Codable, Sendable, Error { + var kind: String + var message: String +} + +struct ErrorReply: Codable, Sendable { + var error: ErrorBody +} + +// Stream events, discriminated by `type`. + +struct StreamDeltaEvent: Codable, Sendable { + var type = "delta" + var text: String +} + +struct StreamDoneEvent: Codable, Sendable { + var type = "done" + var text: String + var finish: String +} + +struct StreamErrorEvent: Codable, Sendable { + var type = "error" + var error: ErrorBody +} + +func encodeJSON(_ value: T) -> String { + guard let data = try? JSONEncoder().encode(value), + let s = String(data: data, encoding: .utf8) + else { + return #"{"error":{"kind":"internal","message":"reply encoding failed"}}"# + } + return s +} + +func errorJSON(_ kind: String, _ message: String) -> String { + encodeJSON(ErrorReply(error: ErrorBody(kind: kind, message: message))) +} diff --git a/providers/applefm/build.rs b/providers/applefm/build.rs new file mode 100644 index 0000000..b41418d --- /dev/null +++ b/providers/applefm/build.rs @@ -0,0 +1,79 @@ +//! Compiles the Swift bridge (`bridge/`) and links it into the crate. +//! +//! The bridge is the only piece of this provider that talks to Apple's +//! FoundationModels framework — see `bridge/Sources/AppleFMBridge`. It is +//! built with SwiftPM into a static library and linked here, so users only +//! ever run `cargo build`. +//! +//! On non-macOS targets (or when `APPLEFM_SKIP_BRIDGE` is set, or on +//! docs.rs) the bridge is not built and the crate falls back to a stub +//! that reports the model as unavailable. This keeps the workspace +//! compiling everywhere without Xcode. + +use std::env; +use std::path::PathBuf; +use std::process::Command; + +fn main() { + println!("cargo:rustc-check-cfg=cfg(applefm_bridge)"); + println!("cargo:rerun-if-changed=bridge/Sources"); + println!("cargo:rerun-if-changed=bridge/Package.swift"); + println!("cargo:rerun-if-env-changed=APPLEFM_SKIP_BRIDGE"); + + let target_os = env::var("CARGO_CFG_TARGET_OS").unwrap_or_default(); + let docsrs = env::var("DOCS_RS").is_ok(); + let skip = env::var("APPLEFM_SKIP_BRIDGE").is_ok(); + if target_os != "macos" || docsrs || skip { + // Stub mode: `cfg(applefm_bridge)` is not emitted. + return; + } + + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let scratch = out_dir.join("bridge-build"); + + let status = Command::new("swift") + .arg("build") + .args(["-c", "release"]) + .arg("--package-path") + .arg(manifest_dir.join("bridge")) + .arg("--scratch-path") + .arg(&scratch) + .status() + .expect( + "failed to invoke `swift build` — building chat-applefm on macOS \ + requires the Xcode command line tools (xcode-select --install)", + ); + assert!( + status.success(), + "swift build of the AppleFMBridge package failed" + ); + + println!( + "cargo:rustc-link-search=native={}", + scratch.join("release").display() + ); + println!("cargo:rustc-link-lib=static=AppleFMBridge"); + + // Swift runtime stubs from the SDK so the autolinked -lswiftCore etc. + // resolve at link time. At run time the OS runtime in /usr/lib/swift + // is used. + let sdk_path = Command::new("xcrun") + .args(["--show-sdk-path"]) + .output() + .expect("failed to invoke `xcrun --show-sdk-path`"); + let sdk_path = String::from_utf8_lossy(&sdk_path.stdout).trim().to_string(); + println!("cargo:rustc-link-search=native={sdk_path}/usr/lib/swift"); + println!("cargo:rustc-link-search=native=/usr/lib/swift"); + + // Weak-link so binaries still load on macOS < 26; the bridge guards + // every use behind `#available` and reports unavailability instead. + println!("cargo:rustc-link-arg=-Wl,-weak_framework,FoundationModels"); + + // The Swift concurrency runtime is referenced as + // @rpath/libswift_Concurrency.dylib; point the rpath at the OS + // runtime so binaries resolve it from the dyld shared cache. + println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift"); + + println!("cargo:rustc-cfg=applefm_bridge"); +} diff --git a/providers/applefm/src/api/completion.rs b/providers/applefm/src/api/completion.rs new file mode 100644 index 0000000..6bba548 --- /dev/null +++ b/providers/applefm/src/api/completion.rs @@ -0,0 +1,45 @@ +use async_trait::async_trait; +use chat_core::error::{ChatError, ChatFailure}; +use chat_core::traits::CompletionProvider; +use chat_core::types::messages::Messages; +use chat_core::types::options::ChatOptions; +use chat_core::types::provider_meta::ProviderMeta; +use chat_core::types::response::ChatResponse; +use chat_core::types::tools::ToolDeclarations; + +use crate::api::types::{request, response}; +use crate::client::AppleFMClient; +use crate::ffi; + +#[async_trait] +impl CompletionProvider for AppleFMClient { + async fn complete( + &mut self, + messages: &mut Messages, + tool_declarations: Option<&dyn ToolDeclarations>, + options: Option<&ChatOptions>, + structured_output: Option<&schemars::Schema>, + ) -> Result { + let request_json = request::from_core( + &self.config, + messages, + options, + structured_output, + tool_declarations.is_some(), + )?; + + // The bridge call blocks (model inference); keep it off the + // async workers. + let reply_json = tokio::task::spawn_blocking(move || ffi::complete_json(&request_json)) + .await + .map_err(|e| { + ChatFailure::from_err(ChatError::Other(format!("bridge task failed: {e}"))) + })?; + + response::into_core(&self.model_slug(), &reply_json) + } + + fn metadata(&self) -> Option<&ProviderMeta> { + Some(&self.meta) + } +} diff --git a/providers/applefm/src/api/mod.rs b/providers/applefm/src/api/mod.rs new file mode 100644 index 0000000..b98b881 --- /dev/null +++ b/providers/applefm/src/api/mod.rs @@ -0,0 +1,4 @@ +pub(crate) mod completion; +#[cfg(feature = "stream")] +pub(crate) mod stream; +pub(crate) mod types; diff --git a/providers/applefm/src/api/stream.rs b/providers/applefm/src/api/stream.rs new file mode 100644 index 0000000..1c76f25 --- /dev/null +++ b/providers/applefm/src/api/stream.rs @@ -0,0 +1,84 @@ +use async_trait::async_trait; +use chat_core::error::ChatError; +use chat_core::traits::StreamProvider; +use chat_core::types::messages::Messages; +use chat_core::types::messages::content::{Content, RoleEnum}; +use chat_core::types::messages::parts::{PartEnum, Parts}; +use chat_core::types::messages::text::Text; +use chat_core::types::metadata::Metadata; +use chat_core::types::options::ChatOptions; +use chat_core::types::response::{ChatResponse, StreamEvent}; +use chat_core::types::tools::ToolDeclarations; +use futures::StreamExt; +use futures::stream::BoxStream; + +use crate::api::types::{WireStreamEvent, request, response}; +use crate::client::AppleFMClient; +use crate::ffi; + +#[async_trait] +impl StreamProvider for AppleFMClient { + async fn stream( + &mut self, + messages: &mut Messages, + tool_declarations: Option<&dyn ToolDeclarations>, + options: Option<&ChatOptions>, + ) -> Result>, ChatError> { + let request_json = request::from_core( + &self.config, + messages, + options, + None, + tool_declarations.is_some(), + ) + .map_err(|failure| failure.err)?; + let model_slug = self.model_slug(); + + // The bridge call blocks until the stream finishes, emitting one + // JSON event per callback; pump events through a channel into the + // async world. + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + tokio::task::spawn_blocking(move || { + ffi::stream_json(&request_json, |event| { + let _ = tx.send(event.to_owned()); + }); + }); + + let stream = async_stream::try_stream! { + while let Some(event_json) = rx.recv().await { + let event: WireStreamEvent = + serde_json::from_str(&event_json).map_err(|e| { + ChatError::InvalidResponse(format!( + "malformed bridge stream event ({e}): {event_json}" + )) + })?; + match event { + WireStreamEvent::Delta { text } => { + yield StreamEvent::TextChunk(text); + } + WireStreamEvent::Done { text, finish } => { + // The done event carries the authoritative full + // text; deltas are best-effort display fragments. + yield StreamEvent::Done(ChatResponse { + metadata: Some(Metadata { + model_slug: Some(model_slug.clone()), + ..Default::default() + }), + content: Content { + role: RoleEnum::Model, + parts: Parts(vec![PartEnum::Text(Text::new(text))]), + complete_reason: response::map_finish(&finish), + }, + }); + break; + } + WireStreamEvent::Error { error } => { + Err(response::error_to_chat(error))?; + } + } + } + }; + + Ok(stream.boxed()) + } +} diff --git a/providers/applefm/src/api/types/mod.rs b/providers/applefm/src/api/types/mod.rs new file mode 100644 index 0000000..e37a19e --- /dev/null +++ b/providers/applefm/src/api/types/mod.rs @@ -0,0 +1,102 @@ +//! The JSON wire protocol crossing the C boundary. +//! +//! Mirrored by `bridge/Sources/AppleFMBridge/WireTypes.swift` — keep the +//! two in sync. + +pub(crate) mod request; +pub(crate) mod response; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Default, Serialize)] +pub(crate) struct WireOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + // Sampling mode, flattened. The bridge picks greedy > top_k > top_p. + #[serde(skip_serializing_if = "Option::is_none")] + pub greedy: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, +} + +#[derive(Debug, Serialize)] +pub(crate) struct WireMessage { + /// "user" | "assistant" + pub role: &'static str, + pub text: String, +} + +#[derive(Debug, Serialize)] +pub(crate) struct CompleteRequest { + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + /// Filesystem path to a `.fmadapter` LoRA package. + #[serde(skip_serializing_if = "Option::is_none")] + pub lora: Option, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct CompleteReply { + pub text: String, + pub finish: String, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct ErrorBody { + pub kind: String, + pub message: String, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct ErrorReply { + pub error: ErrorBody, +} + +/// Events emitted by the bridge's streaming path, discriminated by `type`. +#[cfg(feature = "stream")] +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub(crate) enum WireStreamEvent { + /// A text fragment (already diffed from cumulative snapshots). + Delta { + text: String, + }, + /// Stream finished; carries the authoritative full text. + Done { + text: String, + finish: String, + }, + Error { + error: ErrorBody, + }, +} + +#[cfg(all(test, feature = "stream"))] +mod tests { + use super::*; + + #[test] + fn parses_stream_events() { + let event: WireStreamEvent = + serde_json::from_str(r#"{"type":"delta","text":"hi"}"#).unwrap(); + assert!(matches!(event, WireStreamEvent::Delta { text } if text == "hi")); + + let event: WireStreamEvent = + serde_json::from_str(r#"{"type":"done","text":"hi there","finish":"stop"}"#).unwrap(); + assert!(matches!(event, WireStreamEvent::Done { finish, .. } if finish == "stop")); + + let event: WireStreamEvent = + serde_json::from_str(r#"{"type":"error","error":{"kind":"generation","message":"x"}}"#) + .unwrap(); + assert!(matches!(event, WireStreamEvent::Error { error } if error.kind == "generation")); + } +} diff --git a/providers/applefm/src/api/types/request.rs b/providers/applefm/src/api/types/request.rs new file mode 100644 index 0000000..9b12562 --- /dev/null +++ b/providers/applefm/src/api/types/request.rs @@ -0,0 +1,245 @@ +//! `Messages` + `ChatOptions` → `CompleteRequest` JSON. +//! +//! All capability rejections live here, mistralrs-style: what the +//! on-device model can't do fails fast with a clear error instead of +//! crossing the bridge. + +use chat_core::error::{ChatError, ChatFailure}; +use chat_core::types::messages::Messages; +use chat_core::types::messages::content::RoleEnum; +use chat_core::types::messages::parts::PartEnum; +use chat_core::types::options::ChatOptions; + +use crate::client::{Config, Sampling}; + +use super::{CompleteRequest, WireMessage, WireOptions}; + +pub(crate) fn from_core( + config: &Config, + messages: &Messages, + options: Option<&ChatOptions>, + structured_output: Option<&schemars::Schema>, + tools_present: bool, +) -> Result { + if tools_present { + return Err(unsupported("tool declarations")); + } + if structured_output.is_some() { + return Err(unsupported("structured outputs")); + } + + // System-role messages fold into the session instructions; the + // FoundationModels API has no system role in the conversation itself. + let mut instructions = String::new(); + let mut wire_messages = Vec::new(); + + for content in &messages.0 { + let text = flatten_text_only(&content.parts.0)?; + match content.role { + RoleEnum::System => { + if !instructions.is_empty() { + instructions.push('\n'); + } + instructions.push_str(&text); + } + RoleEnum::User => wire_messages.push(WireMessage { role: "user", text }), + RoleEnum::Model => wire_messages.push(WireMessage { + role: "assistant", + text, + }), + } + } + + if wire_messages.is_empty() { + return Err(ChatFailure::from_err(ChatError::Provider( + "chat-applefm needs at least one user message".into(), + ))); + } + + let request = CompleteRequest { + instructions: (!instructions.is_empty()).then_some(instructions), + lora: config + .lora + .as_ref() + .map(|p| p.to_string_lossy().into_owned()), + messages: wire_messages, + options: merge_options(config, options), + }; + + serde_json::to_string(&request) + .map_err(|e| ChatFailure::from_err(ChatError::Other(format!("request serialization: {e}")))) +} + +/// Builder defaults overlaid with per-call `ChatOptions`. Temperature and +/// max_tokens merge per field. Sampling merges as a family: if the call +/// carries *any* sampling key (`top_p`, or `greedy`/`top_k`/`seed` in +/// metadata) it replaces the builder's sampling default wholesale, so a +/// builder top-k never mixes with a per-call top-p. +fn merge_options(config: &Config, opts: Option<&ChatOptions>) -> Option { + let mut wire = WireOptions { + temperature: config.temperature, + max_tokens: config.max_tokens, + ..Default::default() + }; + match config.sampling { + Some(Sampling::Greedy) => wire.greedy = Some(true), + Some(Sampling::TopK { k, seed }) => (wire.top_k, wire.seed) = (Some(k), seed), + Some(Sampling::TopP { p, seed }) => (wire.top_p, wire.seed) = (Some(p), seed), + None => {} + } + + if let Some(opts) = opts { + if let Some(t) = opts.temperature { + wire.temperature = Some(f64::from(t)); + } + if let Some(m) = opts.max_tokens { + wire.max_tokens = Some(m); + } + + let greedy = opts.metadata.get("greedy").and_then(|v| v.as_bool()); + let top_k = opts.metadata.get("top_k").and_then(|v| v.as_u64()); + let seed = opts.metadata.get("seed").and_then(|v| v.as_u64()); + if greedy.is_some() || top_k.is_some() || opts.top_p.is_some() || seed.is_some() { + (wire.greedy, wire.top_k, wire.top_p, wire.seed) = ( + greedy, + top_k.map(|k| k as u32), + opts.top_p.map(f64::from), + seed, + ); + } + } + + let is_empty = matches!( + wire, + WireOptions { + temperature: None, + max_tokens: None, + greedy: None, + top_k: None, + top_p: None, + seed: None, + } + ); + (!is_empty).then_some(wire) +} + +/// The on-device model is text-only via this API; everything else is +/// rejected with a clear pointer to what's unsupported. +fn flatten_text_only(parts: &[PartEnum]) -> Result { + let mut buf = String::new(); + for part in parts { + match part { + PartEnum::Text(t) => { + if !buf.is_empty() { + buf.push('\n'); + } + buf.push_str(t.as_str()); + } + PartEnum::File(f) => { + return Err(unsupported(&format!("file parts (mimetype {})", f.mime))); + } + PartEnum::Tool(_) => return Err(unsupported("tool parts")), + PartEnum::Structured(_) => return Err(unsupported("structured parts in input")), + PartEnum::Reasoning(_) => return Err(unsupported("reasoning parts in input")), + PartEnum::Embeddings(_) => return Err(unsupported("embedding parts in input")), + } + } + Ok(buf) +} + +fn unsupported(what: &str) -> ChatFailure { + ChatFailure::from_err(ChatError::Provider(format!( + "chat-applefm does not yet support {what}" + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use chat_core::parts; + use chat_core::types::messages::content; + + fn config() -> Config { + Config { + lora: Some("adapters/transcripts.fmadapter".into()), + ..Default::default() + } + } + + #[test] + fn folds_system_into_instructions_and_carries_lora() { + let mut messages = Messages::default(); + messages.push(content::from_system(parts!["Talk like a pirate."])); + messages.push(content::from_user(parts!["hello"])); + + let json = from_core(&config(), &messages, None, None, false).unwrap(); + let v: serde_json::Value = serde_json::from_str(&json).unwrap(); + + assert_eq!(v["instructions"], "Talk like a pirate."); + assert_eq!(v["lora"], "adapters/transcripts.fmadapter"); + assert_eq!(v["messages"][0]["role"], "user"); + assert_eq!(v["messages"][0]["text"], "hello"); + } + + #[test] + fn rejects_tools_and_structured() { + let mut messages = Messages::default(); + messages.push(content::from_user(parts!["hi"])); + + assert!(from_core(&config(), &messages, None, None, true).is_err()); + let schema = schemars::json_schema!({"type": "object"}); + assert!(from_core(&config(), &messages, None, Some(&schema), false).is_err()); + } + + #[test] + fn builder_defaults_yield_to_call_options_as_a_family() { + let config = Config { + lora: None, + temperature: Some(0.7), + max_tokens: Some(100), + sampling: Some(Sampling::TopK { + k: 40, + seed: Some(7), + }), + }; + let mut messages = Messages::default(); + messages.push(content::from_user(parts!["hi"])); + + // No call options → builder defaults flow through. + let json = from_core(&config, &messages, None, None, false).unwrap(); + let v: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(v["options"]["top_k"], 40); + assert_eq!(v["options"]["seed"], 7); + + // A call-level top_p replaces the whole sampling family (no + // leftover top_k/seed) but per-field merge keeps max_tokens. + // 0.75 and 0.5 are exact in binary — immune to f32→f64 widening. + let mut opts = ChatOptions::default(); + opts.top_p = Some(0.75); + opts.temperature = Some(0.5); + let json = from_core(&config, &messages, Some(&opts), None, false).unwrap(); + let v: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(v["options"]["top_p"], 0.75); + assert!(v["options"]["top_k"].is_null()); + assert!(v["options"]["seed"].is_null()); + assert_eq!(v["options"]["max_tokens"], 100); + assert_eq!(v["options"]["temperature"], 0.5); + } + + #[test] + fn maps_options() { + let mut messages = Messages::default(); + messages.push(content::from_user(parts!["hi"])); + + let mut opts = ChatOptions::default(); + opts.temperature = Some(0.2); + opts.max_tokens = Some(64); + opts.metadata + .insert("greedy".into(), serde_json::Value::Bool(true)); + + let json = from_core(&config(), &messages, Some(&opts), None, false).unwrap(); + let v: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(v["options"]["max_tokens"], 64); + assert_eq!(v["options"]["greedy"], true); + } +} diff --git a/providers/applefm/src/api/types/response.rs b/providers/applefm/src/api/types/response.rs new file mode 100644 index 0000000..f52c5df --- /dev/null +++ b/providers/applefm/src/api/types/response.rs @@ -0,0 +1,87 @@ +//! Bridge reply JSON → `ChatResponse` (or a mapped `ChatFailure`). + +use chat_core::error::{ChatError, ChatFailure}; +use chat_core::types::messages::content::{CompleteReasonEnum, Content, RoleEnum}; +use chat_core::types::messages::parts::{PartEnum, Parts}; +use chat_core::types::messages::text::Text; +use chat_core::types::metadata::Metadata; +use chat_core::types::response::ChatResponse; + +use super::{CompleteReply, ErrorReply}; + +pub(crate) fn into_core(model_slug: &str, reply_json: &str) -> Result { + if let Ok(err) = serde_json::from_str::(reply_json) { + return Err(map_error(err)); + } + + let reply: CompleteReply = serde_json::from_str(reply_json).map_err(|e| { + ChatFailure::from_err(ChatError::InvalidResponse(format!( + "malformed bridge reply ({e}): {reply_json}" + ))) + })?; + + let complete_reason = map_finish(&reply.finish); + + Ok(ChatResponse { + metadata: Some(Metadata { + model_slug: Some(model_slug.to_owned()), + ..Default::default() + }), + content: Content { + role: RoleEnum::Model, + parts: Parts(vec![PartEnum::Text(Text::new(reply.text))]), + complete_reason, + }, + }) +} + +pub(crate) fn map_finish(finish: &str) -> CompleteReasonEnum { + match finish { + "stop" => CompleteReasonEnum::Stop, + "max_tokens" => CompleteReasonEnum::MaxTokens, + other => CompleteReasonEnum::Other(other.to_owned()), + } +} + +/// The bridge's error kinds, mapped onto chat-rs error semantics. All are +/// non-retryable `Provider` errors except `internal`, which is `Other`. +pub(crate) fn error_to_chat(error: super::ErrorBody) -> ChatError { + match error.kind.as_str() { + "internal" => ChatError::Other(format!("applefm bridge: {}", error.message)), + kind => ChatError::Provider(format!("applefm {kind}: {}", error.message)), + } +} + +fn map_error(reply: ErrorReply) -> ChatFailure { + let ErrorReply { error } = reply; + ChatFailure::from_err(error_to_chat(error)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parses_completion() { + let resp = into_core("apple-on-device", r#"{"text":"hi there","finish":"stop"}"#).unwrap(); + assert_eq!( + resp.content.parts.text_response().map(|t| t.as_str()), + Some("hi there") + ); + assert!(matches!( + resp.content.complete_reason, + CompleteReasonEnum::Stop + )); + } + + #[test] + fn maps_error_kinds() { + let err = into_core( + "apple-on-device", + r#"{"error":{"kind":"unavailable","message":"not enabled"}}"#, + ) + .unwrap_err(); + assert!(matches!(err.err, ChatError::Provider(_))); + assert!(err.err.to_string().contains("unavailable")); + } +} diff --git a/providers/applefm/src/builder.rs b/providers/applefm/src/builder.rs new file mode 100644 index 0000000..6b329fc --- /dev/null +++ b/providers/applefm/src/builder.rs @@ -0,0 +1,178 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use chat_core::error::{ChatError, ChatFailure}; +use chat_core::types::provider_meta::ProviderMeta; + +use crate::client::{AppleFMClient, Config, Sampling}; + +/// Builder for [`AppleFMClient`]. +/// +/// A pure connection stub: it wires up *which model variant* to talk to +/// (base, or base + LoRA) and nothing else. Unlike other providers there +/// is no `with_model` typestate — the model is the one the OS ships. +/// Conversation concerns (system prompts, options) belong to the chat: +/// push a `System`-role message into `Messages` and the provider maps it +/// onto the session's instructions. +/// +/// ```no_run +/// # fn run() -> Result<(), Box> { +/// use chat_applefm::AppleFMBuilder; +/// +/// let client = AppleFMBuilder::new() +/// .with_lora("adapters/transcripts.fmadapter") +/// .build()?; +/// # Ok(()) } +/// ``` +#[derive(Debug, Default)] +pub struct AppleFMBuilder { + lora: Option, + temperature: Option, + max_tokens: Option, + sampling: Option, + description: Option, +} + +impl AppleFMBuilder { + pub fn new() -> Self { + Self::default() + } + + /// Apply a LoRA fine-tune over the on-device base model: the path to a + /// `.fmadapter` package produced by Apple's adapter training toolkit. + /// + /// Adapters are tied to a specific base-model version — when a macOS + /// update rolls the base model, the adapter must be retrained. Loading + /// an incompatible adapter fails at request time with a provider error. + pub fn with_lora(mut self, path: impl Into) -> Self { + self.lora = Some(path.into()); + self + } + + /// Default decoding temperature. Overridden per call by + /// `ChatOptions::temperature`. + pub fn with_temperature(mut self, temperature: f64) -> Self { + self.temperature = Some(temperature); + self + } + + /// Default response-length cap. Overridden per call by + /// `ChatOptions::max_tokens`. + pub fn with_max_tokens(mut self, max_tokens: u32) -> Self { + self.max_tokens = Some(max_tokens); + self + } + + /// Default sampling mode — greedy, top-k, or top-p (the complete set + /// FoundationModels exposes). Overridden per call when `ChatOptions` + /// carries any sampling key (`top_p`, or `greedy` / `top_k` / `seed` + /// in its metadata). + pub fn with_sampling(mut self, sampling: Sampling) -> Self { + self.sampling = Some(sampling); + self + } + + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } + + /// Build the client, validating the configuration upfront — a missing + /// `.fmadapter` path or nonsensical sampling parameters fail here, not + /// on the first request. Cheap otherwise: the model is probed and the + /// session created at request time. Call [`crate::availability`] + /// first to know whether requests can succeed on this machine at all. + pub fn build(self) -> Result { + if let Some(lora) = &self.lora + && !lora.exists() + { + return Err(invalid(format!( + "LoRA adapter not found at {}", + lora.display() + ))); + } + if let Some(t) = self.temperature + && (!t.is_finite() || t < 0.0) + { + return Err(invalid(format!("temperature must be >= 0, got {t}"))); + } + if self.max_tokens == Some(0) { + return Err(invalid("max_tokens must be >= 1".into())); + } + match self.sampling { + Some(Sampling::TopK { k: 0, .. }) => { + return Err(invalid("top-k sampling needs k >= 1".into())); + } + Some(Sampling::TopP { p, .. }) if !p.is_finite() || p <= 0.0 || p > 1.0 => { + return Err(invalid(format!( + "top-p sampling needs p in (0, 1], got {p}" + ))); + } + _ => {} + } + + Ok(AppleFMClient { + config: Arc::new(Config { + lora: self.lora, + temperature: self.temperature, + max_tokens: self.max_tokens, + sampling: self.sampling, + }), + meta: Arc::new(ProviderMeta { + description: self.description, + ..Default::default() + }), + }) + } +} + +fn invalid(message: String) -> ChatFailure { + ChatFailure::from_err(ChatError::Provider(format!( + "chat-applefm builder: {message}" + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rejects_missing_lora_path() { + let err = AppleFMBuilder::new() + .with_lora("/definitely/not/here.fmadapter") + .build() + .unwrap_err(); + assert!(err.err.to_string().contains("not found")); + } + + #[test] + fn accepts_existing_lora_path() { + // Any path that exists works for the check; the manifest is one. + let manifest = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml"); + assert!(AppleFMBuilder::new().with_lora(manifest).build().is_ok()); + } + + #[test] + fn rejects_bad_sampling_and_options() { + assert!( + AppleFMBuilder::new() + .with_sampling(Sampling::TopP { p: 1.5, seed: None }) + .build() + .is_err() + ); + assert!( + AppleFMBuilder::new() + .with_sampling(Sampling::TopK { k: 0, seed: None }) + .build() + .is_err() + ); + assert!( + AppleFMBuilder::new() + .with_temperature(-0.1) + .build() + .is_err() + ); + assert!(AppleFMBuilder::new().with_max_tokens(0).build().is_err()); + assert!(AppleFMBuilder::new().build().is_ok()); + } +} diff --git a/providers/applefm/src/client.rs b/providers/applefm/src/client.rs new file mode 100644 index 0000000..579ebcf --- /dev/null +++ b/providers/applefm/src/client.rs @@ -0,0 +1,58 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use chat_core::types::provider_meta::ProviderMeta; + +/// Decoding strategy for the on-device model — the full set +/// FoundationModels exposes via `GenerationOptions.SamplingMode`. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Sampling { + /// Deterministic decoding. + Greedy, + /// Sample among the `k` most probable tokens. + TopK { k: u32, seed: Option }, + /// Nucleus sampling: sample within the smallest set of tokens whose + /// cumulative probability reaches `p`. + TopP { p: f64, seed: Option }, +} + +/// Model wiring baked in by the builder and sent to the bridge with +/// every request (sessions are rebuilt per call for now). Conversation +/// concerns — including system prompts — live in `Messages`, not here. +/// Everything below is a default; per-call `ChatOptions` override it. +#[derive(Debug, Default)] +pub(crate) struct Config { + /// Path to a `.fmadapter` package — a LoRA trained with Apple's + /// adapter training toolkit, applied over the on-device base model. + pub(crate) lora: Option, + pub(crate) temperature: Option, + pub(crate) max_tokens: Option, + pub(crate) sampling: Option, +} + +/// Client for the Apple Intelligence on-device foundation model. +/// +/// There is no model slug and no API key: the OS owns the (one) model. +/// What varies per client is the configuration — instructions and an +/// optional LoRA adapter. Clone is cheap (`Arc` bumps). +#[derive(Clone, Debug)] +pub struct AppleFMClient { + pub(crate) config: Arc, + /// `Arc` because `ProviderMeta` is not `Clone`. + pub(crate) meta: Arc, +} + +impl AppleFMClient { + /// Identifier used as the response `model_slug`: the base model name, + /// plus the adapter file stem when a LoRA is loaded. + pub fn model_slug(&self) -> String { + match self.config.lora.as_deref().and_then(|p| p.file_stem()) { + Some(stem) => format!("apple-on-device+{}", stem.to_string_lossy()), + None => "apple-on-device".to_owned(), + } + } + + pub fn provider_meta(&self) -> &ProviderMeta { + &self.meta + } +} diff --git a/providers/applefm/src/ffi.rs b/providers/applefm/src/ffi.rs new file mode 100644 index 0000000..9386473 --- /dev/null +++ b/providers/applefm/src/ffi.rs @@ -0,0 +1,112 @@ +//! Safe wrappers over the bridge's C surface. +//! +//! When the crate is built without the bridge (non-macOS target, +//! docs.rs, or `APPLEFM_SKIP_BRIDGE=1`), the stub module below stands in +//! and reports the model as unavailable — same shape, no linking. + +#[cfg(applefm_bridge)] +mod real { + use std::ffi::{CStr, CString, c_char, c_void}; + + unsafe extern "C" { + fn afm_availability() -> *mut c_char; + fn afm_complete(request_json: *const c_char) -> *mut c_char; + fn afm_respond_stream( + request_json: *const c_char, + on_event: Option, + context: *mut c_void, + ); + fn afm_string_free(ptr: *mut c_char); + } + + /// Copy a bridge-owned string out and release it. Maps a null return + /// to `None`. + /// + /// SAFETY: `ptr` must come from this bridge (strdup-allocated, + /// NUL-terminated) and must not be used after this call. + unsafe fn take_bridge_string(ptr: *mut c_char) -> Option { + if ptr.is_null() { + return None; + } + // SAFETY: per contract above. + unsafe { + let s = CStr::from_ptr(ptr).to_string_lossy().into_owned(); + afm_string_free(ptr); + Some(s) + } + } + + pub fn availability_json() -> String { + // SAFETY: afm_availability returns a bridge-owned string. + unsafe { take_bridge_string(afm_availability()) } + .unwrap_or_else(|| r#"{"available":false,"reason":"bridge returned null"}"#.to_owned()) + } + + pub fn complete_json(request: &str) -> String { + let Ok(request) = CString::new(request) else { + return r#"{"error":{"kind":"decode","message":"request contained a NUL byte"}}"# + .to_owned(); + }; + // SAFETY: the pointer is valid for the duration of the call and + // the reply is a bridge-owned string. + unsafe { take_bridge_string(afm_complete(request.as_ptr())) }.unwrap_or_else(|| { + r#"{"error":{"kind":"internal","message":"bridge returned null"}}"#.to_owned() + }) + } + + /// Run one streaming completion, invoking `on_event` once per event + /// JSON. Blocks until the stream finishes — call from a worker thread. + /// `on_event` may be invoked from a different thread than the caller's. + pub fn stream_json(request: &str, on_event: impl FnMut(&str) + Send) { + let mut on_event: Box = Box::new(on_event); + let Ok(request) = CString::new(request) else { + on_event( + r#"{"type":"error","error":{"kind":"decode","message":"request contained a NUL byte"}}"#, + ); + return; + }; + + unsafe extern "C" fn trampoline(context: *mut c_void, event: *const c_char) { + if context.is_null() || event.is_null() { + return; + } + // SAFETY: `context` is the `&mut Box` passed below, + // alive for the whole (blocking) `afm_respond_stream` call; + // `event` is a NUL-terminated string valid for this invocation. + unsafe { + let on_event = &mut *context.cast::>(); + let event = CStr::from_ptr(event).to_string_lossy(); + on_event(&event); + } + } + + let context = (&raw mut on_event).cast::(); + // SAFETY: the bridge blocks until the stream completes, so the + // closure outlives every trampoline invocation. + unsafe { afm_respond_stream(request.as_ptr(), Some(trampoline), context) } + } +} + +#[cfg(not(applefm_bridge))] +mod stub { + const STUB_REASON: &str = "chat-applefm was built without the Swift bridge (non-macOS target, docs build, or APPLEFM_SKIP_BRIDGE set)"; + + pub fn availability_json() -> String { + format!(r#"{{"available":false,"reason":"{STUB_REASON}"}}"#) + } + + pub fn complete_json(_request: &str) -> String { + format!(r#"{{"error":{{"kind":"unavailable","message":"{STUB_REASON}"}}}}"#) + } + + pub fn stream_json(_request: &str, mut on_event: impl FnMut(&str) + Send) { + on_event(&format!( + r#"{{"type":"error","error":{{"kind":"unavailable","message":"{STUB_REASON}"}}}}"# + )); + } +} + +#[cfg(applefm_bridge)] +pub(crate) use real::{availability_json, complete_json, stream_json}; +#[cfg(not(applefm_bridge))] +pub(crate) use stub::{availability_json, complete_json, stream_json}; diff --git a/providers/applefm/src/lib.rs b/providers/applefm/src/lib.rs new file mode 100644 index 0000000..30b1c01 --- /dev/null +++ b/providers/applefm/src/lib.rs @@ -0,0 +1,95 @@ +//! Apple on-device foundation model provider for chat-rs. +//! +//! Talks to the ~3B-parameter model that ships with Apple Intelligence +//! (macOS 26+) through Apple's FoundationModels framework. There is no +//! HTTP and no weights file: the OS owns the model; this crate owns the +//! translation. +//! +//! ## How it connects +//! +//! Rust cannot call Swift directly, so the crate embeds a small Swift +//! package (`bridge/`) exposing a few plain C functions. Requests and +//! responses cross that boundary as JSON strings — the same mental model +//! as an HTTP provider's wire format, minus the network. The bridge is +//! compiled automatically by `build.rs`; users only run `cargo build`. +//! +//! On non-macOS targets the crate still compiles (with a stub bridge) +//! and reports the model as unavailable. +//! +//! **Binary crates on macOS need one build.rs line** so the Swift +//! concurrency runtime resolves at load time: +//! +//! ```no_run +//! // build.rs of your binary crate +//! println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift"); +//! ``` +//! +//! ## Current scope +//! +//! Text completion (with optional LoRA fine-tune) and availability +//! probing. Tools, structured output, and streaming are rejected at +//! request time for now and arrive in later slices. +//! +//! ```no_run +//! # async fn run() -> Result<(), Box> { +//! use chat_applefm::AppleFMBuilder; +//! +//! let probe = chat_applefm::availability(); +//! if !probe.available { +//! println!("unavailable: {}", probe.reason.as_deref().unwrap_or("?")); +//! return Ok(()); +//! } +//! +//! let client = AppleFMBuilder::new() +//! .with_lora("adapters/transcripts.fmadapter") // optional fine-tune +//! .build()?; // validates config (e.g. the adapter path) upfront +//! // System prompts go through Messages like any other provider; the +//! // provider maps them onto the session's instructions. +//! # Ok(()) } +//! ``` +//! +//! ## LoRA fine-tunes +//! +//! [`AppleFMBuilder::with_lora`] loads a `.fmadapter` package — a LoRA +//! trained against the on-device base model with Apple's adapter +//! training toolkit. Adapters are version-locked to the base model: +//! after a macOS update rolls the model, retrain and reship. +//! +//! See `providers/AGENTS.md` for the overall provider architecture. + +#![allow(clippy::result_large_err)] + +mod api; +mod builder; +mod client; +mod ffi; + +pub use builder::AppleFMBuilder; +pub use client::{AppleFMClient, Sampling}; + +use serde::Deserialize; + +/// Result of probing whether the on-device Apple model is usable here. +#[derive(Debug, Clone, Deserialize)] +pub struct Availability { + /// Whether a session can be created on this machine right now. + pub available: bool, + /// Human-readable explanation when `available` is `false` + /// (ineligible hardware, Apple Intelligence disabled, assets still + /// downloading, OS too old, or a bridge-less build). + #[serde(default)] + pub reason: Option, +} + +/// Ask the OS whether the Apple Intelligence on-device model can be used. +/// +/// Cheap to call; performs no generation. This is the recommended first +/// call before constructing a client, and what router strategies should +/// consult when deciding whether this provider is eligible at all. +pub fn availability() -> Availability { + let json = ffi::availability_json(); + serde_json::from_str(&json).unwrap_or_else(|_| Availability { + available: false, + reason: Some(format!("malformed bridge reply: {json}")), + }) +} diff --git a/src/lib.rs b/src/lib.rs index f11fe55..c2bba6c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,6 +100,11 @@ pub mod router { pub use chat_router::*; } +#[cfg(feature = "applefm")] +pub mod applefm { + pub use chat_applefm::*; +} + pub mod prelude { pub use crate::ChatOptions; pub use crate::Messages; @@ -139,4 +144,7 @@ pub mod prelude { #[cfg(feature = "router")] pub use crate::router; + + #[cfg(feature = "applefm")] + pub use crate::applefm; }