Skip to content

Commit d6154cb

Browse files
committed
Add NimbusJwtEncoder Builders
Closes gh-16267 Signed-off-by: Suraj Bhadrike <[email protected]>
1 parent ff8b77d commit d6154cb

File tree

2 files changed

+453
-3
lines changed

2 files changed

+453
-3
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java

Lines changed: 263 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,37 @@
1818

1919
import java.net.URI;
2020
import java.net.URL;
21+
import java.security.KeyPair;
22+
import java.security.interfaces.ECPublicKey;
2123
import java.time.Instant;
2224
import java.util.ArrayList;
2325
import java.util.Date;
2426
import java.util.HashMap;
2527
import java.util.List;
2628
import java.util.Map;
2729
import java.util.Set;
30+
import java.util.UUID;
2831
import java.util.concurrent.ConcurrentHashMap;
2932

33+
import javax.crypto.SecretKey;
34+
3035
import com.nimbusds.jose.JOSEException;
3136
import com.nimbusds.jose.JOSEObjectType;
3237
import com.nimbusds.jose.JWSAlgorithm;
3338
import com.nimbusds.jose.JWSHeader;
3439
import com.nimbusds.jose.JWSSigner;
3540
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
41+
import com.nimbusds.jose.jwk.Curve;
42+
import com.nimbusds.jose.jwk.ECKey;
3643
import com.nimbusds.jose.jwk.JWK;
3744
import com.nimbusds.jose.jwk.JWKMatcher;
3845
import com.nimbusds.jose.jwk.JWKSelector;
46+
import com.nimbusds.jose.jwk.JWKSet;
3947
import com.nimbusds.jose.jwk.KeyType;
4048
import com.nimbusds.jose.jwk.KeyUse;
49+
import com.nimbusds.jose.jwk.OctetSequenceKey;
50+
import com.nimbusds.jose.jwk.RSAKey;
51+
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
4152
import com.nimbusds.jose.jwk.source.JWKSource;
4253
import com.nimbusds.jose.proc.SecurityContext;
4354
import com.nimbusds.jose.produce.JWSSignerFactory;
@@ -47,6 +58,7 @@
4758
import com.nimbusds.jwt.SignedJWT;
4859

4960
import org.springframework.core.convert.converter.Converter;
61+
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
5062
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
5163
import org.springframework.util.Assert;
5264
import org.springframework.util.CollectionUtils;
@@ -83,6 +95,8 @@ public final class NimbusJwtEncoder implements JwtEncoder {
8395

8496
private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory();
8597

98+
private JwsHeader jwsHeader;
99+
86100
private final Map<JWK, JWSSigner> jwsSigners = new ConcurrentHashMap<>();
87101

88102
private final JWKSource<SecurityContext> jwkSource;
@@ -119,14 +133,16 @@ public void setJwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
119133
this.jwkSelector = jwkSelector;
120134
}
121135

136+
public void setJwsHeader(JwsHeader jwsHeader) {
137+
this.jwsHeader = jwsHeader;
138+
}
139+
122140
@Override
123141
public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException {
124142
Assert.notNull(parameters, "parameters cannot be null");
125143

126144
JwsHeader headers = parameters.getJwsHeader();
127-
if (headers == null) {
128-
headers = DEFAULT_JWS_HEADER;
129-
}
145+
headers = (headers != null) ? headers : (this.jwsHeader != null) ? this.jwsHeader : DEFAULT_JWS_HEADER;
130146
JwtClaimsSet claims = parameters.getClaims();
131147

132148
JWK jwk = selectJwk(headers);
@@ -369,4 +385,248 @@ private static URI convertAsURI(String header, URL url) {
369385
}
370386
}
371387

