Skip to content

Commit 0956e7b

Browse files
committed
Converted it to a function
1 parent fa07ae8 commit 0956e7b

File tree

1 file changed

+50
-61
lines changed

1 file changed

+50
-61
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 50 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tu
195195
return C[:orr, :orc]
196196

197197

198-
199198
class _MPIBlockMatrixMult(MPILinearOperator):
200199
r"""MPI Blocked Matrix multiplication
201200
@@ -214,7 +213,7 @@ class _MPIBlockMatrixMult(MPILinearOperator):
214213
Global leading dimension (i.e., number of columns) of the matrices
215214
representing the input model and data vectors.
216215
saveAt : :obj:`bool`, optional
217-
Save ``A`` and ``A.H`` to speed up the computation of adjoint
216+
Save :math:`\mathbf{A}` and ``A.H`` to speed up the computation of adjoint
218217
(``True``) or create ``A.H`` on-the-fly (``False``)
219218
Note that ``saveAt=True`` will double the amount of required memory.
220219
Default is ``False``.
@@ -253,22 +252,22 @@ class _MPIBlockMatrixMult(MPILinearOperator):
253252
processes by a factor equivalent to :math:`\sqrt{P}` across a square process
254253
grid (:math:`\sqrt{P}\times\sqrt{P}`). More specifically:
255254
256-
- The matrix ``A`` is distributed across MPI processes in a block-row fashion
257-
and each process holds a local block of ``A`` with shape
255+
- The matrix :math:`\mathbf{A}` is distributed across MPI processes in a block-row fashion
256+
and each process holds a local block of :math:`\mathbf{A}` with shape
258257
:math:`[N_{loc} \times K]`
259-
- The operand matrix ``X`` is distributed in a block-column fashion and
260-
each process holds a local block of ``X`` with shape
258+
- The operand matrix :math:`\mathbf{X}` is distributed in a block-column fashion and
259+
each process holds a local block of :math:`\mathbf{X}` with shape
261260
:math:`[K \times M_{loc}]`
262261
- Communication is minimized by using a 2D process grid layout
263262
264263
**Forward Operation step-by-step**
265264
266-
1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X``
265+
1. **Input Preparation**: The input vector ``x`` (flattened from matrix :math:`\mathbf{X}`
267266
of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local``
268267
is the number of columns assigned to the current process.
269268
270269
2. **Local Computation**: Each process computes ``A_local @ X_local`` where:
271-
- ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``)
270+
- ``A_local`` is the local block of matrix :math:`\mathbf{A}` (shape ``N_local x K``)
272271
- ``X_local`` is the broadcasted operand (shape ``K x M_local``)
273272
274273
3. **Row-wise Gather**: Results from all processes in each row are gathered
@@ -283,10 +282,10 @@ class _MPIBlockMatrixMult(MPILinearOperator):
283282
representing the local columns of the input matrix.
284283
285284
2. **Local Adjoint Computation**: Each process computes
286-
``A_local.H @ X_tile`` where ``A_local.H`` is either i) Pre-computed
287-
and stored in ``At`` (if ``saveAt=True``), ii) computed on-the-fly as
285+
``A_local.H @ X_tile`` where ``A_local.H`` is either pre-computed
286+
and stored in ``At`` (if ``saveAt=True``), or computed on-the-fly as
288287
``A.T.conj()`` (if ``saveAt=False``). Each process multiplies its
289-
transposed local ``A`` block ``A_local^H`` (shape ``K x N_block``)
288+
transposed local :math:`\mathbf{A}` block ``A_local^H`` (shape ``K x N_block``)
290289
with the extracted ``X_tile`` (shape ``N_block x M_local``),
291290
producing a partial result of shape ``(K, M_local)``.
292291
This computes the local contribution of columns of ``A^H`` to the final
@@ -413,7 +412,7 @@ class _MPISummaMatrixMult(MPILinearOperator):
413412
Global number of columns of the matrices representing the input model
414413
and data vectors.
415414
saveAt : :obj:`bool`, optional
416-
Save ``A`` and ``A.H`` to speed up the computation of adjoint
415+
Save :math:`\mathbf{A}` and ``A.H`` to speed up the computation of adjoint
417416
(``True``) or create ``A.H`` on-the-fly (``False``).
418417
Note that ``saveAt=True`` will double the amount of required memory.
419418
Default is ``False``.
@@ -451,16 +450,16 @@ class _MPISummaMatrixMult(MPILinearOperator):
451450
This implementation is based on a 2D block distribution across a square process
452451
grid (:math:`\sqrt{P}\times\sqrt{P}`). The matrices are distributed as follows:
453452
454-
- The matrix ``A`` is distributed across MPI processes in 2D blocks where
455-
each process holds a local block of ``A`` with shape :math:`[N_{loc} \times K_{loc}]`
453+
- The matrix :math:`\mathbf{A}` is distributed across MPI processes in 2D blocks where
454+
each process holds a local block of :math:`\mathbf{A}` with shape :math:`[N_{loc} \times K_{loc}]`
456455
where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`K_{loc} = \frac{K}{\sqrt{P}}`.
457456
458-
- The operand matrix ``X`` is also distributed across MPI processes in 2D blocks where
459-
each process holds a local block of ``X`` with shape :math:`[K_{loc} \times M_{loc}]`
457+
- The operand matrix :math:`\mathbf{X}` is also distributed across MPI processes in 2D blocks where
458+
each process holds a local block of :math:`\mathbf{X}` with shape :math:`[K_{loc} \times M_{loc}]`
460459
where :math:`K_{loc} = \frac{K}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`.
461460
462-
- The result matrix ``Y`` is also distributed across MPI processes in 2D blocks where
463-
each process holds a local block of ``Y`` with shape :math:`[N_{loc} \times M_{loc}]`
461+
- The result matrix :math:`\mathbf{Y}` is also distributed across MPI processes in 2D blocks where
462+
each process holds a local block of :math:`\mathbf{Y}` with shape :math:`[N_{loc} \times M_{loc}]`
464463
where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`.
465464
466465
@@ -473,18 +472,18 @@ class _MPISummaMatrixMult(MPILinearOperator):
473472
474473
2. **SUMMA Iteration**: For each step ``k`` in the SUMMA algorithm -- :math:`k \in \[ 0, \sqrt{P} \)}` :
475474
476-
a. **Broadcast A blocks**: Process in column ``k`` broadcasts its ``A``
475+
a. **Broadcast A blocks**: Process in column ``k`` broadcasts its :math:`\mathbf{A}`
477476
block to all other processes in the same process row.
478477
479-
b. **Broadcast X blocks**: Process in row ``k`` broadcasts its ``X``
478+
b. **Broadcast X blocks**: Process in row ``k`` broadcasts its :math:`\mathbf{X}`
480479
block to all other processes in the same process column.
481480
482481
c. **Local Computation**: Each process computes the partial matrix
483482
product ``A_broadcast @ X_broadcast`` and accumulates it to its
484483
local result.
485484
486485
3. **Result Assembly**: After all k SUMMA iterations, each process has computed
487-
its local block of the result matrix ``Y``.
486+
its local block of the result matrix :math:`\mathbf{Y}`.
488487
489488
**Adjoint Operation (SUMMA Algorithm)**
490489
@@ -496,11 +495,11 @@ class _MPISummaMatrixMult(MPILinearOperator):
496495
497496
2. **SUMMA Adjoint Iteration**: For each step ``k`` in the adjoint SUMMA algorithm:
498497
499-
a. **Broadcast A^H blocks**: The conjugate transpose of ``A`` blocks is
498+
a. **Broadcast A^H blocks**: The conjugate transpose of :math:`\mathbf{A}` blocks is
500499
communicated between processes. If ``saveAt=True``, the pre-computed
501500
``A.H`` is used; otherwise, ``A.T.conj()`` is computed on-the-fly.
502501
503-
b. **Broadcast Y blocks**: Process in row ``k`` broadcasts its ``Y``
502+
b. **Broadcast Y blocks**: Process in row ``k`` broadcasts its :math:`\mathbf{Y}`
504503
block to all other processes in the same process column.
505504
506505
c. **Local Adjoint Computation**: Each process computes the partial
@@ -683,7 +682,14 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
683682
y[:] = Y_local_unpadded.flatten()
684683
return y
685684

686-
class MPIMatrixMult(MPILinearOperator):
685+
def MPIMatrixMult(
686+
A: NDArray,
687+
M: int,
688+
saveAt: bool = False,
689+
base_comm: MPI.Comm = MPI.COMM_WORLD,
690+
kind: Literal["summa", "block"] = "summa",
691+
dtype: DTypeLike = "float64",
692+
):
687693
r"""
688694
MPI Distributed Matrix Multiplication Operator
689695
@@ -694,32 +700,32 @@ class MPIMatrixMult(MPILinearOperator):
694700
695701
The forward operation computes::
696702
697-
Y = A @ X
703+
:math:`\mathbf{Y} = \mathbf{A} \cdot \mathbf{X}`
698704
699705
where:
700-
- ``A`` is the distributed operator matrix of shape ``[N x K]``
701-
- ``X`` is the distributed operand matrix of shape ``[K x M]``
702-
- ``Y`` is the resulting distributed matrix of shape ``[N x M]``
706+
- :math:`\mathbf{A}` is the distributed operator matrix of shape :math:`[N \times K]`
707+
- :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]`
708+
- :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]`
703709
704710
The adjoint (conjugate-transpose) operation computes::
711+
712+
:math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}`
705713
706-
X_adj = A.H @ Y
707-
708-
where ``A.H`` is the complex-conjugate transpose of ``A``.
714+
where :math:`\mathbf{A}^H` is the complex-conjugate transpose of :math:`\mathbf{A}`.
709715
710716
Distribution Layouts
711717
--------------------
712718
:summa:
713719
2D block-grid distribution over a square process grid :math:`[\sqrt{P} \times \sqrt{P}]`:
714-
- ``A`` and ``X`` are partitioned into :math:`[N_loc \times K_loc]` and
720+
- :math:`\mathbf{A}` and :math:`\mathbf{X}` are partitioned into :math:`[N_loc \times K_loc]` and
715721
:math:`[K_loc \times M_loc]` tiles on each rank, respectively.
716-
- Each SUMMA iteration broadcasts row- and column-blocks of ``A`` and
717-
``X`` and accumulates local partial products.
722+
- Each SUMMA iteration broadcasts row- and column-blocks of :math:`\mathbf{A}` and
723+
:math:`\mathbf{X}` and accumulates local partial products.
718724
719725
:block:
720-
1D block-row distribution over a 1 x P grid:
721-
- ``A`` is partitioned into :math:`[N_loc \times K]` blocks across ranks.
722-
- ``X`` (and result ``Y``) are partitioned into :math:`[K \times M_loc]` blocks.
726+
1D block-row distribution over a :math:`[1 \times P]` grid:
727+
- :math:`\mathbf{A}` is partitioned into :math:`[N_loc \times K]` blocks across ranks.
728+
- :math:`\mathbf{X}` (and result :math:`\mathbf{Y}`) are partitioned into :math:`[K \times M_loc]` blocks.
723729
- Local multiplication is followed by row-wise gather (forward) or
724730
allreduce (adjoint) across ranks.
725731
@@ -730,7 +736,7 @@ class MPIMatrixMult(MPILinearOperator):
730736
M : int
731737
Global number of columns in the operand and result matrices.
732738
saveAt : bool, optional
733-
If ``True``, store both ``A`` and its conjugate transpose ``A.H``
739+
If ``True``, store both :math:`\mathbf{A}` and its conjugate transpose :math:`\mathbf{A}^H`
734740
to accelerate adjoint operations (uses twice the memory).
735741
Default is ``False``.
736742
base_comm : mpi4py.MPI.Comm, optional
@@ -758,26 +764,9 @@ class MPIMatrixMult(MPILinearOperator):
758764
If the MPI communicator does not form a compatible grid for the
759765
selected algorithm.
760766
"""
761-
def __init__(
762-
self,
763-
A: NDArray,
764-
M: int,
765-
saveAt: bool = False,
766-
base_comm: MPI.Comm = MPI.COMM_WORLD,
767-
kind:Literal["summa", "block"] = "summa",
768-
dtype: DTypeLike = "float64",
769-
):
770-
if kind == "summa":
771-
self._f = _MPISummaMatrixMult(A,M,saveAt,base_comm,dtype)
772-
elif kind == "block":
773-
self._f = _MPIBlockMatrixMult(A, M, saveAt, base_comm, dtype)
774-
else:
775-
raise NotImplementedError("kind must be summa or block")
776-
self.kind = kind
777-
super().__init__(shape=self._f.shape, dtype=dtype, base_comm=base_comm)
778-
779-
def _matvec(self, x: DistributedArray) -> DistributedArray:
780-
return self._f.matvec(x)
781-
782-
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
783-
return self._f.rmatvec(x)
767+
if kind == "summa":
768+
return _MPISummaMatrixMult(A,M,saveAt,base_comm,dtype)
769+
elif kind == "block":
770+
return _MPIBlockMatrixMult(A, M, saveAt, base_comm, dtype)
771+
else:
772+
raise NotImplementedError("kind must be summa or block")

0 commit comments

Comments
 (0)