Skip to content

Rust: Implement type inference for closures and calls to closures #20130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,26 @@ class FutureTrait extends Trait {
}
}

/**
* The [`FnOnce` trait][1].
*
* [1]: https://doc.rust-lang.org/std/ops/trait.FnOnce.html
*/
class FnOnceTrait extends Trait {
pragma[nomagic]
FnOnceTrait() { this.getCanonicalPath() = "core::ops::function::FnOnce" }

/** Gets the type parameter of this trait. */
TypeParam getTypeParam() { result = this.getGenericParamList().getGenericParam(0) }

/** Gets the `Output` associated type. */
pragma[nomagic]
TypeAlias getOutputType() {
result = this.getAssocItemList().getAnAssocItem() and
result.getName().getText() = "Output"
}
}

/**
* The [`Iterator` trait][1].
*
Expand Down
129 changes: 129 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,17 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
prefix2.isEmpty() and
s = getRangeType(n1)
)
or
exists(ClosureExpr ce, int index |
n1 = ce and
n2 = ce.getParam(index).getPat() and
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
prefix2.isEmpty()
)
or
n1.(ClosureExpr).getBody() = n2 and
prefix1 = closureReturnPath() and
prefix2.isEmpty()
}

pragma[nomagic]
Expand Down Expand Up @@ -1435,6 +1446,120 @@ private Type inferForLoopExprType(AstNode n, TypePath path) {
)
}

/**
* An invoked expression, the target of a call that is either a local variable
* or a non-path expression. This means that the expression denotes a
* first-class function.
*/
final private class InvokedClosureExpr extends Expr {
private CallExpr call;

InvokedClosureExpr() {
call.getFunction() = this and
Copy link
Preview

Copilot AI Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The condition call.getFunction() = this could be inefficient as it requires checking all CallExpr nodes. Consider using a more targeted approach or adding an index if this becomes a performance bottleneck.

Suggested change
call.getFunction() = this and
/**
* Associates a CallExpr with its function.
*/
private predicate isFunctionOfCall(Expr function, CallExpr call) {
call.getFunction() = function
}
InvokedClosureExpr() {
isFunctionOfCall(this, call) and

Copilot uses AI. Check for mistakes.

(not this instanceof PathExpr or this = any(Variable v).getAnAccess())
}

Type getTypeAt(TypePath path) { result = inferType(this, path) }

CallExpr getCall() { result = call }
}

private module InvokedClosureSatisfiesConstraintInput implements
SatisfiesConstraintInputSig<InvokedClosureExpr>
{
predicate relevantConstraint(InvokedClosureExpr term, Type constraint) {
exists(term) and
constraint.(TraitType).getTrait() instanceof FnOnceTrait
}
}

/** Gets the type of `ce` when viewed as an implementation of `FnOnce`. */
private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
SatisfiesConstraint<InvokedClosureExpr, InvokedClosureSatisfiesConstraintInput>::satisfiesConstraintType(ce,
_, path, result)
}

/** Gets the path to a closure's return type. */
private TypePath closureReturnPath() {
result = TypePath::singleton(TDynTraitTypeParameter(any(FnOnceTrait t).getOutputType()))
}

/** Gets the path to a closure with arity `arity`s `index`th parameter type. */
private TypePath closureParameterPath(int arity, int index) {
result =
TypePath::cons(TDynTraitTypeParameter(any(FnOnceTrait t).getTypeParam()),
TypePath::singleton(TTupleTypeParameter(arity, index)))
}

/** Gets the path to the return type of the `FnOnce` trait. */
private TypePath fnReturnPath() {
result = TypePath::singleton(TAssociatedTypeTypeParameter(any(FnOnceTrait t).getOutputType()))
}

/**
* Gets the path to the parameter type of the `FnOnce` trait with arity `arity`
* and index `index`.
*/
private TypePath fnParameterPath(int arity, int index) {
result =
TypePath::cons(TTypeParamTypeParameter(any(FnOnceTrait t).getTypeParam()),
TypePath::singleton(TTupleTypeParameter(arity, index)))
}

