1
1
import inspect
2
- from types import NoneType
2
+ from types import GenericAlias , NoneType
3
3
import types
4
4
import typing
5
5
from typing import (
10
10
Optional ,
11
11
Set ,
12
12
Tuple ,
13
+ TypeAliasType ,
13
14
TypeVar ,
14
15
Generic ,
15
16
Dict ,
16
17
Type ,
17
18
Union ,
19
+ cast ,
18
20
)
19
21
import functools
20
22
from dataclasses import dataclass
21
23
22
24
23
25
class GenericInstance :
24
- origin : type
26
+ origin : 'VarType'
25
27
args : List ["VarType" ]
26
28
27
- def __init__ (self , origin : type , args : List ["VarType" ]):
29
+ def __init__ (self , origin : 'VarType' , args : List ["VarType" ]):
28
30
self .origin = origin
29
31
self .args = args
30
32
@@ -41,6 +43,9 @@ def __init__(self, types: List["VarType"]):
41
43
def __repr__ (self ):
42
44
return f"Union[{ ', ' .join (map (repr , self .types ))} ]"
43
45
46
+ def substitute (self , env : Dict [TypeVar , 'VarType' ]) -> "UnionType" :
47
+ return UnionType ([subst_type (ty , env ) for ty in self .types ])
48
+
44
49
45
50
class AnyType :
46
51
def __repr__ (self ):
@@ -56,7 +61,8 @@ def __repr__(self):
56
61
57
62
def __eq__ (self , other ):
58
63
return isinstance (other , SelfType )
59
-
64
+
65
+
60
66
class LiteralType :
61
67
value : Any
62
68
@@ -70,7 +76,23 @@ def __eq__(self, other):
70
76
return isinstance (other , LiteralType ) and self .value == other .value
71
77
72
78
73
- VarType = Union [TypeVar , Type , GenericInstance , UnionType , SelfType , AnyType , LiteralType ]
79
+ class AnnotatedType :
80
+ origin : 'VarType'
81
+ annotations : List [Any ]
82
+
83
+ def __init__ (self , origin : 'VarType' , annotations : List [Any ]):
84
+ self .origin = origin
85
+ self .annotations = annotations
86
+
87
+ def __repr__ (self ):
88
+ return f"Annotated[{ self .origin } , { self .annotations } ]"
89
+
90
+ def substitute (self , env : Dict [TypeVar , 'VarType' ]) -> "AnnotatedType" :
91
+ return AnnotatedType (subst_type (self .origin , env ), self .annotations )
92
+
93
+
94
+ type VarType = Union [TypeVar , Type , GenericInstance ,
95
+ UnionType , SelfType , AnyType , LiteralType , AnnotatedType ]
74
96
75
97
76
98
def subst_type (ty : VarType , env : Dict [TypeVar , VarType ]) -> VarType :
@@ -79,6 +101,8 @@ def subst_type(ty: VarType, env: Dict[TypeVar, VarType]) -> VarType:
79
101
return env .get (ty , ty )
80
102
case GenericInstance (origin = origin , args = args ):
81
103
return GenericInstance (origin , [subst_type (arg , env ) for arg in args ])
104
+ case MethodType () | UnionType () | AnnotatedType ():
105
+ return ty .substitute (env )
82
106
case _:
83
107
return ty
84
108
@@ -140,7 +164,8 @@ def __repr__(self):
140
164
def instantiate (self , type_args : List [VarType ]) -> "ClassType" :
141
165
if len (type_args ) != len (self .type_vars ):
142
166
raise RuntimeError (
143
- f"Expected { len (self .type_vars )} type arguments but got { len (type_args )} "
167
+ f"Expected { len (self .type_vars )} " +
168
+ f"type arguments but got { len (type_args )} "
144
169
)
145
170
env = dict (zip (self .type_vars , type_args ))
146
171
return ClassType (
@@ -172,7 +197,8 @@ def _get_base_classinfo(cls: type, globalns) -> List[tuple[str, ClassType]]:
172
197
for base in cls .__orig_bases__ :
173
198
if hasattr (base , "__origin__" ):
174
199
base_params = []
175
- base_orig = base .__origin__
200
+ base_orig : Any = base .__origin__
201
+
176
202
if not _is_class_registered (base_orig ) and base_orig not in _BUILTIN_ANNOTATION_BASES :
177
203
raise RuntimeError (
178
204
f"Base class { base_orig } of { cls } is not registered."
@@ -185,7 +211,8 @@ def _get_base_classinfo(cls: type, globalns) -> List[tuple[str, ClassType]]:
185
211
if base_orig in _BUILTIN_ANNOTATION_BASES :
186
212
pass
187
213
else :
188
- base_info = class_typeinfo (base_orig )
214
+ assert isinstance (base_orig , type )
215
+ base_info = class_typeinfo (cast (type , base_orig ))
189
216
info .append (
190
217
(base .__name__ , base_info .instantiate (base_params )))
191
218
else :
@@ -210,19 +237,40 @@ def parse_type_hint(hint: Any) -> VarType:
210
237
return UnionType ([parse_type_hint (arg ) for arg in hint .__args__ ])
211
238
if hint is typing .Any :
212
239
return AnyType ()
240
+ if isinstance (hint , TypeAliasType ):
241
+ return parse_type_hint (hint .__value__ )
242
+
213
243
origin = typing .get_origin (hint )
214
244
if origin :
215
- if isinstance (origin , type ):
216
- # assert isinstance(origin, type), f"origin must be a type but got {origin}"
217
- args = list (typing .get_args (hint ))
218
- return GenericInstance (origin , [parse_type_hint (arg ) for arg in args ])
245
+ if origin is typing .Annotated :
246
+ annotate_args = typing .get_args (hint )
247
+ return AnnotatedType (parse_type_hint (annotate_args [0 ]), list (annotate_args [1 :]))
219
248
elif origin is Union :
220
249
return UnionType ([parse_type_hint (arg ) for arg in typing .get_args (hint )])
221
250
elif origin is Literal :
222
251
return LiteralType (typing .get_args (hint )[0 ])
252
+ elif isinstance (origin , TypeAliasType ):
253
+ def do () -> VarType :
254
+ assert isinstance (hint , GenericAlias )
255
+ args = list (typing .get_args (hint ))
256
+ assert len (args ) == len (origin .__parameters__ ), f"Expected {
257
+ len (origin .__parameters__ )} type arguments but got { len (args )} "
258
+ true_origin = origin .__value__
259
+ parametric_args = origin .__parameters__
260
+ parsed_args = [parse_type_hint (arg ) for arg in args ]
261
+ env = dict (zip (parametric_args , parsed_args ))
262
+ parsed_origin = parse_type_hint (true_origin )
263
+ return subst_type (parsed_origin , env )
264
+ return do ()
265
+ elif isinstance (origin , type ):
266
+ # assert isinstance(origin, type), f"origin must be a type but got {origin}"
267
+ args = list (typing .get_args (hint ))
268
+ return GenericInstance (origin , [parse_type_hint (arg ) for arg in args ])
269
+
223
270
else :
224
- raise RuntimeError (f"Unsupported origin type: { origin } " )
225
-
271
+ raise RuntimeError (f"Unsupported origin type: {
272
+ origin } , { type (origin ), type (hint )} " )
273
+
226
274
if isinstance (hint , type ):
227
275
return hint
228
276
if hint == typing .Self :
@@ -242,7 +290,7 @@ def extract_type_vars_from_hint(hint: typing.Any) -> List[TypeVar]:
242
290
243
291
244
292
def get_type_vars (func : typing .Callable ) -> List [TypeVar ]:
245
- type_hints = typing .get_type_hints (func )
293
+ type_hints = typing .get_type_hints (func , include_extras = True )
246
294
type_vars = []
247
295
for hint in type_hints .values ():
248
296
type_vars .extend (extract_type_vars_from_hint (hint ))
0 commit comments