Skip to content

Commit 9097870

Browse files
committed
Rust: Restrict type propagation into arguments
1 parent 46d0506 commit 9097870

File tree

13 files changed

+155
-563
lines changed

13 files changed

+155
-563
lines changed

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ newtype TType =
5151
TSliceType() or
5252
TNeverType() or
5353
TPtrType() or
54+
TContextType() or
5455
TTupleTypeParameter(int arity, int i) { exists(TTuple(arity)) and i in [0 .. arity - 1] } or
5556
TTypeParamTypeParameter(TypeParam t) or
5657
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
@@ -371,6 +372,26 @@ class PtrType extends Type, TPtrType {
371372
override Location getLocation() { result instanceof EmptyLocation }
372373
}
373374

375+
/**
376+
* A special pseudo type used to indicate that the actual type is to be inferred
377+
* from a context.
378+
*
379+
* For example, a call like `Default::default()` is assigned this type, which
380+
* means that the actual type is to be inferred from the context in which the call
381+
* occurs.
382+
*
383+
* Context types are not restricted to root types, for example in a call like
384+
* `Vec::new()` we assign this type at the type path corresponding to the type
385+
* parameter of `Vec`.
386+
*/
387+
class ContextType extends Type, TContextType {
388+
override TypeParameter getPositionalTypeParameter(int i) { none() }
389+
390+
override string toString() { result = "(context typed)" }
391+
392+
override Location getLocation() { result instanceof EmptyLocation }
393+
}
394+
374395
/** A type parameter. */
375396
abstract class TypeParameter extends Type {
376397
override TypeParameter getPositionalTypeParameter(int i) { none() }

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,38 @@ private Type inferPathExprType(PathExpr pe, TypePath path) {
865865
)
866866
}
867867

