Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SHA-2 accelerating instructions #103

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/aarch64/assembler-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5957,6 +5957,38 @@ void Assembler::sha1su1(const VRegister& vd, const VRegister& vn) {
Emit(0x5e281800 | Rd(vd) | Rn(vn));
}

void Assembler::sha256h(const VRegister& vd, const VRegister& vn, const VRegister& vm) {
VIXL_ASSERT(CPUHas(CPUFeatures::kNEON));
VIXL_ASSERT(CPUHas(CPUFeatures::kSHA2));
VIXL_ASSERT(vd.IsQ() && vn.IsQ() && vm.Is4S());

Emit(0x5e004000 | Rd(vd) | Rn(vn) | Rm(vm));
}

void Assembler::sha256h2(const VRegister& vd, const VRegister& vn, const VRegister& vm) {
VIXL_ASSERT(CPUHas(CPUFeatures::kNEON));
VIXL_ASSERT(CPUHas(CPUFeatures::kSHA2));
VIXL_ASSERT(vd.IsQ() && vn.IsQ() && vm.Is4S());

Emit(0x5e005000 | Rd(vd) | Rn(vn) | Rm(vm));
}

void Assembler::sha256su0(const VRegister& vd, const VRegister& vn) {
VIXL_ASSERT(CPUHas(CPUFeatures::kNEON));
VIXL_ASSERT(CPUHas(CPUFeatures::kSHA2));
VIXL_ASSERT(vd.Is4S() && vn.Is4S());

Emit(0x5e282800 | Rd(vd) | Rn(vn));
}

void Assembler::sha256su1(const VRegister& vd, const VRegister& vn, const VRegister& vm) {
VIXL_ASSERT(CPUHas(CPUFeatures::kNEON));
VIXL_ASSERT(CPUHas(CPUFeatures::kSHA2));
VIXL_ASSERT(vd.Is4S() && vn.Is4S() && vm.Is4S());

Emit(0x5e006000 | Rd(vd) | Rn(vn) | Rm(vm));
}

