Skip to content

Commit

Permalink
Add support for hardware.list endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt committed Nov 2, 2023
1 parent 91819c8 commit 0ae49eb
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 0 deletions.
7 changes: 7 additions & 0 deletions Sources/Replicate/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,13 @@ public class Client {
return try await fetch(.post, "trainings/\(id)/cancel")
}

/// List hardware available for running a model on Replicate.
///
/// - Returns: An array of hardware.
public func listHardware() async throws -> [Hardware] {
return try await fetch(.get, "hardware")
}

// MARK: -

private enum Method: String, Hashable {
Expand Down
30 changes: 30 additions & 0 deletions Sources/Replicate/Hardware.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Hardware for running a model on Replicate.
public struct Hardware: Hashable, Codable {
public typealias ID = String

/// The product identifier for the hardware.
///
/// For example, "gpu-a40-large".
public let sku: String

/// The name of the hardware.
///
/// For example, "Nvidia A40 (Large) GPU".
public let name: String
}

// MARK: - Identifiable

extension Hardware: Identifiable {
public var id: String {
return self.sku
}
}

// MARK: - CustomStringConvertible

extension Hardware: CustomStringConvertible {
public var description: String {
return self.name
}
}
7 changes: 7 additions & 0 deletions Tests/ReplicateTests/ClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ final class ClientTests: XCTestCase {
XCTAssertEqual(trainings.results.count, 1)
}

func testListHardware() async throws {
let hardware = try await client.listHardware()
XCTAssertGreaterThan(hardware.count, 1)
XCTAssertEqual(hardware.first?.name, "CPU")
XCTAssertEqual(hardware.first?.sku, "cpu")
}

func testCustomBaseURL() async throws {
let client = Client(baseURLString: "https://v1.replicate.proxy", token: MockURLProtocol.validToken).mocked
let collection = try await client.getModelCollection("super-resolution")
Expand Down
13 changes: 13 additions & 0 deletions Tests/ReplicateTests/Helpers/MockURLProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,19 @@ class MockURLProtocol: URLProtocol {
"metrics": {}
}
"""#
case ("GET", "https://api.replicate.com/v1/hardware"?):
statusCode = 200
json = #"""
[
{ "name": "CPU", "sku": "cpu" },
{ "name": "Nvidia T4 GPU", "sku": "gpu-t4" },
{ "name": "Nvidia A40 GPU", "sku": "gpu-a40-small" },
{ "name": "Nvidia A40 (Large) GPU", "sku": "gpu-a40-large" },
{ "name": "Nvidia A40 (Large) GPU (8x)", "sku": "gpu-a40-large-8x" },
{ "name": "Nvidia A100 (40GB) GPU", "sku": "gpu-a100-small" },
{ "name": "Nvidia A100 (80GB) GPU", "sku": "gpu-a100-large" },
]
"""#
case ("GET", "https://api.replicate.com/v1/models"?):
statusCode = 200
json = #"""
Expand Down

0 comments on commit 0ae49eb

Please sign in to comment.