Skip to content

Commit 50b6daf

Browse files
committed
Add NimbusJwtEncoder Builders
Closes spring-projectsgh-16267 Signed-off-by: Suraj Bhadrike <[email protected]>
1 parent 226e81d commit 50b6daf

File tree

2 files changed

+436
-3
lines changed

2 files changed

+436
-3
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java

Lines changed: 244 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,39 @@
1818

1919
import java.net.URI;
2020
import java.net.URL;
21+
import java.security.KeyPair;
22+
import java.security.interfaces.ECPrivateKey;
23+
import java.security.interfaces.ECPublicKey;
24+
import java.security.interfaces.RSAPrivateKey;
2125
import java.time.Instant;
2226
import java.util.ArrayList;
2327
import java.util.Date;
2428
import java.util.HashMap;
2529
import java.util.List;
2630
import java.util.Map;
2731
import java.util.Set;
32+
import java.util.UUID;
2833
import java.util.concurrent.ConcurrentHashMap;
2934

35+
import javax.crypto.SecretKey;
36+
3037
import com.nimbusds.jose.JOSEException;
3138
import com.nimbusds.jose.JOSEObjectType;
3239
import com.nimbusds.jose.JWSAlgorithm;
3340
import com.nimbusds.jose.JWSHeader;
3441
import com.nimbusds.jose.JWSSigner;
3542
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
43+
import com.nimbusds.jose.jwk.Curve;
44+
import com.nimbusds.jose.jwk.ECKey;
3645
import com.nimbusds.jose.jwk.JWK;
3746
import com.nimbusds.jose.jwk.JWKMatcher;
3847
import com.nimbusds.jose.jwk.JWKSelector;
48+
import com.nimbusds.jose.jwk.JWKSet;
3949
import com.nimbusds.jose.jwk.KeyType;
4050
import com.nimbusds.jose.jwk.KeyUse;
51+
import com.nimbusds.jose.jwk.OctetSequenceKey;
52+
import com.nimbusds.jose.jwk.RSAKey;
53+
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
4154
import com.nimbusds.jose.jwk.source.JWKSource;
4255
import com.nimbusds.jose.proc.SecurityContext;
4356
import com.nimbusds.jose.produce.JWSSignerFactory;
@@ -47,6 +60,7 @@
4760
import com.nimbusds.jwt.SignedJWT;
4861

4962
import org.springframework.core.convert.converter.Converter;
63+
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
5064
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
5165
import org.springframework.util.Assert;
5266
import org.springframework.util.CollectionUtils;
@@ -83,6 +97,8 @@ public final class NimbusJwtEncoder implements JwtEncoder {
8397

8498
private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory();
8599

100+
private JwsHeader jwsHeader;
101+
86102
private final Map<JWK, JWSSigner> jwsSigners = new ConcurrentHashMap<>();
87103

88104
private final JWKSource<SecurityContext> jwkSource;
@@ -119,14 +135,16 @@ public void setJwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
119135
this.jwkSelector = jwkSelector;
120136
}
121137

138+
public void setJwsHeader(JwsHeader jwsHeader) {
139+
this.jwsHeader = jwsHeader;
140+
}
141+
122142
@Override
123143
public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException {
124144
Assert.notNull(parameters, "parameters cannot be null");
125145

126146
JwsHeader headers = parameters.getJwsHeader();
127-
if (headers == null) {
128-
headers = DEFAULT_JWS_HEADER;
129-
}
147+
headers = (headers != null) ? headers : (this.jwsHeader != null) ? this.jwsHeader : DEFAULT_JWS_HEADER;
130148
JwtClaimsSet claims = parameters.getClaims();
131149

132150
JWK jwk = selectJwk(headers);
@@ -369,4 +387,227 @@ private static URI convertAsURI(String header, URL url) {
369387
}
370388
}
371389

