From 014d6dfdb84c86fced98cb8f2b7660c6aac685c7 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 2 Nov 2023 10:15:41 -0700 Subject: [PATCH] Add support for models.create endpoint --- Sources/Replicate/Client.swift | 54 +++++++++++++++++++ Tests/ReplicateTests/ClientTests.swift | 6 +++ .../Helpers/MockURLProtocol.swift | 17 ++++++ 3 files changed, 77 insertions(+) diff --git a/Sources/Replicate/Client.swift b/Sources/Replicate/Client.swift index d9282d73..f8c7f3ff 100644 --- a/Sources/Replicate/Client.swift +++ b/Sources/Replicate/Client.swift @@ -486,6 +486,60 @@ public class Client { return try await fetch(.get, "hardware") } + /// Create a model + /// + /// - Parameters: + /// - owner: The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization. + /// - name: The name of the model. This must be unique among all models owned by the user or organization. + /// - visibility: Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model. + /// - hardware: The SKU for the hardware used to run the model. Possible values can be found by calling ``listHardware()``. + /// - description: A description of the model. + /// - githubURL: A URL for the model's source code on GitHub. + /// - paperURL: A URL for the model's paper. + /// - licenseURL: A URL for the model's license. + /// - coverImageURL: A URL for the model's cover image. This should be an image file. + public func createModel( + owner: String, + name: String, + visibility: Model.Visibility, + hardware: Hardware.ID, + description: String? = nil, + githubURL: URL? = nil, + paperURL: URL? = nil, + licenseURL: URL? = nil, + coverImageURL: URL? = nil + ) async throws -> Model + { + var params: [String: Value] = [ + "owner": "\(owner)", + "name": "\(name)", + "visibility": "\(visibility.rawValue)", + "hardware": "\(hardware)" + ] + + if let description { + params["description"] = "\(description)" + } + + if let githubURL { + params["github_url"] = "\(githubURL)" + } + + if let paperURL { + params["paper_url"] = "\(paperURL)" + } + + if let licenseURL { + params["license_url"] = "\(licenseURL)" + } + + if let coverImageURL { + params["cover_image_url"] = "\(coverImageURL)" + } + + return try await fetch(.post, "models", params: params) + } + // MARK: - private enum Method: String, Hashable { diff --git a/Tests/ReplicateTests/ClientTests.swift b/Tests/ReplicateTests/ClientTests.swift index a0115146..21ad5d3f 100644 --- a/Tests/ReplicateTests/ClientTests.swift +++ b/Tests/ReplicateTests/ClientTests.swift @@ -105,6 +105,12 @@ final class ClientTests: XCTestCase { XCTAssertEqual(version.id, "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa") } + func testCreateModel() async throws { + let model = try await client.createModel(owner: "replicate", name: "hello-world", visibility: .public, hardware: "cpu") + XCTAssertEqual(model.owner, "replicate") + XCTAssertEqual(model.name, "hello-world") + } + func testListModelCollections() async throws { let collections = try await client.listModelCollections() XCTAssertEqual(collections.results.count, 1) diff --git a/Tests/ReplicateTests/Helpers/MockURLProtocol.swift b/Tests/ReplicateTests/Helpers/MockURLProtocol.swift index 8c9de87e..df467894 100644 --- a/Tests/ReplicateTests/Helpers/MockURLProtocol.swift +++ b/Tests/ReplicateTests/Helpers/MockURLProtocol.swift @@ -465,6 +465,23 @@ class MockURLProtocol: URLProtocol { } } """# + case ("POST", "https://api.replicate.com/v1/models"?): + statusCode = 200 + json = #""" + { + "url": "https://replicate.com/replicate/hello-world", + "owner": "replicate", + "name": "hello-world", + "description": "A tiny model that says hello", + "visibility": "public", + "github_url": null, + "paper_url": null, + "license_url": null, + "run_count": 0, + "cover_image_url": null, + "default_example": null + } + """# case ("GET", "https://api.replicate.com/v1/models/replicate/hello-world/versions"?): statusCode = 200 json = #"""