diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 893ad9f9f738c..46a5aec99d222 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -824,11 +824,9 @@ impl<'db> Type<'db> { /// - Literal[True, False] | T <: bool | T #[must_use] pub fn with_normalized_bools(self, db: &'db dyn Db) -> Self { - const LITERAL_BOOLS: [Type; 2] = [Type::BooleanLiteral(false), Type::BooleanLiteral(true)]; - match self { Type::Instance(InstanceType { class }) if class.is_known(db, KnownClass::Bool) => { - Type::Union(UnionType::new(db, Box::from(LITERAL_BOOLS))) + Type::normalized_bool(db) } // TODO: decompose `LiteralString` into `Literal[""] | TruthyLiteralString`? // We'd need to rename this method... --Alex @@ -884,12 +882,6 @@ impl<'db> Type<'db> { return true; } - let normalized_self = self.with_normalized_bools(db); - let normalized_target = target.with_normalized_bools(db); - if normalized_self != self || normalized_target != target { - return normalized_self.is_subtype_of(db, normalized_target); - } - // Non-fully-static types do not participate in subtyping. // // Type `A` can only be a subtype of type `B` if the set of possible runtime objects @@ -912,6 +904,13 @@ impl<'db> Type<'db> { (Type::Never, _) => true, (_, Type::Never) => false, + (Type::Instance(InstanceType { class }), _) if class.is_known(db, KnownClass::Bool) => { + Type::normalized_bool(db).is_subtype_of(db, target) + } + (_, Type::Instance(InstanceType { class })) if class.is_known(db, KnownClass::Bool) => { + self.is_subtype_of(db, Type::normalized_bool(db)) + } + (Type::Union(union), _) => union .elements(db) .iter() @@ -1108,11 +1107,6 @@ impl<'db> Type<'db> { if self.is_gradual_equivalent_to(db, target) { return true; } - let normalized_self = self.with_normalized_bools(db); - let normalized_target = target.with_normalized_bools(db); - if normalized_self != self || normalized_target != target { - return normalized_self.is_assignable_to(db, normalized_target); - } match (self, target) { // Never can be assigned to any type. (Type::Never, _) => true, @@ -1129,6 +1123,13 @@ impl<'db> Type<'db> { true } + (Type::Instance(InstanceType { class }), _) if class.is_known(db, KnownClass::Bool) => { + Type::normalized_bool(db).is_assignable_to(db, target) + } + (_, Type::Instance(InstanceType { class })) if class.is_known(db, KnownClass::Bool) => { + self.is_assignable_to(db, Type::normalized_bool(db)) + } + // A union is assignable to a type T iff every element of the union is assignable to T. (Type::Union(union), ty) => union .elements(db) @@ -1213,19 +1214,18 @@ impl<'db> Type<'db> { pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool { // TODO equivalent but not identical types: TypedDicts, Protocols, type aliases, etc. - let normalized_self = self.with_normalized_bools(db); - let normalized_other = other.with_normalized_bools(db); - - if normalized_self != self || normalized_other != other { - return normalized_self.is_equivalent_to(db, normalized_other); - } - match (self, other) { (Type::Union(left), Type::Union(right)) => left.is_equivalent_to(db, right), (Type::Intersection(left), Type::Intersection(right)) => { left.is_equivalent_to(db, right) } (Type::Tuple(left), Type::Tuple(right)) => left.is_equivalent_to(db, right), + (Type::Instance(InstanceType { class }), _) if class.is_known(db, KnownClass::Bool) => { + Type::normalized_bool(db).is_equivalent_to(db, other) + } + (_, Type::Instance(InstanceType { class })) if class.is_known(db, KnownClass::Bool) => { + self.is_equivalent_to(db, Type::normalized_bool(db)) + } _ => self == other && self.is_fully_static(db) && other.is_fully_static(db), } } @@ -1261,13 +1261,6 @@ impl<'db> Type<'db> { /// /// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool { - let normalized_self = self.with_normalized_bools(db); - let normalized_other = other.with_normalized_bools(db); - - if normalized_self != self || normalized_other != other { - return normalized_self.is_gradual_equivalent_to(db, normalized_other); - } - if self == other { return true; } @@ -1291,6 +1284,13 @@ impl<'db> Type<'db> { first.is_gradual_equivalent_to(db, second) } + (Type::Instance(InstanceType { class }), _) if class.is_known(db, KnownClass::Bool) => { + Type::normalized_bool(db).is_gradual_equivalent_to(db, other) + } + (_, Type::Instance(InstanceType { class })) if class.is_known(db, KnownClass::Bool) => { + self.is_gradual_equivalent_to(db, Type::normalized_bool(db)) + } + _ => false, } } @@ -1300,17 +1300,18 @@ impl<'db> Type<'db> { /// Note: This function aims to have no false positives, but might return /// wrong `false` answers in some cases. pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool { - let normalized_self = self.with_normalized_bools(db); - let normalized_other = other.with_normalized_bools(db); - if normalized_self != self || normalized_other != other { - return normalized_self.is_disjoint_from(db, normalized_other); - } - match (self, other) { (Type::Never, _) | (_, Type::Never) => true, (Type::Dynamic(_), _) | (_, Type::Dynamic(_)) => false, + (Type::Instance(InstanceType { class }), ty) + | (ty, Type::Instance(InstanceType { class })) + if class.is_known(db, KnownClass::Bool) => + { + Type::normalized_bool(db).is_disjoint_from(db, ty) + } + (Type::Union(union), other) | (other, Type::Union(union)) => union .elements(db) .iter() @@ -2427,6 +2428,13 @@ impl<'db> Type<'db> { KnownClass::NoneType.to_instance(db) } + /// The type `Literal[True, False]`, which is exactly equivalent to `bool` + /// (and which `bool` is eagerly normalized to in several situations) + pub fn normalized_bool(db: &'db dyn Db) -> Type<'db> { + const LITERAL_BOOLS: [Type; 2] = [Type::BooleanLiteral(false), Type::BooleanLiteral(true)]; + Type::Union(UnionType::new(db, Box::from(LITERAL_BOOLS))) + } + /// Return the type of `tuple(sys.version_info)`. /// /// This is not exactly the type that `sys.version_info` has at runtime,