// Note:
// For all ToImm instructions below, a difference in case
// for the same letter indicates a negated bit.
Expand Down
12 changes: 12 additions & 0 deletions src/aarch64/assembler-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -3660,6 +3660,18 @@ class Assembler : public vixl::internal::AssemblerBase {
// SHA1 schedule update 1.
void sha1su1(const VRegister& vd, const VRegister& vn);

// SHA256 hash update (part 1).
void sha256h(const VRegister& vd, const VRegister& vn, const VRegister& vm);

// SHA256 hash update (part 2).
void sha256h2(const VRegister& vd, const VRegister& vn, const VRegister& vm);

// SHA256 schedule update 0.
void sha256su0(const VRegister& vd, const VRegister& vn);

// SHA256 schedule update 1.
void sha256su1(const VRegister& vd, const VRegister& vn, const VRegister& vm);

// Scalable Vector Extensions.

// Absolute value (predicated).
Expand Down
20 changes: 18 additions & 2 deletions src/aarch64/cpu-features-auditor-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,29 @@ void CPUFeaturesAuditor::VisitConditionalSelect(const Instruction* instr) {

void CPUFeaturesAuditor::VisitCrypto2RegSHA(const Instruction* instr) {
RecordInstructionFeaturesScope scope(this);
scope.Record(CPUFeatures::kNEON, CPUFeatures::kSHA1);
if (form_hash_ == "sha256su0_vv_cryptosha2"_h) {
scope.Record(CPUFeatures::kNEON, CPUFeatures::kSHA2);
} else {
scope.Record(CPUFeatures::kNEON, CPUFeatures::kSHA1);
}
USE(instr);
}

void CPUFeaturesAuditor::VisitCrypto3RegSHA(const Instruction* instr) {
RecordInstructionFeaturesScope scope(this);
scope.Record(CPUFeatures::kNEON, CPUFeatures::kSHA1);
switch (form_hash_) {
case "sha1c_qsv_cryptosha3"_h:
case "sha1m_qsv_cryptosha3"_h:
case "sha1p_qsv_cryptosha3"_h:
case "sha1su0_vvv_cryptosha3"_h:
scope.Record(CPUFeatures::kNEON, CPUFeatures::kSHA1);
break;
case "sha256h_qqv_cryptosha3"_h:
case "sha256h2_qqv_cryptosha3"_h:
case "sha256su1_vvv_cryptosha3"_h:
scope.Record(CPUFeatures::kNEON, CPUFeatures::kSHA2);
break;
}
USE(instr);
}

Expand Down
11 changes: 9 additions & 2 deletions src/aarch64/disasm-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2182,8 +2182,15 @@ void Disassembler::VisitCrypto2RegSHA(const Instruction *instr) {

void Disassembler::VisitCrypto3RegSHA(const Instruction *instr) {
const char *form = "'Qd, 'Sn, 'Vm.4s";
if (form_hash_ == "sha1su0_vvv_cryptosha3"_h) {
form = "'Vd.4s, 'Vn.4s, 'Vm.4s";
switch (form_hash_) {
case "sha1su0_vvv_cryptosha3"_h:
case "sha256su1_vvv_cryptosha3"_h:
form = "'Vd.4s, 'Vn.4s, 'Vm.4s";
break;
case "sha256h_qqv_cryptosha3"_h:
case "sha256h2_qqv_cryptosha3"_h:
form = "'Qd, 'Qn, 'Vm.4s";
break;
}
FormatWithDecodedMnemonic(instr, form);
}
Expand Down
83 changes: 83 additions & 0 deletions src/aarch64/logic-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7909,6 +7909,89 @@ uint64_t SHA1Operation<"parity"_h>(uint64_t x, uint64_t y, uint64_t z) {
return x ^ y ^ z;
}

template <unsigned A, unsigned B, unsigned C>
static uint64_t SHA2Sigma(uint64_t x) {
return RotateRight(x, A, kSRegSize) ^ RotateRight(x, B, kSRegSize) ^
RotateRight(x, C, kSRegSize);
}

LogicVRegister Simulator::sha2h(LogicVRegister srcdst,
const LogicVRegister& src1,
const LogicVRegister& src2,
bool part1) {
uint64_t x[4] = {};
uint64_t y[4] = {};
if (part1) {
// Switch input order based on which part is being handled.
srcdst.UintArray(kFormat4S, x);
src1.UintArray(kFormat4S, y);
} else {
src1.UintArray(kFormat4S, x);
srcdst.UintArray(kFormat4S, y);
}

for (unsigned i = 0; i < ArrayLength(x); i++) {
uint64_t chs = SHA1Operation<"choose"_h>(y[0], y[1], y[2]);
uint64_t maj = SHA1Operation<"majority"_h>(x[0], x[1], x[2]);

uint64_t w = src2.Uint(kFormat4S, i);
uint64_t t = y[3] + SHA2Sigma<6, 11, 25>(y[0]) + chs + w;

x[3] += t;
y[3] = t + SHA2Sigma<2, 13, 22>(x[0]) + maj;

// y:x = ROL(y:x, 32)
SHARotateEltsLeftOne(x);
SHARotateEltsLeftOne(y);
std::swap(x[0], y[0]);
}

srcdst.SetUintArray(kFormat4S, part1 ? x : y);
return srcdst;
}

template <unsigned A, unsigned B, unsigned C>
static uint64_t SHA2SURotate(uint64_t x) {
return RotateRight(x, A, kSRegSize) ^ RotateRight(x, B, kSRegSize) ^
((x & 0xffffffff) >> C);
}

LogicVRegister Simulator::sha2su0(LogicVRegister srcdst,
const LogicVRegister& src1) {
uint64_t w[4] = {};
uint64_t result[4];
srcdst.UintArray(kFormat4S, w);
uint64_t x = src1.Uint(kFormat4S, 0);

result[0] = SHA2SURotate<7, 18, 3>(w[1]) + w[0];
result[1] = SHA2SURotate<7, 18, 3>(w[2]) + w[1];
result[2] = SHA2SURotate<7, 18, 3>(w[3]) + w[2];
result[3] = SHA2SURotate<7, 18, 3>(x) + w[3];

srcdst.SetUintArray(kFormat4S, result);
return srcdst;
}

LogicVRegister Simulator::sha2su1(LogicVRegister srcdst,
const LogicVRegister& src1,
const LogicVRegister& src2) {
uint64_t w[4] = {};
uint64_t x[4] = {};
uint64_t y[4] = {};
uint64_t result[4];
srcdst.UintArray(kFormat4S, w);
src1.UintArray(kFormat4S, x);
src2.UintArray(kFormat4S, y);

result[0] = SHA2SURotate<17, 19, 10>(y[2]) + w[0] + x[1];
result[1] = SHA2SURotate<17, 19, 10>(y[3]) + w[1] + x[2];
result[2] = SHA2SURotate<17, 19, 10>(result[0]) + w[2] + x[3];
result[3] = SHA2SURotate<17, 19, 10>(result[1]) + w[3] + y[0];

srcdst.SetUintArray(kFormat4S, result);
return srcdst;
}

} // namespace aarch64
} // namespace vixl

Expand Down
4 changes: 4 additions & 0 deletions src/aarch64/macro-assembler-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -2804,6 +2804,9 @@ class MacroAssembler : public Assembler, public MacroAssemblerInterface {
V(sha1m, Sha1m) \
V(sha1p, Sha1p) \
V(sha1su0, Sha1su0) \
V(sha256h, Sha256h) \
V(sha256h2, Sha256h2) \
V(sha256su1, Sha256su1) \
V(shadd, Shadd) \
V(shsub, Shsub) \
V(smax, Smax) \
Expand Down Expand Up @@ -2950,6 +2953,7 @@ class MacroAssembler : public Assembler, public MacroAssemblerInterface {
V(saddlv, Saddlv) \
V(sha1h, Sha1h) \
V(sha1su1, Sha1su1) \
V(sha256su0, Sha256su0) \
V(smaxv, Smaxv) \
V(sminv, Sminv) \
V(sqabs, Sqabs) \
Expand Down
11 changes: 10 additions & 1 deletion src/aarch64/simulator-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7193,7 +7193,7 @@ void Simulator::VisitCrypto2RegSHA(const Instruction* instr) {
break;
}
case "sha256su0_vv_cryptosha2"_h:
VIXL_UNIMPLEMENTED();
sha2su0(rd, rn);
break;
}
}
Expand Down Expand Up @@ -7221,6 +7221,15 @@ void Simulator::VisitCrypto3RegSHA(const Instruction* instr) {
eor(kFormat16B, rd, temp, rm);
break;
}
case "sha256h_qqv_cryptosha3"_h:
sha2h(rd, rn, rm, /* part1 = */ true);
break;
case "sha256h2_qqv_cryptosha3"_h:
sha2h(rd, rn, rm, /* part1 = */ false);
break;
case "sha256su1_vvv_cryptosha3"_h:
sha2su1(rd, rn, rm);
break;
}
}

