Skip to content

Commit 7684e01

Browse files
committed
Rust: Use Call in type inference
1 parent 4786478 commit 7684e01

File tree

1 file changed

+35
-180
lines changed

1 file changed

+35
-180
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 35 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ private import TypeMention
88
private import codeql.typeinference.internal.TypeInference
99
private import codeql.rust.frameworks.stdlib.Stdlib
1010
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
11+
private import codeql.rust.elements.Call
1112

1213
class Type = T::Type;
1314

@@ -496,28 +497,25 @@ private Type inferPathExprType(PathExpr pe, TypePath path) {
496497
* like `foo::bar(baz)` and `foo.bar(baz)`.
497498
*/
498499
private module CallExprBaseMatchingInput implements MatchingInputSig {
499-
private predicate paramPos(ParamList pl, Param p, int pos, boolean inMethod) {
500-
p = pl.getParam(pos) and
501-
if pl.hasSelfParam() then inMethod = true else inMethod = false
502-
}
500+
private predicate paramPos(ParamList pl, Param p, int pos) { p = pl.getParam(pos) }
503501

504502
private newtype TDeclarationPosition =
505503
TSelfDeclarationPosition() or
506-
TPositionalDeclarationPosition(int pos, boolean inMethod) { paramPos(_, _, pos, inMethod) } or
504+
TPositionalDeclarationPosition(int pos) { paramPos(_, _, pos) } or
507505
TReturnDeclarationPosition()
508506

509507
class DeclarationPosition extends TDeclarationPosition {
510508
predicate isSelf() { this = TSelfDeclarationPosition() }
511509

512-
int asPosition(boolean inMethod) { this = TPositionalDeclarationPosition(result, inMethod) }
510+
int asPosition() { this = TPositionalDeclarationPosition(result) }
513511

514512
predicate isReturn() { this = TReturnDeclarationPosition() }
515513

516514
string toString() {
517515
this.isSelf() and
518516
result = "self"
519517
or
520-
result = this.asPosition(_).toString()
518+
result = this.asPosition().toString()
521519
or
522520
this.isReturn() and
523521
result = "(return)"
@@ -550,7 +548,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
550548
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
551549
exists(int pos |
552550
result = this.getTupleField(pos).getTypeRepr().(TypeMention).resolveTypeAt(path) and
553-
dpos = TPositionalDeclarationPosition(pos, false)
551+
dpos = TPositionalDeclarationPosition(pos)
554552
)
555553
}
556554

@@ -573,7 +571,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
573571
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
574572
exists(int p |
575573
result = this.getTupleField(p).getTypeRepr().(TypeMention).resolveTypeAt(path) and
576-
dpos = TPositionalDeclarationPosition(p, false)
574+
dpos = TPositionalDeclarationPosition(p)
577575
)
578576
}
579577

@@ -606,9 +604,9 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
606604
}
607605

608606
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
609-
exists(Param p, int i, boolean inMethod |
610-
paramPos(this.getParamList(), p, i, inMethod) and
611-
dpos = TPositionalDeclarationPosition(i, inMethod) and
607+
exists(Param p, int i |
608+
paramPos(this.getParamList(), p, i) and
609+
dpos = TPositionalDeclarationPosition(i) and
612610
result = inferAnnotatedType(p.getPat(), path)
613611
)
614612
or
@@ -640,125 +638,44 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
640638
}
641639
}
642640

643-
private predicate argPos(CallExprBase call, Expr e, int pos, boolean isMethodCall) {
644-
exists(ArgList al |
645-
e = al.getArg(pos) and
646-
call.getArgList() = al and
647-
if call instanceof MethodCallExpr then isMethodCall = true else isMethodCall = false
648-
)
649-
}
650-
651-
private newtype TAccessPosition =
652-
TSelfAccessPosition() or
653-
TPositionalAccessPosition(int pos, boolean isMethodCall) { argPos(_, _, pos, isMethodCall) } or
654-
TReturnAccessPosition()
655-
656-
class AccessPosition extends TAccessPosition {
657-
predicate isSelf() { this = TSelfAccessPosition() }
658-
659-
int asPosition(boolean isMethodCall) { this = TPositionalAccessPosition(result, isMethodCall) }
660-
661-
predicate isReturn() { this = TReturnAccessPosition() }
662-
663-
string toString() {
664-
this.isSelf() and
665-
result = "self"
666-
or
667-
result = this.asPosition(_).toString()
668-
or
669-
this.isReturn() and
670-
result = "(return)"
671-
}
672-
}
641+
class AccessPosition = DeclarationPosition;
673642

