Skip to content

Commit

Permalink
[red-knot] Decompose bool to Literal[True, False] in unions and i…
Browse files Browse the repository at this point in the history
…ntersections
  • Loading branch information
AlexWaygood committed Jan 25, 2025
1 parent f85ea1b commit 0ff14bf
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,22 @@ reveal_type(c >= d) # revealed: Literal[True]
#### Results with Ambiguity

```py
def _(x: bool, y: int):
class P:
def __lt__(self, other: "P") -> bool:
return True

def __le__(self, other: "P") -> bool:
return True

def __gt__(self, other: "P") -> bool:
return True

def __ge__(self, other: "P") -> bool:
return True

class Q(P): ...

def _(x: P, y: Q):
a = (x,)
b = (y,)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,9 @@ else:
reveal_type(x) # revealed: slice
finally:
# TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice`
reveal_type(x) # revealed: bool | float | slice
reveal_type(x) # revealed: bool | slice | float

reveal_type(x) # revealed: bool | float | slice
reveal_type(x) # revealed: bool | slice | float
```

## Nested `try`/`except` blocks
Expand Down Expand Up @@ -534,7 +534,7 @@ try:
reveal_type(x) # revealed: slice
finally:
# TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice`
reveal_type(x) # revealed: bool | float | slice
reveal_type(x) # revealed: bool | slice | float
x = 2
reveal_type(x) # revealed: Literal[2]
reveal_type(x) # revealed: Literal[2]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@ else:
if x and not x:
reveal_type(x) # revealed: Never
else:
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None

if not (x and not x):
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
else:
reveal_type(x) # revealed: Never

if x or not x:
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
else:
reveal_type(x) # revealed: Never

if not (x or not x):
reveal_type(x) # revealed: Never
else:
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None

if (isinstance(x, int) or isinstance(x, str)) and x:
reveal_type(x) # revealed: Literal[-1, True, "foo"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,48 @@ static_assert(
)
```

## Unions containing tuples containing `bool`

```py
from knot_extensions import is_equivalent_to, static_assert
from typing_extensions import Literal

class P: ...

static_assert(is_equivalent_to(tuple[Literal[True, False]] | P, tuple[bool] | P))
static_assert(is_equivalent_to(P | tuple[bool], P | tuple[Literal[True, False]]))
```

## Unions and intersections involving `AlwaysTruthy`, `bool` and `AlwaysFalsy`

```py
from knot_extensions import AlwaysTruthy, AlwaysFalsy, static_assert, is_equivalent_to, Not
from typing_extensions import Literal

static_assert(is_equivalent_to(AlwaysTruthy | bool, Literal[False] | AlwaysTruthy))
static_assert(is_equivalent_to(AlwaysFalsy | bool, Literal[True] | AlwaysFalsy))
static_assert(is_equivalent_to(Not[AlwaysTruthy] | bool, Not[AlwaysTruthy] | Literal[True]))
static_assert(is_equivalent_to(Not[AlwaysFalsy] | bool, Literal[False] | Not[AlwaysFalsy]))
```

## Unions and intersections involving `AlwaysTruthy`, `LiteralString` and `AlwaysFalsy`

```py
from knot_extensions import AlwaysTruthy, AlwaysFalsy, static_assert, is_equivalent_to, Not, Intersection
from typing_extensions import Literal, LiteralString

# TODO: these should all pass!

