Skip to content

Commit 6370999

Browse files
committed
NimbusJwtEncoder should simplify constructing with javax.security.Keys Signed-off-by: Suraj Bhadrike <[email protected]>
Signed-off-by: surajbh <[email protected]>
1 parent 226e81d commit 6370999

File tree

2 files changed

+529
-21
lines changed

2 files changed

+529
-21
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java

Lines changed: 323 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,42 +16,55 @@
1616

1717
package org.springframework.security.oauth2.jwt;
1818

19-
import java.net.URI;
20-
import java.net.URL;
21-
import java.time.Instant;
22-
import java.util.ArrayList;
23-
import java.util.Date;
24-
import java.util.HashMap;
25-
import java.util.List;
26-
import java.util.Map;
27-
import java.util.Set;
28-
import java.util.concurrent.ConcurrentHashMap;
29-
3019
import com.nimbusds.jose.JOSEException;
3120
import com.nimbusds.jose.JOSEObjectType;
3221
import com.nimbusds.jose.JWSAlgorithm;
3322
import com.nimbusds.jose.JWSHeader;
3423
import com.nimbusds.jose.JWSSigner;
3524
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
25+
import com.nimbusds.jose.jwk.Curve;
26+
import com.nimbusds.jose.jwk.ECKey;
3627
import com.nimbusds.jose.jwk.JWK;
3728
import com.nimbusds.jose.jwk.JWKMatcher;
3829
import com.nimbusds.jose.jwk.JWKSelector;
30+
import com.nimbusds.jose.jwk.JWKSet;
3931
import com.nimbusds.jose.jwk.KeyType;
4032
import com.nimbusds.jose.jwk.KeyUse;
33+
import com.nimbusds.jose.jwk.OctetSequenceKey;
34+
import com.nimbusds.jose.jwk.RSAKey;
35+
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
4136
import com.nimbusds.jose.jwk.source.JWKSource;
4237
import com.nimbusds.jose.proc.SecurityContext;
4338
import com.nimbusds.jose.produce.JWSSignerFactory;
4439
import com.nimbusds.jose.util.Base64;
4540
import com.nimbusds.jose.util.Base64URL;
4641
import com.nimbusds.jwt.JWTClaimsSet;
4742
import com.nimbusds.jwt.SignedJWT;
48-
4943
import org.springframework.core.convert.converter.Converter;
44+
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
5045
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
5146
import org.springframework.util.Assert;
5247
import org.springframework.util.CollectionUtils;
5348
import org.springframework.util.StringUtils;
5449