674643
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl
675644

676-
abstract class Access extends Expr {
677-
abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
678-
679-
abstract AstNode getNodeAt(AccessPosition apos);
680-
681-
abstract Type getInferredType(AccessPosition apos, TypePath path);
682-
683-
abstract Declaration getTarget();
684-
}
685-
686-
private class CallExprBaseAccess extends Access instanceof CallExprBase {
687-
private TypeMention getMethodTypeArg(int i) {
688-
result = this.(MethodCallExpr).getGenericArgList().getTypeArg(i)
689-
}
690-
691-
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
645+
final class Access extends Call {
646+
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
692647
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
693648
arg = getExplicitTypeArgMention(CallExprImpl::getFunctionPath(this), apos.asTypeParam())
694649
or
695-
arg = this.getMethodTypeArg(apos.asMethodTypeArgumentPosition())
650+
arg =
651+
this.(MethodCallExpr).getGenericArgList().getTypeArg(apos.asMethodTypeArgumentPosition())
696652
)
697653
}
698654

699-
override AstNode getNodeAt(AccessPosition apos) {
700-
exists(int p, boolean isMethodCall |
701-
argPos(this, result, p, isMethodCall) and
702-
apos = TPositionalAccessPosition(p, isMethodCall)
703-
)
655+
AstNode getNodeAt(AccessPosition apos) {
656+
result = this.getArgument(apos.asPosition())
704657
or
705-
result = this.(MethodCallExpr).getReceiver() and
706-
apos = TSelfAccessPosition()
658+
result = this.getReceiver() and apos.isSelf()
707659
or
708-
result = this and
709-
apos = TReturnAccessPosition()
660+
result = this and apos.isReturn()
710661
}
711662

712-
override Type getInferredType(AccessPosition apos, TypePath path) {
663+
Type getInferredType(AccessPosition apos, TypePath path) {
713664
result = inferType(this.getNodeAt(apos), path)
714665
}
715666

716-
override Declaration getTarget() {
717-
result = CallExprImpl::getResolvedFunction(this)
718-
or
667+
Declaration getTarget() {
719668
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
720-
}
721-
}
722-
723-
private class OperationAccess extends Access instanceof Operation {
724-
OperationAccess() { super.isOverloaded(_, _) }
725-
726-
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
727-
// The syntax for operators does not allow type arguments.
728-
none()
729-
}
730-
731-
override AstNode getNodeAt(AccessPosition apos) {
732-
result = super.getOperand(0) and apos = TSelfAccessPosition()
733-
or
734-
result = super.getOperand(1) and apos = TPositionalAccessPosition(0, true)
735669
or
736-
result = this and apos = TReturnAccessPosition()
737-
}
738-
739-
override Type getInferredType(AccessPosition apos, TypePath path) {
740-
result = inferType(this.getNodeAt(apos), path)
741-
}
742-
743-
override Declaration getTarget() {
744-
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
670+
result = CallExprImpl::getResolvedFunction(this)
745671
}
746672
}
747673

748674
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
749675
apos.isSelf() and
750676
dpos.isSelf()
751677
or
752-
exists(int pos, boolean isMethodCall | pos = apos.asPosition(isMethodCall) |
753-
pos = 0 and
754-
isMethodCall = false and
755-
dpos.isSelf()
756-
or
757-
isMethodCall = false and
758-
pos = dpos.asPosition(true) + 1
759-
or
760-
pos = dpos.asPosition(isMethodCall)
761-
)
678+
apos.asPosition() = dpos.asPosition()
762679
or
763680
apos.isReturn() and
764681
dpos.isReturn()
@@ -1180,91 +1097,29 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
11801097
)
11811098
}
11821099

