@@ -28,7 +28,7 @@ use libc::c_int;
28
28
#[ cfg( feature = "blas" ) ]
29
29
use cblas_sys as blas_sys;
30
30
#[ cfg( feature = "blas" ) ]
31
- use cblas_sys:: { CblasNoTrans , CblasTrans , CBLAS_LAYOUT } ;
31
+ use cblas_sys:: { CblasNoTrans , CblasTrans , CBLAS_LAYOUT , CBLAS_TRANSPOSE } ;
32
32
33
33
/// len of vector before we use blas
34
34
#[ cfg( feature = "blas" ) ]
@@ -400,40 +400,33 @@ fn mat_mul_impl<A>(
400
400
// Compute A B -> C
401
401
// We require for BLAS compatibility that:
402
402
// A, B, C are contiguous (stride=1) in their fastest dimension,
403
- // but it can be either first or second axis (either rowmajor /"c" or colmajor /"f") .
403
+ // but they can be either row major /"c" or col major /"f".
404
404
//
405
405
// The "normal case" is CblasRowMajor for cblas.
406
- // Select CblasRowMajor, CblasColMajor to fit C's memory order.
406
+ // Select CblasRowMajor / CblasColMajor to fit C's memory order.
407
407
//
408
- // Apply transpose to A, B as needed if they differ from the normal case.
408
+ // Apply transpose to A, B as needed if they differ from the row major case.
409
409
// If C is CblasColMajor then transpose both A, B (again!)
410
410
411
- let ( a_layout, a_axis, b_layout, b_axis, c_layout) =
412
- match ( get_blas_compatible_layout ( a) ,
413
- get_blas_compatible_layout ( b) ,
414
- get_blas_compatible_layout ( c) )
411
+ let ( a_layout, b_layout, c_layout) =
412
+ if let ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout) ) =
413
+ ( get_blas_compatible_layout ( a) ,
414
+ get_blas_compatible_layout ( b) ,
415
+ get_blas_compatible_layout ( c) )
415
416
{
416
- ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout @ MemoryOrder :: C ) ) => {
417
- ( a_layout, a_layout. lead_axis ( ) ,
418
- b_layout, b_layout. lead_axis ( ) , c_layout)
419
- } ,
420
- ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout @ MemoryOrder :: F ) ) => {
421
- // CblasColMajor is the "other case"
422
- // Mark a, b as having layouts opposite of what they were detected as, which
423
- // ends up with the correct transpose setting w.r.t col major
424
- ( a_layout. opposite ( ) , a_layout. lead_axis ( ) ,
425
- b_layout. opposite ( ) , b_layout. lead_axis ( ) , c_layout)
426
- } ,
427
- _ => break ' blas_block,
417
+ ( a_layout, b_layout, c_layout)
418
+ } else {
419
+ break ' blas_block;
428
420
} ;
429
421
430
- let a_trans = a_layout. to_cblas_transpose ( ) ;
431
- let lda = blas_stride ( & a, a_axis) ;
422
+ let cblas_layout = c_layout. to_cblas_layout ( ) ;
423
+ let a_trans = a_layout. to_cblas_transpose_for ( cblas_layout) ;
424
+ let lda = blas_stride ( & a, a_layout) ;
432
425
433
- let b_trans = b_layout. to_cblas_transpose ( ) ;
434
- let ldb = blas_stride ( & b, b_axis ) ;
426
+ let b_trans = b_layout. to_cblas_transpose_for ( cblas_layout ) ;
427
+ let ldb = blas_stride ( & b, b_layout ) ;
435
428
436
- let ldc = blas_stride ( & c, c_layout. lead_axis ( ) ) ;
429
+ let ldc = blas_stride ( & c, c_layout) ;
437
430
438
431
macro_rules! gemm_scalar_cast {
439
432
( f32 , $var: ident) => {
@@ -457,7 +450,7 @@ fn mat_mul_impl<A>(
457
450
// Where Op is notrans/trans/conjtrans
458
451
unsafe {
459
452
blas_sys:: $gemm(
460
- c_layout . to_cblas_layout ( ) ,
453
+ cblas_layout ,
461
454
a_trans,
462
455
b_trans,
463
456
m as blas_index, // m, rows of Op(a)
@@ -696,16 +689,8 @@ unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
696
689
// may be arbitrary.
697
690
let a_trans = CblasNoTrans ;
698
691
699
- let ( a_stride, cblas_layout) = match layout {
700
- MemoryOrder :: C => {
701
- ( a. strides( ) [ 0 ] . max( k as isize ) as blas_index,
702
- CBLAS_LAYOUT :: CblasRowMajor )
703
- }
704
- MemoryOrder :: F => {
705
- ( a. strides( ) [ 1 ] . max( m as isize ) as blas_index,
706
- CBLAS_LAYOUT :: CblasColMajor )
707
- }
708
- } ;
692
+ let a_stride = blas_stride( & a, layout) ;
693
+ let cblas_layout = layout. to_cblas_layout( ) ;
709
694
710
695
// Low addr in memory pointers required for x, y
711
696
let x_offset = offset_from_low_addr_ptr_to_logical_ptr( & x. dim, & x. strides) ;
@@ -835,61 +820,66 @@ where
835
820
#[ cfg( feature = "blas" ) ]
836
821
#[ derive( Copy , Clone ) ]
837
822
#[ cfg_attr( test, derive( PartialEq , Eq , Debug ) ) ]
838
- enum MemoryOrder
823
+ enum BlasOrder
839
824
{
840
825
C ,
841
826
F ,
842
827
}
843
828
844
829
#[ cfg( feature = "blas" ) ]
845
- impl MemoryOrder
830
+ impl BlasOrder
846
831
{
847
- #[ inline]
848
- /// Axis of leading stride (opposite of contiguous axis)
849
- fn lead_axis ( self ) -> usize
832
+ fn transpose ( self ) -> Self
850
833
{
851
834
match self {
852
- MemoryOrder :: C => 0 ,
853
- MemoryOrder :: F => 1 ,
835
+ Self :: C => Self :: F ,
836
+ Self :: F => Self :: C ,
854
837
}
855
838
}
856
839
857
- /// Get opposite memory order
858
840
#[ inline]
859
- fn opposite ( self ) -> Self
841
+ /// Axis of leading stride (opposite of contiguous axis)
842
+ fn get_blas_lead_axis ( self ) -> usize
860
843
{
861
844
match self {
862
- MemoryOrder :: C => MemoryOrder :: F ,
863
- MemoryOrder :: F => MemoryOrder :: C ,
845
+ Self :: C => 0 ,
846
+ Self :: F => 1 ,
864
847
}
865
848
}
866
849
867
- fn to_cblas_transpose ( self ) -> cblas_sys :: CBLAS_TRANSPOSE
850
+ fn to_cblas_layout ( self ) -> CBLAS_LAYOUT
868
851
{
869
852
match self {
870
- MemoryOrder :: C => CblasNoTrans ,
871
- MemoryOrder :: F => CblasTrans ,
853
+ Self :: C => CBLAS_LAYOUT :: CblasRowMajor ,
854
+ Self :: F => CBLAS_LAYOUT :: CblasColMajor ,
872
855
}
873
856
}
874
857
875
- fn to_cblas_layout ( self ) -> CBLAS_LAYOUT
858
+ /// When using cblas_sgemm (etc) with C matrix using `for_layout`,
859
+ /// how should this `self` matrix be transposed
860
+ fn to_cblas_transpose_for ( self , for_layout : CBLAS_LAYOUT ) -> CBLAS_TRANSPOSE
876
861
{
877
- match self {
878
- MemoryOrder :: C => CBLAS_LAYOUT :: CblasRowMajor ,
879
- MemoryOrder :: F => CBLAS_LAYOUT :: CblasColMajor ,
862
+ let effective_order = match for_layout {
863
+ CBLAS_LAYOUT :: CblasRowMajor => self ,
864
+ CBLAS_LAYOUT :: CblasColMajor => self . transpose ( ) ,
865
+ } ;
866
+
867
+ match effective_order {
868
+ Self :: C => CblasNoTrans ,
869
+ Self :: F => CblasTrans ,
880
870
}
881
871
}
882
872
}
883
873
884
874
#[ cfg( feature = "blas" ) ]
885
- fn is_blas_2d ( dim : & Ix2 , stride : & Ix2 , order : MemoryOrder ) -> bool
875
+ fn is_blas_2d ( dim : & Ix2 , stride : & Ix2 , order : BlasOrder ) -> bool
886
876
{
887
877
let ( m, n) = dim. into_pattern ( ) ;
888
878
let s0 = stride[ 0 ] as isize ;
889
879
let s1 = stride[ 1 ] as isize ;
890
880
let ( inner_stride, outer_stride, inner_dim, outer_dim) = match order {
891
- MemoryOrder :: C => ( s1, s0, m, n) ,
892
- MemoryOrder :: F => ( s0, s1, n, m) ,
881
+ BlasOrder :: C => ( s1, s0, m, n) ,
882
+ BlasOrder :: F => ( s0, s1, n, m) ,
893
883
} ;
894
884
895
885
if !( inner_stride == 1 || outer_dim == 1 ) {
@@ -920,13 +910,13 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
920
910
921
911
/// Get BLAS compatible layout if any (C or F, preferring the former)
922
912
#[ cfg( feature = "blas" ) ]
923
- fn get_blas_compatible_layout < S > ( a : & ArrayBase < S , Ix2 > ) -> Option < MemoryOrder >
913
+ fn get_blas_compatible_layout < S > ( a : & ArrayBase < S , Ix2 > ) -> Option < BlasOrder >
924
914
where S : Data
925
915
{
926
- if is_blas_2d ( & a. dim , & a. strides , MemoryOrder :: C ) {
927
- Some ( MemoryOrder :: C )
928
- } else if is_blas_2d ( & a. dim , & a. strides , MemoryOrder :: F ) {
929
- Some ( MemoryOrder :: F )
916
+ if is_blas_2d ( & a. dim , & a. strides , BlasOrder :: C ) {
917
+ Some ( BlasOrder :: C )
918
+ } else if is_blas_2d ( & a. dim , & a. strides , BlasOrder :: F ) {
919
+ Some ( BlasOrder :: F )
930
920
} else {
931
921
None
932
922
}
@@ -937,10 +927,10 @@ where S: Data
937
927
///
938
928
/// Return leading stride (lda, ldb, ldc) of array
939
929
#[ cfg( feature = "blas" ) ]
940
- fn blas_stride < S > ( a : & ArrayBase < S , Ix2 > , axis : usize ) -> blas_index
930
+ fn blas_stride < S > ( a : & ArrayBase < S , Ix2 > , order : BlasOrder ) -> blas_index
941
931
where S : Data
942
932
{
943
- debug_assert ! ( axis <= 1 ) ;
933
+ let axis = order . get_blas_lead_axis ( ) ;
944
934
let other_axis = 1 - axis;
945
935
let len_this = a. shape ( ) [ axis] ;
946
936
let len_other = a. shape ( ) [ other_axis] ;
@@ -968,7 +958,7 @@ where
968
958
if !same_type :: < A , S :: Elem > ( ) {
969
959
return false ;
970
960
}
971
- is_blas_2d ( & a. dim , & a. strides , MemoryOrder :: C )
961
+ is_blas_2d ( & a. dim , & a. strides , BlasOrder :: C )
972
962
}
973
963
974
964
#[ cfg( test) ]
@@ -982,7 +972,7 @@ where
982
972
if !same_type :: < A , S :: Elem > ( ) {
983
973
return false ;
984
974
}
985
- is_blas_2d ( & a. dim , & a. strides , MemoryOrder :: F )
975
+ is_blas_2d ( & a. dim , & a. strides , BlasOrder :: F )
986
976
}
987
977
988
978
#[ cfg( test) ]
@@ -1096,7 +1086,7 @@ mod blas_tests
1096
1086
if stride < N {
1097
1087
assert_eq ! ( get_blas_compatible_layout( & m) , None ) ;
1098
1088
} else {
1099
- assert_eq ! ( get_blas_compatible_layout( & m) , Some ( MemoryOrder :: C ) ) ;
1089
+ assert_eq ! ( get_blas_compatible_layout( & m) , Some ( BlasOrder :: C ) ) ;
1100
1090
}
1101
1091
}
1102
1092
}
0 commit comments