Skip to content

Commit

Permalink
Eliminate redundant calls to CredentialRepository.lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
emlun committed Nov 23, 2021
1 parent ddb1af3 commit e91d042
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 52 deletions.
8 changes: 8 additions & 0 deletions NEWS
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
== Version 1.12.1 (unreleased) ==

Fixes:

* `RelyingParty.finishAssertion()` no longer makes multiple (redundant) calls to
`CredentialRepository.lookup()`.


== Version 1.12.0 ==

New features:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,27 @@ default AssertionResult run() throws InvalidSignatureCountException {

@Value
class Step0 implements Step<Step1> {

private final Optional<ByteArray> userHandle =
response
.getResponse()
.getUserHandle()
.map(Optional::of)
.orElseGet(
() -> credentialRepository.getUserHandleForUsername(request.getUsername().get()));

private final Optional<String> username =
request
.getUsername()
.map(Optional::of)
.orElseGet(
() ->
credentialRepository.getUsernameForUserHandle(
response.getResponse().getUserHandle().get()));

@Override
public Step1 nextStep() {
return new Step1(username().get(), userHandle().get(), allWarnings());
return new Step1(username.get(), userHandle.get(), allWarnings());
}

@Override
Expand All @@ -128,12 +146,12 @@ public void validate() {
request.getUsername().isPresent() || response.getResponse().getUserHandle().isPresent(),
"At least one of username and user handle must be given; none was.");
assure(
userHandle().isPresent(),
userHandle.isPresent(),
"No user found for username: %s, userHandle: %s",
request.getUsername(),
response.getResponse().getUserHandle());
assure(
username().isPresent(),
username.isPresent(),
"No user found for username: %s, userHandle: %s",
request.getUsername(),
response.getResponse().getUserHandle());
Expand All @@ -143,25 +161,6 @@ public void validate() {
public List<String> getPrevWarnings() {
return Collections.emptyList();
}

private Optional<ByteArray> userHandle() {
return response
.getResponse()
.getUserHandle()
.map(Optional::of)
.orElseGet(
() -> credentialRepository.getUserHandleForUsername(request.getUsername().get()));
}

private Optional<String> username() {
return request
.getUsername()
.map(Optional::of)
.orElseGet(
() ->
credentialRepository.getUsernameForUserHandle(
response.getResponse().getUserHandle().get()));
}
}

@Value
Expand Down Expand Up @@ -196,16 +195,22 @@ class Step2 implements Step<Step3> {
private final ByteArray userHandle;
private final List<String> prevWarnings;

private final Optional<RegisteredCredential> registration;

public Step2(String username, ByteArray userHandle, List<String> prevWarnings) {
this.username = username;
this.userHandle = userHandle;
this.prevWarnings = prevWarnings;
this.registration = credentialRepository.lookup(response.getId(), userHandle);
}

@Override
public Step3 nextStep() {
return new Step3(username, userHandle, allWarnings());
return new Step3(username, userHandle, registration, allWarnings());
}

@Override
public void validate() {
Optional<RegisteredCredential> registration =
credentialRepository.lookup(response.getId(), userHandle);

assure(registration.isPresent(), "Unknown credential: %s", response.getId());

assure(
Expand All @@ -220,29 +225,22 @@ public void validate() {
class Step3 implements Step<Step4> {
private final String username;
private final ByteArray userHandle;
private final Optional<RegisteredCredential> credential;
private final List<String> prevWarnings;

@Override
public Step4 nextStep() {
return new Step4(username, userHandle, credential(), allWarnings());
return new Step4(username, userHandle, credential.get(), allWarnings());
}

@Override
public void validate() {
assure(
maybeCredential().isPresent(),
credential.isPresent(),
"Unknown credential. Credential ID: %s, user handle: %s",
response.getId(),
userHandle);
}

private Optional<RegisteredCredential> maybeCredential() {
return credentialRepository.lookup(response.getId(), userHandle);
}

public RegisteredCredential credential() {
return maybeCredential().get();
}
}

@Value
Expand Down Expand Up @@ -581,7 +579,7 @@ public void validate() {

@Override
public Step17 nextStep() {
return new Step17(username, userHandle, allWarnings());
return new Step17(username, userHandle, credential, allWarnings());
}

public ByteArray signedBytes() {
Expand All @@ -593,19 +591,33 @@ public ByteArray signedBytes() {
class Step17 implements Step<Finished> {
private final String username;
private final ByteArray userHandle;
private final RegisteredCredential credential;
private final List<String> prevWarnings;
private final long storedSignatureCountBefore;

public Step17(
String username,
ByteArray userHandle,
RegisteredCredential credential,
List<String> prevWarnings) {
this.username = username;
this.userHandle = userHandle;
this.credential = credential;
this.prevWarnings = prevWarnings;
this.storedSignatureCountBefore = credential.getSignatureCount();
}

@Override
public void validate() throws InvalidSignatureCountException {
if (validateSignatureCounter && !signatureCounterValid()) {
throw new InvalidSignatureCountException(
response.getId(), storedSignatureCountBefore() + 1, assertionSignatureCount());
response.getId(), storedSignatureCountBefore + 1, assertionSignatureCount());
}
}

private boolean signatureCounterValid() {
return (assertionSignatureCount() == 0 && storedSignatureCountBefore() == 0)
|| assertionSignatureCount() > storedSignatureCountBefore();
return (assertionSignatureCount() == 0 && storedSignatureCountBefore == 0)
|| assertionSignatureCount() > storedSignatureCountBefore;
}

@Override
Expand All @@ -614,13 +626,6 @@ public Finished nextStep() {
username, userHandle, assertionSignatureCount(), signatureCounterValid(), allWarnings());
}

private long storedSignatureCountBefore() {
return credentialRepository
.lookup(response.getId(), userHandle)
.map(RegisteredCredential::getSignatureCount)
.orElse(0L);
}

private long assertionSignatureCount() {
return response.getResponse().getParsedAuthenticatorData().getSignatureCounter();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,98 @@ class RelyingPartyAssertionSpec

}

describe("§7.2. Verifying an authentication assertion") {
describe("RelyingParty.finishAssertion") {

describe("When verifying a given PublicKeyCredential structure (credential) and an AuthenticationExtensionsClientOutputs structure clientExtensionResults, as part of an authentication ceremony, the Relying Party MUST proceed as follows:") {
it("does not make redundant calls to CredentialRepository.lookup().") {
val registrationTestData =
RegistrationTestData.Packed.BasicAttestationEdDsa
val testData = registrationTestData.assertion.get

val credRepo = {
val user = registrationTestData.userId
val credential = RegisteredCredential
.builder()
.credentialId(registrationTestData.response.getId)
.userHandle(registrationTestData.userId.getId)
.publicKeyCose(
registrationTestData.response.getResponse.getParsedAuthenticatorData.getAttestedCredentialData.get.getCredentialPublicKey
)
.signatureCount(0)
.build()

new CredentialRepository {
var lookupCount = 0

override def getCredentialIdsForUsername(
username: String
): java.util.Set[PublicKeyCredentialDescriptor] =
if (username == user.getName)
Set(
PublicKeyCredentialDescriptor
.builder()
.id(credential.getCredentialId)
.build()
).asJava
else Set.empty.asJava

override def getUserHandleForUsername(
username: String
): Optional[ByteArray] =
if (username == user.getName)
Some(user.getId).asJava
else None.asJava

override def getUsernameForUserHandle(
userHandle: ByteArray
): Optional[String] =
if (userHandle == user.getId)
Some(user.getName).asJava
else None.asJava

override def lookup(
credentialId: ByteArray,
userHandle: ByteArray,
): Optional[RegisteredCredential] = {
lookupCount += 1
if (
credentialId == credential.getCredentialId && userHandle == user.getId
)
Some(credential).asJava
else None.asJava
}

override def lookupAll(
credentialId: ByteArray
): java.util.Set[RegisteredCredential] =
if (credentialId == credential.getCredentialId)
Set(credential).asJava
else Set.empty.asJava
}
}
val rp = RelyingParty
.builder()
.identity(
RelyingPartyIdentity.builder().id("localhost").name("Test RP").build()
)
.credentialRepository(credRepo)
.build()

val result = rp.finishAssertion(
FinishAssertionOptions
.builder()
.request(testData.request)
.response(testData.response)
.build()
)

result.isSuccess should be(true)
result.getUserHandle should equal(registrationTestData.userId.getId)
result.getCredentialId should equal(registrationTestData.response.getId)
result.getCredentialId should equal(testData.response.getId)
credRepo.lookupCount should equal(1)
}

describe("§7.2. Verifying an authentication assertion: When verifying a given PublicKeyCredential structure (credential) and an AuthenticationExtensionsClientOutputs structure clientExtensionResults, as part of an authentication ceremony, the Relying Party MUST proceed as follows:") {

describe("1. If the allowCredentials option was given when this authentication ceremony was initiated, verify that credential.id identifies one of the public key credentials that were listed in allowCredentials.") {
it("Fails if returned credential ID is a requested one.") {
Expand Down Expand Up @@ -497,6 +586,7 @@ class RelyingPartyAssertionSpec
val step: steps.Step3 = new steps.Step3(
Defaults.username,
Defaults.userHandle,
None.asJava,
Nil.asJava,
)

Expand All @@ -523,9 +613,6 @@ class RelyingPartyAssertionSpec
val step: FinishAssertionSteps#Step3 = steps.begin.next.next.next

step.validations shouldBe a[Success[_]]
step.credential.getPublicKeyCose should equal(
getPublicKeyBytes(Defaults.credentialKey)
)
step.tryNext shouldBe a[Success[_]]
}
}
Expand Down

0 comments on commit e91d042

Please sign in to comment.