Skip to content

Hook up sha2 instruction and detect sha-supporting cpu #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sha2/benches/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![no_std]
#![feature(test)]
extern crate test;
extern crate sha2_asm;
extern crate test;

use test::Bencher;

Expand Down
18 changes: 12 additions & 6 deletions sha2/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@ fn main() {
panic!("Unsupported target architecture");
};
cc::Build::new()
.flag("-c")
.file(sha256_path)
.compile("libsha256.a");
.flag("-c")
.file(sha256_path)
.compile("libsha256.a");
cc::Build::new()
.flag("-c")
.file(sha512_path)
.compile("libsha512.a");
.flag("-c")
.file(sha512_path)
.compile("libsha512.a");
cc::Build::new()
.flag("-c")
.flag("-msse4.1")
.flag("-msha")
.file("src/sha256_x64.c")
.compile("libsha256_shani.a");
}
28 changes: 25 additions & 3 deletions sha2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,40 @@
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
compile_error!("crate can only be used on x86 and x86-64 architectures");

#[link(name="sha256", kind="static")]
#[link(name = "sha256_shani", kind = "static")]
extern "C" {
fn sha256_process_x86(state: &mut [u32; 8], block: *const u8, length: u32);
}

#[inline]
pub fn compress256_shani(state: &mut [u32; 8], block: &[u8; 64]) {
unsafe { sha256_process_x86(state, block.as_ptr(), block.len() as u32) }
}

#[link(name = "sha256", kind = "static")]
extern "C" {
fn sha256_compress(state: &mut [u32; 8], block: &[u8; 64]);
}

use core::arch::x86_64::CpuidResult;
#[inline]
pub fn get_cpuid(info: u32) -> CpuidResult {
use core::arch::x86_64::__cpuid_count;
unsafe { __cpuid_count(info, 0) }
}

/// Safe wrapper around assembly implementation of SHA256 compression function
#[inline]
pub fn compress256(state: &mut [u32; 8], block: &[u8; 64]) {
unsafe { sha256_compress(state, block) }
let x = get_cpuid(0x7);
if x.ebx & (1 << 29) != 0 {
compress256_shani(state, block);
} else {
unsafe { sha256_compress(state, block) };
}
}

#[link(name="sha512", kind="static")]
#[link(name = "sha512", kind = "static")]
extern "C" {
fn sha512_compress(state: &mut [u64; 8], block: &[u8; 128]);
}
Expand Down
263 changes: 263 additions & 0 deletions sha2/src/sha256_x64.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
/* sha256-x86.c - Intel SHA extensions using C intrinsics */
/* Written and place in public domain by Jeffrey Walton */
/* Based on code from Intel, and by Sean Gulley for */
/* the miTLS project. */

/* gcc -DTEST_MAIN -msse4.1 -msha sha256-x86.c -o sha256.exe */

/* Include the GCC super header */
#if defined(__GNUC__)
# include <stdint.h>
# include <x86intrin.h>
#endif

/* Microsoft supports Intel SHA ACLE extensions as of Visual Studio 2015 */
#if defined(_MSC_VER)
# include <immintrin.h>
# define WIN32_LEAN_AND_MEAN
# include <Windows.h>
typedef UINT32 uint32_t;
typedef UINT8 uint8_t;
#endif

