Skip to content

Commit a994192

Browse files
committed
Added SUMMA tests and fixed dtype problem
1 parent 0956e7b commit a994192

File tree

4 files changed

+191
-54
lines changed

4 files changed

+191
-54
lines changed

examples/plot_matrixmult.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import pylops_mpi
3030
from pylops_mpi import Partition
31+
from pylops_mpi.basicoperators.MatrixMult import active_grid_comm, MPIMatrixMult
3132

3233
plt.close("all")
3334

@@ -88,8 +89,7 @@
8889
# than the row or columm ranks.
8990

9091
base_comm = MPI.COMM_WORLD
91-
comm, rank, row_id, col_id, is_active = \
92-
pylops_mpi.MPIMatrixMult.active_grid_comm(base_comm, N, M)
92+
comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M)
9393
print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}")
9494
if not is_active: exit(0)
9595

@@ -147,7 +147,7 @@
147147
################################################################################
148148
# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
149149
# operator and the input matrix :math:`\mathbf{X}`
150-
Aop = pylops_mpi.MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32")
150+
Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32", kind="block")
151151

152152
col_lens = comm.allgather(my_own_cols)
153153
total_cols = np.sum(col_lens)

examples/plot_summamatrixmult.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,28 @@
1+
r"""
2+
Distributed SUMMA Matrix Multiplication
3+
=======================================
4+
This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPISummaMatrixMult`
5+
operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}`
6+
distributed in 2D blocks across a square process grid and matrices :math:`\mathbf{X}`
7+
and :math:`\mathbf{Y}` distributed in 2D blocks across the same grid. Similarly,
8+
the adjoint operation can be performed with a matrix :math:`\mathbf{Y}` distributed
9+
in the same fashion as matrix :math:`\mathbf{X}`.
10+
11+
Note that whilst the different blocks of matrix :math:`\mathbf{A}` are directly
12+
stored in the operator on different ranks, the matrices :math:`\mathbf{X}` and
13+
:math:`\mathbf{Y}` are effectively represented by 1-D :py:class:`pylops_mpi.DistributedArray`
14+
objects where the different blocks are flattened and stored on different ranks.
15+
Note that to optimize communications, the ranks are organized in a square grid and
16+
blocks of :math:`\mathbf{A}` and :math:`\mathbf{X}` are systematically broadcast
17+
across different ranks during computation - see below for details.
18+
"""
19+
120
import math
221
import numpy as np
322
from mpi4py import MPI
423

524
import pylops_mpi
6-
from pylops_mpi.basicoperators.MatrixMult import (local_block_spit,
7-
block_gather,
8-
MPISummaMatrixMult)
25+
from pylops_mpi.basicoperators.MatrixMult import (local_block_spit, block_gather, MPIMatrixMult)
926

1027
comm = MPI.COMM_WORLD
1128
rank = comm.Get_rank()
@@ -16,43 +33,40 @@
1633
K = 9
1734

1835
A_shape = (N, K)
19-
B_shape = (K, M)
20-
C_shape = (N, M)
36+
x_shape = (K, M)
37+
y_shape = (N, M)
2138

2239
p_prime = math.isqrt(size)
23-
assert p_prime * p_prime == size, "Number of processes must be a perfect square"
24-
2540
A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape)
26-
B_data = np.arange(int(B_shape[0] * B_shape[1])).reshape(B_shape)
41+
x_data = np.arange(int(x_shape[0] * x_shape[1])).reshape(x_shape)
2742

2843
A_slice = local_block_spit(A_shape, rank, comm)
29-
B_slice = local_block_spit(B_shape, rank, comm)
44+
x_slice = local_block_spit(x_shape, rank, comm)
3045
A_local = A_data[A_slice]
31-
B_local = B_data[B_slice]
32-
# A_local, (N_new, K_new) = block_distribute(A_data,rank, comm)
33-
# B_local, (K_new, M_new) = block_distribute(B_data,rank, comm)
46+
x_local = x_data[x_slice]
3447

35-
B_dist = pylops_mpi.DistributedArray(global_shape=(K * M),
36-
local_shapes=comm.allgather(B_local.shape[0] * B_local.shape[1]),
48+
x_dist = pylops_mpi.DistributedArray(global_shape=(K * M),
49+
local_shapes=comm.allgather(x_local.shape[0] * x_local.shape[1]),
3750
base_comm=comm,
38-
partition=pylops_mpi.Partition.SCATTER)
39-
B_dist.local_array[:] = B_local.flatten()
51+
partition=pylops_mpi.Partition.SCATTER,
52+
dtype=x_local.dtype)
53+
x_dist.local_array[:] = x_local.flatten()
4054

