Skip to content

Commit

Permalink
Allow Client.run method to take model or version identifier (#70)
Browse files Browse the repository at this point in the history
* Update Identifier to make version property optional

* Allow run method to take model argument
  • Loading branch information
mattt authored Dec 11, 2023
1 parent 81c775e commit e01de4b
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 14 deletions.
13 changes: 11 additions & 2 deletions Sources/Replicate/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,19 @@ public class Client {
webhook: Webhook? = nil,
_ type: Output.Type = Value.self
) async throws -> Output? {
var prediction = try await createPrediction(Prediction<Input, Output>.self,
version: identifier.version,
var prediction: Prediction<Input, Output>
if let version = identifier.version {
prediction = try await createPrediction(Prediction<Input, Output>.self,
version: version,
input: input,
webhook: webhook)
} else {
prediction = try await createPrediction(Prediction<Input, Output>.self,
model: "\(identifier.owner)/\(identifier.name)",
input: input,
webhook: webhook)
}

try await prediction.wait(with: self)

if prediction.status == .failed {
Expand Down
29 changes: 18 additions & 11 deletions Sources/Replicate/Identifier.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ public struct Identifier: Hashable {
public let name: String

/// The version.
let version: Model.Version.ID
let version: Model.Version.ID?
}

// MARK: - Equatable & Comparable
Expand All @@ -31,19 +31,26 @@ extension Identifier: RawRepresentable {
let components = rawValue.split(separator: "/")
guard components.count == 2 else { return nil }

let owner = String(components[0])

let nameAndVersion = components[1].split(separator: ":")
guard nameAndVersion.count == 2 else { return nil }

let name = String(nameAndVersion[0])
let version = Model.Version.ID(nameAndVersion[1])

self.init(owner: owner, name: name, version: version)
if components[1].contains(":") {
let nameAndVersion = components[1].split(separator: ":")
guard nameAndVersion.count == 2 else { return nil }

self.init(owner: String(components[0]),
name: String(nameAndVersion[0]),
version: Model.Version.ID(nameAndVersion[1]))
} else {
self.init(owner: String(components[0]),
name: String(components[1]),
version: nil)
}
}

public var rawValue: String {
return "\(owner)/\(name):\(version)"
if let version = version {
return "\(owner)/\(name):\(version)"
} else {
return "\(owner)/\(name)"
}
}
}

Expand Down
8 changes: 7 additions & 1 deletion Tests/ReplicateTests/ClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@ final class ClientTests: XCTestCase {
URLProtocol.registerClass(MockURLProtocol.self)
}

func testRun() async throws {
func testRunWithVersion() async throws {
let identifier: Identifier = "test/example:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
let output = try await client.run(identifier, input: ["text": "Alice"])
XCTAssertEqual(output, ["Hello, Alice!"])
}

func testRunWithModel() async throws {
let identifier: Identifier = "meta/llama-2-70b-chat"
let output = try await client.run(identifier, input: ["prompt": "Please write a haiku about llamas."])
XCTAssertEqual(output, ["I'm sorry, I'm afraid I can't do that"] )
}

func testRunWithInvalidVersion() async throws {
let identifier: Identifier = "test/example:invalid"
do {
Expand Down
26 changes: 26 additions & 0 deletions Tests/ReplicateTests/Helpers/MockURLProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,32 @@ class MockURLProtocol: URLProtocol {
}
}
"""#
case ("GET", "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci"?):
statusCode = 200
json = #"""
{
"id": "heat2o3bzn3ahtr6bjfftvbaci",
"model": "meta/llama-2-70b-chat",
"version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
"urls": {
"get": "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci",
"cancel": "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel"
},
"created_at": "2022-04-26T22:13:06.224088Z",
"completed_at": "2022-04-26T22:15:06.224088Z",
"source": "web",
"status": "succeeded",
"input": {
"prompt": "Please write a haiku about llamas."
},
"output": ["I'm sorry, I'm afraid I can't do that"],
"error": null,
"logs": "",
"metrics": {
"predict_time": 1.0
}
}
"""#
case ("POST", "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel"?):
statusCode = 200
json = #"""
Expand Down

0 comments on commit e01de4b

Please sign in to comment.