diff --git a/Sources/Replicate/Client.swift b/Sources/Replicate/Client.swift
index fa33d78a..2010ad5d 100644
--- a/Sources/Replicate/Client.swift
+++ b/Sources/Replicate/Client.swift
@@ -74,10 +74,19 @@ public class Client {
webhook: Webhook? = nil,
_ type: Output.Type = Value.self
) async throws -> Output? {
- var prediction = try await createPrediction(Prediction.self,
- version: identifier.version,
+ var prediction: Prediction
+ if let version = identifier.version {
+ prediction = try await createPrediction(Prediction.self,
+ version: version,
input: input,
webhook: webhook)
+ } else {
+ prediction = try await createPrediction(Prediction.self,
+ model: "\(identifier.owner)/\(identifier.name)",
+ input: input,
+ webhook: webhook)
+ }
+
try await prediction.wait(with: self)
if prediction.status == .failed {
diff --git a/Sources/Replicate/Identifier.swift b/Sources/Replicate/Identifier.swift
index e767d7fe..cb2f0cb8 100644
--- a/Sources/Replicate/Identifier.swift
+++ b/Sources/Replicate/Identifier.swift
@@ -7,7 +7,7 @@ public struct Identifier: Hashable {
public let name: String
/// The version.
- let version: Model.Version.ID
+ let version: Model.Version.ID?
}
// MARK: - Equatable & Comparable
@@ -31,19 +31,26 @@ extension Identifier: RawRepresentable {
let components = rawValue.split(separator: "/")
guard components.count == 2 else { return nil }
- let owner = String(components[0])
-
- let nameAndVersion = components[1].split(separator: ":")
- guard nameAndVersion.count == 2 else { return nil }
-
- let name = String(nameAndVersion[0])
- let version = Model.Version.ID(nameAndVersion[1])
-
- self.init(owner: owner, name: name, version: version)
+ if components[1].contains(":") {
+ let nameAndVersion = components[1].split(separator: ":")
+ guard nameAndVersion.count == 2 else { return nil }
+
+ self.init(owner: String(components[0]),
+ name: String(nameAndVersion[0]),
+ version: Model.Version.ID(nameAndVersion[1]))
+ } else {
+ self.init(owner: String(components[0]),
+ name: String(components[1]),
+ version: nil)
+ }
}
public var rawValue: String {
- return "\(owner)/\(name):\(version)"
+ if let version = version {
+ return "\(owner)/\(name):\(version)"
+ } else {
+ return "\(owner)/\(name)"
+ }
}
}
diff --git a/Tests/ReplicateTests/ClientTests.swift b/Tests/ReplicateTests/ClientTests.swift
index 3369b9cd..2e130684 100644
--- a/Tests/ReplicateTests/ClientTests.swift
+++ b/Tests/ReplicateTests/ClientTests.swift
@@ -8,12 +8,18 @@ final class ClientTests: XCTestCase {
URLProtocol.registerClass(MockURLProtocol.self)
}
- func testRun() async throws {
+ func testRunWithVersion() async throws {
let identifier: Identifier = "test/example:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
let output = try await client.run(identifier, input: ["text": "Alice"])
XCTAssertEqual(output, ["Hello, Alice!"])
}
+ func testRunWithModel() async throws {
+ let identifier: Identifier = "meta/llama-2-70b-chat"
+ let output = try await client.run(identifier, input: ["prompt": "Please write a haiku about llamas."])
+ XCTAssertEqual(output, ["I'm sorry, I'm afraid I can't do that"] )
+ }
+
func testRunWithInvalidVersion() async throws {
let identifier: Identifier = "test/example:invalid"
do {
diff --git a/Tests/ReplicateTests/Helpers/MockURLProtocol.swift b/Tests/ReplicateTests/Helpers/MockURLProtocol.swift
index d7ed0d55..edc267e2 100644
--- a/Tests/ReplicateTests/Helpers/MockURLProtocol.swift
+++ b/Tests/ReplicateTests/Helpers/MockURLProtocol.swift
@@ -138,6 +138,32 @@ class MockURLProtocol: URLProtocol {
}
}
"""#
+ case ("GET", "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci"?):
+ statusCode = 200
+ json = #"""
+ {
+ "id": "heat2o3bzn3ahtr6bjfftvbaci",
+ "model": "meta/llama-2-70b-chat",
+ "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
+ "urls": {
+ "get": "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci",
+ "cancel": "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel"
+ },
+ "created_at": "2022-04-26T22:13:06.224088Z",
+ "completed_at": "2022-04-26T22:15:06.224088Z",
+ "source": "web",
+ "status": "succeeded",
+ "input": {
+ "prompt": "Please write a haiku about llamas."
+ },
+ "output": ["I'm sorry, I'm afraid I can't do that"],
+ "error": null,
+ "logs": "",
+ "metrics": {
+ "predict_time": 1.0
+ }
+ }
+ """#
case ("POST", "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel"?):
statusCode = 200
json = #"""