Skip to content

Commit b97ab83

Browse files
committed
.await for async native functions outside of the VM
This avoids manually polling boxed futures within the __call__ implementation for NativeAsyncFunctions. This instead stashes the boxed future under `Context::async_call` so that we can `.await` for it's completion within `Context::run_async_with_budget`. This approach should integrate much better with async runtimes like Tokio, since it doesn't involve manually polling with a task::Context + Waker that aren't managed by the current async runtime. This also means the thread can block waiting for async IO without polling for native function completion in a busy loop. This still needs further iteration, but hopefully serves as a usable draft / proof of concept.
1 parent 3de68a7 commit b97ab83

File tree

4 files changed

+101
-159
lines changed

4 files changed

+101
-159
lines changed

core/engine/src/context/mod.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use intrinsics::Intrinsics;
1414
use temporal_rs::tzdb::FsTzdbProvider;
1515

1616
use crate::job::Job;
17+
use crate::native_function::AsyncCallState;
1718
use crate::vm::RuntimeLimits;
1819
use crate::{
1920
builtins,
@@ -129,6 +130,17 @@ pub struct Context {
129130
parser_identifier: u32,
130131

131132
data: HostDefined,
133+
134+
/// State of any boxed future that should be .awaited before continuing
135+
/// to execute VM instructions.
136+
///
137+
/// XXX: How do we make sure that any
138+
/// `AsyncCallState::Finished(Result<JsValue>)` can't be garbage collected?
139+
///
140+
/// XXX: there's maybe a better place for this, or better abstraction /
141+
/// generalization than this, but this hopefully works for a draft /
142+
/// proof-of-concept.
143+
pub(crate) async_call: AsyncCallState,
132144
}
133145

134146
impl std::fmt::Debug for Context {
@@ -1131,6 +1143,7 @@ impl ContextBuilder {
11311143
parser_identifier: 0,
11321144
can_block: self.can_block,
11331145
data: HostDefined::default(),
1146+
async_call: AsyncCallState::None,
11341147
};
11351148

11361149
builtins::set_default_global_bindings(&mut context)?;

core/engine/src/native_function/mod.rs

Lines changed: 76 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use std::pin::Pin;
99
use std::rc::Rc;
1010

1111
use boa_gc::{custom_trace, Finalize, Gc, Trace};
12-
use futures_lite::FutureExt as _;
1312

1413
use crate::job::NativeAsyncJob;
1514
use crate::value::JsVariant;
@@ -134,55 +133,32 @@ enum AsyncCallResult {
134133
Ready(JsResult<JsValue>),
135134
}
136135

137-
enum AsyncRunningState {
138-
None,
139-
Calling,
140-
Constructing { new_target: JsValue },
141-
}
142136
trait TraceableAsyncFunction: Trace {
143-
fn running(&self) -> AsyncRunningState;
144137
fn call_or_construct(
145138
&self,
146139
this: &JsValue,
147140
args: &[JsValue],
148141
new_target: &JsValue,
149142
context: &mut Context,
150143
) -> AsyncCallResult;
151-
fn poll(&self, context: &mut Context) -> AsyncCallResult {
152-
self.call_or_construct(&JsValue::undefined(), &[], &JsValue::undefined(), context)
153-
}
154144
}
155145

156-
#[derive(Finalize)]
157-
enum AsyncCallState {
146+
#[derive(Default, Finalize)]
147+
pub(crate) enum AsyncCallState {
148+
#[default]
158149
None,
159150
Calling {
160-
this: JsValue,
161-
args: Vec<JsValue>,
162-
f: Pin<Box<dyn Future<Output = JsResult<JsValue>>>>,
163-
},
164-
Constructing {
165-
new_target: JsValue,
166-
args: Vec<JsValue>,
167151
f: Pin<Box<dyn Future<Output = JsResult<JsValue>>>>,
168152
},
153+
Finished(JsResult<JsValue>),
169154
}
170155
unsafe impl Trace for AsyncCallState {
171156
custom_trace!(this, mark, {
172157
match this {
173-
AsyncCallState::None => {}
174-
AsyncCallState::Calling { this, args, f: _ } => {
175-
mark(this);
176-
mark(args);
177-
}
178-
AsyncCallState::Constructing {
179-
new_target,
180-
args,
181-
f: _,
182-
} => {
183-
mark(new_target);
184-
mark(args);
185-
}
158+
AsyncCallState::None
159+
| AsyncCallState::Calling { .. }
160+
| AsyncCallState::Finished(Err(_)) => {}
161+
AsyncCallState::Finished(Ok(value)) => mark(value),
186162
}
187163
});
188164
}
@@ -191,83 +167,36 @@ unsafe impl Trace for AsyncCallState {
191167
struct NativeAsyncFunction<T: Trace> {
192168
start: Rc<SpawnAsyncFunctionFn>,
193169
captures: T,
194-
state: RefCell<AsyncCallState>,
195170
}
196171
unsafe impl<T: Trace> Trace for NativeAsyncFunction<T> {
197172
custom_trace!(this, mark, {
198173
mark(&this.captures);
199-
mark(&*this.state.borrow());
200174
});
201175
}
202176

203177
impl<T> TraceableAsyncFunction for NativeAsyncFunction<T>
204178
where
205179
T: Trace,
206180
{
207-
fn running(&self) -> AsyncRunningState {
208-
match &*self.state.borrow() {
209-
AsyncCallState::None => AsyncRunningState::None,
210-
AsyncCallState::Calling { .. } => AsyncRunningState::Calling,
211-
AsyncCallState::Constructing { new_target, .. } => AsyncRunningState::Constructing {
212-
new_target: new_target.clone(),
213-
},
214-
}
215-
}
216-
217181
fn call_or_construct(
218182
&self,
219183
this: &JsValue,
220184
args: &[JsValue],
221185
new_target: &JsValue,
222186
context: &mut Context,
223187
) -> AsyncCallResult {
224-
let mut state = self.state.borrow_mut();
225-
match *state {
226-
AsyncCallState::None => {
227-
if new_target.is_undefined() {
228-
let f = (*self.start)(this, args, context);
229-
let f = match f {
230-
Ok(f) => f,
231-
Err(e) => return AsyncCallResult::Ready(Err(e)),
232-
};
233-
*state = AsyncCallState::Calling {
234-
this: this.clone(),
235-
args: args.to_vec(),
236-
f,
237-
};
238-
} else {
239-
let f = (*self.start)(new_target, args, context);
240-
let f = match f {
241-
Ok(f) => f,
242-
Err(e) => return AsyncCallResult::Ready(Err(e)),
243-
};
244-
*state = AsyncCallState::Constructing {
245-
new_target: new_target.clone(),
246-
args: args.to_vec(),
247-
f,
248-
};
249-
}
188+
let f = if new_target.is_undefined() {
189+
(*self.start)(this, args, context)
190+
} else {
191+
(*self.start)(new_target, args, context)
192+
};
193+
let f = match f {
194+
Ok(f) => f,
195+
Err(e) => return AsyncCallResult::Ready(Err(e)),
196+
};
197+
context.async_call = AsyncCallState::Calling { f };
250198

251-
AsyncCallResult::Pending
252-
}
253-
AsyncCallState::Calling { ref mut f, .. }
254-
| AsyncCallState::Constructing { ref mut f, .. } => {
255-
// FIXME: figure out how to work with an async Context / Waker from the application (e.g. from Tokio)
256-
let waker = std::task::Waker::noop();
257-
let mut context = std::task::Context::from_waker(waker);
258-
let result = f.poll(&mut context);
259-
match result {
260-
std::task::Poll::Pending => {
261-
//println!("Pending");
262-
AsyncCallResult::Pending
263-
}
264-
std::task::Poll::Ready(result) => {
265-
*state = AsyncCallState::None;
266-
AsyncCallResult::Ready(result)
267-
}
268-
}
269-
}
270-
}
199+
AsyncCallResult::Pending
271200
}
272201
}
273202

@@ -495,7 +424,6 @@ impl NativeFunction {
495424
let ptr = Gc::into_raw(Gc::new(NativeAsyncFunction {
496425
start: Rc::new(f),
497426
captures,
498-
state: RefCell::new(AsyncCallState::None),
499427
}));
500428

501429
// SAFETY: The pointer returned by `into_raw` is only used to coerce to a trait object,
@@ -608,46 +536,39 @@ impl NativeFunction {
608536
context: &mut Context,
609537
) -> AsyncCallResult {
610538
//println!("[NativeFunction] call, arg_count: {}", argument_count);
539+
let args = context
540+
.vm
541+
.stack
542+
.calling_convention_pop_arguments(argument_count);
543+
let func = context.vm.stack.pop();
544+
let this = context.vm.stack.pop();
545+
let this_ref = if is_constructor {
546+
&JsValue::undefined()
547+
} else {
548+
&this
549+
};
550+
611551
if let Inner::AsyncFn(ref f) = self.inner {
612-
match f.running() {
613-
AsyncRunningState::None => {
614-
let args = context
615-
.vm
616-
.stack
617-
.calling_convention_pop_arguments(argument_count);
618-
let func = context.vm.stack.pop();
619-
let this = context.vm.stack.pop();
620-
let this_ref = if is_constructor {
621-
&JsValue::undefined()
622-
} else {
623-
&this
624-
};
625-
let result =
626-
f.call_or_construct(this_ref, &args, &JsValue::undefined(), context);
627-
if matches!(result, AsyncCallResult::Pending) {
628-
context.vm.stack.push(this);
629-
context.vm.stack.push(func);
630-
context.vm.stack.calling_convention_push_arguments(&args);
631-
}
632-
result
552+
let result = match &context.async_call {
553+
AsyncCallState::None => {
554+
f.call_or_construct(this_ref, &args, &JsValue::undefined(), context)
633555
}
634-
AsyncRunningState::Calling => f.poll(context),
635-
AsyncRunningState::Constructing { new_target: _ } => {
636-
unreachable!()
556+
AsyncCallState::Calling { .. } => AsyncCallResult::Pending,
557+
AsyncCallState::Finished(_) => {
558+
let AsyncCallState::Finished(result) = std::mem::take(&mut context.async_call)
559+
else {
560+
unreachable!()
561+
};
562+
AsyncCallResult::Ready(result)
637563
}
564+
};
565+
if matches!(result, AsyncCallResult::Pending) {
566+
context.vm.stack.push(this);
567+
context.vm.stack.push(func);
568+
context.vm.stack.calling_convention_push_arguments(&args);
638569
}
570+
result
639571
} else {
640-
let args = context
641-
.vm
642-
.stack
643-
.calling_convention_pop_arguments(argument_count);
644-
let _func = context.vm.stack.pop();
645-
let this = context.vm.stack.pop();
646-
let this = if is_constructor {
647-
JsValue::undefined()
648-
} else {
649-
this
650-
};
651572
AsyncCallResult::Ready(match self.inner {
652573
Inner::PointerFn(f) => f(&this, &args, context),
653574
Inner::Closure(ref c) => c.call(&this, &args, context),
@@ -661,43 +582,42 @@ impl NativeFunction {
661582
argument_count: usize,
662583
context: &mut Context,
663584
) -> (JsValue, AsyncCallResult) {
664-
if let Inner::AsyncFn(ref f) = self.inner {
665-
match f.running() {
666-
AsyncRunningState::None => {
667-
let new_target = context.vm.stack.pop();
668-
let args = context
669-
.vm
670-
.stack
671-
.calling_convention_pop_arguments(argument_count);
672-
let _func = context.vm.stack.pop();
673-
let _this = context.vm.stack.pop();
674-
(
675-
new_target.clone(),
676-
f.call_or_construct(&JsValue::undefined(), &args, &new_target, context),
677-
)
678-
}
679-
AsyncRunningState::Constructing { new_target } => {
680-
(new_target.clone(), f.poll(context))
585+
let new_target = context.vm.stack.pop();
586+
let args = context
587+
.vm
588+
.stack
589+
.calling_convention_pop_arguments(argument_count);
590+
let func = context.vm.stack.pop();
591+
let this = context.vm.stack.pop();
592+
let result = if let Inner::AsyncFn(ref f) = self.inner {
593+
let result = match &context.async_call {
594+
AsyncCallState::None => {
595+
f.call_or_construct(&JsValue::undefined(), &args, &new_target, context)
681596
}
682-
AsyncRunningState::Calling => {
683-
unreachable!()
597+
AsyncCallState::Calling { .. } => AsyncCallResult::Pending,
598+
AsyncCallState::Finished(_) => {
599+
let AsyncCallState::Finished(result) = std::mem::take(&mut context.async_call)
600+
else {
601+
unreachable!()
602+
};
603+
AsyncCallResult::Ready(result)
684604
}
605+
};
606+
if matches!(result, AsyncCallResult::Pending) {
607+
context.vm.stack.push(this);
608+
context.vm.stack.push(func);
609+
context.vm.stack.calling_convention_push_arguments(&args);
610+
context.vm.stack.push(new_target.clone());
685611
}
612+
result
686613
} else {
687-
let new_target = context.vm.stack.pop();
688-
let args = context
689-
.vm
690-
.stack
691-
.calling_convention_pop_arguments(argument_count);
692-
let _func = context.vm.stack.pop();
693-
let _this = context.vm.stack.pop();
694-
let result = AsyncCallResult::Ready(match self.inner {
614+
AsyncCallResult::Ready(match self.inner {
695615
Inner::PointerFn(f) => f(&new_target, &args, context),
696616
Inner::Closure(ref c) => c.call(&new_target, &args, context),
697617
Inner::AsyncFn(_) => unreachable!(),
698-
});
699-
(new_target.clone(), result)
700-
}
618+
})
619+
};
620+
(new_target.clone(), result)
701621
}
702622

703623
/// Converts this `NativeFunction` into a `JsFunction` without setting its name or length.

core/engine/src/vm/mod.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
use crate::{
88
builtins::promise::{PromiseCapability, ResolvingFunctions},
99
environments::EnvironmentStack,
10+
native_function::AsyncCallState,
1011
object::JsFunction,
1112
realm::Realm,
1213
script::Script,
@@ -819,7 +820,15 @@ impl Context {
819820

820821
if runtime_budget == 0 {
821822
runtime_budget = budget;
822-
yield_now().await;
823+
match &mut self.async_call {
824+
AsyncCallState::None | AsyncCallState::Finished(_) => {
825+
yield_now().await;
826+
}
827+
AsyncCallState::Calling { f } => {
828+
let result = f.await;
829+
self.async_call = AsyncCallState::Finished(result);
830+
}
831+
}
823832
}
824833
}
825834

@@ -846,6 +855,7 @@ impl Context {
846855
ControlFlow::Continue(_) => {}
847856
ControlFlow::Break(value) => return value,
848857
}
858+
debug_assert!(matches!(self.async_call, AsyncCallState::None));
849859
}
850860

851861
CompletionRecord::Throw(JsError::from_native(JsNativeError::error()))

core/engine/src/vm/opcode/call/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,7 @@ impl Call {
197197
.__call__(argument_count.into())
198198
.async_resolve(context)?
199199
{
200-
ResolvedCallValue::Ready => Ok(OpStatus::Finished),
201-
ResolvedCallValue::Complete => Ok(OpStatus::Finished),
200+
ResolvedCallValue::Ready | ResolvedCallValue::Complete => Ok(OpStatus::Finished),
202201
ResolvedCallValue::Pending => {
203202
//println!("Pending call");
204203
Ok(OpStatus::Pending)

0 commit comments

Comments
 (0)