# error: [static-assert-error]
static_assert(is_equivalent_to(AlwaysTruthy | LiteralString, Literal[""] | AlwaysTruthy))
# error: [static-assert-error]
static_assert(is_equivalent_to(AlwaysFalsy | LiteralString, Intersection[LiteralString, Not[Literal[""]]] | AlwaysFalsy))
# error: [static-assert-error]
static_assert(is_equivalent_to(Not[AlwaysFalsy] | LiteralString, Literal[""] | Not[AlwaysFalsy]))
# error: [static-assert-error]
static_assert(
is_equivalent_to(Not[AlwaysTruthy] | LiteralString, Not[AlwaysTruthy] | Intersection[LiteralString, Not[Literal[""]]])
)
```

[the equivalence relation]: https://typing.readthedocs.io/en/latest/spec/glossary.html#term-equivalent
71 changes: 64 additions & 7 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,31 @@ impl<'db> Type<'db> {
}
}

/// Normalize the type `bool` -> `Literal[True, False]`.
///
/// Using this method in various type-relational methods
/// ensures that the following invariants hold true:
///
/// - bool ≡ Literal[True, False]
/// - bool | T ≡ Literal[True, False] | T
/// - bool <: Literal[True, False]
/// - bool | T <: Literal[True, False] | T
/// - Literal[True, False] <: bool
/// - Literal[True, False] | T <: bool | T
#[must_use]
pub fn with_normalized_bools(self, db: &'db dyn Db) -> Self {
const LITERAL_BOOLS: [Type; 2] = [Type::BooleanLiteral(false), Type::BooleanLiteral(true)];

match self {
Type::Instance(InstanceType { class }) if class.is_known(db, KnownClass::Bool) => {
Type::Union(UnionType::new(db, Box::from(LITERAL_BOOLS)))
}
// TODO: decompose `LiteralString` into `Literal[""] | TruthyLiteralString`?
// We'd need to rename this method... --Alex
_ => self,
}
}

/// Return a normalized version of `self` in which all unions and intersections are sorted
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
#[must_use]
Expand Down Expand Up @@ -859,6 +884,12 @@ impl<'db> Type<'db> {
return true;
}

let normalized_self = self.with_normalized_bools(db);
let normalized_target = target.with_normalized_bools(db);
if normalized_self != self || normalized_target != target {
return normalized_self.is_subtype_of(db, normalized_target);
}

// Non-fully-static types do not participate in subtyping.
//
// Type `A` can only be a subtype of type `B` if the set of possible runtime objects
Expand Down Expand Up @@ -961,7 +992,7 @@ impl<'db> Type<'db> {
KnownClass::Str.to_instance(db).is_subtype_of(db, target)
}
(Type::BooleanLiteral(_), _) => {
KnownClass::Bool.to_instance(db).is_subtype_of(db, target)
KnownClass::Int.to_instance(db).is_subtype_of(db, target)
}
(Type::IntLiteral(_), _) => KnownClass::Int.to_instance(db).is_subtype_of(db, target),
(Type::BytesLiteral(_), _) => {
Expand Down Expand Up @@ -1077,6 +1108,11 @@ impl<'db> Type<'db> {
if self.is_gradual_equivalent_to(db, target) {
return true;
}
let normalized_self = self.with_normalized_bools(db);
let normalized_target = target.with_normalized_bools(db);
if normalized_self != self || normalized_target != target {
return normalized_self.is_assignable_to(db, normalized_target);
}
match (self, target) {
// Never can be assigned to any type.
(Type::Never, _) => true,
Expand Down Expand Up @@ -1177,6 +1213,13 @@ impl<'db> Type<'db> {
pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
// TODO equivalent but not identical types: TypedDicts, Protocols, type aliases, etc.

let normalized_self = self.with_normalized_bools(db);
let normalized_other = other.with_normalized_bools(db);

if normalized_self != self || normalized_other != other {
return normalized_self.is_equivalent_to(db, normalized_other);
}

match (self, other) {
(Type::Union(left), Type::Union(right)) => left.is_equivalent_to(db, right),
(Type::Intersection(left), Type::Intersection(right)) => {
Expand Down Expand Up @@ -1218,6 +1261,13 @@ impl<'db> Type<'db> {
///
/// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations
pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
let normalized_self = self.with_normalized_bools(db);
let normalized_other = other.with_normalized_bools(db);

if normalized_self != self || normalized_other != other {
return normalized_self.is_gradual_equivalent_to(db, normalized_other);
}

if self == other {
return true;
}
Expand Down Expand Up @@ -1250,6 +1300,12 @@ impl<'db> Type<'db> {
/// Note: This function aims to have no false positives, but might return
/// wrong `false` answers in some cases.
pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool {
let normalized_self = self.with_normalized_bools(db);
let normalized_other = other.with_normalized_bools(db);
if normalized_self != self || normalized_other != other {
return normalized_self.is_disjoint_from(db, normalized_other);
}

match (self, other) {
(Type::Never, _) | (_, Type::Never) => true,

Expand Down Expand Up @@ -4624,18 +4680,19 @@ pub struct TupleType<'db> {
}

impl<'db> TupleType<'db> {
pub fn from_elements<T: Into<Type<'db>>>(
db: &'db dyn Db,
types: impl IntoIterator<Item = T>,
) -> Type<'db> {
pub fn from_elements<I, T>(db: &'db dyn Db, types: I) -> Type<'db>
where
I: IntoIterator<Item = T>,
T: Into<Type<'db>>,
{
let mut elements = vec![];

for ty in types {
let ty = ty.into();
let ty: Type<'db> = ty.into();
if ty.is_never() {
return Type::Never;
}
elements.push(ty);
elements.push(ty.with_normalized_bools(db));
}

Type::Tuple(Self::new(db, elements.into_boxed_slice()))
Expand Down
Loading

0 comments on commit 0ff14bf

Please sign in to comment.