Skip to content

NimbusJwtEncoder should simplify constructing with javax.security Keys #17033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,37 @@

import java.net.URI;
import java.net.URL;
import java.security.KeyPair;
import java.security.interfaces.ECPublicKey;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;

import javax.crypto.SecretKey;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
import com.nimbusds.jose.jwk.Curve;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.produce.JWSSignerFactory;
Expand All @@ -47,6 +58,7 @@
import com.nimbusds.jwt.SignedJWT;

import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
Expand Down Expand Up @@ -83,6 +95,8 @@ public final class NimbusJwtEncoder implements JwtEncoder {

private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory();

private JwsHeader jwsHeader;

private final Map<JWK, JWSSigner> jwsSigners = new ConcurrentHashMap<>();

private final JWKSource<SecurityContext> jwkSource;
Expand Down Expand Up @@ -119,14 +133,16 @@ public void setJwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
this.jwkSelector = jwkSelector;
}

public void setJwsHeader(JwsHeader jwsHeader) {
this.jwsHeader = jwsHeader;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check for null here instead of in encode. You can use Assert.notNull for this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can never be null here
as we setting default value if jws header is null i.e. not sent by user

}

@Override
public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException {
Assert.notNull(parameters, "parameters cannot be null");

JwsHeader headers = parameters.getJwsHeader();
if (headers == null) {
headers = DEFAULT_JWS_HEADER;
}
headers = (headers != null) ? headers : (this.jwsHeader != null) ? this.jwsHeader : DEFAULT_JWS_HEADER;
JwtClaimsSet claims = parameters.getClaims();

JWK jwk = selectJwk(headers);
Expand Down Expand Up @@ -369,4 +385,250 @@ private static URI convertAsURI(String header, URL url) {
}
}

/**
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
* {@link SecretKey}.
* @param secretKey the {@link SecretKey} to use for signing JWTs
* @return a {@link SecretKeyJwtEncoderBuilder} for further configuration
* @since 7.0
*/
public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) {
Assert.notNull(secretKey, "secretKey cannot be null");
return new SecretKeyJwtEncoderBuilder(secretKey);
}

/**
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
* {@link KeyPair}. The key pair must contain either an {@link RSAKey} or an
* {@link ECKey}.
* @param keyPair the {@link KeyPair} to use for signing JWTs
* @return a {@link KeyPairJwtEncoderBuilder} for further configuration
* @since 7.0
*/
public static KeyPairJwtEncoderBuilder withKeyPair(KeyPair keyPair) {
Assert.isTrue(keyPair != null && keyPair.getPrivate() != null && keyPair.getPublic() != null,
"keyPair, its private key, and public key must not be null");
if (keyPair.getPrivate() instanceof java.security.interfaces.RSAKey) {
return new RsaKeyPairJwtEncoderBuilder(keyPair);
}
if (keyPair.getPrivate() instanceof java.security.interfaces.ECKey) {
return new EcKeyPairJwtEncoderBuilder(keyPair);
}
throw new IllegalArgumentException("keyPair must be an RSAKey or an ECKey");
}

/**
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
* {@link SecretKey}.
*
* @since 7.0
*/
public static final class SecretKeyJwtEncoderBuilder {

private final SecretKey secretKey;

private String keyId;

private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256;

private SecretKeyJwtEncoderBuilder(SecretKey secretKey) {
this.secretKey = secretKey;
}

/**
* Sets the JWS algorithm to use for signing. Defaults to
* {@link JWSAlgorithm#HS256}. Must be an HMAC-based algorithm (HS256, HS384, or
* HS512).
* @param macAlgorithm the {@link MacAlgorithm} to use
* @return this builder instance for method chaining
*/
public SecretKeyJwtEncoderBuilder macAlgorithm(MacAlgorithm macAlgorithm) {
Assert.notNull(macAlgorithm, "macAlgorithm cannot be null");
Assert.state(JWSAlgorithm.Family.HMAC_SHA.contains(this.jwsAlgorithm),
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with a SecretKey. "
+ "Please use one of the HS256, HS384, or HS512 algorithms.");

this.jwsAlgorithm = JWSAlgorithm.parse(macAlgorithm.getName());
return this;
}

/**
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
* header.
* @param keyId the key identifier
* @return this builder instance for method chaining
*/
public SecretKeyJwtEncoderBuilder keyId(String keyId) {
this.keyId = keyId;
return this;
}

/**
* Builds the {@link NimbusJwtEncoder} instance.
* @return the configured {@link NimbusJwtEncoder}
* @throws IllegalStateException if the configured JWS algorithm is not compatible
* with a {@link SecretKey}.
*/
public NimbusJwtEncoder build() {
this.jwsAlgorithm = (this.jwsAlgorithm != null) ? this.jwsAlgorithm : JWSAlgorithm.HS256;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are giving this a default value and there isn't a way for the application to set it back to null, I believe you can remove this line.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the following case, it may be null.

    SecretKey defaultSecretKey = createSecretKey("HmacSHA256", 256); // Key for default HS256
    NimbusJwtEncoder encoderWithSecretDefaults = NimbusJwtEncoder.withSecretKey(defaultSecretKey)
            // .macAlgorithm(MacAlgorithm.HS256) // Default is HS256
            // .keyId(null)                     // Default keyId is null
            // .jwkSelector(...)               // Default selector throws if multiple keys match
            .build();


OctetSequenceKey.Builder builder = new OctetSequenceKey.Builder(this.secretKey).keyUse(KeyUse.SIGNATURE)
.algorithm(this.jwsAlgorithm)
.keyID(this.keyId);

OctetSequenceKey jwk = builder.build();
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
encoder.setJwsHeader(JwsHeader.with(MacAlgorithm.from(this.jwsAlgorithm.getName())).build());
return encoder;
}

}

