@@ -121,8 +121,11 @@ def device_assert(cond: bool, msg: str = "") -> typing.NoReturn:
121
121
raise NotImplementedError (
122
122
"device_assert should not be called in host-side Python code. " )
123
123
124
+
124
125
def sizeof (t : type [T ]) -> u64 :
125
- raise NotImplementedError ("sizeof should not be called in host-side Python code. " )
126
+ raise NotImplementedError (
127
+ "sizeof should not be called in host-side Python code. " )
128
+
126
129
127
130
@overload
128
131
def range (n : T ) -> List [T ]: ...
@@ -208,11 +211,11 @@ class Array(Generic[T, N]):
208
211
def __init__ (self ) -> None :
209
212
self = intrinsic ("init.array" , Array [T , N ])
210
213
211
- def __getitem__ (self , index : int | u32 | u64 ) -> T :
214
+ def __getitem__ (self , index : int | i32 | u32 | i64 | u64 ) -> T :
212
215
return intrinsic ("array.ref" , T , byref (self ), index ) # type: ignore
213
216
214
- def __setitem__ (self , index : int | u32 | u64 , value : T ) -> None :
215
- pass
217
+ def __setitem__ (self , index : int | i32 | u32 | i64 | u64 , value : T | int | float ) -> None :
218
+ """value: T | int | float annotation is to make mypy happy. this function is ignored by the compiler"""
216
219
217
220
def __len__ (self ) -> u64 :
218
221
return intrinsic ("array.size" , u64 , self ) # type: ignore
@@ -233,10 +236,11 @@ def __len__(self) -> u64:
233
236
234
237
@opaque ("Buffer" )
235
238
class Buffer (Generic [T ]):
236
- def __getitem__ (self , index : int | u32 | u64 ) -> T :
239
+ def __getitem__ (self , index : int | i32 | u32 | i64 | u64 ) -> T :
237
240
return intrinsic ("buffer.ref" , T , self , index ) # type: ignore
238
241
239
- def __setitem__ (self , index : int | u32 | u64 , value : T ) -> None :
242
+ def __setitem__ (self , index : int | i32 | u32 | i64 | u64 , value : T | int | float ) -> None :
243
+ """value: T | int | float annotation is to make mypy happy. this function is ignored by the compiler"""
240
244
pass
241
245
242
246
def __len__ (self ) -> u64 :
0 commit comments