41-
Aop = MPISummaMatrixMult(A_local, M, base_comm=comm)
42-
C_dist = Aop @ B_dist
43-
Z_dist = Aop.H @ C_dist
55+
Aop = MPIMatrixMult(A_local, M, base_comm=comm, kind="summa", dtype=A_local.dtype)
56+
y_dist = Aop @ x_dist
57+
xadj_dist = Aop.H @ y_dist
4458

45-
C = block_gather(C_dist, (N,M), (N,M), comm)
46-
Z = block_gather(Z_dist, (K,M), (K,M), comm)
59+
y = block_gather(y_dist, (N,M), (N,M), comm)
60+
xadj = block_gather(xadj_dist, (K,M), (K,M), comm)
4761
if rank == 0 :
48-
C_correct = np.allclose(A_data @ B_data, C)
49-
print("C expected: ", C_correct)
50-
if not C_correct:
51-
print("expected:\n", A_data @ B_data)
52-
print("calculated:\n",C)
53-
54-
Z_correct = np.allclose((A_data.T.dot((A_data @ B_data).conj())).conj(), Z.astype(np.int32))
55-
print("Z expected: ", Z_correct)
56-
if not Z_correct:
57-
print("expected:\n", (A_data.T.dot((A_data @ B_data).conj())).conj())
58-
print("calculated:\n", Z.astype(np.int32))
62+
y_correct = np.allclose(A_data @ x_data, y)
63+
print("y expected: ", y_correct)
64+
if not y_correct:
65+
print("expected:\n", A_data @ x_data)
66+
print("calculated:\n",y)
67+
68+
xadj_correct = np.allclose((A_data.T.dot((A_data @ x_data).conj())).conj(), xadj.astype(np.int32))
69+
print("xadj expected: ", xadj_correct)
70+
if not xadj_correct:
71+
print("expected:\n", (A_data.T.dot((A_data @ x_data).conj())).conj())
72+
print("calculated:\n", xadj.astype(np.int32))

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
7474
def local_block_spit(global_shape: Tuple[int, int],
7575
rank: int,
7676
comm: MPI.Comm) -> Tuple[slice, slice]:
77-
"""
77+
r"""
7878
Compute the local sub‐block of a 2D global array for a process in a square process grid.
7979
8080
Parameters
@@ -122,7 +122,7 @@ def local_block_spit(global_shape: Tuple[int, int],
122122

123123

124124
def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tuple[int, int], comm: MPI.Comm):
125-
"""
125+
r"""
126126
Gather distributed local blocks from 2D block distributed matrix distributed
127127
amongst a square process grid into the full global array.
128128
@@ -351,19 +351,19 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
351351
ncp = get_module(x.engine)
352352
if x.partition != Partition.SCATTER:
353353
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
354-
354+
output_dtype = np.result_type(self.dtype, x.dtype)
355355
y = DistributedArray(
356356
global_shape=(self.N * self.dimsd[1]),
357357
local_shapes=[(self.N * c) for c in self._rank_col_lens],
358358
mask=x.mask,
359359
partition=Partition.SCATTER,
360-
dtype=self.dtype,
360+
dtype=output_dtype,
361361
base_comm=self.base_comm
362362
)
363363

