@@ -580,13 +580,13 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
580
580
pad_m = bm - local_m
581
581
582
582
if pad_k > 0 or pad_m > 0 :
583
- x_block = np .pad (x_block , [(0 , pad_k ), (0 , pad_m )], mode = 'constant' )
583
+ x_block = ncp .pad (x_block , [(0 , pad_k ), (0 , pad_m )], mode = 'constant' )
584
584
585
- Y_local = np .zeros ((self .A .shape [0 ], bm ),dtype = output_dtype )
585
+ Y_local = ncp .zeros ((self .A .shape [0 ], bm ),dtype = output_dtype )
586
586
587
587
for k in range (self ._P_prime ):
588
- Atemp = self .A .copy () if self ._col_id == k else np .empty_like (self .A )
589
- Xtemp = x_block .copy () if self ._row_id == k else np .empty_like (x_block )
588
+ Atemp = self .A .copy () if self ._col_id == k else ncp .empty_like (self .A )
589
+ Xtemp = x_block .copy () if self ._row_id == k else ncp .empty_like (x_block )
590
590
self ._row_comm .Bcast (Atemp , root = k )
591
591
self ._col_comm .Bcast (Xtemp , root = k )
592
592
Y_local += ncp .dot (Atemp , Xtemp )
@@ -646,14 +646,14 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
646
646
pad_m = bm - local_m
647
647
648
648
if pad_n > 0 or pad_m > 0 :
649
- x_block = np .pad (x_block , [(0 , pad_n ), (0 , pad_m )], mode = 'constant' )
649
+ x_block = ncp .pad (x_block , [(0 , pad_n ), (0 , pad_m )], mode = 'constant' )
650
650
651
651
A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
652
- Y_local = np .zeros ((self .A .shape [1 ], bm ), dtype = output_dtype )
652
+ Y_local = ncp .zeros ((self .A .shape [1 ], bm ), dtype = output_dtype )
653
653
654
654
for k in range (self ._P_prime ):
655
655
requests = []
656
- ATtemp = np .empty_like (A_local )
656
+ ATtemp = ncp .empty_like (A_local )
657
657
srcA = k * self ._P_prime + self ._row_id
658
658
tagA = (100 + k ) * 1000 + self .rank
659
659
requests .append (self .base_comm .Irecv (ATtemp , source = srcA , tag = tagA ))
@@ -663,7 +663,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
663
663
destA = fixed_col * self ._P_prime + moving_col
664
664
tagA = (100 + k ) * 1000 + destA
665
665
requests .append (self .base_comm .Isend (A_local , dest = destA , tag = tagA ))
666
- Xtemp = x_block .copy () if self ._row_id == k else np .empty_like (x_block )
666
+ Xtemp = x_block .copy () if self ._row_id == k else ncp .empty_like (x_block )
667
667
requests .append (self ._col_comm .Ibcast (Xtemp , root = k ))
668
668
MPI .Request .Waitall (requests )
669
669
Y_local += ncp .dot (ATtemp , Xtemp )
0 commit comments