Skip to content

Commit

Permalink
Add support for models.create endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt committed Nov 2, 2023
1 parent 0ae49eb commit 014d6df
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 0 deletions.
54 changes: 54 additions & 0 deletions Sources/Replicate/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,60 @@ public class Client {
return try await fetch(.get, "hardware")
}

/// Create a model
///
/// - Parameters:
/// - owner: The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization.
/// - name: The name of the model. This must be unique among all models owned by the user or organization.
/// - visibility: Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model.
/// - hardware: The SKU for the hardware used to run the model. Possible values can be found by calling ``listHardware()``.
/// - description: A description of the model.
/// - githubURL: A URL for the model's source code on GitHub.
/// - paperURL: A URL for the model's paper.
/// - licenseURL: A URL for the model's license.
/// - coverImageURL: A URL for the model's cover image. This should be an image file.
public func createModel(
owner: String,
name: String,
visibility: Model.Visibility,
hardware: Hardware.ID,
description: String? = nil,
githubURL: URL? = nil,
paperURL: URL? = nil,
licenseURL: URL? = nil,
coverImageURL: URL? = nil
) async throws -> Model
{
var params: [String: Value] = [
"owner": "\(owner)",
"name": "\(name)",
"visibility": "\(visibility.rawValue)",
"hardware": "\(hardware)"
]

if let description {
params["description"] = "\(description)"
}

if let githubURL {
params["github_url"] = "\(githubURL)"
}

if let paperURL {
params["paper_url"] = "\(paperURL)"
}

if let licenseURL {
params["license_url"] = "\(licenseURL)"
}

if let coverImageURL {
params["cover_image_url"] = "\(coverImageURL)"
}

return try await fetch(.post, "models", params: params)
}

// MARK: -

private enum Method: String, Hashable {
Expand Down
6 changes: 6 additions & 0 deletions Tests/ReplicateTests/ClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ final class ClientTests: XCTestCase {
XCTAssertEqual(version.id, "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
}

func testCreateModel() async throws {
let model = try await client.createModel(owner: "replicate", name: "hello-world", visibility: .public, hardware: "cpu")
XCTAssertEqual(model.owner, "replicate")
XCTAssertEqual(model.name, "hello-world")
}

func testListModelCollections() async throws {
let collections = try await client.listModelCollections()
XCTAssertEqual(collections.results.count, 1)
Expand Down
17 changes: 17 additions & 0 deletions Tests/ReplicateTests/Helpers/MockURLProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,23 @@ class MockURLProtocol: URLProtocol {
}
}
"""#
case ("POST", "https://api.replicate.com/v1/models"?):
statusCode = 200
json = #"""
{
"url": "https://replicate.com/replicate/hello-world",
"owner": "replicate",
"name": "hello-world",
"description": "A tiny model that says hello",
"visibility": "public",
"github_url": null,
"paper_url": null,
"license_url": null,
"run_count": 0,
"cover_image_url": null,
"default_example": null
}
"""#
case ("GET", "https://api.replicate.com/v1/models/replicate/hello-world/versions"?):
statusCode = 200
json = #"""
Expand Down

0 comments on commit 014d6df

Please sign in to comment.