/* Process multiple blocks. The caller is responsible for setting the initial */
/* state, and the caller is responsible for padding the final block. */
void sha256_process_x86(uint32_t state[8], const uint8_t data[], uint32_t length)
{
__m128i STATE0, STATE1;
__m128i MSG, TMP;
__m128i MSG0, MSG1, MSG2, MSG3;
__m128i ABEF_SAVE, CDGH_SAVE;
const __m128i MASK = _mm_set_epi64x(0x0c0d0e0f08090a0bULL, 0x0405060700010203ULL);

/* Load initial values */
TMP = _mm_loadu_si128((const __m128i*) &state[0]);
STATE1 = _mm_loadu_si128((const __m128i*) &state[4]);


TMP = _mm_shuffle_epi32(TMP, 0xB1); /* CDAB */
STATE1 = _mm_shuffle_epi32(STATE1, 0x1B); /* EFGH */
STATE0 = _mm_alignr_epi8(TMP, STATE1, 8); /* ABEF */
STATE1 = _mm_blend_epi16(STATE1, TMP, 0xF0); /* CDGH */

while (length >= 64)
{
/* Save current state */
ABEF_SAVE = STATE0;
CDGH_SAVE = STATE1;

/* Rounds 0-3 */
MSG = _mm_loadu_si128((const __m128i*) (data+0));
MSG0 = _mm_shuffle_epi8(MSG, MASK);
MSG = _mm_add_epi32(MSG0, _mm_set_epi64x(0xE9B5DBA5B5C0FBCFULL, 0x71374491428A2F98ULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);

/* Rounds 4-7 */
MSG1 = _mm_loadu_si128((const __m128i*) (data+16));
MSG1 = _mm_shuffle_epi8(MSG1, MASK);
MSG = _mm_add_epi32(MSG1, _mm_set_epi64x(0xAB1C5ED5923F82A4ULL, 0x59F111F13956C25BULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);
MSG0 = _mm_sha256msg1_epu32(MSG0, MSG1);

/* Rounds 8-11 */
MSG2 = _mm_loadu_si128((const __m128i*) (data+32));
MSG2 = _mm_shuffle_epi8(MSG2, MASK);
MSG = _mm_add_epi32(MSG2, _mm_set_epi64x(0x550C7DC3243185BEULL, 0x12835B01D807AA98ULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);
MSG1 = _mm_sha256msg1_epu32(MSG1, MSG2);

/* Rounds 12-15 */
MSG3 = _mm_loadu_si128((const __m128i*) (data+48));
MSG3 = _mm_shuffle_epi8(MSG3, MASK);
MSG = _mm_add_epi32(MSG3, _mm_set_epi64x(0xC19BF1749BDC06A7ULL, 0x80DEB1FE72BE5D74ULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
TMP = _mm_alignr_epi8(MSG3, MSG2, 4);
MSG0 = _mm_add_epi32(MSG0, TMP);
MSG0 = _mm_sha256msg2_epu32(MSG0, MSG3);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);
MSG2 = _mm_sha256msg1_epu32(MSG2, MSG3);

/* Rounds 16-19 */
MSG = _mm_add_epi32(MSG0, _mm_set_epi64x(0x240CA1CC0FC19DC6ULL, 0xEFBE4786E49B69C1ULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
TMP = _mm_alignr_epi8(MSG0, MSG3, 4);
MSG1 = _mm_add_epi32(MSG1, TMP);
MSG1 = _mm_sha256msg2_epu32(MSG1, MSG0);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);
MSG3 = _mm_sha256msg1_epu32(MSG3, MSG0);

/* Rounds 20-23 */
MSG = _mm_add_epi32(MSG1, _mm_set_epi64x(0x76F988DA5CB0A9DCULL, 0x4A7484AA2DE92C6FULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
TMP = _mm_alignr_epi8(MSG1, MSG0, 4);
MSG2 = _mm_add_epi32(MSG2, TMP);
MSG2 = _mm_sha256msg2_epu32(MSG2, MSG1);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);
MSG0 = _mm_sha256msg1_epu32(MSG0, MSG1);

/* Rounds 24-27 */
MSG = _mm_add_epi32(MSG2, _mm_set_epi64x(0xBF597FC7B00327C8ULL, 0xA831C66D983E5152ULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
TMP = _mm_alignr_epi8(MSG2, MSG1, 4);
MSG3 = _mm_add_epi32(MSG3, TMP);
MSG3 = _mm_sha256msg2_epu32(MSG3, MSG2);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);
MSG1 = _mm_sha256msg1_epu32(MSG1, MSG2);

/* Rounds 28-31 */
MSG = _mm_add_epi32(MSG3, _mm_set_epi64x(0x1429296706CA6351ULL, 0xD5A79147C6E00BF3ULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
TMP = _mm_alignr_epi8(MSG3, MSG2, 4);
MSG0 = _mm_add_epi32(MSG0, TMP);
MSG0 = _mm_sha256msg2_epu32(MSG0, MSG3);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);
MSG2 = _mm_sha256msg1_epu32(MSG2, MSG3);

/* Rounds 32-35 */
MSG = _mm_add_epi32(MSG0, _mm_set_epi64x(0x53380D134D2C6DFCULL, 0x2E1B213827B70A85ULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
TMP = _mm_alignr_epi8(MSG0, MSG3, 4);
MSG1 = _mm_add_epi32(MSG1, TMP);
MSG1 = _mm_sha256msg2_epu32(MSG1, MSG0);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);
MSG3 = _mm_sha256msg1_epu32(MSG3, MSG0);

/* Rounds 36-39 */
MSG = _mm_add_epi32(MSG1, _mm_set_epi64x(0x92722C8581C2C92EULL, 0x766A0ABB650A7354ULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
TMP = _mm_alignr_epi8(MSG1, MSG0, 4);
MSG2 = _mm_add_epi32(MSG2, TMP);
MSG2 = _mm_sha256msg2_epu32(MSG2, MSG1);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);
MSG0 = _mm_sha256msg1_epu32(MSG0, MSG1);

/* Rounds 40-43 */
MSG = _mm_add_epi32(MSG2, _mm_set_epi64x(0xC76C51A3C24B8B70ULL, 0xA81A664BA2BFE8A1ULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
TMP = _mm_alignr_epi8(MSG2, MSG1, 4);
MSG3 = _mm_add_epi32(MSG3, TMP);
MSG3 = _mm_sha256msg2_epu32(MSG3, MSG2);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);
MSG1 = _mm_sha256msg1_epu32(MSG1, MSG2);

/* Rounds 44-47 */
MSG = _mm_add_epi32(MSG3, _mm_set_epi64x(0x106AA070F40E3585ULL, 0xD6990624D192E819ULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
TMP = _mm_alignr_epi8(MSG3, MSG2, 4);
MSG0 = _mm_add_epi32(MSG0, TMP);
MSG0 = _mm_sha256msg2_epu32(MSG0, MSG3);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);
MSG2 = _mm_sha256msg1_epu32(MSG2, MSG3);

/* Rounds 48-51 */
MSG = _mm_add_epi32(MSG0, _mm_set_epi64x(0x34B0BCB52748774CULL, 0x1E376C0819A4C116ULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
TMP = _mm_alignr_epi8(MSG0, MSG3, 4);
MSG1 = _mm_add_epi32(MSG1, TMP);
MSG1 = _mm_sha256msg2_epu32(MSG1, MSG0);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);
MSG3 = _mm_sha256msg1_epu32(MSG3, MSG0);

/* Rounds 52-55 */
MSG = _mm_add_epi32(MSG1, _mm_set_epi64x(0x682E6FF35B9CCA4FULL, 0x4ED8AA4A391C0CB3ULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
TMP = _mm_alignr_epi8(MSG1, MSG0, 4);
MSG2 = _mm_add_epi32(MSG2, TMP);
MSG2 = _mm_sha256msg2_epu32(MSG2, MSG1);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);

/* Rounds 56-59 */
MSG = _mm_add_epi32(MSG2, _mm_set_epi64x(0x8CC7020884C87814ULL, 0x78A5636F748F82EEULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
TMP = _mm_alignr_epi8(MSG2, MSG1, 4);
MSG3 = _mm_add_epi32(MSG3, TMP);
MSG3 = _mm_sha256msg2_epu32(MSG3, MSG2);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);

/* Rounds 60-63 */
MSG = _mm_add_epi32(MSG3, _mm_set_epi64x(0xC67178F2BEF9A3F7ULL, 0xA4506CEB90BEFFFAULL));
STATE1 = _mm_sha256rnds2_epu32(STATE1, STATE0, MSG);
MSG = _mm_shuffle_epi32(MSG, 0x0E);
STATE0 = _mm_sha256rnds2_epu32(STATE0, STATE1, MSG);

/* Combine state */
STATE0 = _mm_add_epi32(STATE0, ABEF_SAVE);
STATE1 = _mm_add_epi32(STATE1, CDGH_SAVE);

data += 64;
length -= 64;
}

TMP = _mm_shuffle_epi32(STATE0, 0x1B); /* FEBA */
STATE1 = _mm_shuffle_epi32(STATE1, 0xB1); /* DCHG */
STATE0 = _mm_blend_epi16(TMP, STATE1, 0xF0); /* DCBA */
STATE1 = _mm_alignr_epi8(STATE1, TMP, 8); /* ABEF */

/* Save state */
_mm_storeu_si128((__m128i*) &state[0], STATE0);
_mm_storeu_si128((__m128i*) &state[4], STATE1);
}

#if defined(TEST_MAIN)

#include <stdio.h>
#include <string.h>
int main(int argc, char* argv[])
{
/* empty message with padding */
uint8_t message[64];
memset(message, 0x00, sizeof(message));
message[0] = 0x80;

/* initial state */
uint32_t state[8] = {
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a,
0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19
};

sha256_process_x86(state, message, sizeof(message));

const uint8_t b1 = (uint8_t)(state[0] >> 24);
const uint8_t b2 = (uint8_t)(state[0] >> 16);
const uint8_t b3 = (uint8_t)(state[0] >> 8);
const uint8_t b4 = (uint8_t)(state[0] >> 0);
const uint8_t b5 = (uint8_t)(state[1] >> 24);
const uint8_t b6 = (uint8_t)(state[1] >> 16);
const uint8_t b7 = (uint8_t)(state[1] >> 8);
const uint8_t b8 = (uint8_t)(state[1] >> 0);

/* e3b0c44298fc1c14... */
printf("SHA256 hash of empty message: ");
printf("%02X%02X%02X%02X%02X%02X%02X%02X...\n",
b1, b2, b3, b4, b5, b6, b7, b8);

int success = ((b1 == 0xE3) && (b2 == 0xB0) && (b3 == 0xC4) && (b4 == 0x42) &&
(b5 == 0x98) && (b6 == 0xFC) && (b7 == 0x1C) && (b8 == 0x14));

if (success)
printf("Success!\n");
else
printf("Failure!\n");

return (success != 0 ? 0 : 1);
}

#endif