Skip to content

Commit

Permalink
Support SM3 accelerating instructions (#108)
Browse files Browse the repository at this point in the history
Add support for seven Neon SM3 accelerating instructions.
  • Loading branch information
mmc28a authored Jul 24, 2024
1 parent da718c2 commit b7ab6ec
Show file tree
Hide file tree
Showing 15 changed files with 710 additions and 24 deletions.
64 changes: 64 additions & 0 deletions src/aarch64/assembler-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6053,6 +6053,70 @@ void Assembler::aesmc(const VRegister& vd, const VRegister& vn) {
Emit(0x4e286800 | Rd(vd) | Rn(vn));
}

void Assembler::sm3partw1(const VRegister& vd, const VRegister& vn, const VRegister& vm) {
VIXL_ASSERT(CPUHas(CPUFeatures::kNEON));
VIXL_ASSERT(CPUHas(CPUFeatures::kSM3));
VIXL_ASSERT(vd.Is4S() && vn.Is4S() && vm.Is4S());

Emit(0xce60c000 | Rd(vd) | Rn(vn) | Rm(vm));
}

void Assembler::sm3partw2(const VRegister& vd, const VRegister& vn, const VRegister& vm) {
VIXL_ASSERT(CPUHas(CPUFeatures::kNEON));
VIXL_ASSERT(CPUHas(CPUFeatures::kSM3));
VIXL_ASSERT(vd.Is4S() && vn.Is4S() && vm.Is4S());

Emit(0xce60c400 | Rd(vd) | Rn(vn) | Rm(vm));
}

void Assembler::sm3ss1(const VRegister& vd, const VRegister& vn, const VRegister& vm, const VRegister& va) {
VIXL_ASSERT(CPUHas(CPUFeatures::kNEON));
VIXL_ASSERT(CPUHas(CPUFeatures::kSM3));
VIXL_ASSERT(vd.Is4S() && vn.Is4S() && vm.Is4S() && va.Is4S());

Emit(0xce400000 | Rd(vd) | Rn(vn) | Rm(vm) | Ra(va));
}

void Assembler::sm3tt1a(const VRegister& vd, const VRegister& vn, const VRegister& vm, int index) {
VIXL_ASSERT(CPUHas(CPUFeatures::kNEON));
VIXL_ASSERT(CPUHas(CPUFeatures::kSM3));
VIXL_ASSERT(vd.Is4S() && vn.Is4S() && vm.Is4S());
VIXL_ASSERT(IsUint2(index));

Instr i = static_cast<uint32_t>(index) << 12;
Emit(0xce408000 | Rd(vd) | Rn(vn) | Rm(vm) | i);
}

void Assembler::sm3tt1b(const VRegister& vd, const VRegister& vn, const VRegister& vm, int index) {
VIXL_ASSERT(CPUHas(CPUFeatures::kNEON));
VIXL_ASSERT(CPUHas(CPUFeatures::kSM3));
VIXL_ASSERT(vd.Is4S() && vn.Is4S() && vm.Is4S());
VIXL_ASSERT(IsUint2(index));

Instr i = static_cast<uint32_t>(index) << 12;
Emit(0xce408400 | Rd(vd) | Rn(vn) | Rm(vm) | i);
}

void Assembler::sm3tt2a(const VRegister& vd, const VRegister& vn, const VRegister& vm, int index) {
VIXL_ASSERT(CPUHas(CPUFeatures::kNEON));
VIXL_ASSERT(CPUHas(CPUFeatures::kSM3));
VIXL_ASSERT(vd.Is4S() && vn.Is4S() && vm.Is4S());
VIXL_ASSERT(IsUint2(index));

Instr i = static_cast<uint32_t>(index) << 12;
Emit(0xce408800 | Rd(vd) | Rn(vn) | Rm(vm) | i);
}

void Assembler::sm3tt2b(const VRegister& vd, const VRegister& vn, const VRegister& vm, int index) {
VIXL_ASSERT(CPUHas(CPUFeatures::kNEON));
VIXL_ASSERT(CPUHas(CPUFeatures::kSM3));
VIXL_ASSERT(vd.Is4S() && vn.Is4S() && vm.Is4S());
VIXL_ASSERT(IsUint2(index));

Instr i = static_cast<uint32_t>(index) << 12;
Emit(0xce408c00 | Rd(vd) | Rn(vn) | Rm(vm) | i);
}

// Note:
// For all ToImm instructions below, a difference in case
// for the same letter indicates a negated bit.
Expand Down
36 changes: 36 additions & 0 deletions src/aarch64/assembler-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -3696,6 +3696,42 @@ class Assembler : public vixl::internal::AssemblerBase {
// AES mix columns.
void aesmc(const VRegister& vd, const VRegister& vn);

// SM3PARTW1.
void sm3partw1(const VRegister& vd, const VRegister& vn, const VRegister& vm);

// SM3PARTW2.
void sm3partw2(const VRegister& vd, const VRegister& vn, const VRegister& vm);

// SM3SS1.
void sm3ss1(const VRegister& vd,
const VRegister& vn,
const VRegister& vm,
const VRegister& va);

// SM3TT1A.
void sm3tt1a(const VRegister& vd,
const VRegister& vn,
const VRegister& vm,
int index);

// SM3TT1B.
void sm3tt1b(const VRegister& vd,
const VRegister& vn,
const VRegister& vm,
int index);

// SM3TT2A.
void sm3tt2a(const VRegister& vd,
const VRegister& vn,
const VRegister& vm,
int index);

// SM3TT2B.
void sm3tt2b(const VRegister& vd,
const VRegister& vn,
const VRegister& vm,
int index);

// Scalable Vector Extensions.

// Absolute value (predicated).
Expand Down
6 changes: 6 additions & 0 deletions src/aarch64/cpu-features-auditor-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,12 @@ void CPUFeaturesAuditor::VisitCryptoAES(const Instruction* instr) {
USE(instr);
}

void CPUFeaturesAuditor::VisitCryptoSM3(const Instruction* instr) {
RecordInstructionFeaturesScope scope(this);
scope.Record(CPUFeatures::kNEON, CPUFeatures::kSM3);
USE(instr);
}

void CPUFeaturesAuditor::VisitDataProcessing1Source(const Instruction* instr) {
RecordInstructionFeaturesScope scope(this);
switch (instr->Mask(DataProcessing1SourceMask)) {
Expand Down
1 change: 1 addition & 0 deletions src/aarch64/cpu-features-auditor-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class CPUFeaturesAuditor : public DecoderVisitor {
#define DECLARE(A) virtual void Visit##A(const Instruction* instr);
VISITOR_LIST(DECLARE)
#undef DECLARE
void VisitCryptoSM3(const Instruction* instr);

void LoadStoreHelper(const Instruction* instr);
void LoadStorePairHelper(const Instruction* instr);
Expand Down
14 changes: 7 additions & 7 deletions src/aarch64/decoder-visitor-map-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -2656,13 +2656,13 @@
{"ldtrsw_64_ldst_unpriv"_h, &VISITORCLASS::VisitUnimplemented}, \
{"ldtr_32_ldst_unpriv"_h, &VISITORCLASS::VisitUnimplemented}, \
{"ldtr_64_ldst_unpriv"_h, &VISITORCLASS::VisitUnimplemented}, \
{"sm3partw1_vvv4_cryptosha512_3"_h, &VISITORCLASS::VisitUnimplemented}, \
{"sm3partw2_vvv4_cryptosha512_3"_h, &VISITORCLASS::VisitUnimplemented}, \
{"sm3ss1_vvv4_crypto4"_h, &VISITORCLASS::VisitUnimplemented}, \
{"sm3tt1a_vvv4_crypto3_imm2"_h, &VISITORCLASS::VisitUnimplemented}, \
{"sm3tt1b_vvv4_crypto3_imm2"_h, &VISITORCLASS::VisitUnimplemented}, \
{"sm3tt2a_vvv4_crypto3_imm2"_h, &VISITORCLASS::VisitUnimplemented}, \
{"sm3tt2b_vvv_crypto3_imm2"_h, &VISITORCLASS::VisitUnimplemented}, \
{"sm3partw1_vvv4_cryptosha512_3"_h, &VISITORCLASS::VisitCryptoSM3}, \
{"sm3partw2_vvv4_cryptosha512_3"_h, &VISITORCLASS::VisitCryptoSM3}, \
{"sm3ss1_vvv4_crypto4"_h, &VISITORCLASS::VisitCryptoSM3}, \
{"sm3tt1a_vvv4_crypto3_imm2"_h, &VISITORCLASS::VisitCryptoSM3}, \
{"sm3tt1b_vvv4_crypto3_imm2"_h, &VISITORCLASS::VisitCryptoSM3}, \
{"sm3tt2a_vvv4_crypto3_imm2"_h, &VISITORCLASS::VisitCryptoSM3}, \
{"sm3tt2b_vvv_crypto3_imm2"_h, &VISITORCLASS::VisitCryptoSM3}, \
{"sm4ekey_vvv4_cryptosha512_3"_h, &VISITORCLASS::VisitUnimplemented}, \
{"sm4e_vv4_cryptosha512_2"_h, &VISITORCLASS::VisitUnimplemented}, \
{"st64b_64l_memop"_h, &VISITORCLASS::VisitUnimplemented}, \
Expand Down
19 changes: 19 additions & 0 deletions src/aarch64/disasm-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2204,6 +2204,25 @@ void Disassembler::VisitCryptoAES(const Instruction *instr) {
FormatWithDecodedMnemonic(instr, "'Vd.16b, 'Vn.16b");
}

void Disassembler::VisitCryptoSM3(const Instruction *instr) {
const char *form = "'Vd.4s, 'Vn.4s, 'Vm.";
const char *suffix = "4s";

switch (form_hash_) {
case "sm3ss1_vvv4_crypto4"_h:
suffix = "4s, 'Va.4s";
break;
case "sm3tt1a_vvv4_crypto3_imm2"_h:
case "sm3tt1b_vvv4_crypto3_imm2"_h:
case "sm3tt2a_vvv4_crypto3_imm2"_h:
case "sm3tt2b_vvv_crypto3_imm2"_h:
suffix = "s['u1312]";
break;
}

FormatWithDecodedMnemonic(instr, form, suffix);
}

void Disassembler::DisassembleSHA512(const Instruction *instr) {
const char *form = "'Qd, 'Qn, 'Vm.2d";
const char *suffix = NULL;
Expand Down
2 changes: 2 additions & 0 deletions src/aarch64/disasm-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ class Disassembler : public DecoderVisitor {
void Disassemble_Xd_XnSP_Xm(const Instruction* instr);
void Disassemble_Xd_XnSP_XmSP(const Instruction* instr);

void VisitCryptoSM3(const Instruction* instr);

void Format(const Instruction* instr,
const char* mnemonic,
const char* format0,
Expand Down
129 changes: 124 additions & 5 deletions src/aarch64/logic-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7895,17 +7895,17 @@ LogicVRegister Simulator::fmatmul(VectorFormat vform,
}

template <>
uint64_t SHA1Operation<"choose"_h>(uint64_t x, uint64_t y, uint64_t z) {
uint64_t CryptoOp<"choose"_h>(uint64_t x, uint64_t y, uint64_t z) {
return ((y ^ z) & x) ^ z;
}

template <>
uint64_t SHA1Operation<"majority"_h>(uint64_t x, uint64_t y, uint64_t z) {
uint64_t CryptoOp<"majority"_h>(uint64_t x, uint64_t y, uint64_t z) {
return (x & y) | ((x | y) & z);
}

template <>
uint64_t SHA1Operation<"parity"_h>(uint64_t x, uint64_t y, uint64_t z) {
uint64_t CryptoOp<"parity"_h>(uint64_t x, uint64_t y, uint64_t z) {
return x ^ y ^ z;
}

Expand All @@ -7932,8 +7932,8 @@ LogicVRegister Simulator::sha2h(LogicVRegister srcdst,
}

for (unsigned i = 0; i < ArrayLength(x); i++) {
uint64_t chs = SHA1Operation<"choose"_h>(y[0], y[1], y[2]);
uint64_t maj = SHA1Operation<"majority"_h>(x[0], x[1], x[2]);
uint64_t chs = CryptoOp<"choose"_h>(y[0], y[1], y[2]);
uint64_t maj = CryptoOp<"majority"_h>(x[0], x[1], x[2]);

uint64_t w = src2.Uint(kFormat4S, i);
uint64_t t = y[3] + SHASigma<uint32_t, 6, 11, 25>(y[0]) + chs + w;
Expand Down Expand Up @@ -8351,6 +8351,125 @@ LogicVRegister Simulator::aes(LogicVRegister dst,
return dst;
}

LogicVRegister Simulator::sm3partw1(LogicVRegister srcdst,
const LogicVRegister& src1,
const LogicVRegister& src2) {
using namespace std::placeholders;
auto ROL = std::bind(RotateLeft, _1, _2, kSRegSize);

SimVRegister temp;

ext(kFormat16B, temp, src2, temp, 4);
rol(kFormat4S, temp, temp, 15);
eor(kFormat4S, temp, temp, src1);
LogicVRegister r = eor(kFormat4S, temp, temp, srcdst);

uint64_t result[4] = {};
r.UintArray(kFormat4S, result);
for (int i = 0; i < 4; i++) {
if (i == 3) {
// result[3] already contains srcdst[3] ^ src1[3] from the operations
// above.
result[i] ^= ROL(result[0], 15);
}
result[i] ^= ROL(result[i], 15) ^ ROL(result[i], 23);
}
srcdst.SetUintArray(kFormat4S, result);
return srcdst;
}

LogicVRegister Simulator::sm3partw2(LogicVRegister srcdst,
const LogicVRegister& src1,
const LogicVRegister& src2) {
using namespace std::placeholders;
auto ROL = std::bind(RotateLeft, _1, _2, kSRegSize);

SimVRegister temp;
VectorFormat vf = kFormat4S;

rol(vf, temp, src2, 7);
LogicVRegister r = eor(vf, temp, temp, src1);
eor(vf, srcdst, temp, srcdst);

uint64_t tmp2 = ROL(r.Uint(vf, 0), 15);
tmp2 ^= ROL(tmp2, 15) ^ ROL(tmp2, 23);
srcdst.SetUint(vf, 3, srcdst.Uint(vf, 3) ^ tmp2);
return srcdst;
}

LogicVRegister Simulator::sm3ss1(LogicVRegister dst,
const LogicVRegister& src1,
const LogicVRegister& src2,
const LogicVRegister& src3) {
using namespace std::placeholders;
auto ROL = std::bind(RotateLeft, _1, _2, kSRegSize);

VectorFormat vf = kFormat4S;
uint64_t result = ROL(src1.Uint(vf, 3), 12);
result += src2.Uint(vf, 3) + src3.Uint(vf, 3);
dst.Clear();
dst.SetUint(vf, 3, ROL(result, 7));
return dst;
}

LogicVRegister Simulator::sm3tt1(LogicVRegister srcdst,
const LogicVRegister& src1,
const LogicVRegister& src2,
int index,
bool is_a) {
VectorFormat vf = kFormat4S;
using namespace std::placeholders;
auto ROL = std::bind(RotateLeft, _1, _2, kSRegSize);
auto sd = std::bind(&LogicVRegister::Uint, srcdst, vf, _1);

VIXL_ASSERT(IsUint2(index));

uint64_t wjprime = src2.Uint(vf, index);
uint64_t ss2 = src1.Uint(vf, 3) ^ ROL(sd(3), 12);

uint64_t tt1;
if (is_a) {
tt1 = CryptoOp<"parity"_h>(sd(1), sd(2), sd(3));
} else {
tt1 = CryptoOp<"majority"_h>(sd(1), sd(2), sd(3));
}
tt1 += sd(0) + ss2 + wjprime;

ext(kFormat16B, srcdst, srcdst, srcdst, 4);
srcdst.SetUint(vf, 1, ROL(sd(1), 9));
srcdst.SetUint(vf, 3, tt1);
return srcdst;
}

LogicVRegister Simulator::sm3tt2(LogicVRegister srcdst,
const LogicVRegister& src1,
const LogicVRegister& src2,
int index,
bool is_a) {
VectorFormat vf = kFormat4S;
using namespace std::placeholders;
auto ROL = std::bind(RotateLeft, _1, _2, kSRegSize);
auto sd = std::bind(&LogicVRegister::Uint, srcdst, vf, _1);

VIXL_ASSERT(IsUint2(index));

uint64_t wj = src2.Uint(vf, index);

uint64_t tt2;
if (is_a) {
tt2 = CryptoOp<"parity"_h>(sd(1), sd(2), sd(3));
} else {
tt2 = CryptoOp<"choose"_h>(sd(3), sd(2), sd(1));
}
tt2 += sd(0) + src1.Uint(vf, 3) + wj;

ext(kFormat16B, srcdst, srcdst, srcdst, 4);
srcdst.SetUint(vf, 1, ROL(sd(1), 19));
tt2 ^= ROL(tt2, 9) ^ ROL(tt2, 17);
srcdst.SetUint(vf, 3, tt2);
return srcdst;
}

} // namespace aarch64
} // namespace vixl

Expand Down
16 changes: 15 additions & 1 deletion src/aarch64/macro-assembler-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -2812,6 +2812,8 @@ class MacroAssembler : public Assembler, public MacroAssemblerInterface {
V(sha512su1, Sha512su1) \
V(shadd, Shadd) \
V(shsub, Shsub) \
V(sm3partw1, Sm3partw1) \
V(sm3partw2, Sm3partw2) \
V(smax, Smax) \
V(smaxp, Smaxp) \
V(smin, Smin) \
Expand Down Expand Up @@ -3052,7 +3054,11 @@ class MacroAssembler : public Assembler, public MacroAssemblerInterface {
V(umlsl, Umlsl) \
V(umlsl2, Umlsl2) \
V(sudot, Sudot) \
V(usdot, Usdot)
V(usdot, Usdot) \
V(sm3tt1a, Sm3tt1a) \
V(sm3tt1b, Sm3tt1b) \
V(sm3tt2a, Sm3tt2a) \
V(sm3tt2b, Sm3tt2b)


#define DEFINE_MACRO_ASM_FUNC(ASM, MASM) \
Expand Down Expand Up @@ -3523,6 +3529,14 @@ class MacroAssembler : public Assembler, public MacroAssemblerInterface {
SingleEmissionCheckScope guard(this);
st4(vt, vt2, vt3, vt4, lane, dst);
}
void Sm3ss1(const VRegister& vd,
const VRegister& vn,
const VRegister& vm,
const VRegister& va) {
VIXL_ASSERT(allow_macro_instructions_);
SingleEmissionCheckScope guard(this);
sm3ss1(vd, vn, vm, va);
}
void Smov(const Register& rd, const VRegister& vn, int vn_index) {
VIXL_ASSERT(allow_macro_instructions_);
SingleEmissionCheckScope guard(this);
Expand Down
Loading

0 comments on commit b7ab6ec

Please sign in to comment.