From e91d0420411b332a04c4544dddbdcea63fb60e70 Mon Sep 17 00:00:00 2001 From: Emil Lundberg Date: Tue, 23 Nov 2021 16:58:57 +0100 Subject: [PATCH] Eliminate redundant calls to CredentialRepository.lookup --- NEWS | 8 ++ .../yubico/webauthn/FinishAssertionSteps.java | 99 ++++++++++--------- .../webauthn/RelyingPartyAssertionSpec.scala | 97 +++++++++++++++++- 3 files changed, 152 insertions(+), 52 deletions(-) diff --git a/NEWS b/NEWS index f77784cb3..9f398ea7a 100644 --- a/NEWS +++ b/NEWS @@ -1,3 +1,11 @@ +== Version 1.12.1 (unreleased) == + +Fixes: + +* `RelyingParty.finishAssertion()` no longer makes multiple (redundant) calls to + `CredentialRepository.lookup()`. + + == Version 1.12.0 == New features: diff --git a/webauthn-server-core/src/main/java/com/yubico/webauthn/FinishAssertionSteps.java b/webauthn-server-core/src/main/java/com/yubico/webauthn/FinishAssertionSteps.java index 667862bf4..4f5ba22a9 100644 --- a/webauthn-server-core/src/main/java/com/yubico/webauthn/FinishAssertionSteps.java +++ b/webauthn-server-core/src/main/java/com/yubico/webauthn/FinishAssertionSteps.java @@ -117,9 +117,27 @@ default AssertionResult run() throws InvalidSignatureCountException { @Value class Step0 implements Step { + + private final Optional userHandle = + response + .getResponse() + .getUserHandle() + .map(Optional::of) + .orElseGet( + () -> credentialRepository.getUserHandleForUsername(request.getUsername().get())); + + private final Optional username = + request + .getUsername() + .map(Optional::of) + .orElseGet( + () -> + credentialRepository.getUsernameForUserHandle( + response.getResponse().getUserHandle().get())); + @Override public Step1 nextStep() { - return new Step1(username().get(), userHandle().get(), allWarnings()); + return new Step1(username.get(), userHandle.get(), allWarnings()); } @Override @@ -128,12 +146,12 @@ public void validate() { request.getUsername().isPresent() || response.getResponse().getUserHandle().isPresent(), "At least one of username and user handle must be given; none was."); assure( - userHandle().isPresent(), + userHandle.isPresent(), "No user found for username: %s, userHandle: %s", request.getUsername(), response.getResponse().getUserHandle()); assure( - username().isPresent(), + username.isPresent(), "No user found for username: %s, userHandle: %s", request.getUsername(), response.getResponse().getUserHandle()); @@ -143,25 +161,6 @@ public void validate() { public List getPrevWarnings() { return Collections.emptyList(); } - - private Optional userHandle() { - return response - .getResponse() - .getUserHandle() - .map(Optional::of) - .orElseGet( - () -> credentialRepository.getUserHandleForUsername(request.getUsername().get())); - } - - private Optional username() { - return request - .getUsername() - .map(Optional::of) - .orElseGet( - () -> - credentialRepository.getUsernameForUserHandle( - response.getResponse().getUserHandle().get())); - } } @Value @@ -196,16 +195,22 @@ class Step2 implements Step { private final ByteArray userHandle; private final List prevWarnings; + private final Optional registration; + + public Step2(String username, ByteArray userHandle, List prevWarnings) { + this.username = username; + this.userHandle = userHandle; + this.prevWarnings = prevWarnings; + this.registration = credentialRepository.lookup(response.getId(), userHandle); + } + @Override public Step3 nextStep() { - return new Step3(username, userHandle, allWarnings()); + return new Step3(username, userHandle, registration, allWarnings()); } @Override public void validate() { - Optional registration = - credentialRepository.lookup(response.getId(), userHandle); - assure(registration.isPresent(), "Unknown credential: %s", response.getId()); assure( @@ -220,29 +225,22 @@ public void validate() { class Step3 implements Step { private final String username; private final ByteArray userHandle; + private final Optional credential; private final List prevWarnings; @Override public Step4 nextStep() { - return new Step4(username, userHandle, credential(), allWarnings()); + return new Step4(username, userHandle, credential.get(), allWarnings()); } @Override public void validate() { assure( - maybeCredential().isPresent(), + credential.isPresent(), "Unknown credential. Credential ID: %s, user handle: %s", response.getId(), userHandle); } - - private Optional maybeCredential() { - return credentialRepository.lookup(response.getId(), userHandle); - } - - public RegisteredCredential credential() { - return maybeCredential().get(); - } } @Value @@ -581,7 +579,7 @@ public void validate() { @Override public Step17 nextStep() { - return new Step17(username, userHandle, allWarnings()); + return new Step17(username, userHandle, credential, allWarnings()); } public ByteArray signedBytes() { @@ -593,19 +591,33 @@ public ByteArray signedBytes() { class Step17 implements Step { private final String username; private final ByteArray userHandle; + private final RegisteredCredential credential; private final List prevWarnings; + private final long storedSignatureCountBefore; + + public Step17( + String username, + ByteArray userHandle, + RegisteredCredential credential, + List prevWarnings) { + this.username = username; + this.userHandle = userHandle; + this.credential = credential; + this.prevWarnings = prevWarnings; + this.storedSignatureCountBefore = credential.getSignatureCount(); + } @Override public void validate() throws InvalidSignatureCountException { if (validateSignatureCounter && !signatureCounterValid()) { throw new InvalidSignatureCountException( - response.getId(), storedSignatureCountBefore() + 1, assertionSignatureCount()); + response.getId(), storedSignatureCountBefore + 1, assertionSignatureCount()); } } private boolean signatureCounterValid() { - return (assertionSignatureCount() == 0 && storedSignatureCountBefore() == 0) - || assertionSignatureCount() > storedSignatureCountBefore(); + return (assertionSignatureCount() == 0 && storedSignatureCountBefore == 0) + || assertionSignatureCount() > storedSignatureCountBefore; } @Override @@ -614,13 +626,6 @@ public Finished nextStep() { username, userHandle, assertionSignatureCount(), signatureCounterValid(), allWarnings()); } - private long storedSignatureCountBefore() { - return credentialRepository - .lookup(response.getId(), userHandle) - .map(RegisteredCredential::getSignatureCount) - .orElse(0L); - } - private long assertionSignatureCount() { return response.getResponse().getParsedAuthenticatorData().getSignatureCounter(); } diff --git a/webauthn-server-core/src/test/scala/com/yubico/webauthn/RelyingPartyAssertionSpec.scala b/webauthn-server-core/src/test/scala/com/yubico/webauthn/RelyingPartyAssertionSpec.scala index e525c3190..949830f49 100644 --- a/webauthn-server-core/src/test/scala/com/yubico/webauthn/RelyingPartyAssertionSpec.scala +++ b/webauthn-server-core/src/test/scala/com/yubico/webauthn/RelyingPartyAssertionSpec.scala @@ -329,9 +329,98 @@ class RelyingPartyAssertionSpec } - describe("§7.2. Verifying an authentication assertion") { + describe("RelyingParty.finishAssertion") { - describe("When verifying a given PublicKeyCredential structure (credential) and an AuthenticationExtensionsClientOutputs structure clientExtensionResults, as part of an authentication ceremony, the Relying Party MUST proceed as follows:") { + it("does not make redundant calls to CredentialRepository.lookup().") { + val registrationTestData = + RegistrationTestData.Packed.BasicAttestationEdDsa + val testData = registrationTestData.assertion.get + + val credRepo = { + val user = registrationTestData.userId + val credential = RegisteredCredential + .builder() + .credentialId(registrationTestData.response.getId) + .userHandle(registrationTestData.userId.getId) + .publicKeyCose( + registrationTestData.response.getResponse.getParsedAuthenticatorData.getAttestedCredentialData.get.getCredentialPublicKey + ) + .signatureCount(0) + .build() + + new CredentialRepository { + var lookupCount = 0 + + override def getCredentialIdsForUsername( + username: String + ): java.util.Set[PublicKeyCredentialDescriptor] = + if (username == user.getName) + Set( + PublicKeyCredentialDescriptor + .builder() + .id(credential.getCredentialId) + .build() + ).asJava + else Set.empty.asJava + + override def getUserHandleForUsername( + username: String + ): Optional[ByteArray] = + if (username == user.getName) + Some(user.getId).asJava + else None.asJava + + override def getUsernameForUserHandle( + userHandle: ByteArray + ): Optional[String] = + if (userHandle == user.getId) + Some(user.getName).asJava + else None.asJava + + override def lookup( + credentialId: ByteArray, + userHandle: ByteArray, + ): Optional[RegisteredCredential] = { + lookupCount += 1 + if ( + credentialId == credential.getCredentialId && userHandle == user.getId + ) + Some(credential).asJava + else None.asJava + } + + override def lookupAll( + credentialId: ByteArray + ): java.util.Set[RegisteredCredential] = + if (credentialId == credential.getCredentialId) + Set(credential).asJava + else Set.empty.asJava + } + } + val rp = RelyingParty + .builder() + .identity( + RelyingPartyIdentity.builder().id("localhost").name("Test RP").build() + ) + .credentialRepository(credRepo) + .build() + + val result = rp.finishAssertion( + FinishAssertionOptions + .builder() + .request(testData.request) + .response(testData.response) + .build() + ) + + result.isSuccess should be(true) + result.getUserHandle should equal(registrationTestData.userId.getId) + result.getCredentialId should equal(registrationTestData.response.getId) + result.getCredentialId should equal(testData.response.getId) + credRepo.lookupCount should equal(1) + } + + describe("§7.2. Verifying an authentication assertion: When verifying a given PublicKeyCredential structure (credential) and an AuthenticationExtensionsClientOutputs structure clientExtensionResults, as part of an authentication ceremony, the Relying Party MUST proceed as follows:") { describe("1. If the allowCredentials option was given when this authentication ceremony was initiated, verify that credential.id identifies one of the public key credentials that were listed in allowCredentials.") { it("Fails if returned credential ID is a requested one.") { @@ -497,6 +586,7 @@ class RelyingPartyAssertionSpec val step: steps.Step3 = new steps.Step3( Defaults.username, Defaults.userHandle, + None.asJava, Nil.asJava, ) @@ -523,9 +613,6 @@ class RelyingPartyAssertionSpec val step: FinishAssertionSteps#Step3 = steps.begin.next.next.next step.validations shouldBe a[Success[_]] - step.credential.getPublicKeyCose should equal( - getPublicKeyBytes(Defaults.credentialKey) - ) step.tryNext shouldBe a[Success[_]] } }