From e01de4b1abc16cbd4582e0b3e8392437dbaa2a6d Mon Sep 17 00:00:00 2001 From: Mattt Date: Mon, 11 Dec 2023 04:48:53 -0800 Subject: [PATCH] Allow `Client.run` method to take model or version identifier (#70) * Update Identifier to make version property optional * Allow run method to take model argument --- Sources/Replicate/Client.swift | 13 +++++++-- Sources/Replicate/Identifier.swift | 29 ++++++++++++------- Tests/ReplicateTests/ClientTests.swift | 8 ++++- .../Helpers/MockURLProtocol.swift | 26 +++++++++++++++++ 4 files changed, 62 insertions(+), 14 deletions(-) diff --git a/Sources/Replicate/Client.swift b/Sources/Replicate/Client.swift index fa33d78a..2010ad5d 100644 --- a/Sources/Replicate/Client.swift +++ b/Sources/Replicate/Client.swift @@ -74,10 +74,19 @@ public class Client { webhook: Webhook? = nil, _ type: Output.Type = Value.self ) async throws -> Output? { - var prediction = try await createPrediction(Prediction.self, - version: identifier.version, + var prediction: Prediction + if let version = identifier.version { + prediction = try await createPrediction(Prediction.self, + version: version, input: input, webhook: webhook) + } else { + prediction = try await createPrediction(Prediction.self, + model: "\(identifier.owner)/\(identifier.name)", + input: input, + webhook: webhook) + } + try await prediction.wait(with: self) if prediction.status == .failed { diff --git a/Sources/Replicate/Identifier.swift b/Sources/Replicate/Identifier.swift index e767d7fe..cb2f0cb8 100644 --- a/Sources/Replicate/Identifier.swift +++ b/Sources/Replicate/Identifier.swift @@ -7,7 +7,7 @@ public struct Identifier: Hashable { public let name: String /// The version. - let version: Model.Version.ID + let version: Model.Version.ID? } // MARK: - Equatable & Comparable @@ -31,19 +31,26 @@ extension Identifier: RawRepresentable { let components = rawValue.split(separator: "/") guard components.count == 2 else { return nil } - let owner = String(components[0]) - - let nameAndVersion = components[1].split(separator: ":") - guard nameAndVersion.count == 2 else { return nil } - - let name = String(nameAndVersion[0]) - let version = Model.Version.ID(nameAndVersion[1]) - - self.init(owner: owner, name: name, version: version) + if components[1].contains(":") { + let nameAndVersion = components[1].split(separator: ":") + guard nameAndVersion.count == 2 else { return nil } + + self.init(owner: String(components[0]), + name: String(nameAndVersion[0]), + version: Model.Version.ID(nameAndVersion[1])) + } else { + self.init(owner: String(components[0]), + name: String(components[1]), + version: nil) + } } public var rawValue: String { - return "\(owner)/\(name):\(version)" + if let version = version { + return "\(owner)/\(name):\(version)" + } else { + return "\(owner)/\(name)" + } } } diff --git a/Tests/ReplicateTests/ClientTests.swift b/Tests/ReplicateTests/ClientTests.swift index 3369b9cd..2e130684 100644 --- a/Tests/ReplicateTests/ClientTests.swift +++ b/Tests/ReplicateTests/ClientTests.swift @@ -8,12 +8,18 @@ final class ClientTests: XCTestCase { URLProtocol.registerClass(MockURLProtocol.self) } - func testRun() async throws { + func testRunWithVersion() async throws { let identifier: Identifier = "test/example:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" let output = try await client.run(identifier, input: ["text": "Alice"]) XCTAssertEqual(output, ["Hello, Alice!"]) } + func testRunWithModel() async throws { + let identifier: Identifier = "meta/llama-2-70b-chat" + let output = try await client.run(identifier, input: ["prompt": "Please write a haiku about llamas."]) + XCTAssertEqual(output, ["I'm sorry, I'm afraid I can't do that"] ) + } + func testRunWithInvalidVersion() async throws { let identifier: Identifier = "test/example:invalid" do { diff --git a/Tests/ReplicateTests/Helpers/MockURLProtocol.swift b/Tests/ReplicateTests/Helpers/MockURLProtocol.swift index d7ed0d55..edc267e2 100644 --- a/Tests/ReplicateTests/Helpers/MockURLProtocol.swift +++ b/Tests/ReplicateTests/Helpers/MockURLProtocol.swift @@ -138,6 +138,32 @@ class MockURLProtocol: URLProtocol { } } """# + case ("GET", "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci"?): + statusCode = 200 + json = #""" + { + "id": "heat2o3bzn3ahtr6bjfftvbaci", + "model": "meta/llama-2-70b-chat", + "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + "urls": { + "get": "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci", + "cancel": "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel" + }, + "created_at": "2022-04-26T22:13:06.224088Z", + "completed_at": "2022-04-26T22:15:06.224088Z", + "source": "web", + "status": "succeeded", + "input": { + "prompt": "Please write a haiku about llamas." + }, + "output": ["I'm sorry, I'm afraid I can't do that"], + "error": null, + "logs": "", + "metrics": { + "predict_time": 1.0 + } + } + """# case ("POST", "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel"?): statusCode = 200 json = #"""