Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ let package = Package(
.package(url: "https://github.com/swift-server/async-http-client.git", from: "1.20.1"),
.package(url: "https://github.com/orlandos-nl/DNSClient.git", from: "2.4.1"),
.package(url: "https://github.com/Bouke/DNS.git", from: "1.2.0"),
.package(url: "https://github.com/apple/containerization.git", exact: Version(stringLiteral: scVersion)),
.package(url: "https://github.com/apple/containerization.git", branch: "main"),
],
targets: [
.executableTarget(
Expand Down
7 changes: 4 additions & 3 deletions Sources/ContainerClient/Core/ClientImage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ extension ClientImage {
})
}

public static func pull(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil) async throws -> ClientImage {
public static func pull(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil, maxConcurrentDownloads: Int = 3) async throws -> ClientImage {
let client = newXPCClient()
let request = newRequest(.imagePull)

Expand All @@ -234,6 +234,7 @@ extension ClientImage {

let insecure = try scheme.schemeFor(host: host) == .http
request.set(key: .insecureFlag, value: insecure)
request.set(key: .maxConcurrentDownloads, value: Int64(maxConcurrentDownloads))

var progressUpdateClient: ProgressUpdateClient?
if let progressUpdate {
Expand Down Expand Up @@ -293,7 +294,7 @@ extension ClientImage {
return (digests, size)
}

public static func fetch(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil) async throws -> ClientImage
public static func fetch(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil, maxConcurrentDownloads: Int = 3) async throws -> ClientImage
{
do {
let match = try await self.get(reference: reference)
Expand All @@ -307,7 +308,7 @@ extension ClientImage {
guard err.isCode(.notFound) else {
throw err
}
return try await Self.pull(reference: reference, platform: platform, scheme: scheme, progressUpdate: progressUpdate)
return try await Self.pull(reference: reference, platform: platform, scheme: scheme, progressUpdate: progressUpdate, maxConcurrentDownloads: maxConcurrentDownloads)
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions Sources/ContainerClient/Flags.swift
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,15 @@ public struct Flags {
self.disableProgressUpdates = disableProgressUpdates
}

public init(disableProgressUpdates: Bool, maxConcurrentDownloads: Int) {
self.disableProgressUpdates = disableProgressUpdates
self.maxConcurrentDownloads = maxConcurrentDownloads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason behind putting this in the progress group?

The problem I see here is that any command that wants to present the --progress flag (--disable-progress-updates is gone, apologies for letting this PR go so long so as to force a rebase here) will also see --max-concurrent-downloads whether downloads are relevant to the other command or not. Is there a better home for this? Or do we need to create another reusable group especially for download-related options?

I don't see any harm in copy/pasting this option into each command that needs it either, unless there's simply too many of them.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I think image pull definitely needs it as a user might change this for large images.

container run and container create could use it but its probably not needed as an option.

Is there anything similar to Docker daemon.json? It would be convenient to set a config so all images are pulled with the same maxConcurrentDownloads

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still unresolved. Consider:

  1. Removing maxConcurrentDownloads from Flags.Progress
  2. Creating new Flags.ImageFetch group specifically for image fetch options
  3. Adding @OptionGroup var imageFetchFlags: Flags.ImageFetch to:
    • container image pull
    • container run
    • container create

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestion, set up the new option group

}

@Flag(name: .long, help: "Disable progress bar updates")
public var disableProgressUpdates = false

@Option(name: .long, help: "Maximum number of concurrent layer downloads (default: 3)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adityaramani Should we use the word "layer" or "blob" here?

public var maxConcurrentDownloads: Int = 3
}
}
2 changes: 1 addition & 1 deletion Sources/ContainerCommands/Image/ImagePull.swift
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ extension Application {
let taskManager = ProgressTaskCoordinator()
let fetchTask = await taskManager.startTask()
let image = try await ClientImage.pull(
reference: processedReference, platform: p, scheme: scheme, progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progress.handler)
reference: processedReference, platform: p, scheme: scheme, progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progress.handler), maxConcurrentDownloads: self.progressFlags.maxConcurrentDownloads
)

progress.set(description: "Unpacking image")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public enum ImagesServiceXPCKeys: String {
case ociPlatform
case insecureFlag
case garbageCollect
case maxConcurrentDownloads

/// ContentStore
case digest
Expand All @@ -54,6 +55,10 @@ extension XPCMessage {
self.set(key: key.rawValue, value: value)
}

public func set(key: ImagesServiceXPCKeys, value: Int64) {
self.set(key: key.rawValue, value: value)
}

public func set(key: ImagesServiceXPCKeys, value: Data) {
self.set(key: key.rawValue, value: value)
}
Expand All @@ -78,6 +83,10 @@ extension XPCMessage {
self.uint64(key: key.rawValue)
}

public func int64(key: ImagesServiceXPCKeys) -> Int64 {
self.int64(key: key.rawValue)
}

public func bool(key: ImagesServiceXPCKeys) -> Bool {
self.bool(key: key.rawValue)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ public actor ImagesService {
return try await imageStore.list().map { $0.description.fromCZ }
}

public func pull(reference: String, platform: Platform?, insecure: Bool, progressUpdate: ProgressUpdateHandler?) async throws -> ImageDescription {
self.log.info("ImagesService: \(#function) - ref: \(reference), platform: \(String(describing: platform)), insecure: \(insecure)")
public func pull(reference: String, platform: Platform?, insecure: Bool, progressUpdate: ProgressUpdateHandler?, maxConcurrentDownloads: Int = 3) async throws -> ImageDescription {
self.log.info("ImagesService: \(#function) - ref: \(reference), platform: \(String(describing: platform)), insecure: \(insecure), maxConcurrentDownloads: \(maxConcurrentDownloads)")
let img = try await Self.withAuthentication(ref: reference) { auth in
try await self.imageStore.pull(
reference: reference, platform: platform, insecure: insecure, auth: auth, progress: ContainerizationProgressAdapter.handler(from: progressUpdate))
reference: reference, platform: platform, insecure: insecure, auth: auth, progress: ContainerizationProgressAdapter.handler(from: progressUpdate), maxConcurrentDownloads: maxConcurrentDownloads)
}
guard let img else {
throw ContainerizationError(.internalError, message: "Failed to pull image \(reference)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ public struct ImagesServiceHarness: Sendable {
platform = try JSONDecoder().decode(ContainerizationOCI.Platform.self, from: platformData)
}
let insecure = message.bool(key: .insecureFlag)
let maxConcurrentDownloads = message.int64(key: .maxConcurrentDownloads)

let progressUpdateService = ProgressUpdateService(message: message)
let imageDescription = try await service.pull(reference: ref, platform: platform, insecure: insecure, progressUpdate: progressUpdateService?.handler)
let imageDescription = try await service.pull(reference: ref, platform: platform, insecure: insecure, progressUpdate: progressUpdateService?.handler, maxConcurrentDownloads: Int(maxConcurrentDownloads))

let imageData = try JSONEncoder().encode(imageDescription)
let reply = message.reply()
Expand Down
49 changes: 47 additions & 2 deletions Sources/TerminalProgress/ProgressTaskCoordinator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import Foundation

/// A type that represents a task whose progress is being monitored.
public struct ProgressTask: Sendable, Equatable {
public struct ProgressTask: Sendable, Equatable, Hashable {
private var id = UUID()
private var coordinator: ProgressTaskCoordinator
internal var coordinator: ProgressTaskCoordinator

init(manager: ProgressTaskCoordinator) {
self.coordinator = manager
Expand All @@ -29,6 +29,10 @@ public struct ProgressTask: Sendable, Equatable {
lhs.id == rhs.id
}

public func hash(into hasher: inout Hasher) {
hasher.combine(id)
}

/// Returns `true` if this task is the currently active task, `false` otherwise.
public func isCurrent() async -> Bool {
guard let currentTask = await coordinator.currentTask else {
Expand All @@ -41,6 +45,7 @@ public struct ProgressTask: Sendable, Equatable {
/// A type that coordinates progress tasks to ignore updates from completed tasks.
public actor ProgressTaskCoordinator {
var currentTask: ProgressTask?
var activeTasks: Set<ProgressTask> = []

/// Creates an instance of `ProgressTaskCoordinator`.
public init() {}
Expand All @@ -52,9 +57,36 @@ public actor ProgressTaskCoordinator {
return newTask
}

/// Starts multiple concurrent tasks and returns them.
/// - Parameter count: The number of concurrent tasks to start.
/// - Returns: An array of ProgressTask instances.
public func startConcurrentTasks(count: Int) -> [ProgressTask] {
var tasks: [ProgressTask] = []
for _ in 0..<count {
let task = ProgressTask(manager: self)
tasks.append(task)
activeTasks.insert(task)
}
return tasks
}

/// Marks a specific task as completed and removes it from active tasks.
/// - Parameter task: The task to mark as completed.
public func completeTask(_ task: ProgressTask) {
activeTasks.remove(task)
}

/// Checks if a task is currently active.
/// - Parameter task: The task to check.
/// - Returns: `true` if the task is active, `false` otherwise.
public func isTaskActive(_ task: ProgressTask) -> Bool {
activeTasks.contains(task)
}

/// Performs cleanup when the monitored tasks complete.
public func finish() {
currentTask = nil
activeTasks.removeAll()
}

/// Returns a handler that updates the progress of a given task.
Expand All @@ -69,4 +101,17 @@ public actor ProgressTaskCoordinator {
}
}
}

/// Returns a handler that updates the progress for concurrent tasks.
/// - Parameters:
/// - task: The task whose progress is being updated.
/// - progressUpdate: The handler to invoke when progress updates are received.
public static func concurrentHandler(for task: ProgressTask, from progressUpdate: @escaping ProgressUpdateHandler) -> ProgressUpdateHandler {
{ events in
// Only process updates if the task is still active
if await task.coordinator.isTaskActive(task) {
await progressUpdate(events)
}
}
}
}
109 changes: 109 additions & 0 deletions test_concurrency.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/usr/bin/env swift

import Foundation

func testConcurrentDownloads() async throws {
print("Testing concurrent download behavior...\n")

// Track concurrent task count
actor ConcurrencyTracker {
var currentCount = 0
var maxObservedCount = 0
var completedTasks = 0

func taskStarted() {
currentCount += 1
maxObservedCount = max(maxObservedCount, currentCount)
}

func taskCompleted() {
currentCount -= 1
completedTasks += 1
}

func getStats() -> (max: Int, completed: Int) {
return (maxObservedCount, completedTasks)
}

func reset() {
currentCount = 0
maxObservedCount = 0
completedTasks = 0
}
}

let tracker = ConcurrencyTracker()

// Test with different concurrency limits
for maxConcurrent in [1, 3, 6] {
await tracker.reset()

// Simulate downloading 20 layers
let layerCount = 20
let layers = Array(0..<layerCount)

print("Testing maxConcurrent=\(maxConcurrent) with \(layerCount) layers...")

let startTime = Date()

try await withThrowingTaskGroup(of: Void.self) { group in
var iterator = layers.makeIterator()

// Start initial batch based on maxConcurrent
for _ in 0..<maxConcurrent {
if iterator.next() != nil {
group.addTask {
await tracker.taskStarted()
try await Task.sleep(nanoseconds: 10_000_000)
await tracker.taskCompleted()
}
}
}
for try await _ in group {
if iterator.next() != nil {
group.addTask {
await tracker.taskStarted()
try await Task.sleep(nanoseconds: 10_000_000)
await tracker.taskCompleted()
}
}
}
}

let duration = Date().timeIntervalSince(startTime)
let stats = await tracker.getStats()

print(" ✓ Completed: \(stats.completed)/\(layerCount)")
print(" ✓ Max concurrent: \(stats.max)")
print(" ✓ Duration: \(String(format: "%.3f", duration))s")

guard stats.max <= maxConcurrent + 1 else {
throw TestError.concurrencyLimitExceeded
}

guard stats.completed == layerCount else {
throw TestError.incompleteTasks
}

print(" ✅ PASSED\n")
}

print("All tests passed!")
}

enum TestError: Error {
case concurrencyLimitExceeded
case incompleteTasks
}

Task {
do {
try await testConcurrentDownloads()
exit(0)
} catch {
print("Test failed: \(error)")
exit(1)
}
}

RunLoop.main.run()
Loading