Skip to content

Commit 78ef577

Browse files
committed
NimbusJwtEncoder should simplify constructing with javax.security.Keys Signed-off-by: Suraj Bhadrike <[email protected]>
Signed-off-by: surajbh <[email protected]>
1 parent 226e81d commit 78ef577

File tree

2 files changed

+518
-45
lines changed

2 files changed

+518
-45
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java

+314-23
Original file line numberDiff line numberDiff line change
@@ -16,42 +16,36 @@
1616

1717
package org.springframework.security.oauth2.jwt;
1818

19-
import java.net.URI;
20-
import java.net.URL;
21-
import java.time.Instant;
22-
import java.util.ArrayList;
23-
import java.util.Date;
24-
import java.util.HashMap;
25-
import java.util.List;
26-
import java.util.Map;
27-
import java.util.Set;
28-
import java.util.concurrent.ConcurrentHashMap;
29-
30-
import com.nimbusds.jose.JOSEException;
31-
import com.nimbusds.jose.JOSEObjectType;
32-
import com.nimbusds.jose.JWSAlgorithm;
33-
import com.nimbusds.jose.JWSHeader;
34-
import com.nimbusds.jose.JWSSigner;
19+
import com.nimbusds.jose.*;
3520
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
36-
import com.nimbusds.jose.jwk.JWK;
37-
import com.nimbusds.jose.jwk.JWKMatcher;
38-
import com.nimbusds.jose.jwk.JWKSelector;
39-
import com.nimbusds.jose.jwk.KeyType;
40-
import com.nimbusds.jose.jwk.KeyUse;
21+
import com.nimbusds.jose.jwk.*;
22+
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
4123
import com.nimbusds.jose.jwk.source.JWKSource;
4224
import com.nimbusds.jose.proc.SecurityContext;
4325
import com.nimbusds.jose.produce.JWSSignerFactory;
4426
import com.nimbusds.jose.util.Base64;
4527
import com.nimbusds.jose.util.Base64URL;
4628
import com.nimbusds.jwt.JWTClaimsSet;
4729
import com.nimbusds.jwt.SignedJWT;
48-
4930
import org.springframework.core.convert.converter.Converter;
31+
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
5032
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
5133
import org.springframework.util.Assert;
5234
import org.springframework.util.CollectionUtils;
5335
import org.springframework.util.StringUtils;
5436

