Skip to content

Commit 1e0a2e3

Browse files
committed
Add support for models.search endpoint
1 parent 3fd73c6 commit 1e0a2e3

File tree

3 files changed

+95
-0
lines changed

3 files changed

+95
-0
lines changed

Sources/Replicate/Client.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,17 @@ public class Client {
367367
return try await fetch(.get, "models", cursor: cursor)
368368
}
369369

370+
/// Search for public models on Replicate.
371+
///
372+
/// - Parameter query: The search query string.
373+
/// - Returns: A page of models matching the search query.
374+
public func searchModels(query: String) async throws -> Pagination.Page<Model> {
375+
var request = try createRequest(method: .query, path: "models")
376+
request.addValue("text/plain", forHTTPHeaderField: "Content-Type")
377+
request.httpBody = query.description.data(using: .utf8)
378+
return try await sendRequest(request)
379+
}
380+
370381
/// Get a model
371382
///
372383
/// - Parameters:
@@ -638,6 +649,7 @@ public class Client {
638649
private enum Method: String, Hashable {
639650
case get = "GET"
640651
case post = "POST"
652+
case query = "QUERY"
641653
}
642654

643655
private func fetch<T: Decodable>(_ method: Method,

Tests/ReplicateTests/ClientTests.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,4 +240,11 @@ final class ClientTests: XCTestCase {
240240
XCTAssertEqual(error.detail, "Authentication credentials were not provided.")
241241
}
242242
}
243+
244+
func testSearchModels() async throws {
245+
let models = try await client.searchModels(query: "greeter")
246+
XCTAssertEqual(models.results.count, 1)
247+
XCTAssertEqual(models.results[0].owner, "replicate")
248+
XCTAssertEqual(models.results[0].name, "hello-world")
249+
}
243250
}

Tests/ReplicateTests/Helpers/MockURLProtocol.swift

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,82 @@ class MockURLProtocol: URLProtocol {
284284
]
285285
}
286286
"""#
287+
case ("QUERY", "https://api.replicate.com/v1/models"?):
288+
statusCode = 200
289+
json = #"""
290+
{
291+
"next": null,
292+
"previous": null,
293+
"results": [
294+
{
295+
"url": "https://replicate.com/replicate/hello-world",
296+
"owner": "replicate",
297+
"name": "hello-world",
298+
"description": "A tiny model that says hello",
299+
"visibility": "public",
300+
"github_url": "https://github.com/replicate/cog-examples",
301+
"paper_url": null,
302+
"license_url": null,
303+
"run_count": 930512,
304+
"cover_image_url": "https://tjzk.replicate.delivery/models_models_cover_image/9c1f748e-a9fc-4cfd-a497-68262ee6151a/replicate-prediction-caujujsgrng7.png",
305+
"default_example": {
306+
"completed_at": "2022-04-26T19:30:10.926419Z",
307+
"created_at": "2022-04-26T19:30:10.761396Z",
308+
"error": null,
309+
"id": "3s2vyrb3pfblrnyp2smdsxxjvu",
310+
"input": {
311+
"text": "Alice"
312+
},
313+
"logs": null,
314+
"metrics": {
315+
"predict_time": 2e-06
316+
},
317+
"output": "hello Alice",
318+
"started_at": "2022-04-26T19:30:10.926417Z",
319+
"status": "succeeded",
320+
"urls": {
321+
"get": "https://api.replicate.com/v1/predictions/3s2vyrb3pfblrnyp2smdsxxjvu",
322+
"cancel": "https://api.replicate.com/v1/predictions/3s2vyrb3pfblrnyp2smdsxxjvu/cancel"
323+
},
324+
"model": "replicate/hello-world",
325+
"version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
326+
"webhook_completed": null
327+
},
328+
"latest_version": {
329+
"id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
330+
"created_at": "2022-04-26T19:29:04.418669Z",
331+
"cog_version": "0.3.0",
332+
"openapi_schema": {
333+
"openapi": "3.1.0",
334+
"components": {
335+
"schemas": {
336+
"Input": {
337+
"type": "object",
338+
"title": "Input",
339+
"required": [
340+
"text"
341+
],
342+
"properties": {
343+
"text": {
344+
"type": "string",
345+
"title": "Text",
346+
"x-order": 0,
347+
"description": "Text to prefix with 'hello '"
348+
}
349+
}
350+
},
351+
"Output": {
352+
"type": "string",
353+
"title": "Output"
354+
}
355+
}
356+
}
357+
}
358+
}
359+
}
360+
]
361+
}
362+
"""#
287363
case ("GET", "https://api.replicate.com/v1/models/replicate/hello-world"?):
288364
statusCode = 200
289365
json = #"""

0 commit comments

Comments
 (0)