diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java index fb0468fa9bc..07720a45d88 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java @@ -18,6 +18,8 @@ import java.net.URI; import java.net.URL; +import java.security.KeyPair; +import java.security.interfaces.ECPublicKey; import java.time.Instant; import java.util.ArrayList; import java.util.Date; @@ -25,19 +27,28 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import javax.crypto.SecretKey; + import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JOSEObjectType; import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSHeader; import com.nimbusds.jose.JWSSigner; import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory; +import com.nimbusds.jose.jwk.Curve; +import com.nimbusds.jose.jwk.ECKey; import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKMatcher; import com.nimbusds.jose.jwk.JWKSelector; +import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.KeyType; import com.nimbusds.jose.jwk.KeyUse; +import com.nimbusds.jose.jwk.OctetSequenceKey; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.jwk.source.ImmutableJWKSet; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jose.produce.JWSSignerFactory; @@ -47,6 +58,7 @@ import com.nimbusds.jwt.SignedJWT; import org.springframework.core.convert.converter.Converter; +import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -83,6 +95,8 @@ public final class NimbusJwtEncoder implements JwtEncoder { private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory(); + private JwsHeader jwsHeader; + private final Map jwsSigners = new ConcurrentHashMap<>(); private final JWKSource jwkSource; @@ -119,14 +133,16 @@ public void setJwkSelector(Converter, JWK> jwkSelector) { this.jwkSelector = jwkSelector; } + public void setJwsHeader(JwsHeader jwsHeader) { + this.jwsHeader = jwsHeader; + } + @Override public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { Assert.notNull(parameters, "parameters cannot be null"); JwsHeader headers = parameters.getJwsHeader(); - if (headers == null) { - headers = DEFAULT_JWS_HEADER; - } + headers = (headers != null) ? headers : (this.jwsHeader != null) ? this.jwsHeader : DEFAULT_JWS_HEADER; JwtClaimsSet claims = parameters.getClaims(); JWK jwk = selectJwk(headers); @@ -369,4 +385,250 @@ private static URI convertAsURI(String header, URL url) { } } + /** + * Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided + * {@link SecretKey}. + * @param secretKey the {@link SecretKey} to use for signing JWTs + * @return a {@link SecretKeyJwtEncoderBuilder} for further configuration + * @since 7.0 + */ + public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) { + Assert.notNull(secretKey, "secretKey cannot be null"); + return new SecretKeyJwtEncoderBuilder(secretKey); + } + + /** + * Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided + * {@link KeyPair}. The key pair must contain either an {@link RSAKey} or an + * {@link ECKey}. + * @param keyPair the {@link KeyPair} to use for signing JWTs + * @return a {@link KeyPairJwtEncoderBuilder} for further configuration + * @since 7.0 + */ + public static KeyPairJwtEncoderBuilder withKeyPair(KeyPair keyPair) { + Assert.isTrue(keyPair != null && keyPair.getPrivate() != null && keyPair.getPublic() != null, + "keyPair, its private key, and public key must not be null"); + if (keyPair.getPrivate() instanceof java.security.interfaces.RSAKey) { + return new RsaKeyPairJwtEncoderBuilder(keyPair); + } + if (keyPair.getPrivate() instanceof java.security.interfaces.ECKey) { + return new EcKeyPairJwtEncoderBuilder(keyPair); + } + throw new IllegalArgumentException("keyPair must be an RSAKey or an ECKey"); + } + + /** + * A builder for creating {@link NimbusJwtEncoder} instances configured with a + * {@link SecretKey}. + * + * @since 7.0 + */ + public static final class SecretKeyJwtEncoderBuilder { + + private final SecretKey secretKey; + + private String keyId; + + private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256; + + private SecretKeyJwtEncoderBuilder(SecretKey secretKey) { + this.secretKey = secretKey; + } + + /** + * Sets the JWS algorithm to use for signing. Defaults to + * {@link JWSAlgorithm#HS256}. Must be an HMAC-based algorithm (HS256, HS384, or + * HS512). + * @param macAlgorithm the {@link MacAlgorithm} to use + * @return this builder instance for method chaining + */ + public SecretKeyJwtEncoderBuilder macAlgorithm(MacAlgorithm macAlgorithm) { + Assert.notNull(macAlgorithm, "macAlgorithm cannot be null"); + Assert.state(JWSAlgorithm.Family.HMAC_SHA.contains(this.jwsAlgorithm), + () -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with a SecretKey. " + + "Please use one of the HS256, HS384, or HS512 algorithms."); + + this.jwsAlgorithm = JWSAlgorithm.parse(macAlgorithm.getName()); + return this; + } + + /** + * Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS + * header. + * @param keyId the key identifier + * @return this builder instance for method chaining + */ + public SecretKeyJwtEncoderBuilder keyId(String keyId) { + this.keyId = keyId; + return this; + } + + /** + * Builds the {@link NimbusJwtEncoder} instance. + * @return the configured {@link NimbusJwtEncoder} + * @throws IllegalStateException if the configured JWS algorithm is not compatible + * with a {@link SecretKey}. + */ + public NimbusJwtEncoder build() { + this.jwsAlgorithm = (this.jwsAlgorithm != null) ? this.jwsAlgorithm : JWSAlgorithm.HS256; + + OctetSequenceKey.Builder builder = new OctetSequenceKey.Builder(this.secretKey).keyUse(KeyUse.SIGNATURE) + .algorithm(this.jwsAlgorithm) + .keyID(this.keyId); + + OctetSequenceKey jwk = builder.build(); + JWKSource jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk)); + NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource); + encoder.setJwsHeader(JwsHeader.with(MacAlgorithm.from(this.jwsAlgorithm.getName())).build()); + return encoder; + } + + } + + /** + * A builder for creating {@link NimbusJwtEncoder} instances configured with a + * {@link KeyPair}. + * + * @since 7.0 + */ + public abstract static class KeyPairJwtEncoderBuilder { + + private final KeyPair keyPair; + + private String keyId; + + private JWSAlgorithm jwsAlgorithm; + + private KeyPairJwtEncoderBuilder(KeyPair keyPair) { + this.keyPair = keyPair; + } + + /** + * Sets the JWS algorithm to use for signing. Must be compatible with the key type + * (RSA or EC). If not set, a default algorithm will be chosen based on the key + * type (e.g., RS256 for RSA, ES256 for EC). + * @param signatureAlgorithm the {@link SignatureAlgorithm} to use + * @return this builder instance for method chaining + */ + public KeyPairJwtEncoderBuilder signatureAlgorithm(SignatureAlgorithm signatureAlgorithm) { + Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null"); + if (this.keyPair.getPrivate() instanceof java.security.interfaces.RSAKey) { + Assert.state(JWSAlgorithm.Family.RSA.contains(JWSAlgorithm.parse(signatureAlgorithm.getName())), + () -> "The algorithm '" + signatureAlgorithm + "' is not compatible with an RSAKey. " + + "Please use one of the RS256, RS384, RS512, PS256, PS384, or PS512 algorithms."); + + } + if (this.keyPair.getPrivate() instanceof java.security.interfaces.ECKey) { + Assert.state(JWSAlgorithm.Family.EC.contains(JWSAlgorithm.parse(signatureAlgorithm.getName())), + () -> "The algorithm '" + signatureAlgorithm + "' is not compatible with an ECKey. " + + "Please use one of the ES256, ES384, or ES512 algorithms."); + } + this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName()); + return this; + } + + /** + * Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS + * header. + * @param keyId the key identifier + * @return this builder instance for method chaining + */ + public KeyPairJwtEncoderBuilder keyId(String keyId) { + this.keyId = keyId; + return this; + } + + /** + * Builds the {@link NimbusJwtEncoder} instance. + * @return the configured {@link NimbusJwtEncoder} + * @throws IllegalStateException if the key type is unsupported or the configured + * JWS algorithm is not compatible with the key type. + * @throws JwtEncodingException if the key is invalid (e.g., EC key with unknown + * curve) + */ + public NimbusJwtEncoder build() { + this.keyId = (this.keyId != null) ? this.keyId : UUID.randomUUID().toString(); + JWK jwk = buildJwk(); + JWKSource jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk)); + NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.from(this.jwsAlgorithm.getName())) + .keyId(jwk.getKeyID()) + .build(); + encoder.setJwsHeader(jwsHeader); + return encoder; + } + + protected abstract JWK buildJwk(); + + } + + /** + * A builder for creating {@link NimbusJwtEncoder} instances configured with a + * {@link KeyPair}. + * + * @since 7.0 + */ + public static final class RsaKeyPairJwtEncoderBuilder extends KeyPairJwtEncoderBuilder { + + private RsaKeyPairJwtEncoderBuilder(KeyPair keyPair) { + super(keyPair); + } + + @Override + protected JWK buildJwk() { + if (super.jwsAlgorithm == null) { + super.jwsAlgorithm = JWSAlgorithm.RS256; + } + + RSAKey.Builder builder = new RSAKey.Builder( + (java.security.interfaces.RSAPublicKey) super.keyPair.getPublic()) + .privateKey(super.keyPair.getPrivate()) + .keyID(super.keyId) + .keyUse(KeyUse.SIGNATURE) + .algorithm(super.jwsAlgorithm); + return builder.build(); + } + + } + + /** + * A builder for creating {@link NimbusJwtEncoder} instances configured with a + * {@link KeyPair}. + * + * @since 7.0 + */ + public static final class EcKeyPairJwtEncoderBuilder extends KeyPairJwtEncoderBuilder { + + private EcKeyPairJwtEncoderBuilder(KeyPair keyPair) { + super(keyPair); + } + + @Override + protected JWK buildJwk() { + if (super.jwsAlgorithm == null) { + super.jwsAlgorithm = JWSAlgorithm.ES256; + } + + ECPublicKey publicKey = (ECPublicKey) super.keyPair.getPublic(); + Curve curve = Curve.forECParameterSpec(publicKey.getParams()); + if (curve == null) { + throw new JwtEncodingException("Unable to determine Curve for EC public key."); + } + + com.nimbusds.jose.jwk.ECKey.Builder builder = new com.nimbusds.jose.jwk.ECKey.Builder(curve, publicKey) + .privateKey(super.keyPair.getPrivate()) + .keyUse(KeyUse.SIGNATURE) + .keyID(super.keyId) + .algorithm(super.jwsAlgorithm); + + try { + return builder.build(); + } + catch (IllegalStateException ex) { + throw new IllegalArgumentException("Failed to build ECKey: " + ex.getMessage(), ex); + } + } + + } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java index ab17156eacb..2abbaebf337 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java @@ -16,12 +16,20 @@ package org.springframework.security.oauth2.jwt; +import java.security.KeyPair; +import java.security.KeyPairGenerator; import java.security.interfaces.ECPrivateKey; import java.security.interfaces.ECPublicKey; +import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.UUID; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.KeySourceException; @@ -48,6 +56,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; @@ -344,6 +353,187 @@ public void encodeWhenNoKeysThenJwkSelectorIsNotUsed() throws Exception { verifyNoInteractions(selector); } + @Test + void secretKeyBuilderWithDefaultAlgorithm() { + SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC"); + JwtClaimsSet claims = buildClaims(); + + NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).build(); + Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); + + assertThat(jwt).isNotNull(); + assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("HS256"); + assertThatNoException().isThrownBy(jwt::getClaims); + assertJwt(jwt); + } + + @Test + void secretKeyBuilderWithKeyId() { + SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC"); + String keyId = "test-key-id"; + JwtClaimsSet claims = buildClaims(); + + NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).keyId(keyId).build(); + Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); + + assertThat(jwt).isNotNull(); + assertThat(jwt.getHeaders().get("kid").toString()).isEqualTo(keyId); + assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("HS256"); + assertThatNoException().isThrownBy(jwt::getClaims); + assertJwt(jwt); + } + + @Test + void secretKeyBuilderWithCustomJwkSelector() { + SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC"); + String keyId = "test-key-id"; + JwtClaimsSet claims = buildClaims(); + + NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).keyId(keyId).build(); + Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); + + assertThat(jwt).isNotNull(); + assertThat(jwt.getHeaders().get("kid")).isEqualTo(keyId); + assertThat(jwt.getClaims()).containsEntry("sub", "subject"); + assertThatNoException().isThrownBy(() -> jwt.getClaims()); + assertJwt(jwt); + } + + @Test + void secretKeyBuilderWithCustomHeaders() { + SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC"); + JwtClaimsSet claims = buildClaims(); + JwsHeader headers = JwsHeader.with(org.springframework.security.oauth2.jose.jws.MacAlgorithm.HS256) + .type("JWT") + .contentType("application/jwt") + .build(); + + NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).build(); + Jwt jwt = encoder.encode(JwtEncoderParameters.from(headers, claims)); + + assertThat(jwt).isNotNull(); + assertThat(jwt.getHeaders().get("typ").toString()).isEqualTo("JWT"); + assertThat(jwt.getHeaders().get("cty").toString()).isEqualTo("application/jwt"); + assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("HS256"); + assertThatNoException().isThrownBy(() -> jwt.getClaims()); + assertJwt(jwt); + } + + @Test + void keyPairBuilderWithRsaDefaultAlgorithm() throws Exception { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + KeyPair keyPair = keyPairGenerator.generateKeyPair(); + JwtClaimsSet claims = buildClaims(); + + NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair).build(); + Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); + + assertThat(jwt).isNotNull(); + assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("RS256"); + assertThat(jwt.getSubject()).isEqualTo(claims.getSubject()); + assertThat(jwt.getAudience()).isEqualTo(claims.getAudience()); + assertThatNoException().isThrownBy(() -> jwt.getClaims()); + assertJwt(jwt); + } + + @Test + void keyPairBuilderWithRsaCustomAlgorithm() throws Exception { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + KeyPair keyPair = keyPairGenerator.generateKeyPair(); + JwtClaimsSet claims = buildClaims(); + + NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair) + .signatureAlgorithm(SignatureAlgorithm.RS512) + .build(); + Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); + + assertThat(jwt).isNotNull(); + assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("RS512"); + assertThat(jwt.getSubject()).isEqualTo(claims.getSubject()); + assertThatNoException().isThrownBy(() -> jwt.getClaims()); + assertJwt(jwt); + } + + @Test + void keyPairBuilderWithEcDefaultAlgorithm() throws Exception { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("EC"); + keyPairGenerator.initialize(256); + KeyPair keyPair = keyPairGenerator.generateKeyPair(); + JwtClaimsSet claims = buildClaims(); + + NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair).build(); + Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); + + assertThat(jwt).isNotNull(); + assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("ES256"); + assertThat(jwt.getSubject()).isEqualTo(claims.getSubject()); + assertThatNoException().isThrownBy(() -> jwt.getClaims()); + assertJwt(jwt); + } + + @Test + void keyPairBuilderWithEcCustomAlgorithm() throws Exception { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("EC"); + keyPairGenerator.initialize(256); + KeyPair keyPair = keyPairGenerator.generateKeyPair(); + NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair) + .keyId(UUID.randomUUID().toString()) + .signatureAlgorithm(SignatureAlgorithm.ES256) + .build(); + + JwtClaimsSet claims = buildClaims(); + Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); + + assertThat(jwt).isNotNull(); + assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("ES256"); + assertThatNoException().isThrownBy(() -> jwt.getClaims()); + assertJwt(jwt); + } + + @Test + void keyPairBuilderWithKeyId() throws Exception { // d + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + KeyPair keyPair = keyPairGenerator.generateKeyPair(); + String keyId = "test-key-id"; + JwtClaimsSet claims = buildClaims(); + + NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair).keyId(keyId).build(); + Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims)); + + assertThat(jwt).isNotNull(); + assertThat(jwt.getHeaders().get("kid")).isEqualTo(keyId); + assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("RS256"); + assertThatNoException().isThrownBy(() -> jwt.getClaims()); + } + + private JwtClaimsSet buildClaims() { + Instant now = Instant.now(); + return JwtClaimsSet.builder() + .issuer("https://example.com") + .subject("subject") + .audience(Collections.singletonList("audience")) + .issuedAt(now) + .notBefore(now) + .expiresAt(now.plus(1, ChronoUnit.HOURS)) + .id(UUID.randomUUID().toString()) + .claim("custom", "value") + .build(); + } + + private static void assertJwt(Jwt jwt) { + assertThat(jwt.getIssuer().toString()).isEqualTo("https://example.com"); + assertThat(jwt.getSubject()).isEqualTo("subject"); + assertThat(jwt.getAudience()).containsExactly("audience"); + assertThat(jwt.getIssuedAt()).isNotNull(); + assertThat(jwt.getNotBefore()).isNotNull(); + assertThat(jwt.getExpiresAt()).isNotNull(); + assertThat(jwt.getId()).isNotNull(); + assertThat(jwt.getClaim("custom").toString()).isEqualTo("value"); + } + private static final class JwkListResultCaptor implements Answer> { private List result;