Skip to content

Commit

Permalink
DX-631 Cholesky: making all functions pure, removing tile_tril, addin…
Browse files Browse the repository at this point in the history
…g tests
  • Loading branch information
nvlcambier authored and mmacklin committed Jan 16, 2025
1 parent 23b00ae commit ac5b31f
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 149 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Added

- Added preview of Tile Cholesky factorization and solve APIs through `tile_cholesky` and `tile_cholesky_solve` and the `tile_add_diag` helper function. Those are preview APIs and subject to change.
- Added preview of Tile Cholesky factorization and solve APIs through `tile_cholesky` and `tile_cholesky_solve`, as well as helpers `tile_tril` and `tile_add_diag`. Those are preview APIs and subject to change.
- Support `assert` statements in kernels ([docs](https://nvidia.github.io/warp/debugging.html#assertions)).
Assertions can only be triggered in `"debug"` mode ([GH-366](https://github.com/NVIDIA/warp/issues/336)).
Expand Down
15 changes: 6 additions & 9 deletions docs/modules/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1241,11 +1241,6 @@ Tile Primitives
Add a square matrix and a diagonal matrix


.. py:function:: tile_tril(a: Tile) -> Tile
Zeroes the upper-triangular part of a square matrix and keep the lower triangular part + diagonal


.. py:function:: tile_matmul(a: Tile, b: Tile, out: Tile) -> Tile
Computes the matrix product and accumulates ``out += a*b``.
Expand Down Expand Up @@ -1310,7 +1305,7 @@ Tile Primitives
:param inout: The input/output tile


.. py:function:: tile_cholesky(A: Tile) -> None
.. py:function:: tile_cholesky(A: Tile) -> Tile
Compute the Cholesky factorization L of a matrix A.
L is lower triangular and satisfies LL^T = A.
Expand All @@ -1321,7 +1316,8 @@ Tile Primitives
* float32
* float64

:param A: As input, the matrix A. As output, L.
:param A: A square, symmetric positive-definite, matrix.
:returns L: A square, lower triangular, matrix, such that LL^T = A


.. py:function:: tile_cholesky_solve(L: Tile, x: Tile) -> None
Expand All @@ -1334,8 +1330,9 @@ Tile Primitives
* float32
* float64

:param L: The triangular matrix output of tile_cholesky,
:param x: As input, the right hand side y. As output, the solution x.
:param L: A square, lower triangular, matrix, such that LL^T = A
:param x: An Mx1 tile
:returns y: An Mx1 tile such that LL^T y = x



Expand Down
1 change: 0 additions & 1 deletion docs/modules/tiles.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ Linear Algebra
* :func:`tile_cholesky`
* :func:`tile_cholesky_solve`
* :func:`tile_diag_add`
* :func:`tile_tril`

Tiles and SIMT Code
-------------------
Expand Down
143 changes: 71 additions & 72 deletions warp/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5745,68 +5745,55 @@ def tile_scalar_mul_value_func(arg_types, arg_values):
)


def tile_diag_add_map_value_func(arg_types, arg_values):
def tile_diag_add_value_func(arg_types, arg_values):
if arg_types is None:
return Tile(dtype=Any, M=Any, N=Any)

a = arg_types["a"]
b = arg_types["b"]

# check all args are tiles
if not is_tile(a):
raise TypeError(f"tile_diag_add() arguments must be tiles, got type {a}")

if not is_tile(b):
raise TypeError(f"tile_diag_add() arguments must be tiles, got type {b}")
if not is_tile(a) or not is_tile(b):
raise TypeError("tile_diag_add() arguments must be tiles")

# use first argument to define output type
if not types_equal(a.dtype, b.dtype):
raise TypeError(f"tile_diag_add() arguments must all have the same type {a.dtype} != {b.dtype}")

if a.M != a.N or a.M != b.M or b.N != 1:
raise ValueError("tile_diag_add() arguments must be square (first) and 1D (second)")
raise ValueError("tile_diag_add() first argument must be square and the second must be 1D")

return None
return Tile(dtype=a.dtype, M=a.M, N=a.N, storage="shared")


def tile_diag_add_dispatch_func(
arg_types: Mapping[str, type],
return_type: Any,
return_values: List[Var],
arg_values: Mapping[str, Var],
options: Mapping[str, Any],
builder: warp.context.ModuleBuilder,
):
a = arg_values["a"]
b = arg_values["b"]
a.type.storage = "shared"
b.type.storage = "shared"
out = return_values[0]
return ((a, b, out), [], [], 0)


add_builtin(
"tile_diag_add",
input_types={"a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)},
value_func=tile_diag_add_map_value_func,
# dispatch_func=tile_map_dispatch_func,
# variadic=True,
value_func=tile_diag_add_value_func,
lto_dispatch_func=tile_diag_add_dispatch_func,
native_func="tile_diag_add",
doc="Add a square matrix and a diagonal matrix",
group="Tile Primitives",
export=False,
)


def tile_tril_value_func(arg_types, arg_values):
if arg_types is None:
return Tile(dtype=Any, M=Any, N=Any)

a = arg_types["a"]

if not is_tile(a):
raise TypeError(f"tile_tril() arguments must be tiles, got type {a}")

if a.M != a.N:
raise ValueError("tile_tril() arguments must be square")

return None


add_builtin(
"tile_tril",
input_types={"a": Tile(dtype=Any, M=Any, N=Any)},
value_func=tile_tril_value_func,
native_func="tile_tril",
doc="Zeroes the upper-triangular part of a square matrix and keep the lower triangular part + diagonal",
group="Tile Primitives",
export=False,
)

##
## MathDx, LTOIR-based, Tile functions
##
Expand Down Expand Up @@ -6225,15 +6212,17 @@ def tile_fft_generic_lto_dispatch_func(
##
def tile_cholesky_generic_value_func(arg_types, arg_values):
if arg_types is None:
return None
return Tile(dtype=Any, M=Any, N=Any)

if len(arg_types) != 1:
raise TypeError("tile_cholesky() requires 1 positional args")

if not is_tile(arg_types["A"]):
raise TypeError("tile_cholesky() argument 0 must be a tile")
a = arg_types["A"]

return None
if not is_tile(a) or a.M != a.N:
raise TypeError("tile_cholesky() argument 0 must be a square tile")

return Tile(dtype=a.dtype, M=a.M, N=a.N, storage="shared")


def tile_cholesky_solve_generic_value_func(arg_types, arg_values):
Expand All @@ -6243,10 +6232,12 @@ def tile_cholesky_solve_generic_value_func(arg_types, arg_values):
if len(arg_types) != 2:
raise TypeError("tile_cholesky_solve() requires 2 positional args")

l = arg_types["L"]

if not is_tile(arg_types["L"]) or not is_tile(arg_types["x"]):
raise TypeError("tile_cholesky() argument 0 and 1 must be tiles")

return None
return Tile(dtype=l.dtype, M=l.M, N=1, storage="shared")


cusolver_function_map = {"getrf": 0, "getrf_no_pivot": 1, "potrf": 2, "potrs": 3}
Expand All @@ -6264,20 +6255,24 @@ def tile_cholesky_generic_lto_dispatch_func(
options: Mapping[str, Any],
builder: warp.context.ModuleBuilder,
):
inout = arg_values["A"]
inout.type.storage = "shared"
a = arg_values["A"]
a.type.storage = "shared"

if not is_tile(inout.type):
if not is_tile(a.type):
raise TypeError("tile_cholesky() arguments must be a single tile with shared storage")

if inout.type.dtype not in cusolver_type_map.keys():
raise TypeError("tile_cholesky() argument must be a tile of float64 entries")
if a.type.dtype not in cusolver_type_map.keys():
raise TypeError("tile_cholesky() argument must be a tile of float32 or float64 entries")

dtype, precision_enum = cusolver_type_map[inout.type.dtype]
if len(return_values) != 1:
raise TypeError("tile_cholesky() returns one output")
out = return_values[0]

M, N = inout.type.M, inout.type.N
if M != N:
raise ValueError("Tile must be square")
dtype, precision_enum = cusolver_type_map[a.type.dtype]

M, N = a.type.M, a.type.N
if M != N or out.type.M != M or out.type.N != M:
raise ValueError("Input and output Tile must be square")

num_threads = options["block_dim"]
arch = options["output_arch"]
Expand Down Expand Up @@ -6321,14 +6316,13 @@ def tile_cholesky_generic_lto_dispatch_func(
lto_code_path.unlink()

builder.ltoirs[lto_symbol] = lto_code_data
builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, unsigned);"

return (
(
Var(lto_symbol, str, False, True, False),
Var(dtype, str, False, True, False),
Var(str(M), str, False, True, False),
Var(str(N), str, False, True, False),
inout,
a,
out,
),
[],
[lto_code_data],
Expand All @@ -6345,27 +6339,31 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
builder: warp.context.ModuleBuilder,
):
L = arg_values["L"]
inout = arg_values["x"]
x = arg_values["x"]
L.type.storage = "shared"
inout.type.storage = "shared"
x.type.storage = "shared"

if len(return_values) != 1:
raise TypeError("tile_cholesky_solve() returns one output")
y = return_values[0]

if not is_tile(inout.type) or not is_tile(L.type):
if not is_tile(x.type) or not is_tile(L.type):
raise TypeError("tile_cholesky_solve() arguments must be two tile with shared storage")

if inout.type.dtype != L.type.dtype or any(
T not in cusolver_type_map.keys() for T in [inout.type.dtype, L.type.dtype]
):
if x.type.dtype != L.type.dtype or any(T not in cusolver_type_map.keys() for T in [x.type.dtype, L.type.dtype]):
raise TypeError("tile_cholesky_solve() arguments be tiles of float64 or float32 tiles")

dtype, precision_enum = cusolver_type_map[inout.type.dtype]

dtype, precision_enum = cusolver_type_map[L.type.dtype]
M, N = L.type.M, L.type.N

if M != N:
raise ValueError("L Tile must be square")
raise ValueError("L tile must be square")

if inout.type.M != M or inout.type.N != 1:
raise ValueError(f"Right-hand side Tile must be {M}x1")
if x.type.M != M or x.type.N != 1:
raise ValueError(f"Input vector of tile_cholesky_solve must be {M}x1")

if y.type.M != M or y.type.N != 1:
raise ValueError(f"Output vectir of tile_cholesky_solve be {M}x1")

num_threads = options["block_dim"]
arch = options["output_arch"]
Expand Down Expand Up @@ -6409,15 +6407,14 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
lto_code_path.unlink()

builder.ltoirs[lto_symbol] = lto_code_data
builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, {dtype}*);"

return (
(
Var(lto_symbol, str, False, True, False),
Var(dtype, str, False, True, False),
Var(str(M), str, False, True, False),
Var(str(N), str, False, True, False),
L,
inout,
x,
y,
),
[],
[lto_code_data],
Expand All @@ -6440,7 +6437,8 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
* float32
* float64
:param A: As input, the matrix A. As output, L.""",
:param A: A square, symmetric positive-definite, matrix.
:returns L: A square, lower triangular, matrix, such that LL^T = A""",
group="Tile Primitives",
export=False,
namespace="",
Expand All @@ -6460,8 +6458,9 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
* float32
* float64
:param L: The triangular matrix output of tile_cholesky,
:param x: As input, the right hand side y. As output, the solution x.""",
:param L: A square, lower triangular, matrix, such that LL^T = A
:param x: An Mx1 tile
:returns y: An Mx1 tile such that LL^T y = x""",
group="Tile Primitives",
export=False,
namespace="",
Expand Down
11 changes: 5 additions & 6 deletions warp/examples/tile/example_tile_cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ def cholesky(
i, j, _ = wp.tid()

a = wp.tile_load(A, i, j, m=TILE, n=TILE)
wp.tile_cholesky(a)
wp.tile_tril(a)
wp.tile_store(L, i, j, a)
l = wp.tile_cholesky(a)
wp.tile_store(L, i, j, l)

x = wp.tile_load(X, i, j, m=TILE, n=1)
wp.tile_cholesky_solve(a, x)
wp.tile_store(Y, i, j, x)
y = wp.tile_cholesky_solve(l, x)
wp.tile_store(Y, i, j, y)


if __name__ == "__main__":
Expand All @@ -62,7 +61,7 @@ def cholesky(
X_wp = wp.array2d(X_h, dtype=wp_type)
Y_wp = wp.array2d(Y_h, dtype=wp_type)

wp.launch(cholesky, dim=[1, 1, BLOCK_DIM], inputs=[A_wp, L_wp, X_wp, Y_wp], block_dim=BLOCK_DIM)
wp.launch_tiled(cholesky, dim=[1, 1], inputs=[A_wp, L_wp, X_wp, Y_wp], block_dim=BLOCK_DIM)

L_np = np.linalg.cholesky(A_h)
Y_np = np.linalg.solve(A_h, X_h)
Expand Down
Loading

0 comments on commit ac5b31f

Please sign in to comment.