diff --git a/spring-security-kerberos-web/src/main/java/org/springframework/security/kerberos/web/authentication/SpnegoAuthenticationProcessingFilter.java b/spring-security-kerberos-web/src/main/java/org/springframework/security/kerberos/web/authentication/SpnegoAuthenticationProcessingFilter.java index c914b43..051ba37 100644 --- a/spring-security-kerberos-web/src/main/java/org/springframework/security/kerberos/web/authentication/SpnegoAuthenticationProcessingFilter.java +++ b/spring-security-kerberos-web/src/main/java/org/springframework/security/kerberos/web/authentication/SpnegoAuthenticationProcessingFilter.java @@ -27,7 +27,9 @@ import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.crypto.codec.Base64; import org.springframework.security.kerberos.authentication.KerberosServiceAuthenticationProvider; import org.springframework.security.kerberos.authentication.KerberosServiceRequestToken; @@ -36,6 +38,8 @@ import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; import org.springframework.security.web.authentication.session.NullAuthenticatedSessionStrategy; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; +import org.springframework.security.web.context.RequestAttributeSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; @@ -107,6 +111,9 @@ */ public class SpnegoAuthenticationProcessingFilter extends OncePerRequestFilter { + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + private SecurityContextRepository securityContextRepository = new RequestAttributeSecurityContextRepository(); private AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); private AuthenticationManager authenticationManager; private AuthenticationSuccessHandler successHandler; @@ -154,7 +161,7 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse // That shouldn't happen, as it is most likely a wrong // configuration on the server side logger.warn("Negotiate Header was invalid: " + header, e); - SecurityContextHolder.clearContext(); + securityContextHolderStrategy.clearContext(); if (failureHandler != null) { failureHandler.onAuthenticationFailure(request, response, e); } else { @@ -164,7 +171,11 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse return; } sessionStrategy.onAuthentication(authentication, request, response); - SecurityContextHolder.getContext().setAuthentication(authentication); + + SecurityContext context = securityContextHolderStrategy.createEmptyContext(); + context.setAuthentication(authentication); + securityContextHolderStrategy.setContext(context); + securityContextRepository.saveContext(context, request, response); if (successHandler != null) { successHandler.onAuthenticationSuccess(request, response, authentication); } @@ -264,4 +275,29 @@ public void setAuthenticationDetailsSource( public void setStopFilterChainOnSuccessfulAuthentication(boolean shouldStop) { this.stopFilterChainOnSuccessfulAuthentication = shouldStop; } + + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @param securityContextHolderStrategy the {@link SecurityContextHolderStrategy} to use. + * Cannot be null. + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } } diff --git a/spring-security-kerberos-web/src/test/java/org/springframework/security/kerberos/web/SpnegoAuthenticationProcessingFilterTest.java b/spring-security-kerberos-web/src/test/java/org/springframework/security/kerberos/web/SpnegoAuthenticationProcessingFilterTest.java index 59ba062..426e0db 100644 --- a/spring-security-kerberos-web/src/test/java/org/springframework/security/kerberos/web/SpnegoAuthenticationProcessingFilterTest.java +++ b/spring-security-kerberos-web/src/test/java/org/springframework/security/kerberos/web/SpnegoAuthenticationProcessingFilterTest.java @@ -41,6 +41,7 @@ import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.security.web.context.SecurityContextRepository; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; @@ -145,11 +146,14 @@ private void everythingWorksVerifyHandlers() throws Exception { private void everythingWorks(String tokenPrefix) throws IOException, ServletException { // stubbing + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + filter.setSecurityContextRepository(securityContextRepository); everythingWorksStub(tokenPrefix); // testing filter.doFilter(request, response, chain); verify(chain).doFilter(request, response); + verify(securityContextRepository).saveContext(SecurityContextHolder.getContext(), request, response); assertEquals(AUTHENTICATION, SecurityContextHolder.getContext().getAuthentication()); }