Skip to content

Commit 13375d0

Browse files
authored
[ty] Use the top materialization of classes for narrowing in class-patterns for match statements (#21150)
1 parent c0b04d4 commit 13375d0

File tree

4 files changed

+113
-8
lines changed

4 files changed

+113
-8
lines changed

crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,21 @@ def match_non_exhaustive(x: Color):
182182

183183
## `isinstance` checks
184184

185+
```toml
186+
[environment]
187+
python-version = "3.12"
188+
```
189+
185190
```py
186191
from typing import assert_never
187192

188193
class A: ...
189194
class B: ...
190195
class C: ...
191196

197+
class GenericClass[T]:
198+
x: T
199+
192200
def if_else_exhaustive(x: A | B | C):
193201
if isinstance(x, A):
194202
pass
@@ -253,6 +261,17 @@ def match_non_exhaustive(x: A | B | C):
253261

254262
# this diagnostic is correct: the inferred type of `x` is `B & ~A & ~C`
255263
assert_never(x) # error: [type-assertion-failure]
264+
265+
# Note: no invalid-return-type diagnostic; the `match` is exhaustive
266+
def match_exhaustive_generic[T](obj: GenericClass[T]) -> GenericClass[T]:
267+
match obj:
268+
case GenericClass(x=42):
269+
reveal_type(obj) # revealed: GenericClass[T@match_exhaustive_generic]
270+
return obj
271+
case GenericClass(x=x):
272+
reveal_type(x) # revealed: @Todo(`match` pattern definition types)
273+
reveal_type(obj) # revealed: GenericClass[T@match_exhaustive_generic]
274+
return obj
256275
```
257276

258277
## `isinstance` checks with generics

crates/ty_python_semantic/resources/mdtest/narrow/match.md

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,81 @@ match x:
6969
reveal_type(x) # revealed: object
7070
```
7171

72+
## Class patterns with generic classes
73+
74+
```toml
75+
[environment]
76+
python-version = "3.12"
77+
```
78+
79+
```py
80+
from typing import assert_never
81+
82+
class Covariant[T]:
83+
def get(self) -> T:
84+
raise NotImplementedError
85+
86+
def f(x: Covariant[int]):
87+
match x:
88+
case Covariant():
89+
reveal_type(x) # revealed: Covariant[int]
90+
case _:
91+
reveal_type(x) # revealed: Never
92+
assert_never(x)
93+
```
94+
95+
## Class patterns with generic `@final` classes
96+
97+
These work the same as non-`@final` classes.
98+
99+
```toml
100+
[environment]
101+
python-version = "3.12"
102+
```
103+
104+
```py
105+
from typing import assert_never, final
106+
107+
@final
108+
class Covariant[T]:
109+
def get(self) -> T:
110+
raise NotImplementedError
111+
112+
def f(x: Covariant[int]):
113+
match x:
114+
case Covariant():
115+
reveal_type(x) # revealed: Covariant[int]
116+
case _:
117+
reveal_type(x) # revealed: Never
118+
assert_never(x)
119+
```
120+
121+
## Class patterns where the class pattern does not resolve to a class
122+
123+
In general this does not allow for narrowing, but we make an exception for `Any`. This is to support
124+
[real ecosystem code](https://github.com/jax-ml/jax/blob/d2ce04b6c3d03ae18b145965b8b8b92e09e8009c/jax/_src/pallas/mosaic_gpu/lowering.py#L3372-L3387)
125+
found in `jax`.
126+
127+
```py
128+
from typing import Any
129+
130+
X = Any
131+
132+
def f(obj: object):
133+
match obj:
134+
case int():
135+
reveal_type(obj) # revealed: int
136+
case X():
137+
reveal_type(obj) # revealed: Any & ~int
138+
139+
def g(obj: object, Y: Any):
140+
match obj:
141+
case int():
142+
reveal_type(obj) # revealed: int
143+
case Y():
144+
reveal_type(obj) # revealed: Any & ~int
145+
```
146+
72147
## Value patterns
73148

74149
Value patterns are evaluated by equality, which is overridable. Therefore successfully matching on

crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -771,8 +771,9 @@ impl ReachabilityConstraints {
771771
truthiness
772772
}
773773
PatternPredicateKind::Class(class_expr, kind) => {
774-
let class_ty =
775-
infer_expression_type(db, *class_expr, TypeContext::default()).to_instance(db);
774+
let class_ty = infer_expression_type(db, *class_expr, TypeContext::default())
775+
.as_class_literal()
776+
.map(|class| Type::instance(db, class.top_materialization(db)));
776777

777778
class_ty.map_or(Truthiness::Ambiguous, |class_ty| {
778779
if subject_ty.is_subtype_of(db, class_ty) {

crates/ty_python_semantic/src/types/narrow.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ use crate::types::enums::{enum_member_literals, enum_metadata};
1111
use crate::types::function::KnownFunction;
1212
use crate::types::infer::infer_same_file_expression_type;
1313
use crate::types::{
14-
ClassLiteral, ClassType, IntersectionBuilder, KnownClass, SubclassOfInner, SubclassOfType,
15-
Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types,
14+
ClassLiteral, ClassType, IntersectionBuilder, KnownClass, SpecialFormType, SubclassOfInner,
15+
SubclassOfType, Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder,
16+
infer_expression_types,
1617
};
1718

1819
use ruff_db::parsed::{ParsedModuleRef, parsed_module};
@@ -962,11 +963,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
962963
let subject = place_expr(subject.node_ref(self.db, self.module))?;
963964
let place = self.expect_place(&subject);
964965

965-
let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module)
966-
.to_instance(self.db)?;
966+
let class_type =
967+
infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module);
967968

968-
let ty = ty.negate_if(self.db, !is_positive);
969-
Some(NarrowingConstraints::from_iter([(place, ty)]))
969+
let narrowed_type = match class_type {
970+
Type::ClassLiteral(class) => {
971+
Type::instance(self.db, class.top_materialization(self.db))
972+
.negate_if(self.db, !is_positive)
973+
}
974+
dynamic @ Type::Dynamic(_) => dynamic,
975+
Type::SpecialForm(SpecialFormType::Any) => Type::any(),
976+
_ => return None,
977+
};
978+
979+
Some(NarrowingConstraints::from_iter([(place, narrowed_type)]))
970980
}
971981

972982
fn evaluate_match_pattern_value(

0 commit comments

Comments
 (0)