Skip to content

Commit

Permalink
try to use public key instead of network address (#204)
Browse files Browse the repository at this point in the history
* try to use public key instead of network address

* connection status and test fix
  • Loading branch information
xlc authored Oct 29, 2024
1 parent 61defc9 commit a6bb5d3
Show file tree
Hide file tree
Showing 12 changed files with 393 additions and 127 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
test:
name: Build and Test
runs-on: [self-hosted, linux]
timeout-minutes: 30
steps:
- name: Checkout Code
uses: actions/checkout@v4
Expand Down
5 changes: 3 additions & 2 deletions Networking/Sources/MsQuicSwift/QuicConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,14 @@ private class ConnectionHandle {

case QUIC_CONNECTION_EVENT_SHUTDOWN_INITIATED_BY_TRANSPORT:
let evtData = event.pointee.SHUTDOWN_INITIATED_BY_TRANSPORT
if evtData.Status == QuicStatusCode.connectionIdle.rawValue {
let status = QuicStatus(rawValue: evtData.Status)
if status == .code(.connectionIdle) {
logger.trace("Successfully shut down on idle.")
if let connection {
connection.handler.shutdownInitiated(connection, reason: .idle)
}
} else {
logger.debug("Shut down by transport. Status: \(evtData.Status) Error: \(evtData.ErrorCode)")
logger.debug("Shut down by transport. Status: \(status) Error: \(evtData.ErrorCode)")
if let connection {
connection.handler.shutdownInitiated(
connection,
Expand Down
93 changes: 93 additions & 0 deletions Networking/Sources/Networking/Connection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,47 @@ public protocol ConnectionInfoProtocol {
var id: UniqueId { get }
var role: PeerRole { get }
var remoteAddress: NetAddr { get }
var publicKey: Data? { get }
}

enum ConnectionError: Error {
case receiveFailed
case invalidLength
case unexpectedState
case closed
}

enum ConnectionState {
case connecting(continuations: [CheckedContinuation<Void, Error>])
case connected(publicKey: Data)
case closed
}

public final class Connection<Handler: StreamHandler>: Sendable, ConnectionInfoProtocol {
let connection: QuicConnection
let impl: PeerImpl<Handler>

public let role: PeerRole
public let remoteAddress: NetAddr

let presistentStreams: ThreadSafeContainer<
[Handler.PresistentHandler.StreamKind: Stream<Handler>]
> = .init([:])
let initiatedByLocal: Bool
private let state: ThreadSafeContainer<ConnectionState> = .init(.connecting(continuations: []))

public var publicKey: Data? {
state.read {
switch $0 {
case .connecting:
nil
case let .connected(publicKey):
publicKey
case .closed:
nil
}
}
}

public var id: UniqueId {
connection.id
Expand All @@ -39,11 +64,79 @@ public final class Connection<Handler: StreamHandler>: Sendable, ConnectionInfoP
self.initiatedByLocal = initiatedByLocal
}

func opened(publicKey: Data) throws {
try state.write { state in
if case let .connecting(continuations) = state {
for continuation in continuations {
continuation.resume()
}
state = .connected(publicKey: publicKey)
} else {
throw ConnectionError.unexpectedState
}
}
}

func closed() {
state.write { state in
if case let .connecting(continuations) = state {
for continuation in continuations {
continuation.resume(throwing: ConnectionError.closed)
}
state = .closed
}
state = .closed
}
}

public var isClosed: Bool {
state.read {
switch $0 {
case .connecting:
false
case .connected:
false
case .closed:
true
}
}
}

public func ready() async throws {
let isReady = state.read {
switch $0 {
case .connecting:
false
case .connected:
true
case .closed:
true
}
}

if isReady {
return
}
try await withCheckedThrowingContinuation { continuation in
state.write { state in
if case var .connecting(continuations) = state {
continuations.append(continuation)
state = .connecting(continuations: continuations)
} else {
continuation.resume()
}
}
}
}

public func close(abort: Bool = false) {
try? connection.shutdown(errorCode: abort ? 1 : 0) // TODO: define some error code
}

public func request(_ request: Handler.EphemeralHandler.Request) async throws -> Data {
guard !isClosed else {
throw ConnectionError.closed
}
logger.trace("sending request", metadata: ["kind": "\(request.kind)"])
let data = try request.encode()
let kind = request.kind
Expand Down
75 changes: 64 additions & 11 deletions Networking/Sources/Networking/Peer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public final class Peer<Handler: StreamHandler>: Sendable {
impl.logger
}

public let publicKey: Data

public init(options: PeerOptions<Handler>) throws {
let logger = Logger(label: "Peer".uniqueId)

Expand All @@ -81,11 +83,14 @@ public final class Peer<Handler: StreamHandler>: Sendable {
registration: registration, pkcs12: pkcs12, alpns: [clientAlpn], client: true, settings: options.clientSettings
)

publicKey = options.secretKey.publicKey.data.data

impl = PeerImpl(
logger: logger,
role: options.role,
settings: options.peerSettings,
alpns: alpns,
publicKey: publicKey,
clientConfiguration: clientConfiguration,
presistentStreamHandler: options.presistentStreamHandler,
ephemeralStreamHandler: options.ephemeralStreamHandler
Expand All @@ -98,6 +103,12 @@ public final class Peer<Handler: StreamHandler>: Sendable {
listenAddress: options.listenAddress,
alpns: allAlpns
)

logger.debug("Peer initialized", metadata: [
"listenAddress": "\(options.listenAddress)",
"role": "\(options.role)",
"publicKey": "\(options.secretKey.publicKey.data.toHexString())",
])
}

public func listenAddress() throws -> NetAddr {
Expand All @@ -107,11 +118,10 @@ public final class Peer<Handler: StreamHandler>: Sendable {
// TODO: see if we can remove the role parameter
public func connect(to address: NetAddr, role: PeerRole) throws -> Connection<Handler> {
let conn = impl.connections.read { connections in
connections.byAddr[address]?.0
connections.byAddr[address]
}
return try conn ?? impl.connections.write { connections in
let curr = connections.byAddr[address]?.0
if let curr {
if let curr = connections.byAddr[address] {
return curr
}

Expand All @@ -130,12 +140,18 @@ public final class Peer<Handler: StreamHandler>: Sendable {
remoteAddress: address,
initiatedByLocal: true
)
connections.byAddr[address] = (conn, role)
connections.byAddr[address] = conn
connections.byId[conn.id] = conn
return conn
}
}

public func getConnection(publicKey: Data) -> Connection<Handler>? {
impl.connections.read { connections in
connections.byPublicKey[publicKey]
}
}

public func broadcast(kind: Handler.PresistentHandler.StreamKind, message: Handler.PresistentHandler.Message) {
let connections = impl.connections.read { connections in
connections.byId.values
Expand Down Expand Up @@ -170,15 +186,17 @@ public final class Peer<Handler: StreamHandler>: Sendable {

final class PeerImpl<Handler: StreamHandler>: Sendable {
struct ConnectionStorage {
var byAddr: [NetAddr: (Connection<Handler>, PeerRole)] = [:]
var byAddr: [NetAddr: Connection<Handler>] = [:]
var byId: [UniqueId: Connection<Handler>] = [:]
var byPublicKey: [Data: Connection<Handler>] = [:]
}

fileprivate let logger: Logger
fileprivate let role: PeerRole
fileprivate let settings: PeerSettings
fileprivate let alpns: [PeerRole: Data]
fileprivate let alpnLookup: [Data: PeerRole]
fileprivate let publicKey: Data

fileprivate let clientConfiguration: QuicConfiguration

Expand All @@ -193,6 +211,7 @@ final class PeerImpl<Handler: StreamHandler>: Sendable {
role: PeerRole,
settings: PeerSettings,
alpns: [PeerRole: Data],
publicKey: Data,
clientConfiguration: QuicConfiguration,
presistentStreamHandler: Handler.PresistentHandler,
ephemeralStreamHandler: Handler.EphemeralHandler
Expand All @@ -201,6 +220,7 @@ final class PeerImpl<Handler: StreamHandler>: Sendable {
self.role = role
self.settings = settings
self.alpns = alpns
self.publicKey = publicKey
self.clientConfiguration = clientConfiguration
self.presistentStreamHandler = presistentStreamHandler
self.ephemeralStreamHandler = ephemeralStreamHandler
Expand All @@ -215,7 +235,7 @@ final class PeerImpl<Handler: StreamHandler>: Sendable {
func addConnection(_ connection: QuicConnection, addr: NetAddr, role: PeerRole) -> Bool {
connections.write { connections in
if role == .builder {
let currentCount = connections.byAddr.values.filter { $0.1 == role }.count
let currentCount = connections.byAddr.values.filter { $0.role == role }.count
if currentCount >= self.settings.maxBuilderConnections {
self.logger.warning("max builder connections reached")
// TODO: consider connection rotation strategy
Expand All @@ -233,7 +253,7 @@ final class PeerImpl<Handler: StreamHandler>: Sendable {
remoteAddress: addr,
initiatedByLocal: false
)
connections.byAddr[addr] = (conn, role)
connections.byAddr[addr] = conn
connections.byId[connection.id] = conn
return true
}
Expand Down Expand Up @@ -279,6 +299,13 @@ private struct PeerEventHandler<Handler: StreamHandler>: QuicEventHandler {
guard let certificate else {
return .code(.requiredCert)
}
let conn = impl.connections.read { connections in
connections.byId[connection.id]
}
guard let conn else {
logger.warning("Trying to open connection but connection is gone?", metadata: ["connectionId": "\(connection.id)"])
return .code(.connectionRefused)
}
do {
let (publicKey, alternativeName) = try parseCertificate(data: certificate, type: .x509)
logger.trace("Certificate parsed", metadata: [
Expand All @@ -289,16 +316,38 @@ private struct PeerEventHandler<Handler: StreamHandler>: QuicEventHandler {
if alternativeName != generateSubjectAlternativeName(pubkey: publicKey) {
return .code(.badCert)
}
if impl.role == PeerRole.validator {
// TODO: verify if it is current or next validator

if publicKey == impl.publicKey {
// self connection
logger.trace("self connection rejected", metadata: [
"connectionId": "\(connection.id)",
"publicKey": "\(publicKey.toHexString())",
])
return .code(.connectionRefused)
}

// TODO: verify if it is current or next validator

return try impl.connections.write { connections in
if connections.byPublicKey.keys.contains(publicKey) {
// duplicated connection
logger.debug("duplicated connection rejected", metadata: [
"connectionId": "\(connection.id)",
"publicKey": "\(publicKey.toHexString())",
])
// TODO: write a test for this
return .code(.connectionRefused)
}
connections.byPublicKey[publicKey] = conn
try conn.opened(publicKey: publicKey)
return .code(.success)
}
} catch {
logger.warning("Failed to parse certificate", metadata: [
"connectionId": "\(connection.id)",
"error": "\(error)"])
return .code(.badCert)
}
return .code(.success)
}

func connected(_ connection: QuicConnection) {
Expand Down Expand Up @@ -330,8 +379,12 @@ private struct PeerEventHandler<Handler: StreamHandler>: QuicEventHandler {
logger.trace("connection shutdown complete", metadata: ["connectionId": "\(connection.id)"])
impl.connections.write { connections in
if let conn = connections.byId[connection.id] {
conn.closed()
connections.byId.removeValue(forKey: connection.id)
connections.byAddr[conn.remoteAddress] = nil
connections.byAddr.removeValue(forKey: conn.remoteAddress)
if let publicKey = conn.publicKey {
connections.byPublicKey.removeValue(forKey: publicKey)
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion Networking/Tests/MsQuicSwiftTests/QuicListenerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ struct QuicListenerTests {
}.first!
try remoteStream2.send(data: Data("another replay to 2".utf8))

try? await Task.sleep(for: .milliseconds(100))
try? await Task.sleep(for: .milliseconds(200))
let receivedData = serverHandler.events.value.compactMap {
switch $0 {
case let .dataReceived(_, data):
Expand Down
Loading

0 comments on commit a6bb5d3

Please sign in to comment.