diff --git a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle index 8b40d415d61..b05c1bbd577 100644 --- a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle +++ b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle @@ -106,6 +106,7 @@ dependencies { provided 'jakarta.servlet:jakarta.servlet-api' optional 'com.fasterxml.jackson.core:jackson-databind' + optional 'org.springframework:spring-jdbc' testImplementation 'com.squareup.okhttp3:mockwebserver' testImplementation "org.assertj:assertj-core" @@ -118,6 +119,7 @@ dependencies { testImplementation "org.springframework:spring-test" testRuntimeOnly 'org.junit.platform:junit-platform-launcher' + testRuntimeOnly 'org.hsqldb:hsqldb' } jar { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcRelyingPartyRegistrationRepository.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcRelyingPartyRegistrationRepository.java new file mode 100644 index 00000000000..1240eb0dc76 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcRelyingPartyRegistrationRepository.java @@ -0,0 +1,451 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.registration; + +import java.nio.charset.StandardCharsets; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.security.interfaces.RSAPrivateKey; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Types; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.function.Consumer; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.cryptacular.util.KeyPairUtil; +import org.opensaml.security.x509.X509Support; + +import org.springframework.core.log.LogMessage; +import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.PreparedStatementSetter; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.SqlParameterValue; +import org.springframework.jdbc.support.lob.DefaultLobHandler; +import org.springframework.jdbc.support.lob.LobHandler; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * A JDBC implementation of {@link RelyingPartyRegistrationRepository}. Also implements + * {@link Iterable} to simplify the default login page. + */ +public class JdbcRelyingPartyRegistrationRepository implements IterableRelyingPartyRegistrationRepository { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + // @formatter:off + static final String COLUMN_NAMES = "id, " + + "entity_id, " + + "name_id_format, " + + "acs_location, " + + "acs_binding, " + + "signing_credentials, " + + "decryption_credentials, " + + "singlelogout_url, " + + "singlelogout_response_url, " + + "singlelogout_binding, " + + "assertingparty_entity_id, " + + "assertingparty_metadata_uri, " + + "assertingparty_singlesignon_url, " + + "assertingparty_singlesignon_binding, " + + "assertingparty_singlesignon_sign_request, " + + "assertingparty_verification_credentials, " + + "assertingparty_singlelogout_url, " + + "assertingparty_singlelogout_response_url, " + + "assertingparty_singlelogout_binding"; + // @formatter:on + + private static final String TABLE_NAME = "saml2_relying_party_registration"; + + private static final String PK_FILTER = "id = ?"; + + private static final String ENTITY_ID_FILTER = "entity_id = ?"; + + // @formatter:off + private static final String LOAD_BY_ID_SQL = "SELECT " + COLUMN_NAMES + + " FROM " + TABLE_NAME + + " WHERE " + PK_FILTER; + + private static final String LOAD_BY_ENTITY_ID_SQL = "SELECT " + COLUMN_NAMES + + " FROM " + TABLE_NAME + + " WHERE " + ENTITY_ID_FILTER; + + private static final String LOAD_ALL_SQL = "SELECT " + COLUMN_NAMES + + " FROM " + TABLE_NAME; + // @formatter:on + + protected final JdbcOperations jdbcOperations; + + protected RowMapper relyingPartyRegistrationRowMapper; + + protected final LobHandler lobHandler; + + /** + * Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided + * parameters. + * @param jdbcOperations the JDBC operations + */ + public JdbcRelyingPartyRegistrationRepository(JdbcOperations jdbcOperations) { + this(jdbcOperations, new DefaultLobHandler()); + } + + /** + * Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided + * parameters. + * @param jdbcOperations the JDBC operations + * @param lobHandler the handler for large binary fields and large text fields + */ + public JdbcRelyingPartyRegistrationRepository(JdbcOperations jdbcOperations, LobHandler lobHandler) { + Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); + Assert.notNull(lobHandler, "lobHandler cannot be null"); + this.jdbcOperations = jdbcOperations; + this.lobHandler = lobHandler; + RelyingPartyRegistrationRowMapper rowMapper = new RelyingPartyRegistrationRowMapper(); + rowMapper.setLobHandler(lobHandler); + this.relyingPartyRegistrationRowMapper = rowMapper; + } + + /** + * Sets the {@link RowMapper} used for mapping the current row in + * {@code java.sql.ResultSet} to {@link RelyingPartyRegistration}. The default is + * {@link RelyingPartyRegistrationRowMapper}. + * @param relyingPartyRegistrationRowMapper the {@link RowMapper} used for mapping the + * current row in {@code java.sql.ResultSet} to {@link RelyingPartyRegistration} + */ + public final void setAuthorizedClientRowMapper( + RowMapper relyingPartyRegistrationRowMapper) { + Assert.notNull(relyingPartyRegistrationRowMapper, "relyingPartyRegistrationRowMapper cannot be null"); + this.relyingPartyRegistrationRowMapper = relyingPartyRegistrationRowMapper; + } + + @Override + public RelyingPartyRegistration findByRegistrationId(String registrationId) { + Assert.hasText(registrationId, "registrationId cannot be empty"); + SqlParameterValue[] parameters = new SqlParameterValue[] { + new SqlParameterValue(Types.VARCHAR, registrationId) }; + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); + List result = this.jdbcOperations.query(LOAD_BY_ID_SQL, pss, + this.relyingPartyRegistrationRowMapper); + return !result.isEmpty() ? result.get(0) : null; + } + + @Override + public RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) { + Assert.hasText(entityId, "entityId cannot be empty"); + SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, entityId) }; + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); + List result = this.jdbcOperations.query(LOAD_BY_ENTITY_ID_SQL, pss, + this.relyingPartyRegistrationRowMapper); + return !result.isEmpty() ? result.get(0) : null; + } + + @Override + public Iterator iterator() { + List result = this.jdbcOperations.query(LOAD_ALL_SQL, + this.relyingPartyRegistrationRowMapper); + return result.iterator(); + } + + /** + * The default {@link RowMapper} that maps the current row in + * {@code java.sql.ResultSet} to {@link RelyingPartyRegistration}. + */ + public static class RelyingPartyRegistrationRowMapper implements RowMapper { + + private final Log logger = LogFactory.getLog(getClass()); + + protected LobHandler lobHandler = new DefaultLobHandler(); + + public final void setLobHandler(LobHandler lobHandler) { + Assert.notNull(lobHandler, "lobHandler cannot be null"); + this.lobHandler = lobHandler; + } + + @Override + public RelyingPartyRegistration mapRow(ResultSet rs, int rowNum) throws SQLException { + String registrationId = rs.getString("id"); + String entityId = StringUtils.hasText(rs.getString("entity_id")) ? rs.getString("entity_id") + : "{baseUrl}/saml2/service-provider-metadata/{registrationId}"; + String nameIdFormat = rs.getString("name_id_format"); + String acsLocation = StringUtils.hasText(rs.getString("acs_location")) ? rs.getString("acs_location") + : "{baseUrl}/login/saml2/sso/{registrationId}"; + String acsBinding = StringUtils.hasText(rs.getString("acs_binding")) ? rs.getString("acs_binding") + : Saml2MessageBinding.POST.getUrn(); + List signingCredentials; + try { + signingCredentials = parseCredentials(getLobValue(rs, "signing_credentials")); + } + catch (JsonProcessingException ex) { + this.logger.error(LogMessage.format("Signing credentials of %s could not be parsed.", registrationId), + ex); + return null; + } + List decryptionCredentials; + try { + decryptionCredentials = parseCredentials(getLobValue(rs, "decryption_credentials")); + } + catch (JsonProcessingException ex) { + this.logger + .error(LogMessage.format("Decryption credentials of %s could not be parsed.", registrationId), ex); + return null; + } + String singleLogoutUrl = rs.getString("singlelogout_url"); + String singleLogoutResponseUrl = rs.getString("singlelogout_response_url"); + Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding.from(rs.getString("singlelogout_binding")); + String assertingPartyEntityId = rs.getString("assertingparty_entity_id"); + String assertingPartyMetadataUri = rs.getString("assertingparty_metadata_uri"); + String assertingPartySingleSignOnUrl = rs.getString("assertingparty_singlesignon_url"); + Saml2MessageBinding assertingPartySingleSignOnBinding = Saml2MessageBinding + .from(rs.getString("assertingparty_singlesignon_binding")); + Boolean assertingPartySingleSignOnSignRequest = rs.getBoolean("assertingparty_singlesignon_sign_request"); + List assertingPartyVerificationCredentials; + try { + assertingPartyVerificationCredentials = parseCertificate( + getLobValue(rs, "assertingparty_verification_credentials")); + } + catch (JsonProcessingException ex) { + this.logger.error( + LogMessage.format("Verification certificate of %s could not be parsed.", registrationId), ex); + return null; + } + String assertingPartySingleLogoutUrl = rs.getString("assertingparty_singlelogout_url"); + String assertingPartySingleLogoutResponseUrl = rs.getString("assertingparty_singlelogout_response_url"); + Saml2MessageBinding assertingPartySingleLogoutBinding = Saml2MessageBinding + .from(rs.getString("assertingparty_singlelogout_binding")); + + boolean usingMetadata = StringUtils.hasText(assertingPartyMetadataUri); + RelyingPartyRegistration.Builder builder = (!usingMetadata) + ? RelyingPartyRegistration.withRegistrationId(registrationId) + : createBuilderUsingMetadata(assertingPartyEntityId, assertingPartyMetadataUri) + .registrationId(registrationId); + builder.assertionConsumerServiceLocation(acsLocation); + builder.assertionConsumerServiceBinding(Saml2MessageBinding.from(acsBinding)); + builder.assertingPartyMetadata(mapAssertingParty(assertingPartyEntityId, assertingPartySingleSignOnBinding, + assertingPartySingleSignOnUrl, assertingPartySingleSignOnSignRequest, + assertingPartySingleLogoutBinding, assertingPartySingleLogoutResponseUrl, + assertingPartySingleLogoutUrl)); + + for (Credential credential : signingCredentials) { + try { + Saml2X509Credential signingCredential = asSigningCredential(credential); + builder.signingX509Credentials((credentials) -> credentials.add(signingCredential)); + } + catch (Exception ex) { + this.logger.error(LogMessage.format("Signing credentials of %s must have a valid certificate.", + registrationId), ex); + return null; + } + } + for (Credential credential : decryptionCredentials) { + try { + Saml2X509Credential decryptionCredential = asDecryptionCredential(credential); + builder.decryptionX509Credentials((credentials) -> credentials.add(decryptionCredential)); + } + catch (Exception ex) { + this.logger.error(LogMessage.format("Decryption credentials of %s must have a valid certificate.", + registrationId), ex); + return null; + } + } + List verificationCredentials = new ArrayList<>(); + for (Certificate certificate : assertingPartyVerificationCredentials) { + try { + verificationCredentials.add(asVerificationCredential(certificate)); + } + catch (Exception ex) { + this.logger.error(LogMessage.format("Verification credentials of %s must have a valid certificate.", + registrationId), ex); + return null; + } + } + builder.assertingPartyMetadata((details) -> details + .verificationX509Credentials((credentials) -> credentials.addAll(verificationCredentials))); + builder.singleLogoutServiceLocation(singleLogoutUrl); + builder.singleLogoutServiceResponseLocation(singleLogoutResponseUrl); + builder.singleLogoutServiceBinding(singleLogoutBinding); + builder.entityId(entityId); + builder.nameIdFormat(nameIdFormat); + RelyingPartyRegistration registration = builder.build(); + boolean signRequest = registration.getAssertingPartyMetadata().getWantAuthnRequestsSigned(); + if (signRequest && signingCredentials.isEmpty()) { + this.logger + .error(LogMessage.format("Signing credentials of %s must not be empty when authentication requests " + + "require signing.", registrationId)); + return null; + } + return registration; + } + + private Saml2X509Credential asSigningCredential(Credential credential) throws Exception { + RSAPrivateKey privateKey = readPrivateKey(credential.getPrivateKey()); + X509Certificate certificate = readCertificate(credential.getCertificate()); + return new Saml2X509Credential(privateKey, certificate, + Saml2X509Credential.Saml2X509CredentialType.SIGNING); + } + + private Saml2X509Credential asDecryptionCredential(Credential credential) throws Exception { + RSAPrivateKey privateKey = readPrivateKey(credential.getPrivateKey()); + X509Certificate certificate = readCertificate(credential.getCertificate()); + return new Saml2X509Credential(privateKey, certificate, + Saml2X509Credential.Saml2X509CredentialType.DECRYPTION); + } + + private Saml2X509Credential asVerificationCredential(Certificate certificate) throws Exception { + X509Certificate x509Certificate = readCertificate(certificate.getCertificate()); + return new Saml2X509Credential(x509Certificate, Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION, + Saml2X509Credential.Saml2X509CredentialType.VERIFICATION); + } + + private RSAPrivateKey readPrivateKey(String privateKey) { + return (RSAPrivateKey) KeyPairUtil.decodePrivateKey(privateKey.getBytes(StandardCharsets.UTF_8)); + } + + private X509Certificate readCertificate(String certificate) throws CertificateException { + return X509Support.decodeCertificate(certificate); + } + + private Consumer> mapAssertingParty(String assertingPartyEntityId, + Saml2MessageBinding assertingPartySingleSignOnBinding, String assertingPartySingleSignOnUrl, + Boolean assertingPartySingleSignOnSignRequest, Saml2MessageBinding assertingPartySingleLogoutBinding, + String assertingPartySingleLogoutResponseUrl, String assertingPartySingleLogoutUrl) { + return (details) -> { + applyingWhenNonNull(assertingPartyEntityId, details::entityId); + applyingWhenNonNull(assertingPartySingleSignOnBinding, details::singleSignOnServiceBinding); + applyingWhenNonNull(assertingPartySingleSignOnUrl, details::singleSignOnServiceLocation); + applyingWhenNonNull(assertingPartySingleSignOnSignRequest, details::wantAuthnRequestsSigned); + applyingWhenNonNull(assertingPartySingleLogoutUrl, details::singleLogoutServiceLocation); + applyingWhenNonNull(assertingPartySingleLogoutResponseUrl, + details::singleLogoutServiceResponseLocation); + applyingWhenNonNull(assertingPartySingleLogoutBinding, details::singleLogoutServiceBinding); + }; + } + + private void applyingWhenNonNull(T value, Consumer consumer) { + if (value != null) { + consumer.accept(value); + } + } + + private RelyingPartyRegistration.Builder createBuilderUsingMetadata(String assertingPartyEntityId, + String assertingPartyMetadataUri) { + Collection candidates = RelyingPartyRegistrations + .collectionFromMetadataLocation(assertingPartyMetadataUri); + for (RelyingPartyRegistration.Builder candidate : candidates) { + if (assertingPartyEntityId == null || assertingPartyEntityId.equals(getEntityId(candidate))) { + return candidate; + } + } + throw new IllegalStateException("No relying party with Entity ID '" + assertingPartyEntityId + "' found"); + } + + private Object getEntityId(RelyingPartyRegistration.Builder candidate) { + String[] result = new String[1]; + candidate.assertingPartyMetadata((builder) -> result[0] = builder.build().getEntityId()); + return result[0]; + } + + private List parseCredentials(String credentials) throws JsonProcessingException { + if (!StringUtils.hasText(credentials)) { + return new ArrayList<>(); + } + return OBJECT_MAPPER.readValue(credentials, new TypeReference<>() { + }); + } + + private List parseCertificate(String certificate) throws JsonProcessingException { + if (!StringUtils.hasText(certificate)) { + return new ArrayList<>(); + } + return OBJECT_MAPPER.readValue(certificate, new TypeReference<>() { + }); + } + + private String getLobValue(ResultSet rs, String columnName) throws SQLException { + String columnValue = null; + byte[] columnValueBytes = this.lobHandler.getBlobAsBytes(rs, columnName); + if (columnValueBytes != null) { + columnValue = new String(columnValueBytes, StandardCharsets.UTF_8); + } + return columnValue; + } + + } + + public static class Certificate { + + private String certificate; + + public Certificate() { + } + + public Certificate(String certificate) { + this.certificate = certificate; + } + + public String getCertificate() { + return this.certificate; + } + + public void setCertificate(String certificate) { + this.certificate = certificate; + } + + } + + public static class Credential { + + private String privateKey; + + private String certificate; + + public Credential() { + } + + public Credential(String privateKey, String certificate) { + this.privateKey = privateKey; + this.certificate = certificate; + } + + public String getPrivateKey() { + return this.privateKey; + } + + public void setPrivateKey(String privateKey) { + this.privateKey = privateKey; + } + + public String getCertificate() { + return this.certificate; + } + + public void setCertificate(String certificate) { + this.certificate = certificate; + } + + } + +} diff --git a/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-relying-party-registration-schema-postgres.sql b/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-relying-party-registration-schema-postgres.sql new file mode 100644 index 00000000000..bd5997df282 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-relying-party-registration-schema-postgres.sql @@ -0,0 +1,23 @@ +CREATE TABLE saml2_relying_party_registration +( + id VARCHAR(200) NOT NULL, + entity_id VARCHAR(1000), + name_id_format VARCHAR(200), + acs_location VARCHAR(1000), + acs_binding VARCHAR(200), + signing_credentials BYTEA, + decryption_credentials BYTEA, + singlelogout_url VARCHAR(1000), + singlelogout_response_url VARCHAR(1000), + singlelogout_binding VARCHAR(200), + assertingparty_entity_id VARCHAR(1000), + assertingparty_metadata_uri VARCHAR(1000), + assertingparty_singlesignon_url VARCHAR(1000), + assertingparty_singlesignon_binding VARCHAR(200), + assertingparty_singlesignon_sign_request VARCHAR(1000), + assertingparty_verification_credentials BYTEA, + assertingparty_singlelogout_url VARCHAR(1000), + assertingparty_singlelogout_response_url VARCHAR(1000), + assertingparty_singlelogout_binding VARCHAR(200), + PRIMARY KEY (id) +); diff --git a/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-relying-party-registration-schema.sql b/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-relying-party-registration-schema.sql new file mode 100644 index 00000000000..70ed301d899 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-relying-party-registration-schema.sql @@ -0,0 +1,23 @@ +CREATE TABLE saml2_relying_party_registration +( + id VARCHAR(200) NOT NULL, + entity_id VARCHAR(1000), + name_id_format VARCHAR(200), + acs_location VARCHAR(1000), + acs_binding VARCHAR(200), + signing_credentials blob, + decryption_credentials blob, + singlelogout_url VARCHAR(1000), + singlelogout_response_url VARCHAR(1000), + singlelogout_binding VARCHAR(200), + assertingparty_entity_id VARCHAR(1000), + assertingparty_metadata_uri VARCHAR(1000), + assertingparty_singlesignon_url VARCHAR(1000), + assertingparty_singlesignon_binding VARCHAR(200), + assertingparty_singlesignon_sign_request VARCHAR(1000), + assertingparty_verification_credentials blob, + assertingparty_singlelogout_url VARCHAR(1000), + assertingparty_singlelogout_response_url VARCHAR(1000), + assertingparty_singlelogout_binding VARCHAR(200), + PRIMARY KEY (id) +); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcRelyingPartyRegistrationRepositoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcRelyingPartyRegistrationRepositoryTests.java new file mode 100644 index 00000000000..1f7722faec2 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcRelyingPartyRegistrationRepositoryTests.java @@ -0,0 +1,520 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.registration; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.security.interfaces.RSAPrivateKey; +import java.util.Base64; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; +import org.springframework.security.converter.RsaKeyConverters; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link JdbcRelyingPartyRegistrationRepository} + */ +public class JdbcRelyingPartyRegistrationRepositoryTests { + + private static final String RP_REGISTRATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/saml2/saml2-relying-party-registration-schema.sql"; + + private static final String SAVE_RP_REGISTRATION_SQL = "INSERT INTO saml2_relying_party_registration" + " (" + + JdbcRelyingPartyRegistrationRepository.COLUMN_NAMES + + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; + + private static final String REGISTRATION_ID = "adfs"; + + private static final String ENTITY_ID = "https://rp.example.org/saml2/service-provider-metadata/adfs"; + + private static final String NAME_ID_FORMAT = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"; + + private static final String ACS_LOCATION = "https://rp.example.org/login/saml2/sso/adfs"; + + private static final String ACS_BINDING = Saml2MessageBinding.POST.getUrn(); + + private static final String SINGLE_LOGOUT_URL = "https://rp.example.org/logout/saml2/slo"; + + private static final String SINGLE_LOGOUT_RESPONSE_URL = "https://rp.example.org/logout/saml2/slo"; + + private static final String SINGLE_LOGOUT_BINDING = Saml2MessageBinding.POST.getUrn(); + + private static final String ASSERTINGPARTY_ENTITY_ID = "https://localhost/simplesaml/saml2/idp/metadata.php"; + + private static final String ASSERTINGPARTY_SINGLE_SIGNON_URL = "https://localhost/SSO"; + + private static final String ASSERTINGPARTY_SINGLE_SIGNON_BINDING = Saml2MessageBinding.REDIRECT.getUrn(); + + private static final String ASSERTINGPARTY_SINGLE_SIGNON_SIGN_REQUEST = "true"; + + private static final String ASSERTINGPARTY_SINGLE_LOGOUT_URL = "https://localhost/SLO"; + + private static final String ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL = "https://localhost/SLO/response"; + + private static final String ASSERTINGPARTY_SINGLE_LOGOUT_BINDING = Saml2MessageBinding.REDIRECT.getUrn(); + + private String assertingpartyMetadataUri; + + private String signingCredentials; + + private String decryptionCredentials; + + private String assertingpartyVerificationCredentials; + + private EmbeddedDatabase db; + + private JdbcRelyingPartyRegistrationRepository repository; + + private JdbcOperations jdbcOperations; + + private final ObjectMapper objectMapper = new ObjectMapper(); + + private final MockWebServer mockWebServer = new MockWebServer(); + + @BeforeEach + public void setUp() throws Exception { + this.db = createDb(); + this.jdbcOperations = new JdbcTemplate(this.db); + this.repository = new JdbcRelyingPartyRegistrationRepository(this.jdbcOperations); + + ClassPathResource resource = new ClassPathResource("test-federated-metadata.xml"); + String metadata; + try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { + metadata = reader.lines().collect(Collectors.joining()); + } + + this.mockWebServer.enqueue(new MockResponse().setBody(metadata).setResponseCode(200)); + this.assertingpartyMetadataUri = this.mockWebServer.url("/metadata").toString(); + + X509Certificate x509Certificate = loadCertificate("rsa.crt"); + RSAPrivateKey rsaPrivateKey = loadPrivateKey("rsa.key"); + String credentials = this.objectMapper + .writeValueAsString(List.of(new JdbcRelyingPartyRegistrationRepository.Credential( + Base64.getEncoder().encodeToString(rsaPrivateKey.getEncoded()), + Base64.getEncoder().encodeToString(x509Certificate.getEncoded())))); + this.signingCredentials = credentials; + this.decryptionCredentials = credentials; + this.assertingpartyVerificationCredentials = this.objectMapper + .writeValueAsString(List.of(new JdbcRelyingPartyRegistrationRepository.Certificate( + Base64.getEncoder().encodeToString(x509Certificate.getEncoded())))); + } + + @AfterEach + public void tearDown() throws IOException { + this.db.shutdown(); + this.mockWebServer.close(); + } + + @Test + void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new JdbcRelyingPartyRegistrationRepository(null)) + .withMessage("jdbcOperations cannot be null"); + // @formatter:on + } + + @Test + void findByRegistrationIdWhenCredentialIdIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.repository.findByRegistrationId(null)) + .withMessage("registrationId cannot be empty"); + // @formatter:on + } + + @Test + void findByRegistrationIdWhenAssertingpartyMetadataUriIsNull() { + this.jdbcOperations.update(SAVE_RP_REGISTRATION_SQL, REGISTRATION_ID, ENTITY_ID, NAME_ID_FORMAT, ACS_LOCATION, + ACS_BINDING, this.signingCredentials.getBytes(StandardCharsets.UTF_8), + this.decryptionCredentials.getBytes(StandardCharsets.UTF_8), SINGLE_LOGOUT_URL, + SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING, ASSERTINGPARTY_ENTITY_ID, null, + ASSERTINGPARTY_SINGLE_SIGNON_URL, ASSERTINGPARTY_SINGLE_SIGNON_BINDING, + ASSERTINGPARTY_SINGLE_SIGNON_SIGN_REQUEST, + this.assertingpartyVerificationCredentials.getBytes(StandardCharsets.UTF_8), + ASSERTINGPARTY_SINGLE_LOGOUT_URL, ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL, + ASSERTINGPARTY_SINGLE_LOGOUT_BINDING); + + RelyingPartyRegistration result = this.repository.findByRegistrationId(REGISTRATION_ID); + + assertThat(result).isNotNull(); + assertThat(result.getRegistrationId()).isEqualTo(REGISTRATION_ID); + assertThat(result.getEntityId()).isEqualTo(ENTITY_ID); + assertThat(result.getNameIdFormat()).isEqualTo(NAME_ID_FORMAT); + assertThat(result.getAssertionConsumerServiceLocation()).isEqualTo(ACS_LOCATION); + assertThat(result.getAssertionConsumerServiceBinding()).isEqualTo(Saml2MessageBinding.from(ACS_BINDING)); + assertThat(result.getSingleLogoutServiceLocation()).isEqualTo(SINGLE_LOGOUT_URL); + assertThat(result.getSingleLogoutServiceResponseLocation()).isEqualTo(SINGLE_LOGOUT_RESPONSE_URL); + assertThat(result.getSingleLogoutServiceBinding()).isEqualTo(Saml2MessageBinding.from(SINGLE_LOGOUT_BINDING)); + assertThat(result.getSigningX509Credentials()).hasSize(1); + assertThat(result.getDecryptionX509Credentials()).hasSize(1); + AssertingPartyMetadata apm = result.getAssertingPartyMetadata(); + assertThat(apm.getEntityId()).isEqualTo(ASSERTINGPARTY_ENTITY_ID); + assertThat(apm.getSingleSignOnServiceLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_SIGNON_URL); + assertThat(apm.getSingleSignOnServiceBinding()) + .isEqualTo(Saml2MessageBinding.from(ASSERTINGPARTY_SINGLE_SIGNON_BINDING)); + assertThat(apm.getWantAuthnRequestsSigned()).isTrue(); + assertThat(apm.getVerificationX509Credentials()).hasSize(1); + assertThat(apm.getSingleLogoutServiceLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_LOGOUT_URL); + assertThat(apm.getSingleLogoutServiceResponseLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL); + assertThat(apm.getSingleLogoutServiceBinding()) + .isEqualTo(Saml2MessageBinding.from(ASSERTINGPARTY_SINGLE_LOGOUT_BINDING)); + } + + @Test + void findByRegistrationIdWhenAssertingpartyMetadataUriExists() { + this.jdbcOperations.update(SAVE_RP_REGISTRATION_SQL, REGISTRATION_ID, ENTITY_ID, NAME_ID_FORMAT, ACS_LOCATION, + ACS_BINDING, this.signingCredentials.getBytes(StandardCharsets.UTF_8), + this.decryptionCredentials.getBytes(StandardCharsets.UTF_8), SINGLE_LOGOUT_URL, + SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING, ASSERTINGPARTY_ENTITY_ID, + this.assertingpartyMetadataUri, null, null, null, null, null, null, null); + + RelyingPartyRegistration result = this.repository.findByRegistrationId(REGISTRATION_ID); + + assertThat(result).isNotNull(); + assertThat(result.getRegistrationId()).isEqualTo(REGISTRATION_ID); + assertThat(result.getEntityId()).isEqualTo(ENTITY_ID); + assertThat(result.getNameIdFormat()).isEqualTo(NAME_ID_FORMAT); + assertThat(result.getAssertionConsumerServiceLocation()).isEqualTo(ACS_LOCATION); + assertThat(result.getAssertionConsumerServiceBinding()).isEqualTo(Saml2MessageBinding.from(ACS_BINDING)); + assertThat(result.getSingleLogoutServiceLocation()).isEqualTo(SINGLE_LOGOUT_URL); + assertThat(result.getSingleLogoutServiceResponseLocation()).isEqualTo(SINGLE_LOGOUT_RESPONSE_URL); + assertThat(result.getSingleLogoutServiceBinding()).isEqualTo(Saml2MessageBinding.from(SINGLE_LOGOUT_BINDING)); + assertThat(result.getSigningX509Credentials()).hasSize(1); + assertThat(result.getDecryptionX509Credentials()).hasSize(1); + AssertingPartyMetadata apm = result.getAssertingPartyMetadata(); + assertThat(apm.getEntityId()).isEqualTo(ASSERTINGPARTY_ENTITY_ID); + assertThat(apm.getSingleSignOnServiceLocation()) + .isEqualTo("https://localhost/simplesaml/saml2/idp/SSOService.php"); + assertThat(apm.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + assertThat(apm.getWantAuthnRequestsSigned()).isFalse(); + assertThat(apm.getVerificationX509Credentials()).hasSize(1); + assertThat(apm.getSingleLogoutServiceLocation()) + .isEqualTo("https://localhost/simplesaml/saml2/idp/SingleLogoutService.php"); + assertThat(apm.getSingleLogoutServiceResponseLocation()) + .isEqualTo("https://localhost/simplesaml/saml2/idp/SingleLogoutService.php"); + assertThat(apm.getSingleLogoutServiceBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + } + + @Test + void findByRegistrationIdWhenAssertingpartyMetadataUriExistsAndAssertingpartyMetadataExists() { + this.jdbcOperations.update(SAVE_RP_REGISTRATION_SQL, REGISTRATION_ID, ENTITY_ID, NAME_ID_FORMAT, ACS_LOCATION, + ACS_BINDING, this.signingCredentials.getBytes(StandardCharsets.UTF_8), + this.decryptionCredentials.getBytes(StandardCharsets.UTF_8), SINGLE_LOGOUT_URL, + SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING, ASSERTINGPARTY_ENTITY_ID, + this.assertingpartyMetadataUri, ASSERTINGPARTY_SINGLE_SIGNON_URL, ASSERTINGPARTY_SINGLE_SIGNON_BINDING, + ASSERTINGPARTY_SINGLE_SIGNON_SIGN_REQUEST, + this.assertingpartyVerificationCredentials.getBytes(StandardCharsets.UTF_8), + ASSERTINGPARTY_SINGLE_LOGOUT_URL, ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL, + ASSERTINGPARTY_SINGLE_LOGOUT_BINDING); + + RelyingPartyRegistration result = this.repository.findByRegistrationId(REGISTRATION_ID); + + assertThat(result).isNotNull(); + assertThat(result.getRegistrationId()).isEqualTo(REGISTRATION_ID); + assertThat(result.getEntityId()).isEqualTo(ENTITY_ID); + assertThat(result.getNameIdFormat()).isEqualTo(NAME_ID_FORMAT); + assertThat(result.getAssertionConsumerServiceLocation()).isEqualTo(ACS_LOCATION); + assertThat(result.getAssertionConsumerServiceBinding()).isEqualTo(Saml2MessageBinding.from(ACS_BINDING)); + assertThat(result.getSingleLogoutServiceLocation()).isEqualTo(SINGLE_LOGOUT_URL); + assertThat(result.getSingleLogoutServiceResponseLocation()).isEqualTo(SINGLE_LOGOUT_RESPONSE_URL); + assertThat(result.getSingleLogoutServiceBinding()).isEqualTo(Saml2MessageBinding.from(SINGLE_LOGOUT_BINDING)); + assertThat(result.getSigningX509Credentials()).hasSize(1); + assertThat(result.getDecryptionX509Credentials()).hasSize(1); + AssertingPartyMetadata apm = result.getAssertingPartyMetadata(); + assertThat(apm.getEntityId()).isEqualTo(ASSERTINGPARTY_ENTITY_ID); + assertThat(apm.getSingleSignOnServiceLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_SIGNON_URL); + assertThat(apm.getSingleSignOnServiceBinding()) + .isEqualTo(Saml2MessageBinding.from(ASSERTINGPARTY_SINGLE_SIGNON_BINDING)); + assertThat(apm.getWantAuthnRequestsSigned()).isTrue(); + assertThat(apm.getVerificationX509Credentials()).hasSize(2); + assertThat(apm.getSingleLogoutServiceLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_LOGOUT_URL); + assertThat(apm.getSingleLogoutServiceResponseLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL); + assertThat(apm.getSingleLogoutServiceBinding()) + .isEqualTo(Saml2MessageBinding.from(ASSERTINGPARTY_SINGLE_LOGOUT_BINDING)); + } + + @Test + void findByRegistrationIdWhenMissingEntityIdAndAcsLocationAndAcsBinding() { + this.jdbcOperations.update(SAVE_RP_REGISTRATION_SQL, REGISTRATION_ID, null, NAME_ID_FORMAT, null, null, + this.signingCredentials.getBytes(StandardCharsets.UTF_8), + this.decryptionCredentials.getBytes(StandardCharsets.UTF_8), SINGLE_LOGOUT_URL, + SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING, ASSERTINGPARTY_ENTITY_ID, null, + ASSERTINGPARTY_SINGLE_SIGNON_URL, ASSERTINGPARTY_SINGLE_SIGNON_BINDING, + ASSERTINGPARTY_SINGLE_SIGNON_SIGN_REQUEST, + this.assertingpartyVerificationCredentials.getBytes(StandardCharsets.UTF_8), + ASSERTINGPARTY_SINGLE_LOGOUT_URL, ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL, + ASSERTINGPARTY_SINGLE_LOGOUT_BINDING); + + RelyingPartyRegistration result = this.repository.findByRegistrationId(REGISTRATION_ID); + + assertThat(result).isNotNull(); + assertThat(result.getRegistrationId()).isEqualTo(REGISTRATION_ID); + assertThat(result.getEntityId()).isEqualTo("{baseUrl}/saml2/service-provider-metadata/{registrationId}"); + assertThat(result.getNameIdFormat()).isEqualTo(NAME_ID_FORMAT); + assertThat(result.getAssertionConsumerServiceLocation()) + .isEqualTo("{baseUrl}/login/saml2/sso/{registrationId}"); + assertThat(result.getAssertionConsumerServiceBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(result.getSingleLogoutServiceLocation()).isEqualTo(SINGLE_LOGOUT_URL); + assertThat(result.getSingleLogoutServiceResponseLocation()).isEqualTo(SINGLE_LOGOUT_RESPONSE_URL); + assertThat(result.getSingleLogoutServiceBinding()).isEqualTo(Saml2MessageBinding.from(SINGLE_LOGOUT_BINDING)); + assertThat(result.getSigningX509Credentials()).hasSize(1); + assertThat(result.getDecryptionX509Credentials()).hasSize(1); + AssertingPartyMetadata apm = result.getAssertingPartyMetadata(); + assertThat(apm.getEntityId()).isEqualTo(ASSERTINGPARTY_ENTITY_ID); + assertThat(apm.getSingleSignOnServiceLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_SIGNON_URL); + assertThat(apm.getSingleSignOnServiceBinding()) + .isEqualTo(Saml2MessageBinding.from(ASSERTINGPARTY_SINGLE_SIGNON_BINDING)); + assertThat(apm.getWantAuthnRequestsSigned()).isTrue(); + assertThat(apm.getVerificationX509Credentials()).hasSize(1); + assertThat(apm.getSingleLogoutServiceLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_LOGOUT_URL); + assertThat(apm.getSingleLogoutServiceResponseLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL); + assertThat(apm.getSingleLogoutServiceBinding()) + .isEqualTo(Saml2MessageBinding.from(ASSERTINGPARTY_SINGLE_LOGOUT_BINDING)); + } + + @Test + void findByRegistrationIdWhenWantAuthnRequestsSignedButMissingSigningCredentials() { + this.jdbcOperations.update(SAVE_RP_REGISTRATION_SQL, REGISTRATION_ID, ENTITY_ID, NAME_ID_FORMAT, ACS_LOCATION, + ACS_BINDING, null, this.decryptionCredentials.getBytes(StandardCharsets.UTF_8), SINGLE_LOGOUT_URL, + SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING, ASSERTINGPARTY_ENTITY_ID, null, + ASSERTINGPARTY_SINGLE_SIGNON_URL, ASSERTINGPARTY_SINGLE_SIGNON_BINDING, + ASSERTINGPARTY_SINGLE_SIGNON_SIGN_REQUEST, + this.assertingpartyVerificationCredentials.getBytes(StandardCharsets.UTF_8), + ASSERTINGPARTY_SINGLE_LOGOUT_URL, ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL, + ASSERTINGPARTY_SINGLE_LOGOUT_BINDING); + + RelyingPartyRegistration result = this.repository.findByRegistrationId(REGISTRATION_ID); + assertThat(result).isNull(); + } + + @Test + void findByRegistrationIdWhenSigningCredentialsWithNonPrivateKey() throws JsonProcessingException { + String invalidCredentials = this.objectMapper + .writeValueAsString(List.of(new JdbcRelyingPartyRegistrationRepository.Credential(null, ""))); + this.jdbcOperations.update(SAVE_RP_REGISTRATION_SQL, REGISTRATION_ID, ENTITY_ID, NAME_ID_FORMAT, ACS_LOCATION, + ACS_BINDING, invalidCredentials.getBytes(StandardCharsets.UTF_8), + this.decryptionCredentials.getBytes(StandardCharsets.UTF_8), SINGLE_LOGOUT_URL, + SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING, ASSERTINGPARTY_ENTITY_ID, null, + ASSERTINGPARTY_SINGLE_SIGNON_URL, ASSERTINGPARTY_SINGLE_SIGNON_BINDING, + ASSERTINGPARTY_SINGLE_SIGNON_SIGN_REQUEST, + this.assertingpartyVerificationCredentials.getBytes(StandardCharsets.UTF_8), + ASSERTINGPARTY_SINGLE_LOGOUT_URL, ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL, + ASSERTINGPARTY_SINGLE_LOGOUT_BINDING); + + RelyingPartyRegistration result = this.repository.findByRegistrationId(REGISTRATION_ID); + assertThat(result).isNull(); + } + + @Test + void findByRegistrationIdWhenSigningCredentialsIsInvalid() throws JsonProcessingException { + String invalidCredentials = this.objectMapper + .writeValueAsString(List.of(new JdbcRelyingPartyRegistrationRepository.Credential("invalid", "invalid"))); + this.jdbcOperations.update(SAVE_RP_REGISTRATION_SQL, REGISTRATION_ID, ENTITY_ID, NAME_ID_FORMAT, ACS_LOCATION, + ACS_BINDING, invalidCredentials.getBytes(StandardCharsets.UTF_8), + this.decryptionCredentials.getBytes(StandardCharsets.UTF_8), SINGLE_LOGOUT_URL, + SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING, ASSERTINGPARTY_ENTITY_ID, null, + ASSERTINGPARTY_SINGLE_SIGNON_URL, ASSERTINGPARTY_SINGLE_SIGNON_BINDING, + ASSERTINGPARTY_SINGLE_SIGNON_SIGN_REQUEST, + this.assertingpartyVerificationCredentials.getBytes(StandardCharsets.UTF_8), + ASSERTINGPARTY_SINGLE_LOGOUT_URL, ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL, + ASSERTINGPARTY_SINGLE_LOGOUT_BINDING); + + RelyingPartyRegistration result = this.repository.findByRegistrationId(REGISTRATION_ID); + assertThat(result).isNull(); + } + + @Test + void findByRegistrationIdWhenDecryptionCredentialsIsInvalid() throws JsonProcessingException { + String invalidCredentials = this.objectMapper + .writeValueAsString(List.of(new JdbcRelyingPartyRegistrationRepository.Credential("invalid", "invalid"))); + this.jdbcOperations.update(SAVE_RP_REGISTRATION_SQL, REGISTRATION_ID, ENTITY_ID, NAME_ID_FORMAT, ACS_LOCATION, + ACS_BINDING, this.signingCredentials.getBytes(StandardCharsets.UTF_8), + invalidCredentials.getBytes(StandardCharsets.UTF_8), SINGLE_LOGOUT_URL, SINGLE_LOGOUT_RESPONSE_URL, + SINGLE_LOGOUT_BINDING, ASSERTINGPARTY_ENTITY_ID, null, ASSERTINGPARTY_SINGLE_SIGNON_URL, + ASSERTINGPARTY_SINGLE_SIGNON_BINDING, ASSERTINGPARTY_SINGLE_SIGNON_SIGN_REQUEST, + this.assertingpartyVerificationCredentials.getBytes(StandardCharsets.UTF_8), + ASSERTINGPARTY_SINGLE_LOGOUT_URL, ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL, + ASSERTINGPARTY_SINGLE_LOGOUT_BINDING); + + RelyingPartyRegistration result = this.repository.findByRegistrationId(REGISTRATION_ID); + assertThat(result).isNull(); + } + + @Test + void findByRegistrationIdWhenVerificationCredentialsIsInvalid() throws JsonProcessingException { + byte[] invalidVerificationCredential = this.signingCredentials.getBytes(StandardCharsets.UTF_8); + this.jdbcOperations.update(SAVE_RP_REGISTRATION_SQL, REGISTRATION_ID, ENTITY_ID, NAME_ID_FORMAT, ACS_LOCATION, + ACS_BINDING, this.signingCredentials.getBytes(StandardCharsets.UTF_8), + this.decryptionCredentials.getBytes(StandardCharsets.UTF_8), SINGLE_LOGOUT_URL, + SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING, ASSERTINGPARTY_ENTITY_ID, null, + ASSERTINGPARTY_SINGLE_SIGNON_URL, ASSERTINGPARTY_SINGLE_SIGNON_BINDING, + ASSERTINGPARTY_SINGLE_SIGNON_SIGN_REQUEST, invalidVerificationCredential, + ASSERTINGPARTY_SINGLE_LOGOUT_URL, ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL, + ASSERTINGPARTY_SINGLE_LOGOUT_BINDING); + + RelyingPartyRegistration result = this.repository.findByRegistrationId(REGISTRATION_ID); + assertThat(result).isNull(); + } + + @Test + void findByEntityId() { + this.jdbcOperations.update(SAVE_RP_REGISTRATION_SQL, REGISTRATION_ID, ENTITY_ID, NAME_ID_FORMAT, ACS_LOCATION, + ACS_BINDING, this.signingCredentials.getBytes(StandardCharsets.UTF_8), + this.decryptionCredentials.getBytes(StandardCharsets.UTF_8), SINGLE_LOGOUT_URL, + SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING, ASSERTINGPARTY_ENTITY_ID, null, + ASSERTINGPARTY_SINGLE_SIGNON_URL, ASSERTINGPARTY_SINGLE_SIGNON_BINDING, + ASSERTINGPARTY_SINGLE_SIGNON_SIGN_REQUEST, + this.assertingpartyVerificationCredentials.getBytes(StandardCharsets.UTF_8), + ASSERTINGPARTY_SINGLE_LOGOUT_URL, ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL, + ASSERTINGPARTY_SINGLE_LOGOUT_BINDING); + + RelyingPartyRegistration result = this.repository.findUniqueByAssertingPartyEntityId(ENTITY_ID); + + assertThat(result).isNotNull(); + assertThat(result.getRegistrationId()).isEqualTo(REGISTRATION_ID); + assertThat(result.getEntityId()).isEqualTo(ENTITY_ID); + assertThat(result.getNameIdFormat()).isEqualTo(NAME_ID_FORMAT); + assertThat(result.getAssertionConsumerServiceLocation()).isEqualTo(ACS_LOCATION); + assertThat(result.getAssertionConsumerServiceBinding()).isEqualTo(Saml2MessageBinding.from(ACS_BINDING)); + assertThat(result.getSingleLogoutServiceLocation()).isEqualTo(SINGLE_LOGOUT_URL); + assertThat(result.getSingleLogoutServiceResponseLocation()).isEqualTo(SINGLE_LOGOUT_RESPONSE_URL); + assertThat(result.getSingleLogoutServiceBinding()).isEqualTo(Saml2MessageBinding.from(SINGLE_LOGOUT_BINDING)); + assertThat(result.getSigningX509Credentials()).hasSize(1); + assertThat(result.getDecryptionX509Credentials()).hasSize(1); + AssertingPartyMetadata apm = result.getAssertingPartyMetadata(); + assertThat(apm.getEntityId()).isEqualTo(ASSERTINGPARTY_ENTITY_ID); + assertThat(apm.getSingleSignOnServiceLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_SIGNON_URL); + assertThat(apm.getSingleSignOnServiceBinding()) + .isEqualTo(Saml2MessageBinding.from(ASSERTINGPARTY_SINGLE_SIGNON_BINDING)); + assertThat(apm.getWantAuthnRequestsSigned()).isTrue(); + assertThat(apm.getVerificationX509Credentials()).hasSize(1); + assertThat(apm.getSingleLogoutServiceLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_LOGOUT_URL); + assertThat(apm.getSingleLogoutServiceResponseLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL); + assertThat(apm.getSingleLogoutServiceBinding()) + .isEqualTo(Saml2MessageBinding.from(ASSERTINGPARTY_SINGLE_LOGOUT_BINDING)); + } + + @Test + void iterator() { + this.jdbcOperations.update(SAVE_RP_REGISTRATION_SQL, REGISTRATION_ID, ENTITY_ID, NAME_ID_FORMAT, ACS_LOCATION, + ACS_BINDING, this.signingCredentials.getBytes(StandardCharsets.UTF_8), + this.decryptionCredentials.getBytes(StandardCharsets.UTF_8), SINGLE_LOGOUT_URL, + SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING, ASSERTINGPARTY_ENTITY_ID, null, + ASSERTINGPARTY_SINGLE_SIGNON_URL, ASSERTINGPARTY_SINGLE_SIGNON_BINDING, + ASSERTINGPARTY_SINGLE_SIGNON_SIGN_REQUEST, + this.assertingpartyVerificationCredentials.getBytes(StandardCharsets.UTF_8), + ASSERTINGPARTY_SINGLE_LOGOUT_URL, ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL, + ASSERTINGPARTY_SINGLE_LOGOUT_BINDING); + this.jdbcOperations.update(SAVE_RP_REGISTRATION_SQL, "okta", ENTITY_ID, NAME_ID_FORMAT, ACS_LOCATION, + ACS_BINDING, this.signingCredentials.getBytes(StandardCharsets.UTF_8), + this.decryptionCredentials.getBytes(StandardCharsets.UTF_8), SINGLE_LOGOUT_URL, + SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING, ASSERTINGPARTY_ENTITY_ID, null, + ASSERTINGPARTY_SINGLE_SIGNON_URL, ASSERTINGPARTY_SINGLE_SIGNON_BINDING, + ASSERTINGPARTY_SINGLE_SIGNON_SIGN_REQUEST, + this.assertingpartyVerificationCredentials.getBytes(StandardCharsets.UTF_8), + ASSERTINGPARTY_SINGLE_LOGOUT_URL, ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL, + ASSERTINGPARTY_SINGLE_LOGOUT_BINDING); + + Iterator iterator = this.repository.iterator(); + + RelyingPartyRegistration adfs = iterator.next(); + assertThat(adfs).isNotNull(); + assertThat(adfs.getRegistrationId()).isEqualTo(REGISTRATION_ID); + assertThat(adfs.getEntityId()).isEqualTo(ENTITY_ID); + assertThat(adfs.getNameIdFormat()).isEqualTo(NAME_ID_FORMAT); + assertThat(adfs.getAssertionConsumerServiceLocation()).isEqualTo(ACS_LOCATION); + assertThat(adfs.getAssertionConsumerServiceBinding()).isEqualTo(Saml2MessageBinding.from(ACS_BINDING)); + assertThat(adfs.getSingleLogoutServiceLocation()).isEqualTo(SINGLE_LOGOUT_URL); + assertThat(adfs.getSingleLogoutServiceResponseLocation()).isEqualTo(SINGLE_LOGOUT_RESPONSE_URL); + assertThat(adfs.getSingleLogoutServiceBinding()).isEqualTo(Saml2MessageBinding.from(SINGLE_LOGOUT_BINDING)); + assertThat(adfs.getSigningX509Credentials()).hasSize(1); + assertThat(adfs.getDecryptionX509Credentials()).hasSize(1); + AssertingPartyMetadata apm = adfs.getAssertingPartyMetadata(); + assertThat(apm.getEntityId()).isEqualTo(ASSERTINGPARTY_ENTITY_ID); + assertThat(apm.getSingleSignOnServiceLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_SIGNON_URL); + assertThat(apm.getSingleSignOnServiceBinding()) + .isEqualTo(Saml2MessageBinding.from(ASSERTINGPARTY_SINGLE_SIGNON_BINDING)); + assertThat(apm.getWantAuthnRequestsSigned()).isTrue(); + assertThat(apm.getVerificationX509Credentials()).hasSize(1); + assertThat(apm.getSingleLogoutServiceLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_LOGOUT_URL); + assertThat(apm.getSingleLogoutServiceResponseLocation()).isEqualTo(ASSERTINGPARTY_SINGLE_LOGOUT_RESPONSE_URL); + assertThat(apm.getSingleLogoutServiceBinding()) + .isEqualTo(Saml2MessageBinding.from(ASSERTINGPARTY_SINGLE_LOGOUT_BINDING)); + + RelyingPartyRegistration okta = iterator.next(); + assertThat(okta).isNotNull(); + assertThat(okta.getRegistrationId()).isEqualTo("okta"); + } + + private static EmbeddedDatabase createDb() { + return createDb(RP_REGISTRATION_SCHEMA_SQL_RESOURCE); + } + + private static EmbeddedDatabase createDb(String schema) { + // @formatter:off + return new EmbeddedDatabaseBuilder() + .generateUniqueName(true) + .setType(EmbeddedDatabaseType.HSQL) + .setScriptEncoding("UTF-8") + .addScript(schema) + .build(); + // @formatter:on + } + + private X509Certificate loadCertificate(String path) { + try (InputStream is = new ClassPathResource(path).getInputStream()) { + CertificateFactory factory = CertificateFactory.getInstance("X.509"); + return (X509Certificate) factory.generateCertificate(is); + } + catch (Exception ex) { + throw new RuntimeException("Error loading certificate from " + path, ex); + } + } + + private RSAPrivateKey loadPrivateKey(String path) { + try (InputStream is = new ClassPathResource(path).getInputStream()) { + return RsaKeyConverters.pkcs8().convert(is); + } + catch (Exception ex) { + throw new RuntimeException("Error loading private key from " + path, ex); + } + } + +} diff --git a/saml2/saml2-service-provider/src/test/resources/rsa.crt b/saml2/saml2-service-provider/src/test/resources/rsa.crt new file mode 100644 index 00000000000..aa147065ded --- /dev/null +++ b/saml2/saml2-service-provider/src/test/resources/rsa.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID1zCCAr+gAwIBAgIUCzQeKBMTO0iHVW3iKmZC41haqCowDQYJKoZIhvcNAQEL +BQAwezELMAkGA1UEBhMCWFgxEjAQBgNVBAgMCVN0YXRlTmFtZTERMA8GA1UEBwwI +Q2l0eU5hbWUxFDASBgNVBAoMC0NvbXBhbnlOYW1lMRswGQYDVQQLDBJDb21wYW55 +U2VjdGlvbk5hbWUxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0yMzA5MjAwODI5MDNa +Fw0zMzA5MTcwODI5MDNaMHsxCzAJBgNVBAYTAlhYMRIwEAYDVQQIDAlTdGF0ZU5h +bWUxETAPBgNVBAcMCENpdHlOYW1lMRQwEgYDVQQKDAtDb21wYW55TmFtZTEbMBkG +A1UECwwSQ29tcGFueVNlY3Rpb25OYW1lMRIwEAYDVQQDDAlsb2NhbGhvc3QwggEi +MA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDUfi4aaCotJZX6OSDjv6fxCCfc +ihSs91Z/mmN+yc1fsxVSs53SIbqUuo+Wzhv34kp8I/r03P9LWVTkFPbeDxAl75Oa +PGggxK55US0Zfy9Hj1BwWIKV3330N61emID1GDEtFKL4yJbJdreQXnIXTBL2o76V +nuV/tYozyZnb07IQ1WhUm5WDxgzM0yFudMynTczCBeZHfvharDtB8PFFhCZXW2/9 +TZVVfW4oOML8EAX3hvnvYBlFl/foxXekZSwq/odOkmWCZavT2+0sburHUlOnPGUh +Qj4tHwpMRczp7VX4ptV1D2UrxsK/2B+s9FK2QSLKQ9JzAYJ6WxQjHcvET9jvAgMB +AAGjUzBRMB0GA1UdDgQWBBQjDr/1E/01pfLPD8uWF7gbaYL0TTAfBgNVHSMEGDAW +gBQjDr/1E/01pfLPD8uWF7gbaYL0TTAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3 +DQEBCwUAA4IBAQAGjUuec0+0XNMCRDKZslbImdCAVsKsEWk6NpnUViDFAxL+KQuC +NW131UeHb9SCzMqRwrY4QI3nAwJQCmilL/hFM3ss4acn3WHu1yci/iKPUKeL1ec5 +kCFUmqX1NpTiVaytZ/9TKEr69SMVqNfQiuW5U1bIIYTqK8xo46WpM6YNNHO3eJK6 +NH0MW79Wx5ryi4i4C6afqYbVbx7tqcmy8CFeNxgZ0bFQ87SiwYXIj77b6sVYbu32 +doykBQgSHLcagWASPQ73m73CWUgo+7+EqSKIQqORbgmTLPmOUh99gFIx7jmjTyHm +NBszx1ZVWuIv3mWmp626Kncyc+LLM9tvgymx +-----END CERTIFICATE----- diff --git a/saml2/saml2-service-provider/src/test/resources/rsa.key b/saml2/saml2-service-provider/src/test/resources/rsa.key new file mode 100644 index 00000000000..e458f0d5eb4 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/resources/rsa.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDUfi4aaCotJZX6 +OSDjv6fxCCfcihSs91Z/mmN+yc1fsxVSs53SIbqUuo+Wzhv34kp8I/r03P9LWVTk +FPbeDxAl75OaPGggxK55US0Zfy9Hj1BwWIKV3330N61emID1GDEtFKL4yJbJdreQ +XnIXTBL2o76VnuV/tYozyZnb07IQ1WhUm5WDxgzM0yFudMynTczCBeZHfvharDtB +8PFFhCZXW2/9TZVVfW4oOML8EAX3hvnvYBlFl/foxXekZSwq/odOkmWCZavT2+0s +burHUlOnPGUhQj4tHwpMRczp7VX4ptV1D2UrxsK/2B+s9FK2QSLKQ9JzAYJ6WxQj +HcvET9jvAgMBAAECggEADdeRuZml1F65mDJm1enduaH+NWvEm1yEr3ecr0fbujYI +bQ89+CVx/znvRvPH4aFwQwmgUZl12JrfS05MTectoPMBf/obDwtmPDPmsV2rdEi9 +2jEB11vW23T8X7L6hOdzCKHqrd8kkhzK1LuPnhHlaFipU8YlOBOuMYpv8eB78y79 +Qkd5/ZEygFhqVGz96R7nT/xS21aPC7OPhicAauLLuguF4caCNhwkjLi3bizLemUn +4i41q69drg7G8WX6BTxzem5FupKfI8rn2EkOjO/biVRknzGxAdqkM8SDHWkqeOuY +8QVhc1kZsMkB0BGPlDPStUwEHSfUiND4GJTcngc++QKBgQD2lyeW3PoPjQ1qzjN4 +V/0XE77zpcPE5dW7chLtiWRY1dqk2uOJ32iOtxuqk9Q/YMSZyPJlTkfI5JePuC/B +MB+QXzXuWN03Vn0ZrOpQlxcdA4A1o10NT1nEw8kZlf4+LyUk8GpMGUhjnxFZpZbf +5S3fy0/2V8wGvOmXR65c8m6ASQKBgQDcmfCV5npu1HrtO8jmU9gBIhniNjB4IWue +TSRt3ANDQaVBqsVaIMe/mUEQrZ6MdikMeA4bobOA6bUYwOiq8JGWSenAzGL22TbA +W51q6A8hgDCuH1JnoagqUIbr61kwEVcfbRHEFpuxLURsjoDg/xBtwO96SxWPh5Wr ++f1q8t5/dwKBgGWc+AVk3e6Wk1bVzcPjjjl6O4+vWTLD+wUZBs+3dBBfX4/bWzQv +Sai1r8Lk0+uh9qHgenJghZg1CneA0LztFbSqZ1DmcZIiI7720D+RY0bjcGup++hG +MJmyjCXs9y2sw8OrBkKBkKDspXupjriIehTkdPjwSPTl1+Qs9575j6txAoGAT8n+ +ErnCHsQLkjLFf0lkH0TOR9uBvHGaEy+jtXiWVYUw2IeDyg2BMfOkbPvfFL7IKhJi +R+w8mKvvLHzZqrpIbitduLY0NURrYTfBwCEfF+bdtJzvmTwHLwbhRgNhxtj+wgcZ +HetvdK4CyaDhTH/02T2nYHw32CoaIJHS7xPZFhECgYEAv7xRawjlrC4V0BLjP3Ej +pk8BbsRABxN1CrS6nJK+So4u2gKQDsL3WA0oJTS8v8AD5LvQUNr1d57FVlq9lwCd +u623eOIuluCUZBVy1iYdkRXWz9pg5bCidCgEYUpF3SqpsuFou0XFzDD773UVQFVw +VYriYasPwmzS2y2P7PKFzJs= +-----END PRIVATE KEY-----