Skip to content

Commit 52ef97d

Browse files
committed
Add NimbusJwtEncoder Builders
Closes gh-16267 Signed-off-by: Suraj Bhadrike <[email protected]>
1 parent ff8b77d commit 52ef97d

File tree

2 files changed

+454
-3
lines changed

2 files changed

+454
-3
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java

+264-3
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,37 @@
1818

1919
import java.net.URI;
2020
import java.net.URL;
21+
import java.security.KeyPair;
22+
import java.security.interfaces.ECPublicKey;
2123
import java.time.Instant;
2224
import java.util.ArrayList;
2325
import java.util.Date;
2426
import java.util.HashMap;
2527
import java.util.List;
2628
import java.util.Map;
2729
import java.util.Set;
30+
import java.util.UUID;
2831
import java.util.concurrent.ConcurrentHashMap;
2932

33+
import javax.crypto.SecretKey;
34+
3035
import com.nimbusds.jose.JOSEException;
3136
import com.nimbusds.jose.JOSEObjectType;
3237
import com.nimbusds.jose.JWSAlgorithm;
3338
import com.nimbusds.jose.JWSHeader;
3439
import com.nimbusds.jose.JWSSigner;
3540
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
41+
import com.nimbusds.jose.jwk.Curve;
42+
import com.nimbusds.jose.jwk.ECKey;
3643
import com.nimbusds.jose.jwk.JWK;
3744
import com.nimbusds.jose.jwk.JWKMatcher;
3845
import com.nimbusds.jose.jwk.JWKSelector;
46+
import com.nimbusds.jose.jwk.JWKSet;
3947
import com.nimbusds.jose.jwk.KeyType;
4048
import com.nimbusds.jose.jwk.KeyUse;
49+
import com.nimbusds.jose.jwk.OctetSequenceKey;
50+
import com.nimbusds.jose.jwk.RSAKey;
51+
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
4152
import com.nimbusds.jose.jwk.source.JWKSource;
4253
import com.nimbusds.jose.proc.SecurityContext;
4354
import com.nimbusds.jose.produce.JWSSignerFactory;
@@ -47,6 +58,7 @@
4758
import com.nimbusds.jwt.SignedJWT;
4859

4960
import org.springframework.core.convert.converter.Converter;
61+
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
5062
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
5163
import org.springframework.util.Assert;
5264
import org.springframework.util.CollectionUtils;
@@ -83,6 +95,8 @@ public final class NimbusJwtEncoder implements JwtEncoder {
8395

8496
private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory();
8597

98+
private JwsHeader jwsHeader;
99+
86100
private final Map<JWK, JWSSigner> jwsSigners = new ConcurrentHashMap<>();
87101

88102
private final JWKSource<SecurityContext> jwkSource;
@@ -119,14 +133,16 @@ public void setJwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
119133
this.jwkSelector = jwkSelector;
120134
}
121135

136+
public void setJwsHeader(JwsHeader jwsHeader) {
137+
this.jwsHeader = jwsHeader;
138+
}
139+
122140
@Override
123141
public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException {
124142
Assert.notNull(parameters, "parameters cannot be null");
125143

126144
JwsHeader headers = parameters.getJwsHeader();
127-
if (headers == null) {
128-
headers = DEFAULT_JWS_HEADER;
129-
}
145+
headers = (headers != null) ? headers : (this.jwsHeader != null) ? this.jwsHeader : DEFAULT_JWS_HEADER;
130146
JwtClaimsSet claims = parameters.getClaims();
131147

132148
JWK jwk = selectJwk(headers);
@@ -369,4 +385,249 @@ private static URI convertAsURI(String header, URL url) {
369385
}
370386
}
371387

