Skip to content

Commit 4562fc9

Browse files
committed
progress
1 parent e5d69de commit 4562fc9

File tree

8 files changed

+3047
-2925
lines changed

8 files changed

+3047
-2925
lines changed

README.md

+30
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ A new Python DSL frontend for LuisaCompute. Will be integrated into LuisaCompute
1111
- [Functions](#functions)
1212
- [User-defined Structs](#user-defined-structs)
1313
- [Control Flow](#control-flow)
14+
- [Define DSL Operation for Non-DSL Types](#define-dsl-operation-for-non-dsl-types)
1415
- [Advanced Usage](#advanced-syntax)
1516
- [Generics](#generics)
1617
- [Metaprogramming](#metaprogramming)
@@ -104,6 +105,35 @@ class Sphere:
104105
radius: lc.float
105106
```
106107

108+
### Define DSL Operation for Non-DSL Types
109+
Sometimes we want to use a non-DSL type in our DSL code. Such type could be imported from a third-party library or a built-in Python type. As long as we know the object layout, we can define the DSL operation for it by first defining a proxy struct that mirrors the object layout, and then define the operation for the proxy struct.
110+
111+
```python
112+
# Assume we have a third-party library that defines a Vec3 class
113+
class Vec3:
114+
def __init__(self, x, y, z):
115+
self.x = x
116+
self.y = y
117+
self.z = z
118+
119+
@lc.struct
120+
class Vec3Proxy:
121+
x: lc.float
122+
y: lc.float
123+
z: lc.float
124+
125+
# write DSL operations here
126+
127+
lc.register_dsl_type_alias(Vec3, Vec3Proxy)
128+
129+
@lc.func
130+
def use_vec3(v: Vec3): # Vec3 is now treated as Vec3Proxy internally
131+
v.x = 1.0
132+
v.y = 2.0
133+
v.z = 3.0
134+
135+
```
136+
107137
### Generics
108138
```python
109139
T = TypeVar('T', bound=Any)

docs/luisa_lang/hir.html

+2,825-2,780
Large diffs are not rendered by default.

docs/luisa_lang/lang_builtins.html

+135-141
Large diffs are not rendered by default.

docs/search.js

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

luisa_lang/hir.py

+2
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,8 @@ def member(self, field: Any) -> Optional['Type']:
394394
def __len__(self) -> int:
395395
return self.count
396396

397+
class MatrixType(Type):
398+
pass
397399

398400
class ArrayType(Type):
399401
element: Type

luisa_lang/lang_builtins.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,14 @@ def _instantiate_array_type(args: List[Any]) -> hir.Type:
208208
_instantiate_array_type
209209
)
210210
class Array(Generic[T, N]):
211+
"""
212+
An array is a fixed-size collection of elements of type T. N need to be a Literal type.
213+
Example:
214+
```python
215+
arr = Array[int, Literal[10]]()
216+
```
217+
"""
218+
211219
def __init__(self) -> None:
212220
self = intrinsic("init.array", Array[T, N])
213221

@@ -236,6 +244,10 @@ def __len__(self) -> u64:
236244

237245
@opaque("Buffer")
238246
class Buffer(Generic[T]):
247+
"""
248+
A buffer is a contiguuos memory of elements of type T.
249+
"""
250+
239251
def __getitem__(self, index: int | u32 | i64 | u64) -> T:
240252
return intrinsic("buffer.ref", T, self, index) # type: ignore
241253

@@ -262,13 +274,30 @@ def _inst_pointer_type(args: List[Any]) -> hir.Type:
262274
_inst_pointer_type
263275
)
264276
class Pointer(Generic[T]):
277+
"""
278+
A physical pointer (just like C pointer) to a memory location of type T.
279+
Note that pointers might not be available in all backends.
280+
281+
```python
282+
p = Pointer[int](123456798) # 123456798 is the address of the memory location
283+
i = p[0] # read the value at the memory location
284+
p[0] = 10 # write the value at the memory location
285+
# alternatively
286+
i = p.read()
287+
p.write(10)
288+
# offset the pointer
289+
x = p[1]
290+
y = (p + 1).read()
291+
```
292+
"""
293+
265294
def __init__(self, addr: u64) -> None:
266295
self = intrinsic("init.pointer", Pointer[T], addr)
267296

268-
def __getitem__(self, index: int | i32 | i64 | u32 | u64) -> T:
297+
def __getitem__(self, index: i32 | i64 | u32 | u64) -> T:
269298
return intrinsic("pointer.read", T, self, index) # type: ignore
270299

271-
def __setitem__(self, index: int | i32 | i64 | u32 | u64, value: T) -> None:
300+
def __setitem__(self, index: i32 | i64 | u32 | u64, value: T) -> None:
272301
pass
273302

274303
def read(self) -> T:
@@ -277,6 +306,12 @@ def read(self) -> T:
277306
def write(self, value: T) -> None:
278307
intrinsic("pointer.write", None, self, value) # type: ignore
279308

309+
def __add__(self, offset: i32 | i64 | u32 | u64) -> 'Pointer[T]':
310+
return intrinsic("pointer.add", Pointer[T], self, offset)
311+
312+
def __sub__(self, offset: i32 | i64 | u32 | u64) -> 'Pointer[T]':
313+
return intrinsic("pointer.sub", Pointer[T], self, offset)
314+
280315

281316
__all__: List[str] = [
282317
# 'Pointer',

luisa_lang/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def increment_lineno_and_col_offset(
4141
return node
4242

4343

44-
def dedent_and_retrieve_indentation(lines: str) -> Tuple[str, int]:
44+
def dedent_and_retrieve_indentation(lines: List[str]) -> Tuple[str, int]:
4545
"""
4646
Dedent the lines and return the indentation level of the first line.
4747
"""

scripts/gen_math_types.py

+16
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,22 @@ class {ty}{inherits_str}:
212212
)
213213
print("")
214214

215+
def gen_matrix_type(ty:str, vector_ty:str, scalar_ty:str, literal_scalar_ty:str, dim:int):
216+
nonlocal exports
217+
exports.append(ty)
218+
comps = "xyzw"[:dim]
219+
fields_def = "".join([f" {comp}: {vector_ty}\n" for comp in comps])
220+
inherits:List[str] = []
221+
# if kind == Kind.FLOAT:
222+
# inherits.append(f"FloatBuiltin['{ty}']")
223+
inherits_str = "" if len(inherits) == 0 else f"({', '.join(inherits)})"
224+
print(
225+
f"""@builtin_type(_hir.MatrixType({dim}))
226+
class {ty}{inherits_str}:
227+
{fields_def}
228+
def __init__(self) -> None: self = intrinsic("init.{ty}", {ty})
229+
""")
230+
215231
float_types = ["f32", "f64"]
216232
for size in [2, 3, 4]:
217233
float_types.append(f"float{size}")

0 commit comments

Comments
 (0)