@@ -613,44 +613,46 @@ def get_supported_annotation(
613
613
* ,
614
614
_none : type = NoneType ,
615
615
_mapping : Dict [Any , Type [Transformer ]] = BUILT_IN_TRANSFORMERS ,
616
- ) -> Tuple [Any , Any ]:
616
+ ) -> Tuple [Any , Any , bool ]:
617
617
"""Returns an appropriate, yet supported, annotation along with an optional default value.
618
618
619
+ The third boolean element of the tuple indicates if default values should be validated.
620
+
619
621
This differs from the built in mapping by supporting a few more things.
620
622
Likewise, this returns a "transformed" annotation that is ready to use with CommandParameter.transform.
621
623
"""
622
624
623
625
try :
624
- return (_mapping [annotation ], MISSING )
626
+ return (_mapping [annotation ], MISSING , True )
625
627
except KeyError :
626
628
pass
627
629
628
630
if hasattr (annotation , '__discord_app_commands_transform__' ):
629
- return (annotation .metadata , MISSING )
631
+ return (annotation .metadata , MISSING , False )
630
632
631
633
if hasattr (annotation , '__metadata__' ):
632
634
return get_supported_annotation (annotation .__metadata__ [0 ])
633
635
634
636
if inspect .isclass (annotation ):
635
637
if issubclass (annotation , Transformer ):
636
- return (annotation , MISSING )
638
+ return (annotation , MISSING , False )
637
639
if issubclass (annotation , (Enum , InternalEnum )):
638
640
if all (isinstance (v .value , (str , int , float )) for v in annotation ):
639
- return (_make_enum_transformer (annotation ), MISSING )
641
+ return (_make_enum_transformer (annotation ), MISSING , False )
640
642
else :
641
- return (_make_complex_enum_transformer (annotation ), MISSING )
643
+ return (_make_complex_enum_transformer (annotation ), MISSING , False )
642
644
if annotation is Choice :
643
645
raise TypeError (f'Choice requires a type argument of int, str, or float' )
644
646
645
647
# Check if there's an origin
646
648
origin = getattr (annotation , '__origin__' , None )
647
649
if origin is Literal :
648
650
args = annotation .__args__ # type: ignore
649
- return (_make_literal_transformer (args ), MISSING )
651
+ return (_make_literal_transformer (args ), MISSING , True )
650
652
651
653
if origin is Choice :
652
654
arg = annotation .__args__ [0 ] # type: ignore
653
- return (_make_choice_transformer (arg ), MISSING )
655
+ return (_make_choice_transformer (arg ), MISSING , True )
654
656
655
657
if origin is not Union :
656
658
# Only Union/Optional is supported right now so bail early
@@ -661,18 +663,18 @@ def get_supported_annotation(
661
663
if args [- 1 ] is _none :
662
664
if len (args ) == 2 :
663
665
underlying = args [0 ]
664
- inner , _ = get_supported_annotation (underlying )
666
+ inner , _ , validate_default = get_supported_annotation (underlying )
665
667
if inner is None :
666
668
raise TypeError (f'unsupported inner optional type { underlying !r} ' )
667
- return (inner , None )
669
+ return (inner , None , validate_default )
668
670
else :
669
671
args = args [:- 1 ]
670
672
default = None
671
673
672
674
# Check for channel union types
673
675
if any (arg in CHANNEL_TO_TYPES for arg in args ):
674
676
# If any channel type is given, then *all* must be channel types
675
- return (channel_transformer (* args , raw = None ), default )
677
+ return (channel_transformer (* args , raw = None ), default , True )
676
678
677
679
# The only valid transformations here are:
678
680
# [Member, User] => user
@@ -682,9 +684,9 @@ def get_supported_annotation(
682
684
if not all (arg in supported_types for arg in args ):
683
685
raise TypeError (f'unsupported types given inside { annotation !r} ' )
684
686
if args == (User , Member ) or args == (Member , User ):
685
- return (passthrough_transformer (AppCommandOptionType .user ), default )
687
+ return (passthrough_transformer (AppCommandOptionType .user ), default , True )
686
688
687
- return (passthrough_transformer (AppCommandOptionType .mentionable ), default )
689
+ return (passthrough_transformer (AppCommandOptionType .mentionable ), default , True )
688
690
689
691
690
692
def annotation_to_parameter (annotation : Any , parameter : inspect .Parameter ) -> CommandParameter :
@@ -695,7 +697,7 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co
695
697
of a command parameter.
696
698
"""
697
699
698
- (inner , default ) = get_supported_annotation (annotation )
700
+ (inner , default , validate_default ) = get_supported_annotation (annotation )
699
701
type = inner .type ()
700
702
701
703
if default is MISSING or default is None :
@@ -704,12 +706,10 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co
704
706
default = param_default
705
707
706
708
# Verify validity of the default parameter
707
- if default is not MISSING :
708
- enum_type = getattr (inner , '__discord_app_commands_transformer_enum__' , None )
709
- if default .__class__ is not enum_type :
710
- valid_types : Tuple [Any , ...] = ALLOWED_DEFAULTS .get (type , (NoneType ,))
711
- if not isinstance (default , valid_types ):
712
- raise TypeError (f'invalid default parameter type given ({ default .__class__ } ), expected { valid_types } ' )
709
+ if default is not MISSING and validate_default :
710
+ valid_types : Tuple [Any , ...] = ALLOWED_DEFAULTS .get (type , (NoneType ,))
711
+ if not isinstance (default , valid_types ):
712
+ raise TypeError (f'invalid default parameter type given ({ default .__class__ } ), expected { valid_types } ' )
713
713
714
714
result = CommandParameter (
715
715
type = type ,
0 commit comments