Skip to content

Commit 27ac22d

Browse files
committed
Move cancellation to LambdaRuntimeClient.nextInvocation
1 parent 467d012 commit 27ac22d

File tree

2 files changed

+74
-57
lines changed

2 files changed

+74
-57
lines changed

Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -145,22 +145,28 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
145145
}
146146

147147
func nextInvocation() async throws -> (Invocation, Writer) {
148-
switch self.lambdaState {
149-
case .idle:
150-
self.lambdaState = .waitingForNextInvocation
151-
let handler = try await self.makeOrGetConnection()
152-
let invocation = try await handler.nextInvocation()
153-
guard case .waitingForNextInvocation = self.lambdaState else {
148+
try await withTaskCancellationHandler {
149+
switch self.lambdaState {
150+
case .idle:
151+
self.lambdaState = .waitingForNextInvocation
152+
let handler = try await self.makeOrGetConnection()
153+
let invocation = try await handler.nextInvocation()
154+
guard case .waitingForNextInvocation = self.lambdaState else {
155+
fatalError("Invalid state: \(self.lambdaState)")
156+
}
157+
self.lambdaState = .waitingForResponse(requestID: invocation.metadata.requestID)
158+
return (invocation, Writer(runtimeClient: self))
159+
160+
case .waitingForNextInvocation,
161+
.waitingForResponse,
162+
.sendingResponse,
163+
.sentResponse:
154164
fatalError("Invalid state: \(self.lambdaState)")
155165
}
156-
self.lambdaState = .waitingForResponse(requestID: invocation.metadata.requestID)
157-
return (invocation, Writer(runtimeClient: self))
158-
159-
case .waitingForNextInvocation,
160-
.waitingForResponse,
161-
.sendingResponse,
162-
.sentResponse:
163-
fatalError("Invalid state: \(self.lambdaState)")
166+
} onCancel: {
167+
Task {
168+
await self.close()
169+
}
164170
}
165171
}
166172

@@ -469,37 +475,10 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>
469475
func nextInvocation(isolation: isolated (any Actor)? = #isolation) async throws -> Invocation {
470476
switch self.state {
471477
case .connected(let context, .idle):
472-
return try await withTaskCancellationHandler {
473-
try Task.checkCancellation()
474-
return try await withCheckedThrowingContinuation {
475-
(continuation: CheckedContinuation<Invocation, any Error>) in
476-
self.state = .connected(context, .waitingForNextInvocation(continuation))
477-
478-
let unsafeContext = NIOLoopBound(context, eventLoop: context.eventLoop)
479-
context.eventLoop.execute { [nextInvocationPath, defaultHeaders] in
480-
// Send next request. The function `sendNextRequest` requires `self` which is not
481-
// Sendable so just inlined the code instead
482-
let httpRequest = HTTPRequestHead(
483-
version: .http1_1,
484-
method: .GET,
485-
uri: nextInvocationPath,
486-
headers: defaultHeaders
487-
)
488-
let context = unsafeContext.value
489-
context.write(Self.wrapOutboundOut(.head(httpRequest)), promise: nil)
490-
context.write(Self.wrapOutboundOut(.end(nil)), promise: nil)
491-
context.flush()
492-
}
493-
}
494-
} onCancel: {
495-
switch self.state {
496-
case .connected(_, .waitingForNextInvocation(let continuation)):
497-
continuation.resume(throwing: CancellationError())
498-
case .connected(_, .idle):
499-
break
500-
default:
501-
fatalError("Invalid state: \(self.state)")
502-
}
478+
return try await withCheckedThrowingContinuation {
479+
(continuation: CheckedContinuation<Invocation, any Error>) in
480+
self.state = .connected(context, .waitingForNextInvocation(continuation))
481+
self.sendNextRequest(context: context)
503482
}
504483

505484
case .connected(_, .sendingResponse),
@@ -846,6 +825,12 @@ extension LambdaChannelHandler: ChannelInboundHandler {
846825

847826
func channelInactive(context: ChannelHandlerContext) {
848827
// fail any pending responses with last error or assume peer disconnected
828+
switch self.state {
829+
case .connected(_, .waitingForNextInvocation(let continuation)):
830+
continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel)
831+
default:
832+
break
833+
}
849834

850835
// we don't need to forward channelInactive to the delegate, as the delegate observes the
851836
// closeFuture

Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,20 +89,52 @@ struct LambdaRuntimeClientTests {
8989

9090
@Test
9191
func testCancellation() async throws {
92-
try await LambdaRuntimeClient.withRuntimeClient(
93-
configuration: .init(ip: "127.0.0.1", port: 7000),
94-
eventLoop: NIOSingletons.posixEventLoopGroup.next(),
95-
logger: self.logger
96-
) { runtimeClient in
97-
try await withThrowingTaskGroup(of: Void.self) { group in
98-
group.addTask {
99-
while true {
100-
_ = try await runtimeClient.nextInvocation()
92+
struct HappyBehavior: LambdaServerBehavior {
93+
let requestId = UUID().uuidString
94+
let event = "hello"
95+
96+
func getInvocation() -> GetInvocationResult {
97+
.success((self.requestId, self.event))
98+
}
99+
100+
func processResponse(requestId: String, response: String?) -> Result<Void, ProcessResponseError> {
101+
#expect(self.requestId == requestId)
102+
#expect(self.event == response)
103+
return .success(())
104+
}
105+
106+
func processError(requestId: String, error: ErrorResponse) -> Result<Void, ProcessErrorError> {
107+
Issue.record("should not report error")
108+
return .failure(.internalServerError)
109+
}
110+
111+
func processInitError(error: ErrorResponse) -> Result<Void, ProcessErrorError> {
112+
Issue.record("should not report init error")
113+
return .failure(.internalServerError)
114+
}
115+
}
116+
117+
try await withMockServer(behaviour: HappyBehavior()) { port in
118+
try await LambdaRuntimeClient.withRuntimeClient(
119+
configuration: .init(ip: "127.0.0.1", port: port),
120+
eventLoop: NIOSingletons.posixEventLoopGroup.next(),
121+
logger: self.logger
122+
) { runtimeClient in
123+
try await withThrowingTaskGroup(of: Void.self) { group in
124+
group.addTask {
125+
while true {
126+
print("Waiting")
127+
let (_, writer) = try await runtimeClient.nextInvocation()
128+
try await Task {
129+
try await writer.write(ByteBuffer(string: "hello"))
130+
try await writer.finish()
131+
}.value
132+
}
101133
}
134+
// wait a small amount to ensure we are waiting for continuation
135+
try await Task.sleep(for: .milliseconds(100))
136+
group.cancelAll()
102137
}
103-
// wait a small amount to ensure we are waiting for continuation
104-
try await Task.sleep(for: .milliseconds(100))
105-
group.cancelAll()
106138
}
107139
}
108140
}

0 commit comments

Comments
 (0)