Skip to content

Commit 699aff2

Browse files
committed
added comparison and unary op
1 parent 1646db0 commit 699aff2

File tree

1 file changed

+63
-42
lines changed

1 file changed

+63
-42
lines changed

luisa_lang/parse.py

+63-42
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,24 @@ def collect_args() -> List[hir.Value | hir.Ref]:
534534
raise hir.ParsingError(expr, ret.message)
535535
return ret
536536

537-
def parse_binop(self, expr: ast.BinOp) -> hir.Value:
537+
# def parse_compare(self, expr: ast.Compare) -> hir.Value | ComptimeValue:
538+
# cmpop_to_str: Dict[type, str] = {
539+
# ast.Eq: "==",
540+
# ast.NotEq: "!=",
541+
# ast.Lt: "<",
542+
# ast.LtE: "<=",
543+
# ast.Gt: ">",
544+
# ast.GtE: ">="
545+
# }
546+
# if len(expr.ops) != 1:
547+
# raise hir.ParsingError(expr, "only one comparison operator is allowed")
548+
# op = expr.ops[0]
549+
# if type(op) not in cmpop_to_str:
550+
# raise hir.ParsingError(expr, f"unsupported comparison operator {type(op)}")
551+
# op_str = cmpop_to_str[type(op)]
552+
# method_name = BINOP_TO_METHOD_NAMES[type(op)]
553+
554+
def parse_binop(self, expr: ast.BinOp | ast.Compare) -> hir.Value:
538555
binop_to_op_str: Dict[type, str] = {
539556
ast.Add: "+",
540557
ast.Sub: "-",
@@ -556,20 +573,32 @@ def parse_binop(self, expr: ast.BinOp) -> hir.Value:
556573
ast.GtE: ">=",
557574

558575
}
559-
op_str = binop_to_op_str[type(expr.op)]
560-
lhs = self.parse_expr(expr.left)
576+
op: ast.AST
577+
if isinstance(expr, ast.Compare):
578+
if len(expr.ops) != 1:
579+
raise hir.ParsingError(
580+
expr, "only one comparison operator is allowed")
581+
op = expr.ops[0]
582+
left = expr.left
583+
right = expr.comparators[0]
584+
else:
585+
op = expr.op
586+
left = expr.left
587+
right = expr.right
588+
op_str = binop_to_op_str[type(op)]
589+
lhs = self.parse_expr(left)
561590
if isinstance(lhs, ComptimeValue):
562591
lhs = self.try_convert_comptime_value(lhs, hir.Span.from_ast(expr))
563592
if not lhs.type:
564593
raise hir.ParsingError(
565-
expr.left, f"unable to infer type of left operand of binary operation {op_str}")
566-
rhs = self.parse_expr(expr.right)
594+
left, f"unable to infer type of left operand of binary operation {op_str}")
595+
rhs = self.parse_expr(right)
567596
if isinstance(rhs, ComptimeValue):
568597
rhs = self.try_convert_comptime_value(rhs, hir.Span.from_ast(expr))
569598
if not rhs.type:
570599
raise hir.ParsingError(
571-
expr.right, f"unable to infer type of right operand of binary operation {op_str}")
572-
ops = BINOP_TO_METHOD_NAMES[type(expr.op)]
600+
right, f"unable to infer type of right operand of binary operation {op_str}")
601+
ops = BINOP_TO_METHOD_NAMES[type(op)]
573602

574603
def infer_binop(name: str, rname: str) -> hir.Value:
575604
assert lhs.type and rhs.type
@@ -712,6 +741,30 @@ def check(i: int, val_type: hir.Type) -> None:
712741
raise hir.ParsingError(
713742
targets[0], f"unsupported type for unpacking: {values.type}")
714743

744+
def parse_unary(self, expr: ast.UnaryOp) -> hir.Value:
745+
op = expr.op
746+
if type(op) not in UNARY_OP_TO_METHOD_NAMES:
747+
raise hir.ParsingError(
748+
expr, f"unsupported unary operator {type(op)}")
749+
op_str = UNARY_OP_TO_METHOD_NAMES[type(op)]
750+
operand = self.parse_expr(expr.operand)
751+
if isinstance(operand, ComptimeValue):
752+
operand = self.try_convert_comptime_value(
753+
operand, hir.Span.from_ast(expr))
754+
if not operand.type:
755+
raise hir.ParsingError(
756+
expr.operand, f"unable to infer type of operand of unary operation {op_str}")
757+
method_name = UNARY_OP_TO_METHOD_NAMES[type(op)]
758+
if (method := operand.type.method(method_name)) and method:
759+
ret = self.parse_call_impl(
760+
hir.Span.from_ast(expr), method, [operand])
761+
if isinstance(ret, hir.TemplateMatchingError):
762+
raise hir.ParsingError(expr, ret.message)
763+
return ret
764+
else:
765+
raise hir.ParsingError(
766+
expr, f"operator {type(op)} not defined for type {operand.type}")
767+
715768
def parse_expr(self, expr: ast.expr) -> hir.Value | ComptimeValue:
716769
match expr:
717770
case ast.Constant():
@@ -723,8 +776,10 @@ def parse_expr(self, expr: ast.expr) -> hir.Value | ComptimeValue:
723776
return ret
724777
case ast.Subscript() | ast.Attribute():
725778
return self.parse_access(expr)
726-
case ast.BinOp():
779+
case ast.BinOp() | ast.Compare():
727780
return self.parse_binop(expr)
781+
case ast.UnaryOp():
782+
return self.parse_unary(expr)
728783
case ast.Call():
729784
return self.parse_call(expr)
730785
case ast.Tuple():
@@ -970,40 +1025,6 @@ def parse_anno_ty() -> hir.Type:
9701025
if stmt.value:
9711026
self.parse_multi_assignment(
9721027
[stmt.target], [parse_anno_ty], self.parse_expr(stmt.value))
973-
# value = self.parse_expr(stmt.value)
974-
# if isinstance(value, ComptimeValue):
975-
# var = self.parse_ref(
976-
# stmt.target, new_var_hint='comptime')
977-
# else:
978-
# var = self.parse_ref(stmt.target, new_var_hint='dsl')
979-
# if isinstance(var, ComptimeValue):
980-
# if isinstance(value, ComptimeValue):
981-
# try:
982-
# var.update(value.value)
983-
# except Exception as e:
984-
# raise hir.ParsingError(
985-
# stmt, f"error updating comptime value: {e}") from e
986-
# return
987-
# else:
988-
# raise hir.ParsingError(
989-
# stmt, f"comptime value cannot be assigned with DSL value")
990-
# else:
991-
# if isinstance(value, ComptimeValue):
992-
# value = self.try_convert_comptime_value(
993-
# value, span)
994-
# assert value.type
995-
# anno_ty = parse_anno_ty()
996-
# if not var.type:
997-
# var.type = value.type
998-
# if not var.type.is_concrete():
999-
# raise hir.ParsingError(
1000-
# stmt, "only concrete type can be assigned, please annotate the variable with concrete types")
1001-
# if not hir.is_type_compatible_to(value.type, anno_ty):
1002-
# raise hir.ParsingError(
1003-
# stmt, f"expected {anno_ty}, got {value.type}")
1004-
# if not value.type.is_concrete():
1005-
# value.type = var.type
1006-
# self.cur_bb().append(hir.Assign(var, value, span))
10071028
else:
10081029
var = self.parse_ref(stmt.target, new_var_hint='dsl')
10091030
anno_ty = parse_anno_ty()

0 commit comments

Comments
 (0)