Skip to content

Commit abddfb3

Browse files
committed
NimbusJwtEncoder should simplify constructing with javax.security.Keys Signed-off-by: Suraj Bhadrike <[email protected]>
Signed-off-by: surajbh <[email protected]>
1 parent 226e81d commit abddfb3

File tree

1 file changed

+318
-22
lines changed

1 file changed

+318
-22
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java

+318-22
Original file line numberDiff line numberDiff line change
@@ -16,42 +16,39 @@
1616

1717
package org.springframework.security.oauth2.jwt;
1818

19-
import java.net.URI;
20-
import java.net.URL;
21-
import java.time.Instant;
22-
import java.util.ArrayList;
23-
import java.util.Date;
24-
import java.util.HashMap;
25-
import java.util.List;
26-
import java.util.Map;
27-
import java.util.Set;
28-
import java.util.concurrent.ConcurrentHashMap;
29-
30-
import com.nimbusds.jose.JOSEException;
31-
import com.nimbusds.jose.JOSEObjectType;
32-
import com.nimbusds.jose.JWSAlgorithm;
33-
import com.nimbusds.jose.JWSHeader;
34-
import com.nimbusds.jose.JWSSigner;
19+
import com.nimbusds.jose.*;
3520
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
36-
import com.nimbusds.jose.jwk.JWK;
37-
import com.nimbusds.jose.jwk.JWKMatcher;
38-
import com.nimbusds.jose.jwk.JWKSelector;
39-
import com.nimbusds.jose.jwk.KeyType;
40-
import com.nimbusds.jose.jwk.KeyUse;
21+
import com.nimbusds.jose.jwk.*;
22+
import com.nimbusds.jose.jwk.gen.OctetSequenceKeyGenerator;
23+
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
4124
import com.nimbusds.jose.jwk.source.JWKSource;
4225
import com.nimbusds.jose.proc.SecurityContext;
4326
import com.nimbusds.jose.produce.JWSSignerFactory;
4427
import com.nimbusds.jose.util.Base64;
4528
import com.nimbusds.jose.util.Base64URL;
4629
import com.nimbusds.jwt.JWTClaimsSet;
4730
import com.nimbusds.jwt.SignedJWT;
48-
4931
import org.springframework.core.convert.converter.Converter;
32+
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
5033
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
5134
import org.springframework.util.Assert;
5235
import org.springframework.util.CollectionUtils;
5336
import org.springframework.util.StringUtils;
5437

