Skip to content

Commit

Permalink
Add progress property to Prediction (#75)
Browse files Browse the repository at this point in the history
* Add progress property to Prediction

* Check for macOS 13 in tests
  • Loading branch information
mattt authored Sep 16, 2024
1 parent d3d3bb7 commit d144107
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
29 changes: 29 additions & 0 deletions Sources/Replicate/Prediction.swift
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import struct Foundation.Date
import struct Foundation.TimeInterval
import struct Foundation.URL
import class Foundation.Progress
import struct Dispatch.DispatchTime

import RegexBuilder

/// A prediction with unspecified inputs and outputs.
public typealias AnyPrediction = Prediction<[String: Value], Value>

Expand Down Expand Up @@ -81,6 +84,32 @@ public struct Prediction<Input, Output>: Identifiable where Input: Codable, Outp

// MARK: -

@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
public var progress: Progress? {
guard let logs = self.logs else { return nil }

let regex: Regex = #/^\s*(\d+)%\s*\|.+?\|\s*(\d+)\/(\d+)/#

let lines = logs.split(separator: "\n")
guard !lines.isEmpty else { return nil }

for line in lines.reversed() {
let lineString = String(line).trimmingCharacters(in: .whitespaces)
if let match = try? regex.firstMatch(in: lineString),
let current = Int64(match.output.2),
let total = Int64(match.output.3)
{
let progress = Progress(totalUnitCount: total)
progress.completedUnitCount = current
return progress
}
}

return nil
}

// MARK: -

/// Wait for the prediction to complete.
///
/// - Parameters:
Expand Down
5 changes: 5 additions & 0 deletions Tests/ReplicateTests/ClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ final class ClientTests: XCTestCase {
XCTAssertEqual(prediction.status, .succeeded)
XCTAssertEqual(prediction.createdAt.timeIntervalSinceReferenceDate, 672703986.224, accuracy: 1)
XCTAssertEqual(prediction.urls["cancel"]?.absoluteString, "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel")

if #available(macOS 13.0, *) {
XCTAssertEqual(prediction.progress?.completedUnitCount, 5)
XCTAssertEqual(prediction.progress?.totalUnitCount, 5)
}
}

func testCancelPrediction() async throws {
Expand Down
2 changes: 1 addition & 1 deletion Tests/ReplicateTests/Helpers/MockURLProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class MockURLProtocol: URLProtocol {
},
"output": ["Hello, Alice!"],
"error": null,
"logs": "",
"logs": "Using seed: 12345,\n0%| | 0/5 [00:00<?, ?it/s]\n20%|██ | 1/5 [00:00<00:01, 21.38it/s]\n40%|████▍ | 2/5 [00:01<00:01, 22.46it/s]\n60%|████▍ | 3/5 [00:01<00:01, 22.46it/s]\n 80%|████████ | 4/5 [00:01<00:00, 22.86it/s]\n100%|██████████| 5/5 [00:02<00:00, 22.26it/s]",
"metrics": {
"predict_time": 10.0
}
Expand Down

0 comments on commit d144107

Please sign in to comment.