Skip to content

Commit f7e5281

Browse files
committed
wip
1 parent fc9ed22 commit f7e5281

File tree

4 files changed

+185
-94
lines changed

4 files changed

+185
-94
lines changed

rust/ql/lib/codeql/rust/internal/PathResolution.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2137,7 +2137,7 @@ private module Debug {
21372137
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
21382138
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
21392139
filepath.matches("%/main.rs") and
2140-
startline = 52
2140+
startline = 2909
21412141
)
21422142
}
21432143

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 170 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,8 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
748748
/**
749749
* A matching configuration for resolving types of struct expressions
750750
* like `Foo { bar = baz }`.
751+
*
752+
* This also includes nullary struct expressions like `None`.
751753
*/
752754
private module StructExprMatchingInput implements MatchingInputSig {
753755
private newtype TPos =
@@ -830,26 +832,96 @@ private module StructExprMatchingInput implements MatchingInputSig {
830832

831833
class AccessPosition = DeclarationPosition;
832834

833-
class Access extends StructExpr {
835+
abstract class Access extends AstNode {
836+
pragma[nomagic]
837+
abstract AstNode getNodeAt(AccessPosition apos);
838+
839+
pragma[nomagic]
840+
Type getInferredType(AccessPosition apos, TypePath path) {
841+
result = inferType(this.getNodeAt(apos), path)
842+
}
843+
844+
pragma[nomagic]
845+
abstract Path getStructPath();
846+
847+
pragma[nomagic]
848+
Declaration getTarget() { result = resolvePath(this.getStructPath()) }
849+
850+
pragma[nomagic]
834851
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
852+
exists(TypeMention tm, Path p, ItemNode target, int i |
853+
p = this.getStructPath() and
854+
target = this.getTarget() and
855+
apos.asTypeParam() = target.getTypeParam(pragma[only_bind_into](i)) and
856+
result = tm.resolveTypeAt(path)
857+
|
858+
// e.g. `Option::None::<i32>`
859+
tm = p.getSegment().getGenericArgList().getTypeArg(pragma[only_bind_into](i))
860+
or
861+
// todo: share logic with TypeMention.qll
862+
// e.g. `Option::<i32>::None`
863+
target instanceof Variant and
864+
not exists(p.getSegment().getGenericArgList().getTypeArg(_)) and
865+
tm = p.getQualifier().getSegment().getGenericArgList().getTypeArg(pragma[only_bind_into](i))
866+
)
867+
}
868+
869+
/**
870+
* Holds if the return type of this call at `path` may have to be inferred
871+
* from the context.
872+
*/
873+
pragma[nomagic]
874+
predicate isContextTypedAt(DeclarationPosition pos, TypePath path) {
875+
// Struct declarations, such as `Foo::Bar{field = ...}`, may also be context typed
876+
exists(Declaration td, TypeParameter tp |
877+
td = this.getTarget() and
878+
pos.isStructPos() and
879+
tp = td.getDeclaredType(pos, path) and
880+
not exists(DeclarationPosition paramDpos |
881+
not paramDpos.isStructPos() and
882+
tp = td.getDeclaredType(paramDpos, _)
883+
) and
884+
// check that no explicit type arguments have been supplied for `tp`
885+
not exists(TypeArgumentPosition tapos |
886+
exists(this.getTypeArgument(tapos, _)) and
887+
TTypeParamTypeParameter(tapos.asTypeParam()) = tp
888+
)
889+
)
890+
}
891+
}
892+
893+
private class StructExprAccess extends Access, StructExpr {
894+
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
895+
result = super.getTypeArgument(apos, path)
896+
or
835897
exists(TypePath suffix |
836898
suffix.isCons(TTypeParamTypeParameter(apos.asTypeParam()), path) and
837899
result = CertainTypeInference::inferCertainType(this, suffix)
838900
)
839901
}
840902

841-
AstNode getNodeAt(AccessPosition apos) {
903+
override AstNode getNodeAt(AccessPosition apos) {
842904
result = this.getFieldExpr(apos.asFieldPos()).getExpr()
843905
or
844906
result = this and
845907
apos.isStructPos()
846908
}
847909

848-
Type getInferredType(AccessPosition apos, TypePath path) {
849-
result = inferType(this.getNodeAt(apos), path)
910+
override Path getStructPath() { result = this.getPath() }
911+
}
912+
913+
/**
914+
* A potential nullary struct/variant construction such as `None`.
915+
*/
916+
private class PathExprAccess extends Access, PathExpr {
917+
PathExprAccess() { not exists(CallExpr ce | this = ce.getFunction()) }
918+
919+
override AstNode getNodeAt(AccessPosition apos) {
920+
result = this and
921+
apos.isStructPos()
850922
}
851923

852-
Declaration getTarget() { result = resolvePath(this.getPath()) }
924+
override Path getStructPath() { result = this.getPath() }
853925
}
854926

855927
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
@@ -859,36 +931,32 @@ private module StructExprMatchingInput implements MatchingInputSig {
859931

860932
private module StructExprMatching = Matching<StructExprMatchingInput>;
861933

862-
/**
863-
* Gets the type of `n` at `path`, where `n` is either a struct expression or
864-
* a field expression of a struct expression.
865-
*/
866934
pragma[nomagic]
867-
private Type inferStructExprType(AstNode n, TypePath path) {
935+
private Type inferStructExprType0(AstNode n, boolean isReturn, TypePath path) {
868936
exists(StructExprMatchingInput::Access a, StructExprMatchingInput::AccessPosition apos |
869937
n = a.getNodeAt(apos) and
938+
if apos.isStructPos() then isReturn = true else isReturn = false
939+
|
870940
result = StructExprMatching::inferAccessType(a, apos, path)
941+
or
942+
a.isContextTypedAt(apos, path) and
943+
result = TContextType()
871944
)
872945
}
873946

947+
/**
948+
* Gets the type of `n` at `path`, where `n` is either a struct expression or
949+
* a field expression of a struct expression.
950+
*/
951+
private predicate inferStructExprType =
952+
ContextTyping::CheckContextTyping<inferStructExprType0/3>::check/2;
953+
874954
pragma[nomagic]
875955
private Type inferTupleRootType(AstNode n) {
876956
// `typeEquality` handles the non-root cases
877957
result = TTuple([n.(TupleExpr).getNumberOfFields(), n.(TuplePat).getTupleArity()])
878958
}
879959

880-
pragma[nomagic]
881-
private Type inferPathExprType(PathExpr pe, TypePath path) {
882-
// nullary struct/variant constructors
883-
not exists(CallExpr ce | pe = ce.getFunction()) and
884-
path.isEmpty() and
885-
exists(ItemNode i | i = resolvePath(pe.getPath()) |
886-
result = TEnum(i.(Variant).getEnum())
887-
or
888-
result = TStruct(i)
889-
)
890-
}
891-
892960
pragma[nomagic]
893961
private Path getCallExprPathQualifier(CallExpr ce) {
894962
result = CallExprImpl::getFunctionPath(ce).getQualifier()
@@ -982,7 +1050,7 @@ private module ContextTyping {
9821050
pragma[nomagic]
9831051
private predicate isContextTyped(AstNode n) { isContextTyped(n, _) }
9841052

985-
signature Type inferCallTypeSig(AstNode n, FunctionPosition pos, TypePath path);
1053+
signature Type inferCallTypeSig(AstNode n, boolean isReturn, TypePath path);
9861054

9871055
/**
9881056
* Given a predicate `inferCallType` for inferring the type of a call at a given
@@ -992,30 +1060,24 @@ private module ContextTyping {
9921060
*/
9931061
module CheckContextTyping<inferCallTypeSig/3 inferCallType> {
9941062
pragma[nomagic]
995-
private Type inferCallTypeFromContextCand(
996-
AstNode n, FunctionPosition pos, TypePath path, TypePath prefix
997-
) {
998-
result = inferCallType(n, pos, path) and
999-
not pos.isReturn() and
1063+
private Type inferCallTypeFromContextCand(AstNode n, TypePath path, TypePath prefix) {
1064+
result = inferCallType(n, false, path) and
10001065
isContextTyped(n) and
10011066
prefix = path
10021067
or
10031068
exists(TypePath mid |
1004-
result = inferCallTypeFromContextCand(n, pos, path, mid) and
1069+
result = inferCallTypeFromContextCand(n, path, mid) and
10051070
mid.isSnoc(prefix, _)
10061071
)
10071072
}
10081073

10091074
pragma[nomagic]
10101075
Type check(AstNode n, TypePath path) {
1011-
exists(FunctionPosition pos |
1012-
result = inferCallType(n, pos, path) and
1013-
pos.isReturn()
1014-
or
1015-
exists(TypePath prefix |
1016-
result = inferCallTypeFromContextCand(n, pos, path, prefix) and
1017-
isContextTyped(n, prefix)
1018-
)
1076+
result = inferCallType(n, true, path)
1077+
or
1078+
exists(TypePath prefix |
1079+
result = inferCallTypeFromContextCand(n, path, prefix) and
1080+
isContextTyped(n, prefix)
10191081
)
10201082
}
10211083
}
@@ -2131,11 +2193,13 @@ private Type inferMethodCallType0(
21312193
}
21322194

21332195
pragma[nomagic]
2134-
private Type inferMethodCallType1(
2135-
AstNode n, MethodCallMatchingInput::AccessPosition apos, TypePath path
2136-
) {
2137-
exists(MethodCallMatchingInput::Access a, string derefChainBorrow, TypePath path0 |
2138-
result = inferMethodCallType0(a, apos, n, derefChainBorrow, path0)
2196+
private Type inferMethodCallType1(AstNode n, boolean isReturn, TypePath path) {
2197+
exists(
2198+
MethodCallMatchingInput::Access a, MethodCallMatchingInput::AccessPosition apos,
2199+
string derefChainBorrow, TypePath path0
2200+
|
2201+
result = inferMethodCallType0(a, apos, n, derefChainBorrow, path0) and
2202+
if apos.isReturn() then isReturn = true else isReturn = false
21392203
|
21402204
(
21412205
not apos.isSelf()
@@ -2460,6 +2524,9 @@ private module NonMethodResolution {
24602524
/**
24612525
* A matching configuration for resolving types of calls like
24622526
* `foo::bar(baz)` where the target is not a method.
2527+
*
2528+
* This also includes "calls" to tuple variants and tuple structs such
2529+
* as `Result::Ok(42)`.
24632530
*/
24642531
private module NonMethodCallMatchingInput implements MatchingInputSig {
24652532
import FunctionPositionMatchingInput
@@ -2577,42 +2644,31 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
25772644
}
25782645
}
25792646

2580-
class Access extends NonMethodResolution::NonMethodCall, ContextTyping::ContextTypedCallCand {
2647+
abstract class Access extends AstNode {
25812648
pragma[nomagic]
2582-
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
2583-
result = getCallExprTypeArgument(this, apos).resolveTypeAt(path)
2584-
}
2649+
abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
25852650

25862651
pragma[nomagic]
2587-
Type getInferredType(AccessPosition apos, TypePath path) {
2588-
apos.isSelf() and
2589-
result = getCallExprTypeQualifier(this, path)
2590-
or
2591-
result = inferType(this.getNodeAt(apos), path)
2592-
}
2652+
abstract AstNode getNodeAt(FunctionPosition pos);
25932653

2594-
Declaration getTarget() {
2595-
result = this.resolveCallTarget() // potential mutual recursion; resolving some associated function calls requires resolving types
2596-
}
2654+
pragma[nomagic]
2655+
abstract Type getInferredType(AccessPosition apos, TypePath path);
2656+
2657+
pragma[nomagic]
2658+
abstract Declaration getTarget();
2659+
2660+
pragma[nomagic]
2661+
abstract Declaration getTargetForContextTyping();
25972662

25982663
/**
25992664
* Holds if the return type of this call at `path` may have to be inferred
26002665
* from the context.
26012666
*/
26022667
pragma[nomagic]
26032668
predicate isContextTypedAt(FunctionPosition pos, TypePath path) {
2604-
exists(ImplOrTraitItemNode i |
2605-
this.isContextTypedAt(i,
2606-
[
2607-
this.resolveCallTargetViaPathResolution().(NonMethodFunction),
2608-
this.resolveCallTargetViaTypeInference(i),
2609-
this.resolveTraitFunctionViaPathResolution(i)
2610-
], pos, path)
2611-
)
2612-
or
2613-
// Tuple declarations, such as `None`, may also be context typed
2669+
// Tuple declarations, such as `Result::Ok(...)`, may also be context typed
26142670
exists(TupleDeclaration td, TypeParameter tp |
2615-
td = this.resolveCallTargetViaPathResolution() and
2671+
td = this.getTargetForContextTyping() and
26162672
pos.isReturn() and
26172673
tp = td.getReturnType(path) and
26182674
not tp = td.getParameterType(_, _) and
@@ -2624,15 +2680,55 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
26242680
)
26252681
}
26262682
}
2683+
2684+
private class NonMethodCallAccess extends Access, ContextTyping::ContextTypedCallCand instanceof NonMethodResolution::NonMethodCall
2685+
{
2686+
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
2687+
result = getCallExprTypeArgument(this, apos).resolveTypeAt(path)
2688+
}
2689+
2690+
override AstNode getNodeAt(FunctionPosition pos) {
2691+
result = NonMethodResolution::NonMethodCall.super.getNodeAt(pos)
2692+
}
2693+
2694+
override Type getInferredType(AccessPosition apos, TypePath path) {
2695+
apos.isSelf() and
2696+
result = getCallExprTypeQualifier(this, path)
2697+
or
2698+
result = inferType(this.getNodeAt(apos), path)
2699+
}
2700+
2701+
override Declaration getTarget() {
2702+
result = super.resolveCallTarget() // potential mutual recursion; resolving some associated function calls requires resolving types
2703+
}
2704+
2705+
override Declaration getTargetForContextTyping() {
2706+
result = super.resolveCallTargetViaPathResolution()
2707+
}
2708+
2709+
override predicate isContextTypedAt(FunctionPosition pos, TypePath path) {
2710+
super.isContextTypedAt(pos, path)
2711+
or
2712+
exists(ImplOrTraitItemNode i |
2713+
this.isContextTypedAt(i,
2714+
[
2715+
super.resolveCallTargetViaPathResolution().(NonMethodFunction),
2716+
super.resolveCallTargetViaTypeInference(i),
2717+
super.resolveTraitFunctionViaPathResolution(i)
2718+
], pos, path)
2719+
)
2720+
}
2721+
}
26272722
}
26282723

26292724
private module NonMethodCallMatching = Matching<NonMethodCallMatchingInput>;
26302725

26312726
pragma[nomagic]
2632-
private Type inferNonMethodCallType0(
2633-
AstNode n, NonMethodCallMatchingInput::AccessPosition apos, TypePath path
2634-
) {
2635-
exists(NonMethodCallMatchingInput::Access a | n = a.getNodeAt(apos) |
2727+
private Type inferNonMethodCallType0(AstNode n, boolean isReturn, TypePath path) {
2728+
exists(NonMethodCallMatchingInput::Access a, NonMethodCallMatchingInput::AccessPosition apos |
2729+
n = a.getNodeAt(apos) and
2730+
if apos.isReturn() then isReturn = true else isReturn = false
2731+
|
26362732
result = NonMethodCallMatching::inferAccessType(a, apos, path)
26372733
or
26382734
a.isContextTypedAt(apos, path) and
@@ -2715,12 +2811,11 @@ private module OperationMatchingInput implements MatchingInputSig {
27152811
private module OperationMatching = Matching<OperationMatchingInput>;
27162812

27172813
pragma[nomagic]
2718-
private Type inferOperationType0(
2719-
AstNode n, OperationMatchingInput::AccessPosition apos, TypePath path
2720-
) {
2721-
exists(OperationMatchingInput::Access a |
2814+
private Type inferOperationType0(AstNode n, boolean isReturn, TypePath path) {
2815+
exists(OperationMatchingInput::Access a, OperationMatchingInput::AccessPosition apos |
27222816
n = a.getNodeAt(apos) and
2723-
result = OperationMatching::inferAccessType(a, apos, path)
2817+
result = OperationMatching::inferAccessType(a, apos, path) and
2818+
if apos.isReturn() then isReturn = true else isReturn = false
27242819
)
27252820
}
27262821

@@ -3488,8 +3583,6 @@ private module Cached {
34883583
or
34893584
result = inferStructExprType(n, path)
34903585
or
3491-
result = inferPathExprType(n, path)
3492-
or
34933586
result = inferMethodCallType(n, path)
34943587
or
34953588
result = inferNonMethodCallType(n, path)

0 commit comments

Comments
 (0)