Expand Down
27 changes: 21 additions & 6 deletions src/aarch64/simulator-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -4498,6 +4498,16 @@ class Simulator : public DecoderVisitor {
const LogicVRegister& src1,
const LogicVRegister& src2);

template <unsigned N>
static void SHARotateEltsLeftOne(uint64_t (&x)[N]) {
VIXL_STATIC_ASSERT(N == 4);
uint64_t temp = x[3];
x[3] = x[2];
x[2] = x[1];
x[1] = x[0];
x[0] = temp;
}

template <uint32_t mode>
LogicVRegister sha1(LogicVRegister srcdst,
const LogicVRegister& src1,
Expand All @@ -4515,18 +4525,23 @@ class Simulator : public DecoderVisitor {
sd[1] = RotateLeft(sd[1], 30, kSRegSize);

// y:sd = ROL(y:sd, 32)
uint64_t temp = sd[3];
sd[3] = sd[2];
sd[2] = sd[1];
sd[1] = sd[0];
sd[0] = y;
y = temp;
SHARotateEltsLeftOne(sd);
std::swap(sd[0], y);
}

srcdst.SetUintArray(kFormat4S, sd);
return srcdst;
}

LogicVRegister sha2h(LogicVRegister srcdst,
const LogicVRegister& src1,
const LogicVRegister& src2,
bool part1);
LogicVRegister sha2su0(LogicVRegister srcdst, const LogicVRegister& src1);
LogicVRegister sha2su1(LogicVRegister srcdst,
const LogicVRegister& src1,
const LogicVRegister& src2);

#define NEON_3VREG_LOGIC_LIST(V) \
V(addhn) \
V(addhn2) \
Expand Down
10 changes: 10 additions & 0 deletions test/aarch64/test-cpu-features-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3805,5 +3805,15 @@ TEST_NEON_SHA1(sha1su0_0, sha1su0(v19.V4S(), v9.V4S(), v27.V4S()))
TEST_NEON_SHA1(sha1h_0, sha1h(s12, s0))
TEST_NEON_SHA1(sha1su1_0, sha1su1(v2.V4S(), v4.V4S()))

#define TEST_FEAT(NAME, ASM) \
TEST_TEMPLATE(CPUFeatures(CPUFeatures::kNEON, CPUFeatures::kSHA2), \
NEON_SHA2_##NAME, \
ASM)
TEST_FEAT(sha256h_0, sha256h(q0, q12, v20.V4S()))
TEST_FEAT(sha256h2_0, sha256h2(q22, q2, v13.V4S()))
TEST_FEAT(sha256su0_0, sha256su0(v2.V4S(), v4.V4S()))
TEST_FEAT(sha256su1_0, sha256su1(v19.V4S(), v9.V4S(), v27.V4S()))
#undef TEST_FEAT

} // namespace aarch64
} // namespace vixl
12 changes: 12 additions & 0 deletions test/aarch64/test-disasm-neon-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4544,6 +4544,18 @@ TEST(neon_sha1) {
CLEANUP();
}

TEST(neon_sha2) {
SETUP();

COMPARE_MACRO(Sha256h(q0, q12, v20.V4S()), "sha256h q0, q12, v20.4s");
COMPARE_MACRO(Sha256h2(q22, q2, v13.V4S()), "sha256h2 q22, q2, v13.4s");
COMPARE_MACRO(Sha256su0(v2.V4S(), v4.V4S()), "sha256su0 v2.4s, v4.4s");
COMPARE_MACRO(Sha256su1(v19.V4S(), v9.V4S(), v27.V4S()),
"sha256su1 v19.4s, v9.4s, v27.4s");

CLEANUP();
}

TEST(neon_unallocated_regression_test) {
SETUP();

Expand Down
Loading