Skip to content

Commit f0f5036

Browse files
committed
Refactor SslInfo to support multiple certificate stores.
1 parent 4db4fb7 commit f0f5036

File tree

3 files changed

+92
-27
lines changed

3 files changed

+92
-27
lines changed

spring-boot-project/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/ssl/SslHealthIndicatorTests.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.boot.actuate.ssl;
1818

1919
import java.util.List;
20+
import java.util.Set;
2021

2122
import org.junit.jupiter.api.BeforeEach;
2223
import org.junit.jupiter.api.Test;
@@ -38,6 +39,7 @@
3839
* Tests for {@link SslHealthIndicator}.
3940
*
4041
* @author Jonatan Ivanov
42+
* @author Joshua Chen
4143
*/
4244
class SslHealthIndicatorTests {
4345

@@ -55,7 +57,7 @@ void setUp() {
5557
this.validity = mock(CertificateValidityInfo.class);
5658
given(sslInfo.getBundles()).willReturn(List.of(bundle));
5759
given(bundle.getCertificateChains()).willReturn(List.of(certificateChain));
58-
given(certificateChain.getCertificates()).willReturn(List.of(certificateInfo));
60+
given(certificateChain.getCertificates()).willReturn(Set.of(certificateInfo));
5961
given(certificateInfo.getValidity()).willReturn(this.validity);
6062
}
6163

spring-boot-project/spring-boot/src/main/java/org/springframework/boot/info/SslInfo.java

+48-7
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,28 @@
2424
import java.security.cert.X509Certificate;
2525
import java.time.Duration;
2626
import java.time.Instant;
27+
import java.util.ArrayList;
2728
import java.util.Arrays;
2829
import java.util.Collections;
2930
import java.util.Date;
3031
import java.util.List;
32+
import java.util.Set;
3133
import java.util.function.Function;
34+
import java.util.stream.Collectors;
3235

3336
import javax.security.auth.x500.X500Principal;
3437

3538
import org.springframework.boot.info.SslInfo.CertificateValidityInfo.Status;
3639
import org.springframework.boot.ssl.SslBundle;
3740
import org.springframework.boot.ssl.SslBundles;
41+
import org.springframework.boot.ssl.SslStoreBundle;
3842
import org.springframework.util.ObjectUtils;
3943

4044
/**
4145
* Information about the certificates that the application uses.
4246
*
4347
* @author Jonatan Ivanov
48+
* @author Joshua Chen
4449
* @since 3.4.0
4550
*/
4651
public class SslInfo {
@@ -72,7 +77,21 @@ public final class BundleInfo {
7277

7378
private BundleInfo(String name, SslBundle sslBundle) {
7479
this.name = name;
75-
this.certificateChains = extractCertificateChains(sslBundle.getStores().getKeyStore());
80+
this.certificateChains = extractCertificateChains(sslBundle.getStores());
81+
}
82+
83+
private List<CertificateChainInfo> extractCertificateChains(SslStoreBundle storeBundle) {
84+
if (storeBundle == null) {
85+
return Collections.emptyList();
86+
}
87+
List<CertificateChainInfo> certificateChains = new ArrayList<>();
88+
if (storeBundle.getKeyStore() != null) {
89+
certificateChains.addAll(extractCertificateChains(storeBundle.getKeyStore()));
90+
}
91+
if (storeBundle.getTrustStore() != null) {
92+
certificateChains.addAll(extractCertificateChains(storeBundle.getTrustStore()));
93+
}
94+
return certificateChains;
7695
}
7796

7897
private List<CertificateChainInfo> extractCertificateChains(KeyStore keyStore) {
@@ -107,29 +126,34 @@ public final class CertificateChainInfo {
107126

108127
private final String alias;
109128

110-
private final List<CertificateInfo> certificates;
129+
private final Set<CertificateInfo> certificates;
111130

112131
CertificateChainInfo(KeyStore keyStore, String alias) {
113132
this.alias = alias;
114133
this.certificates = extractCertificates(keyStore, alias);
115134
}
116135

117-
private List<CertificateInfo> extractCertificates(KeyStore keyStore, String alias) {
136+
private Set<CertificateInfo> extractCertificates(KeyStore keyStore, String alias) {
118137
try {
119138
Certificate[] certificates = keyStore.getCertificateChain(alias);
120-
return (!ObjectUtils.isEmpty(certificates))
121-
? Arrays.stream(certificates).map(CertificateInfo::new).toList() : Collections.emptyList();
139+
Set<CertificateInfo> certificateInfos = new java.util.HashSet<>(!ObjectUtils.isEmpty(certificates)
140+
? Arrays.stream(certificates).map(CertificateInfo::new).collect(Collectors.toSet())
141+
: Collections.emptySet());
142+
if (keyStore.getCertificate(alias) != null) {
143+
certificateInfos.add(new CertificateInfo(keyStore.getCertificate(alias)));
144+
}
145+
return certificateInfos;
122146
}
123147
catch (KeyStoreException ex) {
124-
return Collections.emptyList();
148+
return Collections.emptySet();
125149
}
126150
}
127151

128152
public String getAlias() {
129153
return this.alias;
130154
}
131155

132-
public List<CertificateInfo> getCertificates() {
156+
public Set<CertificateInfo> getCertificates() {
133157
return this.certificates;
134158
}
135159

@@ -208,6 +232,23 @@ private <R> R extract(Function<X509Certificate, R> extractor) {
208232
return (this.certificate != null) ? extractor.apply(this.certificate) : null;
209233
}
210234

235+
@Override
236+
public boolean equals(Object other) {
237+
if (this == other) {
238+
return true;
239+
}
240+
if (other == null || getClass() != other.getClass()) {
241+
return false;
242+
}
243+
return (this.certificate != null) ? this.certificate.equals(((CertificateInfo) other).certificate)
244+
: super.equals(other);
245+
}
246+
247+
@Override
248+
public int hashCode() {
249+
return (this.certificate != null) ? this.certificate.hashCode() : super.hashCode();
250+
}
251+
211252
}
212253

213254
/**

spring-boot-project/spring-boot/src/test/java/org/springframework/boot/info/SslInfoTests.java

+41-19
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
import java.nio.file.Path;
2424
import java.time.Duration;
2525
import java.util.List;
26+
import java.util.Set;
2627
import java.util.stream.Collectors;
2728

2829
import org.junit.jupiter.api.Test;
2930
import org.junit.jupiter.api.io.TempDir;
31+
import org.junit.jupiter.params.ParameterizedTest;
32+
import org.junit.jupiter.params.provider.EnumSource;
3033

3134
import org.springframework.boot.info.SslInfo.BundleInfo;
3235
import org.springframework.boot.info.SslInfo.CertificateChainInfo;
@@ -46,13 +49,15 @@
4649
* Tests for {@link SslInfo}.
4750
*
4851
* @author Jonatan Ivanov
52+
* @author Joshua Chen
4953
*/
5054
class SslInfoTests {
5155

52-
@Test
56+
@ParameterizedTest
57+
@EnumSource(StoreType.class)
5358
@WithPackageResources("test.p12")
54-
void validCertificatesShouldProvideSslInfo() {
55-
SslInfo sslInfo = createSslInfo("classpath:test.p12");
59+
void validCertificatesShouldProvideSslInfo(StoreType storeType) {
60+
SslInfo sslInfo = createSslInfo(storeType, "classpath:test.p12");
5661
assertThat(sslInfo.getBundles()).hasSize(1);
5762
BundleInfo bundle = sslInfo.getBundles().get(0);
5863
assertThat(bundle.getName()).isEqualTo("test-0");
@@ -62,10 +67,10 @@ void validCertificatesShouldProvideSslInfo() {
6267
assertThat(bundle.getCertificateChains().get(1).getAlias()).isEqualTo("test-alias");
6368
assertThat(bundle.getCertificateChains().get(1).getCertificates()).hasSize(1);
6469
assertThat(bundle.getCertificateChains().get(2).getAlias()).isEqualTo("spring-boot-cert");
65-
assertThat(bundle.getCertificateChains().get(2).getCertificates()).isEmpty();
70+
assertThat(bundle.getCertificateChains().get(2).getCertificates()).hasSize(1);
6671
assertThat(bundle.getCertificateChains().get(3).getAlias()).isEqualTo("test-alias-cert");
67-
assertThat(bundle.getCertificateChains().get(3).getCertificates()).isEmpty();
68-
CertificateInfo cert1 = bundle.getCertificateChains().get(0).getCertificates().get(0);
72+
assertThat(bundle.getCertificateChains().get(3).getCertificates()).hasSize(1);
73+
CertificateInfo cert1 = bundle.getCertificateChains().get(0).getCertificates().iterator().next();
6974
assertThat(cert1.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
7075
assertThat(cert1.getIssuer()).isEqualTo(cert1.getSubject());
7176
assertThat(cert1.getSerialNumber()).isNotEmpty();
@@ -76,7 +81,7 @@ void validCertificatesShouldProvideSslInfo() {
7681
assertThat(cert1.getValidity()).isNotNull();
7782
assertThat(cert1.getValidity().getStatus()).isSameAs(Status.VALID);
7883
assertThat(cert1.getValidity().getMessage()).isNull();
79-
CertificateInfo cert2 = bundle.getCertificateChains().get(1).getCertificates().get(0);
84+
CertificateInfo cert2 = bundle.getCertificateChains().get(1).getCertificates().iterator().next();
8085
assertThat(cert2.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
8186
assertThat(cert2.getIssuer()).isEqualTo(cert2.getSubject());
8287
assertThat(cert2.getSerialNumber()).isNotEmpty();
@@ -89,19 +94,20 @@ void validCertificatesShouldProvideSslInfo() {
8994
assertThat(cert2.getValidity().getMessage()).isNull();
9095
}
9196

92-
@Test
97+
@ParameterizedTest
98+
@EnumSource(StoreType.class)
9399
@WithPackageResources("test-not-yet-valid.p12")
94-
void notYetValidCertificateShouldProvideSslInfo() {
95-
SslInfo sslInfo = createSslInfo("classpath:test-not-yet-valid.p12");
100+
void notYetValidCertificateShouldProvideSslInfo(StoreType storeType) {
101+
SslInfo sslInfo = createSslInfo(storeType, "classpath:test-not-yet-valid.p12");
96102
assertThat(sslInfo.getBundles()).hasSize(1);
97103
BundleInfo bundle = sslInfo.getBundles().get(0);
98104
assertThat(bundle.getName()).isEqualTo("test-0");
99105
assertThat(bundle.getCertificateChains()).hasSize(1);
100106
CertificateChainInfo certificateChain = bundle.getCertificateChains().get(0);
101107
assertThat(certificateChain.getAlias()).isEqualTo("spring-boot");
102-
List<CertificateInfo> certs = certificateChain.getCertificates();
108+
Set<CertificateInfo> certs = certificateChain.getCertificates();
103109
assertThat(certs).hasSize(1);
104-
CertificateInfo cert = certs.get(0);
110+
CertificateInfo cert = certs.iterator().next();
105111
assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
106112
assertThat(cert.getIssuer()).isEqualTo(cert.getSubject());
107113
assertThat(cert.getSerialNumber()).isNotEmpty();
@@ -124,9 +130,9 @@ void expiredCertificateShouldProvideSslInfo() {
124130
assertThat(bundle.getCertificateChains()).hasSize(1);
125131
CertificateChainInfo certificateChain = bundle.getCertificateChains().get(0);
126132
assertThat(certificateChain.getAlias()).isEqualTo("spring-boot");
127-
List<CertificateInfo> certs = certificateChain.getCertificates();
133+
Set<CertificateInfo> certs = certificateChain.getCertificates();
128134
assertThat(certs).hasSize(1);
129-
CertificateInfo cert = certs.get(0);
135+
CertificateInfo cert = certs.iterator().next();
130136
assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
131137
assertThat(cert.getIssuer()).isEqualTo(cert.getSubject());
132138
assertThat(cert.getSerialNumber()).isNotEmpty();
@@ -150,9 +156,9 @@ void soonToBeExpiredCertificateShouldProvideSslInfo(@TempDir Path tempDir)
150156
assertThat(bundle.getCertificateChains()).hasSize(1);
151157
CertificateChainInfo certificateChain = bundle.getCertificateChains().get(0);
152158
assertThat(certificateChain.getAlias()).isEqualTo("spring-boot");
153-
List<CertificateInfo> certs = certificateChain.getCertificates();
159+
Set<CertificateInfo> certs = certificateChain.getCertificates();
154160
assertThat(certs).hasSize(1);
155-
CertificateInfo cert = certs.get(0);
161+
CertificateInfo cert = certs.iterator().next();
156162
assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
157163
assertThat(cert.getIssuer()).isEqualTo(cert.getSubject());
158164
assertThat(cert.getSerialNumber()).isNotEmpty();
@@ -178,7 +184,7 @@ void multipleBundlesShouldProvideSslInfo(@TempDir Path tempDir) throws IOExcepti
178184
.flatMap((bundle) -> bundle.getCertificateChains().stream())
179185
.flatMap((certificateChain) -> certificateChain.getCertificates().stream())
180186
.toList();
181-
assertThat(certs).hasSize(5);
187+
assertThat(certs).hasSize(7);
182188
assertThat(certs).allSatisfy((cert) -> {
183189
assertThat(cert.getSubject()).isEqualTo("CN=localhost,OU=Spring,O=VMware,L=Palo Alto,ST=California,C=US");
184190
assertThat(cert.getIssuer()).isEqualTo(cert.getSubject());
@@ -227,10 +233,20 @@ void nullKeyStore() {
227233
}
228234

229235
private SslInfo createSslInfo(String... locations) {
236+
return createSslInfo(StoreType.KEYSTORE, locations);
237+
}
238+
239+
private SslInfo createSslInfo(StoreType storeType, String... locations) {
230240
DefaultSslBundleRegistry sslBundleRegistry = new DefaultSslBundleRegistry();
231241
for (int i = 0; i < locations.length; i++) {
232-
JksSslStoreDetails keyStoreDetails = JksSslStoreDetails.forLocation(locations[i]).withPassword("secret");
233-
SslStoreBundle sslStoreBundle = new JksSslStoreBundle(keyStoreDetails, null);
242+
JksSslStoreDetails storeDetails = JksSslStoreDetails.forLocation(locations[i]).withPassword("secret");
243+
SslStoreBundle sslStoreBundle;
244+
if (storeType == StoreType.TRUSTSTORE) {
245+
sslStoreBundle = new JksSslStoreBundle(null, storeDetails);
246+
}
247+
else {
248+
sslStoreBundle = new JksSslStoreBundle(storeDetails, null);
249+
}
234250
sslBundleRegistry.registerBundle("test-%d".formatted(i), SslBundle.of(sslStoreBundle));
235251
}
236252
return new SslInfo(sslBundleRegistry, Duration.ofDays(7));
@@ -270,4 +286,10 @@ private ProcessBuilder createProcessBuilder(Path keystore) {
270286
return processBuilder;
271287
}
272288

289+
private enum StoreType {
290+
291+
KEYSTORE, TRUSTSTORE
292+
293+
}
294+
273295
}

0 commit comments

Comments
 (0)