@@ -74,7 +74,7 @@ def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
74
74
def local_block_spit (global_shape : Tuple [int , int ],
75
75
rank : int ,
76
76
comm : MPI .Comm ) -> Tuple [slice , slice ]:
77
- """
77
+ r """
78
78
Compute the local sub‐block of a 2D global array for a process in a square process grid.
79
79
80
80
Parameters
@@ -122,7 +122,7 @@ def local_block_spit(global_shape: Tuple[int, int],
122
122
123
123
124
124
def block_gather (x : DistributedArray , new_shape : Tuple [int , int ], orig_shape : Tuple [int , int ], comm : MPI .Comm ):
125
- """
125
+ r """
126
126
Gather distributed local blocks from 2D block distributed matrix distributed
127
127
amongst a square process grid into the full global array.
128
128
@@ -351,19 +351,19 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
351
351
ncp = get_module (x .engine )
352
352
if x .partition != Partition .SCATTER :
353
353
raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
354
-
354
+ output_dtype = np . result_type ( self . dtype , x . dtype )
355
355
y = DistributedArray (
356
356
global_shape = (self .N * self .dimsd [1 ]),
357
357
local_shapes = [(self .N * c ) for c in self ._rank_col_lens ],
358
358
mask = x .mask ,
359
359
partition = Partition .SCATTER ,
360
- dtype = self . dtype ,
360
+ dtype = output_dtype ,
361
361
base_comm = self .base_comm
362
362
)
363
363
364
364
my_own_cols = self ._rank_col_lens [self .rank ]
365
365
x_arr = x .local_array .reshape ((self .dims [0 ], my_own_cols ))
366
- X_local = x_arr .astype (self . dtype )
366
+ X_local = x_arr .astype (output_dtype )
367
367
Y_local = ncp .vstack (
368
368
self ._row_comm .allgather (
369
369
ncp .matmul (self .A , X_local )
@@ -377,16 +377,28 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
377
377
if x .partition != Partition .SCATTER :
378
378
raise ValueError (f"x should have partition={ Partition .SCATTER } . Got { x .partition } instead." )
379
379
380
+ # - If A is real: A^H = A^T,
381
+ # so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real)
382
+ # - If A is complex: A^H is complex,
383
+ # so result will be complex regardless of x
384
+ if np .iscomplexobj (self .A ):
385
+ output_dtype = np .result_type (self .dtype , x .dtype )
386
+ else :
387
+ # Real matrix: A^T @ x preserves input type complexity
388
+ output_dtype = x .dtype if np .iscomplexobj (x .local_array ) else self .dtype
389
+ # But still need to check type promotion for precision
390
+ output_dtype = np .result_type (self .dtype , output_dtype )
391
+
380
392
y = DistributedArray (
381
393
global_shape = (self .K * self .dimsd [1 ]),
382
394
local_shapes = [self .K * c for c in self ._rank_col_lens ],
383
395
mask = x .mask ,
384
396
partition = Partition .SCATTER ,
385
- dtype = self . dtype ,
397
+ dtype = output_dtype ,
386
398
base_comm = self .base_comm
387
399
)
388
400
389
- x_arr = x .local_array .reshape ((self .N , self ._local_ncols )).astype (self . dtype )
401
+ x_arr = x .local_array .reshape ((self .N , self ._local_ncols )).astype (output_dtype )
390
402
X_tile = x_arr [self ._row_start :self ._row_end , :]
391
403
A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
392
404
Y_local = ncp .matmul (A_local , X_tile )
@@ -536,7 +548,6 @@ def __init__(
536
548
self ._col_comm = base_comm .Split (color = self ._col_id , key = self ._row_id )
537
549
538
550
self .A = A .astype (np .dtype (dtype ))
539
- if saveAt : self .At = A .T .conj ()
540
551
541
552
self .N = self ._col_comm .allreduce (A .shape [0 ])
542
553
self .K = self ._row_comm .allreduce (A .shape [1 ])
@@ -569,6 +580,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
569
580
if x .partition != Partition .SCATTER :
570
581
raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
571
582
583
+ output_dtype = np .result_type (self .dtype , x .dtype )
572
584
# Calculate local shapes for block distribution
573
585
bn = self ._N_padded // self ._P_prime # block size in N dimension
574
586
bm = self ._M_padded // self ._P_prime # block size in M dimension
@@ -582,9 +594,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
582
594
mask = x .mask ,
583
595
local_shapes = local_shapes ,
584
596
partition = Partition .SCATTER ,
585
- dtype = self .dtype ,
586
- base_comm = self .base_comm
587
- )
597
+ dtype = output_dtype ,
598
+ base_comm = self .base_comm )
588
599
589
600
# Calculate expected padded dimensions for x
590
601
bk = self ._K_padded // self ._P_prime # block size in K dimension
@@ -603,13 +614,13 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
603
614
if pad_k > 0 or pad_m > 0 :
604
615
x_block = np .pad (x_block , [(0 , pad_k ), (0 , pad_m )], mode = 'constant' )
605
616
606
- Y_local = np .zeros ((self .A .shape [0 ], bm ))
617
+ Y_local = np .zeros ((self .A .shape [0 ], bm ), dtype = output_dtype )
607
618
608
619
for k in range (self ._P_prime ):
609
620
Atemp = self .A .copy () if self ._col_id == k else np .empty_like (self .A )
610
621
Xtemp = x_block .copy () if self ._row_id == k else np .empty_like (x_block )
611
- self ._row_comm .bcast (Atemp , root = k )
612
- self ._col_comm .bcast (Xtemp , root = k )
622
+ self ._row_comm .Bcast (Atemp , root = k )
623
+ self ._col_comm .Bcast (Xtemp , root = k )
613
624
Y_local += ncp .dot (Atemp , Xtemp )
614
625
615
626
Y_local_unpadded = Y_local [:local_n , :local_m ]
@@ -631,13 +642,24 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
631
642
local_m = bm if self ._col_id != self ._P_prime - 1 else self .M - (self ._P_prime - 1 ) * bm
632
643
633
644
local_shapes = self .base_comm .allgather (local_k * local_m )
645
+ # - If A is real: A^H = A^T,
646
+ # so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real)
647
+ # - If A is complex: A^H is complex,
648
+ # so result will be complex regardless of x
649
+ if np .iscomplexobj (self .A ):
650
+ output_dtype = np .result_type (self .dtype , x .dtype )
651
+ else :
652
+ # Real matrix: A^T @ x preserves input type complexity
653
+ output_dtype = x .dtype if np .iscomplexobj (x .local_array ) else self .dtype
654
+ # But still need to check type promotion for precision
655
+ output_dtype = np .result_type (self .dtype , output_dtype )
634
656
635
657
y = DistributedArray (
636
658
global_shape = (self .K * self .M ),
637
659
mask = x .mask ,
638
660
local_shapes = local_shapes ,
639
661
partition = Partition .SCATTER ,
640
- dtype = self . dtype ,
662
+ dtype = output_dtype ,
641
663
base_comm = self .base_comm
642
664
)
643
665
@@ -659,7 +681,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
659
681
x_block = np .pad (x_block , [(0 , pad_n ), (0 , pad_m )], mode = 'constant' )
660
682
661
683
A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
662
- Y_local = np .zeros ((self .A .shape [1 ], bm ))
684
+ Y_local = np .zeros ((self .A .shape [1 ], bm ), dtype = output_dtype )
663
685
664
686
for k in range (self ._P_prime ):
665
687
requests = []
0 commit comments