Skip to content

Commit 453eae3

Browse files
committed
blas: Refactor and simplify gemm call further
Further clarify transpose logic by putting it into BlasOrder methods.
1 parent 7226d39 commit 453eae3

File tree

1 file changed

+57
-67
lines changed

1 file changed

+57
-67
lines changed

src/linalg/impl_linalg.rs

Lines changed: 57 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use libc::c_int;
2828
#[cfg(feature = "blas")]
2929
use cblas_sys as blas_sys;
3030
#[cfg(feature = "blas")]
31-
use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT};
31+
use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT, CBLAS_TRANSPOSE};
3232

3333
/// len of vector before we use blas
3434
#[cfg(feature = "blas")]
@@ -400,40 +400,33 @@ fn mat_mul_impl<A>(
400400
// Compute A B -> C
401401
// We require for BLAS compatibility that:
402402
// A, B, C are contiguous (stride=1) in their fastest dimension,
403-
// but it can be either first or second axis (either rowmajor/"c" or colmajor/"f").
403+
// but they can be either row major/"c" or col major/"f".
404404
//
405405
// The "normal case" is CblasRowMajor for cblas.
406-
// Select CblasRowMajor, CblasColMajor to fit C's memory order.
406+
// Select CblasRowMajor / CblasColMajor to fit C's memory order.
407407
//
408-
// Apply transpose to A, B as needed if they differ from the normal case.
408+
// Apply transpose to A, B as needed if they differ from the row major case.
409409
// If C is CblasColMajor then transpose both A, B (again!)
410410

411-
let (a_layout, a_axis, b_layout, b_axis, c_layout) =
412-
match (get_blas_compatible_layout(a),
413-
get_blas_compatible_layout(b),
414-
get_blas_compatible_layout(c))
411+
let (a_layout, b_layout, c_layout) =
412+
if let (Some(a_layout), Some(b_layout), Some(c_layout)) =
413+
(get_blas_compatible_layout(a),
414+
get_blas_compatible_layout(b),
415+
get_blas_compatible_layout(c))
415416
{
416-
(Some(a_layout), Some(b_layout), Some(c_layout @ MemoryOrder::C)) => {
417-
(a_layout, a_layout.lead_axis(),
418-
b_layout, b_layout.lead_axis(), c_layout)
419-
},
420-
(Some(a_layout), Some(b_layout), Some(c_layout @ MemoryOrder::F)) => {
421-
// CblasColMajor is the "other case"
422-
// Mark a, b as having layouts opposite of what they were detected as, which
423-
// ends up with the correct transpose setting w.r.t col major
424-
(a_layout.opposite(), a_layout.lead_axis(),
425-
b_layout.opposite(), b_layout.lead_axis(), c_layout)
426-
},
427-
_ => break 'blas_block,
417+
(a_layout, b_layout, c_layout)
418+
} else {
419+
break 'blas_block;
428420
};
429421

430-
let a_trans = a_layout.to_cblas_transpose();
431-
let lda = blas_stride(&a, a_axis);
422+
let cblas_layout = c_layout.to_cblas_layout();
423+
let a_trans = a_layout.to_cblas_transpose_for(cblas_layout);
424+
let lda = blas_stride(&a, a_layout);
432425

433-
let b_trans = b_layout.to_cblas_transpose();
434-
let ldb = blas_stride(&b, b_axis);
426+
let b_trans = b_layout.to_cblas_transpose_for(cblas_layout);
427+
let ldb = blas_stride(&b, b_layout);
435428

436-
let ldc = blas_stride(&c, c_layout.lead_axis());
429+
let ldc = blas_stride(&c, c_layout);
437430

438431
macro_rules! gemm_scalar_cast {
439432
(f32, $var:ident) => {
@@ -457,7 +450,7 @@ fn mat_mul_impl<A>(
457450
// Where Op is notrans/trans/conjtrans
458451
unsafe {
459452
blas_sys::$gemm(
460-
c_layout.to_cblas_layout(),
453+
cblas_layout,
461454
a_trans,
462455
b_trans,
463456
m as blas_index, // m, rows of Op(a)
@@ -696,16 +689,8 @@ unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
696689
// may be arbitrary.
697690
let a_trans = CblasNoTrans;
698691

699-
let (a_stride, cblas_layout) = match layout {
700-
MemoryOrder::C => {
701-
(a.strides()[0].max(k as isize) as blas_index,
702-
CBLAS_LAYOUT::CblasRowMajor)
703-
}
704-
MemoryOrder::F => {
705-
(a.strides()[1].max(m as isize) as blas_index,
706-
CBLAS_LAYOUT::CblasColMajor)
707-
}
708-
};
692+
let a_stride = blas_stride(&a, layout);
693+
let cblas_layout = layout.to_cblas_layout();
709694

710695
// Low addr in memory pointers required for x, y
711696
let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides);
@@ -835,61 +820,66 @@ where
835820
#[cfg(feature = "blas")]
836821
#[derive(Copy, Clone)]
837822
#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
838-
enum MemoryOrder
823+
enum BlasOrder
839824
{
840825
C,
841826
F,
842827
}
843828

844829
#[cfg(feature = "blas")]
845-
impl MemoryOrder
830+
impl BlasOrder
846831
{
847-
#[inline]
848-
/// Axis of leading stride (opposite of contiguous axis)
849-
fn lead_axis(self) -> usize
832+
fn transpose(self) -> Self
850833
{
851834
match self {
852-
MemoryOrder::C => 0,
853-
MemoryOrder::F => 1,
835+
Self::C => Self::F,
836+
Self::F => Self::C,
854837
}
855838
}
856839

857-
/// Get opposite memory order
858840
#[inline]
859-
fn opposite(self) -> Self
841+
/// Axis of leading stride (opposite of contiguous axis)
842+
fn get_blas_lead_axis(self) -> usize
860843
{
861844
match self {
862-
MemoryOrder::C => MemoryOrder::F,
863-
MemoryOrder::F => MemoryOrder::C,
845+
Self::C => 0,
846+
Self::F => 1,
864847
}
865848
}
866849

867-
fn to_cblas_transpose(self) -> cblas_sys::CBLAS_TRANSPOSE
850+
fn to_cblas_layout(self) -> CBLAS_LAYOUT
868851
{
869852
match self {
870-
MemoryOrder::C => CblasNoTrans,
871-
MemoryOrder::F => CblasTrans,
853+
Self::C => CBLAS_LAYOUT::CblasRowMajor,
854+
Self::F => CBLAS_LAYOUT::CblasColMajor,
872855
}
873856
}
874857

875-
fn to_cblas_layout(self) -> CBLAS_LAYOUT
858+
/// When using cblas_sgemm (etc) with C matrix using `for_layout`,
859+
/// how should this `self` matrix be transposed
860+
fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE
876861
{
877-
match self {
878-
MemoryOrder::C => CBLAS_LAYOUT::CblasRowMajor,
879-
MemoryOrder::F => CBLAS_LAYOUT::CblasColMajor,
862+
let effective_order = match for_layout {
863+
CBLAS_LAYOUT::CblasRowMajor => self,
864+
CBLAS_LAYOUT::CblasColMajor => self.transpose(),
865+
};
866+
867+
match effective_order {
868+
Self::C => CblasNoTrans,
869+
Self::F => CblasTrans,
880870
}
881871
}
882872
}
883873

884874
#[cfg(feature = "blas")]
885-
fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
875+
fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool
886876
{
887877
let (m, n) = dim.into_pattern();
888878
let s0 = stride[0] as isize;
889879
let s1 = stride[1] as isize;
890880
let (inner_stride, outer_stride, inner_dim, outer_dim) = match order {
891-
MemoryOrder::C => (s1, s0, m, n),
892-
MemoryOrder::F => (s0, s1, n, m),
881+
BlasOrder::C => (s1, s0, m, n),
882+
BlasOrder::F => (s0, s1, n, m),
893883
};
894884

895885
if !(inner_stride == 1 || outer_dim == 1) {
@@ -920,13 +910,13 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
920910

921911
/// Get BLAS compatible layout if any (C or F, preferring the former)
922912
#[cfg(feature = "blas")]
923-
fn get_blas_compatible_layout<S>(a: &ArrayBase<S, Ix2>) -> Option<MemoryOrder>
913+
fn get_blas_compatible_layout<S>(a: &ArrayBase<S, Ix2>) -> Option<BlasOrder>
924914
where S: Data
925915
{
926-
if is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) {
927-
Some(MemoryOrder::C)
928-
} else if is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) {
929-
Some(MemoryOrder::F)
916+
if is_blas_2d(&a.dim, &a.strides, BlasOrder::C) {
917+
Some(BlasOrder::C)
918+
} else if is_blas_2d(&a.dim, &a.strides, BlasOrder::F) {
919+
Some(BlasOrder::F)
930920
} else {
931921
None
932922
}
@@ -937,10 +927,10 @@ where S: Data
937927
///
938928
/// Return leading stride (lda, ldb, ldc) of array
939929
#[cfg(feature = "blas")]
940-
fn blas_stride<S>(a: &ArrayBase<S, Ix2>, axis: usize) -> blas_index
930+
fn blas_stride<S>(a: &ArrayBase<S, Ix2>, order: BlasOrder) -> blas_index
941931
where S: Data
942932
{
943-
debug_assert!(axis <= 1);
933+
let axis = order.get_blas_lead_axis();
944934
let other_axis = 1 - axis;
945935
let len_this = a.shape()[axis];
946936
let len_other = a.shape()[other_axis];
@@ -968,7 +958,7 @@ where
968958
if !same_type::<A, S::Elem>() {
969959
return false;
970960
}
971-
is_blas_2d(&a.dim, &a.strides, MemoryOrder::C)
961+
is_blas_2d(&a.dim, &a.strides, BlasOrder::C)
972962
}
973963

974964
#[cfg(test)]
@@ -982,7 +972,7 @@ where
982972
if !same_type::<A, S::Elem>() {
983973
return false;
984974
}
985-
is_blas_2d(&a.dim, &a.strides, MemoryOrder::F)
975+
is_blas_2d(&a.dim, &a.strides, BlasOrder::F)
986976
}
987977

988978
#[cfg(test)]
@@ -1096,7 +1086,7 @@ mod blas_tests
10961086
if stride < N {
10971087
assert_eq!(get_blas_compatible_layout(&m), None);
10981088
} else {
1099-
assert_eq!(get_blas_compatible_layout(&m), Some(MemoryOrder::C));
1089+
assert_eq!(get_blas_compatible_layout(&m), Some(BlasOrder::C));
11001090
}
11011091
}
11021092
}

0 commit comments

Comments
 (0)