Skip to content

Commit ce30705

Browse files
committed
better types; now int, float are supported dsl types; only float and int types of other bit width needs to write lc.ixx, lc.fxx explicitly
1 parent 9c0acb7 commit ce30705

File tree

4 files changed

+1535
-1521
lines changed

4 files changed

+1535
-1521
lines changed

luisa_lang/hir.py

+32
Original file line numberDiff line numberDiff line change
@@ -1597,3 +1597,35 @@ def inline(func: Function, args: List[Value | Ref], body: BasicBlock, span: Opti
15971597
inliner = FunctionInliner(func, args, body, span)
15981598
assert inliner.ret
15991599
return inliner.ret
1600+
1601+
1602+
def register_dsl_type_alias(target: type, alias: type):
1603+
"""
1604+
Allow a type to be remapped to another type within DSL code.
1605+
Parameters:
1606+
target (type): The type to be remapped.
1607+
alias (type): The type to which the target type will be remapped.
1608+
Example:
1609+
1610+
For example,
1611+
```python
1612+
@lc.struct
1613+
class Foo:
1614+
x: int
1615+
y: int
1616+
1617+
class SomeOtherFoo:
1618+
components: List[int]
1619+
1620+
register_dsl_type_alias(SomeOtherFoo, Foo)
1621+
1622+
@lc.func
1623+
def foo(f: SomeOtherFoo): # SomeOtherFoo is interpreted as Foo
1624+
...
1625+
1626+
```
1627+
"""
1628+
ctx = GlobalContext.get()
1629+
alias_ty = get_dsl_type(alias)
1630+
assert alias_ty, f"alias type {alias} is not a DSL type"
1631+
ctx.types[target] = alias_ty

luisa_lang/lang_builtins.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,11 @@ class Array(Generic[T, N]):
211211
def __init__(self) -> None:
212212
self = intrinsic("init.array", Array[T, N])
213213

214-
def __getitem__(self, index: int | i32 | u32 | i64 | u64) -> T:
214+
def __getitem__(self, index: int | u32 | i64 | u64) -> T:
215215
return intrinsic("array.ref", T, byref(self), index) # type: ignore
216216

217-
def __setitem__(self, index: int | i32 | u32 | i64 | u64, value: T | int | float) -> None:
218-
"""value: T | int | float annotation is to make mypy happy. this function is ignored by the compiler"""
217+
def __setitem__(self, index: int | u32 | i64 | u64, value: T) -> None:
218+
pass
219219

220220
def __len__(self) -> u64:
221221
return intrinsic("array.size", u64, self) # type: ignore
@@ -236,11 +236,10 @@ def __len__(self) -> u64:
236236

237237
@opaque("Buffer")
238238
class Buffer(Generic[T]):
239-
def __getitem__(self, index: int | i32 | u32 | i64 | u64) -> T:
239+
def __getitem__(self, index: int | u32 | i64 | u64) -> T:
240240
return intrinsic("buffer.ref", T, self, index) # type: ignore
241241

242-
def __setitem__(self, index: int | i32 | u32 | i64 | u64, value: T | int | float) -> None:
243-
"""value: T | int | float annotation is to make mypy happy. this function is ignored by the compiler"""
242+
def __setitem__(self, index: int | u32 | i64 | u64, value: T) -> None:
244243
pass
245244

246245
def __len__(self) -> u64:

0 commit comments

Comments
 (0)