@@ -122,7 +122,7 @@ def local_block_spit(global_shape: Tuple[int, int],
122
122
return slice (i0 , i1 ), slice (j0 , j1 )
123
123
124
124
125
- def block_gather (x : DistributedArray , new_shape : Tuple [ int , int ], orig_shape : Tuple [int , int ], comm : MPI .Comm ):
125
+ def block_gather (x : DistributedArray , orig_shape : Tuple [int , int ], comm : MPI .Comm ):
126
126
r"""
127
127
Gather distributed local blocks from 2D block distributed matrix distributed
128
128
amongst a square process grid into the full global array.
@@ -131,11 +131,8 @@ def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tu
131
131
----------
132
132
x : :obj:`pylops_mpi.DistributedArray`
133
133
The distributed array to gather locally.
134
- new_shape : Tuple[int, int]
135
- Shape `(N', M')` of the padded global array, where both dimensions
136
- are multiples of :math:`\sqrt{\mathbf{P}}`.
137
134
orig_shape : Tuple[int, int]
138
- Original shape `(N, M)` of the global array before padding .
135
+ Original shape `(N, M)` of the global array to be gathered .
139
136
comm : MPI.Comm
140
137
MPI communicator whose size must be a perfect square (P = p_prime**2).
141
138
@@ -156,45 +153,15 @@ def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tu
156
153
raise RuntimeError (f"Communicator size must be a perfect square, got { comm .Get_size ()!r} " )
157
154
158
155
all_blks = comm .allgather (x .local_array )
159
-
160
- nr , nc = new_shape
161
- orr , orc = orig_shape
162
-
163
- # Calculate base block sizes
164
- br_base = nr // p_prime
165
- bc_base = nc // p_prime
166
-
167
- # Calculate remainder rows/cols that need to be distributed
168
- r_remainder = nr % p_prime
169
- c_remainder = nc % p_prime
170
-
171
- # Create the output matrix
156
+ nr , nc = orig_shape
157
+ br , bc = math .ceil (nr / p_prime ), math .ceil (nc / p_prime )
172
158
C = ncp .zeros ((nr , nc ), dtype = all_blks [0 ].dtype )
173
-
174
- # Place each block in the correct position
175
159
for rank in range (p_prime * p_prime ):
176
- # Convert linear rank to 2D grid position
177
- proc_row = rank // p_prime
178
- proc_col = rank % p_prime
179
-
180
- # Calculate this process's block dimensions
181
- block_rows = br_base + (1 if proc_row < r_remainder else 0 )
182
- block_cols = bc_base + (1 if proc_col < c_remainder else 0 )
183
-
184
- # Calculate starting position in global matrix
185
- start_row = proc_row * br_base + min (proc_row , r_remainder )
186
- start_col = proc_col * bc_base + min (proc_col , c_remainder )
187
-
188
- # Place the block
189
- block = all_blks [rank ]
190
- if block .ndim == 1 :
191
- block = block .reshape (block_rows , block_cols )
192
- C [start_row :start_row + block_rows ,
193
- start_col :start_col + block_cols ] = block
194
-
195
- # Trim off any padding
196
- return C [:orr , :orc ]
197
-
160
+ pr , pc = divmod (rank , p_prime )
161
+ rs , cs = pr * br , pc * bc
162
+ re , ce = min (rs + br , nr ), min (cs + bc , nc )
163
+ C [rs :re , cs :ce ] = all_blks [rank ].reshape (re - rs , cs - ce )
164
+ return C
198
165
199
166
class _MPIBlockMatrixMult (MPILinearOperator ):
200
167
r"""MPI Blocked Matrix multiplication
0 commit comments