388+
/**
389+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
390+
* {@link SecretKey}.
391+
* @param secretKey the {@link SecretKey} to use for signing JWTs
392+
* @return a {@link SecretKeyJwtEncoderBuilder} for further configuration
393+
* @since 7.0
394+
*/
395+
public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) {
396+
Assert.notNull(secretKey, "secretKey cannot be null");
397+
return new SecretKeyJwtEncoderBuilder(secretKey);
398+
}
399+
400+
/**
401+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
402+
* {@link KeyPair}. The key pair must contain either an {@link RSAKey} or an
403+
* {@link ECKey}.
404+
* @param keyPair the {@link KeyPair} to use for signing JWTs
405+
* @return a {@link KeyPairJwtEncoderBuilder} for further configuration
406+
* @since 7.0
407+
*/
408+
public static KeyPairJwtEncoderBuilder withKeyPair(KeyPair keyPair) {
409+
Assert.isTrue(keyPair != null && keyPair.getPrivate() != null && keyPair.getPublic() != null,
410+
"keyPair, its private key, and public key must not be null");
411+
Assert.isTrue(
412+
keyPair.getPrivate() instanceof java.security.interfaces.RSAKey
413+
|| keyPair.getPrivate() instanceof java.security.interfaces.ECKey,
414+
"keyPair must be an RSAKey or an ECKey");
415+
if (keyPair.getPrivate() instanceof java.security.interfaces.RSAKey) {
416+
return new RsaKeyPairJwtEncoderBuilder(keyPair);
417+
}
418+
if (keyPair.getPrivate() instanceof java.security.interfaces.ECKey) {
419+
return new EcKeyPairJwtEncoderBuilder(keyPair);
420+
}
421+
throw new IllegalArgumentException("keyPair must be an RSAKey or an ECKey");
422+
}
423+
424+
/**
425+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
426+
* {@link SecretKey}.
427+
*
428+
* @since 7.0
429+
*/
430+
public static final class SecretKeyJwtEncoderBuilder {
431+
432+
private final SecretKey secretKey;
433+
434+
private String keyId;
435+
436+
private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256;
437+
438+
private SecretKeyJwtEncoderBuilder(SecretKey secretKey) {
439+
this.secretKey = secretKey;
440+
}
441+
442+
/**
443+
* Sets the JWS algorithm to use for signing. Defaults to
444+
* {@link JWSAlgorithm#HS256}. Must be an HMAC-based algorithm (HS256, HS384, or
445+
* HS512).
446+
* @param macAlgorithm the {@link MacAlgorithm} to use
447+
* @return this builder instance for method chaining
448+
*/
449+
public SecretKeyJwtEncoderBuilder macAlgorithm(MacAlgorithm macAlgorithm) {
450+
Assert.notNull(macAlgorithm, "macAlgorithm cannot be null");
451+
this.jwsAlgorithm = JWSAlgorithm.parse(macAlgorithm.getName());
452+
return this;
453+
}
454+
455+
/**
456+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
457+
* header.
458+
* @param keyId the key identifier
459+
* @return this builder instance for method chaining
460+
*/
461+
public SecretKeyJwtEncoderBuilder keyId(String keyId) {
462+
this.keyId = keyId;
463+
return this;
464+
}
465+
466+
/**
467+
* Builds the {@link NimbusJwtEncoder} instance.
468+
* @return the configured {@link NimbusJwtEncoder}
469+
* @throws IllegalStateException if the configured JWS algorithm is not compatible
470+
* with a {@link SecretKey}.
471+
*/
472+
public NimbusJwtEncoder build() {
473+
Assert.state(JWSAlgorithm.Family.HMAC_SHA.contains(this.jwsAlgorithm),
474+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with a SecretKey. "
475+
+ "Please use one of the HS256, HS384, or HS512 algorithms.");
476+
this.jwsAlgorithm = (this.jwsAlgorithm != null) ? this.jwsAlgorithm : JWSAlgorithm.HS256;
477+
478+
OctetSequenceKey.Builder builder = new OctetSequenceKey.Builder(this.secretKey).keyUse(KeyUse.SIGNATURE)
479+
.algorithm(this.jwsAlgorithm)
480+
.keyID(this.keyId);
481+
482+
OctetSequenceKey jwk = builder.build();
483+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
484+
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
485+
encoder.setJwsHeader(JwsHeader.with(MacAlgorithm.from(this.jwsAlgorithm.getName())).build());
486+
return encoder;
487+
}
488+
489+
}
490+
491+
/**
492+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
493+
* {@link KeyPair}.
494+
*
495+
* @since 7.0
496+
*/
497+
public abstract static class KeyPairJwtEncoderBuilder {
498+
499+
private final KeyPair keyPair;
500+
501+
private String keyId;
502+
503+
private JWSAlgorithm jwsAlgorithm;
504+
505+
private KeyPairJwtEncoderBuilder(KeyPair keyPair) {
506+
this.keyPair = keyPair;
507+
}
508+
509+
/**
510+
* Sets the JWS algorithm to use for signing. Must be compatible with the key type
511+
* (RSA or EC). If not set, a default algorithm will be chosen based on the key
512+
* type (e.g., RS256 for RSA, ES256 for EC).
513+
* @param signatureAlgorithm the {@link SignatureAlgorithm} to use
514+
* @return this builder instance for method chaining
515+
*/
516+
public KeyPairJwtEncoderBuilder signatureAlgorithm(SignatureAlgorithm signatureAlgorithm) {
517+
Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null");
518+
this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
519+
return this;
520+
}
521+
522+
/**
523+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
524+
* header.
525+
* @param keyId the key identifier
526+
* @return this builder instance for method chaining
527+
*/
528+
public KeyPairJwtEncoderBuilder keyId(String keyId) {
529+
this.keyId = keyId;
530+
return this;
531+
}
532+
533+
/**
534+
* Builds the {@link NimbusJwtEncoder} instance.
535+
* @return the configured {@link NimbusJwtEncoder}
536+
* @throws IllegalStateException if the key type is unsupported or the configured
537+
* JWS algorithm is not compatible with the key type.
538+
* @throws JwtEncodingException if the key is invalid (e.g., EC key with unknown
539+
* curve)
540+
*/
541+
public NimbusJwtEncoder build() {
542+
this.keyId = (this.keyId != null) ? this.keyId : UUID.randomUUID().toString();
543+
JWK jwk = buildJwk();
544+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
545+
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
546+
JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.from(this.jwsAlgorithm.getName()))
547+
.keyId(jwk.getKeyID())
548+
.build();
549+
encoder.setJwsHeader(jwsHeader);
550+
return encoder;
551+
}
552+
553+
protected abstract JWK buildJwk();
554+
555+
}
556+
557+
/**
558+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
559+
* {@link KeyPair}.
560+
*
561+
* @since 7.0
562+
*/
563+
public static final class RsaKeyPairJwtEncoderBuilder extends KeyPairJwtEncoderBuilder {
564+
565+
private RsaKeyPairJwtEncoderBuilder(KeyPair keyPair) {
566+
super(keyPair);
567+
}
568+
569+
@Override
570+
protected JWK buildJwk() {
571+
if (super.jwsAlgorithm == null) {
572+
super.jwsAlgorithm = JWSAlgorithm.RS256;
573+
}
574+
Assert.state(JWSAlgorithm.Family.RSA.contains(super.jwsAlgorithm),
575+
() -> "The algorithm '" + super.jwsAlgorithm + "' is not compatible with an RSAKey. "
576+
+ "Please use one of the RS256, RS384, RS512, PS256, PS384, or PS512 algorithms.");
577+
578+
RSAKey.Builder builder = new RSAKey.Builder(
579+
(java.security.interfaces.RSAPublicKey) super.keyPair.getPublic())
580+
.privateKey(super.keyPair.getPrivate())
581+
.keyID(super.keyId)
582+
.keyUse(KeyUse.SIGNATURE)
583+
.algorithm(super.jwsAlgorithm);
584+
return builder.build();
585+
}
586+
587+
}
588+
589+
/**
590+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
591+
* {@link KeyPair}.
592+
*
593+
* @since 7.0
594+
*/
595+
public static final class EcKeyPairJwtEncoderBuilder extends KeyPairJwtEncoderBuilder {
596+
597+
private EcKeyPairJwtEncoderBuilder(KeyPair keyPair) {
598+
super(keyPair);
599+
}
600+
601+
@Override
602+
protected JWK buildJwk() {
603+
if (super.jwsAlgorithm == null) {
604+
super.jwsAlgorithm = JWSAlgorithm.ES256;
605+
}
606+
Assert.state(JWSAlgorithm.Family.EC.contains(super.jwsAlgorithm),
607+
() -> "The algorithm '" + super.jwsAlgorithm + "' is not compatible with an ECKey. "
608+
+ "Please use one of the ES256, ES384, or ES512 algorithms.");
609+
610+
ECPublicKey publicKey = (ECPublicKey) super.keyPair.getPublic();
611+
Curve curve = Curve.forECParameterSpec(publicKey.getParams());
612+
if (curve == null) {
613+
throw new JwtEncodingException("Unable to determine Curve for EC public key.");
614+
}
615+
616+
com.nimbusds.jose.jwk.ECKey.Builder builder = new com.nimbusds.jose.jwk.ECKey.Builder(curve, publicKey)
617+
.privateKey(super.keyPair.getPrivate())
618+
.keyUse(KeyUse.SIGNATURE)
619+
.keyID(super.keyId)
620+
.algorithm(super.jwsAlgorithm);
621+
622+
try {
623+
return builder.build();
624+
}
625+
catch (IllegalStateException ex) {
626+
throw new IllegalArgumentException("Failed to build ECKey: " + ex.getMessage(), ex);
627+
}
628+
}
629+
630+
}
631+
372632
}

0 commit comments

Comments
 (0)