Skip to content

Commit 75900a7

Browse files
committed
Fixed np to ncp in forward and backward
1 parent 531873f commit 75900a7

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -580,13 +580,13 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
580580
pad_m = bm - local_m
581581

582582
if pad_k > 0 or pad_m > 0:
583-
x_block = np.pad(x_block, [(0, pad_k), (0, pad_m)], mode='constant')
583+
x_block = ncp.pad(x_block, [(0, pad_k), (0, pad_m)], mode='constant')
584584

585-
Y_local = np.zeros((self.A.shape[0], bm),dtype=output_dtype)
585+
Y_local = ncp.zeros((self.A.shape[0], bm),dtype=output_dtype)
586586

587587
for k in range(self._P_prime):
588-
Atemp = self.A.copy() if self._col_id == k else np.empty_like(self.A)
589-
Xtemp = x_block.copy() if self._row_id == k else np.empty_like(x_block)
588+
Atemp = self.A.copy() if self._col_id == k else ncp.empty_like(self.A)
589+
Xtemp = x_block.copy() if self._row_id == k else ncp.empty_like(x_block)
590590
self._row_comm.Bcast(Atemp, root=k)
591591
self._col_comm.Bcast(Xtemp, root=k)
592592
Y_local += ncp.dot(Atemp, Xtemp)
@@ -646,14 +646,14 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
646646
pad_m = bm - local_m
647647

648648
if pad_n > 0 or pad_m > 0:
649-
x_block = np.pad(x_block, [(0, pad_n), (0, pad_m)], mode='constant')
649+
x_block = ncp.pad(x_block, [(0, pad_n), (0, pad_m)], mode='constant')
650650

651651
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
652-
Y_local = np.zeros((self.A.shape[1], bm), dtype=output_dtype)
652+
Y_local = ncp.zeros((self.A.shape[1], bm), dtype=output_dtype)
653653

654654
for k in range(self._P_prime):
655655
requests = []
656-
ATtemp = np.empty_like(A_local)
656+
ATtemp = ncp.empty_like(A_local)
657657
srcA = k * self._P_prime + self._row_id
658658
tagA = (100 + k) * 1000 + self.rank
659659
requests.append(self.base_comm.Irecv(ATtemp, source=srcA, tag=tagA))
@@ -663,7 +663,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
663663
destA = fixed_col * self._P_prime + moving_col
664664
tagA = (100 + k) * 1000 + destA
665665
requests.append(self.base_comm.Isend(A_local, dest=destA, tag=tagA))
666-
Xtemp = x_block.copy() if self._row_id == k else np.empty_like(x_block)
666+
Xtemp = x_block.copy() if self._row_id == k else ncp.empty_like(x_block)
667667
requests.append(self._col_comm.Ibcast(Xtemp, root=k))
668668
MPI.Request.Waitall(requests)
669669
Y_local += ncp.dot(ATtemp, Xtemp)

0 commit comments

Comments
 (0)