Skip to content

Commit 0ae49eb

Browse files
committed
Add support for hardware.list endpoint
1 parent 91819c8 commit 0ae49eb

File tree

4 files changed

+57
-0
lines changed

4 files changed

+57
-0
lines changed

Sources/Replicate/Client.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,13 @@ public class Client {
479479
return try await fetch(.post, "trainings/\(id)/cancel")
480480
}
481481

482+
/// List hardware available for running a model on Replicate.
483+
///
484+
/// - Returns: An array of hardware.
485+
public func listHardware() async throws -> [Hardware] {
486+
return try await fetch(.get, "hardware")
487+
}
488+
482489
// MARK: -
483490

484491
private enum Method: String, Hashable {

Sources/Replicate/Hardware.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Hardware for running a model on Replicate.
2+
public struct Hardware: Hashable, Codable {
3+
public typealias ID = String
4+
5+
/// The product identifier for the hardware.
6+
///
7+
/// For example, "gpu-a40-large".
8+
public let sku: String
9+
10+
/// The name of the hardware.
11+
///
12+
/// For example, "Nvidia A40 (Large) GPU".
13+
public let name: String
14+
}
15+
16+
// MARK: - Identifiable
17+
18+
extension Hardware: Identifiable {
19+
public var id: String {
20+
return self.sku
21+
}
22+
}
23+
24+
// MARK: - CustomStringConvertible
25+
26+
extension Hardware: CustomStringConvertible {
27+
public var description: String {
28+
return self.name
29+
}
30+
}

Tests/ReplicateTests/ClientTests.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ final class ClientTests: XCTestCase {
153153
XCTAssertEqual(trainings.results.count, 1)
154154
}
155155

156+
func testListHardware() async throws {
157+
let hardware = try await client.listHardware()
158+
XCTAssertGreaterThan(hardware.count, 1)
159+
XCTAssertEqual(hardware.first?.name, "CPU")
160+
XCTAssertEqual(hardware.first?.sku, "cpu")
161+
}
162+
156163
func testCustomBaseURL() async throws {
157164
let client = Client(baseURLString: "https://v1.replicate.proxy", token: MockURLProtocol.validToken).mocked
158165
let collection = try await client.getModelCollection("super-resolution")

Tests/ReplicateTests/Helpers/MockURLProtocol.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,19 @@ class MockURLProtocol: URLProtocol {
158158
"metrics": {}
159159
}
160160
"""#
161+
case ("GET", "https://api.replicate.com/v1/hardware"?):
162+
statusCode = 200
163+
json = #"""
164+
[
165+
{ "name": "CPU", "sku": "cpu" },
166+
{ "name": "Nvidia T4 GPU", "sku": "gpu-t4" },
167+
{ "name": "Nvidia A40 GPU", "sku": "gpu-a40-small" },
168+
{ "name": "Nvidia A40 (Large) GPU", "sku": "gpu-a40-large" },
169+
{ "name": "Nvidia A40 (Large) GPU (8x)", "sku": "gpu-a40-large-8x" },
170+
{ "name": "Nvidia A100 (40GB) GPU", "sku": "gpu-a100-small" },
171+
{ "name": "Nvidia A100 (80GB) GPU", "sku": "gpu-a100-large" },
172+
]
173+
"""#
161174
case ("GET", "https://api.replicate.com/v1/models"?):
162175
statusCode = 200
163176
json = #"""

0 commit comments

Comments
 (0)