Skip to content

Commit ada5292

Browse files
committed
Add NimbusJwtEncoder Builders
Closes spring-projectsgh-16267 Signed-off-by: Suraj Bhadrike <[email protected]>
1 parent 226e81d commit ada5292

File tree

2 files changed

+446
-3
lines changed

2 files changed

+446
-3
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java

Lines changed: 254 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,39 @@
1818

1919
import java.net.URI;
2020
import java.net.URL;
21+
import java.security.KeyPair;
22+
import java.security.interfaces.ECPrivateKey;
23+
import java.security.interfaces.ECPublicKey;
24+
import java.security.interfaces.RSAPrivateKey;
2125
import java.time.Instant;
2226
import java.util.ArrayList;
2327
import java.util.Date;
2428
import java.util.HashMap;
2529
import java.util.List;
2630
import java.util.Map;
2731
import java.util.Set;
32+
import java.util.UUID;
2833
import java.util.concurrent.ConcurrentHashMap;
2934

35+
import javax.crypto.SecretKey;
36+
3037
import com.nimbusds.jose.JOSEException;
3138
import com.nimbusds.jose.JOSEObjectType;
3239
import com.nimbusds.jose.JWSAlgorithm;
3340
import com.nimbusds.jose.JWSHeader;
3441
import com.nimbusds.jose.JWSSigner;
3542
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
43+
import com.nimbusds.jose.jwk.Curve;
44+
import com.nimbusds.jose.jwk.ECKey;
3645
import com.nimbusds.jose.jwk.JWK;
3746
import com.nimbusds.jose.jwk.JWKMatcher;
3847
import com.nimbusds.jose.jwk.JWKSelector;
48+
import com.nimbusds.jose.jwk.JWKSet;
3949
import com.nimbusds.jose.jwk.KeyType;
4050
import com.nimbusds.jose.jwk.KeyUse;
51+
import com.nimbusds.jose.jwk.OctetSequenceKey;
52+
import com.nimbusds.jose.jwk.RSAKey;
53+
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
4154
import com.nimbusds.jose.jwk.source.JWKSource;
4255
import com.nimbusds.jose.proc.SecurityContext;
4356
import com.nimbusds.jose.produce.JWSSignerFactory;
@@ -47,6 +60,7 @@
4760
import com.nimbusds.jwt.SignedJWT;
4861

4962
import org.springframework.core.convert.converter.Converter;
63+
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
5064
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
5165
import org.springframework.util.Assert;
5266
import org.springframework.util.CollectionUtils;
@@ -83,6 +97,8 @@ public final class NimbusJwtEncoder implements JwtEncoder {
8397

8498
private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory();
8599

100+
private JwsHeader jwsHeader;
101+
86102
private final Map<JWK, JWSSigner> jwsSigners = new ConcurrentHashMap<>();
87103

88104
private final JWKSource<SecurityContext> jwkSource;
@@ -119,14 +135,16 @@ public void setJwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
119135
this.jwkSelector = jwkSelector;
120136
}
121137

138+
public void setJwsHeader(JwsHeader jwsHeader) {
139+
this.jwsHeader = jwsHeader;
140+
}
141+
122142
@Override
123143
public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException {
124144
Assert.notNull(parameters, "parameters cannot be null");
125145

126146
JwsHeader headers = parameters.getJwsHeader();
127-
if (headers == null) {
128-
headers = DEFAULT_JWS_HEADER;
129-
}
147+
headers = (headers != null) ? headers : (this.jwsHeader != null) ? this.jwsHeader : DEFAULT_JWS_HEADER;
130148
JwtClaimsSet claims = parameters.getClaims();
131149

132150
JWK jwk = selectJwk(headers);
@@ -369,4 +387,237 @@ private static URI convertAsURI(String header, URL url) {
369387
}
370388
}
371389

