diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java index 8acc64bf97c..0bd46918bf9 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,6 +43,8 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo private Saml2AuthenticationRequestRepository authenticationRequestRepository; + private Boolean shouldInflate; + /** * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for * resolving {@link RelyingPartyRegistration}s @@ -86,16 +88,26 @@ public void setAuthenticationRequestRepository( this.authenticationRequestRepository = authenticationRequestRepository; } + /** + * Use the given {@code shouldInflate} to inflate request. + * @param shouldInflate the {@code shouldInflate} to use + * @since 7.0 + */ + public void setShouldInflateResponse(boolean shouldInflate) { + this.shouldInflate = shouldInflate; + } + private String decode(HttpServletRequest request) { + // prevent to break passivity in Saml2LoginBeanDefinitionParserTests + if (this.shouldInflate == null) { + this.shouldInflate = HttpMethod.GET.matches(request.getMethod()); + } String encoded = request.getParameter(Saml2ParameterNames.SAML_RESPONSE); if (encoded == null) { return null; } try { - return Saml2Utils.withEncoded(encoded) - .requireBase64(true) - .inflate(HttpMethod.GET.matches(request.getMethod())) - .decode(); + return Saml2Utils.withEncoded(encoded).requireBase64(true).inflate(this.shouldInflate).decode(); } catch (Exception ex) { throw new Saml2AuthenticationException(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, ex.getMessage()), diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java index f54788a4ea7..9db3a5c1c0d 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,6 +61,7 @@ public class Saml2AuthenticationTokenConverterTests { public void convertWhenSamlResponseThenToken() { Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter( this.relyingPartyRegistrationResolver); + converter.setShouldInflateResponse(false); given(this.relyingPartyRegistrationResolver.resolve(any(HttpServletRequest.class), any())) .willReturn(this.relyingPartyRegistration); MockHttpServletRequest request = new MockHttpServletRequest(); @@ -76,6 +77,7 @@ public void convertWhenSamlResponseThenToken() { public void convertWhenSamlResponseWithRelyingPartyRegistrationResolver( @Mock RelyingPartyRegistrationResolver resolver) { Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(resolver); + converter.setShouldInflateResponse(false); given(resolver.resolve(any(HttpServletRequest.class), any())).willReturn(this.relyingPartyRegistration); MockHttpServletRequest request = new MockHttpServletRequest(); request.setParameter(Saml2ParameterNames.SAML_RESPONSE, @@ -161,6 +163,7 @@ public void convertWhenGetRequestInvalidDeflatedThenSaml2AuthenticationException public void convertWhenUsingSamlUtilsBase64ThenXmlIsValid() throws Exception { Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter( this.relyingPartyRegistrationResolver); + converter.setShouldInflateResponse(false); given(this.relyingPartyRegistrationResolver.resolve(any(HttpServletRequest.class), any())) .willReturn(this.relyingPartyRegistration); MockHttpServletRequest request = new MockHttpServletRequest(); @@ -178,6 +181,7 @@ public void convertWhenSavedAuthenticationRequestThenToken() { .willReturn(this.relyingPartyRegistration.getRegistrationId()); Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter( this.relyingPartyRegistrationResolver); + converter.setShouldInflateResponse(false); converter.setAuthenticationRequestRepository(authenticationRequestRepository); given(this.relyingPartyRegistrationResolver.resolve(any(HttpServletRequest.class), any())) .willReturn(this.relyingPartyRegistration); @@ -203,6 +207,7 @@ public void convertWhenSavedAuthenticationRequestThenTokenWithRelyingPartyRegist .willReturn(this.relyingPartyRegistration.getRegistrationId()); Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(resolver); converter.setAuthenticationRequestRepository(authenticationRequestRepository); + converter.setShouldInflateResponse(false); given(resolver.resolve(any(HttpServletRequest.class), any())).willReturn(this.relyingPartyRegistration); given(authenticationRequestRepository.loadAuthenticationRequest(any(HttpServletRequest.class))) .willReturn(authenticationRequest);