@@ -93,6 +93,22 @@ def __str__(self) -> str:
9393 return self .raw_value
9494
9595
96+ from typing import Generic , TypeVar , Iterable
97+
98+ T = TypeVar ("T" )
99+
100+
101+ class NonEmptyList (List [T ], Generic [T ]):
102+ """
103+ When using the List[T] annotation we assume that it is allowed to have zero elements in it.
104+ If you need to indicate, that a List is guaranteed to be non-empty, use this type annotation instead.
105+ Note that this is also enforced during parsing, i.e., parsing an empty list declared as being non-empty
106+ will result in an error.
107+ """
108+
109+ pass
110+
111+
96112class TypeBase :
97113 def construct (self , args : Dict [str , any ]):
98114 raise NotImplementedError (self )
@@ -244,7 +260,7 @@ def _resolve(self):
244260
245261def _reflect_list (item_cls , globals , locals , symbol_table : SymbolTable , allow_empty : bool ) -> ClassType :
246262 item_cls = _unwrap_forward_ref (item_cls , globals , locals )
247- _reflect (item_cls , globals , locals , symbol_table , False )
263+ _reflect (item_cls , globals , locals , symbol_table )
248264 return ListType (UnresolvedType (item_cls ), allow_empty )
249265
250266
@@ -261,40 +277,38 @@ def _reflect_class(cls, globals, locals, symbol_table: SymbolTable) -> ClassType
261277 except KeyError :
262278 pass
263279 force_name = field .metadata .get ("force_name" , None )
264- allow_empty_list = field .metadata .get ("allow_empty" , False )
265280 field_type = _unwrap_forward_ref (field_type , globals , locals )
266281 attrs .append (ClassType .Attribute (field .name , UnresolvedType (field_type ), required , force_name ))
267- _reflect (field_type , globals , locals , symbol_table , allow_empty_list )
282+ _reflect (field_type , globals , locals , symbol_table )
268283
269284 subclasses : List [TypeBase ] = []
270285 for subclass in _collect_subclasses (cls ):
271286 subclasses .append (UnresolvedType (subclass ))
272- _reflect (subclass , globals , locals , symbol_table , False )
287+ _reflect (subclass , globals , locals , symbol_table )
273288 return ClassType (cls , attrs , static_attrs , subclasses )
274289
275290
276- def _reflect (
277- cls : any , globals , locals , symbol_table : SymbolTable , allow_empty_list : bool
278- ) -> Tuple [TypeBase , SymbolTable ]:
291+ def _reflect (cls : any , globals , locals , symbol_table : SymbolTable ) -> Tuple [TypeBase , SymbolTable ]:
279292 key = str (cls )
280293 try :
281294 return symbol_table .symbols [key ]
282295 except KeyError :
283296 # Avoid infinite recursion if _reflect_unsafe calls itself again
284297 symbol_table .symbols [key ] = None
285- result = _reflect_unsafe (cls , globals , locals , symbol_table , allow_empty_list )
298+ result = _reflect_unsafe (cls , globals , locals , symbol_table )
286299 symbol_table .symbols [key ] = result
287300 return result
288301
289302
290- def _reflect_unsafe (
291- cls : any , globals , locals , symbol_table : SymbolTable , allow_empty_list : bool
292- ) -> Tuple [TypeBase , SymbolTable ]:
303+ def _reflect_unsafe (cls : any , globals , locals , symbol_table : SymbolTable ) -> Tuple [TypeBase , SymbolTable ]:
293304 origin = getattr (cls , "__origin__" , None )
294305 if origin :
295306 if origin is list :
296307 item_type = cls .__args__ [0 ]
297- return _reflect_list (item_type , globals , locals , symbol_table , allow_empty_list )
308+ return _reflect_list (item_type , globals , locals , symbol_table , True )
309+ elif origin is NonEmptyList :
310+ item_type = cls .__args__ [0 ]
311+ return _reflect_list (item_type , globals , locals , symbol_table , False )
298312 else :
299313 if cls is None :
300314 return NoneType ()
@@ -327,20 +341,20 @@ def _reflect_unsafe(
327341
328342def reflect (cls : any , globals = {}, locals = {}) -> Tuple [TypeBase , SymbolTable ]:
329343 symbol_table = SymbolTable ()
330- type = _reflect (cls , globals , locals , symbol_table , False )
344+ type = _reflect (cls , globals , locals , symbol_table )
331345 symbol_table ._resolve ()
332346 return type , symbol_table
333347
334348
335349def reflect_function (fn : callable , globals = {}, locals = {}) -> FunctionType :
336350 symbol_table = SymbolTable ()
337351 return_type = fn .__annotations__ .get ("return" , None )
338- r_return_type = _reflect (return_type , globals , locals , symbol_table , False )
352+ r_return_type = _reflect (return_type , globals , locals , symbol_table )
339353 args : List [FunctionType .Argument ] = []
340354 for key , value in fn .__annotations__ .items ():
341355 if key in ["return" ]:
342356 continue
343357 required , arg_type = _unwrap_optional (value )
344- r_arg_type = _reflect (arg_type , globals , locals , symbol_table , False )
358+ r_arg_type = _reflect (arg_type , globals , locals , symbol_table )
345359 args .append (FunctionType .Argument (key , r_arg_type , required ))
346360 return FunctionType (fn , r_return_type , args )
0 commit comments