Skip to content

Commit 0c5cb7e

Browse files
committed
consistancy
1 parent 75900a7 commit 0c5cb7e

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,14 @@ def local_block_spit(global_shape: Tuple[int, int],
108108
if not ( isinstance(rank, int) and 0 <= rank < size ):
109109
raise ValueError(f"rank must be integer in [0, {size}), got {rank!r}")
110110

111-
proc_i, proc_j = divmod(rank, p_prime)
111+
pr, pc = divmod(rank, p_prime)
112112
orig_r, orig_c = global_shape
113-
114113
new_r = math.ceil(orig_r / p_prime) * p_prime
115114
new_c = math.ceil(orig_c / p_prime) * p_prime
116-
117115
blkr, blkc = new_r // p_prime, new_c // p_prime
118-
119-
i0, j0 = proc_i * blkr, proc_j * blkc
120-
i1, j1 = min(i0 + blkr, orig_r), min(j0 + blkc, orig_c)
121-
122-
return slice(i0, i1), slice(j0, j1)
116+
rs, cs = pr * blkr, pc * blkc
117+
re, ce = min(rs + blkr, orig_r), min(cs + blkc, orig_c)
118+
return slice(rs, re), slice(cs, ce)
123119

124120

125121
def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Comm):

tests/test_matrixmult.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,7 @@ def test_MPISummaMatrixMult(N, K, M, dtype_str):
243243
xadj1 = block_gather(xadj1_dist, (K,M), comm)
244244

245245
if rank == 0:
246-
y1_loc = ((A_glob @ X_glob).conj().ravel()).reshape(N, M) + 1.0j
247-
y1_loc = y1_loc - 1.0j
246+
y1_loc = ((A_glob @ X_glob).conj().ravel()).reshape(N, M)
248247

249248
assert_allclose(
250249
y1.squeeze(),

0 commit comments

Comments
 (0)