diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index b97c78cba2fc..02d5a8718c41 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -98,7 +98,13 @@ ) from mypy.semanal_enum import ENUM_BASES from mypy.state import state -from mypy.subtypes import is_equivalent, is_same_type, is_subtype, non_method_protocol_members +from mypy.subtypes import ( + is_equivalent, + is_proper_subtype, + is_same_type, + is_subtype, + non_method_protocol_members, +) from mypy.traverser import has_await_expression from mypy.typeanal import ( check_for_explicit_any, @@ -259,7 +265,7 @@ class ExpressionChecker(ExpressionVisitor[Type]): type_context: list[Type | None] # cache resolved types in some cases - resolved_type: dict[Expression, ProperType] + resolved_type: dict[tuple[Expression, Type | None], ProperType] strfrm_checker: StringFormatterChecker plugin: Plugin @@ -3994,34 +4000,59 @@ def fast_container_type( self, e: ListExpr | SetExpr | TupleExpr, container_fullname: str ) -> Type | None: """ - Fast path to determine the type of a list or set literal, - based on the list of entries. This mostly impacts large - module-level constant definitions. + Fast path to determine the type of a list or set literal, based on the list of entries. + This mostly impacts large constant definitions. Limitations: - no active type context - no star expressions - the joined type of all entries must be an Instance or Tuple type """ - ctx = self.type_context[-1] - if ctx: + ctx = get_proper_type(self.type_context[-1]) + # TODO: can we safely allow TypeVarType (with appropriate upper_bound or values) ? + if not ctx or isinstance(ctx, AnyType): + return self._fast_container_type(e, container_fullname) + + # TODO: support relevant typing base classes (Sequence, AbstractSet, ...) ? + if not (isinstance(ctx, Instance) and ctx.type.fullname == container_fullname): return None - rt = self.resolved_type.get(e, None) + + vt = get_proper_type(ctx.args[0]) + if not (isinstance(vt, Instance) and allow_fast_container_literal(vt)): + return None + + with self.msg.filter_errors() as w: + rt = self._fast_container_type(e, container_fullname) + if w.has_new_errors(): + # fallback to slow path if we run into any errors here + return None + if rt and is_proper_subtype(rt.args[0], vt): + # let the type context win if we inferred a more precise type (List is invariant...) + return ctx + return None + + def _fast_container_type( + self, + e: ListExpr | SetExpr | TupleExpr, + container_fullname: str, + ctx: Optional[Instance] = None, + ) -> Instance | None: + rt = self.resolved_type.get((e, ctx), None) if rt is not None: return rt if isinstance(rt, Instance) else None values: list[Type] = [] for item in e.items: if isinstance(item, StarExpr): # fallback to slow path - self.resolved_type[e] = NoneType() + self.resolved_type[(e, ctx)] = NoneType() return None - values.append(self.accept(item)) + values.append(self.accept(item, type_context=ctx.args[0] if ctx else None)) vt = join.join_type_list(values) if not allow_fast_container_literal(vt): - self.resolved_type[e] = NoneType() + self.resolved_type[(e, ctx)] = NoneType() return None ct = self.chk.named_generic_type(container_fullname, [vt]) - self.resolved_type[e] = ct + self.resolved_type[(e, ctx)] = ct return ct def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag: str) -> Type: @@ -4116,19 +4147,53 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: def fast_dict_type(self, e: DictExpr) -> Type | None: """ - Fast path to determine the type of a dict literal, - based on the list of entries. This mostly impacts large - module-level constant definitions. + Fast path to determine the type of a dict literal, based on the list of entries. + This mostly impacts large module-level constant definitions. + + Will only trigger for the following type contexts: + - None (i.e no type context) + - AnyType + - Dict[K, V] where K and V are both Instance or Tuple types + """ + ctx = get_proper_type(self.type_context[-1]) + # TODO: can we safely allow TypeVarType (with appropriate upper_bound or values) ? + if not ctx or isinstance(ctx, AnyType): + return self._fast_dict_type(e) + + # TODO: support relevant typing base classes (Mapping, MutableMapping, ...) ? + if not (isinstance(ctx, Instance) and ctx.type.fullname == "builtins.dict"): + return None + + kt = get_proper_type(ctx.args[0]) + vt = get_proper_type(ctx.args[1]) + if not ( + isinstance(kt, Instance) + and isinstance(vt, Instance) + and allow_fast_container_literal(kt) + and allow_fast_container_literal(vt) + ): + return None + + with self.msg.filter_errors() as w: + rt = self._fast_dict_type(e, ctx) + if w.has_new_errors(): + # fallback to slow path if we run into any errors here + return None + if rt and is_proper_subtype(rt.args[0], kt) and is_proper_subtype(rt.args[1], vt): + # let the type context win if we inferred a more precise type (Dict is invariant...) + return ctx + return None + + def _fast_dict_type(self, e: DictExpr, ctx: Optional[Instance] = None) -> Instance | None: + """ + Fast path to determine the type of a dict literal, based on the list of entries. + This mostly impacts large constant definitions. Limitations: - - no active type context - only supported star expressions are other dict instances - the joined types of all keys and values must be Instance or Tuple types """ - ctx = self.type_context[-1] - if ctx: - return None - rt = self.resolved_type.get(e, None) + rt = self.resolved_type.get((e, ctx), None) if rt is not None: return rt if isinstance(rt, Instance) else None keys: list[Type] = [] @@ -4136,7 +4201,7 @@ def fast_dict_type(self, e: DictExpr) -> Type | None: stargs: tuple[Type, Type] | None = None for key, value in e.items: if key is None: - st = get_proper_type(self.accept(value)) + st = get_proper_type(self.accept(value, type_context=ctx.args[1] if ctx else None)) if ( isinstance(st, Instance) and st.type.fullname == "builtins.dict" @@ -4144,21 +4209,21 @@ def fast_dict_type(self, e: DictExpr) -> Type | None: ): stargs = (st.args[0], st.args[1]) else: - self.resolved_type[e] = NoneType() + self.resolved_type[(e, ctx)] = NoneType() return None else: - keys.append(self.accept(key)) - values.append(self.accept(value)) + keys.append(self.accept(key, type_context=ctx.args[0] if ctx else None)) + values.append(self.accept(value, type_context=ctx.args[1] if ctx else None)) kt = join.join_type_list(keys) vt = join.join_type_list(values) if not (allow_fast_container_literal(kt) and allow_fast_container_literal(vt)): - self.resolved_type[e] = NoneType() + self.resolved_type[(e, ctx)] = NoneType() return None if stargs and (stargs[0] != kt or stargs[1] != vt): - self.resolved_type[e] = NoneType() + self.resolved_type[(e, ctx)] = NoneType() return None dt = self.chk.named_generic_type("builtins.dict", [kt, vt]) - self.resolved_type[e] = dt + self.resolved_type[(e, ctx)] = dt return dt def visit_dict_expr(self, e: DictExpr) -> Type: