Skip to content

Commit

Permalink
Replace RSA/ECDSA_Sign/Verify with EVP_DigestSign/Verify APIs in ACVP…
Browse files Browse the repository at this point in the history
… tests. (#264)

* Replace RSA/ECDSA_Sign/Verify APIs with EVP_DigestSign/Verify APIs in ACVP tests

* Fix memory leak.

* Rename key to rsa.

* Apply suggestions from code review

Co-authored-by: torben-hansen <[email protected]>

Co-authored-by: torben-hansen <[email protected]>
  • Loading branch information
bryce-shang and torben-hansen authored Oct 1, 2021
1 parent c81510d commit 7e7f06c
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 73 deletions.
23 changes: 19 additions & 4 deletions util/fipstools/acvp/acvptool/subprocess/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ func (e *ecdsa) Process(vectorSet []byte, m Transactable) (interface{}, error) {
ID: group.ID,
}
var sigGenPrivateKey []byte
var qxHex []byte
var qyHex []byte

for _, test := range group.Tests {
var testResp ecdsaTestResponse
Expand Down Expand Up @@ -148,8 +150,10 @@ func (e *ecdsa) Process(vectorSet []byte, m Transactable) (interface{}, error) {
}

sigGenPrivateKey = result[0]
response.QxHex = hex.EncodeToString(result[1])
response.QyHex = hex.EncodeToString(result[2])
qxHex = result[1]
qyHex = result[2]
response.QxHex = hex.EncodeToString(qxHex)
response.QyHex = hex.EncodeToString(qyHex)
}

msg, err := hex.DecodeString(test.MsgHex)
Expand All @@ -167,8 +171,19 @@ func (e *ecdsa) Process(vectorSet []byte, m Transactable) (interface{}, error) {
if err != nil {
return nil, fmt.Errorf("signature generation failed for test case %d/%d: %s", group.ID, test.ID, err)
}
testResp.RHex = hex.EncodeToString(result[0])
testResp.SHex = hex.EncodeToString(result[1])
rHex := result[0]
sHex := result[1]
testResp.RHex = hex.EncodeToString(rHex)
testResp.SHex = hex.EncodeToString(sHex)
// Ask the subprocess to verify the generated signature for this test case.
ver_result, ver_err := m.Transact(e.algo+"/"+"sigVer", 1, []byte(group.Curve), []byte(group.HashAlgo), msg, qxHex, qyHex, rHex, sHex)
if ver_err != nil {
return nil, fmt.Errorf("After signature generation, signature verification failed for test case %d/%d: %s", group.ID, test.ID, ver_err)
}
// ver_result[0] should be a single byte. The value should be one in this case.
if !bytes.Equal(ver_result[0], []byte{01}) {
return nil, fmt.Errorf("After signature generation, signature verification returned unexpected result: %q for test case %d/%d.", ver_result[0], group.ID, test.ID)
}

case "sigVer":
p := e.primitives[group.HashAlgo]
Expand Down
21 changes: 17 additions & 4 deletions util/fipstools/acvp/acvptool/subprocess/rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ func processSigGen(vectorSet []byte, m Transactable) (interface{}, error) {
}

operation := "RSA/sigGen/" + group.Hash + "/" + group.SigType
ver_operation := "RSA/sigVer/" + group.Hash + "/" + group.SigType

for _, test := range group.Tests {
msg, err := hex.DecodeString(test.MessageHex)
Expand All @@ -190,16 +191,28 @@ func processSigGen(vectorSet []byte, m Transactable) (interface{}, error) {
return nil, err
}

n := results[0]
e := results[1]
sig := results[2]

if len(response.N) == 0 {
response.N = hex.EncodeToString(results[0])
response.E = hex.EncodeToString(results[1])
} else if response.N != hex.EncodeToString(results[0]) {
response.N = hex.EncodeToString(n)
response.E = hex.EncodeToString(e)
} else if response.N != hex.EncodeToString(n) {
return nil, fmt.Errorf("module wrapper returned different RSA keys for the same SigGen configuration")
}

// Ask the subprocess to verify the generated signature for this test case.
ver_results, ver_err := m.Transact(ver_operation, 1, n, e, msg, sig)
if ver_err != nil {
return nil, ver_err
}
if len(ver_results[0]) != 1 || ver_results[0][0] != 1 {
return nil, fmt.Errorf("module wrapper returned RSA Sig cannot be verified for test case %d/%d.", group.ID, test.ID)
}
response.Tests = append(response.Tests, rsaSigGenTestResponse{
ID: test.ID,
Sig: hex.EncodeToString(results[2]),
Sig: hex.EncodeToString(sig),
})
}

Expand Down
2 changes: 2 additions & 0 deletions util/fipstools/acvp/acvptool/test/tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
{"Wrapper": "modulewrapper", "In": "vectors/CMAC-AES.bz2", "Out": "expected/CMAC-AES.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/ctrDRBG.bz2", "Out": "expected/ctrDRBG.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/ECDSA.bz2", "Out": "expected/ECDSA.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/ECDSA-SigGen.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/HMAC-SHA-1.bz2", "Out": "expected/HMAC-SHA-1.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/HMAC-SHA2-224.bz2", "Out": "expected/HMAC-SHA2-224.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/HMAC-SHA2-256.bz2", "Out": "expected/HMAC-SHA2-256.bz2"},
Expand All @@ -25,6 +26,7 @@
{"Wrapper": "modulewrapper", "In": "vectors/ACVP-AES-GCM-internal-IV.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/kdf-components.bz2", "Out": "expected/kdf-components.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/RSA.bz2", "Out": "expected/RSA.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/RSA-SigGen.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/SHA-1.bz2", "Out": "expected/SHA-1.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/SHA2-224.bz2", "Out": "expected/SHA2-224.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/SHA2-256.bz2", "Out": "expected/SHA2-256.bz2"},
Expand Down
Binary file not shown.
Binary file not shown.
158 changes: 93 additions & 65 deletions util/fipstools/acvp/modulewrapper/modulewrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <openssl/dh.h>
#include <openssl/digest.h>
#include <openssl/ec.h>
#include <openssl/evp.h>
#include <openssl/ec_key.h>
#include <openssl/ecdh.h>
#include <openssl/ecdsa.h>
Expand Down Expand Up @@ -1567,16 +1568,27 @@ static bool ECDSASigGen(const Span<const uint8_t> args[], ReplyCallback write_re
bssl::UniquePtr<EC_KEY> key = ECKeyFromName(args[0]);
bssl::UniquePtr<BIGNUM> d = BytesToBIGNUM(args[1]);
const EVP_MD *hash = HashFromName(args[2]);
uint8_t digest[EVP_MAX_MD_SIZE];
unsigned digest_len;
if (!key || !hash ||
!EVP_Digest(args[3].data(), args[3].size(), digest, &digest_len, hash,
/*impl=*/nullptr) ||
!EC_KEY_set_private_key(key.get(), d.get())) {
auto msg = args[3];
if (!key || !hash || !EC_KEY_set_private_key(key.get(), d.get())) {
return false;
}

bssl::UniquePtr<ECDSA_SIG> sig(ECDSA_do_sign(digest, digest_len, key.get()));
bssl::ScopedEVP_MD_CTX ctx;
EVP_PKEY_CTX *pctx;
bssl::UniquePtr<EVP_PKEY> evp_pkey(EVP_PKEY_new());
if (!evp_pkey || !EVP_PKEY_set1_EC_KEY(evp_pkey.get(), key.get())) {
return false;
}
std::vector<uint8_t> sig_der;
size_t len;
if (!EVP_DigestSignInit(ctx.get(), &pctx, hash, nullptr, evp_pkey.get()) ||
!EVP_DigestSign(ctx.get(), nullptr, &len, msg.data(), msg.size())) {
return false;
}
sig_der.resize(len);
if (!EVP_DigestSign(ctx.get(), sig_der.data(), &len, msg.data(), msg.size())) {
return false;
}
bssl::UniquePtr<ECDSA_SIG> sig(ECDSA_SIG_from_bytes(sig_der.data(), len));
if (!sig) {
return false;
}
Expand All @@ -1599,27 +1611,35 @@ static bool ECDSASigVer(const Span<const uint8_t> args[], ReplyCallback write_re
ECDSA_SIG sig;
sig.r = r.get();
sig.s = s.get();

uint8_t digest[EVP_MAX_MD_SIZE];
unsigned digest_len;
if (!key || !hash ||
!EVP_Digest(msg.data(), msg.size(), digest, &digest_len, hash,
/*impl=*/nullptr)) {
uint8_t *der;
size_t der_len;
if (!key || !hash || !ECDSA_SIG_to_bytes(&der, &der_len, &sig)) {
return false;
}

// Let |delete_der| manage the release of |der|.
bssl::UniquePtr<uint8_t> delete_der(der);
bssl::UniquePtr<EC_POINT> point(EC_POINT_new(EC_KEY_get0_group(key.get())));
uint8_t reply[1];
if (!EC_POINT_set_affine_coordinates_GFp(EC_KEY_get0_group(key.get()),
point.get(), x.get(), y.get(),
/*ctx=*/nullptr) ||
!EC_KEY_set_public_key(key.get(), point.get()) ||
!EC_KEY_check_fips(key.get()) ||
!ECDSA_do_verify(digest, digest_len, &sig, key.get())) {
!EC_KEY_check_fips(key.get())) {
return false;
}
bssl::ScopedEVP_MD_CTX ctx;
EVP_PKEY_CTX *pctx;
bssl::UniquePtr<EVP_PKEY> evp_pkey(EVP_PKEY_new());
if (!evp_pkey || !EVP_PKEY_set1_EC_KEY(evp_pkey.get(), key.get())) {
return false;
}
uint8_t reply[1];
if (!EVP_DigestVerifyInit(ctx.get(), &pctx, hash, nullptr, evp_pkey.get()) ||
!EVP_DigestVerify(ctx.get(), der, der_len, msg.data(), msg.size())) {
reply[0] = 0;
} else {
reply[0] = 1;
}
ERR_clear_error();

return write_reply({Span<const uint8_t>(reply)});
}
Expand Down Expand Up @@ -1657,26 +1677,34 @@ static bool CMAC_AESVerify(const Span<const uint8_t> args[], ReplyCallback write
return write_reply({Span<const uint8_t>(&ok, sizeof(ok))});
}

static std::map<unsigned, bssl::UniquePtr<RSA>>& CachedRSAKeys() {
static std::map<unsigned, bssl::UniquePtr<RSA>> keys;
static std::map<unsigned, bssl::UniquePtr<EVP_PKEY>>& CachedRSAEVPKeys() {
static std::map<unsigned, bssl::UniquePtr<EVP_PKEY>> keys;
return keys;
}

static RSA* GetRSAKey(unsigned bits) {
auto it = CachedRSAKeys().find(bits);
if (it != CachedRSAKeys().end()) {
static EVP_PKEY* AddRSAKeyToCache(bssl::UniquePtr<RSA>& rsa, unsigned bits) {
bssl::UniquePtr<EVP_PKEY> evp_pkey(EVP_PKEY_new());
if (!evp_pkey || !EVP_PKEY_set1_RSA(evp_pkey.get(), rsa.get())) {
return nullptr;
}

EVP_PKEY *const ret = evp_pkey.get();
CachedRSAEVPKeys().emplace(static_cast<unsigned>(bits), std::move(evp_pkey));
return ret;
}

static EVP_PKEY* GetRSAKey(unsigned bits) {
auto it = CachedRSAEVPKeys().find(bits);
if (it != CachedRSAEVPKeys().end()) {
return it->second.get();
}

bssl::UniquePtr<RSA> key(RSA_new());
if (!RSA_generate_key_fips(key.get(), bits, nullptr)) {
bssl::UniquePtr<RSA> rsa(RSA_new());
if (!RSA_generate_key_fips(rsa.get(), bits, nullptr)) {
abort();
}

RSA *const ret = key.get();
CachedRSAKeys().emplace(static_cast<unsigned>(bits), std::move(key));

return ret;
return AddRSAKeyToCache(rsa, bits);
}

static bool RSAKeyGen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
Expand All @@ -1686,22 +1714,24 @@ static bool RSAKeyGen(const Span<const uint8_t> args[], ReplyCallback write_repl
}
memcpy(&bits, args[0].data(), sizeof(bits));

bssl::UniquePtr<RSA> key(RSA_new());
if (!RSA_generate_key_fips(key.get(), bits, nullptr)) {
bssl::UniquePtr<RSA> rsa(RSA_new());
if (!RSA_generate_key_fips(rsa.get(), bits, nullptr)) {
LOG_ERROR("RSA_generate_key_fips failed for modulus length %u.\n", bits);
return false;
}

const BIGNUM *n, *e, *d, *p, *q;
RSA_get0_key(key.get(), &n, &e, &d);
RSA_get0_factors(key.get(), &p, &q);
RSA_get0_key(rsa.get(), &n, &e, &d);
RSA_get0_factors(rsa.get(), &p, &q);

if (!write_reply({BIGNUMBytes(e), BIGNUMBytes(p), BIGNUMBytes(q),
BIGNUMBytes(n), BIGNUMBytes(d)})) {
return false;
}

CachedRSAKeys().emplace(static_cast<unsigned>(bits), std::move(key));
if (AddRSAKeyToCache(rsa, bits) == nullptr) {
return false;
}
return true;
}

Expand All @@ -1713,35 +1743,32 @@ static bool RSASigGen(const Span<const uint8_t> args[], ReplyCallback write_repl
}
memcpy(&bits, args[0].data(), sizeof(bits));
const Span<const uint8_t> msg = args[1];

RSA *const key = GetRSAKey(bits);
const EVP_MD *const md = MDFunc();
uint8_t digest_buf[EVP_MAX_MD_SIZE];
unsigned digest_len;
if (!EVP_Digest(msg.data(), msg.size(), digest_buf, &digest_len, md, NULL)) {
EVP_PKEY *const evp_pkey = GetRSAKey(bits);
if (evp_pkey == nullptr) {
return false;
}

std::vector<uint8_t> sig(RSA_size(key));
RSA *const rsa = EVP_PKEY_get0_RSA(evp_pkey);
if (rsa == nullptr) {
return false;
}
const EVP_MD *const md = MDFunc();
std::vector<uint8_t> sig;
size_t sig_len;
if (UsePSS) {
if (!RSA_sign_pss_mgf1(key, &sig_len, sig.data(), sig.size(), digest_buf,
digest_len, md, md, -1)) {
return false;
}
} else {
unsigned sig_len_u;
if (!RSA_sign(EVP_MD_type(md), digest_buf, digest_len, sig.data(),
&sig_len_u, key)) {
return false;
}
sig_len = sig_len_u;
bssl::ScopedEVP_MD_CTX ctx;
EVP_PKEY_CTX *pctx;
int padding = UsePSS ? RSA_PKCS1_PSS_PADDING : RSA_PKCS1_PADDING;
if (!EVP_DigestSignInit(ctx.get(), &pctx, md, nullptr, evp_pkey) ||
!EVP_PKEY_CTX_set_rsa_padding(pctx, padding) ||
!EVP_DigestSign(ctx.get(), nullptr, &sig_len, msg.data(), msg.size())) {
return false;
}

sig.resize(sig_len);
if (!EVP_DigestSign(ctx.get(), sig.data(), &sig_len, msg.data(), msg.size())) {
return false;
}

return write_reply(
{BIGNUMBytes(RSA_get0_n(key)), BIGNUMBytes(RSA_get0_e(key)), sig});
{BIGNUMBytes(RSA_get0_n(rsa)), BIGNUMBytes(RSA_get0_e(rsa)), sig});
}

template <const EVP_MD *(MDFunc)(), bool UsePSS>
Expand All @@ -1761,19 +1788,20 @@ static bool RSASigVer(const Span<const uint8_t> args[], ReplyCallback write_repl
}

const EVP_MD *const md = MDFunc();
uint8_t digest_buf[EVP_MAX_MD_SIZE];
unsigned digest_len;
if (!EVP_Digest(msg.data(), msg.size(), digest_buf, &digest_len, md, NULL)) {
bssl::ScopedEVP_MD_CTX ctx;
EVP_PKEY_CTX *pctx;
bssl::UniquePtr<EVP_PKEY> evp_pkey(EVP_PKEY_new());
if (!evp_pkey || !EVP_PKEY_set1_RSA(evp_pkey.get(), key.get())) {
return false;
}

uint8_t ok;
if (UsePSS) {
ok = RSA_verify_pss_mgf1(key.get(), digest_buf, digest_len, md, md, -1,
sig.data(), sig.size());
int padding = UsePSS ? RSA_PKCS1_PSS_PADDING : RSA_PKCS1_PADDING;
if (!EVP_DigestVerifyInit(ctx.get(), &pctx, md, nullptr, evp_pkey.get()) ||
!EVP_PKEY_CTX_set_rsa_padding(pctx, padding) ||
!EVP_DigestVerify(ctx.get(), sig.data(), sig.size(), msg.data(), msg.size())) {
ok = 0;
} else {
ok = RSA_verify(EVP_MD_type(md), digest_buf, digest_len, sig.data(),
sig.size(), key.get());
ok = 1;
}
ERR_clear_error();

Expand Down

0 comments on commit 7e7f06c

Please sign in to comment.