Skip to content

Commit 406e494

Browse files
committed
Add JDBC based implementation for RelyingPartyRegistrationRepository
Closes spring-projectsgh-16012 Signed-off-by: amirelgayed <[email protected]>
1 parent 8eee71a commit 406e494

File tree

2 files changed

+117
-0
lines changed

2 files changed

+117
-0
lines changed

saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ dependencies {
105105

106106
provided 'jakarta.servlet:jakarta.servlet-api'
107107

108+
optional 'org.springframework:spring-jdbc'
108109
optional 'com.fasterxml.jackson.core:jackson-databind'
109110

110111
testImplementation 'com.squareup.okhttp3:mockwebserver'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.saml2.provider.service.registration;
18+
19+
import org.springframework.jdbc.core.*;
20+
import org.springframework.util.Assert;
21+
import org.springframework.util.LinkedMultiValueMap;
22+
import org.springframework.util.MultiValueMap;
23+
24+
import java.io.ByteArrayInputStream;
25+
import java.sql.ResultSet;
26+
import java.sql.Types;
27+
import java.util.*;
28+
import java.util.stream.Collectors;
29+
30+
/**
31+
* A JDBC implementation of {@link RelyingPartyRegistrationRepository}. Also
32+
* implements {@link Iterable} to simplify the default login page.
33+
*
34+
* @author َAmir Elgayed
35+
* @since 6.5
36+
*/
37+
38+
// TODO https://github.com/spring-projects/spring-security/issues/16012
39+
public class JdbcRelyingPartyRegistrationRepository implements RelyingPartyRegistrationRepository {
40+
// TODO add logging
41+
private final JdbcOperations jdbcOperations;
42+
43+
// TODO ticket specifies that data is saved in the DB as XML
44+
private static final String TABLE_NAME = "relying_party_registration";
45+
46+
private static final String[] COLUMN_NAMES = {"id", "metadata"};
47+
48+
private static final String JOINT_COLUMN_NAMES = String.join(", ", COLUMN_NAMES);
49+
50+
private static final String FILTER = "id = ?";
51+
52+
// @formatter:off
53+
private static final String SAVE_RELYING_PARTY_REGISTRATION_SQL = "INSERT INTO " + TABLE_NAME
54+
+ " (" + JOINT_COLUMN_NAMES + ")"
55+
+ " VALUES ("
56+
+ Arrays.stream(COLUMN_NAMES).map(columnName -> "?").collect(Collectors.joining(", "))
57+
+")";
58+
59+
private static final String SELECT_RELYING_PARTY_REGISTRATION_SQL = "SELECT "
60+
+ JOINT_COLUMN_NAMES
61+
+ " FROM " + TABLE_NAME
62+
+ " WHERE " + FILTER;
63+
// @formatter:on
64+
65+
private static final String DELETE_RELYING_PARTY_REGISTRATION_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + FILTER;
66+
67+
/**
68+
* Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provide parameters.
69+
* @param jdbcOperations the JDBC operations
70+
*/
71+
public JdbcRelyingPartyRegistrationRepository(JdbcOperations jdbcOperations) {
72+
Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
73+
this.jdbcOperations = jdbcOperations;
74+
}
75+
76+
// TODO to be removed
77+
private static Map<String, RelyingPartyRegistration> createMappingToIdentityProvider(
78+
Collection<RelyingPartyRegistration> rps) {
79+
LinkedHashMap<String, RelyingPartyRegistration> result = new LinkedHashMap<>();
80+
for (RelyingPartyRegistration rp : rps) {
81+
Assert.notNull(rp, "relying party collection cannot contain null values");
82+
String key = rp.getRegistrationId();
83+
Assert.notNull(key, "relying party identifier cannot be null");
84+
Assert.isNull(result.get(key), () -> "relying party duplicate identifier '" + key + "' detected.");
85+
result.put(key, rp);
86+
}
87+
return Collections.unmodifiableMap(result);
88+
}
89+
90+
// TODO to be removed
91+
private static Map<String, List<RelyingPartyRegistration>> createMappingByAssertingPartyEntityId(
92+
Collection<RelyingPartyRegistration> rps) {
93+
MultiValueMap<String, RelyingPartyRegistration> result = new LinkedMultiValueMap<>();
94+
for (RelyingPartyRegistration rp : rps) {
95+
result.add(rp.getAssertingPartyMetadata().getEntityId(), rp);
96+
}
97+
return Collections.unmodifiableMap(result);
98+
}
99+
100+
private String fetchRelyingRegistryMetadata(String id){
101+
SqlParameterValue[] parameters = { new SqlParameterValue(Types.VARCHAR, id) };
102+
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
103+
return this.jdbcOperations.query(
104+
SELECT_RELYING_PARTY_REGISTRATION_SQL,
105+
pss,
106+
(ResultSet rs) -> rs.getString("metadata")
107+
);
108+
}
109+
110+
// TODO test contract: return null if no match is found
111+
@Override
112+
public RelyingPartyRegistration findByRegistrationId(String id) {
113+
String metadata = fetchRelyingRegistryMetadata(id);
114+
return RelyingPartyRegistrations.fromMetadata(new ByteArrayInputStream(metadata.getBytes())).build();
115+
}
116+
}

0 commit comments

Comments
 (0)