@@ -195,7 +195,6 @@ def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tu
195
195
return C [:orr , :orc ]
196
196
197
197
198
-
199
198
class _MPIBlockMatrixMult (MPILinearOperator ):
200
199
r"""MPI Blocked Matrix multiplication
201
200
@@ -214,7 +213,7 @@ class _MPIBlockMatrixMult(MPILinearOperator):
214
213
Global leading dimension (i.e., number of columns) of the matrices
215
214
representing the input model and data vectors.
216
215
saveAt : :obj:`bool`, optional
217
- Save ``A` ` and ``A.H`` to speed up the computation of adjoint
216
+ Save :math:`\mathbf{A} ` and ``A.H`` to speed up the computation of adjoint
218
217
(``True``) or create ``A.H`` on-the-fly (``False``)
219
218
Note that ``saveAt=True`` will double the amount of required memory.
220
219
Default is ``False``.
@@ -253,22 +252,22 @@ class _MPIBlockMatrixMult(MPILinearOperator):
253
252
processes by a factor equivalent to :math:`\sqrt{P}` across a square process
254
253
grid (:math:`\sqrt{P}\times\sqrt{P}`). More specifically:
255
254
256
- - The matrix ``A` ` is distributed across MPI processes in a block-row fashion
257
- and each process holds a local block of ``A` ` with shape
255
+ - The matrix :math:`\mathbf{A} ` is distributed across MPI processes in a block-row fashion
256
+ and each process holds a local block of :math:`\mathbf{A} ` with shape
258
257
:math:`[N_{loc} \times K]`
259
- - The operand matrix ``X` ` is distributed in a block-column fashion and
260
- each process holds a local block of ``X` ` with shape
258
+ - The operand matrix :math:`\mathbf{X} ` is distributed in a block-column fashion and
259
+ each process holds a local block of :math:`\mathbf{X} ` with shape
261
260
:math:`[K \times M_{loc}]`
262
261
- Communication is minimized by using a 2D process grid layout
263
262
264
263
**Forward Operation step-by-step**
265
264
266
- 1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X` `
265
+ 1. **Input Preparation**: The input vector ``x`` (flattened from matrix :math:`\mathbf{X} `
267
266
of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local``
268
267
is the number of columns assigned to the current process.
269
268
270
269
2. **Local Computation**: Each process computes ``A_local @ X_local`` where:
271
- - ``A_local`` is the local block of matrix ``A` ` (shape ``N_local x K``)
270
+ - ``A_local`` is the local block of matrix :math:`\mathbf{A} ` (shape ``N_local x K``)
272
271
- ``X_local`` is the broadcasted operand (shape ``K x M_local``)
273
272
274
273
3. **Row-wise Gather**: Results from all processes in each row are gathered
@@ -283,10 +282,10 @@ class _MPIBlockMatrixMult(MPILinearOperator):
283
282
representing the local columns of the input matrix.
284
283
285
284
2. **Local Adjoint Computation**: Each process computes
286
- ``A_local.H @ X_tile`` where ``A_local.H`` is either i) Pre -computed
287
- and stored in ``At`` (if ``saveAt=True``), ii) computed on-the-fly as
285
+ ``A_local.H @ X_tile`` where ``A_local.H`` is either pre -computed
286
+ and stored in ``At`` (if ``saveAt=True``), or computed on-the-fly as
288
287
``A.T.conj()`` (if ``saveAt=False``). Each process multiplies its
289
- transposed local ``A` ` block ``A_local^H`` (shape ``K x N_block``)
288
+ transposed local :math:`\mathbf{A} ` block ``A_local^H`` (shape ``K x N_block``)
290
289
with the extracted ``X_tile`` (shape ``N_block x M_local``),
291
290
producing a partial result of shape ``(K, M_local)``.
292
291
This computes the local contribution of columns of ``A^H`` to the final
@@ -413,7 +412,7 @@ class _MPISummaMatrixMult(MPILinearOperator):
413
412
Global number of columns of the matrices representing the input model
414
413
and data vectors.
415
414
saveAt : :obj:`bool`, optional
416
- Save ``A` ` and ``A.H`` to speed up the computation of adjoint
415
+ Save :math:`\mathbf{A} ` and ``A.H`` to speed up the computation of adjoint
417
416
(``True``) or create ``A.H`` on-the-fly (``False``).
418
417
Note that ``saveAt=True`` will double the amount of required memory.
419
418
Default is ``False``.
@@ -451,16 +450,16 @@ class _MPISummaMatrixMult(MPILinearOperator):
451
450
This implementation is based on a 2D block distribution across a square process
452
451
grid (:math:`\sqrt{P}\times\sqrt{P}`). The matrices are distributed as follows:
453
452
454
- - The matrix ``A` ` is distributed across MPI processes in 2D blocks where
455
- each process holds a local block of ``A` ` with shape :math:`[N_{loc} \times K_{loc}]`
453
+ - The matrix :math:`\mathbf{A} ` is distributed across MPI processes in 2D blocks where
454
+ each process holds a local block of :math:`\mathbf{A} ` with shape :math:`[N_{loc} \times K_{loc}]`
456
455
where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`K_{loc} = \frac{K}{\sqrt{P}}`.
457
456
458
- - The operand matrix ``X` ` is also distributed across MPI processes in 2D blocks where
459
- each process holds a local block of ``X` ` with shape :math:`[K_{loc} \times M_{loc}]`
457
+ - The operand matrix :math:`\mathbf{X} ` is also distributed across MPI processes in 2D blocks where
458
+ each process holds a local block of :math:`\mathbf{X} ` with shape :math:`[K_{loc} \times M_{loc}]`
460
459
where :math:`K_{loc} = \frac{K}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`.
461
460
462
- - The result matrix ``Y` ` is also distributed across MPI processes in 2D blocks where
463
- each process holds a local block of ``Y` ` with shape :math:`[N_{loc} \times M_{loc}]`
461
+ - The result matrix :math:`\mathbf{Y} ` is also distributed across MPI processes in 2D blocks where
462
+ each process holds a local block of :math:`\mathbf{Y} ` with shape :math:`[N_{loc} \times M_{loc}]`
464
463
where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`.
465
464
466
465
@@ -473,18 +472,18 @@ class _MPISummaMatrixMult(MPILinearOperator):
473
472
474
473
2. **SUMMA Iteration**: For each step ``k`` in the SUMMA algorithm -- :math:`k \in \[ 0, \sqrt{P} \)}` :
475
474
476
- a. **Broadcast A blocks**: Process in column ``k`` broadcasts its ``A` `
475
+ a. **Broadcast A blocks**: Process in column ``k`` broadcasts its :math:`\mathbf{A} `
477
476
block to all other processes in the same process row.
478
477
479
- b. **Broadcast X blocks**: Process in row ``k`` broadcasts its ``X` `
478
+ b. **Broadcast X blocks**: Process in row ``k`` broadcasts its :math:`\mathbf{X} `
480
479
block to all other processes in the same process column.
481
480
482
481
c. **Local Computation**: Each process computes the partial matrix
483
482
product ``A_broadcast @ X_broadcast`` and accumulates it to its
484
483
local result.
485
484
486
485
3. **Result Assembly**: After all k SUMMA iterations, each process has computed
487
- its local block of the result matrix ``Y` `.
486
+ its local block of the result matrix :math:`\mathbf{Y} `.
488
487
489
488
**Adjoint Operation (SUMMA Algorithm)**
490
489
@@ -496,11 +495,11 @@ class _MPISummaMatrixMult(MPILinearOperator):
496
495
497
496
2. **SUMMA Adjoint Iteration**: For each step ``k`` in the adjoint SUMMA algorithm:
498
497
499
- a. **Broadcast A^H blocks**: The conjugate transpose of ``A` ` blocks is
498
+ a. **Broadcast A^H blocks**: The conjugate transpose of :math:`\mathbf{A} ` blocks is
500
499
communicated between processes. If ``saveAt=True``, the pre-computed
501
500
``A.H`` is used; otherwise, ``A.T.conj()`` is computed on-the-fly.
502
501
503
- b. **Broadcast Y blocks**: Process in row ``k`` broadcasts its ``Y` `
502
+ b. **Broadcast Y blocks**: Process in row ``k`` broadcasts its :math:`\mathbf{Y} `
504
503
block to all other processes in the same process column.
505
504
506
505
c. **Local Adjoint Computation**: Each process computes the partial
@@ -683,7 +682,14 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
683
682
y [:] = Y_local_unpadded .flatten ()
684
683
return y
685
684
686
- class MPIMatrixMult (MPILinearOperator ):
685
+ def MPIMatrixMult (
686
+ A : NDArray ,
687
+ M : int ,
688
+ saveAt : bool = False ,
689
+ base_comm : MPI .Comm = MPI .COMM_WORLD ,
690
+ kind : Literal ["summa" , "block" ] = "summa" ,
691
+ dtype : DTypeLike = "float64" ,
692
+ ):
687
693
r"""
688
694
MPI Distributed Matrix Multiplication Operator
689
695
@@ -694,32 +700,32 @@ class MPIMatrixMult(MPILinearOperator):
694
700
695
701
The forward operation computes::
696
702
697
- Y = A @ X
703
+ :math:`\mathbf{Y} = \mathbf{A} \cdot \mathbf{X}`
698
704
699
705
where:
700
- - ``A`` is the distributed operator matrix of shape `` [N x K]` `
701
- - ``X`` is the distributed operand matrix of shape `` [K x M]` `
702
- - ``Y`` is the resulting distributed matrix of shape `` [N x M]` `
706
+ - :math:`\mathbf{A}` is the distributed operator matrix of shape :math:` [N \times K]`
707
+ - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:` [K \times M]`
708
+ - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:` [N \times M]`
703
709
704
710
The adjoint (conjugate-transpose) operation computes::
711
+
712
+ :math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}`
705
713
706
- X_adj = A.H @ Y
707
-
708
- where ``A.H`` is the complex-conjugate transpose of ``A``.
714
+ where :math:`\mathbf{A}^H` is the complex-conjugate transpose of :math:`\mathbf{A}`.
709
715
710
716
Distribution Layouts
711
717
--------------------
712
718
:summa:
713
719
2D block-grid distribution over a square process grid :math:`[\sqrt{P} \times \sqrt{P}]`:
714
- - ``A`` and ``X` ` are partitioned into :math:`[N_loc \times K_loc]` and
720
+ - :math:`\mathbf{A}` and :math:`\mathbf{X} ` are partitioned into :math:`[N_loc \times K_loc]` and
715
721
:math:`[K_loc \times M_loc]` tiles on each rank, respectively.
716
- - Each SUMMA iteration broadcasts row- and column-blocks of ``A` ` and
717
- ``X` ` and accumulates local partial products.
722
+ - Each SUMMA iteration broadcasts row- and column-blocks of :math:`\mathbf{A} ` and
723
+ :math:`\mathbf{X} ` and accumulates local partial products.
718
724
719
725
:block:
720
- 1D block-row distribution over a 1 x P grid:
721
- - ``A` ` is partitioned into :math:`[N_loc \times K]` blocks across ranks.
722
- - ``X`` (and result ``Y` `) are partitioned into :math:`[K \times M_loc]` blocks.
726
+ 1D block-row distribution over a :math:`[1 \times P]` grid:
727
+ - :math:`\mathbf{A} ` is partitioned into :math:`[N_loc \times K]` blocks across ranks.
728
+ - :math:`\mathbf{X}` (and result :math:`\mathbf{Y} `) are partitioned into :math:`[K \times M_loc]` blocks.
723
729
- Local multiplication is followed by row-wise gather (forward) or
724
730
allreduce (adjoint) across ranks.
725
731
@@ -730,7 +736,7 @@ class MPIMatrixMult(MPILinearOperator):
730
736
M : int
731
737
Global number of columns in the operand and result matrices.
732
738
saveAt : bool, optional
733
- If ``True``, store both ``A`` and its conjugate transpose ``A.H` `
739
+ If ``True``, store both :math:`\mathbf{A}` and its conjugate transpose :math:`\mathbf{A}^H `
734
740
to accelerate adjoint operations (uses twice the memory).
735
741
Default is ``False``.
736
742
base_comm : mpi4py.MPI.Comm, optional
@@ -758,26 +764,9 @@ class MPIMatrixMult(MPILinearOperator):
758
764
If the MPI communicator does not form a compatible grid for the
759
765
selected algorithm.
760
766
"""
761
- def __init__ (
762
- self ,
763
- A : NDArray ,
764
- M : int ,
765
- saveAt : bool = False ,
766
- base_comm : MPI .Comm = MPI .COMM_WORLD ,
767
- kind :Literal ["summa" , "block" ] = "summa" ,
768
- dtype : DTypeLike = "float64" ,
769
- ):
770
- if kind == "summa" :
771
- self ._f = _MPISummaMatrixMult (A ,M ,saveAt ,base_comm ,dtype )
772
- elif kind == "block" :
773
- self ._f = _MPIBlockMatrixMult (A , M , saveAt , base_comm , dtype )
774
- else :
775
- raise NotImplementedError ("kind must be summa or block" )
776
- self .kind = kind
777
- super ().__init__ (shape = self ._f .shape , dtype = dtype , base_comm = base_comm )
778
-
779
- def _matvec (self , x : DistributedArray ) -> DistributedArray :
780
- return self ._f .matvec (x )
781
-
782
- def _rmatvec (self , x : DistributedArray ) -> DistributedArray :
783
- return self ._f .rmatvec (x )
767
+ if kind == "summa" :
768
+ return _MPISummaMatrixMult (A ,M ,saveAt ,base_comm ,dtype )
769
+ elif kind == "block" :
770
+ return _MPIBlockMatrixMult (A , M , saveAt , base_comm , dtype )
771
+ else :
772
+ raise NotImplementedError ("kind must be summa or block" )
0 commit comments