37+
import javax.crypto.SecretKey;
38+
import java.net.URI;
39+
import java.net.URL;
40+
import java.security.KeyPair;
41+
import java.security.interfaces.ECPrivateKey;
42+
import java.security.interfaces.ECPublicKey;
43+
import java.security.interfaces.RSAPrivateKey;
44+
import java.time.Instant;
45+
import java.util.*;
46+
import java.util.concurrent.ConcurrentHashMap;
47+
import java.util.stream.Stream;
48+
5549
/**
5650
* An implementation of a {@link JwtEncoder} that encodes a JSON Web Token (JWT) using the
5751
* JSON Web Signature (JWS) Compact Serialization format. The private/secret key used for
@@ -79,7 +73,7 @@ public final class NimbusJwtEncoder implements JwtEncoder {
7973

8074
private static final String ENCODING_ERROR_MESSAGE_TEMPLATE = "An error occurred while attempting to encode the Jwt: %s";
8175

82-
private static final JwsHeader DEFAULT_JWS_HEADER = JwsHeader.with(SignatureAlgorithm.RS256).build();
76+
private static JwsHeader DEFAULT_JWS_HEADER;
8377

8478
private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory();
8579

@@ -102,6 +96,12 @@ public final class NimbusJwtEncoder implements JwtEncoder {
10296
public NimbusJwtEncoder(JWKSource<SecurityContext> jwkSource) {
10397
Assert.notNull(jwkSource, "jwkSource cannot be null");
10498
this.jwkSource = jwkSource;
99+
this.DEFAULT_JWS_HEADER = JwsHeader.with(SignatureAlgorithm.RS256).build();
100+
}
101+
public NimbusJwtEncoder(JWKSource<SecurityContext> jwkSource, JwsHeader jwsHeader) {
102+
Assert.notNull(jwkSource, "jwkSource cannot be null");
103+
this.jwkSource = jwkSource;
104+
this.DEFAULT_JWS_HEADER = jwsHeader;
105105
}
106106

107107
/**
@@ -369,4 +369,295 @@ private static URI convertAsURI(String header, URL url) {
369369
}
370370
}
371371

372+
/**
373+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
374+
* {@link SecretKey}.
375+
*
376+
* @param secretKey the {@link SecretKey} to use for signing JWTs
377+
* @return a {@link SecretKeyJwtEncoderBuilder} for further configuration
378+
* @since // TODO: Update version
379+
*/
380+
public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) {
381+
Assert.notNull(secretKey, "secretKey cannot be null");
382+
return new SecretKeyJwtEncoderBuilder(secretKey);
383+
}
384+
385+
/**
386+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
387+
* {@link KeyPair}. The key pair must contain either an {@link RSAKey} or an
388+
* {@link ECKey}.
389+
*
390+
* @param keyPair the {@link KeyPair} to use for signing JWTs
391+
* @return a {@link KeyPairJwtEncoderBuilder} for further configuration
392+
* @since // TODO: Update version
393+
*/
394+
public static KeyPairJwtEncoderBuilder withKeyPair(KeyPair keyPair) {
395+
Assert.notNull(keyPair, "keyPair cannot be null");
396+
Assert.notNull(keyPair.getPrivate(), "keyPair must contain a private key");
397+
Assert.notNull(keyPair.getPublic(), "keyPair must contain a public key");
398+
Assert.isTrue(keyPair.getPrivate() instanceof java.security.interfaces.RSAKey || keyPair.getPrivate() instanceof java.security.interfaces.ECKey,
399+
"keyPair must be an RSAKey or an ECKey");
400+
return new KeyPairJwtEncoderBuilder(keyPair);
401+
}
402+
403+
/**
404+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
405+
* {@link SecretKey}.
406+
*
407+
* @since // TODO: Update version
408+
*/
409+
public static final class SecretKeyJwtEncoderBuilder {
410+
411+
private final SecretKey secretKey;
412+
413+
private String keyId;
414+
415+
private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256;
416+
417+
private Converter<List<JWK>, JWK> jwkSelector;
418+
419+
private JwsHeader jwsHeader;
420+
421+
private SecretKeyJwtEncoderBuilder(SecretKey secretKey) {
422+
this.secretKey = secretKey;
423+
}
424+
425+
/**
426+
* Sets the JWS algorithm to use for signing. Defaults to {@link JWSAlgorithm#HS256}.
427+
* Must be an HMAC-based algorithm (HS256, HS384, or HS512).
428+
*
429+
* @param macAlgorithm the {@link MacAlgorithm} to use
430+
* @return this builder instance for method chaining
431+
*/
432+
public SecretKeyJwtEncoderBuilder macAlgorithm(MacAlgorithm macAlgorithm) {
433+
Assert.notNull(macAlgorithm, "macAlgorithm cannot be null");
434+
this.jwsAlgorithm = JWSAlgorithm.parse(macAlgorithm.getName());
435+
return this;
436+
}
437+
438+
/**
439+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
440+
* header.
441+
*
442+
* @param keyId the key identifier
443+
* @return this builder instance for method chaining
444+
*/
445+
public SecretKeyJwtEncoderBuilder keyId(String keyId) {
446+
this.keyId = keyId;
447+
return this;
448+
}
449+
450+
/**
451+
* Configures the {@link Converter} used to select the JWK when multiple keys match
452+
* the header criteria. This is generally not needed for single-key setups but is
453+
* provided for consistency.
454+
*
455+
* @param jwkSelector the {@link Converter} to select a {@link JWK}
456+
* @return this builder instance for method chaining
457+
* @since // TODO: Update version
458+
*/
459+
public SecretKeyJwtEncoderBuilder jwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
460+
Assert.notNull(jwkSelector, "jwkSelector cannot be null");
461+
this.jwkSelector = jwkSelector;
462+
return this;
463+
}
464+
465+
466+
/**
467+
* Builds the {@link NimbusJwtEncoder} instance.
468+
*
469+
* @return the configured {@link NimbusJwtEncoder}
470+
* @throws IllegalStateException if the configured JWS algorithm is not compatible
471+
* with a {@link SecretKey}.
472+
*/
473+
public NimbusJwtEncoder build() {
474+
this.jwsAlgorithm = this.jwsAlgorithm != null ? this.jwsAlgorithm : JWSAlgorithm.HS256;
475+
Assert.state(JWSAlgorithm.Family.HMAC_SHA.contains(this.jwsAlgorithm),
476+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with a SecretKey. "
477+
+ "Please use one of the HS256, HS384, or HS512 algorithms.");
478+
479+
OctetSequenceKey.Builder builder = new OctetSequenceKey.Builder(this.secretKey)
480+
.keyUse(KeyUse.SIGNATURE)
481+
.algorithm(this.jwsAlgorithm);
482+
483+
this.jwsHeader = JwsHeader.with(MacAlgorithm.from(this.jwsAlgorithm.getName())).build();
484+
485+
if (StringUtils.hasText(this.keyId)) {
486+
builder.keyID(this.keyId);
487+
}
488+
489+
OctetSequenceKey jwk = builder.build();
490+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
491+
NimbusJwtEncoder encoder = this.jwsHeader != null
492+
? new NimbusJwtEncoder(jwkSource, this.jwsHeader)
493+
: new NimbusJwtEncoder(jwkSource);
494+
if (this.jwkSelector != null) {
495+
encoder.setJwkSelector(this.jwkSelector);
496+
} else {
497+
encoder.setJwkSelector(jwkSet -> jwkSet.stream().findFirst().orElse(null));
498+
}
499+
return encoder;
500+
}
501+
502+
}
503+
504+
/**
505+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
506+
* {@link KeyPair}.
507+
*
508+
* @since // TODO: Update version
509+
*/
510+
public static final class KeyPairJwtEncoderBuilder {
511+
512+
private final KeyPair keyPair;
513+
514+
private String keyId;
515+
516+
private JWSAlgorithm jwsAlgorithm;
517+
518+
private Converter<List<JWK>, JWK> jwkSelector;
519+
520+
private JwsHeader jwsHeader;
521+
522+
private static boolean isEcAlgorithm(SignatureAlgorithm algorithm) {
523+
return Stream.of(SignatureAlgorithm.ES256, SignatureAlgorithm.ES384, SignatureAlgorithm.ES512)
524+
.anyMatch(ecAlg -> ecAlg == algorithm);
525+
}
526+
527+
private KeyPairJwtEncoderBuilder(KeyPair keyPair) {
528+
this.keyPair = keyPair;
529+
}
530+
531+
/**
532+
* Sets the JWS algorithm to use for signing. Must be compatible with the key type
533+
* (RSA or EC). If not set, a default algorithm will be chosen based on the key
534+
* type (e.g., RS256 for RSA, ES256 for EC).
535+
*
536+
* @param signatureAlgorithm the {@link SignatureAlgorithm} to use
537+
* @return this builder instance for method chaining
538+
*/
539+
public KeyPairJwtEncoderBuilder signatureAlgorithm(SignatureAlgorithm signatureAlgorithm) {
540+
Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null");
541+
this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
542+
return this;
543+
}
544+
545+
/**
546+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
547+
* header.
548+
*
549+
* @param keyId the key identifier
550+
* @return this builder instance for method chaining
551+
*/
552+
public KeyPairJwtEncoderBuilder keyId(String keyId) {
553+
this.keyId = keyId;
554+
return this;
555+
}
556+
557+
/**
558+
* Configures the {@link Converter} used to select the JWK when multiple keys match
559+
* the header criteria. This is generally not needed for single-key setups but is
560+
* provided for consistency.
561+
*
562+
* @param jwkSelector the {@link Converter} to select a {@link JWK}
563+
* @return this builder instance for method chaining
564+
* @since // TODO: Update version
565+
*/
566+
public KeyPairJwtEncoderBuilder jwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
567+
Assert.notNull(jwkSelector, "jwkSelector cannot be null");
568+
this.jwkSelector = jwkSelector;
569+
return this;
570+
}
571+
572+
/**
573+
* Builds the {@link NimbusJwtEncoder} instance.
574+
*
575+
* @return the configured {@link NimbusJwtEncoder}
576+
* @throws IllegalStateException if the key type is unsupported or the configured
577+
* JWS algorithm is not compatible with the key type.
578+
* @throws JwtEncodingException if the key is invalid (e.g., EC key with unknown curve)
579+
*/
580+
public NimbusJwtEncoder build() {
581+
this.keyId = this.keyId != null ? this.keyId : UUID.randomUUID().toString();
582+
JWK jwk = buildJwk();
583+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
584+
NimbusJwtEncoder encoder = this.jwsHeader == null ?
585+
new NimbusJwtEncoder(jwkSource) :
586+
new NimbusJwtEncoder(jwkSource, this.jwsHeader);
587+
if (this.jwkSelector != null) {
588+
encoder.setJwkSelector(this.jwkSelector);
589+
} else {
590+
encoder.setJwkSelector(jwkSet -> jwkSet.stream().findFirst().orElse(null));
591+
}
592+
return encoder;
593+
}
594+
595+
private JWK buildJwk() {
596+
if (this.keyPair.getPrivate() instanceof RSAPrivateKey) {
597+
RSAKey rsaKey = buildRsaJwk();
598+
this.jwsHeader = JwsHeader.with(SignatureAlgorithm.from(jwsAlgorithm.getName())) // Use Spring's enum
599+
.keyId(rsaKey.getKeyID())
600+
.build();
601+
return rsaKey;
602+
} else if (this.keyPair.getPrivate() instanceof ECPrivateKey) {
603+
ECKey ecKey = buildEcJwk();
604+
this.jwsHeader = JwsHeader.with(SignatureAlgorithm.from(jwsAlgorithm.getName())) // Use Spring's enum
605+
.keyId(ecKey.getKeyID())
606+
.build();
607+
return ecKey;
608+
} else {
609+
throw new IllegalStateException("Unsupported key pair type: " + this.keyPair.getPrivate().getClass().getName());
610+
}
611+
}
612+
613+
private RSAKey buildRsaJwk() {
614+
if (this.jwsAlgorithm == null) {
615+
this.jwsAlgorithm = JWSAlgorithm.RS256;
616+
}
617+
Assert.state(JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm),
618+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with an RSAKey. "
619+
+ "Please use one of the RS256, RS384, RS512, PS256, PS384, or PS512 algorithms.");
620+
621+
RSAKey.Builder builder = new RSAKey.Builder((java.security.interfaces.RSAPublicKey) this.keyPair.getPublic())
622+
.privateKey(this.keyPair.getPrivate())
623+
.keyUse(KeyUse.SIGNATURE)
624+
.algorithm(this.jwsAlgorithm);
625+
626+
if (StringUtils.hasText(this.keyId)) {
627+
builder.keyID(this.keyId);
628+
}
629+
return builder.build();
630+
}
631+
632+
private com.nimbusds.jose.jwk.ECKey buildEcJwk() {
633+
if (this.jwsAlgorithm == null) {
634+
this.jwsAlgorithm = JWSAlgorithm.ES256;
635+
}
636+
Assert.state(JWSAlgorithm.Family.EC.contains(this.jwsAlgorithm),
637+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with an ECKey. "
638+
+ "Please use one of the ES256, ES384, or ES512 algorithms.");
639+
640+
ECPublicKey publicKey = (ECPublicKey) this.keyPair.getPublic();
641+
Curve curve = Curve.forECParameterSpec(publicKey.getParams());
642+
if (curve == null) {
643+
throw new JwtEncodingException("Unable to determine Curve for EC public key.");
644+
}
645+
646+
com.nimbusds.jose.jwk.ECKey.Builder builder = new com.nimbusds.jose.jwk.ECKey.Builder(curve, publicKey)
647+
.privateKey(this.keyPair.getPrivate())
648+
.keyUse(KeyUse.SIGNATURE)
649+
.algorithm(this.jwsAlgorithm);
650+
651+
if (StringUtils.hasText(this.keyId)) {
652+
builder.keyID(this.keyId);
653+
}
654+
655+
try {
656+
return builder.build();
657+
} catch (IllegalStateException ex) {
658+
throw new JwtEncodingException("Failed to build ECKey: " + ex.getMessage(), ex);
659+
}
660+
}
661+
}
662+
372663
}

0 commit comments

Comments
 (0)