diff --git a/CHANGELOG.md b/CHANGELOG.md index ec45d4422..7d596b7e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - `warp.autograd.gradcheck`, `function_jacobian`, `function_jacobian_fd` now also accept arbitrary Python functions that have Warp arrays as inputs and outputs. - `warp.autograd.gradcheck_tape` now has additional optional arguments `reverse_launches` and `skip_to_launch_index`. +- Added preview of Tile Cholesky factorization and solve APIs through `tile_cholesky` and `tile_cholesky_solve` and the `tile_add_diag` helper function. Those are preview APIs and subject to change. - Added preview of Tile Cholesky factorization and solve APIs through `tile_cholesky` and `tile_cholesky_solve`, as well as helpers `tile_tril` and `tile_add_diag`. Those are preview APIs and subject to change. - Support `assert` statements in kernels ([docs](https://nvidia.github.io/warp/debugging.html#assertions)). Assertions can only be triggered in `"debug"` mode ([GH-366](https://github.com/NVIDIA/warp/issues/336)). diff --git a/docs/modules/functions.rst b/docs/modules/functions.rst index 8c2fc6f0c..2d1c1a034 100644 --- a/docs/modules/functions.rst +++ b/docs/modules/functions.rst @@ -1241,11 +1241,6 @@ Tile Primitives Add a square matrix and a diagonal matrix -.. py:function:: tile_tril(a: Tile) -> Tile - - Zeroes the upper-triangular part of a square matrix and keep the lower triangular part + diagonal - - .. py:function:: tile_matmul(a: Tile, b: Tile, out: Tile) -> Tile Computes the matrix product and accumulates ``out += a*b``. @@ -1310,7 +1305,7 @@ Tile Primitives :param inout: The input/output tile -.. py:function:: tile_cholesky(A: Tile) -> None +.. py:function:: tile_cholesky(A: Tile) -> Tile Compute the Cholesky factorization L of a matrix A. L is lower triangular and satisfies LL^T = A. @@ -1321,7 +1316,8 @@ Tile Primitives * float32 * float64 - :param A: As input, the matrix A. As output, L. + :param A: A square, symmetric positive-definite, matrix. + :returns L: A square, lower triangular, matrix, such that LL^T = A .. py:function:: tile_cholesky_solve(L: Tile, x: Tile) -> None @@ -1334,8 +1330,9 @@ Tile Primitives * float32 * float64 - :param L: The triangular matrix output of tile_cholesky, - :param x: As input, the right hand side y. As output, the solution x. + :param L: A square, lower triangular, matrix, such that LL^T = A + :param x: An Mx1 tile + :returns y: An Mx1 tile such that LL^T y = x diff --git a/docs/modules/tiles.rst b/docs/modules/tiles.rst index 2ae5bde44..f6d92cf4d 100644 --- a/docs/modules/tiles.rst +++ b/docs/modules/tiles.rst @@ -224,7 +224,6 @@ Linear Algebra * :func:`tile_cholesky` * :func:`tile_cholesky_solve` * :func:`tile_diag_add` -* :func:`tile_tril` Tiles and SIMT Code ------------------- diff --git a/warp/builtins.py b/warp/builtins.py index cdb99cbc7..f4a4a4d4b 100644 --- a/warp/builtins.py +++ b/warp/builtins.py @@ -5745,7 +5745,7 @@ def tile_scalar_mul_value_func(arg_types, arg_values): ) -def tile_diag_add_map_value_func(arg_types, arg_values): +def tile_diag_add_value_func(arg_types, arg_values): if arg_types is None: return Tile(dtype=Any, M=Any, N=Any) @@ -5753,28 +5753,40 @@ def tile_diag_add_map_value_func(arg_types, arg_values): b = arg_types["b"] # check all args are tiles - if not is_tile(a): - raise TypeError(f"tile_diag_add() arguments must be tiles, got type {a}") - - if not is_tile(b): - raise TypeError(f"tile_diag_add() arguments must be tiles, got type {b}") + if not is_tile(a) or not is_tile(b): + raise TypeError("tile_diag_add() arguments must be tiles") # use first argument to define output type if not types_equal(a.dtype, b.dtype): raise TypeError(f"tile_diag_add() arguments must all have the same type {a.dtype} != {b.dtype}") if a.M != a.N or a.M != b.M or b.N != 1: - raise ValueError("tile_diag_add() arguments must be square (first) and 1D (second)") + raise ValueError("tile_diag_add() first argument must be square and the second must be 1D") - return None + return Tile(dtype=a.dtype, M=a.M, N=a.N, storage="shared") + + +def tile_diag_add_dispatch_func( + arg_types: Mapping[str, type], + return_type: Any, + return_values: List[Var], + arg_values: Mapping[str, Var], + options: Mapping[str, Any], + builder: warp.context.ModuleBuilder, +): + a = arg_values["a"] + b = arg_values["b"] + a.type.storage = "shared" + b.type.storage = "shared" + out = return_values[0] + return ((a, b, out), [], [], 0) add_builtin( "tile_diag_add", input_types={"a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)}, - value_func=tile_diag_add_map_value_func, - # dispatch_func=tile_map_dispatch_func, - # variadic=True, + value_func=tile_diag_add_value_func, + lto_dispatch_func=tile_diag_add_dispatch_func, native_func="tile_diag_add", doc="Add a square matrix and a diagonal matrix", group="Tile Primitives", @@ -5782,31 +5794,6 @@ def tile_diag_add_map_value_func(arg_types, arg_values): ) -def tile_tril_value_func(arg_types, arg_values): - if arg_types is None: - return Tile(dtype=Any, M=Any, N=Any) - - a = arg_types["a"] - - if not is_tile(a): - raise TypeError(f"tile_tril() arguments must be tiles, got type {a}") - - if a.M != a.N: - raise ValueError("tile_tril() arguments must be square") - - return None - - -add_builtin( - "tile_tril", - input_types={"a": Tile(dtype=Any, M=Any, N=Any)}, - value_func=tile_tril_value_func, - native_func="tile_tril", - doc="Zeroes the upper-triangular part of a square matrix and keep the lower triangular part + diagonal", - group="Tile Primitives", - export=False, -) - ## ## MathDx, LTOIR-based, Tile functions ## @@ -6225,15 +6212,17 @@ def tile_fft_generic_lto_dispatch_func( ## def tile_cholesky_generic_value_func(arg_types, arg_values): if arg_types is None: - return None + return Tile(dtype=Any, M=Any, N=Any) if len(arg_types) != 1: raise TypeError("tile_cholesky() requires 1 positional args") - if not is_tile(arg_types["A"]): - raise TypeError("tile_cholesky() argument 0 must be a tile") + a = arg_types["A"] - return None + if not is_tile(a) or a.M != a.N: + raise TypeError("tile_cholesky() argument 0 must be a square tile") + + return Tile(dtype=a.dtype, M=a.M, N=a.N, storage="shared") def tile_cholesky_solve_generic_value_func(arg_types, arg_values): @@ -6243,10 +6232,12 @@ def tile_cholesky_solve_generic_value_func(arg_types, arg_values): if len(arg_types) != 2: raise TypeError("tile_cholesky_solve() requires 2 positional args") + l = arg_types["L"] + if not is_tile(arg_types["L"]) or not is_tile(arg_types["x"]): raise TypeError("tile_cholesky() argument 0 and 1 must be tiles") - return None + return Tile(dtype=l.dtype, M=l.M, N=1, storage="shared") cusolver_function_map = {"getrf": 0, "getrf_no_pivot": 1, "potrf": 2, "potrs": 3} @@ -6264,20 +6255,24 @@ def tile_cholesky_generic_lto_dispatch_func( options: Mapping[str, Any], builder: warp.context.ModuleBuilder, ): - inout = arg_values["A"] - inout.type.storage = "shared" + a = arg_values["A"] + a.type.storage = "shared" - if not is_tile(inout.type): + if not is_tile(a.type): raise TypeError("tile_cholesky() arguments must be a single tile with shared storage") - if inout.type.dtype not in cusolver_type_map.keys(): - raise TypeError("tile_cholesky() argument must be a tile of float64 entries") + if a.type.dtype not in cusolver_type_map.keys(): + raise TypeError("tile_cholesky() argument must be a tile of float32 or float64 entries") - dtype, precision_enum = cusolver_type_map[inout.type.dtype] + if len(return_values) != 1: + raise TypeError("tile_cholesky() returns one output") + out = return_values[0] - M, N = inout.type.M, inout.type.N - if M != N: - raise ValueError("Tile must be square") + dtype, precision_enum = cusolver_type_map[a.type.dtype] + + M, N = a.type.M, a.type.N + if M != N or out.type.M != M or out.type.N != M: + raise ValueError("Input and output Tile must be square") num_threads = options["block_dim"] arch = options["output_arch"] @@ -6321,14 +6316,13 @@ def tile_cholesky_generic_lto_dispatch_func( lto_code_path.unlink() builder.ltoirs[lto_symbol] = lto_code_data + builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, unsigned);" return ( ( Var(lto_symbol, str, False, True, False), - Var(dtype, str, False, True, False), - Var(str(M), str, False, True, False), - Var(str(N), str, False, True, False), - inout, + a, + out, ), [], [lto_code_data], @@ -6345,27 +6339,31 @@ def tile_cholesky_solve_generic_lto_dispatch_func( builder: warp.context.ModuleBuilder, ): L = arg_values["L"] - inout = arg_values["x"] + x = arg_values["x"] L.type.storage = "shared" - inout.type.storage = "shared" + x.type.storage = "shared" + + if len(return_values) != 1: + raise TypeError("tile_cholesky_solve() returns one output") + y = return_values[0] - if not is_tile(inout.type) or not is_tile(L.type): + if not is_tile(x.type) or not is_tile(L.type): raise TypeError("tile_cholesky_solve() arguments must be two tile with shared storage") - if inout.type.dtype != L.type.dtype or any( - T not in cusolver_type_map.keys() for T in [inout.type.dtype, L.type.dtype] - ): + if x.type.dtype != L.type.dtype or any(T not in cusolver_type_map.keys() for T in [x.type.dtype, L.type.dtype]): raise TypeError("tile_cholesky_solve() arguments be tiles of float64 or float32 tiles") - dtype, precision_enum = cusolver_type_map[inout.type.dtype] - + dtype, precision_enum = cusolver_type_map[L.type.dtype] M, N = L.type.M, L.type.N if M != N: - raise ValueError("L Tile must be square") + raise ValueError("L tile must be square") - if inout.type.M != M or inout.type.N != 1: - raise ValueError(f"Right-hand side Tile must be {M}x1") + if x.type.M != M or x.type.N != 1: + raise ValueError(f"Input vector of tile_cholesky_solve must be {M}x1") + + if y.type.M != M or y.type.N != 1: + raise ValueError(f"Output vectir of tile_cholesky_solve be {M}x1") num_threads = options["block_dim"] arch = options["output_arch"] @@ -6409,15 +6407,14 @@ def tile_cholesky_solve_generic_lto_dispatch_func( lto_code_path.unlink() builder.ltoirs[lto_symbol] = lto_code_data + builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, {dtype}*);" return ( ( Var(lto_symbol, str, False, True, False), - Var(dtype, str, False, True, False), - Var(str(M), str, False, True, False), - Var(str(N), str, False, True, False), L, - inout, + x, + y, ), [], [lto_code_data], @@ -6440,7 +6437,8 @@ def tile_cholesky_solve_generic_lto_dispatch_func( * float32 * float64 - :param A: As input, the matrix A. As output, L.""", + :param A: A square, symmetric positive-definite, matrix. + :returns L: A square, lower triangular, matrix, such that LL^T = A""", group="Tile Primitives", export=False, namespace="", @@ -6460,8 +6458,9 @@ def tile_cholesky_solve_generic_lto_dispatch_func( * float32 * float64 - :param L: The triangular matrix output of tile_cholesky, - :param x: As input, the right hand side y. As output, the solution x.""", + :param L: A square, lower triangular, matrix, such that LL^T = A + :param x: An Mx1 tile + :returns y: An Mx1 tile such that LL^T y = x""", group="Tile Primitives", export=False, namespace="", diff --git a/warp/examples/tile/example_tile_cholesky.py b/warp/examples/tile/example_tile_cholesky.py index 361e5f52a..6bdd253bf 100644 --- a/warp/examples/tile/example_tile_cholesky.py +++ b/warp/examples/tile/example_tile_cholesky.py @@ -38,13 +38,12 @@ def cholesky( i, j, _ = wp.tid() a = wp.tile_load(A, i, j, m=TILE, n=TILE) - wp.tile_cholesky(a) - wp.tile_tril(a) - wp.tile_store(L, i, j, a) + l = wp.tile_cholesky(a) + wp.tile_store(L, i, j, l) x = wp.tile_load(X, i, j, m=TILE, n=1) - wp.tile_cholesky_solve(a, x) - wp.tile_store(Y, i, j, x) + y = wp.tile_cholesky_solve(l, x) + wp.tile_store(Y, i, j, y) if __name__ == "__main__": @@ -62,7 +61,7 @@ def cholesky( X_wp = wp.array2d(X_h, dtype=wp_type) Y_wp = wp.array2d(Y_h, dtype=wp_type) - wp.launch(cholesky, dim=[1, 1, BLOCK_DIM], inputs=[A_wp, L_wp, X_wp, Y_wp], block_dim=BLOCK_DIM) + wp.launch_tiled(cholesky, dim=[1, 1], inputs=[A_wp, L_wp, X_wp, Y_wp], block_dim=BLOCK_DIM) L_np = np.linalg.cholesky(A_h) Y_np = np.linalg.solve(A_h, X_h) diff --git a/warp/native/tile.h b/warp/native/tile.h index 9dc3baea9..389b291d3 100644 --- a/warp/native/tile.h +++ b/warp/native/tile.h @@ -1782,30 +1782,63 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \ } while (0) -#define tile_cholesky(function_name, dtype, M, N, Xinout) \ - do { \ - void function_name(dtype*, unsigned); \ - WP_TILE_SYNC(); \ - function_name(Xinout.data.ptr, M); \ - WP_TILE_SYNC(); \ - } while (0) +template +TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L) +{ + // Copy to L + + L = A; + + // Call cholesky on L + + WP_TILE_SYNC(); + + fun_forward(L.data.ptr, L.M); + + WP_TILE_SYNC(); + + // Zero-out the upper triangular part of L + + WP_PRAGMA_UNROLL + for (int i=threadIdx.x; i < L.Size ; i += WP_TILE_BLOCK_DIM) + { + coord_t c = L.coord(i); + if(c.i < c.j) { + L.data(i) = 0.0; + } + } + + WP_TILE_SYNC(); + + return L; +} -#define adj_tile_cholesky(function_name, dtype, M, N, Xinout, \ - adj_function_name, adj_dtype, adj_M, adj_N, adj_Xinout) \ +#define adj_tile_cholesky(function_name, A, L, \ + adj_function_name, adj_A, adj_L, adj_ret) \ do { \ assert(false); \ } while (0) -#define tile_cholesky_solve(function_name, dtype, M, N, L, Xinout) \ - do { \ - void function_name(dtype*, dtype*); \ - WP_TILE_SYNC(); \ - function_name(L.data.ptr, Xinout.data.ptr); \ - WP_TILE_SYNC(); \ - } while (0) +template +TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y) +{ + // Copy x to y + + Y = X; + + // Call cholesky solve on L & y -#define adj_tile_cholesky_solve(function_name, dtype, M, N, L, Xinout, \ - adj_function_name, adj_dtype, adj_M, adj_N, adj_L, adj_Xinout) \ + WP_TILE_SYNC(); + + fun_forward(L.data.ptr, Y.data.ptr); \ + + WP_TILE_SYNC(); + + return Y; +} + +#define adj_tile_cholesky_solve(function_name, L, X, Y, \ + adj_function_name, adj_L, adj_X, adj_Y, adj_ret) \ do { \ assert(false); \ } while (0) @@ -1866,46 +1899,29 @@ inline CUDA_CALLABLE void tile_assign(TileA& dest, int i, int j, TileB& src) WP_TILE_SYNC(); } -template -inline CUDA_CALLABLE void tile_diag_add(TileA& inout, TileB& diag) +template +inline CUDA_CALLABLE TileC& tile_diag_add(TileA& a, TileB& b, TileC& c) { static_assert(TileA::M == TileA::N); static_assert(TileB::M == TileA::M); static_assert(TileB::N == 1); - - for (int t=threadIdx.x; t < TileA::M; t += WP_TILE_BLOCK_DIM) - { - inout.data(t, t) += diag.data(t, 1); - } + static_assert(TileC::M == TileA::M); + static_assert(TileC::M == TileC::N); - WP_TILE_SYNC(); -} - -template -inline CUDA_CALLABLE void tile_tril(Tile& inout) -{ - static_assert(Tile::M == Tile::N); + c = a; - for (int t=threadIdx.x; t < inout.Size; t += WP_TILE_BLOCK_DIM) + for (int t=threadIdx.x; t < TileA::M; t += WP_TILE_BLOCK_DIM) { - coord_t c = inout.coord(t); - if(c.i < c.j) - { - inout.data(c.i, c.j) = 0.0; - } + c.data(t, t) += b.data(t, 0); } WP_TILE_SYNC(); -} -template -inline CUDA_CALLABLE void adj_tile_diag_add(TileA& inout, TileB& diag, AdjTileA& adj_inout, AdjTileB& adj_diag) -{ - assert(false); + return c; } -template -inline CUDA_CALLABLE void adj_tile_tril(Tile& inout, AdjTile& adj_inout) +template +inline CUDA_CALLABLE void adj_tile_diag_add(TileA& a, TileB& b, TileC& c, AdjTileA& adj_a, AdjTileB& adj_b, AdjTileC& adj_c, AdjTileC& adj_ret) { assert(false); } diff --git a/warp/sim/integrator_featherstone.py b/warp/sim/integrator_featherstone.py index 1cb6bc2e9..cedcc29c1 100644 --- a/warp/sim/integrator_featherstone.py +++ b/warp/sim/integrator_featherstone.py @@ -1239,12 +1239,9 @@ def eval_dense_gemm_and_cholesky_tile( # cholesky L L^T = (H + diag(R)) R = wp.tile_load(R_arr[articulation], 0, 0, m=num_dofs, n=1, storage="shared") - wp.tile_diag_add(H, R) - wp.tile_cholesky(H) - # We should somehow encode the tril nature of the matrix in the type - # and just store the lower part, instead of having to do this - wp.tile_tril(H) - wp.tile_store(L_arr[articulation], 0, 0, H) + H_R = wp.tile_diag_add(H, R) + L = wp.tile_cholesky(H_R) + wp.tile_store(L_arr[articulation], 0, 0, L) return eval_dense_gemm_and_cholesky_tile diff --git a/warp/stubs.py b/warp/stubs.py index 064830bea..844aea958 100644 --- a/warp/stubs.py +++ b/warp/stubs.py @@ -2914,12 +2914,6 @@ def tile_diag_add(a: Tile, b: Tile) -> Tile: ... -@over -def tile_tril(a: Tile) -> Tile: - """Zeroes the upper-triangular part of a square matrix and keep the lower triangular part + diagonal""" - ... - - @over def tile_matmul(a: Tile, b: Tile, out: Tile) -> Tile: """Computes the matrix product and accumulates ``out += a*b``. @@ -2991,7 +2985,7 @@ def tile_ifft(inout: Tile) -> Tile: @over -def tile_cholesky(A: Tile): +def tile_cholesky(A: Tile) -> Tile: """Compute the Cholesky factorization L of a matrix A. L is lower triangular and satisfies LL^T = A. @@ -3001,7 +2995,8 @@ def tile_cholesky(A: Tile): * float32 * float64 - :param A: As input, the matrix A. As output, L. + :param A: A square, symmetric positive-definite, matrix. + :returns L: A square, lower triangular, matrix, such that LL^T = A """ ... @@ -3016,8 +3011,9 @@ def tile_cholesky_solve(L: Tile, x: Tile): * float32 * float64 - :param L: The triangular matrix output of tile_cholesky, - :param x: As input, the right hand side y. As output, the solution x. + :param L: A square, lower triangular, matrix, such that LL^T = A + :param x: An Mx1 tile + :returns y: An Mx1 tile such that LL^T y = x """ ... diff --git a/warp/tests/test_tile_mathdx.py b/warp/tests/test_tile_mathdx.py index dc45238bb..263054723 100644 --- a/warp/tests/test_tile_mathdx.py +++ b/warp/tests/test_tile_mathdx.py @@ -114,6 +114,56 @@ def test_tile_math_fft(test, device, wp_dtype): # TODO: implement and test backward pass +@wp.kernel() +def tile_math_cholesky( + gA: wp.array2d(dtype=wp.float64), + gD: wp.array2d(dtype=wp.float64), + gL: wp.array2d(dtype=wp.float64), + gx: wp.array2d(dtype=wp.float64), + gy: wp.array2d(dtype=wp.float64), +): + i, j = wp.tid() + # Load A, D & x + a = wp.tile_load(gA, i, j, m=TILE_M, n=TILE_M, storage="shared") + d = wp.tile_load(gD, i, j, m=TILE_M, n=1, storage="shared") + x = wp.tile_load(gx, i, j, m=TILE_M, n=1, storage="shared") + # Compute L st LL^T = A + diag(D) + b = wp.tile_diag_add(a, d) + l = wp.tile_cholesky(b) + # Solve for y in LL^T y = x + y = wp.tile_cholesky_solve(l, x) + # Store L & y + wp.tile_store(gL, i, j, l) + wp.tile_store(gy, i, j, y) + + +def test_tile_math_cholesky(test, device): + A_h = np.ones((TILE_M, TILE_M), dtype=np.float64) + D_h = 8.0 * np.ones((TILE_M, 1), dtype=np.float64) + L_h = np.zeros_like(A_h) + X_h = np.arange(TILE_M, dtype=np.float64).reshape((TILE_M, 1)) + Y_h = np.zeros_like(X_h) + + A_np = A_h + np.diag(D_h.reshape((-1,)), 0) + L_np = np.linalg.cholesky(A_np) + Y_np = np.linalg.solve(A_np, X_h) + + A_wp = wp.array2d(A_h, requires_grad=True, dtype=wp.float64, device=device) + D_wp = wp.array2d(D_h, requires_grad=True, dtype=wp.float64, device=device) + L_wp = wp.array2d(L_h, requires_grad=True, dtype=wp.float64, device=device) + X_wp = wp.array2d(X_h, requires_grad=True, dtype=wp.float64, device=device) + Y_wp = wp.array2d(Y_h, requires_grad=True, dtype=wp.float64, device=device) + + wp.launch_tiled( + tile_math_cholesky, dim=[1, 1], inputs=[A_wp, D_wp, L_wp, X_wp, Y_wp], block_dim=TILE_DIM, device=device + ) + wp.synchronize_device() + + assert np.allclose(Y_wp.numpy(), Y_np) and np.allclose(L_wp.numpy(), L_np) + + # TODO: implement and test backward pass + + devices = get_cuda_test_devices() @@ -124,6 +174,9 @@ class TestTileMathDx(unittest.TestCase): # check_output=False so we can enable libmathdx's logging without failing the tests add_function_test(TestTileMathDx, "test_tile_math_matmul", test_tile_math_matmul, devices=devices, check_output=False) +add_function_test( + TestTileMathDx, "test_tile_math_cholesky", test_tile_math_cholesky, devices=devices, check_output=False +) add_function_test( TestTileMathDx, "test_tile_math_fft_vec2f",