Skip to content

Commit fefeec6

Browse files
authored
Support numpy as result type when slang function returns matrix (#250)
* Support numpy as result type when slang function returns matrix close #205 When calling a slang function which returns a matrix type, we cannot use "numpy" type bound to the result. This change add the support to that by *implementing store/load method for `{RW}NDBuffer` type in slang shader, * add type resolve logic for ndbuffer python type when we bind nbbuffer to slang matrix, it will known what type to bind with. * For every test in test_transform.py, also add matrix type support. So this will cover NDBuffer and Tensor type. * For test_numpy.py add test to cover the support of binding a numpy result to slang matrix return type.
1 parent 92366f1 commit fefeec6

File tree

9 files changed

+235
-83
lines changed

9 files changed

+235
-83
lines changed

slangpy/builtin/ndbuffer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
SlangProgramLayout,
2727
SlangType,
2828
VectorType,
29+
MatrixType,
2930
StructuredBufferType,
3031
is_matching_array_type,
3132
)
@@ -111,9 +112,8 @@ def ndbuffer_resolve_type(
111112
# if implicit tensor casts enabled, allow conversion from vector to element type
112113
if context.options["implicit_tensor_casts"]:
113114
if (
114-
isinstance(bound_type, VectorType)
115-
and self.slang_element_type == bound_type.element_type
116-
):
115+
isinstance(bound_type, VectorType) or isinstance(bound_type, MatrixType)
116+
) and self.slang_element_type == bound_type.scalar_type:
117117
return bound_type
118118

119119
# Default to just casting to itself (i.e. no implicit cast)

slangpy/builtin/valueref.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def numpy_to_slang_matrix_remove_padding(
6262
(slang_type.rows, slang_type.cols),
6363
dtype=kfr.SCALAR_TYPE_TO_NUMPY_TYPE[slang_type.slang_scalar_type],
6464
)
65-
mat_remove_padding[:, : slang_type.cols] = mat_aligned
65+
mat_remove_padding = mat_aligned[:, : slang_type.cols]
6666
return python_type(mat_remove_padding)
6767

6868

slangpy/slang/core.slang

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,23 @@ public struct NDBuffer<T, let N : int>
345345
}
346346
}
347347

348-
public void store<let VD : int>(ContextND<N - 1> context, out vector<T, VD> value) {}
348+
public void store<let VD : int>(ContextND<N - 1> context, in vector<T, VD> value) {}
349+
350+
public void load<let R : int, let C: int>(ContextND<N - 2> context, out matrix<T, R, C> value) {
351+
int call_id[N];
352+
for (int i = 0; i < N - 2; i++) { call_id[i] = context.call_id[i]; }
353+
354+
for (int r = 0; r < R; r++)
355+
{
356+
call_id[N - 2] = r;
357+
for (int vi = 0; vi < C; vi++) {
358+
call_id[N - 1] = vi;
359+
value[r][vi] = get(call_id);
360+
}
361+
}
362+
}
363+
364+
public void store<let R : int, let C: int>(ContextND<N - 2> context, in matrix<T, R, C> value) {}
349365
}
350366

351367
public struct RWNDBuffer<T, let N : int>
@@ -405,6 +421,34 @@ public struct RWNDBuffer<T, let N : int>
405421
set(call_id, value[vi]);
406422
}
407423
}
424+
425+
public void load<let R : int, let C: int>(ContextND<N - 2> context, out matrix<T, R, C> value) {
426+
int call_id[N];
427+
for (int i = 0; i < N - 2; i++) { call_id[i] = context.call_id[i]; }
428+
429+
for (int r = 0; r < R; r++)
430+
{
431+
call_id[N - 2] = r;
432+
for (int vi = 0; vi < C; vi++) {
433+
call_id[N - 1] = vi;
434+
value[r][vi] = get(call_id);
435+
}
436+
}
437+
}
438+
439+
public void store<let R : int, let C: int>(ContextND<N - 2> context, in matrix<T, R, C> value) {
440+
int call_id[N];
441+
for (int i = 0; i < N - 2; i++) { call_id[i] = context.call_id[i]; }
442+
443+
for (int r = 0; r < R; r++)
444+
{
445+
call_id[N - 2] = r;
446+
for (int vi = 0; vi < C; vi++) {
447+
call_id[N - 1] = vi;
448+
set(call_id, value[r][vi]);
449+
}
450+
}
451+
}
408452
}
409453

