diff --git a/Sources/Replicate/Client.swift b/Sources/Replicate/Client.swift index d139c832..d9282d73 100644 --- a/Sources/Replicate/Client.swift +++ b/Sources/Replicate/Client.swift @@ -479,6 +479,13 @@ public class Client { return try await fetch(.post, "trainings/\(id)/cancel") } + /// List hardware available for running a model on Replicate. + /// + /// - Returns: An array of hardware. + public func listHardware() async throws -> [Hardware] { + return try await fetch(.get, "hardware") + } + // MARK: - private enum Method: String, Hashable { diff --git a/Sources/Replicate/Hardware.swift b/Sources/Replicate/Hardware.swift new file mode 100644 index 00000000..33a84bbb --- /dev/null +++ b/Sources/Replicate/Hardware.swift @@ -0,0 +1,30 @@ +// Hardware for running a model on Replicate. +public struct Hardware: Hashable, Codable { + public typealias ID = String + + /// The product identifier for the hardware. + /// + /// For example, "gpu-a40-large". + public let sku: String + + /// The name of the hardware. + /// + /// For example, "Nvidia A40 (Large) GPU". + public let name: String +} + +// MARK: - Identifiable + +extension Hardware: Identifiable { + public var id: String { + return self.sku + } +} + +// MARK: - CustomStringConvertible + +extension Hardware: CustomStringConvertible { + public var description: String { + return self.name + } +} diff --git a/Tests/ReplicateTests/ClientTests.swift b/Tests/ReplicateTests/ClientTests.swift index 34a44515..a0115146 100644 --- a/Tests/ReplicateTests/ClientTests.swift +++ b/Tests/ReplicateTests/ClientTests.swift @@ -153,6 +153,13 @@ final class ClientTests: XCTestCase { XCTAssertEqual(trainings.results.count, 1) } + func testListHardware() async throws { + let hardware = try await client.listHardware() + XCTAssertGreaterThan(hardware.count, 1) + XCTAssertEqual(hardware.first?.name, "CPU") + XCTAssertEqual(hardware.first?.sku, "cpu") + } + func testCustomBaseURL() async throws { let client = Client(baseURLString: "https://v1.replicate.proxy", token: MockURLProtocol.validToken).mocked let collection = try await client.getModelCollection("super-resolution") diff --git a/Tests/ReplicateTests/Helpers/MockURLProtocol.swift b/Tests/ReplicateTests/Helpers/MockURLProtocol.swift index 44a8dfdc..8c9de87e 100644 --- a/Tests/ReplicateTests/Helpers/MockURLProtocol.swift +++ b/Tests/ReplicateTests/Helpers/MockURLProtocol.swift @@ -158,6 +158,19 @@ class MockURLProtocol: URLProtocol { "metrics": {} } """# + case ("GET", "https://api.replicate.com/v1/hardware"?): + statusCode = 200 + json = #""" + [ + { "name": "CPU", "sku": "cpu" }, + { "name": "Nvidia T4 GPU", "sku": "gpu-t4" }, + { "name": "Nvidia A40 GPU", "sku": "gpu-a40-small" }, + { "name": "Nvidia A40 (Large) GPU", "sku": "gpu-a40-large" }, + { "name": "Nvidia A40 (Large) GPU (8x)", "sku": "gpu-a40-large-8x" }, + { "name": "Nvidia A100 (40GB) GPU", "sku": "gpu-a100-small" }, + { "name": "Nvidia A100 (80GB) GPU", "sku": "gpu-a100-large" }, + ] + """# case ("GET", "https://api.replicate.com/v1/models"?): statusCode = 200 json = #"""