Skip to content

[Custom Descriptors] Interpret descriptor casts #7677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions src/wasm-interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1651,6 +1651,36 @@ class ExpressionRunner : public OverriddenVisitor<SubType, Flow> {
return typename Cast::Failure{val};
}
}
template<typename T> Cast doDescCast(T* curr) {
Flow ref = self()->visit(curr->ref);
if (ref.breaking()) {
return typename Cast::Breaking{ref};
}
Flow desc = self()->visit(curr->desc);
if (desc.breaking()) {
return typename Cast::Breaking{ref};
}
auto expected = desc.getSingleValue().getGCData();
if (!expected) {
trap("null descriptor");
}
Literal val = ref.getSingleValue();
auto data = val.getGCData();
if (!data) {
// Check whether null is allowed.
if (curr->getCastType().isNullable()) {
return typename Cast::Success{val};
} else {
return typename Cast::Failure{val};
}
}
// The cast succeeds if we have the expected descriptor.
if (data->desc.getGCData() == expected) {
return typename Cast::Success{val};
} else {
return typename Cast::Failure{val};
}
}

Flow visitRefTest(RefTest* curr) {
NOTE_ENTER("RefTest");
Expand All @@ -1663,7 +1693,7 @@ class ExpressionRunner : public OverriddenVisitor<SubType, Flow> {
}
Flow visitRefCast(RefCast* curr) {
NOTE_ENTER("RefCast");
auto cast = doCast(curr);
auto cast = curr->desc ? doDescCast(curr) : doCast(curr);
if (auto* breaking = cast.getBreaking()) {
return *breaking;
} else if (auto* result = cast.getSuccess()) {
Expand All @@ -1690,29 +1720,28 @@ class ExpressionRunner : public OverriddenVisitor<SubType, Flow> {
// BrOnCast* uses the casting infrastructure, so handle them first.
switch (curr->op) {
case BrOnCast:
case BrOnCastFail: {
auto cast = doCast(curr);
case BrOnCastFail:
case BrOnCastDesc:
case BrOnCastDescFail: {
auto cast = curr->desc ? doDescCast(curr) : doCast(curr);
if (auto* breaking = cast.getBreaking()) {
return *breaking;
} else if (auto* original = cast.getFailure()) {
if (curr->op == BrOnCast) {
if (curr->op == BrOnCast || curr->op == BrOnCastDesc) {
return *original;
} else {
return Flow(curr->name, *original);
}
} else {
auto* result = cast.getSuccess();
assert(result);
if (curr->op == BrOnCast) {
if (curr->op == BrOnCast || curr->op == BrOnCastDesc) {
return Flow(curr->name, *result);
} else {
return *result;
}
}
}
case BrOnCastDesc:
case BrOnCastDescFail:
WASM_UNREACHABLE("TODO");
case BrOnNull:
case BrOnNonNull: {
// Otherwise we are just checking for null.
Expand Down
164 changes: 164 additions & 0 deletions test/spec/br_on_cast_desc.wast
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
(type $sub.desc (sub $super.desc (describes $sub (struct))))
)

(global $super.desc1 (ref (exact $super.desc)) (struct.new $super.desc))
(global $super.desc2 (ref (exact $super.desc)) (struct.new $super.desc))
(global $super1 (ref $super) (struct.new $super (global.get $super.desc1)))

(global $sub.desc (ref (exact $sub.desc)) (struct.new $sub.desc))
(global $sub (ref $sub) (struct.new $sub (global.get $sub.desc)))

;; br_on_cast_desc

(func $br_on_cast_desc-unreachable (result anyref)
Expand All @@ -33,6 +40,78 @@
)
(unreachable)
)
(func (export "cast-success") (result i32)
(drop
(block $l (result anyref)
(br_on_cast_desc $l anyref (ref null $super)
(global.get $super1)
(global.get $super.desc1)
)
(return (i32.const 0))
)
)
(i32.const 1)
)
(func (export "cast-success-supertype") (result i32)
(drop
(block $l (result anyref)
(br_on_cast_desc $l anyref (ref null $super)
(global.get $sub)
(global.get $sub.desc)
)
(return (i32.const 0))
)
)
(i32.const 1)
)
(func (export "cast-success-null") (result i32)
(drop
(block $l (result anyref)
(br_on_cast_desc $l anyref (ref null $super)
(ref.null none)
(global.get $super.desc1)
)
(return (i32.const 0))
)
)
(i32.const 1)
)
(func (export "cast-fail-null") (result i32)
(drop
(block $l (result anyref)
(br_on_cast_desc $l anyref (ref $super)
(ref.null none)
(global.get $super.desc1)
)
(return (i32.const 0))
)
)
(i32.const 1)
)
(func (export "cast-fail-wrong-desc") (result i32)
(drop
(block $l (result anyref)
(br_on_cast_desc $l anyref (ref null $super)
(global.get $super1)
(global.get $super.desc2)
)
(return (i32.const 0))
)
)
(i32.const 1)
)
(func (export "cast-fail-null-desc") (result i32)
(drop
(block $l (result anyref)
(br_on_cast_desc $l anyref (ref null $super)
(global.get $super1)
(ref.null none)
)
(return (i32.const 0))
)
)
(i32.const 1)
)

;; br_on_cast_desc_fail

Expand Down Expand Up @@ -61,8 +140,93 @@
)
)
)
(func (export "fail-cast-success") (result i32)
(drop
(block $l (result anyref)
(br_on_cast_desc_fail $l anyref (ref null $super)
(global.get $super1)
(global.get $super.desc1)
)
(return (i32.const 0))
)
)
(i32.const 1)
)
(func (export "fail-cast-success-supertype") (result i32)
(drop
(block $l (result anyref)
(br_on_cast_desc_fail $l anyref (ref null $super)
(global.get $sub)
(global.get $sub.desc)
)
(return (i32.const 0))
)
)
(i32.const 1)
)
(func (export "fail-cast-success-null") (result i32)
(drop
(block $l (result anyref)
(br_on_cast_desc_fail $l anyref (ref null $super)
(ref.null none)
(global.get $super.desc1)
)
(return (i32.const 0))
)
)
(i32.const 1)
)
(func (export "fail-cast-fail-null") (result i32)
(drop
(block $l (result anyref)
(br_on_cast_desc_fail $l anyref (ref $super)
(ref.null none)
(global.get $super.desc1)
)
(return (i32.const 0))
)
)
(i32.const 1)
)
(func (export "fail-cast-fail-wrong-desc") (result i32)
(drop
(block $l (result anyref)
(br_on_cast_desc_fail $l anyref (ref null $super)
(global.get $super1)
(global.get $super.desc2)
)
(return (i32.const 0))
)
)
(i32.const 1)
)
(func (export "fail-cast-fail-null-desc") (result i32)
(drop
(block $l (result anyref)
(br_on_cast_desc_fail $l anyref (ref null $super)
(global.get $super1)
(ref.null none)
)
(return (i32.const 0))
)
)
(i32.const 1)
)
)

(assert_return (invoke "cast-success") (i32.const 1))
(assert_return (invoke "cast-success-supertype") (i32.const 1))
(assert_return (invoke "cast-success-null") (i32.const 1))
(assert_return (invoke "cast-fail-null") (i32.const 0))
(assert_return (invoke "cast-fail-wrong-desc") (i32.const 0))
(assert_trap (invoke "cast-fail-null-desc") "null descriptor")
(assert_return (invoke "fail-cast-success") (i32.const 0))
(assert_return (invoke "fail-cast-success-supertype") (i32.const 0))
(assert_return (invoke "fail-cast-success-null") (i32.const 0))
(assert_return (invoke "fail-cast-fail-null") (i32.const 1))
(assert_return (invoke "fail-cast-fail-wrong-desc") (i32.const 1))
(assert_trap (invoke "fail-cast-fail-null-desc") "null descriptor")

(assert_malformed
;; Input type must be a reference.
(module quote "(module (rec (type $struct (descriptor $desc (struct))) (type $desc (describes $struct (struct)))) (func (result anyref) (unreachable) (br_on_cast_desc 0 i32 (ref null $struct))))")
Expand Down
98 changes: 98 additions & 0 deletions test/spec/ref.cast_desc.wast
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
(type $sub.desc (sub $super.desc (describes $sub (struct))))
)

(global $super.desc1 (ref (exact $super.desc)) (struct.new $super.desc))
(global $super.desc2 (ref (exact $super.desc)) (struct.new $super.desc))
(global $super1 (ref $super) (struct.new $super (global.get $super.desc1)))

(global $sub.desc (ref (exact $sub.desc)) (struct.new $sub.desc))
(global $sub (ref $sub) (struct.new $sub (global.get $sub.desc)))

;; ref.cast_desc (ref null ht)

(func $ref.cast_desc-null-unreachable (result anyref)
Expand All @@ -32,6 +39,46 @@
(local.get $super.desc)
)
)
(func (export "cast-success")
(drop
(ref.cast_desc (ref null $super)
(global.get $super1)
(global.get $super.desc1)
)
)
)
(func (export "cast-success-supertype")
(drop
(ref.cast_desc (ref null $super)
(global.get $sub)
(global.get $sub.desc)
)
)
)
(func (export "cast-success-null")
(drop
(ref.cast_desc (ref null $super)
(ref.null none)
(global.get $super.desc1)
)
)
)
(func (export "cast-fail-wrong-desc")
(drop
(ref.cast_desc (ref null $super)
(global.get $super1)
(global.get $super.desc2)
)
)
)
(func (export "cast-fail-null-desc")
(drop
(ref.cast_desc (ref null $super)
(global.get $super1)
(ref.null none)
)
)
)

;; ref.cast_desc (ref ht)

Expand All @@ -58,8 +105,59 @@
(local.get $super.desc)
)
)
(func (export "cast-nn-success")
(drop
(ref.cast_desc (ref $super)
(global.get $super1)
(global.get $super.desc1)
)
)
)
(func (export "cast-nn-success-supertype")
(drop
(ref.cast_desc (ref $super)
(global.get $sub)
(global.get $sub.desc)
)
)
)
(func (export "cast-nn-fail-null")
(drop
(ref.cast_desc (ref $super)
(ref.null none)
(global.get $super.desc1)
)
)
)
(func (export "cast-nn-fail-wrong-desc")
(drop
(ref.cast_desc (ref $super)
(global.get $super1)
(global.get $super.desc2)
)
)
)
(func (export "cast-nn-fail-null-desc")
(drop
(ref.cast_desc (ref $super)
(global.get $super1)
(ref.null none)
)
)
)
)

(assert_return (invoke "cast-success"))
(assert_return (invoke "cast-success-supertype"))
(assert_return (invoke "cast-success-null"))
(assert_trap (invoke "cast-fail-wrong-desc") "cast error")
(assert_trap (invoke "cast-fail-null-desc") "null descriptor")
(assert_return (invoke "cast-nn-success"))
(assert_return (invoke "cast-nn-success-supertype"))
(assert_trap (invoke "cast-nn-fail-null") "cast error")
(assert_trap (invoke "cast-nn-fail-wrong-desc") "cast error")
(assert_trap (invoke "cast-nn-fail-null-desc") "null descriptor")

(assert_malformed
;; Cast type must be a reference.
(module quote "(module (func (unreachable) (ref.cast_desc i32) (unreachable)))")
Expand Down
Loading