868+
/**
869+
* A call where the type of the result may have to be inferred from the
870+
* context in which the call appears.
871+
*/
872+
abstract private class ContextTypedCallCand extends AstNode {
873+
abstract predicate hasTypeArgument(TypeArgumentPosition apos);
874+
875+
/**
876+
* Holds if `this` call resolves to `target` and the type at `pos` and `path`
877+
* may be inferred from the context.
878+
*/
879+
bindingset[this, target]
880+
predicate isContextTypedAt(Function target, TypePath path, FunctionPosition pos) {
881+
exists(TypeParameter tp |
882+
assocFunctionReturnContextTypedAt(target, pos, path, tp) and
883+
// check that no explicit type arguments have been supplied for `tp`
884+
not exists(TypeArgumentPosition tapos | this.hasTypeArgument(tapos) |
885+
exists(int i |
886+
i = tapos.asMethodTypeArgumentPosition() and
887+
tp = TTypeParamTypeParameter(target.getGenericParamList().getTypeParam(i))
888+
)
889+
or
890+
TTypeParamTypeParameter(tapos.asTypeParam()) = tp
891+
) and
892+
not (
893+
tp instanceof TSelfTypeParameter and
894+
exists(getCallExprTypeQualifier(this, _))
895+
)
896+
)
897+
}
898+
}
899+
868900
pragma[nomagic]
869901
private Path getCallExprPathQualifier(CallExpr ce) {
870902
result = CallExprImpl::getFunctionPath(ce).getQualifier()
@@ -1890,7 +1922,7 @@ private module MethodCallMatchingInput implements MatchingWithEnvironmentInputSi
18901922

18911923
final private class MethodCallFinal = MethodResolution::MethodCall;
18921924

1893-
class Access extends MethodCallFinal {
1925+
class Access extends MethodCallFinal, ContextTypedCallCand {
18941926
Access() {
18951927
// handled in the `OperationMatchingInput` module
18961928
not this instanceof Operation
@@ -1906,6 +1938,10 @@ private module MethodCallMatchingInput implements MatchingWithEnvironmentInputSi
19061938
)
19071939
}
19081940

1941+
override predicate hasTypeArgument(TypeArgumentPosition apos) {
1942+
exists(this.getTypeArgument(apos, _))
1943+
}
1944+
19091945
pragma[nomagic]
19101946
private Type getInferredSelfType(string derefChain, boolean borrow, TypePath path) {
19111947
result = this.getACandidateReceiverTypeAt(derefChain, borrow, path)
@@ -1961,7 +1997,12 @@ private Type inferMethodCallType0(
19611997
) {
19621998
exists(TypePath path0 |
19631999
n = a.getNodeAt(apos) and
1964-
result = MethodCallMatching::inferAccessType(a, derefChainBorrow, apos, path0)
2000+
(
2001+
result = MethodCallMatching::inferAccessType(a, derefChainBorrow, apos, path0)
2002+
or
2003+
a.isContextTypedAt(a.getTarget(derefChainBorrow), path0, apos) and
2004+
result = TContextType()
2005+
)
19652006
|
19662007
if
19672008
// index expression `x[i]` desugars to `*x.index(i)`, so we must account for
@@ -1973,6 +2014,9 @@ private Type inferMethodCallType0(
19732014
)
19742015
}
19752016

2017+
pragma[nomagic]
2018+
private TypePath getContextTypePath(AstNode n) { inferType(n, result) = TContextType() }
2019+
19762020
/**
19772021
* Gets the type of `n` at `path`, where `n` is either a method call or an
19782022
* argument/receiver of a method call.
@@ -1983,7 +2027,8 @@ private Type inferMethodCallType(AstNode n, TypePath path) {
19832027
MethodCallMatchingInput::Access a, MethodCallMatchingInput::AccessPosition apos,
19842028
string derefChainBorrow, TypePath path0
19852029
|
1986-
result = inferMethodCallType0(a, apos, n, derefChainBorrow, path0)
2030+
result = inferMethodCallType0(a, apos, n, derefChainBorrow, path0) and
2031+
if not apos.isReturn() then path.startsWith(getContextTypePath(n)) else any()
19872032
|
19882033
(
19892034
not apos.isSelf()
@@ -2171,6 +2216,12 @@ private module NonMethodResolution {
21712216
or
21722217
result = this.resolveCallTargetRec()
21732218
}
2219+
2220+
pragma[nomagic]
2221+
Function resolveTraitFunction() {
2222+
this.(Call).hasTrait() and
2223+
result = this.getPathResolutionResolved()
2224+
}
21742225
}
21752226

21762227
private newtype TCallAndBlanketPos =
@@ -2405,12 +2456,16 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
24052456
}
24062457
}
24072458

2408-
class Access extends NonMethodResolution::NonMethodCall {
2459+
class Access extends NonMethodResolution::NonMethodCall, ContextTypedCallCand {
24092460
pragma[nomagic]
24102461
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
24112462
result = getCallExprTypeArgument(this, apos).resolveTypeAt(path)
24122463
}
24132464

2465+
override predicate hasTypeArgument(TypeArgumentPosition apos) {
2466+
exists(this.getTypeArgument(apos, _))
2467+
}
2468+
24142469
pragma[nomagic]
24152470
Type getInferredType(AccessPosition apos, TypePath path) {
24162471
apos.isSelf() and
@@ -2431,7 +2486,12 @@ pragma[nomagic]
24312486
private Type inferNonMethodCallType(AstNode n, TypePath path) {
24322487
exists(NonMethodCallMatchingInput::Access a, NonMethodCallMatchingInput::AccessPosition apos |
24332488
n = a.getNodeAt(apos) and
2489+
if not apos.isReturn() then path.startsWith(getContextTypePath(n)) else any()
2490+
|
24342491
result = NonMethodCallMatching::inferAccessType(a, apos, path)
2492+
or
2493+
a.isContextTypedAt([a.resolveCallTarget().(Function), a.resolveTraitFunction()], path, apos) and
2494+
result = TContextType()
24352495
)
24362496
}
24372497

@@ -2510,7 +2570,8 @@ pragma[nomagic]
25102570
private Type inferOperationType(AstNode n, TypePath path) {
25112571
exists(OperationMatchingInput::Access a, OperationMatchingInput::AccessPosition apos |
25122572
n = a.getNodeAt(apos) and
2513-
result = OperationMatching::inferAccessType(a, apos, path)
2573+
result = OperationMatching::inferAccessType(a, apos, path) and
2574+
if not apos.isReturn() then path.startsWith(getContextTypePath(n)) else any()
25142575
)
25152576
}
25162577

@@ -3291,8 +3352,10 @@ private module Debug {
32913352
Locatable getRelevantLocatable() {
32923353
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
32933354
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
3294-
filepath.matches("%/sqlx.rs") and
3295-
startline = [56 .. 60]
3355+
filepath.matches("%/crate/data_derive/src/lib.rs") and
3356+
startline = [48, 74]
3357+
// filepath.matches("%/main.rs") and
3358+
// startline = [2525]
32963359
)
32973360
}
32983361

@@ -3355,7 +3418,7 @@ private module Debug {
33553418
}
33563419

33573420
predicate countTypesForNodeAtLimit(AstNode n, int c) {
3358-
n = getRelevantLocatable() and
3421+
// n = getRelevantLocatable() and
33593422
c = strictcount(Type t, TypePath path | t = debugInferTypeForNodeAtLimit(n, path))
33603423
}
33613424

rust/ql/lib/codeql/rust/internal/typeinference/FunctionType.qll

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,24 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
418418
)
419419
}
420420
}
421+
422+
/**
423+
* Holds if the return type of the function `f` at path `path` is `tp`,
424+
* and `tp` does not appear in the type of any parameter of `f`.
425+
*
426+
* In this case, the context in which `f` is called may be needed to infer
427+
* the instantiation of `tp`.
428+
*/
429+
pragma[nomagic]
430+
predicate assocFunctionReturnContextTypedAt(
431+
Function f, FunctionPosition pos, TypePath path, TypeParameter tp
432+
) {
433+
exists(ImplOrTraitItemNode i |
434+
pos.isReturn() and
435+
assocFunctionTypeAt(f, i, pos, path, tp) and
436+
not exists(FunctionPosition nonResPos |
437+
not nonResPos.isReturn() and
438+
assocFunctionTypeAt(f, i, nonResPos, _, tp)
439+
)
440+
)
441+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
multipleCallTargets
2+
| test.rs:113:62:113:77 | ...::from(...) |
3+
| test.rs:120:58:120:73 | ...::from(...) |
4+
| test.rs:229:22:229:72 | ... .read_to_string(...) |
5+
| test.rs:911:24:911:34 | row.take(...) |
6+
| test.rs:998:24:998:34 | row.take(...) |
7+
| test.rs:1096:50:1096:66 | ...::from(...) |
8+
| test.rs:1096:50:1096:66 | ...::from(...) |
9+
| test_futures_io.rs:35:26:35:63 | pinned.poll_read(...) |
10+
| test_futures_io.rs:62:22:62:50 | pinned.poll_fill_buf(...) |
11+
| test_futures_io.rs:69:23:69:67 | ... .poll_fill_buf(...) |
12+
| test_futures_io.rs:93:26:93:63 | pinned.poll_read(...) |
13+
| test_futures_io.rs:116:22:116:50 | pinned.poll_fill_buf(...) |
14+
multiplePathResolutions
15+
| test.rs:897:28:897:65 | Result::<...> |
16+
| test.rs:984:40:984:49 | Result::<...> |
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
multipleCallTargets
2+
| test.rs:24:24:24:34 | row.take(...) |
3+
| test.rs:111:24:111:34 | row.take(...) |
14
multiplePathResolutions
25
| test.rs:10:28:10:65 | Result::<...> |
36
| test.rs:97:40:97:49 | Result::<...> |
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
multipleCallTargets
2+
| test.rs:288:7:288:36 | ... .as_str() |

0 commit comments

Comments
 (0)