pragma[nomagic]
private Type inferDynamicCallExprType(Expr n, TypePath path) {
exists(InvokedClosureExpr ce |
// Propagate the function's return type to the call expression
exists(TypePath path0 | result = invokedClosureFnTypeAt(ce, path0) |
n = ce.getCall() and
path = path0.stripPrefix(fnReturnPath())
or
// Propagate the function's parameter type to the arguments
exists(int index |
n = ce.getCall().getArgList().getArg(index) and
path = path0.stripPrefix(fnParameterPath(ce.getCall().getNumberOfArgs(), index))
)
)
or
// _If_ the invoked expression has the type of a closure, then we propagate
// the surrounding types into the closure.
exists(int arity, TypePath path0 |
ce.getTypeAt(TypePath::nil()).(DynTraitType).getTrait() instanceof FnOnceTrait
Copy link
Preview

Copilot AI Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition performs multiple method calls and type checks in sequence. Consider caching the result of ce.getTypeAt(TypePath::nil()) or restructuring to avoid repeated computation.

Suggested change
ce.getTypeAt(TypePath::nil()).(DynTraitType).getTrait() instanceof FnOnceTrait
let cachedType = ce.getTypeAt(TypePath::nil()) |
cachedType.(DynTraitType).getTrait() instanceof FnOnceTrait

Copilot uses AI. Check for mistakes.

|
// Propagate the type of arguments to the parameter types of closure
exists(int index |
n = ce and
arity = ce.getCall().getNumberOfArgs() and
result = inferType(ce.getCall().getArg(index), path0) and
path = closureParameterPath(arity, index).append(path0)
)
or
// Propagate the type of the call expression to the return type of the closure
n = ce and
arity = ce.getCall().getNumberOfArgs() and
result = inferType(ce.getCall(), path0) and
path = closureReturnPath().append(path0)
)
)
}

pragma[nomagic]
private Type inferClosureExprType(AstNode n, TypePath path) {
exists(ClosureExpr ce |
n = ce and
path.isEmpty() and
result = TDynTraitType(any(FnOnceTrait t))
or
n = ce and
path = TypePath::singleton(TDynTraitTypeParameter(any(FnOnceTrait t).getTypeParam())) and
result = TTuple(ce.getNumberOfParams())
or
// Propagate return type annotation to body
n = ce.getBody() and
result = ce.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path)
)
}