/**
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
* {@link KeyPair}.
*
* @since 7.0
*/
public abstract static class KeyPairJwtEncoderBuilder {

private final KeyPair keyPair;

private String keyId;

private JWSAlgorithm jwsAlgorithm;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's please set a default JWSAlgorithm.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in buildJwk() as now we have two separate builder for RSA and EC


private KeyPairJwtEncoderBuilder(KeyPair keyPair) {
this.keyPair = keyPair;
}

/**
* Sets the JWS algorithm to use for signing. Must be compatible with the key type
* (RSA or EC). If not set, a default algorithm will be chosen based on the key
* type (e.g., RS256 for RSA, ES256 for EC).
* @param signatureAlgorithm the {@link SignatureAlgorithm} to use
* @return this builder instance for method chaining
*/
public KeyPairJwtEncoderBuilder signatureAlgorithm(SignatureAlgorithm signatureAlgorithm) {
Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null");
if (this.keyPair.getPrivate() instanceof java.security.interfaces.RSAKey) {
Assert.state(JWSAlgorithm.Family.RSA.contains(JWSAlgorithm.parse(signatureAlgorithm.getName())),
() -> "The algorithm '" + signatureAlgorithm + "' is not compatible with an RSAKey. "
+ "Please use one of the RS256, RS384, RS512, PS256, PS384, or PS512 algorithms.");

}
if (this.keyPair.getPrivate() instanceof java.security.interfaces.ECKey) {
Assert.state(JWSAlgorithm.Family.EC.contains(JWSAlgorithm.parse(signatureAlgorithm.getName())),
() -> "The algorithm '" + signatureAlgorithm + "' is not compatible with an ECKey. "
+ "Please use one of the ES256, ES384, or ES512 algorithms.");
}
this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
return this;
}

/**
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
* header.
* @param keyId the key identifier
* @return this builder instance for method chaining
*/
public KeyPairJwtEncoderBuilder keyId(String keyId) {
this.keyId = keyId;
return this;
}

/**
* Builds the {@link NimbusJwtEncoder} instance.
* @return the configured {@link NimbusJwtEncoder}
* @throws IllegalStateException if the key type is unsupported or the configured
* JWS algorithm is not compatible with the key type.
* @throws JwtEncodingException if the key is invalid (e.g., EC key with unknown
* curve)
*/
public NimbusJwtEncoder build() {
this.keyId = (this.keyId != null) ? this.keyId : UUID.randomUUID().toString();
JWK jwk = buildJwk();
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.from(this.jwsAlgorithm.getName()))
.keyId(jwk.getKeyID())
.build();
encoder.setJwsHeader(jwsHeader);
return encoder;
}

protected abstract JWK buildJwk();

}

/**
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
* {@link KeyPair}.
*
* @since 7.0
*/
public static final class RsaKeyPairJwtEncoderBuilder extends KeyPairJwtEncoderBuilder {

private RsaKeyPairJwtEncoderBuilder(KeyPair keyPair) {
super(keyPair);
}

@Override
protected JWK buildJwk() {
if (super.jwsAlgorithm == null) {
super.jwsAlgorithm = JWSAlgorithm.RS256;
}

RSAKey.Builder builder = new RSAKey.Builder(
(java.security.interfaces.RSAPublicKey) super.keyPair.getPublic())
.privateKey(super.keyPair.getPrivate())
.keyID(super.keyId)
.keyUse(KeyUse.SIGNATURE)
.algorithm(super.jwsAlgorithm);
return builder.build();
}

}

/**
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
* {@link KeyPair}.
*
* @since 7.0
*/
public static final class EcKeyPairJwtEncoderBuilder extends KeyPairJwtEncoderBuilder {

private EcKeyPairJwtEncoderBuilder(KeyPair keyPair) {
super(keyPair);
}

@Override
protected JWK buildJwk() {
if (super.jwsAlgorithm == null) {
super.jwsAlgorithm = JWSAlgorithm.ES256;
}

ECPublicKey publicKey = (ECPublicKey) super.keyPair.getPublic();
Curve curve = Curve.forECParameterSpec(publicKey.getParams());
if (curve == null) {
throw new JwtEncodingException("Unable to determine Curve for EC public key.");
}

com.nimbusds.jose.jwk.ECKey.Builder builder = new com.nimbusds.jose.jwk.ECKey.Builder(curve, publicKey)
.privateKey(super.keyPair.getPrivate())
.keyUse(KeyUse.SIGNATURE)
.keyID(super.keyId)
.algorithm(super.jwsAlgorithm);

try {
return builder.build();
}
catch (IllegalStateException ex) {
throw new IllegalArgumentException("Failed to build ECKey: " + ex.getMessage(), ex);
}
}

}

}
Loading