Skip to content

Commit 0d3089f

Browse files
committed
optional and mandatory claims check
Signed-off-by: Ferenc Kemeny <[email protected]>
1 parent 9df3a57 commit 0d3089f

File tree

6 files changed

+336
-39
lines changed

6 files changed

+336
-39
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtAudienceValidator.java

+14-3
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,29 @@ public final class JwtAudienceValidator implements OAuth2TokenValidator<Jwt> {
3333
private final JwtClaimValidator<Collection<String>> validator;
3434

3535
/**
36-
* Constructs a {@link JwtAudienceValidator} using the provided parameters
36+
* Constructs a {@link JwtAudienceValidator} using the provided parameters with
37+
* {@link JwtClaimNames#ISS "iss"} claim is REQUIRED
3738
* @param audience - The audience that each {@link Jwt} should have.
3839
*/
3940
public JwtAudienceValidator(String audience) {
41+
this(audience, true);
42+
}
43+
44+
/**
45+
* Constructs a {@link JwtIssuerValidator} using the provided parameters
46+
* @param audience - The audience that each {@link Jwt} should have.
47+
* @param required -{@code true} if the {@link JwtClaimNames#AUD "aud"} claim is
48+
* REQUIRED in the {@link Jwt}, {@code false} otherwise
49+
*/
50+
public JwtAudienceValidator(String audience, boolean required) {
4051
Assert.notNull(audience, "audience cannot be null");
4152
this.validator = new JwtClaimValidator<>(JwtClaimNames.AUD,
42-
(claimValue) -> (claimValue != null) && claimValue.contains(audience));
53+
(claimValue) -> (claimValue != null) ? claimValue.contains(audience) : !required);
4354
}
4455

4556
@Override
4657
public OAuth2TokenValidatorResult validate(Jwt token) {
47-
Assert.notNull(token, "token cannot be null");
58+
Assert.notNull(token, "jwt cannot be null");
4859
return this.validator.validate(token);
4960
}
5061

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtIssuerValidator.java

+15-7
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
package org.springframework.security.oauth2.jwt;
1818