pragma[nomagic]
private Type inferCastExprType(CastExpr ce, TypePath path) {
result = ce.getTypeRepr().(TypeMention).resolveTypeAt(path)
Expand Down Expand Up @@ -2062,6 +2187,10 @@ private module Cached {
or
result = inferForLoopExprType(n, path)
or
result = inferDynamicCallExprType(n, path)
or
result = inferClosureExprType(n, path)
or
result = inferCastExprType(n, path)
or
result = inferStructPatType(n, path)
Expand Down
24 changes: 24 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeMention.qll
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/** Provides classes for representing type mentions, used in type inference. */

private import rust
private import codeql.rust.frameworks.stdlib.Stdlib
private import Type
private import PathResolution
private import TypeInference
Expand All @@ -26,6 +27,18 @@ class TupleTypeReprMention extends TypeMention instanceof TupleTypeRepr {
}
}

class ParenthesizedArgListMention extends TypeMention instanceof ParenthesizedArgList {
override Type resolveTypeAt(TypePath path) {
path.isEmpty() and
result = TTuple(super.getNumberOfTypeArgs())
or
exists(TypePath suffix, int index |
result = super.getTypeArg(index).getTypeRepr().(TypeMention).resolveTypeAt(suffix) and
path = TypePath::cons(TTupleTypeParameter(super.getNumberOfTypeArgs(), index), suffix)
)
}
}

class ArrayTypeReprMention extends TypeMention instanceof ArrayTypeRepr {
override Type resolveTypeAt(TypePath path) {
path.isEmpty() and
Expand Down Expand Up @@ -215,6 +228,17 @@ class NonAliasPathTypeMention extends PathTypeMention {
.(TraitItemNode)
.getAssocItem(pragma[only_bind_into](name)))
)
or
// Handle the special syntactic sugar for function traits. For now we only
// support `FnOnce` as we can't support the "inherited" associated types of
// `Fn` and `FnMut` yet.
exists(FnOnceTrait t | t = resolved |
tp = TTypeParamTypeParameter(t.getTypeParam()) and
result = this.getSegment().getParenthesizedArgList()
or
tp = TAssociatedTypeTypeParameter(t.getOutputType()) and
result = this.getSegment().getRetType().getTypeRepr()
)
}

Type resolveRootType() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
category: minorAnalysis
---
* Type inference now supports closures, calls to closures, and trait bounds
sing the `FnOnce` trait.
76 changes: 76 additions & 0 deletions rust/ql/test/library-tests/type-inference/closure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/// Tests for type inference for closures and higher-order functions.

mod simple_closures {
pub fn test() {
// A simple closure without type annotations or invocations.
let my_closure = |a, b| a && b;

let x: i64 = 1i64; // $ type=x:i64
let add_one = |n| n + 1i64; // $ target=add
let _y = add_one(x); // $ type=_y:i64

// The type of `x` is inferred from the closure's argument type.
let x = Default::default(); // $ type=x:i64 target=default
let add_zero = |n: i64| n;
let _y = add_zero(x); // $ type=_y:i64

let _get_bool = || -> bool {
// The return type annotation on the closure lets us infer the type of `b`.
let b = Default::default(); // $ type=b:bool target=default
b
};

// The parameter type of `id` is inferred from the argument.
let id = |b| b; // $ type=b:bool
let _b = id(true); // $ type=_b:bool

// The return type of `id2` is inferred from the type of the call expression.
let id2 = |b| b;
let arg = Default::default(); // $ target=default type=arg:bool
let _b2: bool = id2(arg); // $ type=_b2:bool
}
}

mod fn_once_trait {
fn return_type<F: FnOnce(bool) -> i64>(f: F) {
let _return = f(true); // $ type=_return:i64
}

fn argument_type<F: FnOnce(bool) -> i64>(f: F) {
let arg = Default::default(); // $ target=default type=arg:bool
f(arg);
}

fn apply<A, B, F: FnOnce(A) -> B>(f: F, a: A) -> B {
f(a)
}

fn apply_two(f: impl FnOnce(i64) -> i64) -> i64 {
f(2)
}

fn test() {
let f = |x: bool| -> i64 {
if x {
1
} else {
0
}
};
let _r = apply(f, true); // $ target=apply type=_r:i64

let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
let _r2 = apply_two(f); // $ target=apply_two type=_r2:i64
}
}

mod dyn_fn_once {
fn apply_boxed<A, B, F: FnOnce(A) -> B + ?Sized>(f: Box<F>, arg: A) -> B {
f(arg)
}

fn apply_boxed_dyn<A, B>(f: Box<dyn FnOnce(A) -> B>, arg: A) {
let _r1 = apply_boxed(f, arg); // $ target=apply_boxed type=_r1:B
let _r2 = apply_boxed(Box::new(|_: i64| true), 3); // $ target=apply_boxed target=new type=_r2:bool
}
}
42 changes: 1 addition & 41 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2459,46 +2459,7 @@ pub mod pattern_matching_experimental {
}
}

mod closures {
struct Row {
data: i64,
}

impl Row {
fn get(&self) -> i64 {
self.data // $ fieldof=Row
}
}

struct Table {
rows: Vec<Row>,
}

impl Table {
fn new() -> Self {
Table { rows: Vec::new() } // $ target=new
}

fn count_with(&self, property: impl Fn(Row) -> bool) -> i64 {
0 // (not implemented)
}
}

pub fn f() {
Some(1).map(|x| {
let x = x; // $ MISSING: type=x:i32
println!("{x}");
}); // $ target=map

let table = Table::new(); // $ target=new type=table:Table
let result = table.count_with(|row| // $ type=result:i64
{
let v = row.get(); // $ MISSING: target=get type=v:i64
v > 0 // $ MISSING: target=gt
}); // $ target=count_with
}
}

mod closure;
mod dereference;
mod dyn_type;

Expand Down Expand Up @@ -2532,6 +2493,5 @@ fn main() {
dereference::test(); // $ target=test
pattern_matching::test_all_patterns(); // $ target=test_all_patterns
pattern_matching_experimental::box_patterns(); // $ target=box_patterns
closures::f(); // $ target=f
dyn_type::test(); // $ target=test
}
Loading