388+
/**
389+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
390+
* {@link SecretKey}.
391+
* @param secretKey the {@link SecretKey} to use for signing JWTs
392+
* @return a {@link SecretKeyJwtEncoderBuilder} for further configuration
393+
* @since 7.0
394+
*/
395+
public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) {
396+
Assert.notNull(secretKey, "secretKey cannot be null");
397+
return new SecretKeyJwtEncoderBuilder(secretKey);
398+
}
399+
400+
/**
401+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
402+
* {@link KeyPair}. The key pair must contain either an {@link RSAKey} or an
403+
* {@link ECKey}.
404+
* @param keyPair the {@link KeyPair} to use for signing JWTs
405+
* @return a {@link KeyPairJwtEncoderBuilder} for further configuration
406+
* @since 7.0
407+
*/
408+
public static KeyPairJwtEncoderBuilder withKeyPair(KeyPair keyPair) {
409+
Assert.isTrue(keyPair != null && keyPair.getPrivate() != null && keyPair.getPublic() != null,
410+
"keyPair, its private key, and public key must not be null");
411+
Assert.isTrue(
412+
keyPair.getPrivate() instanceof java.security.interfaces.RSAKey
413+
|| keyPair.getPrivate() instanceof java.security.interfaces.ECKey,
414+
"keyPair must be an RSAKey or an ECKey");
415+
if (keyPair.getPrivate() instanceof java.security.interfaces.RSAKey) {
416+
return new RsaKeyPairJwtEncoderBuilder(keyPair);
417+
}
418+
if (keyPair.getPrivate() instanceof java.security.interfaces.ECKey) {
419+
return new EcKeyPairJwtEncoderBuilder(keyPair);
420+
}
421+
throw new IllegalArgumentException("keyPair must be an RSAKey or an ECKey");
422+
}
423+
424+
/**
425+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
426+
* {@link SecretKey}.
427+
*
428+
* @since 7.0
429+
*/
430+
public static final class SecretKeyJwtEncoderBuilder {
431+
432+
private final SecretKey secretKey;
433+
434+
private String keyId;
435+
436+
private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256;
437+
438+
private SecretKeyJwtEncoderBuilder(SecretKey secretKey) {
439+
this.secretKey = secretKey;
440+
}
441+
442+
/**
443+
* Sets the JWS algorithm to use for signing. Defaults to
444+
* {@link JWSAlgorithm#HS256}. Must be an HMAC-based algorithm (HS256, HS384, or
445+
* HS512).
446+
* @param macAlgorithm the {@link MacAlgorithm} to use
447+
* @return this builder instance for method chaining
448+
*/
449+
public SecretKeyJwtEncoderBuilder macAlgorithm(MacAlgorithm macAlgorithm) {
450+
Assert.notNull(macAlgorithm, "macAlgorithm cannot be null");
451+
Assert.state(JWSAlgorithm.Family.HMAC_SHA.contains(this.jwsAlgorithm),
452+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with a SecretKey. "
453+
+ "Please use one of the HS256, HS384, or HS512 algorithms.");
454+
455+
this.jwsAlgorithm = JWSAlgorithm.parse(macAlgorithm.getName());
456+
return this;
457+
}
458+
459+
/**
460+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
461+
* header.
462+
* @param keyId the key identifier
463+
* @return this builder instance for method chaining
464+
*/
465+
public SecretKeyJwtEncoderBuilder keyId(String keyId) {
466+
this.keyId = keyId;
467+
return this;
468+
}
469+
470+
/**
471+
* Builds the {@link NimbusJwtEncoder} instance.
472+
* @return the configured {@link NimbusJwtEncoder}
473+
* @throws IllegalStateException if the configured JWS algorithm is not compatible
474+
* with a {@link SecretKey}.
475+
*/
476+
public NimbusJwtEncoder build() {
477+
this.jwsAlgorithm = (this.jwsAlgorithm != null) ? this.jwsAlgorithm : JWSAlgorithm.HS256;
478+
479+
OctetSequenceKey.Builder builder = new OctetSequenceKey.Builder(this.secretKey).keyUse(KeyUse.SIGNATURE)
480+
.algorithm(this.jwsAlgorithm)
481+
.keyID(this.keyId);
482+
483+
OctetSequenceKey jwk = builder.build();
484+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
485+
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
486+
encoder.setJwsHeader(JwsHeader.with(MacAlgorithm.from(this.jwsAlgorithm.getName())).build());
487+
return encoder;
488+
}
489+
490+
}
491+
492+
/**
493+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
494+
* {@link KeyPair}.
495+
*
496+
* @since 7.0
497+
*/
498+
public abstract static class KeyPairJwtEncoderBuilder {
499+
500+
private final KeyPair keyPair;
501+
502+
private String keyId;
503+
504+
private JWSAlgorithm jwsAlgorithm;
505+
506+
private KeyPairJwtEncoderBuilder(KeyPair keyPair) {
507+
this.keyPair = keyPair;
508+
}
509+
510+
/**
511+
* Sets the JWS algorithm to use for signing. Must be compatible with the key type
512+
* (RSA or EC). If not set, a default algorithm will be chosen based on the key
513+
* type (e.g., RS256 for RSA, ES256 for EC).
514+
* @param signatureAlgorithm the {@link SignatureAlgorithm} to use
515+
* @return this builder instance for method chaining
516+
*/
517+
public KeyPairJwtEncoderBuilder signatureAlgorithm(SignatureAlgorithm signatureAlgorithm) {
518+
Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null");
519+
this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
520+
return this;
521+
}
522+
523+
/**
524+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
525+
* header.
526+
* @param keyId the key identifier
527+
* @return this builder instance for method chaining
528+
*/
529+
public KeyPairJwtEncoderBuilder keyId(String keyId) {
530+
this.keyId = keyId;
531+
return this;
532+
}
533+
534+
/**
535+
* Builds the {@link NimbusJwtEncoder} instance.
536+
* @return the configured {@link NimbusJwtEncoder}
537+
* @throws IllegalStateException if the key type is unsupported or the configured
538+
* JWS algorithm is not compatible with the key type.
539+
* @throws JwtEncodingException if the key is invalid (e.g., EC key with unknown
540+
* curve)
541+
*/
542+
public NimbusJwtEncoder build() {
543+
this.keyId = (this.keyId != null) ? this.keyId : UUID.randomUUID().toString();
544+
JWK jwk = buildJwk();
545+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
546+
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
547+
JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.from(this.jwsAlgorithm.getName()))
548+
.keyId(jwk.getKeyID())
549+
.build();
550+
encoder.setJwsHeader(jwsHeader);
551+
return encoder;
552+
}
553+
554+
protected abstract JWK buildJwk();
555+
556+
}
557+
558+
/**
559+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
560+
* {@link KeyPair}.
561+
*
562+
* @since 7.0
563+
*/
564+
public static final class RsaKeyPairJwtEncoderBuilder extends KeyPairJwtEncoderBuilder {
565+
566+
private RsaKeyPairJwtEncoderBuilder(KeyPair keyPair) {
567+
super(keyPair);
568+
}
569+
570+
@Override
571+
protected JWK buildJwk() {
572+
if (super.jwsAlgorithm == null) {
573+
super.jwsAlgorithm = JWSAlgorithm.RS256;
574+
}
575+
Assert.state(JWSAlgorithm.Family.RSA.contains(super.jwsAlgorithm),
576+
() -> "The algorithm '" + super.jwsAlgorithm + "' is not compatible with an RSAKey. "
577+
+ "Please use one of the RS256, RS384, RS512, PS256, PS384, or PS512 algorithms.");
578+
579+
RSAKey.Builder builder = new RSAKey.Builder(
580+
(java.security.interfaces.RSAPublicKey) super.keyPair.getPublic())
581+
.privateKey(super.keyPair.getPrivate())
582+
.keyID(super.keyId)
583+
.keyUse(KeyUse.SIGNATURE)
584+
.algorithm(super.jwsAlgorithm);
585+
return builder.build();
586+
}
587+
588+
}
589+
590+
/**
591+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
592+
* {@link KeyPair}.
593+
*
594+
* @since 7.0
595+
*/
596+
public static final class EcKeyPairJwtEncoderBuilder extends KeyPairJwtEncoderBuilder {
597+
598+
private EcKeyPairJwtEncoderBuilder(KeyPair keyPair) {
599+
super(keyPair);
600+
}
601+
602+
@Override
603+
protected JWK buildJwk() {
604+
if (super.jwsAlgorithm == null) {
605+
super.jwsAlgorithm = JWSAlgorithm.ES256;
606+
}
607+
Assert.state(JWSAlgorithm.Family.EC.contains(super.jwsAlgorithm),
608+
() -> "The algorithm '" + super.jwsAlgorithm + "' is not compatible with an ECKey. "
609+
+ "Please use one of the ES256, ES384, or ES512 algorithms.");
610+
611+
ECPublicKey publicKey = (ECPublicKey) super.keyPair.getPublic();
612+
Curve curve = Curve.forECParameterSpec(publicKey.getParams());
613+
if (curve == null) {
614+
throw new JwtEncodingException("Unable to determine Curve for EC public key.");
615+
}
616+
617+
com.nimbusds.jose.jwk.ECKey.Builder builder = new com.nimbusds.jose.jwk.ECKey.Builder(curve, publicKey)
618+
.privateKey(super.keyPair.getPrivate())
619+
.keyUse(KeyUse.SIGNATURE)
620+
.keyID(super.keyId)
621+
.algorithm(super.jwsAlgorithm);
622+
623+
try {
624+
return builder.build();
625+
}
626+
catch (IllegalStateException ex) {
627+
throw new IllegalArgumentException("Failed to build ECKey: " + ex.getMessage(), ex);
628+
}
629+
}
630+
631+
}
632+
372633
}

0 commit comments

Comments
 (0)