19-
import java.util.function.Predicate;
20-
2119
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
2220
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
2321
import org.springframework.util.Assert;
@@ -33,19 +31,29 @@ public final class JwtIssuerValidator implements OAuth2TokenValidator<Jwt> {
3331
private final JwtClaimValidator<Object> validator;
3432

3533
/**
36-
* Constructs a {@link JwtIssuerValidator} using the provided parameters
34+
* Constructs a {@link JwtIssuerValidator} using the provided parameters with
35+
* {@link JwtClaimNames#ISS "iss"} claim is REQUIRED
3736
* @param issuer - The issuer that each {@link Jwt} should have.
3837
*/
3938
public JwtIssuerValidator(String issuer) {
40-
Assert.notNull(issuer, "issuer cannot be null");
39+
this(issuer, true);
40+
}
4141

42-
Predicate<Object> testClaimValue = (claimValue) -> (claimValue != null) && issuer.equals(claimValue.toString());
43-
this.validator = new JwtClaimValidator<>(JwtClaimNames.ISS, testClaimValue);
42+
/**
43+
* Constructs a {@link JwtIssuerValidator} using the provided parameters
44+
* @param issuer - The issuer that each {@link Jwt} should have.
45+
* @param required -{@code true} if the {@link JwtClaimNames#ISS "iss"} claim is
46+
* REQUIRED in the {@link Jwt}, {@code false} otherwise
47+
*/
48+
public JwtIssuerValidator(String issuer, boolean required) {
49+
Assert.notNull(issuer, "issuer cannot be null");
50+
this.validator = new JwtClaimValidator<>(JwtClaimNames.ISS,
51+
(claimValue) -> (claimValue != null) ? issuer.equals(claimValue.toString()) : !required);
4452
}
4553

4654
@Override
4755
public OAuth2TokenValidatorResult validate(Jwt token) {
48-
Assert.notNull(token, "token cannot be null");
56+
Assert.notNull(token, "jwt cannot be null");
4957
return this.validator.validate(token);
5058
}
5159

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtTimestampValidator.java

+17-2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ public final class JwtTimestampValidator implements OAuth2TokenValidator<Jwt> {
5252

5353
private static final Duration DEFAULT_MAX_CLOCK_SKEW = Duration.of(60, ChronoUnit.SECONDS);
5454

55+
private final boolean required;
56+
5557
private final Duration clockSkew;
5658

5759
private Clock clock = Clock.systemUTC();
@@ -60,25 +62,38 @@ public final class JwtTimestampValidator implements OAuth2TokenValidator<Jwt> {
6062
* A basic instance with no custom verification and the default max clock skew
6163
*/
6264
public JwtTimestampValidator() {
63-
this(DEFAULT_MAX_CLOCK_SKEW);
65+
this(DEFAULT_MAX_CLOCK_SKEW, false);
66+
}
67+
68+
public JwtTimestampValidator(boolean required) {
69+
this(DEFAULT_MAX_CLOCK_SKEW, required);
6470
}
6571

6672
public JwtTimestampValidator(Duration clockSkew) {
73+
this(clockSkew, false);
74+
}
75+
76+
public JwtTimestampValidator(Duration clockSkew, boolean required) {
6777
Assert.notNull(clockSkew, "clockSkew cannot be null");
78+
this.required = required;
6879
this.clockSkew = clockSkew;
6980
}
7081

7182
@Override
7283
public OAuth2TokenValidatorResult validate(Jwt jwt) {
7384
Assert.notNull(jwt, "jwt cannot be null");
7485
Instant expiry = jwt.getExpiresAt();
86+
Instant notBefore = jwt.getNotBefore();
87+
if (this.required && !(expiry != null || notBefore != null)) {
88+
OAuth2Error oAuth2Error = createOAuth2Error("exp and nbf are required");
89+
return OAuth2TokenValidatorResult.failure(oAuth2Error);
90+
}
7591
if (expiry != null) {
7692
if (Instant.now(this.clock).minus(this.clockSkew).isAfter(expiry)) {
7793
OAuth2Error oAuth2Error = createOAuth2Error(String.format("Jwt expired at %s", jwt.getExpiresAt()));
7894
return OAuth2TokenValidatorResult.failure(oAuth2Error);
7995
}
8096
}
81-
Instant notBefore = jwt.getNotBefore();
8297
if (notBefore != null) {
8398
if (Instant.now(this.clock).plus(this.clockSkew).isBefore(notBefore)) {
8499
OAuth2Error oAuth2Error = createOAuth2Error(String.format("Jwt used before %s", jwt.getNotBefore()));

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtAudienceValidatorTests.java

+53-7
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,73 @@
3131
*/
3232
class JwtAudienceValidatorTests {
3333

34-
private final JwtAudienceValidator validator = new JwtAudienceValidator("audience");
34+
private final JwtAudienceValidator validatorDefault = new JwtAudienceValidator("audience");
35+
36+
private final JwtAudienceValidator validatorRequiredTrue = new JwtAudienceValidator("audience", true);
37+
38+
private final JwtAudienceValidator validatorRequiredFalse = new JwtAudienceValidator("audience", false);
39+
40+
@Test
41+
void givenRequiredDefaultJwtWithMatchingAudienceThenShouldValidate() {
42+
Jwt jwt = TestJwts.jwt().audience(List.of("audience")).build();
43+
OAuth2TokenValidatorResult result = this.validatorDefault.validate(jwt);
44+
assertThat(result).isEqualTo(OAuth2TokenValidatorResult.success());
45+
}
46+
47+
@Test
48+
void givenRequiredJwtWithMatchingAudienceThenShouldValidate() {
49+
Jwt jwt = TestJwts.jwt().audience(List.of("audience")).build();
50+
OAuth2TokenValidatorResult result = this.validatorRequiredTrue.validate(jwt);
51+
assertThat(result).isEqualTo(OAuth2TokenValidatorResult.success());
52+
}
3553

3654
@Test
37-
void givenJwtWithMatchingAudienceThenShouldValidate() {
55+
void givenNotRequiredJwtWithMatchingAudienceThenShouldValidate() {
3856
Jwt jwt = TestJwts.jwt().audience(List.of("audience")).build();
39-
OAuth2TokenValidatorResult result = this.validator.validate(jwt);
57+
OAuth2TokenValidatorResult result = this.validatorRequiredFalse.validate(jwt);
4058
assertThat(result).isEqualTo(OAuth2TokenValidatorResult.success());
4159
}
4260

4361
@Test
44-
void givenJwtWithoutMatchingAudienceThenShouldValidate() {
62+
void givenRequiredDefaultJwtWithoutMatchingAudienceThenShouldValidate() {
4563
Jwt jwt = TestJwts.jwt().audience(List.of("other")).build();
46-
OAuth2TokenValidatorResult result = this.validator.validate(jwt);
64+
OAuth2TokenValidatorResult result = this.validatorDefault.validate(jwt);
4765
assertThat(result.hasErrors()).isTrue();
4866
}
4967

5068
@Test
51-
void givenJwtWithoutAudienceThenShouldValidate() {
69+
void givenRequiredJwtWithoutMatchingAudienceThenShouldValidate() {
70+
Jwt jwt = TestJwts.jwt().audience(List.of("other")).build();
71+
OAuth2TokenValidatorResult result = this.validatorRequiredTrue.validate(jwt);
72+
assertThat(result.hasErrors()).isTrue();
73+
}
74+
75+
@Test
76+
void givenNotRequiredJwtWithoutMatchingAudienceThenShouldValidate() {
77+
Jwt jwt = TestJwts.jwt().audience(List.of("other")).build();
78+
OAuth2TokenValidatorResult result = this.validatorRequiredFalse.validate(jwt);
79+
assertThat(result.hasErrors()).isTrue();
80+
}
81+
82+
@Test
83+
void givenRequiredDefaultJwtWithoutAudienceThenShouldValidate() {
5284
Jwt jwt = TestJwts.jwt().audience(null).build();
53-
OAuth2TokenValidatorResult result = this.validator.validate(jwt);
85+
OAuth2TokenValidatorResult result = this.validatorDefault.validate(jwt);
5486
assertThat(result.hasErrors()).isTrue();
5587
}
5688

89+
@Test
90+
void givenRequiredJwtWithoutAudienceThenShouldValidate() {
91+
Jwt jwt = TestJwts.jwt().audience(null).build();
92+
OAuth2TokenValidatorResult result = this.validatorRequiredTrue.validate(jwt);
93+
assertThat(result.hasErrors()).isTrue();
94+
}
95+
96+
@Test
97+
void givenNotRequiredJwtWithoutAudienceThenShouldValidate() {
98+
Jwt jwt = TestJwts.jwt().audience(null).build();
99+
OAuth2TokenValidatorResult result = this.validatorRequiredFalse.validate(jwt);
100+
assertThat(result.hasErrors()).isFalse();
101+
}
102+
57103
}

0 commit comments

Comments
 (0)