390+
/**
391+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
392+
* {@link SecretKey}.
393+
* @param secretKey the {@link SecretKey} to use for signing JWTs
394+
* @return a {@link SecretKeyJwtEncoderBuilder} for further configuration
395+
* @since 7.0
396+
*/
397+
public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) {
398+
Assert.notNull(secretKey, "secretKey cannot be null");
399+
return new SecretKeyJwtEncoderBuilder(secretKey);
400+
}
401+
402+
/**
403+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
404+
* {@link KeyPair}. The key pair must contain either an {@link RSAKey} or an
405+
* {@link ECKey}.
406+
* @param keyPair the {@link KeyPair} to use for signing JWTs
407+
* @return a {@link KeyPairJwtEncoderBuilder} for further configuration
408+
* @since 7.0
409+
*/
410+
public static KeyPairJwtEncoderBuilder withKeyPair(KeyPair keyPair) {
411+
Assert.notNull(keyPair, "keyPair cannot be null");
412+
Assert.notNull(keyPair.getPrivate(), "keyPair must contain a private key");
413+
Assert.notNull(keyPair.getPublic(), "keyPair must contain a public key");
414+
Assert.isTrue(
415+
keyPair.getPrivate() instanceof java.security.interfaces.RSAKey
416+
|| keyPair.getPrivate() instanceof java.security.interfaces.ECKey,
417+
"keyPair must be an RSAKey or an ECKey");
418+
return new KeyPairJwtEncoderBuilder(keyPair);
419+
}
420+
421+
/**
422+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
423+
* {@link SecretKey}.
424+
*
425+
* @since 7.0
426+
*/
427+
public static final class SecretKeyJwtEncoderBuilder {
428+
429+
private final SecretKey secretKey;
430+
431+
private String keyId;
432+
433+
private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256;
434+
435+
private SecretKeyJwtEncoderBuilder(SecretKey secretKey) {
436+
this.secretKey = secretKey;
437+
}
438+
439+
/**
440+
* Sets the JWS algorithm to use for signing. Defaults to
441+
* {@link JWSAlgorithm#HS256}. Must be an HMAC-based algorithm (HS256, HS384, or
442+
* HS512).
443+
* @param macAlgorithm the {@link MacAlgorithm} to use
444+
* @return this builder instance for method chaining
445+
*/
446+
public SecretKeyJwtEncoderBuilder macAlgorithm(MacAlgorithm macAlgorithm) {
447+
Assert.notNull(macAlgorithm, "macAlgorithm cannot be null");
448+
this.jwsAlgorithm = JWSAlgorithm.parse(macAlgorithm.getName());
449+
return this;
450+
}
451+
452+
/**
453+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
454+
* header.
455+
* @param keyId the key identifier
456+
* @return this builder instance for method chaining
457+
*/
458+
public SecretKeyJwtEncoderBuilder keyId(String keyId) {
459+
this.keyId = keyId;
460+
return this;
461+
}
462+
463+
/**
464+
* Builds the {@link NimbusJwtEncoder} instance.
465+
* @return the configured {@link NimbusJwtEncoder}
466+
* @throws IllegalStateException if the configured JWS algorithm is not compatible
467+
* with a {@link SecretKey}.
468+
*/
469+
public NimbusJwtEncoder build() {
470+
Assert.state(JWSAlgorithm.Family.HMAC_SHA.contains(this.jwsAlgorithm),
471+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with a SecretKey. "
472+
+ "Please use one of the HS256, HS384, or HS512 algorithms.");
473+
this.jwsAlgorithm = (this.jwsAlgorithm != null) ? this.jwsAlgorithm : JWSAlgorithm.HS256;
474+
475+
OctetSequenceKey.Builder builder = new OctetSequenceKey.Builder(this.secretKey).keyUse(KeyUse.SIGNATURE)
476+
.algorithm(this.jwsAlgorithm)
477+
.keyID(this.keyId);
478+
479+
OctetSequenceKey jwk = builder.build();
480+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
481+
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
482+
encoder.setJwsHeader(JwsHeader.with(MacAlgorithm.from(this.jwsAlgorithm.getName())).build());
483+
return encoder;
484+
}
485+
486+
}
487+
488+
/**
489+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
490+
* {@link KeyPair}.
491+
*
492+
* @since 7.0
493+
*/
494+
public static final class KeyPairJwtEncoderBuilder {
495+
496+
private final KeyPair keyPair;
497+
498+
private String keyId;
499+
500+
private JWSAlgorithm jwsAlgorithm;
501+
502+
private JwsHeader jwsHeader;
503+
504+
private KeyPairJwtEncoderBuilder(KeyPair keyPair) {
505+
this.keyPair = keyPair;
506+
}
507+
508+
/**
509+
* Sets the JWS algorithm to use for signing. Must be compatible with the key type
510+
* (RSA or EC). If not set, a default algorithm will be chosen based on the key
511+
* type (e.g., RS256 for RSA, ES256 for EC).
512+
* @param signatureAlgorithm the {@link SignatureAlgorithm} to use
513+
* @return this builder instance for method chaining
514+
*/
515+
public KeyPairJwtEncoderBuilder signatureAlgorithm(SignatureAlgorithm signatureAlgorithm) {
516+
Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null");
517+
this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
518+
return this;
519+
}
520+
521+
/**
522+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
523+
* header.
524+
* @param keyId the key identifier
525+
* @return this builder instance for method chaining
526+
*/
527+
public KeyPairJwtEncoderBuilder keyId(String keyId) {
528+
this.keyId = keyId;
529+
return this;
530+
}
531+
532+
533+
/**
534+
* Builds the {@link NimbusJwtEncoder} instance.
535+
* @return the configured {@link NimbusJwtEncoder}
536+
* @throws IllegalStateException if the key type is unsupported or the configured
537+
* JWS algorithm is not compatible with the key type.
538+
* @throws JwtEncodingException if the key is invalid (e.g., EC key with unknown
539+
* curve)
540+
*/
541+
public NimbusJwtEncoder build() {
542+
this.keyId = (this.keyId != null) ? this.keyId : UUID.randomUUID().toString();
543+
JWK jwk = buildJwk();
544+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
545+
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
546+
if (this.jwsHeader != null) {
547+
encoder.setJwsHeader(this.jwsHeader);
548+
}
549+
return encoder;
550+
}
551+
552+
private JWK buildJwk() {
553+
if (this.keyPair.getPrivate() instanceof RSAPrivateKey) {
554+
RSAKey rsaKey = buildRsaJwk();
555+
this.jwsHeader = JwsHeader.with(SignatureAlgorithm.from(this.jwsAlgorithm.getName()))
556+
.keyId(rsaKey.getKeyID())
557+
.build();
558+
return rsaKey;
559+
}
560+
else if (this.keyPair.getPrivate() instanceof ECPrivateKey) {
561+
ECKey ecKey = buildEcJwk();
562+
this.jwsHeader = JwsHeader.with(SignatureAlgorithm.from(this.jwsAlgorithm.getName()))
563+
.keyId(ecKey.getKeyID())
564+
.build();
565+
return ecKey;
566+
}
567+
else {
568+
throw new IllegalStateException(
569+
"Unsupported key pair type: " + this.keyPair.getPrivate().getClass().getName());
570+
}
571+
}
572+
573+
private RSAKey buildRsaJwk() {
574+
if (this.jwsAlgorithm == null) {
575+
this.jwsAlgorithm = JWSAlgorithm.RS256;
576+
}
577+
Assert.state(JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm),
578+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with an RSAKey. "
579+
+ "Please use one of the RS256, RS384, RS512, PS256, PS384, or PS512 algorithms.");
580+
581+
RSAKey.Builder builder = new RSAKey.Builder(
582+
(java.security.interfaces.RSAPublicKey) this.keyPair.getPublic())
583+
.privateKey(this.keyPair.getPrivate())
584+
.keyUse(KeyUse.SIGNATURE)
585+
.algorithm(this.jwsAlgorithm);
586+
587+
if (StringUtils.hasText(this.keyId)) {
588+
builder.keyID(this.keyId);
589+
}
590+
return builder.build();
591+
}
592+
593+
private com.nimbusds.jose.jwk.ECKey buildEcJwk() {
594+
if (this.jwsAlgorithm == null) {
595+
this.jwsAlgorithm = JWSAlgorithm.ES256;
596+
}
597+
Assert.state(JWSAlgorithm.Family.EC.contains(this.jwsAlgorithm),
598+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with an ECKey. "
599+
+ "Please use one of the ES256, ES384, or ES512 algorithms.");
600+
601+
ECPublicKey publicKey = (ECPublicKey) this.keyPair.getPublic();
602+
Curve curve = Curve.forECParameterSpec(publicKey.getParams());
603+
if (curve == null) {
604+
throw new JwtEncodingException("Unable to determine Curve for EC public key.");
605+
}
606+
607+
com.nimbusds.jose.jwk.ECKey.Builder builder = new com.nimbusds.jose.jwk.ECKey.Builder(curve, publicKey)
608+
.privateKey(this.keyPair.getPrivate())
609+
.keyUse(KeyUse.SIGNATURE)
610+
.keyID(this.keyId)
611+
.algorithm(this.jwsAlgorithm);
612+
613+
try {
614+
return builder.build();
615+
}
616+
catch (IllegalStateException ex) {
617+
throw new IllegalArgumentException("Failed to build ECKey: " + ex.getMessage(), ex);
618+
}
619+
}
620+
621+
}
622+
372623
}

0 commit comments

Comments
 (0)