Skip to content

Commit da68f83

Browse files
committed
added if and while
1 parent 0d624e5 commit da68f83

File tree

4 files changed

+329
-129
lines changed

4 files changed

+329
-129
lines changed

luisa_lang/hir.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
# matched: bool
3232

3333

34-
FunctionTemplateResolvingArgs = List[Tuple[str, Union['Type', 'ComptimeValue']]]
34+
FunctionTemplateResolvingArgs = List[Tuple[str,
35+
Union['Type', 'ComptimeValue']]]
3536
"""
3637
[Function parameter name, Type or Value].
3738
The reason for using parameter name instead of GenericParameter is that python supports passing type[T] as a parameter,
@@ -727,13 +728,24 @@ class Ref(TypedNode):
727728
pass
728729

729730

731+
class LocalRef(Ref):
732+
value: 'Value'
733+
734+
def __init__(self, value: 'Value') -> None:
735+
super().__init__(value.type)
736+
self.value = value
737+
self.span = value.span
738+
739+
730740
class Value(TypedNode):
731741
pass
732742

743+
733744
class Unit(Value):
734745
def __init__(self) -> None:
735746
super().__init__(UnitType())
736747

748+
737749
class SymbolicConstant(Value):
738750
generic: GenericParameter
739751

@@ -884,20 +896,32 @@ def __str__(self) -> str:
884896
return f"Template matching error at {self.span}:\n\t{self.message}"
885897

886898

887-
class TypeInferenceError(Exception):
899+
class SpannedError(Exception):
888900
span: Span | None
889901
message: str
890902

891-
def __init__(self, node: Node | Span | None, message: str) -> None:
903+
def __init__(self, node: Node | Span | ast.AST | None, message: str) -> None:
892904
if node is not None:
893-
if isinstance(node, Node):
894-
self.span = node.span
895-
else:
896-
self.span = node
905+
match node:
906+
case Node():
907+
self.span = node.span
908+
case Span():
909+
self.span = node
910+
case ast.AST():
911+
self.span = Span.from_ast(node)
897912
else:
898913
self.span = None
899914
self.message = message
900915

916+
917+
class ParsingError(SpannedError):
918+
def __str__(self) -> str:
919+
if self.span is None:
920+
return f"Parsing error:\n\t{self.message}"
921+
return f"Parsing error at {self.span}:\n\t{self.message}"
922+
923+
924+
class TypeInferenceError(SpannedError):
901925
def __str__(self) -> str:
902926
if self.span is None:
903927
return f"Type inference error:\n\t{self.message}"
@@ -998,6 +1022,18 @@ def __init__(self, value: Optional[Value], span: Optional[Span] = None) -> None:
9981022
self.value = value
9991023

10001024

1025+
class Range(Value):
1026+
start: Value
1027+
step: Optional[Value]
1028+
stop: Optional[Value]
1029+
1030+
def __init__(self, start: Value, stop: Optional[Value] = None, step: Optional[Value] = None, span: Optional[Span] = None) -> None:
1031+
super().__init__(None, span)
1032+
self.start = start
1033+
self.stop = stop
1034+
self.step = step
1035+
1036+
10011037
class ComptimeValue:
10021038
value: Any
10031039
update_func: Optional[Callable[[Any], None]]

luisa_lang/lang.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
List,
88
Optional,
99
Sequence,
10+
Set,
1011
Tuple,
1112
TypeAlias,
1213
TypeVar,
@@ -50,34 +51,48 @@ def _make_func_template(f: Callable[..., Any], func_name: str, func_globals: Dic
5051

5152
func_sig = classinfo.parse_func_signature(f, func_globals, [])
5253
func_sig_converted, sig_parser = parse.convert_func_signature(
53-
func_sig, func_name, func_globals, {}, [], self_type)
54+
func_sig, func_name, func_globals, {}, {}, self_type)
55+
implicit_type_params = sig_parser.implicit_type_params
56+
implicit_generic_params: Set[hir.GenericParameter] = set()
57+
for p in implicit_type_params.values():
58+
assert isinstance(p, hir.SymbolicType)
59+
implicit_generic_params.add(p.param)
5460

5561
def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike:
5662
type_var_ns: Dict[TypeVar, hir.Type | hir.ComptimeValue] = {}
57-
any_param_types: List[hir.Type] = []
63+
mapped_implicit_type_params: Dict[str,
64+
hir.Type] = dict()
5865
if is_generic:
5966
mapping = hir.match_func_template_args(func_sig_converted, args)
6067
if isinstance(mapping, hir.TypeInferenceError):
6168
raise mapping
6269
if len(mapping) != len(func_sig_converted.generic_params):
63-
# print(mapping, func_sig_converted.generic_params)
6470
raise hir.TypeInferenceError(
6571
None, "not all type parameters are resolved")
6672
for p in func_sig_converted.generic_params:
6773
if p not in mapping:
6874
raise hir.TypeInferenceError(
6975
None, f"type parameter {p} is not resolved")
70-
type_var_ns[sig_parser.generic_param_to_type_var[p]
71-
] = mapping[p]
72-
# print(f'binding {p.name} = {mapping[p]}, tv: {sig_parser.generic_param_to_type_var[p]} @{id(sig_parser.generic_param_to_type_var[p])}')
73-
# print('parsing instantiated signature')
74-
func_sig_instantiated, _ = parse.convert_func_signature(
75-
func_sig, func_name, func_globals, type_var_ns, any_param_types, self_type)
76+
if p not in implicit_generic_params:
77+
type_var_ns[sig_parser.generic_param_to_type_var[p]
78+
] = mapping[p]
79+
80+
for name, itp, in implicit_type_params.items():
81+
assert isinstance(itp, hir.SymbolicType)
82+
gp = itp.param
83+
mapped_type = mapping[gp]
84+
assert isinstance(mapped_type, hir.Type)
85+
mapped_implicit_type_params[name] = mapped_type
86+
func_sig_instantiated, _p = parse.convert_func_signature(
87+
func_sig, func_name, func_globals, type_var_ns, mapped_implicit_type_params, self_type, mode='instantiate')
88+
assert len(
89+
func_sig_instantiated.generic_params) == 0, f"generic params should be resolved but found {func_sig_instantiated.generic_params}"
7690
func_parser = FuncParser(
7791
func_name, f, func_sig_instantiated, func_globals, type_var_ns, self_type)
7892
return func_parser.parse_body()
7993
params = [v[0] for v in func_sig.args]
80-
is_generic = len(func_sig.type_vars) > 0
94+
is_generic = len(func_sig_converted.generic_params) > 0
95+
# print(f"func {func_name} is_generic: {is_generic}")
8196
return hir.FunctionTemplate(func_name, params, parsing_func, is_generic)
8297

8398

@@ -223,5 +238,3 @@ def decorator(f):
223238
return impl(f)
224239

225240
return decorator
226-
227-

0 commit comments

Comments
 (0)