Skip to content

Enhance SslInfo to support multiple certificate stores. #45355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: 3.4.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.boot.actuate.ssl;

import java.util.List;
import java.util.Set;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -38,6 +39,7 @@
* Tests for {@link SslHealthIndicator}.
*
* @author Jonatan Ivanov
* @author Joshua Chen
*/
class SslHealthIndicatorTests {

Expand All @@ -55,7 +57,7 @@ void setUp() {
this.validity = mock(CertificateValidityInfo.class);
given(sslInfo.getBundles()).willReturn(List.of(bundle));
given(bundle.getCertificateChains()).willReturn(List.of(certificateChain));
given(certificateChain.getCertificates()).willReturn(List.of(certificateInfo));
given(certificateChain.getCertificates()).willReturn(Set.of(certificateInfo));
given(certificateInfo.getValidity()).willReturn(this.validity);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,28 @@
import java.security.cert.X509Certificate;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import javax.security.auth.x500.X500Principal;

import org.springframework.boot.info.SslInfo.CertificateValidityInfo.Status;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.ssl.SslBundles;
import org.springframework.boot.ssl.SslStoreBundle;
import org.springframework.util.ObjectUtils;

/**
* Information about the certificates that the application uses.
*
* @author Jonatan Ivanov
* @author Joshua Chen
* @since 3.4.0
*/
public class SslInfo {
Expand Down Expand Up @@ -72,7 +77,21 @@ public final class BundleInfo {

private BundleInfo(String name, SslBundle sslBundle) {
this.name = name;
this.certificateChains = extractCertificateChains(sslBundle.getStores().getKeyStore());
this.certificateChains = extractCertificateChains(sslBundle.getStores());
}

private List<CertificateChainInfo> extractCertificateChains(SslStoreBundle storeBundle) {
if (storeBundle == null) {
return Collections.emptyList();
}
List<CertificateChainInfo> certificateChains = new ArrayList<>();
if (storeBundle.getKeyStore() != null) {
certificateChains.addAll(extractCertificateChains(storeBundle.getKeyStore()));
}
if (storeBundle.getTrustStore() != null) {
certificateChains.addAll(extractCertificateChains(storeBundle.getTrustStore()));
}
return certificateChains;
}

private List<CertificateChainInfo> extractCertificateChains(KeyStore keyStore) {
Expand Down Expand Up @@ -107,29 +126,34 @@ public final class CertificateChainInfo {

private final String alias;

private final List<CertificateInfo> certificates;
private final Set<CertificateInfo> certificates;

CertificateChainInfo(KeyStore keyStore, String alias) {
this.alias = alias;
this.certificates = extractCertificates(keyStore, alias);
}

private List<CertificateInfo> extractCertificates(KeyStore keyStore, String alias) {
private Set<CertificateInfo> extractCertificates(KeyStore keyStore, String alias) {
try {
Certificate[] certificates = keyStore.getCertificateChain(alias);
return (!ObjectUtils.isEmpty(certificates))
? Arrays.stream(certificates).map(CertificateInfo::new).toList() : Collections.emptyList();
Set<CertificateInfo> certificateInfos = new java.util.HashSet<>(!ObjectUtils.isEmpty(certificates)
? Arrays.stream(certificates).map(CertificateInfo::new).collect(Collectors.toSet())
: Collections.emptySet());
if (keyStore.getCertificate(alias) != null) {
certificateInfos.add(new CertificateInfo(keyStore.getCertificate(alias)));
}
return certificateInfos;
}
catch (KeyStoreException ex) {
return Collections.emptyList();
return Collections.emptySet();
}
}

public String getAlias() {
return this.alias;
}

public List<CertificateInfo> getCertificates() {
public Set<CertificateInfo> getCertificates() {
return this.certificates;
}

Expand Down Expand Up @@ -208,6 +232,23 @@ private <R> R extract(Function<X509Certificate, R> extractor) {
return (this.certificate != null) ? extractor.apply(this.certificate) : null;
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || getClass() != other.getClass()) {
return false;
}
return (this.certificate != null) ? this.certificate.equals(((CertificateInfo) other).certificate)
: super.equals(other);
}

@Override
public int hashCode() {
return (this.certificate != null) ? this.certificate.hashCode() : super.hashCode();
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
import java.nio.file.Path;
import java.time.Duration;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;

import org.springframework.boot.info.SslInfo.BundleInfo;
import org.springframework.boot.info.SslInfo.CertificateChainInfo;
Expand All @@ -46,13 +49,15 @@
* Tests for {@link SslInfo}.
*
* @author Jonatan Ivanov
* @author Joshua Chen
*/
class SslInfoTests {

@Test
@ParameterizedTest
@EnumSource(StoreType.class)
@WithPackageResources("test.p12")
void validCertificatesShouldProvideSslInfo() {
SslInfo sslInfo = createSslInfo("classpath:test.p12");
void validCertificatesShouldProvideSslInfo(StoreType storeType) {
SslInfo sslInfo = createSslInfo(storeType, "classpath:test.p12");
assertThat(sslInfo.getBundles()).hasSize(1);
BundleInfo bundle = sslInfo.getBundles().get(0);
assertThat(bundle.getName()).isEqualTo("test-0");
Expand All @@ -62,10 +67,10 @@ void validCertificatesShouldProvideSslInfo() {
assertThat(bundle.getCertificateChains().get(1).getAlias()).isEqualTo("test-alias");
assertThat(bundle.getCertificateChains().get(1).getCertificates()).hasSize(1);
assertThat(bundle.getCertificateChains().get(2).getAlias()).isEqualTo("spring-boot-cert");
assertThat(bundle.getCertificateChains().get(2).getCertificates()).isEmpty();
assertThat(bundle.getCertificateChains().get(2).getCertificates()).hasSize(1);
assertThat(bundle.getCertificateChains().get(3).getAlias()).isEqualTo("test-alias-cert");
assertThat(bundle.getCertificateChains().get(3).getCertificates()).isEmpty();
CertificateInfo cert1 = bundle.getCertificateChains().get(0).getCertificates().get(0);
assertThat(bundle.getCertificateChains().get(3).getCertificates()).hasSize(1);
CertificateInfo cert1 = bundle.getCertificateChains().get(0).getCertificates().iterator().next();
assertThat(cert1.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
assertThat(cert1.getIssuer()).isEqualTo(cert1.getSubject());
assertThat(cert1.getSerialNumber()).isNotEmpty();
Expand All @@ -76,7 +81,7 @@ void validCertificatesShouldProvideSslInfo() {
assertThat(cert1.getValidity()).isNotNull();
assertThat(cert1.getValidity().getStatus()).isSameAs(Status.VALID);
assertThat(cert1.getValidity().getMessage()).isNull();
CertificateInfo cert2 = bundle.getCertificateChains().get(1).getCertificates().get(0);
CertificateInfo cert2 = bundle.getCertificateChains().get(1).getCertificates().iterator().next();
assertThat(cert2.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
assertThat(cert2.getIssuer()).isEqualTo(cert2.getSubject());
assertThat(cert2.getSerialNumber()).isNotEmpty();
Expand All @@ -89,19 +94,20 @@ void validCertificatesShouldProvideSslInfo() {
assertThat(cert2.getValidity().getMessage()).isNull();
}

@Test
@ParameterizedTest
@EnumSource(StoreType.class)
@WithPackageResources("test-not-yet-valid.p12")
void notYetValidCertificateShouldProvideSslInfo() {
SslInfo sslInfo = createSslInfo("classpath:test-not-yet-valid.p12");
void notYetValidCertificateShouldProvideSslInfo(StoreType storeType) {
SslInfo sslInfo = createSslInfo(storeType, "classpath:test-not-yet-valid.p12");
assertThat(sslInfo.getBundles()).hasSize(1);
BundleInfo bundle = sslInfo.getBundles().get(0);
assertThat(bundle.getName()).isEqualTo("test-0");
assertThat(bundle.getCertificateChains()).hasSize(1);
CertificateChainInfo certificateChain = bundle.getCertificateChains().get(0);
assertThat(certificateChain.getAlias()).isEqualTo("spring-boot");
List<CertificateInfo> certs = certificateChain.getCertificates();
Set<CertificateInfo> certs = certificateChain.getCertificates();
assertThat(certs).hasSize(1);
CertificateInfo cert = certs.get(0);
CertificateInfo cert = certs.iterator().next();
assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
assertThat(cert.getIssuer()).isEqualTo(cert.getSubject());
assertThat(cert.getSerialNumber()).isNotEmpty();
Expand All @@ -124,9 +130,9 @@ void expiredCertificateShouldProvideSslInfo() {
assertThat(bundle.getCertificateChains()).hasSize(1);
CertificateChainInfo certificateChain = bundle.getCertificateChains().get(0);
assertThat(certificateChain.getAlias()).isEqualTo("spring-boot");
List<CertificateInfo> certs = certificateChain.getCertificates();
Set<CertificateInfo> certs = certificateChain.getCertificates();
assertThat(certs).hasSize(1);
CertificateInfo cert = certs.get(0);
CertificateInfo cert = certs.iterator().next();
assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
assertThat(cert.getIssuer()).isEqualTo(cert.getSubject());
assertThat(cert.getSerialNumber()).isNotEmpty();
Expand All @@ -150,9 +156,9 @@ void soonToBeExpiredCertificateShouldProvideSslInfo(@TempDir Path tempDir)
assertThat(bundle.getCertificateChains()).hasSize(1);
CertificateChainInfo certificateChain = bundle.getCertificateChains().get(0);
assertThat(certificateChain.getAlias()).isEqualTo("spring-boot");
List<CertificateInfo> certs = certificateChain.getCertificates();
Set<CertificateInfo> certs = certificateChain.getCertificates();
assertThat(certs).hasSize(1);
CertificateInfo cert = certs.get(0);
CertificateInfo cert = certs.iterator().next();
assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
assertThat(cert.getIssuer()).isEqualTo(cert.getSubject());
assertThat(cert.getSerialNumber()).isNotEmpty();
Expand All @@ -178,7 +184,7 @@ void multipleBundlesShouldProvideSslInfo(@TempDir Path tempDir) throws IOExcepti
.flatMap((bundle) -> bundle.getCertificateChains().stream())
.flatMap((certificateChain) -> certificateChain.getCertificates().stream())
.toList();
assertThat(certs).hasSize(5);
assertThat(certs).hasSize(7);
assertThat(certs).allSatisfy((cert) -> {
assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
assertThat(cert.getIssuer()).isEqualTo(cert.getSubject());
Expand Down Expand Up @@ -227,10 +233,20 @@ void nullKeyStore() {
}

private SslInfo createSslInfo(String... locations) {
return createSslInfo(StoreType.KEYSTORE, locations);
}

private SslInfo createSslInfo(StoreType storeType, String... locations) {
DefaultSslBundleRegistry sslBundleRegistry = new DefaultSslBundleRegistry();
for (int i = 0; i < locations.length; i++) {
JksSslStoreDetails keyStoreDetails = JksSslStoreDetails.forLocation(locations[i]).withPassword("secret");
SslStoreBundle sslStoreBundle = new JksSslStoreBundle(keyStoreDetails, null);
JksSslStoreDetails storeDetails = JksSslStoreDetails.forLocation(locations[i]).withPassword("secret");
SslStoreBundle sslStoreBundle;
if (storeType == StoreType.TRUSTSTORE) {
sslStoreBundle = new JksSslStoreBundle(null, storeDetails);
}
else {
sslStoreBundle = new JksSslStoreBundle(storeDetails, null);
}
sslBundleRegistry.registerBundle("test-%d".formatted(i), SslBundle.of(sslStoreBundle));
}
return new SslInfo(sslBundleRegistry, Duration.ofDays(7));
Expand Down Expand Up @@ -270,4 +286,10 @@ private ProcessBuilder createProcessBuilder(Path keystore) {
return processBuilder;
}

private enum StoreType {

KEYSTORE, TRUSTSTORE

}

}