Skip to content

Commit

Permalink
Add support for models.search endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt committed Jul 19, 2024
1 parent 3fd73c6 commit 1e0a2e3
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 0 deletions.
12 changes: 12 additions & 0 deletions Sources/Replicate/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,17 @@ public class Client {
return try await fetch(.get, "models", cursor: cursor)
}

/// Search for public models on Replicate.
///
/// - Parameter query: The search query string.
/// - Returns: A page of models matching the search query.
public func searchModels(query: String) async throws -> Pagination.Page<Model> {
var request = try createRequest(method: .query, path: "models")
request.addValue("text/plain", forHTTPHeaderField: "Content-Type")
request.httpBody = query.description.data(using: .utf8)
return try await sendRequest(request)
}

/// Get a model
///
/// - Parameters:
Expand Down Expand Up @@ -638,6 +649,7 @@ public class Client {
private enum Method: String, Hashable {
case get = "GET"
case post = "POST"
case query = "QUERY"
}

private func fetch<T: Decodable>(_ method: Method,
Expand Down
7 changes: 7 additions & 0 deletions Tests/ReplicateTests/ClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -240,4 +240,11 @@ final class ClientTests: XCTestCase {
XCTAssertEqual(error.detail, "Authentication credentials were not provided.")
}
}

func testSearchModels() async throws {
let models = try await client.searchModels(query: "greeter")
XCTAssertEqual(models.results.count, 1)
XCTAssertEqual(models.results[0].owner, "replicate")
XCTAssertEqual(models.results[0].name, "hello-world")
}
}
76 changes: 76 additions & 0 deletions Tests/ReplicateTests/Helpers/MockURLProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,82 @@ class MockURLProtocol: URLProtocol {
]
}
"""#
case ("QUERY", "https://api.replicate.com/v1/models"?):
statusCode = 200
json = #"""
{
"next": null,
"previous": null,
"results": [
{
"url": "https://replicate.com/replicate/hello-world",
"owner": "replicate",
"name": "hello-world",
"description": "A tiny model that says hello",
"visibility": "public",
"github_url": "https://github.com/replicate/cog-examples",
"paper_url": null,
"license_url": null,
"run_count": 930512,
"cover_image_url": "https://tjzk.replicate.delivery/models_models_cover_image/9c1f748e-a9fc-4cfd-a497-68262ee6151a/replicate-prediction-caujujsgrng7.png",
"default_example": {
"completed_at": "2022-04-26T19:30:10.926419Z",
"created_at": "2022-04-26T19:30:10.761396Z",
"error": null,
"id": "3s2vyrb3pfblrnyp2smdsxxjvu",
"input": {
"text": "Alice"
},
"logs": null,
"metrics": {
"predict_time": 2e-06
},
"output": "hello Alice",
"started_at": "2022-04-26T19:30:10.926417Z",
"status": "succeeded",
"urls": {
"get": "https://api.replicate.com/v1/predictions/3s2vyrb3pfblrnyp2smdsxxjvu",
"cancel": "https://api.replicate.com/v1/predictions/3s2vyrb3pfblrnyp2smdsxxjvu/cancel"
},
"model": "replicate/hello-world",
"version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
"webhook_completed": null
},
"latest_version": {
"id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
"created_at": "2022-04-26T19:29:04.418669Z",
"cog_version": "0.3.0",
"openapi_schema": {
"openapi": "3.1.0",
"components": {
"schemas": {
"Input": {
"type": "object",
"title": "Input",
"required": [
"text"
],
"properties": {
"text": {
"type": "string",
"title": "Text",
"x-order": 0,
"description": "Text to prefix with 'hello '"
}
}
},
"Output": {
"type": "string",
"title": "Output"
}
}
}
}
}
}
]
}
"""#
case ("GET", "https://api.replicate.com/v1/models/replicate/hello-world"?):
statusCode = 200
json = #"""
Expand Down

0 comments on commit 1e0a2e3

Please sign in to comment.