Skip to content

Commit fca1614

Browse files
committed
wip
1 parent 865bca6 commit fca1614

File tree

4 files changed

+284
-88
lines changed

4 files changed

+284
-88
lines changed

rust/ql/lib/codeql/rust/internal/PathResolution.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2137,7 +2137,7 @@ private module Debug {
21372137
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
21382138
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
21392139
filepath.matches("%/main.rs") and
2140-
startline = 52
2140+
startline = 2909
21412141
)
21422142
}
21432143

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 155 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,8 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
748748
/**
749749
* A matching configuration for resolving types of struct expressions
750750
* like `Foo { bar = baz }`.
751+
*
752+
* This also includes nullary struct expressions like `None`.
751753
*/
752754
private module StructExprMatchingInput implements MatchingInputSig {
753755
private newtype TPos =
@@ -830,26 +832,100 @@ private module StructExprMatchingInput implements MatchingInputSig {
830832

831833
class AccessPosition = DeclarationPosition;
832834

833-
class Access extends StructExpr {
835+
abstract class Access extends AstNode {
836+
pragma[nomagic]
837+
abstract AstNode getNodeAt(AccessPosition apos);
838+
839+
pragma[nomagic]
840+
Type getInferredType(AccessPosition apos, TypePath path) {
841+
result = inferType(this.getNodeAt(apos), path)
842+
}
843+
844+
pragma[nomagic]
845+
abstract Path getStructPath();
846+
847+
pragma[nomagic]
848+
Declaration getTarget() { result = resolvePath(this.getStructPath()) }
849+
850+
pragma[nomagic]
834851
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
852+
exists(TypeMention tm, Path p, ItemNode target, int i |
853+
p = this.getStructPath() and
854+
target = this.getTarget() and
855+
apos.asTypeParam() = target.getTypeParam(pragma[only_bind_into](i)) and
856+
result = tm.resolveTypeAt(path)
857+
|
858+
// e.g. `Option::None::<i32>`
859+
tm = p.getSegment().getGenericArgList().getTypeArg(pragma[only_bind_into](i))
860+
or
861+
// todo: share logic with TypeMention.qll
862+
// e.g. `Option::<i32>::None`
863+
target instanceof Variant and
864+
not exists(p.getSegment().getGenericArgList().getTypeArg(_)) and
865+
tm = p.getQualifier().getSegment().getGenericArgList().getTypeArg(pragma[only_bind_into](i))
866+
)
867+
}
868+
869+
/**
870+
* Holds if the return type of this call at `path` may have to be inferred
871+
* from the context.
872+
*/
873+
pragma[nomagic]
874+
predicate isContextTypedAt(DeclarationPosition pos, TypePath path) {
875+
// Tuple declarations, such as `Result::Ok(...)`, may also be context typed
876+
exists(Declaration td, TypeParameter tp |
877+
td = this.getTarget() and
878+
pos.isStructPos() and
879+
tp = td.getDeclaredType(pos, path) and
880+
not exists(DeclarationPosition paramDpos |
881+
not paramDpos.isStructPos() and
882+
tp = td.getDeclaredType(paramDpos, _)
883+
) and
884+
// check that no explicit type arguments have been supplied for `tp`
885+
not exists(TypeArgumentPosition tapos |
886+
exists(this.getTypeArgument(tapos, _)) and
887+
TTypeParamTypeParameter(tapos.asTypeParam()) = tp
888+
)
889+
)
890+
}
891+
}
892+
893+
private class StructExprAccess extends Access, StructExpr {
894+
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
895+
result = super.getTypeArgument(apos, path)
896+
or
835897
exists(TypePath suffix |
836898
suffix.isCons(TTypeParamTypeParameter(apos.asTypeParam()), path) and
837899
result = CertainTypeInference::inferCertainType(this, suffix)
838900
)
839901
}
840902

841-
AstNode getNodeAt(AccessPosition apos) {
903+
override AstNode getNodeAt(AccessPosition apos) {
842904
result = this.getFieldExpr(apos.asFieldPos()).getExpr()
843905
or
844906
result = this and
845907
apos.isStructPos()
846908
}
847909

848-
Type getInferredType(AccessPosition apos, TypePath path) {
849-
result = inferType(this.getNodeAt(apos), path)
910+
override Path getStructPath() { result = this.getPath() }
911+
}
912+
913+
/**
914+
* A potential nullary struct/variant construction such as `None`.
915+
*/
916+
private class PathExprAccess extends Access, PathExpr {
917+
PathExprAccess() { not exists(CallExpr ce | this = ce.getFunction()) }
918+
919+
override AstNode getNodeAt(AccessPosition apos) {
920+
result = this and
921+
apos.isStructPos()
850922
}
851923

852-
Declaration getTarget() { result = resolvePath(this.getPath()) }
924+
override Path getStructPath() { result = this.getPath() }
925+
926+
private predicate testisContextTypedAt(DeclarationPosition pos, TypePath path) {
927+
super.isContextTypedAt(pos, path)
928+
}
853929
}
854930

855931
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
@@ -866,29 +942,27 @@ private module StructExprMatching = Matching<StructExprMatchingInput>;
866942
pragma[nomagic]
867943
private Type inferStructExprType(AstNode n, TypePath path) {
868944
exists(StructExprMatchingInput::Access a, StructExprMatchingInput::AccessPosition apos |
869-
n = a.getNodeAt(apos) and
945+
n = a.getNodeAt(apos)
946+
|
870947
result = StructExprMatching::inferAccessType(a, apos, path)
948+
or
949+
a.isContextTypedAt(apos, path) and
950+
result = TContextType()
871951
)
872952
}
873953

954+
// /**
955+
// * Gets the type of `n` at `path`, where `n` is either a struct expression or
956+
// * a field expression of a struct expression.
957+
// */
958+
// private predicate inferStructExprType =
959+
// ContextTyping::CheckContextTyping<inferStructExprType0/3>::check/2;
874960
pragma[nomagic]
875961
private Type inferTupleRootType(AstNode n) {
876962
// `typeEquality` handles the non-root cases
877963
result = TTuple([n.(TupleExpr).getNumberOfFields(), n.(TuplePat).getTupleArity()])
878964
}
879965

880-
pragma[nomagic]
881-
private Type inferPathExprType(PathExpr pe, TypePath path) {
882-
// nullary struct/variant constructors
883-
not exists(CallExpr ce | pe = ce.getFunction()) and
884-
path.isEmpty() and
885-
exists(ItemNode i | i = resolvePath(pe.getPath()) |
886-
result = TEnum(i.(Variant).getEnum())
887-
or
888-
result = TStruct(i)
889-
)
890-
}
891-
892966
pragma[nomagic]
893967
private Path getCallExprPathQualifier(CallExpr ce) {
894968
result = CallExprImpl::getFunctionPath(ce).getQualifier()
@@ -2460,6 +2534,9 @@ private module NonMethodResolution {
24602534
/**
24612535
* A matching configuration for resolving types of calls like
24622536
* `foo::bar(baz)` where the target is not a method.
2537+
*
2538+
* This also includes "calls" to tuple variants and tuple structs such
2539+
* as `Result::Ok(42)`.
24632540
*/
24642541
private module NonMethodCallMatchingInput implements MatchingInputSig {
24652542
import FunctionPositionMatchingInput
@@ -2490,7 +2567,10 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
24902567
}
24912568

24922569
private class TupleStructDecl extends TupleDeclaration, Struct {
2493-
TupleStructDecl() { this.isTuple() }
2570+
TupleStructDecl() {
2571+
this.isTuple() or
2572+
not this.hasFieldList()
2573+
}
24942574

24952575
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
24962576
typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos)
@@ -2513,7 +2593,10 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
25132593
}
25142594

25152595
private class TupleVariantDecl extends TupleDeclaration, Variant {
2516-
TupleVariantDecl() { this.isTuple() }
2596+
TupleVariantDecl() {
2597+
this.isTuple() or
2598+
not this.hasFieldList()
2599+
}
25172600

25182601
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
25192602
typeParamMatchPosition(this.getEnum().getGenericParamList().getATypeParam(), result, ppos)
@@ -2577,42 +2660,31 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
25772660
}
25782661
}
25792662

2580-
class Access extends NonMethodResolution::NonMethodCall, ContextTyping::ContextTypedCallCand {
2663+
abstract class Access extends AstNode {
25812664
pragma[nomagic]
2582-
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
2583-
result = getCallExprTypeArgument(this, apos).resolveTypeAt(path)
2584-
}
2665+
abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
25852666

25862667
pragma[nomagic]
2587-
Type getInferredType(AccessPosition apos, TypePath path) {
2588-
apos.isSelf() and
2589-
result = getCallExprTypeQualifier(this, path)
2590-
or
2591-
result = inferType(this.getNodeAt(apos), path)
2592-
}
2668+
abstract AstNode getNodeAt(FunctionPosition pos);
25932669

2594-
Declaration getTarget() {
2595-
result = this.resolveCallTarget() // potential mutual recursion; resolving some associated function calls requires resolving types
2596-
}
2670+
pragma[nomagic]
2671+
abstract Type getInferredType(AccessPosition apos, TypePath path);
2672+
2673+
pragma[nomagic]
2674+
abstract Declaration getTarget();
2675+
2676+
pragma[nomagic]
2677+
abstract Declaration getTargetForContextTyping();
25972678

25982679
/**
25992680
* Holds if the return type of this call at `path` may have to be inferred
26002681
* from the context.
26012682
*/
26022683
pragma[nomagic]
26032684
predicate isContextTypedAt(FunctionPosition pos, TypePath path) {
2604-
exists(ImplOrTraitItemNode i |
2605-
this.isContextTypedAt(i,
2606-
[
2607-
this.resolveCallTargetViaPathResolution().(NonMethodFunction),
2608-
this.resolveCallTargetViaTypeInference(i),
2609-
this.resolveTraitFunctionViaPathResolution(i)
2610-
], pos, path)
2611-
)
2612-
or
2613-
// Tuple declarations, such as `None`, may also be context typed
2685+
// Tuple declarations, such as `Result::Ok(...)`, may also be context typed
26142686
exists(TupleDeclaration td, TypeParameter tp |
2615-
td = this.resolveCallTargetViaPathResolution() and
2687+
td = this.getTargetForContextTyping() and
26162688
pos.isReturn() and
26172689
tp = td.getReturnType(path) and
26182690
not tp = td.getParameterType(_, _) and
@@ -2624,6 +2696,45 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
26242696
)
26252697
}
26262698
}
2699+
2700+
private class NonMethodCallAccess extends Access, ContextTyping::ContextTypedCallCand instanceof NonMethodResolution::NonMethodCall
2701+
{
2702+
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
2703+
result = getCallExprTypeArgument(this, apos).resolveTypeAt(path)
2704+
}
2705+
2706+
override AstNode getNodeAt(FunctionPosition pos) {
2707+
result = NonMethodResolution::NonMethodCall.super.getNodeAt(pos)
2708+
}
2709+
2710+
override Type getInferredType(AccessPosition apos, TypePath path) {
2711+
apos.isSelf() and
2712+
result = getCallExprTypeQualifier(this, path)
2713+
or
2714+
result = inferType(this.getNodeAt(apos), path)
2715+
}
2716+
2717+
override Declaration getTarget() {
2718+
result = super.resolveCallTarget() // potential mutual recursion; resolving some associated function calls requires resolving types
2719+
}
2720+
2721+
override Declaration getTargetForContextTyping() {
2722+
result = super.resolveCallTargetViaPathResolution()
2723+
}
2724+
2725+
override predicate isContextTypedAt(FunctionPosition pos, TypePath path) {
2726+
super.isContextTypedAt(pos, path)
2727+
or
2728+
exists(ImplOrTraitItemNode i |
2729+
this.isContextTypedAt(i,
2730+
[
2731+
super.resolveCallTargetViaPathResolution().(NonMethodFunction),
2732+
super.resolveCallTargetViaTypeInference(i),
2733+
super.resolveTraitFunctionViaPathResolution(i)
2734+
], pos, path)
2735+
)
2736+
}
2737+
}
26272738
}
26282739

26292740
private module NonMethodCallMatching = Matching<NonMethodCallMatchingInput>;
@@ -3488,8 +3599,6 @@ private module Cached {
34883599
or
34893600
result = inferStructExprType(n, path)
34903601
or
3491-
result = inferPathExprType(n, path)
3492-
or
34933602
result = inferMethodCallType(n, path)
34943603
or
34953604
result = inferNonMethodCallType(n, path)

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2904,6 +2904,28 @@ mod block_types {
29042904
}
29052905
}
29062906

2907+
mod context_typed {
2908+
pub fn f() {
2909+
let x: Option<i32> = None; // $ type=x:T.i32
2910+
let x = Option::<i32>::None; // $ type=x:T.i32
2911+
let x = Option::None::<i32>; // $ type=x:T.i32
2912+
2913+
fn pin_option<T>(opt: Option<T>, x: T) {}
2914+
2915+
let x = None; // $ type=x:T.i32
2916+
pin_option(x, 0); // $ target=pin_option
2917+
2918+
let x: Result<i32, String> = Result::Ok(0); // $ type=x:E.String
2919+
let x = Result::<i32, String>::Ok(0); // $ type=x:E.String
2920+
let x = Result::Ok::<i32, String>(0); // $ type=x:E.String
2921+
2922+
fn pin_result<T, E>(res: Result<T, E>, x: E) {}
2923+
2924+
let x = Result::Ok(0); // $ type=x:T.i32 type=x:E.bool
2925+
pin_result(x, false); // $ target=pin_result
2926+
}
2927+
}
2928+
29072929
mod blanket_impl;
29082930
mod closure;
29092931
mod dereference;

0 commit comments

Comments
 (0)