Skip to content

Commit 531873f

Browse files
committed
Fixed block_gather fn
1 parent 8851e05 commit 531873f

File tree

2 files changed

+13
-46
lines changed

2 files changed

+13
-46
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def local_block_spit(global_shape: Tuple[int, int],
122122
return slice(i0, i1), slice(j0, j1)
123123

124124

125-
def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tuple[int, int], comm: MPI.Comm):
125+
def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Comm):
126126
r"""
127127
Gather distributed local blocks from 2D block distributed matrix distributed
128128
amongst a square process grid into the full global array.
@@ -131,11 +131,8 @@ def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tu
131131
----------
132132
x : :obj:`pylops_mpi.DistributedArray`
133133
The distributed array to gather locally.
134-
new_shape : Tuple[int, int]
135-
Shape `(N', M')` of the padded global array, where both dimensions
136-
are multiples of :math:`\sqrt{\mathbf{P}}`.
137134
orig_shape : Tuple[int, int]
138-
Original shape `(N, M)` of the global array before padding.
135+
Original shape `(N, M)` of the global array to be gathered.
139136
comm : MPI.Comm
140137
MPI communicator whose size must be a perfect square (P = p_prime**2).
141138
@@ -156,45 +153,15 @@ def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tu
156153
raise RuntimeError(f"Communicator size must be a perfect square, got {comm.Get_size()!r}")
157154

158155
all_blks = comm.allgather(x.local_array)
159-
160-
nr, nc = new_shape
161-
orr, orc = orig_shape
162-
163-
# Calculate base block sizes
164-
br_base = nr // p_prime
165-
bc_base = nc // p_prime
166-
167-
# Calculate remainder rows/cols that need to be distributed
168-
r_remainder = nr % p_prime
169-
c_remainder = nc % p_prime
170-
171-
# Create the output matrix
156+
nr, nc = orig_shape
157+
br, bc = math.ceil(nr / p_prime), math.ceil(nc / p_prime)
172158
C = ncp.zeros((nr, nc), dtype=all_blks[0].dtype)
173-
174-
# Place each block in the correct position
175159
for rank in range(p_prime * p_prime):
176-
# Convert linear rank to 2D grid position
177-
proc_row = rank // p_prime
178-
proc_col = rank % p_prime
179-
180-
# Calculate this process's block dimensions
181-
block_rows = br_base + (1 if proc_row < r_remainder else 0)
182-
block_cols = bc_base + (1 if proc_col < c_remainder else 0)
183-
184-
# Calculate starting position in global matrix
185-
start_row = proc_row * br_base + min(proc_row, r_remainder)
186-
start_col = proc_col * bc_base + min(proc_col, c_remainder)
187-
188-
# Place the block
189-
block = all_blks[rank]
190-
if block.ndim == 1:
191-
block = block.reshape(block_rows, block_cols)
192-
C[start_row:start_row + block_rows,
193-
start_col:start_col + block_cols] = block
194-
195-
# Trim off any padding
196-
return C[:orr, :orc]
197-
160+
pr, pc = divmod(rank, p_prime)
161+
rs, cs = pr * br, pc * bc
162+
re, ce = min(rs + br, nr), min(cs + bc, nc)
163+
C[rs:re, cs:ce] = all_blks[rank].reshape(re - rs, cs - ce)
164+
return C
198165

199166
class _MPIBlockMatrixMult(MPILinearOperator):
200167
r"""MPI Blocked Matrix multiplication

tests/test_matrixmult.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ def test_MPISummaMatrixMult(N, K, M, dtype_str):
210210
xadj_dist = Aop.H @ y_dist
211211

212212
# Re-organize in local matrix
213-
y = block_gather(y_dist, (N,M), (N,M), comm)
214-
xadj = block_gather(xadj_dist, (K,M), (K,M), comm)
213+
y = block_gather(y_dist, (N,M), comm)
214+
xadj = block_gather(xadj_dist, (K,M), comm)
215215

216216
if rank == 0:
217217
y_loc = A_glob @ X_glob
@@ -239,8 +239,8 @@ def test_MPISummaMatrixMult(N, K, M, dtype_str):
239239
xadj1_dist = Op.H @ y1_dist
240240

241241
# Re-organize in local matrix
242-
y1 = block_gather(y1_dist, (N, M), (N, M), comm)
243-
xadj1 = block_gather(xadj1_dist, (K,M), (K,M), comm)
242+
y1 = block_gather(y1_dist, (N, M), comm)
243+
xadj1 = block_gather(xadj1_dist, (K,M), comm)
244244

245245
if rank == 0:
246246
y1_loc = ((A_glob @ X_glob).conj().ravel()).reshape(N, M) + 1.0j

0 commit comments

Comments
 (0)