Skip to content

Possible missed vectorization in unrolled_dot #3218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions components/segmenter/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ databake = { version = "0.1.3", path = "../../utils/databake", optional = true,
serde = { version = "1.0", default-features = false, features = ["derive", "alloc"], optional = true }

libm = { version = "0.2", default-features = false, optional = true }
once_cell = { version = "1.17.1"}

[dev-dependencies]
criterion = "0.4"
Expand Down
96 changes: 4 additions & 92 deletions components/segmenter/src/complex/lstm/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
// called LICENSE at the top level of the ICU4X source tree
// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).

use super::ops::{dot_1, dot_2};
use alloc::vec;
use alloc::vec::Vec;
use core::ops::Range;
use zerovec::ule::AsULE;
use zerovec::ZeroSlice;

// This will be used in #[no_std] as f32::exp/f32::tanh are not in core.
Expand Down Expand Up @@ -276,14 +276,6 @@ impl<'a, const D: usize> MatrixBorrowedMut<'a, D> {
}
}

impl<'a> MatrixBorrowed<'a, 1> {
#[allow(dead_code)] // could be useful
pub(super) fn dot_1d(&self, other: MatrixZero<1>) -> f32 {
debug_assert_eq!(self.dims, other.dims);
unrolled_dot_1(self.data, other.data)
}
}

impl<'a> MatrixBorrowedMut<'a, 1> {
/// Calculate the dot product of a and b, adding the result to self.
///
Expand Down Expand Up @@ -311,7 +303,7 @@ impl<'a> MatrixBorrowedMut<'a, 1> {
for i in 0..n {
if let (Some(dest), Some(b_sub)) = (self.as_mut_slice().get_mut(i), b.submatrix::<1>(i))
{
*dest += unrolled_dot_1(a.data, b_sub.data);
*dest += dot_1(a.data, b_sub.data);
} else {
debug_assert!(false, "unreachable: dims checked above");
}
Expand Down Expand Up @@ -353,7 +345,7 @@ impl<'a> MatrixBorrowedMut<'a, 2> {
self.as_mut_slice().get_mut(i),
b.as_slice().get_subslice(i * m..(i + 1) * m),
) {
*dest += unrolled_dot_1(lhs, rhs);
*dest += dot_1(lhs, rhs);
} else {
debug_assert!(false, "unreachable: dims checked above");
}
Expand Down Expand Up @@ -393,7 +385,7 @@ impl<'a> MatrixBorrowedMut<'a, 2> {
self.as_mut_slice().get_mut(i),
b.as_slice().get_subslice(i * m..(i + 1) * m),
) {
*dest += unrolled_dot_2(lhs, rhs);
*dest += dot_2(lhs, rhs);
} else {
debug_assert!(false, "unreachable: dims checked above");
}
Expand Down Expand Up @@ -476,83 +468,3 @@ impl<'a, const D: usize> MatrixZero<'a, D> {
(n * index..n * (index + 1), sub_dims)
}
}

macro_rules! f32c {
($ule:expr) => {
f32::from_unaligned($ule)
};
}

/// Compute the dot product of an aligned and an unaligned f32 slice.
///
/// `xs` and `ys` must be the same length
///
/// (Based on ndarray 0.15.6)
fn unrolled_dot_1(xs: &[f32], ys: &ZeroSlice<f32>) -> f32 {
debug_assert_eq!(xs.len(), ys.len());
// eightfold unrolled so that floating point can be vectorized
// (even with strict floating point accuracy semantics)
let mut p = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
let xit = xs.chunks_exact(8);
let yit = ys.as_ule_slice().chunks_exact(8);
let sum = xit
.remainder()
.iter()
.zip(yit.remainder().iter())
.map(|(x, y)| x * f32c!(*y))
.sum::<f32>();
for (xx, yy) in xit.zip(yit) {
// TODO: Use array_chunks once stable to avoid the unwrap.
// <https://github.com/rust-lang/rust/issues/74985>
#[allow(clippy::unwrap_used)]
let [x0, x1, x2, x3, x4, x5, x6, x7] = *<&[f32; 8]>::try_from(xx).unwrap();
#[allow(clippy::unwrap_used)]
let [y0, y1, y2, y3, y4, y5, y6, y7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(yy).unwrap();
p.0 += x0 * f32c!(y0);
p.1 += x1 * f32c!(y1);
p.2 += x2 * f32c!(y2);
p.3 += x3 * f32c!(y3);
p.4 += x4 * f32c!(y4);
p.5 += x5 * f32c!(y5);
p.6 += x6 * f32c!(y6);
p.7 += x7 * f32c!(y7);
}
sum + (p.0 + p.4) + (p.1 + p.5) + (p.2 + p.6) + (p.3 + p.7)
}

/// Compute the dot product of two unaligned f32 slices.
///
/// `xs` and `ys` must be the same length
///
/// (Based on ndarray 0.15.6)
fn unrolled_dot_2(xs: &ZeroSlice<f32>, ys: &ZeroSlice<f32>) -> f32 {
debug_assert_eq!(xs.len(), ys.len());
// eightfold unrolled so that floating point can be vectorized
// (even with strict floating point accuracy semantics)
let mut p = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
let xit = xs.as_ule_slice().chunks_exact(8);
let yit = ys.as_ule_slice().chunks_exact(8);
let sum = xit
.remainder()
.iter()
.zip(yit.remainder().iter())
.map(|(x, y)| f32c!(*x) * f32c!(*y))
.sum::<f32>();
for (xx, yy) in xit.zip(yit) {
// TODO: Use array_chunks once stable to avoid the unwrap.
// <https://github.com/rust-lang/rust/issues/74985>
#[allow(clippy::unwrap_used)]
let [x0, x1, x2, x3, x4, x5, x6, x7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(xx).unwrap();
#[allow(clippy::unwrap_used)]
let [y0, y1, y2, y3, y4, y5, y6, y7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(yy).unwrap();
p.0 += f32c!(x0) * f32c!(y0);
p.1 += f32c!(x1) * f32c!(y1);
p.2 += f32c!(x2) * f32c!(y2);
p.3 += f32c!(x3) * f32c!(y3);
p.4 += f32c!(x4) * f32c!(y4);
p.5 += f32c!(x5) * f32c!(y5);
p.6 += f32c!(x6) * f32c!(y6);
p.7 += f32c!(x7) * f32c!(y7);
}
sum + (p.0 + p.4) + (p.1 + p.5) + (p.2 + p.6) + (p.3 + p.7)
}
1 change: 1 addition & 0 deletions components/segmenter/src/complex/lstm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use zerovec::{maps::ZeroMapBorrowed, ule::UnvalidatedStr};

mod matrix;
use matrix::*;
mod ops;

// A word break iterator using LSTM model. Input string have to be same language.

Expand Down
Loading