390+
/**
391+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
392+
* {@link SecretKey}.
393+
* @param secretKey the {@link SecretKey} to use for signing JWTs
394+
* @return a {@link SecretKeyJwtEncoderBuilder} for further configuration
395+
* @since 7.0
396+
*/
397+
public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) {
398+
Assert.notNull(secretKey, "secretKey cannot be null");
399+
return new SecretKeyJwtEncoderBuilder(secretKey);
400+
}
401+
402+
/**
403+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
404+
* {@link KeyPair}. The key pair must contain either an {@link RSAKey} or an
405+
* {@link ECKey}.
406+
* @param keyPair the {@link KeyPair} to use for signing JWTs
407+
* @return a {@link KeyPairJwtEncoderBuilder} for further configuration
408+
* @since 7.0
409+
*/
410+
public static KeyPairJwtEncoderBuilder withKeyPair(KeyPair keyPair) {
411+
Assert.isTrue(keyPair != null && keyPair.getPrivate() != null && keyPair.getPublic() != null,
412+
"keyPair, its private key, and public key must not be null");
413+
Assert.isTrue(
414+
keyPair.getPrivate() instanceof java.security.interfaces.RSAKey
415+
|| keyPair.getPrivate() instanceof java.security.interfaces.ECKey,
416+
"keyPair must be an RSAKey or an ECKey");
417+
return new KeyPairJwtEncoderBuilder(keyPair);
418+
}
419+
420+
/**
421+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
422+
* {@link SecretKey}.
423+
*
424+
* @since 7.0
425+
*/
426+
public static final class SecretKeyJwtEncoderBuilder {
427+
428+
private final SecretKey secretKey;
429+
430+
private String keyId;
431+
432+
private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256;
433+
434+
private SecretKeyJwtEncoderBuilder(SecretKey secretKey) {
435+
this.secretKey = secretKey;
436+
}
437+
438+
/**
439+
* Sets the JWS algorithm to use for signing. Defaults to
440+
* {@link JWSAlgorithm#HS256}. Must be an HMAC-based algorithm (HS256, HS384, or
441+
* HS512).
442+
* @param macAlgorithm the {@link MacAlgorithm} to use
443+
* @return this builder instance for method chaining
444+
*/
445+
public SecretKeyJwtEncoderBuilder macAlgorithm(MacAlgorithm macAlgorithm) {
446+
Assert.notNull(macAlgorithm, "macAlgorithm cannot be null");
447+
this.jwsAlgorithm = JWSAlgorithm.parse(macAlgorithm.getName());
448+
return this;
449+
}
450+
451+
/**
452+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
453+
* header.
454+
* @param keyId the key identifier
455+
* @return this builder instance for method chaining
456+
*/
457+
public SecretKeyJwtEncoderBuilder keyId(String keyId) {
458+
this.keyId = keyId;
459+
return this;
460+
}
461+
462+
/**
463+
* Builds the {@link NimbusJwtEncoder} instance.
464+
* @return the configured {@link NimbusJwtEncoder}
465+
* @throws IllegalStateException if the configured JWS algorithm is not compatible
466+
* with a {@link SecretKey}.
467+
*/
468+
public NimbusJwtEncoder build() {
469+
Assert.state(JWSAlgorithm.Family.HMAC_SHA.contains(this.jwsAlgorithm),
470+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with a SecretKey. "
471+
+ "Please use one of the HS256, HS384, or HS512 algorithms.");
472+
this.jwsAlgorithm = (this.jwsAlgorithm != null) ? this.jwsAlgorithm : JWSAlgorithm.HS256;
473+
474+
OctetSequenceKey.Builder builder = new OctetSequenceKey.Builder(this.secretKey).keyUse(KeyUse.SIGNATURE)
475+
.algorithm(this.jwsAlgorithm)
476+
.keyID(this.keyId);
477+
478+
OctetSequenceKey jwk = builder.build();
479+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
480+
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
481+
encoder.setJwsHeader(JwsHeader.with(MacAlgorithm.from(this.jwsAlgorithm.getName())).build());
482+
return encoder;
483+
}
484+
485+
}
486+
487+
/**
488+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
489+
* {@link KeyPair}.
490+
*
491+
* @since 7.0
492+
*/
493+
public static final class KeyPairJwtEncoderBuilder {
494+
495+
private final KeyPair keyPair;
496+
497+
private String keyId;
498+
499+
private JWSAlgorithm jwsAlgorithm;
500+
501+
private KeyPairJwtEncoderBuilder(KeyPair keyPair) {
502+
this.keyPair = keyPair;
503+
}
504+
505+
/**
506+
* Sets the JWS algorithm to use for signing. Must be compatible with the key type
507+
* (RSA or EC). If not set, a default algorithm will be chosen based on the key
508+
* type (e.g., RS256 for RSA, ES256 for EC).
509+
* @param signatureAlgorithm the {@link SignatureAlgorithm} to use
510+
* @return this builder instance for method chaining
511+
*/
512+
public KeyPairJwtEncoderBuilder signatureAlgorithm(SignatureAlgorithm signatureAlgorithm) {
513+
Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null");
514+
this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
515+
return this;
516+
}
517+
518+
/**
519+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
520+
* header.
521+
* @param keyId the key identifier
522+
* @return this builder instance for method chaining
523+
*/
524+
public KeyPairJwtEncoderBuilder keyId(String keyId) {
525+
this.keyId = keyId;
526+
return this;
527+
}
528+
529+
530+
/**
531+
* Builds the {@link NimbusJwtEncoder} instance.
532+
* @return the configured {@link NimbusJwtEncoder}
533+
* @throws IllegalStateException if the key type is unsupported or the configured
534+
* JWS algorithm is not compatible with the key type.
535+
* @throws JwtEncodingException if the key is invalid (e.g., EC key with unknown
536+
* curve)
537+
*/
538+
public NimbusJwtEncoder build() {
539+
this.keyId = (this.keyId != null) ? this.keyId : UUID.randomUUID().toString();
540+
JWK jwk = buildJwk();
541+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
542+
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
543+
JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.from(this.jwsAlgorithm.getName()))
544+
.keyId(jwk.getKeyID())
545+
.build();
546+
encoder.setJwsHeader(jwsHeader);
547+
return encoder;
548+
}
549+
550+
private JWK buildJwk() {
551+
if (this.keyPair.getPrivate() instanceof RSAPrivateKey) {
552+
return buildRsaJwk();
553+
}
554+
else if (this.keyPair.getPrivate() instanceof ECPrivateKey) {
555+
return buildEcJwk();
556+
}
557+
else {
558+
throw new IllegalStateException(
559+
"Unsupported key pair type: " + this.keyPair.getPrivate().getClass().getName());
560+
}
561+
}
562+
563+
private RSAKey buildRsaJwk() {
564+
if (this.jwsAlgorithm == null) {
565+
this.jwsAlgorithm = JWSAlgorithm.RS256;
566+
}
567+
Assert.state(JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm),
568+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with an RSAKey. "
569+
+ "Please use one of the RS256, RS384, RS512, PS256, PS384, or PS512 algorithms.");
570+
571+
RSAKey.Builder builder = new RSAKey.Builder(
572+
(java.security.interfaces.RSAPublicKey) this.keyPair.getPublic())
573+
.privateKey(this.keyPair.getPrivate())
574+
.keyUse(KeyUse.SIGNATURE)
575+
.algorithm(this.jwsAlgorithm);
576+
577+
if (StringUtils.hasText(this.keyId)) {
578+
builder.keyID(this.keyId);
579+
}
580+
return builder.build();
581+
}
582+
583+
private com.nimbusds.jose.jwk.ECKey buildEcJwk() {
584+
if (this.jwsAlgorithm == null) {
585+
this.jwsAlgorithm = JWSAlgorithm.ES256;
586+
}
587+
Assert.state(JWSAlgorithm.Family.EC.contains(this.jwsAlgorithm),
588+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with an ECKey. "
589+
+ "Please use one of the ES256, ES384, or ES512 algorithms.");
590+
591+
ECPublicKey publicKey = (ECPublicKey) this.keyPair.getPublic();
592+
Curve curve = Curve.forECParameterSpec(publicKey.getParams());
593+
if (curve == null) {
594+
throw new JwtEncodingException("Unable to determine Curve for EC public key.");
595+
}
596+
597+
com.nimbusds.jose.jwk.ECKey.Builder builder = new com.nimbusds.jose.jwk.ECKey.Builder(curve, publicKey)
598+
.privateKey(this.keyPair.getPrivate())
599+
.keyUse(KeyUse.SIGNATURE)
600+
.keyID(this.keyId)
601+
.algorithm(this.jwsAlgorithm);
602+
603+
try {
604+
return builder.build();
605+
}
606+
catch (IllegalStateException ex) {
607+
throw new IllegalArgumentException("Failed to build ECKey: " + ex.getMessage(), ex);
608+
}
609+
}
610+
611+
}
612+
372613
}

0 commit comments

Comments
 (0)