38+
import javax.crypto.SecretKey;
39+
import javax.crypto.spec.SecretKeySpec;
40+
import java.net.URI;
41+
import java.net.URL;
42+
import java.security.KeyPair;
43+
import java.security.KeyPairGenerator;
44+
import java.security.NoSuchAlgorithmException;
45+
import java.security.interfaces.ECPrivateKey;
46+
import java.security.interfaces.ECPublicKey;
47+
import java.security.interfaces.RSAPrivateKey;
48+
import java.time.Instant;
49+
import java.util.*;
50+
import java.util.concurrent.ConcurrentHashMap;
51+
5552
/**
5653
* An implementation of a {@link JwtEncoder} that encodes a JSON Web Token (JWT) using the
5754
* JSON Web Signature (JWS) Compact Serialization format. The private/secret key used for
@@ -369,4 +366,303 @@ private static URI convertAsURI(String header, URL url) {
369366
}
370367
}
371368

369+
/**
370+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
371+
* {@link SecretKey}.
372+
*
373+
* @param secretKey the {@link SecretKey} to use for signing JWTs
374+
* @return a {@link SecretKeyJwtEncoderBuilder} for further configuration
375+
* @since // TODO: Update version
376+
*/
377+
public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) {
378+
Assert.notNull(secretKey, "secretKey cannot be null");
379+
return new SecretKeyJwtEncoderBuilder(secretKey);
380+
}
381+
382+
/**
383+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
384+
* {@link KeyPair}. The key pair must contain either an {@link RSAKey} or an
385+
* {@link ECKey}.
386+
*
387+
* @param keyPair the {@link KeyPair} to use for signing JWTs
388+
* @return a {@link KeyPairJwtEncoderBuilder} for further configuration
389+
* @since // TODO: Update version
390+
*/
391+
public static KeyPairJwtEncoderBuilder withKeyPair(KeyPair keyPair) {
392+
Assert.notNull(keyPair, "keyPair cannot be null");
393+
Assert.notNull(keyPair.getPrivate(), "keyPair must contain a private key");
394+
Assert.notNull(keyPair.getPublic(), "keyPair must contain a public key");
395+
Assert.isTrue(keyPair.getPrivate() instanceof java.security.interfaces.RSAKey || keyPair.getPrivate() instanceof java.security.interfaces.ECKey,
396+
"keyPair must be an RSAKey or an ECKey");
397+
return new KeyPairJwtEncoderBuilder(keyPair);
398+
}
399+
400+
/**
401+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
402+
* {@link SecretKey}.
403+
*
404+
* @since // TODO: Update version
405+
*/
406+
public static final class SecretKeyJwtEncoderBuilder {
407+
408+
private final SecretKey secretKey;
409+
410+
private String keyId;
411+
412+
private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256;
413+
414+
private Converter<List<JWK>, JWK> jwkSelector;
415+
416+
private SecretKeyJwtEncoderBuilder(SecretKey secretKey) {
417+
this.secretKey = secretKey;
418+
}
419+
420+
/**
421+
* Sets the JWS algorithm to use for signing. Defaults to {@link JWSAlgorithm#HS256}.
422+
* Must be an HMAC-based algorithm (HS256, HS384, or HS512).
423+
*
424+
* @param macAlgorithm the {@link MacAlgorithm} to use
425+
* @return this builder instance for method chaining
426+
*/
427+
public SecretKeyJwtEncoderBuilder macAlgorithm(MacAlgorithm macAlgorithm) {
428+
Assert.notNull(macAlgorithm, "macAlgorithm cannot be null");
429+
this.jwsAlgorithm = JWSAlgorithm.parse(macAlgorithm.getName());
430+
return this;
431+
}
432+
433+
/**
434+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
435+
* header.
436+
*
437+
* @param keyId the key identifier
438+
* @return this builder instance for method chaining
439+
*/
440+
public SecretKeyJwtEncoderBuilder keyId(String keyId) {
441+
this.keyId = keyId;
442+
return this;
443+
}
444+
445+
/**
446+
* Configures the {@link Converter} used to select the JWK when multiple keys match
447+
* the header criteria. This is generally not needed for single-key setups but is
448+
* provided for consistency.
449+
*
450+
* @param jwkSelector the {@link Converter} to select a {@link JWK}
451+
* @return this builder instance for method chaining
452+
* @since // TODO: Update version
453+
*/
454+
public SecretKeyJwtEncoderBuilder jwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
455+
Assert.notNull(jwkSelector, "jwkSelector cannot be null");
456+
this.jwkSelector = jwkSelector;
457+
return this;
458+
}
459+
460+
461+
/**
462+
* Builds the {@link NimbusJwtEncoder} instance.
463+
*
464+
* @return the configured {@link NimbusJwtEncoder}
465+
* @throws IllegalStateException if the configured JWS algorithm is not compatible
466+
* with a {@link SecretKey}.
467+
*/
468+
public NimbusJwtEncoder build() {
469+
Assert.state(JWSAlgorithm.Family.HMAC_SHA.contains(this.jwsAlgorithm),
470+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with a SecretKey. "
471+
+ "Please use one of the HS256, HS384, or HS512 algorithms.");
472+
473+
OctetSequenceKey.Builder builder = new OctetSequenceKey.Builder(this.secretKey)
474+
.keyUse(KeyUse.SIGNATURE)
475+
.algorithm(this.jwsAlgorithm);
476+
477+
if (StringUtils.hasText(this.keyId)) {
478+
builder.keyID(this.keyId);
479+
}
480+
481+
OctetSequenceKey jwk = builder.build();
482+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
483+
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
484+
if (this.jwkSelector != null) {
485+
encoder.setJwkSelector(this.jwkSelector);
486+
} else {
487+
encoder.setJwkSelector(jwkSet -> jwkSet.stream().findFirst().orElse(null));
488+
}
489+
return encoder;
490+
}
491+
492+
}
493+
494+
/**
495+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
496+
* {@link KeyPair}.
497+
*
498+
* @since // TODO: Update version
499+
*/
500+
public static final class KeyPairJwtEncoderBuilder {
501+
502+
private final KeyPair keyPair;
503+
504+
private String keyId;
505+
506+
private JWSAlgorithm jwsAlgorithm;
507+
508+
private Converter<List<JWK>, JWK> jwkSelector;
509+
510+
private KeyPairJwtEncoderBuilder(KeyPair keyPair) {
511+
this.keyPair = keyPair;
512+
}
513+
514+
/**
515+
* Sets the JWS algorithm to use for signing. Must be compatible with the key type
516+
* (RSA or EC). If not set, a default algorithm will be chosen based on the key
517+
* type (e.g., RS256 for RSA, ES256 for EC).
518+
*
519+
* @param signatureAlgorithm the {@link SignatureAlgorithm} to use
520+
* @return this builder instance for method chaining
521+
*/
522+
public KeyPairJwtEncoderBuilder signatureAlgorithm(SignatureAlgorithm signatureAlgorithm) {
523+
Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null");
524+
this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
525+
return this;
526+
}
527+
528+
/**
529+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
530+
* header.
531+
*
532+
* @param keyId the key identifier
533+
* @return this builder instance for method chaining
534+
*/
535+
public KeyPairJwtEncoderBuilder keyId(String keyId) {
536+
this.keyId = keyId;
537+
return this;
538+
}
539+
540+
/**
541+
* Configures the {@link Converter} used to select the JWK when multiple keys match
542+
* the header criteria. This is generally not needed for single-key setups but is
543+
* provided for consistency.
544+
*
545+
* @param jwkSelector the {@link Converter} to select a {@link JWK}
546+
* @return this builder instance for method chaining
547+
* @since // TODO: Update version
548+
*/
549+
public KeyPairJwtEncoderBuilder jwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
550+
Assert.notNull(jwkSelector, "jwkSelector cannot be null");
551+
this.jwkSelector = jwkSelector;
552+
return this;
553+
}
554+
555+
/**
556+
* Builds the {@link NimbusJwtEncoder} instance.
557+
*
558+
* @return the configured {@link NimbusJwtEncoder}
559+
* @throws IllegalStateException if the key type is unsupported or the configured
560+
* JWS algorithm is not compatible with the key type.
561+
* @throws JwtEncodingException if the key is invalid (e.g., EC key with unknown curve)
562+
*/
563+
public NimbusJwtEncoder build() {
564+
JWK jwk = buildJwk();
565+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
566+
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
567+
if (this.jwkSelector != null) {
568+
encoder.setJwkSelector(this.jwkSelector);
569+
} else {
570+
encoder.setJwkSelector(jwkSet -> jwkSet.stream().findFirst().orElse(null));
571+
}
572+
return encoder;
573+
}
574+
575+
private JWK buildJwk() {
576+
if (this.keyPair.getPrivate() instanceof RSAPrivateKey) {
577+
return buildRsaJwk();
578+
} else if (this.keyPair.getPrivate() instanceof ECPrivateKey) {
579+
return buildEcJwk();
580+
} else {
581+
throw new IllegalStateException("Unsupported key pair type: " + this.keyPair.getPrivate().getClass().getName());
582+
}
583+
}
584+
585+
private RSAKey buildRsaJwk() {
586+
if (this.jwsAlgorithm == null) {
587+
this.jwsAlgorithm = JWSAlgorithm.RS256;
588+
}
589+
Assert.state(JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm),
590+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with an RSAKey. "
591+
+ "Please use one of the RS256, RS384, RS512, PS256, PS384, or PS512 algorithms.");
592+
593+
RSAKey.Builder builder = new RSAKey.Builder((java.security.interfaces.RSAPublicKey) this.keyPair.getPublic())
594+
.privateKey(this.keyPair.getPrivate())
595+
.keyUse(KeyUse.SIGNATURE)
596+
.algorithm(this.jwsAlgorithm);
597+
598+
if (StringUtils.hasText(this.keyId)) {
599+
builder.keyID(this.keyId);
600+
}
601+
return builder.build();
602+
}
603+
604+
private com.nimbusds.jose.jwk.ECKey buildEcJwk() {
605+
if (this.jwsAlgorithm == null) {
606+
this.jwsAlgorithm = JWSAlgorithm.ES256;
607+
}
608+
Assert.state(JWSAlgorithm.Family.EC.contains(this.jwsAlgorithm),
609+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with an ECKey. "
610+
+ "Please use one of the ES256, ES384, or ES512 algorithms.");
611+
612+
ECPublicKey publicKey = (ECPublicKey) this.keyPair.getPublic();
613+
Curve curve = Curve.forECParameterSpec(publicKey.getParams());
614+
if (curve == null) {
615+
throw new JwtEncodingException("Unable to determine Curve for EC public key.");
616+
}
617+
618+
com.nimbusds.jose.jwk.ECKey.Builder builder = new com.nimbusds.jose.jwk.ECKey.Builder(curve, publicKey)
619+
.privateKey(this.keyPair.getPrivate())
620+
.keyUse(KeyUse.SIGNATURE)
621+
.algorithm(this.jwsAlgorithm);
622+
623+
if (StringUtils.hasText(this.keyId)) {
624+
builder.keyID(this.keyId);
625+
}
626+
627+
try {
628+
return builder.build();
629+
} catch (IllegalStateException ex) {
630+
throw new JwtEncodingException("Failed to build ECKey: " + ex.getMessage(), ex);
631+
}
632+
}
633+
}
634+
635+
public static void main(String[] args) throws JOSEException, NoSuchAlgorithmException {
636+
637+
// 1
638+
SecretKeySpec key1 = new SecretKeySpec("secret".getBytes(), SignatureAlgorithm.ES256.getName());
639+
NimbusJwtEncoder encoder1 = NimbusJwtEncoder.withSecretKey(key1).build();
640+
641+
// 2
642+
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
643+
keyPairGenerator.initialize(2048);
644+
KeyPair keyPair = keyPairGenerator.generateKeyPair();
645+
NimbusJwtEncoder encoder2 = NimbusJwtEncoder.withKeyPair(keyPair)
646+
.keyId(UUID.randomUUID().toString())
647+
.signatureAlgorithm(SignatureAlgorithm.RS256)
648+
.build();
649+
650+
// 3
651+
SecretKeySpec key3 = new SecretKeySpec("secret".getBytes(), MacAlgorithm.HS256.getName());
652+
NimbusJwtEncoder encoder3 = NimbusJwtEncoder.withSecretKey(key3)
653+
.macAlgorithm(MacAlgorithm.HS256)
654+
.keyId(UUID.randomUUID().toString())
655+
.build();
656+
657+
// 4
658+
OctetSequenceKey jwk = new OctetSequenceKeyGenerator(256)
659+
.keyID(UUID.randomUUID().toString())
660+
.algorithm(JWSAlgorithm.HS256)
661+
.issueTime(new Date())
662+
.generate();
663+
JWKSource<SecurityContext> source = new ImmutableJWKSet<>(new JWKSet(jwk));
664+
NimbusJwtEncoder encoder4 = new NimbusJwtEncoder(source);
665+
666+
}
667+
372668
}

0 commit comments

Comments
 (0)