410454
namespace impl

slangpy/tests/slangpy_tests/helpers.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,10 @@ def read_ndbuffer_from_numpy(buffer: NDBuffer) -> np.ndarray:
135135
data = np.array([])
136136
shape = np.prod(np.array(buffer.shape))
137137
for i in range(shape):
138-
data = np.append(data, cursor[i].read())
138+
element = cursor[i].read()
139+
if cursor.element_type_layout.kind == TypeReflection.Kind.matrix:
140+
element = element.to_numpy()
141+
data = np.append(data, element)
139142

140143
return data
141144

@@ -149,13 +152,19 @@ def write_ndbuffer_from_numpy(buffer: NDBuffer, data: np.ndarray, element_count:
149152
element_count = 1
150153
elif cursor.element_type_layout.kind == TypeReflection.Kind.vector:
151154
element_count = cursor.element_type.col_count
155+
elif cursor.element_type_layout.kind == TypeReflection.Kind.matrix:
156+
element_count = cursor.element_type.row_count * cursor.element_type.col_count
152157
else:
153158
raise ValueError(
154159
f"element_count not set and type is not scalar or vector: {cursor.element_type_layout.kind}"
155160
)
156161

157162
for i in range(shape):
158163
buffer_data = np.array(data[i * element_count : (i + 1) * element_count])
164+
if cursor.element_type_layout.kind == TypeReflection.Kind.matrix:
165+
buffer_data = buffer_data.reshape(
166+
cursor.element_type.row_count, cursor.element_type.col_count
167+
)
159168
cursor[i].write(buffer_data)
160169

161170
cursor.apply()

slangpy/tests/slangpy_tests/test_numpy.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
}
3838
return res;
3939
}
40+
41+
matrix<float, R, C> matFunc1<int R, int C>(matrix<float, R, C> input){
42+
return input + 2.0f;
43+
}
4044
"""
4145

4246

@@ -161,8 +165,8 @@ def test_return_numpy_matrix(device_type: DeviceType):
161165

162166
module = load_test_module(device_type)
163167

164-
for R in range(2, 4):
165-
for C in range(2, 4):
168+
for R in range(2, 5):
169+
for C in range(2, 5):
166170
funName = f"matFunc<{R}, {C}>"
167171
func = module.find_function(funName)
168172
assert func is not None
@@ -177,8 +181,8 @@ def test_return_numpy_matrix(device_type: DeviceType):
177181
def test_setup_numpy_matrix(device_type: DeviceType):
178182

179183
module = load_test_module(device_type)
180-
for R in range(2, 4):
181-
for C in range(2, 4):
184+
for R in range(2, 5):
185+
for C in range(2, 5):
182186
funName = f"flattenMatrix<{R}, {C}>"
183187
func = module.find_function(funName)
184188
assert func is not None
@@ -189,5 +193,25 @@ def test_setup_numpy_matrix(device_type: DeviceType):
189193
assert np.allclose(res["data"][0 : R * C], np.ones(R * C))
190194

191195

196+
# test that we can use "numpy" as the result type for a function that returns a matrix
197+
@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES)
198+
def test_numpy_matrix_as_result(device_type: DeviceType):
199+
200+
module = load_test_module(device_type)
201+
202+
for R in range(2, 5):
203+
for C in range(2, 5):
204+
funName = f"matFunc1<{R}, {C}>"
205+
func = module.find_function(funName)
206+
assert func is not None
207+
matType = getattr(spy, f"float{R}x{C}")
208+
N = R * C
209+
res = func(matType(np.arange(1, N + 1).reshape(R, C)), _result="numpy")
210+
211+
assert res is not None
212+
assert res.shape == (R, C)
213+
assert np.allclose(res, (np.arange(1, N + 1) + 2.0).reshape(R, C))
214+
215+
192216
if __name__ == "__main__":
193217
pytest.main([__file__, "-v", "-s"])

slangpy/tests/slangpy_tests/test_tensor_with_grads.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,6 @@
99
import sys
1010

1111

12-
if sys.platform == "darwin":
13-
pytest.skip(
14-
"Skipping on macOS: Waiting for slang-gfx fix for resource clear API https://github.com/shader-slang/slang/issues/6640",
15-
allow_module_level=True,
16-
)
17-
18-
1912
def get_test_tensors(device: Device, N: int = 4):
2013
np.random.seed(0)
2114

0 commit comments

Comments
 (0)