diff --git a/spring-boot-project/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/ssl/SslHealthIndicatorTests.java b/spring-boot-project/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/ssl/SslHealthIndicatorTests.java index 74475c5ebf42..29cc949d5ecd 100644 --- a/spring-boot-project/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/ssl/SslHealthIndicatorTests.java +++ b/spring-boot-project/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/ssl/SslHealthIndicatorTests.java @@ -17,6 +17,7 @@ package org.springframework.boot.actuate.ssl; import java.util.List; +import java.util.Set; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -38,6 +39,7 @@ * Tests for {@link SslHealthIndicator}. * * @author Jonatan Ivanov + * @author Joshua Chen */ class SslHealthIndicatorTests { @@ -55,7 +57,7 @@ void setUp() { this.validity = mock(CertificateValidityInfo.class); given(sslInfo.getBundles()).willReturn(List.of(bundle)); given(bundle.getCertificateChains()).willReturn(List.of(certificateChain)); - given(certificateChain.getCertificates()).willReturn(List.of(certificateInfo)); + given(certificateChain.getCertificates()).willReturn(Set.of(certificateInfo)); given(certificateInfo.getValidity()).willReturn(this.validity); } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/info/SslInfo.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/info/SslInfo.java index c88eb43a288a..0b65a444290d 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/info/SslInfo.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/info/SslInfo.java @@ -24,23 +24,28 @@ import java.security.cert.X509Certificate; import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Date; import java.util.List; +import java.util.Set; import java.util.function.Function; +import java.util.stream.Collectors; import javax.security.auth.x500.X500Principal; import org.springframework.boot.info.SslInfo.CertificateValidityInfo.Status; import org.springframework.boot.ssl.SslBundle; import org.springframework.boot.ssl.SslBundles; +import org.springframework.boot.ssl.SslStoreBundle; import org.springframework.util.ObjectUtils; /** * Information about the certificates that the application uses. * * @author Jonatan Ivanov + * @author Joshua Chen * @since 3.4.0 */ public class SslInfo { @@ -72,7 +77,21 @@ public final class BundleInfo { private BundleInfo(String name, SslBundle sslBundle) { this.name = name; - this.certificateChains = extractCertificateChains(sslBundle.getStores().getKeyStore()); + this.certificateChains = extractCertificateChains(sslBundle.getStores()); + } + + private List extractCertificateChains(SslStoreBundle storeBundle) { + if (storeBundle == null) { + return Collections.emptyList(); + } + List certificateChains = new ArrayList<>(); + if (storeBundle.getKeyStore() != null) { + certificateChains.addAll(extractCertificateChains(storeBundle.getKeyStore())); + } + if (storeBundle.getTrustStore() != null) { + certificateChains.addAll(extractCertificateChains(storeBundle.getTrustStore())); + } + return certificateChains; } private List extractCertificateChains(KeyStore keyStore) { @@ -107,21 +126,26 @@ public final class CertificateChainInfo { private final String alias; - private final List certificates; + private final Set certificates; CertificateChainInfo(KeyStore keyStore, String alias) { this.alias = alias; this.certificates = extractCertificates(keyStore, alias); } - private List extractCertificates(KeyStore keyStore, String alias) { + private Set extractCertificates(KeyStore keyStore, String alias) { try { Certificate[] certificates = keyStore.getCertificateChain(alias); - return (!ObjectUtils.isEmpty(certificates)) - ? Arrays.stream(certificates).map(CertificateInfo::new).toList() : Collections.emptyList(); + Set certificateInfos = new java.util.HashSet<>(!ObjectUtils.isEmpty(certificates) + ? Arrays.stream(certificates).map(CertificateInfo::new).collect(Collectors.toSet()) + : Collections.emptySet()); + if (keyStore.getCertificate(alias) != null) { + certificateInfos.add(new CertificateInfo(keyStore.getCertificate(alias))); + } + return certificateInfos; } catch (KeyStoreException ex) { - return Collections.emptyList(); + return Collections.emptySet(); } } @@ -129,7 +153,7 @@ public String getAlias() { return this.alias; } - public List getCertificates() { + public Set getCertificates() { return this.certificates; } @@ -208,6 +232,23 @@ private R extract(Function extractor) { return (this.certificate != null) ? extractor.apply(this.certificate) : null; } + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + return (this.certificate != null) ? this.certificate.equals(((CertificateInfo) other).certificate) + : super.equals(other); + } + + @Override + public int hashCode() { + return (this.certificate != null) ? this.certificate.hashCode() : super.hashCode(); + } + } /** diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/info/SslInfoTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/info/SslInfoTests.java index 04d817b021ae..4c251d1c002a 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/info/SslInfoTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/info/SslInfoTests.java @@ -23,10 +23,13 @@ import java.nio.file.Path; import java.time.Duration; import java.util.List; +import java.util.Set; import java.util.stream.Collectors; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import org.springframework.boot.info.SslInfo.BundleInfo; import org.springframework.boot.info.SslInfo.CertificateChainInfo; @@ -46,13 +49,15 @@ * Tests for {@link SslInfo}. * * @author Jonatan Ivanov + * @author Joshua Chen */ class SslInfoTests { - @Test + @ParameterizedTest + @EnumSource(StoreType.class) @WithPackageResources("test.p12") - void validCertificatesShouldProvideSslInfo() { - SslInfo sslInfo = createSslInfo("classpath:test.p12"); + void validCertificatesShouldProvideSslInfo(StoreType storeType) { + SslInfo sslInfo = createSslInfo(storeType, "classpath:test.p12"); assertThat(sslInfo.getBundles()).hasSize(1); BundleInfo bundle = sslInfo.getBundles().get(0); assertThat(bundle.getName()).isEqualTo("test-0"); @@ -62,10 +67,10 @@ void validCertificatesShouldProvideSslInfo() { assertThat(bundle.getCertificateChains().get(1).getAlias()).isEqualTo("test-alias"); assertThat(bundle.getCertificateChains().get(1).getCertificates()).hasSize(1); assertThat(bundle.getCertificateChains().get(2).getAlias()).isEqualTo("spring-boot-cert"); - assertThat(bundle.getCertificateChains().get(2).getCertificates()).isEmpty(); + assertThat(bundle.getCertificateChains().get(2).getCertificates()).hasSize(1); assertThat(bundle.getCertificateChains().get(3).getAlias()).isEqualTo("test-alias-cert"); - assertThat(bundle.getCertificateChains().get(3).getCertificates()).isEmpty(); - CertificateInfo cert1 = bundle.getCertificateChains().get(0).getCertificates().get(0); + assertThat(bundle.getCertificateChains().get(3).getCertificates()).hasSize(1); + CertificateInfo cert1 = bundle.getCertificateChains().get(0).getCertificates().iterator().next(); assertThat(cert1.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US"); assertThat(cert1.getIssuer()).isEqualTo(cert1.getSubject()); assertThat(cert1.getSerialNumber()).isNotEmpty(); @@ -76,7 +81,7 @@ void validCertificatesShouldProvideSslInfo() { assertThat(cert1.getValidity()).isNotNull(); assertThat(cert1.getValidity().getStatus()).isSameAs(Status.VALID); assertThat(cert1.getValidity().getMessage()).isNull(); - CertificateInfo cert2 = bundle.getCertificateChains().get(1).getCertificates().get(0); + CertificateInfo cert2 = bundle.getCertificateChains().get(1).getCertificates().iterator().next(); assertThat(cert2.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US"); assertThat(cert2.getIssuer()).isEqualTo(cert2.getSubject()); assertThat(cert2.getSerialNumber()).isNotEmpty(); @@ -89,19 +94,20 @@ void validCertificatesShouldProvideSslInfo() { assertThat(cert2.getValidity().getMessage()).isNull(); } - @Test + @ParameterizedTest + @EnumSource(StoreType.class) @WithPackageResources("test-not-yet-valid.p12") - void notYetValidCertificateShouldProvideSslInfo() { - SslInfo sslInfo = createSslInfo("classpath:test-not-yet-valid.p12"); + void notYetValidCertificateShouldProvideSslInfo(StoreType storeType) { + SslInfo sslInfo = createSslInfo(storeType, "classpath:test-not-yet-valid.p12"); assertThat(sslInfo.getBundles()).hasSize(1); BundleInfo bundle = sslInfo.getBundles().get(0); assertThat(bundle.getName()).isEqualTo("test-0"); assertThat(bundle.getCertificateChains()).hasSize(1); CertificateChainInfo certificateChain = bundle.getCertificateChains().get(0); assertThat(certificateChain.getAlias()).isEqualTo("spring-boot"); - List certs = certificateChain.getCertificates(); + Set certs = certificateChain.getCertificates(); assertThat(certs).hasSize(1); - CertificateInfo cert = certs.get(0); + CertificateInfo cert = certs.iterator().next(); assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US"); assertThat(cert.getIssuer()).isEqualTo(cert.getSubject()); assertThat(cert.getSerialNumber()).isNotEmpty(); @@ -124,9 +130,9 @@ void expiredCertificateShouldProvideSslInfo() { assertThat(bundle.getCertificateChains()).hasSize(1); CertificateChainInfo certificateChain = bundle.getCertificateChains().get(0); assertThat(certificateChain.getAlias()).isEqualTo("spring-boot"); - List certs = certificateChain.getCertificates(); + Set certs = certificateChain.getCertificates(); assertThat(certs).hasSize(1); - CertificateInfo cert = certs.get(0); + CertificateInfo cert = certs.iterator().next(); assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US"); assertThat(cert.getIssuer()).isEqualTo(cert.getSubject()); assertThat(cert.getSerialNumber()).isNotEmpty(); @@ -150,9 +156,9 @@ void soonToBeExpiredCertificateShouldProvideSslInfo(@TempDir Path tempDir) assertThat(bundle.getCertificateChains()).hasSize(1); CertificateChainInfo certificateChain = bundle.getCertificateChains().get(0); assertThat(certificateChain.getAlias()).isEqualTo("spring-boot"); - List certs = certificateChain.getCertificates(); + Set certs = certificateChain.getCertificates(); assertThat(certs).hasSize(1); - CertificateInfo cert = certs.get(0); + CertificateInfo cert = certs.iterator().next(); assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US"); assertThat(cert.getIssuer()).isEqualTo(cert.getSubject()); assertThat(cert.getSerialNumber()).isNotEmpty(); @@ -178,7 +184,7 @@ void multipleBundlesShouldProvideSslInfo(@TempDir Path tempDir) throws IOExcepti .flatMap((bundle) -> bundle.getCertificateChains().stream()) .flatMap((certificateChain) -> certificateChain.getCertificates().stream()) .toList(); - assertThat(certs).hasSize(5); + assertThat(certs).hasSize(7); assertThat(certs).allSatisfy((cert) -> { assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US"); assertThat(cert.getIssuer()).isEqualTo(cert.getSubject()); @@ -227,10 +233,20 @@ void nullKeyStore() { } private SslInfo createSslInfo(String... locations) { + return createSslInfo(StoreType.KEYSTORE, locations); + } + + private SslInfo createSslInfo(StoreType storeType, String... locations) { DefaultSslBundleRegistry sslBundleRegistry = new DefaultSslBundleRegistry(); for (int i = 0; i < locations.length; i++) { - JksSslStoreDetails keyStoreDetails = JksSslStoreDetails.forLocation(locations[i]).withPassword("secret"); - SslStoreBundle sslStoreBundle = new JksSslStoreBundle(keyStoreDetails, null); + JksSslStoreDetails storeDetails = JksSslStoreDetails.forLocation(locations[i]).withPassword("secret"); + SslStoreBundle sslStoreBundle; + if (storeType == StoreType.TRUSTSTORE) { + sslStoreBundle = new JksSslStoreBundle(null, storeDetails); + } + else { + sslStoreBundle = new JksSslStoreBundle(storeDetails, null); + } sslBundleRegistry.registerBundle("test-%d".formatted(i), SslBundle.of(sslStoreBundle)); } return new SslInfo(sslBundleRegistry, Duration.ofDays(7)); @@ -270,4 +286,10 @@ private ProcessBuilder createProcessBuilder(Path keystore) { return processBuilder; } + private enum StoreType { + + KEYSTORE, TRUSTSTORE + + } + }