@@ -29,6 +29,7 @@ def __init__(self, origin: type, args: List["VarType"]):
29
29
def __repr__ (self ):
30
30
return f"{ self .origin } [{ ', ' .join (map (repr , self .args ))} ]"
31
31
32
+
32
33
class UnionType :
33
34
types : List ["VarType" ]
34
35
@@ -38,7 +39,23 @@ def __init__(self, types: List["VarType"]):
38
39
def __repr__ (self ):
39
40
return f"Union[{ ', ' .join (map (repr , self .types ))} ]"
40
41
41
- VarType = Union [TypeVar , Type , GenericInstance , UnionType ]
42
+ class AnyType :
43
+ def __repr__ (self ):
44
+ return "Any"
45
+
46
+ def __eq__ (self , other ):
47
+ return isinstance (other , AnyType )
48
+
49
+ class SelfType :
50
+ def __repr__ (self ):
51
+ return "Self"
52
+
53
+ def __eq__ (self , other ):
54
+ return isinstance (other , SelfType )
55
+
56
+
57
+ VarType = Union [TypeVar , Type , GenericInstance , UnionType , SelfType , AnyType ]
58
+
42
59
43
60
def subst_type (ty : VarType , env : Dict [TypeVar , VarType ]) -> VarType :
44
61
match ty :
@@ -49,24 +66,27 @@ def subst_type(ty: VarType, env: Dict[TypeVar, VarType]) -> VarType:
49
66
case _:
50
67
return ty
51
68
69
+
52
70
class MethodType :
53
71
type_vars : List [TypeVar ]
54
72
args : List [VarType ]
55
73
return_type : VarType
56
74
env : Dict [TypeVar , VarType ]
75
+ is_static : bool
57
76
58
77
def __init__ (
59
- self , type_vars : List [TypeVar ], args : List [VarType ], return_type : VarType , env : Optional [Dict [TypeVar , VarType ]] = None
78
+ self , type_vars : List [TypeVar ], args : List [VarType ], return_type : VarType , env : Optional [Dict [TypeVar , VarType ]] = None , is_static : bool = False
60
79
):
61
80
self .type_vars = type_vars
62
81
self .args = args
63
82
self .return_type = return_type
64
83
self .env = env or {}
84
+ self .is_static = is_static
65
85
66
86
def __repr__ (self ):
67
87
# [a, b, c](x: T, y: U) -> V
68
88
return f"[{ ', ' .join (map (repr , self .type_vars ))} ]({ ', ' .join (map (repr , self .args ))} ) -> { self .return_type } "
69
-
89
+
70
90
def substitute (self , env : Dict [TypeVar , VarType ]) -> "MethodType" :
71
91
return MethodType ([], [subst_type (arg , env ) for arg in self .args ], subst_type (self .return_type , env ), env )
72
92
@@ -108,24 +128,27 @@ def instantiate(self, type_args: List[VarType]) -> "ClassType":
108
128
)
109
129
env = dict (zip (self .type_vars , type_args ))
110
130
return ClassType (
111
- [], {name : subst_type (ty , env ) for name , ty in self .fields .items ()}, {name : method .substitute (env ) for name , method in self .methods .items ()}
131
+ [], {name : subst_type (ty , env ) for name , ty in self .fields .items ()}, {
132
+ name : method .substitute (env ) for name , method in self .methods .items ()}
112
133
)
113
134
135
+
114
136
_CLS_TYPE_INFO : Dict [type , ClassType ] = {}
115
137
116
138
117
- def _class_typeinfo (cls : type ) -> ClassType :
139
+ def class_typeinfo (cls : type ) -> ClassType :
118
140
if cls in _CLS_TYPE_INFO :
119
141
return _CLS_TYPE_INFO [cls ]
120
142
raise RuntimeError (f"Class { cls } is not registered." )
121
143
122
144
123
-
124
145
def _is_class_registered (cls : type ) -> bool :
125
146
return cls in _CLS_TYPE_INFO
126
147
148
+
127
149
_BUILTIN_ANNOTATION_BASES = set ([typing .Generic , typing .Protocol , object ])
128
150
151
+
129
152
def _get_base_classinfo (cls : type , globalns ) -> List [tuple [str , ClassType ]]:
130
153
if not hasattr (cls , "__orig_bases__" ):
131
154
return []
@@ -140,23 +163,99 @@ def _get_base_classinfo(cls: type, globalns) -> List[tuple[str, ClassType]]:
140
163
)
141
164
for arg in base .__args__ :
142
165
if isinstance (arg , typing .ForwardRef ):
143
- arg : type = typing ._eval_type (arg , globalns , globalns ) #type: ignore
166
+ arg : type = typing ._eval_type ( # type: ignore
167
+ arg , globalns , globalns ) # type: ignore
144
168
base_params .append (arg )
145
169
if base_orig in _BUILTIN_ANNOTATION_BASES :
146
170
pass
147
171
else :
148
- base_info = _class_typeinfo (base_orig )
149
- info .append ((base .__name__ , base_info .instantiate (base_params )))
172
+ base_info = class_typeinfo (base_orig )
173
+ info .append (
174
+ (base .__name__ , base_info .instantiate (base_params )))
150
175
else :
151
176
if _is_class_registered (base ):
152
- info .append ((base .__name__ , _class_typeinfo (base )))
177
+ info .append ((base .__name__ , class_typeinfo (base )))
153
178
return info
154
179
180
+
155
181
def _get_cls_globalns (cls : type ) -> Dict [str , Any ]:
156
182
module = inspect .getmodule (cls )
157
183
assert module is not None
158
184
return module .__dict__
159
185
186
+
187
+ def parse_type_hint (hint : Any ) -> VarType :
188
+ if hint is None :
189
+ return NoneType
190
+ if isinstance (hint , TypeVar ):
191
+ return hint
192
+ origin = typing .get_origin (hint )
193
+ if origin :
194
+ if isinstance (origin , type ):
195
+ # assert isinstance(origin, type), f"origin must be a type but got {origin}"
196
+ args = list (typing .get_args (hint ))
197
+ return GenericInstance (origin , [parse_type_hint (arg ) for arg in args ])
198
+ elif origin is Union :
199
+ return UnionType ([parse_type_hint (arg ) for arg in typing .get_args (hint )])
200
+ else :
201
+ raise RuntimeError (f"Unsupported origin type: { origin } " )
202
+ if isinstance (hint , type ):
203
+ return hint
204
+ if hint == typing .Self :
205
+ return SelfType ()
206
+ raise RuntimeError (f"Unsupported type hint: { hint } " )
207
+
208
+
209
+ def extract_type_vars_from_hint (hint : typing .Any ) -> List [TypeVar ]:
210
+ if isinstance (hint , TypeVar ):
211
+ return [hint ]
212
+ if hasattr (hint , "__args__" ): # Handle custom generic types like Foo[T]
213
+ type_vars = []
214
+ for arg in hint .__args__ :
215
+ type_vars .extend (extract_type_vars_from_hint (arg ))
216
+ return type_vars
217
+ return []
218
+
219
+
220
+ def get_type_vars (func : typing .Callable ) -> List [TypeVar ]:
221
+ type_hints = typing .get_type_hints (func )
222
+ type_vars = []
223
+ for hint in type_hints .values ():
224
+ type_vars .extend (extract_type_vars_from_hint (hint ))
225
+ return list (set (type_vars )) # Return unique type vars
226
+
227
+
228
+ def parse_func_signature (func : object , globalns : Dict [str , Any ], foreign_type_vars : List [TypeVar ], self_type : Optional [VarType ] = None , is_static : bool = False ) -> MethodType :
229
+ assert inspect .isfunction (func )
230
+ signature = inspect .signature (func )
231
+ method_type_hints = typing .get_type_hints (func , globalns )
232
+ param_types : List [VarType ] = []
233
+ type_vars = get_type_vars (func )
234
+ for param in signature .parameters .values ():
235
+ if param .name == "self" :
236
+ assert self_type is not None
237
+ param_types .append (self_type )
238
+ else :
239
+ param_types .append (parse_type_hint (
240
+ method_type_hints [param .name ]))
241
+ if "return" in method_type_hints :
242
+ return_type = parse_type_hint (method_type_hints .get ("return" ))
243
+ else :
244
+ return_type = AnyType ()
245
+ # remove foreign type vars from type_vars
246
+ type_vars = [tv for tv in type_vars if tv not in foreign_type_vars ]
247
+ return MethodType (type_vars , param_types , return_type , is_static = is_static )
248
+
249
+
250
+ def is_static (cls : type , method_name : str ) -> bool :
251
+ method = getattr (cls , method_name , None )
252
+ if method is None :
253
+ return False
254
+ # Using inspect to retrieve the method directly from the class
255
+ method = cls .__dict__ .get (method_name , None )
256
+ return isinstance (method , staticmethod )
257
+
258
+
160
259
def register_class (cls : type ) -> None :
161
260
cls_qualname = cls .__qualname__
162
261
globalns = _get_cls_globalns (cls )
@@ -202,47 +301,20 @@ def register_class(cls: type) -> None:
202
301
continue
203
302
local_methods .add (name )
204
303
205
- def parse_type_hint (hint : Any ) -> VarType :
206
-
207
- if hint is None :
208
- return NoneType
209
- if isinstance (hint , TypeVar ):
210
- return hint
211
- origin = typing .get_origin (hint )
212
- if origin :
213
- if isinstance (origin , type ):
214
- # assert isinstance(origin, type), f"origin must be a type but got {origin}"
215
- args = list (typing .get_args (hint ))
216
- return GenericInstance (origin , [parse_type_hint (arg ) for arg in args ])
217
- elif origin is Union :
218
- return UnionType ([parse_type_hint (arg ) for arg in typing .get_args (hint )])
219
- else :
220
- raise RuntimeError (f"Unsupported origin type: { origin } " )
221
- if isinstance (hint , type ):
222
- return hint
223
- raise RuntimeError (f"Unsupported type hint: { hint } " )
224
-
225
304
cls_ty = ClassType ([], {}, {})
226
305
for _base_name , base_info in base_infos :
227
306
cls_ty .fields .update (base_info .fields )
228
307
cls_ty .methods .update (base_info .methods )
308
+
229
309
if type_vars :
230
310
for tv in type_vars :
231
311
cls_ty .type_vars .append (tv )
312
+ self_ty : VarType = SelfType ()
232
313
for name , member in inspect .getmembers (cls ):
233
314
if name in local_methods :
234
- assert inspect .isfunction (member )
235
- signature = inspect .signature (member )
236
- method_type_hints = typing .get_type_hints (member )
237
- param_types : List [VarType ] = []
238
- for param in signature .parameters .values ():
239
- if param .name == "self" :
240
- param_types .append (cls )
241
- else :
242
- param_types .append (parse_type_hint (
243
- method_type_hints [param .name ]))
244
- return_type = parse_type_hint (method_type_hints .get ("return" ))
245
- cls_ty .methods [name ] = MethodType ([], param_types , return_type )
315
+ # print(f'Found local method: {name} in {cls}')
316
+ cls_ty .methods [name ] = parse_func_signature (
317
+ member , globalns , cls_ty .type_vars , self_ty , is_static = is_static (cls , name ))
246
318
for name in local_fields :
247
319
cls_ty .fields [name ] = parse_type_hint (type_hints [name ])
248
320
_CLS_TYPE_INFO [cls ] = cls_ty
0 commit comments