Skip to content

Commit 7d21694

Browse files
committed
Rust: Account for borrows in operators in type inference
1 parent ccce99c commit 7d21694

File tree

5 files changed

+50
-44
lines changed

5 files changed

+50
-44
lines changed

rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,19 @@ module Impl {
133133
private class OperatorCall extends Call instanceof Operation {
134134
Trait trait;
135135
string methodName;
136+
int borrows;
136137

137-
OperatorCall() { super.isOverloaded(trait, methodName) }
138+
OperatorCall() { super.isOverloaded(trait, methodName, borrows) }
138139

139140
override string getMethodName() { result = methodName }
140141

141142
override Trait getTrait() { result = trait }
142143

143-
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
144+
override predicate implicitBorrowAt(ArgumentPosition pos) {
145+
pos.isSelf() and borrows >= 1
146+
or
147+
pos.asPosition() = 0 and borrows = 2
148+
}
144149

145150
override Expr getArgument(ArgumentPosition pos) {
146151
pos.isSelf() and result = super.getOperand(0)

rust/ql/lib/codeql/rust/elements/internal/OperationImpl.qll

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,79 +9,80 @@ private import codeql.rust.elements.internal.ExprImpl::Impl as ExprImpl
99

1010
/**
1111
* Holds if the operator `op` with arity `arity` is overloaded to a trait with
12-
* the canonical path `path` and the method name `method`.
12+
* the canonical path `path` and the method name `method`, and if it borrows its
13+
* first `borrows` arguments.
1314
*/
14-
private predicate isOverloaded(string op, int arity, string path, string method) {
15+
private predicate isOverloaded(string op, int arity, string path, string method, int borrows) {
1516
arity = 1 and
1617
(
1718
// Negation
18-
op = "-" and path = "core::ops::arith::Neg" and method = "neg"
19+
op = "-" and path = "core::ops::arith::Neg" and method = "neg" and borrows = 0
1920
or
2021
// Not
21-
op = "!" and path = "core::ops::bit::Not" and method = "not"
22+
op = "!" and path = "core::ops::bit::Not" and method = "not" and borrows = 0
2223
or
2324
// Dereference
24-
op = "*" and path = "core::ops::deref::Deref" and method = "deref"
25+
op = "*" and path = "core::ops::deref::Deref" and method = "deref" and borrows = 0
2526
)
2627
or
2728
arity = 2 and
2829
(
2930
// Comparison operators
30-
op = "==" and path = "core::cmp::PartialEq" and method = "eq"
31+
op = "==" and path = "core::cmp::PartialEq" and method = "eq" and borrows = 2
3132
or
32-
op = "!=" and path = "core::cmp::PartialEq" and method = "ne"
33+
op = "!=" and path = "core::cmp::PartialEq" and method = "ne" and borrows = 2
3334
or
34-
op = "<" and path = "core::cmp::PartialOrd" and method = "lt"
35+
op = "<" and path = "core::cmp::PartialOrd" and method = "lt" and borrows = 2
3536
or
36-
op = "<=" and path = "core::cmp::PartialOrd" and method = "le"
37+
op = "<=" and path = "core::cmp::PartialOrd" and method = "le" and borrows = 2
3738
or
38-
op = ">" and path = "core::cmp::PartialOrd" and method = "gt"
39+
op = ">" and path = "core::cmp::PartialOrd" and method = "gt" and borrows = 2
3940
or
40-
op = ">=" and path = "core::cmp::PartialOrd" and method = "ge"
41+
op = ">=" and path = "core::cmp::PartialOrd" and method = "ge" and borrows = 2
4142
or
4243
// Arithmetic operators
43-
op = "+" and path = "core::ops::arith::Add" and method = "add"
44+
op = "+" and path = "core::ops::arith::Add" and method = "add" and borrows = 0
4445
or
45-
op = "-" and path = "core::ops::arith::Sub" and method = "sub"
46+
op = "-" and path = "core::ops::arith::Sub" and method = "sub" and borrows = 0
4647
or
47-
op = "*" and path = "core::ops::arith::Mul" and method = "mul"
48+
op = "*" and path = "core::ops::arith::Mul" and method = "mul" and borrows = 0
4849
or
49-
op = "/" and path = "core::ops::arith::Div" and method = "div"
50+
op = "/" and path = "core::ops::arith::Div" and method = "div" and borrows = 0
5051
or
51-
op = "%" and path = "core::ops::arith::Rem" and method = "rem"
52+
op = "%" and path = "core::ops::arith::Rem" and method = "rem" and borrows = 0
5253
or
5354
// Arithmetic assignment expressions
54-
op = "+=" and path = "core::ops::arith::AddAssign" and method = "add_assign"
55+
op = "+=" and path = "core::ops::arith::AddAssign" and method = "add_assign" and borrows = 1
5556
or
56-
op = "-=" and path = "core::ops::arith::SubAssign" and method = "sub_assign"
57+
op = "-=" and path = "core::ops::arith::SubAssign" and method = "sub_assign" and borrows = 1
5758
or
58-
op = "*=" and path = "core::ops::arith::MulAssign" and method = "mul_assign"
59+
op = "*=" and path = "core::ops::arith::MulAssign" and method = "mul_assign" and borrows = 1
5960
or
60-
op = "/=" and path = "core::ops::arith::DivAssign" and method = "div_assign"
61+
op = "/=" and path = "core::ops::arith::DivAssign" and method = "div_assign" and borrows = 1
6162
or
62-
op = "%=" and path = "core::ops::arith::RemAssign" and method = "rem_assign"
63+
op = "%=" and path = "core::ops::arith::RemAssign" and method = "rem_assign" and borrows = 1
6364
or
6465
// Bitwise operators
65-
op = "&" and path = "core::ops::bit::BitAnd" and method = "bitand"
66+
op = "&" and path = "core::ops::bit::BitAnd" and method = "bitand" and borrows = 0
6667
or
67-
op = "|" and path = "core::ops::bit::BitOr" and method = "bitor"
68+
op = "|" and path = "core::ops::bit::BitOr" and method = "bitor" and borrows = 0
6869
or
69-
op = "^" and path = "core::ops::bit::BitXor" and method = "bitxor"
70+
op = "^" and path = "core::ops::bit::BitXor" and method = "bitxor" and borrows = 0
7071
or
71-
op = "<<" and path = "core::ops::bit::Shl" and method = "shl"
72+
op = "<<" and path = "core::ops::bit::Shl" and method = "shl" and borrows = 0
7273
or
73-
op = ">>" and path = "core::ops::bit::Shr" and method = "shr"
74+
op = ">>" and path = "core::ops::bit::Shr" and method = "shr" and borrows = 0
7475
or
7576
// Bitwise assignment operators
76-
op = "&=" and path = "core::ops::bit::BitAndAssign" and method = "bitand_assign"
77+
op = "&=" and path = "core::ops::bit::BitAndAssign" and method = "bitand_assign" and borrows = 1
7778
or
78-
op = "|=" and path = "core::ops::bit::BitOrAssign" and method = "bitor_assign"
79+
op = "|=" and path = "core::ops::bit::BitOrAssign" and method = "bitor_assign" and borrows = 1
7980
or
80-
op = "^=" and path = "core::ops::bit::BitXorAssign" and method = "bitxor_assign"
81+
op = "^=" and path = "core::ops::bit::BitXorAssign" and method = "bitxor_assign" and borrows = 1
8182
or
82-
op = "<<=" and path = "core::ops::bit::ShlAssign" and method = "shl_assign"
83+
op = "<<=" and path = "core::ops::bit::ShlAssign" and method = "shl_assign" and borrows = 1
8384
or
84-
op = ">>=" and path = "core::ops::bit::ShrAssign" and method = "shr_assign"
85+
op = ">>=" and path = "core::ops::bit::ShrAssign" and method = "shr_assign" and borrows = 1
8586
)
8687
}
8788

@@ -114,9 +115,9 @@ module Impl {
114115
* Holds if this operation is overloaded to the method `methodName` of the
115116
* trait `trait`.
116117
*/
117-
predicate isOverloaded(Trait trait, string methodName) {
118+
predicate isOverloaded(Trait trait, string methodName, int borrows) {
118119
isOverloaded(this.getOperatorName(), this.getNumberOfOperands(), trait.getCanonicalPath(),
119-
methodName)
120+
methodName, borrows)
120121
}
121122
}
122123
}

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
705705
predicate adjustAccessType(
706706
AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj, Type tAdj
707707
) {
708-
if apos.getArgumentPosition().isSelf() and apos.isBorrowed()
708+
if apos.isBorrowed()
709709
then
710710
exists(Type selfParamType |
711711
selfParamType =
@@ -767,13 +767,10 @@ private Type inferCallExprBaseType(AstNode n, TypePath path) {
767767
|
768768
n = a.getNodeAt(apos) and
769769
result = CallExprBaseMatching::inferAccessType(a, apos, path0) and
770-
// temporary workaround until implicit borrows are handled correctly
771-
if a instanceof Operation then apos.isReturn() else any()
772-
|
773-
if apos.getArgumentPosition().isSelf()
770+
if apos.isBorrowed()
774771
then
775-
exists(Type receiverType | receiverType = inferType(n) |
776-
if receiverType = TRefType()
772+
exists(Type argType | argType = inferType(n) |
773+
if argType = TRefType()
777774
then
778775
path = path0 and
779776
path0.isCons(TRefTypeParameter(), _)
@@ -784,7 +781,7 @@ private Type inferCallExprBaseType(AstNode n, TypePath path) {
784781
path = TypePath::cons(TRefTypeParameter(), path0)
785782
else (
786783
not (
787-
receiverType.(StructType).asItemNode() instanceof StringStruct and
784+
argType.(StructType).asItemNode() instanceof StringStruct and
788785
result.(StructType).asItemNode() instanceof Builtins::Str
789786
) and
790787
(

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1678,7 +1678,7 @@ mod overloadable_operators {
16781678
let vec2_not = !v1; // $ type=vec2_not:Vec2 method=Vec2::not
16791679

16801680
// Here the type of `default_vec2` must be inferred from the `+` call.
1681-
let default_vec2 = Default::default(); // $ MISSING: type=default_vec2:Vec2
1681+
let default_vec2 = Default::default(); // $ type=default_vec2:Vec2
16821682
let vec2_zero_plus = Vec2 { x: 0, y: 0 } + default_vec2; // $ method=Vec2::add
16831683
}
16841684
}

rust/ql/test/library-tests/type-inference/type-inference.expected

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2464,13 +2464,16 @@ inferType
24642464
| main.rs:1678:13:1678:20 | vec2_not | | main.rs:1322:5:1327:5 | Vec2 |
24652465
| main.rs:1678:24:1678:26 | ! ... | | main.rs:1322:5:1327:5 | Vec2 |
24662466
| main.rs:1678:25:1678:26 | v1 | | main.rs:1322:5:1327:5 | Vec2 |
2467+
| main.rs:1681:13:1681:24 | default_vec2 | | main.rs:1322:5:1327:5 | Vec2 |
2468+
| main.rs:1681:28:1681:45 | ...::default(...) | | main.rs:1322:5:1327:5 | Vec2 |
24672469
| main.rs:1682:13:1682:26 | vec2_zero_plus | | main.rs:1322:5:1327:5 | Vec2 |
24682470
| main.rs:1682:30:1682:48 | Vec2 {...} | | main.rs:1322:5:1327:5 | Vec2 |
24692471
| main.rs:1682:30:1682:63 | ... + ... | | main.rs:1322:5:1327:5 | Vec2 |
24702472
| main.rs:1682:40:1682:40 | 0 | | {EXTERNAL LOCATION} | i32 |
24712473
| main.rs:1682:40:1682:40 | 0 | | {EXTERNAL LOCATION} | i64 |
24722474
| main.rs:1682:46:1682:46 | 0 | | {EXTERNAL LOCATION} | i32 |
24732475
| main.rs:1682:46:1682:46 | 0 | | {EXTERNAL LOCATION} | i64 |
2476+
| main.rs:1682:52:1682:63 | default_vec2 | | main.rs:1322:5:1327:5 | Vec2 |
24742477
| main.rs:1692:18:1692:21 | SelfParam | | main.rs:1689:5:1689:14 | S1 |
24752478
| main.rs:1695:25:1697:5 | { ... } | | main.rs:1689:5:1689:14 | S1 |
24762479
| main.rs:1696:9:1696:10 | S1 | | main.rs:1689:5:1689:14 | S1 |

0 commit comments

Comments
 (0)