diff --git a/Blockchain/Sources/Blockchain/Validator/ServiceBase2.swift b/Blockchain/Sources/Blockchain/Validator/ServiceBase2.swift index 3de4f700..906cfbe1 100644 --- a/Blockchain/Sources/Blockchain/Validator/ServiceBase2.swift +++ b/Blockchain/Sources/Blockchain/Validator/ServiceBase2.swift @@ -45,13 +45,17 @@ public class ServiceBase2: ServiceBase, @unchecked Sendable { let cancellables = cancellables let cancellable = scheduler.schedule(id: id, delay: delay, repeats: repeats) { if !repeats { - cancellables.write { $0.remove(IdCancellable(id: id, cancellable: nil)) } + cancellables.write { + $0.remove(IdCancellable(id: id, cancellable: nil)) + } } await task() } onCancel: { cancellables.write { $0.remove(IdCancellable(id: id, cancellable: nil)) } } - cancellables.write { $0.insert(IdCancellable(id: id, cancellable: cancellable)) } + cancellables.write { + $0.insert(IdCancellable(id: id, cancellable: cancellable)) + } return cancellable } diff --git a/Blockchain/Tests/BlockchainTests/MockScheduler.swift b/Blockchain/Tests/BlockchainTests/MockScheduler.swift index f116e5a3..db55836b 100644 --- a/Blockchain/Tests/BlockchainTests/MockScheduler.swift +++ b/Blockchain/Tests/BlockchainTests/MockScheduler.swift @@ -65,7 +65,7 @@ final class MockScheduler: Scheduler, Sendable { storage.tasks.insert(task) } return Cancellable { - self.storage.mutate { storage in + self.storage.write { storage in if let index = storage.tasks.array.firstIndex(where: { $0.id == id }) { let task = storage.tasks.remove(at: index) task.cancel?() @@ -80,7 +80,7 @@ final class MockScheduler: Scheduler, Sendable { } func advanceNext(to time: TimeInterval) async -> Bool { - let task: SchedulerTask? = storage.mutate { storage in + let task: SchedulerTask? = storage.write { storage in if let task = storage.tasks.array.first, task.scheduleTime <= time { storage.tasks.remove(at: 0) return task diff --git a/Boka/Sources/BokaLogger.swift b/Boka/Sources/BokaLogger.swift index 9e2daf17..4b7592b6 100644 --- a/Boka/Sources/BokaLogger.swift +++ b/Boka/Sources/BokaLogger.swift @@ -79,7 +79,7 @@ public struct BokaLogger: LogHandler, Sendable { } let defaultLevel = defaultLevel - return filters.mutate { filters in + return filters.write { filters in for (key, value) in filters where label.hasPrefix(key) { filters[label] = value return value diff --git a/Networking/Package.swift b/Networking/Package.swift index 19fddfa8..36e29a4d 100644 --- a/Networking/Package.swift +++ b/Networking/Package.swift @@ -20,12 +20,14 @@ let package = Package( .package(url: "https://github.com/apple/swift-log.git", from: "1.6.0"), .package(url: "https://github.com/apple/swift-certificates.git", from: "1.5.0"), .package(url: "https://github.com/apple/swift-testing.git", branch: "0.10.0"), + .package(url: "https://github.com/gh123man/Async-Channels.git", from: "1.0.0"), ], targets: [ .target( name: "Networking", dependencies: [ "MsQuicSwift", + .product(name: "AsyncChannels", package: "Async-Channels"), .product(name: "Logging", package: "swift-log"), .product(name: "X509", package: "swift-certificates"), ] diff --git a/Networking/Sources/MsQuicSwift/NetAddr.swift b/Networking/Sources/MsQuicSwift/NetAddr.swift index 4da717ad..1c776393 100644 --- a/Networking/Sources/MsQuicSwift/NetAddr.swift +++ b/Networking/Sources/MsQuicSwift/NetAddr.swift @@ -16,6 +16,7 @@ public struct NetAddr: Hashable, Sendable { self.ipAddress = ipAddress self.port = port self.ipv4 = ipv4 + // TODO: automatically determine the ip address family } public init(quicAddr: QUIC_ADDR) { diff --git a/Networking/Sources/MsQuicSwift/QuicConnection.swift b/Networking/Sources/MsQuicSwift/QuicConnection.swift index 863d449f..d5af66cf 100644 --- a/Networking/Sources/MsQuicSwift/QuicConnection.swift +++ b/Networking/Sources/MsQuicSwift/QuicConnection.swift @@ -100,7 +100,7 @@ public final class QuicConnection: Sendable { public func connect(to address: NetAddr) throws { logger.debug("connecting to \(address)") - try storage.mutate { storage in + try storage.write { storage in guard var storage2 = storage else { throw QuicError.alreadyClosed } @@ -122,7 +122,7 @@ public final class QuicConnection: Sendable { public func shutdown(errorCode: QuicErrorCode = .success) throws { logger.debug("closing connection") - try storage.mutate { storage in + try storage.write { storage in guard let storage2 = storage else { throw QuicError.alreadyClosed } @@ -285,3 +285,13 @@ private func connectionCallback( return handle.callbackHandler(event: event!).rawValue } + +extension QuicConnection: Hashable { + public func hash(into hasher: inout Hasher) { + hasher.combine(ObjectIdentifier(self)) + } + + public static func == (lhs: QuicConnection, rhs: QuicConnection) -> Bool { + ObjectIdentifier(lhs) == ObjectIdentifier(rhs) + } +} diff --git a/Networking/Sources/MsQuicSwift/QuicStatus.swift b/Networking/Sources/MsQuicSwift/QuicStatus.swift index df79be9e..cd755c82 100644 --- a/Networking/Sources/MsQuicSwift/QuicStatus.swift +++ b/Networking/Sources/MsQuicSwift/QuicStatus.swift @@ -23,7 +23,7 @@ public enum QuicStatus: Equatable, Sendable, Codable, RawRepresentable { } extension QuicStatus { - var isSucceeded: Bool { + public var isSucceeded: Bool { switch self { case let .code(code): Int32(bitPattern: code.rawValue) <= 0 diff --git a/Networking/Sources/MsQuicSwift/QuicStream.swift b/Networking/Sources/MsQuicSwift/QuicStream.swift index c0df98c2..1b446284 100644 --- a/Networking/Sources/MsQuicSwift/QuicStream.swift +++ b/Networking/Sources/MsQuicSwift/QuicStream.swift @@ -72,7 +72,7 @@ public final class QuicStream: Sendable { public func shutdown(errorCode: QuicErrorCode = .success) throws { logger.debug("closing stream \(errorCode)") - try storage.mutate { storage in + try storage.write { storage in guard let storage2 = storage else { throw QuicError.alreadyClosed } @@ -88,7 +88,7 @@ public final class QuicStream: Sendable { } } - public func send(with data: Data, startStream: Bool = false, closeStream: Bool = false) throws { + public func send(data: Data, startStream: Bool = false, closeStream: Bool = false) throws { logger.trace("Sending \(data.count) bytes") try storage.read { storage in @@ -241,3 +241,13 @@ private func streamCallback( return handle.callbackHandler(event: event!).rawValue } + +extension QuicStream: Hashable { + public func hash(into hasher: inout Hasher) { + hasher.combine(ObjectIdentifier(self)) + } + + public static func == (lhs: QuicStream, rhs: QuicStream) -> Bool { + ObjectIdentifier(lhs) == ObjectIdentifier(rhs) + } +} diff --git a/Networking/Sources/Networking/Alpn.swift b/Networking/Sources/Networking/Alpn.swift index c7d1864d..265f1d87 100644 --- a/Networking/Sources/Networking/Alpn.swift +++ b/Networking/Sources/Networking/Alpn.swift @@ -4,7 +4,10 @@ import Utils public struct Alpn: Sendable { public let data: Data private static let headerPrefixLength = 8 - init(_ protocolName: String = "jamnp-s", version: String = "0", genesisHeader: Data32) { - data = Data("\(protocolName)/\(version)/\(genesisHeader.toHexString().prefix(Alpn.headerPrefixLength))".utf8) + init(protocolName: String = "jamnp-s", version: String = "0", genesisHeader: Data32, builder: Bool) { + let header: String.SubSequence = genesisHeader.toHexString().prefix(Alpn.headerPrefixLength) + data = Data( + "\(protocolName)/\(version)/\(header)\(builder ? "/builder" : "")".utf8 + ) } } diff --git a/Networking/Sources/Networking/Connection.swift b/Networking/Sources/Networking/Connection.swift new file mode 100644 index 00000000..11c1c475 --- /dev/null +++ b/Networking/Sources/Networking/Connection.swift @@ -0,0 +1,68 @@ +import Foundation +import MsQuicSwift +import TracingUtils +import Utils + +private let logger = Logger(label: "Connection") + +public final class Connection: Sendable { + let connection: QuicConnection + let impl: PeerImpl + let mode: PeerMode + let remoteAddress: NetAddr + let presistentStreams: ThreadSafeContainer<[UniquePresistentStreamKind: Stream]> = .init([:]) + + init(_ connection: QuicConnection, impl: PeerImpl, mode: PeerMode, remoteAddress: NetAddr) { + self.connection = connection + self.impl = impl + self.mode = mode + self.remoteAddress = remoteAddress + } + + public func getStream(kind: UniquePresistentStreamKind) throws -> Stream { + let stream = presistentStreams.read { presistentStreams in + presistentStreams[kind] + } + return try stream ?? presistentStreams.write { presistentStreams in + if let stream = presistentStreams[kind] { + return stream + } + let stream = try self.createStream(kind: kind.rawValue) + presistentStreams[kind] = stream + return stream + } + } + + private func createStream(kind: UInt8) throws -> Stream { + let stream = try Stream(connection.createStream(), impl: impl) + impl.addStream(stream) + try stream.send(data: Data([kind])) + return stream + } + + public func createStream(kind: CommonEphemeralStreamKind) throws -> Stream { + try createStream(kind: kind.rawValue) + } + + func streamStarted(stream: QuicStream) { + let stream = Stream(stream, impl: impl) + impl.addStream(stream) + Task { + guard let byte = await stream.receiveByte() else { + logger.debug("stream closed without receiving kind. status: \(stream.status)") + return + } + if let upKind = UniquePresistentStreamKind(rawValue: byte) { + // TODO: handle duplicated UP streams + presistentStreams.write { presistentStreams in + presistentStreams[upKind] = stream + } + return + } + if let ceKind = CommonEphemeralStreamKind(rawValue: byte) { + logger.debug("stream opened. kind: \(ceKind)") + // TODO: handle requests + } + } + } +} diff --git a/Networking/Sources/Networking/Peer.swift b/Networking/Sources/Networking/Peer.swift index 95b8754c..038db69a 100644 --- a/Networking/Sources/Networking/Peer.swift +++ b/Networking/Sources/Networking/Peer.swift @@ -3,28 +3,46 @@ import Logging import MsQuicSwift import Utils -private let logger = Logger(label: "PeerServer") - public enum StreamType: Sendable { case uniquePersistent case commonEphemeral } +public enum PeerMode: Sendable, Hashable { + case validator + case builder + // case proxy // not yet specified +} + public protocol Message { func encode() -> Data } -public struct PeerConfiguration: Sendable { +public struct PeerOptions: Sendable { + public var mode: PeerMode public var listenAddress: NetAddr - public var alpns: [Alpn] - public var pkcs12: Data - public var settings: QuicSettings + public var genesisHeader: Data32 + public var secretKey: Ed25519.SecretKey + public var serverSettings: QuicSettings + public var clientSettings: QuicSettings + public var peerSettings: PeerSettings - public init(listenAddress: NetAddr, alpns: [Alpn], pkcs12: Data, settings: QuicSettings = .defaultSettings) { + public init( + mode: PeerMode, + listenAddress: NetAddr, + genesisHeader: Data32, + secretKey: Ed25519.SecretKey, + serverSettings: QuicSettings = .defaultSettings, + clientSettings: QuicSettings = .defaultSettings, + peerSettings: PeerSettings = .defaultSettings + ) { + self.mode = mode self.listenAddress = listenAddress - self.alpns = alpns - self.pkcs12 = pkcs12 - self.settings = settings + self.genesisHeader = genesisHeader + self.secretKey = secretKey + self.serverSettings = serverSettings + self.clientSettings = clientSettings + self.peerSettings = peerSettings } } @@ -33,35 +51,228 @@ public struct PeerConfiguration: Sendable { // - limit max connections per connection type // - manage peer reputation and rotate connections when full public final class Peer: Sendable { - private let config: PeerConfiguration - private let eventBus: EventBus + private let impl: PeerImpl + private let listener: QuicListener - private let connections: ThreadSafeContainer<[NetAddr: QuicConnection]> = .init([:]) - private let streams: ThreadSafeContainer<[NetAddr: [QuicStream]]> = .init([:]) public var events: some Subscribable { - eventBus + impl.eventBus } - public init(config: PeerConfiguration, eventBus: EventBus) async throws { - self.config = config - self.eventBus = eventBus + private var logger: Logger { + impl.logger + } + + public init(options: PeerOptions, eventBus: EventBus) throws { + let logger = Logger(label: "Peer".uniqueId) + let eventBus = eventBus + + let alpns = [ + PeerMode.validator: Alpn(genesisHeader: options.genesisHeader, builder: false).data, + PeerMode.builder: Alpn(genesisHeader: options.genesisHeader, builder: true).data, + ] + let allAlpns = Array(alpns.values) - let alpns = config.alpns.map(\.data) + let pkcs12 = try generateSelfSignedCertificate(privateKey: options.secretKey) let registration = try QuicRegistration() - let configuration = try QuicConfiguration( - registration: registration, pkcs12: config.pkcs12, alpns: alpns, client: false, settings: config.settings + let serverConfiguration = try QuicConfiguration( + registration: registration, pkcs12: pkcs12, alpns: allAlpns, client: false, settings: options.serverSettings + ) + + let clientAlpn = alpns[options.mode]! + let clientConfiguration = try QuicConfiguration( + registration: registration, pkcs12: pkcs12, alpns: [clientAlpn], client: true, settings: options.clientSettings + ) + + impl = PeerImpl( + logger: logger, + eventBus: eventBus, + mode: options.mode, + settings: options.peerSettings, + alpns: alpns, + clientConfiguration: clientConfiguration ) listener = try QuicListener( - handler: PeerEventHandler(), + handler: PeerEventHandler(impl), registration: registration, - configuration: configuration, - listenAddress: config.listenAddress, - alpns: alpns + configuration: serverConfiguration, + listenAddress: options.listenAddress, + alpns: allAlpns ) } + + public func connect(to address: NetAddr, mode: PeerMode) throws -> Connection { + let conn = impl.connections.read { connections in + connections.byType[mode]?[address] + } + return try conn ?? impl.connections.write { connections in + let curr = connections.byType[mode, default: [:]][address] + if let curr { + return curr + } + let conn = try Connection( + QuicConnection( + handler: PeerEventHandler(self.impl), + registration: self.impl.clientConfiguration.registration, + configuration: self.impl.clientConfiguration + ), + impl: self.impl, + mode: mode, + remoteAddress: address + ) + connections.byType[mode, default: [:]][address] = conn + connections.byId[conn.connection] = conn + return conn + } + } +} + +struct ConnectionStorage { + var byType: [PeerMode: [NetAddr: Connection]] = [:] + var byId: [QuicConnection: Connection] = [:] +} + +final class PeerImpl: Sendable { + fileprivate let logger: Logger + fileprivate let eventBus: EventBus + fileprivate let mode: PeerMode + fileprivate let settings: PeerSettings + fileprivate let alpns: [PeerMode: Data] + fileprivate let alpnLookup: [Data: PeerMode] + + fileprivate let clientConfiguration: QuicConfiguration + + fileprivate let connections: ThreadSafeContainer = .init(.init()) + fileprivate let streams: ThreadSafeContainer<[QuicStream: Stream]> = .init([:]) + + init( + logger: Logger, + eventBus: EventBus, + mode: PeerMode, + settings: PeerSettings, + alpns: [PeerMode: Data], + clientConfiguration: QuicConfiguration + ) { + self.logger = logger + self.eventBus = eventBus + self.mode = mode + self.settings = settings + self.alpns = alpns + self.clientConfiguration = clientConfiguration + + var alpnLookup = [Data: PeerMode]() + for (mode, alpn) in alpns { + alpnLookup[alpn] = mode + } + self.alpnLookup = alpnLookup + } + + func addConnection(_ connection: QuicConnection, addr: NetAddr, mode: PeerMode) -> Bool { + connections.write { connections in + if mode == .builder { + let currentCount = connections.byType[mode]?.count ?? 0 + if currentCount >= self.settings.maxBuilderConnections { + self.logger.warning("max builder connections reached") + // TODO: consider connection rotation strategy + return false + } + } + if connections.byType[mode, default: [:]][addr] != nil { + self.logger.warning("connection already exists") + return false + } + let conn = Connection(connection, impl: self, mode: mode, remoteAddress: addr) + connections.byType[mode, default: [:]][addr] = conn + connections.byId[connection] = conn + return true + } + } + + func addStream(_ stream: Stream) { + streams.write { streams in + if streams[stream.stream] != nil { + self.logger.warning("stream already exists") + } + streams[stream.stream] = stream + } + } } -public final class PeerEventHandler: QuicEventHandler {} +private final class PeerEventHandler: QuicEventHandler { + private let impl: PeerImpl + + private var logger: Logger { + impl.logger + } + + init(_ impl: PeerImpl) { + self.impl = impl + } + + func newConnection(_: QuicListener, connection: QuicConnection, info: ConnectionInfo) -> QuicStatus { + let addr = info.remoteAddress + let mode = impl.alpnLookup[info.negotiatedAlpn] + guard let mode else { + logger.warning("unknown alpn: \(String(data: info.negotiatedAlpn, encoding: .utf8) ?? info.negotiatedAlpn.toDebugHexString())") + return .code(.alpnNegFailure) + } + logger.debug("new connection: \(addr) mode: \(mode)") + if impl.addConnection(connection, addr: addr, mode: mode) { + return .code(.success) + } else { + return .code(.connectionRefused) + } + } + + func shouldOpen(_: QuicConnection, certificate _: Data?) -> QuicStatus { + // TODO: verify certificate + // - Require a certificate + // - Verify the alt name matches to the public key + // - Check connection mode and if validator, verify if it is current or next validator + .code(.success) + } + + func connected(_: QuicConnection) {} + + func shutdownInitiated(_ connection: QuicConnection, reason _: ConnectionCloseReason) { + impl.connections.write { connections in + if let conn = connections.byId[connection] { + connections.byId.removeValue(forKey: connection) + connections.byType[conn.mode]?.removeValue(forKey: conn.remoteAddress) + } + } + } + + func streamStarted(_ connection: QuicConnection, stream: QuicStream) { + let conn = impl.connections.read { connections in + connections.byId[connection] + } + if let conn { + conn.streamStarted(stream: stream) + } + } + + func dataReceived(_ stream: QuicStream, data: Data) { + let stream = impl.streams.read { streams in + streams[stream] + } + if let stream { + stream.received(data: data) + } + } + + func closed(_ stream: QuicStream, status: QuicStatus, code _: QuicErrorCode) { + let stream = impl.streams.read { streams in + streams[stream] + } + if let stream { + if status.isSucceeded { + stream.close() + } else { + stream.abort() + } + } + } +} diff --git a/Networking/Sources/Networking/PeerSettings.swift b/Networking/Sources/Networking/PeerSettings.swift new file mode 100644 index 00000000..1a4a1152 --- /dev/null +++ b/Networking/Sources/Networking/PeerSettings.swift @@ -0,0 +1,7 @@ +public struct PeerSettings: Sendable { + public var maxBuilderConnections: Int +} + +extension PeerSettings { + public static let defaultSettings = PeerSettings(maxBuilderConnections: 20) +} diff --git a/Networking/Sources/Networking/Stream.swift b/Networking/Sources/Networking/Stream.swift new file mode 100644 index 00000000..8e1e8b1e --- /dev/null +++ b/Networking/Sources/Networking/Stream.swift @@ -0,0 +1,93 @@ +import AsyncChannels +import Foundation +import MsQuicSwift +import Utils + +public enum StreamStatus: Sendable { + case open, closed, aborted +} + +enum StreamError: Error { + case notOpen +} + +public final class Stream: Sendable { + let stream: QuicStream + let impl: PeerImpl + private let channel: Channel = .init(capacity: 100) + // TODO: https://github.com/gh123man/Async-Channels/issues/12 + private let nextData: ThreadSafeContainer = .init(nil) + private let _status: ThreadSafeContainer = .init(.open) + + public private(set) var status: StreamStatus { + get { + _status.value + } + set { + _status.value = newValue + } + } + + init(_ stream: QuicStream, impl: PeerImpl) { + self.stream = stream + self.impl = impl + } + + public func send(data: Data) throws { + guard status == .open else { + throw StreamError.notOpen + } + try stream.send(data: data) + } + + func received(data: Data) { + if data.isEmpty { + return + } + // TODO: backpressure handling + // https://github.com/gh123man/Async-Channels/issues/11 + Task { + await channel.send(data) + } + } + + func close() { + status = .closed + channel.close() + } + + func abort() { + status = .aborted + channel.close() + } + + public func receive() async -> Data? { + if let data = nextData.value { + nextData.value = nil + return data + } + return await channel.receive() + } + + public func receiveByte() async -> UInt8? { + if var data = nextData.value { + let byte = data.removeFirst() + if data.isEmpty { + nextData.value = nil + } else { + nextData.value = data + } + return byte + } + + guard var data = await receive() else { + return nil + } + + let byte = data.removeFirst() + if !data.isEmpty { + nextData.value = data + } + return byte + } +} diff --git a/Networking/Sources/Networking/StreamKind.swift b/Networking/Sources/Networking/StreamKind.swift new file mode 100644 index 00000000..b2b7890e --- /dev/null +++ b/Networking/Sources/Networking/StreamKind.swift @@ -0,0 +1,23 @@ +public enum UniquePresistentStreamKind: UInt8, Sendable, Hashable, CaseIterable { + case blockAnnouncement = 0 +} + +public enum CommonEphemeralStreamKind: UInt8, Sendable, Hashable, CaseIterable { + case blockRequest = 128 + case stateRequest = 129 + case safroleTicket1 = 131 + case safroleTicket2 = 132 + case workPackageSubmission = 133 + case workPackageSharing = 134 + case workReportDistrubution = 135 + case workReportRequest = 136 + case shardDistribution = 137 + case auditShardRequest = 138 + case segmentShardRequest1 = 139 + case segmentShardRequest2 = 140 + case assuranceDistribution = 141 + case preimageAnnouncement = 142 + case preimageRequest = 143 + case auditAnnouncement = 144 + case judgementPublication = 145 +} diff --git a/Networking/Tests/MsQuicSwiftTests/QuicListenerTests.swift b/Networking/Tests/MsQuicSwiftTests/QuicListenerTests.swift index 613ca712..d1606709 100644 --- a/Networking/Tests/MsQuicSwiftTests/QuicListenerTests.swift +++ b/Networking/Tests/MsQuicSwiftTests/QuicListenerTests.swift @@ -71,7 +71,7 @@ struct QuicListenerTests { let stream1 = try clientConnection.createStream() - try stream1.send(with: Data("test data 1".utf8)) + try stream1.send(data: Data("test data 1".utf8)) try? await Task.sleep(for: .milliseconds(100)) let (serverConnection, info) = serverHandler.events.value.compactMap { @@ -89,7 +89,7 @@ struct QuicListenerTests { #expect(info.remoteAddress.ipAddress == "127.0.0.1") let stream2 = try serverConnection.createStream() - try stream2.send(with: Data("other test data 2".utf8)) + try stream2.send(data: Data("other test data 2".utf8)) try? await Task.sleep(for: .milliseconds(100)) let remoteStream1 = clientHandler.events.value.compactMap { @@ -100,7 +100,7 @@ struct QuicListenerTests { nil } }.first! - try remoteStream1.send(with: Data("replay to 1".utf8)) + try remoteStream1.send(data: Data("replay to 1".utf8)) try? await Task.sleep(for: .milliseconds(100)) let remoteStream2 = serverHandler.events.value.compactMap { @@ -111,7 +111,7 @@ struct QuicListenerTests { nil } }.first! - try remoteStream2.send(with: Data("another replay to 2".utf8)) + try remoteStream2.send(data: Data("another replay to 2".utf8)) try? await Task.sleep(for: .milliseconds(100)) let receivedData = serverHandler.events.value.compactMap { diff --git a/Utils/Sources/Utils/EventBus/StoreMiddleware.swift b/Utils/Sources/Utils/EventBus/StoreMiddleware.swift index fffe4e86..cfc0cbcd 100644 --- a/Utils/Sources/Utils/EventBus/StoreMiddleware.swift +++ b/Utils/Sources/Utils/EventBus/StoreMiddleware.swift @@ -12,7 +12,7 @@ public struct StoreMiddleware: MiddlewareProtocol { public func handle(_ event: T, next: @escaping MiddlewareHandler) async throws { logger.debug(">>> dispatching event: \(event)") let task = Task { try await next(event) } - storage.mutate { storage in + storage.write { storage in storage.append((event, task)) } try await task.value diff --git a/Utils/Sources/Utils/ReadWriteLock.swift b/Utils/Sources/Utils/ReadWriteLock.swift new file mode 100644 index 00000000..d7e1f5bf --- /dev/null +++ b/Utils/Sources/Utils/ReadWriteLock.swift @@ -0,0 +1,45 @@ +#if canImport(Glibc) + import Glibc +#elseif canImport(Darwin) + import Darwin +#endif + +public final class ReadWriteLock: @unchecked Sendable { + private var lock: pthread_rwlock_t = .init() + + public init() { + let result = pthread_rwlock_init(&lock, nil) + precondition(result == 0, "Failed to initialize read-write lock") + } + + deinit { + pthread_rwlock_destroy(&lock) + } + + private func readLock() { + let result = pthread_rwlock_rdlock(&lock) + precondition(result == 0, "Failed to acquire read lock") + } + + private func writeLock() { + let result = pthread_rwlock_wrlock(&lock) + precondition(result == 0, "Failed to acquire write lock") + } + + private func unlock() { + let result = pthread_rwlock_unlock(&lock) + precondition(result == 0, "Failed to release lock") + } + + public func withReadLock(_ closure: () throws -> T) rethrows -> T { + readLock() + defer { unlock() } + return try closure() + } + + public func withWriteLock(_ closure: () throws -> T) rethrows -> T { + writeLock() + defer { unlock() } + return try closure() + } +} diff --git a/Utils/Sources/Utils/ThreadSafeContainer.swift b/Utils/Sources/Utils/ThreadSafeContainer.swift index 43d37665..14952298 100644 --- a/Utils/Sources/Utils/ThreadSafeContainer.swift +++ b/Utils/Sources/Utils/ThreadSafeContainer.swift @@ -2,42 +2,28 @@ import Foundation public final class ThreadSafeContainer: @unchecked Sendable { private var storage: T - private let queue: DispatchQueue + private let lock: ReadWriteLock = .init() - public init(_ initialValue: T, label: String = "boka.threadsafecontainer") { + public init(_ initialValue: T) { storage = initialValue - queue = DispatchQueue(label: label, attributes: .concurrent) } public func read(_ action: (T) throws -> U) rethrows -> U { - try queue.sync { try action(self.storage) } + try lock.withReadLock { try action(self.storage) } } - public func write(_ action: @escaping @Sendable (inout T) -> Void) { - queue.async(flags: .barrier) { - action(&self.storage) - } - } - - public func mutate(_ action: @escaping (inout T) throws -> U) rethrows -> U { - try queue.sync(flags: .barrier) { + public func write(_ action: (inout T) throws -> Void) rethrows { + try lock.withWriteLock { try action(&self.storage) } } -} -extension ThreadSafeContainer { - public var value: T { - get { - read { $0 } - } - set { - mutate { $0 = newValue } + public func write(_ action: (inout T) throws -> U) rethrows -> U { + try lock.withWriteLock { + try action(&self.storage) } } -} -extension ThreadSafeContainer where T: Sendable { public var value: T { get { read { $0 } diff --git a/boka.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/boka.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 3db802d8..3b97f812 100644 --- a/boka.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/boka.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,6 +1,15 @@ { - "originHash" : "47e2af6bfda504977277ddca98e4101bce9ec9e241cd666da6b97c6dade33a15", + "originHash" : "15c8990c2a71904ae0fb2311d87b37d53447f70bb359a55f0a30906bb5dcd56f", "pins" : [ + { + "identity" : "async-channels", + "kind" : "remoteSourceControl", + "location" : "https://github.com/gh123man/Async-Channels.git", + "state" : { + "revision" : "37d32cfc70f08b72a38a2c40f65338ee023afa45", + "version" : "1.0.1" + } + }, { "identity" : "async-http-client", "kind" : "remoteSourceControl",