diff --git a/Sources/Replicate/Client.swift b/Sources/Replicate/Client.swift index 93f4df80..7a4e0022 100644 --- a/Sources/Replicate/Client.swift +++ b/Sources/Replicate/Client.swift @@ -367,6 +367,17 @@ public class Client { return try await fetch(.get, "models", cursor: cursor) } + /// Search for public models on Replicate. + /// + /// - Parameter query: The search query string. + /// - Returns: A page of models matching the search query. + public func searchModels(query: String) async throws -> Pagination.Page { + var request = try createRequest(method: .query, path: "models") + request.addValue("text/plain", forHTTPHeaderField: "Content-Type") + request.httpBody = query.description.data(using: .utf8) + return try await sendRequest(request) + } + /// Get a model /// /// - Parameters: @@ -638,6 +649,7 @@ public class Client { private enum Method: String, Hashable { case get = "GET" case post = "POST" + case query = "QUERY" } private func fetch(_ method: Method, diff --git a/Tests/ReplicateTests/ClientTests.swift b/Tests/ReplicateTests/ClientTests.swift index 8221f6ae..c7c96dc8 100644 --- a/Tests/ReplicateTests/ClientTests.swift +++ b/Tests/ReplicateTests/ClientTests.swift @@ -240,4 +240,11 @@ final class ClientTests: XCTestCase { XCTAssertEqual(error.detail, "Authentication credentials were not provided.") } } + + func testSearchModels() async throws { + let models = try await client.searchModels(query: "greeter") + XCTAssertEqual(models.results.count, 1) + XCTAssertEqual(models.results[0].owner, "replicate") + XCTAssertEqual(models.results[0].name, "hello-world") + } } diff --git a/Tests/ReplicateTests/Helpers/MockURLProtocol.swift b/Tests/ReplicateTests/Helpers/MockURLProtocol.swift index 669fa99e..bc974129 100644 --- a/Tests/ReplicateTests/Helpers/MockURLProtocol.swift +++ b/Tests/ReplicateTests/Helpers/MockURLProtocol.swift @@ -284,6 +284,82 @@ class MockURLProtocol: URLProtocol { ] } """# + case ("QUERY", "https://api.replicate.com/v1/models"?): + statusCode = 200 + json = #""" + { + "next": null, + "previous": null, + "results": [ + { + "url": "https://replicate.com/replicate/hello-world", + "owner": "replicate", + "name": "hello-world", + "description": "A tiny model that says hello", + "visibility": "public", + "github_url": "https://github.com/replicate/cog-examples", + "paper_url": null, + "license_url": null, + "run_count": 930512, + "cover_image_url": "https://tjzk.replicate.delivery/models_models_cover_image/9c1f748e-a9fc-4cfd-a497-68262ee6151a/replicate-prediction-caujujsgrng7.png", + "default_example": { + "completed_at": "2022-04-26T19:30:10.926419Z", + "created_at": "2022-04-26T19:30:10.761396Z", + "error": null, + "id": "3s2vyrb3pfblrnyp2smdsxxjvu", + "input": { + "text": "Alice" + }, + "logs": null, + "metrics": { + "predict_time": 2e-06 + }, + "output": "hello Alice", + "started_at": "2022-04-26T19:30:10.926417Z", + "status": "succeeded", + "urls": { + "get": "https://api.replicate.com/v1/predictions/3s2vyrb3pfblrnyp2smdsxxjvu", + "cancel": "https://api.replicate.com/v1/predictions/3s2vyrb3pfblrnyp2smdsxxjvu/cancel" + }, + "model": "replicate/hello-world", + "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + "webhook_completed": null + }, + "latest_version": { + "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + "created_at": "2022-04-26T19:29:04.418669Z", + "cog_version": "0.3.0", + "openapi_schema": { + "openapi": "3.1.0", + "components": { + "schemas": { + "Input": { + "type": "object", + "title": "Input", + "required": [ + "text" + ], + "properties": { + "text": { + "type": "string", + "title": "Text", + "x-order": 0, + "description": "Text to prefix with 'hello '" + } + } + }, + "Output": { + "type": "string", + "title": "Output" + } + } + } + } + } + } + ] + } + """# case ("GET", "https://api.replicate.com/v1/models/replicate/hello-world"?): statusCode = 200 json = #"""