364364
my_own_cols = self._rank_col_lens[self.rank]
365365
x_arr = x.local_array.reshape((self.dims[0], my_own_cols))
366-
X_local = x_arr.astype(self.dtype)
366+
X_local = x_arr.astype(output_dtype)
367367
Y_local = ncp.vstack(
368368
self._row_comm.allgather(
369369
ncp.matmul(self.A, X_local)
@@ -377,16 +377,28 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
377377
if x.partition != Partition.SCATTER:
378378
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")
379379

380+
# - If A is real: A^H = A^T,
381+
# so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real)
382+
# - If A is complex: A^H is complex,
383+
# so result will be complex regardless of x
384+
if np.iscomplexobj(self.A):
385+
output_dtype = np.result_type(self.dtype, x.dtype)
386+
else:
387+
# Real matrix: A^T @ x preserves input type complexity
388+
output_dtype = x.dtype if np.iscomplexobj(x.local_array) else self.dtype
389+
# But still need to check type promotion for precision
390+
output_dtype = np.result_type(self.dtype, output_dtype)
391+
380392
y = DistributedArray(
381393
global_shape=(self.K * self.dimsd[1]),
382394
local_shapes=[self.K * c for c in self._rank_col_lens],
383395
mask=x.mask,
384396
partition=Partition.SCATTER,
385-
dtype=self.dtype,
397+
dtype=output_dtype,
386398
base_comm=self.base_comm
387399
)
388400

389-
x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype)
401+
x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(output_dtype)
390402
X_tile = x_arr[self._row_start:self._row_end, :]
391403
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
392404
Y_local = ncp.matmul(A_local, X_tile)
@@ -536,7 +548,6 @@ def __init__(
536548
self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id)
537549

538550
self.A = A.astype(np.dtype(dtype))
539-
if saveAt: self.At = A.T.conj()
540551

541552
self.N = self._col_comm.allreduce(A.shape[0])
542553
self.K = self._row_comm.allreduce(A.shape[1])
@@ -569,6 +580,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
569580
if x.partition != Partition.SCATTER:
570581
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
571582

583+
output_dtype = np.result_type(self.dtype, x.dtype)
572584
# Calculate local shapes for block distribution
573585
bn = self._N_padded // self._P_prime # block size in N dimension
574586
bm = self._M_padded // self._P_prime # block size in M dimension
@@ -582,9 +594,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
582594
mask=x.mask,
583595
local_shapes=local_shapes,
584596
partition=Partition.SCATTER,
585-
dtype=self.dtype,
586-
base_comm=self.base_comm
587-
)
597+
dtype=output_dtype,
598+
base_comm=self.base_comm)
588599

589600
# Calculate expected padded dimensions for x
590601
bk = self._K_padded // self._P_prime # block size in K dimension
@@ -603,13 +614,13 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
603614
if pad_k > 0 or pad_m > 0:
604615
x_block = np.pad(x_block, [(0, pad_k), (0, pad_m)], mode='constant')
605616

606-
Y_local = np.zeros((self.A.shape[0], bm))
617+
Y_local = np.zeros((self.A.shape[0], bm),dtype=output_dtype)
607618

608619
for k in range(self._P_prime):
609620
Atemp = self.A.copy() if self._col_id == k else np.empty_like(self.A)
610621
Xtemp = x_block.copy() if self._row_id == k else np.empty_like(x_block)
611-
self._row_comm.bcast(Atemp, root=k)
612-
self._col_comm.bcast(Xtemp, root=k)
622+
self._row_comm.Bcast(Atemp, root=k)
623+
self._col_comm.Bcast(Xtemp, root=k)
613624
Y_local += ncp.dot(Atemp, Xtemp)
614625

615626
Y_local_unpadded = Y_local[:local_n, :local_m]
@@ -631,13 +642,24 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
631642
local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm
632643

633644
local_shapes = self.base_comm.allgather(local_k * local_m)
645+
# - If A is real: A^H = A^T,
646+
# so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real)
647+
# - If A is complex: A^H is complex,
648+
# so result will be complex regardless of x
649+
if np.iscomplexobj(self.A):
650+
output_dtype = np.result_type(self.dtype, x.dtype)
651+
else:
652+
# Real matrix: A^T @ x preserves input type complexity
653+
output_dtype = x.dtype if np.iscomplexobj(x.local_array) else self.dtype
654+
# But still need to check type promotion for precision
655+
output_dtype = np.result_type(self.dtype, output_dtype)
634656

635657
y = DistributedArray(
636658
global_shape=(self.K * self.M),
637659
mask=x.mask,
638660
local_shapes=local_shapes,
639661
partition=Partition.SCATTER,
640-
dtype=self.dtype,
662+
dtype=output_dtype,
641663
base_comm=self.base_comm
642664
)
643665

@@ -659,7 +681,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
659681
x_block = np.pad(x_block, [(0, pad_n), (0, pad_m)], mode='constant')
660682

661683
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
662-
Y_local = np.zeros((self.A.shape[1], bm))
684+
Y_local = np.zeros((self.A.shape[1], bm), dtype=output_dtype)
663685

664686
for k in range(self._P_prime):
665687
requests = []

0 commit comments

Comments
 (0)