1183-
private module MethodCall {
1184-
/** An expression that calls a method. */
1185-
abstract private class MethodCallImpl extends Expr {
1186-
/** Gets the name of the method targeted. */
1187-
abstract string getMethodName();
1188-
1189-
/** Gets the number of arguments _excluding_ the `self` argument. */
1190-
abstract int getArity();
1191-
1192-
/** Gets the trait targeted by this method call, if any. */
1193-
Trait getTrait() { none() }
1194-
1195-
/** Gets the type of the receiver of the method call at `path`. */
1196-
abstract Type getTypeAt(TypePath path);
1100+
final class MethodCall extends Call {
1101+
MethodCall() {
1102+
exists(this.getReceiver()) and
1103+
// We want the method calls that don't have a path to a concrete method in
1104+
// an impl block. We need to exclude calls like `MyType::my_method(..)`.
1105+
(this instanceof CallExpr implies exists(this.getTrait()))
11971106
}
11981107

1199-
final class MethodCall = MethodCallImpl;
1200-
1201-
private class MethodCallExprMethodCall extends MethodCallImpl instanceof MethodCallExpr {
1202-
override string getMethodName() { result = super.getIdentifier().getText() }
1203-
1204-
override int getArity() { result = super.getArgList().getNumberOfArgs() }
1205-
1206-
pragma[nomagic]
1207-
override Type getTypeAt(TypePath path) {
1108+
/** Gets the type of the receiver of the method call at `path`. */
1109+
Type getTypeAt(TypePath path) {
1110+
if this.receiverImplicitlyBorrowed()
1111+
then
12081112
exists(TypePath path0 | result = inferType(super.getReceiver(), path0) |
12091113
path0.isCons(TRefTypeParameter(), path)
12101114
or
12111115
not path0.isCons(TRefTypeParameter(), _) and
12121116
not (path0.isEmpty() and result = TRefType()) and
12131117
path = path0
12141118
)
1215-
}
1216-
}
1217-
1218-
private class CallExprMethodCall extends MethodCallImpl instanceof CallExpr {
1219-
TraitItemNode trait;
1220-
string methodName;
1221-
Expr receiver;
1222-
1223-
CallExprMethodCall() {
1224-
receiver = this.getArg(0) and
1225-
exists(Path path, Function f |
1226-
path = this.getFunction().(PathExpr).getPath() and
1227-
f = resolvePath(path) and
1228-
f.getParamList().hasSelfParam() and
1229-
trait = resolvePath(path.getQualifier()) and
1230-
trait.getAnAssocItem() = f and
1231-
path.getSegment().getIdentifier().getText() = methodName
1232-
)
1233-
}
1234-
1235-
override string getMethodName() { result = methodName }
1236-
1237-
override int getArity() { result = super.getArgList().getNumberOfArgs() - 1 }
1238-
1239-
override Trait getTrait() { result = trait }
1240-
1241-
pragma[nomagic]
1242-
override Type getTypeAt(TypePath path) { result = inferType(receiver, path) }
1243-
}
1244-
1245-
private class OperationMethodCall extends MethodCallImpl instanceof Operation {
1246-
TraitItemNode trait;
1247-
string methodName;
1248-
1249-
OperationMethodCall() { super.isOverloaded(trait, methodName) }
1250-
1251-
override string getMethodName() { result = methodName }
1252-
1253-
override int getArity() { result = this.(Operation).getNumberOfOperands() - 1 }
1254-
1255-
override Trait getTrait() { result = trait }
1256-
1257-
pragma[nomagic]
1258-
override Type getTypeAt(TypePath path) {
1259-
result = inferType(this.(BinaryExpr).getLhs(), path)
1260-
or
1261-
result = inferType(this.(PrefixExpr).getExpr(), path)
1262-
}
1119+
else result = inferType(super.getReceiver(), path)
12631120
}
12641121
}
12651122

1266-
import MethodCall
1267-
12681123
/**
12691124
* Holds if a method for `type` with the name `name` and the arity `arity`
12701125
* exists in `impl`.
@@ -1293,7 +1148,7 @@ private module IsInstantiationOfInput implements IsInstantiationOfInputSig<Metho
12931148
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
12941149
rootType = mc.getTypeAt(TypePath::nil()) and
12951150
name = mc.getMethodName() and
1296-
arity = mc.getArity()
1151+
arity = mc.getNumberOfArguments()
12971152
}
12981153

12991154
pragma[nomagic]

0 commit comments

Comments
 (0)