@@ -534,7 +534,24 @@ def collect_args() -> List[hir.Value | hir.Ref]:
534
534
raise hir .ParsingError (expr , ret .message )
535
535
return ret
536
536
537
- def parse_binop (self , expr : ast .BinOp ) -> hir .Value :
537
+ # def parse_compare(self, expr: ast.Compare) -> hir.Value | ComptimeValue:
538
+ # cmpop_to_str: Dict[type, str] = {
539
+ # ast.Eq: "==",
540
+ # ast.NotEq: "!=",
541
+ # ast.Lt: "<",
542
+ # ast.LtE: "<=",
543
+ # ast.Gt: ">",
544
+ # ast.GtE: ">="
545
+ # }
546
+ # if len(expr.ops) != 1:
547
+ # raise hir.ParsingError(expr, "only one comparison operator is allowed")
548
+ # op = expr.ops[0]
549
+ # if type(op) not in cmpop_to_str:
550
+ # raise hir.ParsingError(expr, f"unsupported comparison operator {type(op)}")
551
+ # op_str = cmpop_to_str[type(op)]
552
+ # method_name = BINOP_TO_METHOD_NAMES[type(op)]
553
+
554
+ def parse_binop (self , expr : ast .BinOp | ast .Compare ) -> hir .Value :
538
555
binop_to_op_str : Dict [type , str ] = {
539
556
ast .Add : "+" ,
540
557
ast .Sub : "-" ,
@@ -556,20 +573,32 @@ def parse_binop(self, expr: ast.BinOp) -> hir.Value:
556
573
ast .GtE : ">=" ,
557
574
558
575
}
559
- op_str = binop_to_op_str [type (expr .op )]
560
- lhs = self .parse_expr (expr .left )
576
+ op : ast .AST
577
+ if isinstance (expr , ast .Compare ):
578
+ if len (expr .ops ) != 1 :
579
+ raise hir .ParsingError (
580
+ expr , "only one comparison operator is allowed" )
581
+ op = expr .ops [0 ]
582
+ left = expr .left
583
+ right = expr .comparators [0 ]
584
+ else :
585
+ op = expr .op
586
+ left = expr .left
587
+ right = expr .right
588
+ op_str = binop_to_op_str [type (op )]
589
+ lhs = self .parse_expr (left )
561
590
if isinstance (lhs , ComptimeValue ):
562
591
lhs = self .try_convert_comptime_value (lhs , hir .Span .from_ast (expr ))
563
592
if not lhs .type :
564
593
raise hir .ParsingError (
565
- expr . left , f"unable to infer type of left operand of binary operation { op_str } " )
566
- rhs = self .parse_expr (expr . right )
594
+ left , f"unable to infer type of left operand of binary operation { op_str } " )
595
+ rhs = self .parse_expr (right )
567
596
if isinstance (rhs , ComptimeValue ):
568
597
rhs = self .try_convert_comptime_value (rhs , hir .Span .from_ast (expr ))
569
598
if not rhs .type :
570
599
raise hir .ParsingError (
571
- expr . right , f"unable to infer type of right operand of binary operation { op_str } " )
572
- ops = BINOP_TO_METHOD_NAMES [type (expr . op )]
600
+ right , f"unable to infer type of right operand of binary operation { op_str } " )
601
+ ops = BINOP_TO_METHOD_NAMES [type (op )]
573
602
574
603
def infer_binop (name : str , rname : str ) -> hir .Value :
575
604
assert lhs .type and rhs .type
@@ -712,6 +741,30 @@ def check(i: int, val_type: hir.Type) -> None:
712
741
raise hir .ParsingError (
713
742
targets [0 ], f"unsupported type for unpacking: { values .type } " )
714
743
744
+ def parse_unary (self , expr : ast .UnaryOp ) -> hir .Value :
745
+ op = expr .op
746
+ if type (op ) not in UNARY_OP_TO_METHOD_NAMES :
747
+ raise hir .ParsingError (
748
+ expr , f"unsupported unary operator { type (op )} " )
749
+ op_str = UNARY_OP_TO_METHOD_NAMES [type (op )]
750
+ operand = self .parse_expr (expr .operand )
751
+ if isinstance (operand , ComptimeValue ):
752
+ operand = self .try_convert_comptime_value (
753
+ operand , hir .Span .from_ast (expr ))
754
+ if not operand .type :
755
+ raise hir .ParsingError (
756
+ expr .operand , f"unable to infer type of operand of unary operation { op_str } " )
757
+ method_name = UNARY_OP_TO_METHOD_NAMES [type (op )]
758
+ if (method := operand .type .method (method_name )) and method :
759
+ ret = self .parse_call_impl (
760
+ hir .Span .from_ast (expr ), method , [operand ])
761
+ if isinstance (ret , hir .TemplateMatchingError ):
762
+ raise hir .ParsingError (expr , ret .message )
763
+ return ret
764
+ else :
765
+ raise hir .ParsingError (
766
+ expr , f"operator { type (op )} not defined for type { operand .type } " )
767
+
715
768
def parse_expr (self , expr : ast .expr ) -> hir .Value | ComptimeValue :
716
769
match expr :
717
770
case ast .Constant ():
@@ -723,8 +776,10 @@ def parse_expr(self, expr: ast.expr) -> hir.Value | ComptimeValue:
723
776
return ret
724
777
case ast .Subscript () | ast .Attribute ():
725
778
return self .parse_access (expr )
726
- case ast .BinOp ():
779
+ case ast .BinOp () | ast . Compare () :
727
780
return self .parse_binop (expr )
781
+ case ast .UnaryOp ():
782
+ return self .parse_unary (expr )
728
783
case ast .Call ():
729
784
return self .parse_call (expr )
730
785
case ast .Tuple ():
@@ -970,40 +1025,6 @@ def parse_anno_ty() -> hir.Type:
970
1025
if stmt .value :
971
1026
self .parse_multi_assignment (
972
1027
[stmt .target ], [parse_anno_ty ], self .parse_expr (stmt .value ))
973
- # value = self.parse_expr(stmt.value)
974
- # if isinstance(value, ComptimeValue):
975
- # var = self.parse_ref(
976
- # stmt.target, new_var_hint='comptime')
977
- # else:
978
- # var = self.parse_ref(stmt.target, new_var_hint='dsl')
979
- # if isinstance(var, ComptimeValue):
980
- # if isinstance(value, ComptimeValue):
981
- # try:
982
- # var.update(value.value)
983
- # except Exception as e:
984
- # raise hir.ParsingError(
985
- # stmt, f"error updating comptime value: {e}") from e
986
- # return
987
- # else:
988
- # raise hir.ParsingError(
989
- # stmt, f"comptime value cannot be assigned with DSL value")
990
- # else:
991
- # if isinstance(value, ComptimeValue):
992
- # value = self.try_convert_comptime_value(
993
- # value, span)
994
- # assert value.type
995
- # anno_ty = parse_anno_ty()
996
- # if not var.type:
997
- # var.type = value.type
998
- # if not var.type.is_concrete():
999
- # raise hir.ParsingError(
1000
- # stmt, "only concrete type can be assigned, please annotate the variable with concrete types")
1001
- # if not hir.is_type_compatible_to(value.type, anno_ty):
1002
- # raise hir.ParsingError(
1003
- # stmt, f"expected {anno_ty}, got {value.type}")
1004
- # if not value.type.is_concrete():
1005
- # value.type = var.type
1006
- # self.cur_bb().append(hir.Assign(var, value, span))
1007
1028
else :
1008
1029
var = self .parse_ref (stmt .target , new_var_hint = 'dsl' )
1009
1030
anno_ty = parse_anno_ty ()
0 commit comments