50+
import javax.crypto.SecretKey;
51+
import java.net.URI;
52+
import java.net.URL;
53+
import java.security.KeyPair;
54+
import java.security.interfaces.ECPrivateKey;
55+
import java.security.interfaces.ECPublicKey;
56+
import java.security.interfaces.RSAPrivateKey;
57+
import java.time.Instant;
58+
import java.util.ArrayList;
59+
import java.util.Date;
60+
import java.util.HashMap;
61+
import java.util.List;
62+
import java.util.Map;
63+
import java.util.Set;
64+
import java.util.UUID;
65+
import java.util.concurrent.ConcurrentHashMap;
66+
import java.util.stream.Stream;
67+
5568
/**
5669
* An implementation of a {@link JwtEncoder} that encodes a JSON Web Token (JWT) using the
5770
* JSON Web Signature (JWS) Compact Serialization format. The private/secret key used for
@@ -79,7 +92,7 @@ public final class NimbusJwtEncoder implements JwtEncoder {
7992

8093
private static final String ENCODING_ERROR_MESSAGE_TEMPLATE = "An error occurred while attempting to encode the Jwt: %s";
8194

82-
private static final JwsHeader DEFAULT_JWS_HEADER = JwsHeader.with(SignatureAlgorithm.RS256).build();
95+
private static JwsHeader DEFAULT_JWS_HEADER;
8396

8497
private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory();
8598

@@ -102,6 +115,12 @@ public final class NimbusJwtEncoder implements JwtEncoder {
102115
public NimbusJwtEncoder(JWKSource<SecurityContext> jwkSource) {
103116
Assert.notNull(jwkSource, "jwkSource cannot be null");
104117
this.jwkSource = jwkSource;
118+
this.DEFAULT_JWS_HEADER = JwsHeader.with(SignatureAlgorithm.RS256).build();
119+
}
120+
public NimbusJwtEncoder(JWKSource<SecurityContext> jwkSource, JwsHeader jwsHeader) {
121+
Assert.notNull(jwkSource, "jwkSource cannot be null");
122+
this.jwkSource = jwkSource;
123+
this.DEFAULT_JWS_HEADER = jwsHeader;
105124
}
106125

107126
/**
@@ -369,4 +388,295 @@ private static URI convertAsURI(String header, URL url) {
369388
}
370389
}
371390

391+
/**
392+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
393+
* {@link SecretKey}.
394+
*
395+
* @param secretKey the {@link SecretKey} to use for signing JWTs
396+
* @return a {@link SecretKeyJwtEncoderBuilder} for further configuration
397+
* @since // TODO: Update version
398+
*/
399+
public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) {
400+
Assert.notNull(secretKey, "secretKey cannot be null");
401+
return new SecretKeyJwtEncoderBuilder(secretKey);
402+
}
403+
404+
/**
405+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
406+
* {@link KeyPair}. The key pair must contain either an {@link RSAKey} or an
407+
* {@link ECKey}.
408+
*
409+
* @param keyPair the {@link KeyPair} to use for signing JWTs
410+
* @return a {@link KeyPairJwtEncoderBuilder} for further configuration
411+
* @since // TODO: Update version
412+
*/
413+
public static KeyPairJwtEncoderBuilder withKeyPair(KeyPair keyPair) {
414+
Assert.notNull(keyPair, "keyPair cannot be null");
415+
Assert.notNull(keyPair.getPrivate(), "keyPair must contain a private key");
416+
Assert.notNull(keyPair.getPublic(), "keyPair must contain a public key");
417+
Assert.isTrue(keyPair.getPrivate() instanceof java.security.interfaces.RSAKey || keyPair.getPrivate() instanceof java.security.interfaces.ECKey,
418+
"keyPair must be an RSAKey or an ECKey");
419+
return new KeyPairJwtEncoderBuilder(keyPair);
420+
}
421+
422+
/**
423+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
424+
* {@link SecretKey}.
425+
*
426+
* @since // TODO: Update version
427+
*/
428+
public static final class SecretKeyJwtEncoderBuilder {
429+
430+
private final SecretKey secretKey;
431+
432+
private String keyId;
433+
434+
private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256;
435+
436+
private Converter<List<JWK>, JWK> jwkSelector;
437+
438+
private JwsHeader jwsHeader;
439+
440+
private SecretKeyJwtEncoderBuilder(SecretKey secretKey) {
441+
this.secretKey = secretKey;
442+
}
443+
444+
/**
445+
* Sets the JWS algorithm to use for signing. Defaults to {@link JWSAlgorithm#HS256}.
446+
* Must be an HMAC-based algorithm (HS256, HS384, or HS512).
447+
*
448+
* @param macAlgorithm the {@link MacAlgorithm} to use
449+
* @return this builder instance for method chaining
450+
*/
451+
public SecretKeyJwtEncoderBuilder macAlgorithm(MacAlgorithm macAlgorithm) {
452+
Assert.notNull(macAlgorithm, "macAlgorithm cannot be null");
453+
this.jwsAlgorithm = JWSAlgorithm.parse(macAlgorithm.getName());
454+
return this;
455+
}
456+
457+
/**
458+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
459+
* header.
460+
*
461+
* @param keyId the key identifier
462+
* @return this builder instance for method chaining
463+
*/
464+
public SecretKeyJwtEncoderBuilder keyId(String keyId) {
465+
this.keyId = keyId;
466+
return this;
467+
}
468+
469+
/**
470+
* Configures the {@link Converter} used to select the JWK when multiple keys match
471+
* the header criteria. This is generally not needed for single-key setups but is
472+
* provided for consistency.
473+
*
474+
* @param jwkSelector the {@link Converter} to select a {@link JWK}
475+
* @return this builder instance for method chaining
476+
* @since // TODO: Update version
477+
*/
478+
public SecretKeyJwtEncoderBuilder jwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
479+
Assert.notNull(jwkSelector, "jwkSelector cannot be null");
480+
this.jwkSelector = jwkSelector;
481+
return this;
482+
}
483+
484+
485+
/**
486+
* Builds the {@link NimbusJwtEncoder} instance.
487+
*
488+
* @return the configured {@link NimbusJwtEncoder}
489+
* @throws IllegalStateException if the configured JWS algorithm is not compatible
490+
* with a {@link SecretKey}.
491+
*/
492+
public NimbusJwtEncoder build() {
493+
this.jwsAlgorithm = this.jwsAlgorithm != null ? this.jwsAlgorithm : JWSAlgorithm.HS256;
494+
Assert.state(JWSAlgorithm.Family.HMAC_SHA.contains(this.jwsAlgorithm),
495+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with a SecretKey. "
496+
+ "Please use one of the HS256, HS384, or HS512 algorithms.");
497+
498+
OctetSequenceKey.Builder builder = new OctetSequenceKey.Builder(this.secretKey)
499+
.keyUse(KeyUse.SIGNATURE)
500+
.algorithm(this.jwsAlgorithm);
501+
502+
this.jwsHeader = JwsHeader.with(MacAlgorithm.from(this.jwsAlgorithm.getName())).build();
503+
504+
if (StringUtils.hasText(this.keyId)) {
505+
builder.keyID(this.keyId);
506+
}
507+
508+
OctetSequenceKey jwk = builder.build();
509+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
510+
NimbusJwtEncoder encoder = this.jwsHeader != null
511+
? new NimbusJwtEncoder(jwkSource, this.jwsHeader)
512+
: new NimbusJwtEncoder(jwkSource);
513+
if (this.jwkSelector != null) {
514+
encoder.setJwkSelector(this.jwkSelector);
515+
} else {
516+
encoder.setJwkSelector(jwkSet -> jwkSet.stream().findFirst().orElse(null));
517+
}
518+
return encoder;
519+
}
520+
521+
}
522+
523+
/**
524+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
525+
* {@link KeyPair}.
526+
*
527+
* @since // TODO: Update version
528+
*/
529+
public static final class KeyPairJwtEncoderBuilder {
530+
531+
private final KeyPair keyPair;
532+
533+
private String keyId;
534+
535+
private JWSAlgorithm jwsAlgorithm;
536+
537+
private Converter<List<JWK>, JWK> jwkSelector;
538+
539+
private JwsHeader jwsHeader;
540+
541+
private static boolean isEcAlgorithm(SignatureAlgorithm algorithm) {
542+
return Stream.of(SignatureAlgorithm.ES256, SignatureAlgorithm.ES384, SignatureAlgorithm.ES512)
543+
.anyMatch(ecAlg -> ecAlg == algorithm);
544+
}
545+
546+
private KeyPairJwtEncoderBuilder(KeyPair keyPair) {
547+
this.keyPair = keyPair;
548+
}
549+
550+
/**
551+
* Sets the JWS algorithm to use for signing. Must be compatible with the key type
552+
* (RSA or EC). If not set, a default algorithm will be chosen based on the key
553+
* type (e.g., RS256 for RSA, ES256 for EC).
554+
*
555+
* @param signatureAlgorithm the {@link SignatureAlgorithm} to use
556+
* @return this builder instance for method chaining
557+
*/
558+
public KeyPairJwtEncoderBuilder signatureAlgorithm(SignatureAlgorithm signatureAlgorithm) {
559+
Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null");
560+
this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
561+
return this;
562+
}
563+
564+
/**
565+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
566+
* header.
567+
*
568+
* @param keyId the key identifier
569+
* @return this builder instance for method chaining
570+
*/
571+
public KeyPairJwtEncoderBuilder keyId(String keyId) {
572+
this.keyId = keyId;
573+
return this;
574+
}
575+
576+
/**
577+
* Configures the {@link Converter} used to select the JWK when multiple keys match
578+
* the header criteria. This is generally not needed for single-key setups but is
579+
* provided for consistency.
580+
*
581+
* @param jwkSelector the {@link Converter} to select a {@link JWK}
582+
* @return this builder instance for method chaining
583+
* @since // TODO: Update version
584+
*/
585+
public KeyPairJwtEncoderBuilder jwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
586+
Assert.notNull(jwkSelector, "jwkSelector cannot be null");
587+
this.jwkSelector = jwkSelector;
588+
return this;
589+
}
590+
591+
/**
592+
* Builds the {@link NimbusJwtEncoder} instance.
593+
*
594+
* @return the configured {@link NimbusJwtEncoder}
595+
* @throws IllegalStateException if the key type is unsupported or the configured
596+
* JWS algorithm is not compatible with the key type.
597+
* @throws JwtEncodingException if the key is invalid (e.g., EC key with unknown curve)
598+
*/
599+
public NimbusJwtEncoder build() {
600+
this.keyId = this.keyId != null ? this.keyId : UUID.randomUUID().toString();
601+
JWK jwk = buildJwk();
602+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
603+
NimbusJwtEncoder encoder = this.jwsHeader == null ?
604+
new NimbusJwtEncoder(jwkSource) :
605+
new NimbusJwtEncoder(jwkSource, this.jwsHeader);
606+
if (this.jwkSelector != null) {
607+
encoder.setJwkSelector(this.jwkSelector);
608+
} else {
609+
encoder.setJwkSelector(jwkSet -> jwkSet.stream().findFirst().orElse(null));
610+
}
611+
return encoder;
612+
}
613+
614+
private JWK buildJwk() {
615+
if (this.keyPair.getPrivate() instanceof RSAPrivateKey) {
616+
RSAKey rsaKey = buildRsaJwk();
617+
this.jwsHeader = JwsHeader.with(SignatureAlgorithm.from(jwsAlgorithm.getName())) // Use Spring's enum
618+
.keyId(rsaKey.getKeyID())
619+
.build();
620+
return rsaKey;
621+
} else if (this.keyPair.getPrivate() instanceof ECPrivateKey) {
622+
ECKey ecKey = buildEcJwk();
623+
this.jwsHeader = JwsHeader.with(SignatureAlgorithm.from(jwsAlgorithm.getName())) // Use Spring's enum
624+
.keyId(ecKey.getKeyID())
625+
.build();
626+
return ecKey;
627+
} else {
628+
throw new IllegalStateException("Unsupported key pair type: " + this.keyPair.getPrivate().getClass().getName());
629+
}
630+
}
631+
632+
private RSAKey buildRsaJwk() {
633+
if (this.jwsAlgorithm == null) {
634+
this.jwsAlgorithm = JWSAlgorithm.RS256;
635+
}
636+
Assert.state(JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm),
637+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with an RSAKey. "
638+
+ "Please use one of the RS256, RS384, RS512, PS256, PS384, or PS512 algorithms.");
639+
640+
RSAKey.Builder builder = new RSAKey.Builder((java.security.interfaces.RSAPublicKey) this.keyPair.getPublic())
641+
.privateKey(this.keyPair.getPrivate())
642+
.keyUse(KeyUse.SIGNATURE)
643+
.algorithm(this.jwsAlgorithm);
644+
645+
if (StringUtils.hasText(this.keyId)) {
646+
builder.keyID(this.keyId);
647+
}
648+
return builder.build();
649+
}
650+
651+
private com.nimbusds.jose.jwk.ECKey buildEcJwk() {
652+
if (this.jwsAlgorithm == null) {
653+
this.jwsAlgorithm = JWSAlgorithm.ES256;
654+
}
655+
Assert.state(JWSAlgorithm.Family.EC.contains(this.jwsAlgorithm),
656+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with an ECKey. "
657+
+ "Please use one of the ES256, ES384, or ES512 algorithms.");
658+
659+
ECPublicKey publicKey = (ECPublicKey) this.keyPair.getPublic();
660+
Curve curve = Curve.forECParameterSpec(publicKey.getParams());
661+
if (curve == null) {
662+
throw new JwtEncodingException("Unable to determine Curve for EC public key.");
663+
}
664+
665+
com.nimbusds.jose.jwk.ECKey.Builder builder = new com.nimbusds.jose.jwk.ECKey.Builder(curve, publicKey)
666+
.privateKey(this.keyPair.getPrivate())
667+
.keyUse(KeyUse.SIGNATURE)
668+
.algorithm(this.jwsAlgorithm);
669+
670+
if (StringUtils.hasText(this.keyId)) {
671+
builder.keyID(this.keyId);
672+
}
673+
674+
try {
675+
return builder.build();
676+
} catch (IllegalStateException ex) {
677+
throw new JwtEncodingException("Failed to build ECKey: " + ex.getMessage(), ex);
678+
}
679+
}
680+
}
681+
372682
}

0 commit comments

Comments
 (0)