From 1e61c3867d7f9f1f2d2e1b38ff6a1d9e784a9d18 Mon Sep 17 00:00:00 2001 From: Jason Song Date: Sat, 8 Aug 2020 13:30:07 +0800 Subject: [PATCH] add access control support for admin service --- .../AdminServiceAutoConfiguration.java | 32 ++ .../AdminServiceAuthenticationFilter.java | 87 ++++ .../controller/AbstractControllerTest.java | 2 +- .../AdminServiceAuthenticationFilterTest.java | 210 +++++++++ ...nServiceAuthenticationIntegrationTest.java | 128 ++++++ .../filter/test-access-control-disabled.sql | 4 + .../test-access-control-enabled-no-token.sql | 3 + .../filter/test-access-control-enabled.sql | 4 + .../apollo/biz/config/BizConfig.java | 7 + .../RemoteConfigLongPollServiceTest.java | 3 +- .../internals/RemoteConfigRepositoryTest.java | 3 +- .../filter/ClientAuthenticationFilter.java | 3 +- .../ClientAuthenticationFilterTest.java | 5 +- .../apollo/core/signature/Signature.java | 4 +- .../apollo/core/signature/SignatureTest.java | 3 +- .../filter/ConsumerAuthenticationFilter.java | 3 +- .../component/RetryableRestTemplate.java | 111 ++++- .../portal/component/config/PortalConfig.java | 4 + .../main/resources/static/scripts/AppUtils.js | 2 +- .../ConsumerAuthenticationFilterTest.java | 5 +- .../portal/RetryableRestTemplateTest.java | 428 ++++++++++++++++-- 21 files changed, 977 insertions(+), 74 deletions(-) create mode 100644 apollo-adminservice/src/main/java/com/ctrip/framework/apollo/adminservice/AdminServiceAutoConfiguration.java create mode 100644 apollo-adminservice/src/main/java/com/ctrip/framework/apollo/adminservice/filter/AdminServiceAuthenticationFilter.java create mode 100644 apollo-adminservice/src/test/java/com/ctrip/framework/apollo/adminservice/filter/AdminServiceAuthenticationFilterTest.java create mode 100644 apollo-adminservice/src/test/java/com/ctrip/framework/apollo/adminservice/filter/AdminServiceAuthenticationIntegrationTest.java create mode 100644 apollo-adminservice/src/test/resources/filter/test-access-control-disabled.sql create mode 100644 apollo-adminservice/src/test/resources/filter/test-access-control-enabled-no-token.sql create mode 100644 apollo-adminservice/src/test/resources/filter/test-access-control-enabled.sql diff --git a/apollo-adminservice/src/main/java/com/ctrip/framework/apollo/adminservice/AdminServiceAutoConfiguration.java b/apollo-adminservice/src/main/java/com/ctrip/framework/apollo/adminservice/AdminServiceAutoConfiguration.java new file mode 100644 index 00000000000..bf61f015ee5 --- /dev/null +++ b/apollo-adminservice/src/main/java/com/ctrip/framework/apollo/adminservice/AdminServiceAutoConfiguration.java @@ -0,0 +1,32 @@ +package com.ctrip.framework.apollo.adminservice; + +import com.ctrip.framework.apollo.adminservice.filter.AdminServiceAuthenticationFilter; +import com.ctrip.framework.apollo.biz.config.BizConfig; +import org.springframework.boot.web.servlet.FilterRegistrationBean; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class AdminServiceAutoConfiguration { + + private final BizConfig bizConfig; + + public AdminServiceAutoConfiguration(final BizConfig bizConfig) { + this.bizConfig = bizConfig; + } + + @Bean + public FilterRegistrationBean adminServiceAuthenticationFilter() { + FilterRegistrationBean filterRegistrationBean = new FilterRegistrationBean<>(); + + filterRegistrationBean.setFilter(new AdminServiceAuthenticationFilter(bizConfig)); + filterRegistrationBean.addUrlPatterns("/apps/*"); + filterRegistrationBean.addUrlPatterns("/appnamespaces/*"); + filterRegistrationBean.addUrlPatterns("/instances/*"); + filterRegistrationBean.addUrlPatterns("/items/*"); + filterRegistrationBean.addUrlPatterns("/namespaces/*"); + filterRegistrationBean.addUrlPatterns("/releases/*"); + + return filterRegistrationBean; + } +} diff --git a/apollo-adminservice/src/main/java/com/ctrip/framework/apollo/adminservice/filter/AdminServiceAuthenticationFilter.java b/apollo-adminservice/src/main/java/com/ctrip/framework/apollo/adminservice/filter/AdminServiceAuthenticationFilter.java new file mode 100644 index 00000000000..5ff022cd7fe --- /dev/null +++ b/apollo-adminservice/src/main/java/com/ctrip/framework/apollo/adminservice/filter/AdminServiceAuthenticationFilter.java @@ -0,0 +1,87 @@ +package com.ctrip.framework.apollo.adminservice.filter; + +import com.ctrip.framework.apollo.biz.config.BizConfig; +import com.google.common.base.Splitter; +import com.google.common.base.Strings; +import java.io.IOException; +import java.util.List; +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpHeaders; + +public class AdminServiceAuthenticationFilter implements Filter { + + private static final Logger logger = LoggerFactory + .getLogger(AdminServiceAuthenticationFilter.class); + private static final Splitter ACCESS_TOKEN_SPLITTER = Splitter.on(",").omitEmptyStrings() + .trimResults(); + + private final BizConfig bizConfig; + private volatile String lastAccessTokens; + private volatile List accessTokenList; + + public AdminServiceAuthenticationFilter(BizConfig bizConfig) { + this.bizConfig = bizConfig; + } + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + + } + + @Override + public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain) + throws IOException, ServletException { + if (bizConfig.isAdminServiceAccessControlEnabled()) { + HttpServletRequest request = (HttpServletRequest) req; + HttpServletResponse response = (HttpServletResponse) resp; + + String token = request.getHeader(HttpHeaders.AUTHORIZATION); + + if (!checkAccessToken(token)) { + logger.warn("Invalid access token: {} for uri: {}", token, request.getRequestURI()); + response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Unauthorized"); + return; + } + } + + chain.doFilter(req, resp); + } + + private boolean checkAccessToken(String token) { + String accessTokens = bizConfig.getAdminServiceAccessTokens(); + + // if user forget to configure access tokens, then default to pass + if (Strings.isNullOrEmpty(accessTokens)) { + return true; + } + + // no need to check + if (Strings.isNullOrEmpty(token)) { + return false; + } + + // update cache + if (!accessTokens.equals(lastAccessTokens)) { + synchronized (this) { + accessTokenList = ACCESS_TOKEN_SPLITTER.splitToList(accessTokens); + lastAccessTokens = accessTokens; + } + } + + return accessTokenList.contains(token); + } + + @Override + public void destroy() { + + } +} diff --git a/apollo-adminservice/src/test/java/com/ctrip/framework/apollo/adminservice/controller/AbstractControllerTest.java b/apollo-adminservice/src/test/java/com/ctrip/framework/apollo/adminservice/controller/AbstractControllerTest.java index 1afd888d907..3d4922495b3 100644 --- a/apollo-adminservice/src/test/java/com/ctrip/framework/apollo/adminservice/controller/AbstractControllerTest.java +++ b/apollo-adminservice/src/test/java/com/ctrip/framework/apollo/adminservice/controller/AbstractControllerTest.java @@ -31,7 +31,7 @@ private void postConstruct() { } @Value("${local.server.port}") - int port; + protected int port; protected String url(String path) { return "http://localhost:" + port + path; diff --git a/apollo-adminservice/src/test/java/com/ctrip/framework/apollo/adminservice/filter/AdminServiceAuthenticationFilterTest.java b/apollo-adminservice/src/test/java/com/ctrip/framework/apollo/adminservice/filter/AdminServiceAuthenticationFilterTest.java new file mode 100644 index 00000000000..4d9a3ba3b65 --- /dev/null +++ b/apollo-adminservice/src/test/java/com/ctrip/framework/apollo/adminservice/filter/AdminServiceAuthenticationFilterTest.java @@ -0,0 +1,210 @@ +package com.ctrip.framework.apollo.adminservice.filter; + +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.ctrip.framework.apollo.biz.config.BizConfig; +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.http.HttpHeaders; + +@RunWith(MockitoJUnitRunner.class) +public class AdminServiceAuthenticationFilterTest { + + @Mock + private BizConfig bizConfig; + private HttpServletRequest servletRequest; + private HttpServletResponse servletResponse; + private FilterChain filterChain; + + private AdminServiceAuthenticationFilter authenticationFilter; + + @Before + public void setUp() throws Exception { + authenticationFilter = new AdminServiceAuthenticationFilter(bizConfig); + initVariables(); + } + + private void initVariables() { + servletRequest = mock(HttpServletRequest.class); + servletResponse = mock(HttpServletResponse.class); + filterChain = mock(FilterChain.class); + } + + @Test + public void testWithAccessControlDisabled() throws Exception { + when(bizConfig.isAdminServiceAccessControlEnabled()).thenReturn(false); + + authenticationFilter.doFilter(servletRequest, servletResponse, filterChain); + + verify(bizConfig, times(1)).isAdminServiceAccessControlEnabled(); + verify(filterChain, times(1)).doFilter(servletRequest, servletResponse); + verify(bizConfig, never()).getAdminServiceAccessTokens(); + verify(servletRequest, never()).getHeader(HttpHeaders.AUTHORIZATION); + verify(servletResponse, never()).sendError(anyInt(), anyString()); + } + + @Test + public void testWithAccessControlEnabledWithTokenSpecifiedWithValidTokenPassed() + throws Exception { + String someValidToken = "someToken"; + + when(bizConfig.isAdminServiceAccessControlEnabled()).thenReturn(true); + when(bizConfig.getAdminServiceAccessTokens()).thenReturn(someValidToken); + when(servletRequest.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(someValidToken); + + authenticationFilter.doFilter(servletRequest, servletResponse, filterChain); + + verify(bizConfig, times(1)).isAdminServiceAccessControlEnabled(); + verify(bizConfig, times(1)).getAdminServiceAccessTokens(); + verify(filterChain, times(1)).doFilter(servletRequest, servletResponse); + verify(servletResponse, never()).sendError(anyInt(), anyString()); + } + + @Test + public void testWithAccessControlEnabledWithTokenSpecifiedWithInvalidTokenPassed() + throws Exception { + String someValidToken = "someValidToken"; + String someInvalidToken = "someInvalidToken"; + + when(bizConfig.isAdminServiceAccessControlEnabled()).thenReturn(true); + when(bizConfig.getAdminServiceAccessTokens()).thenReturn(someValidToken); + when(servletRequest.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(someInvalidToken); + + authenticationFilter.doFilter(servletRequest, servletResponse, filterChain); + + verify(bizConfig, times(1)).isAdminServiceAccessControlEnabled(); + verify(bizConfig, times(1)).getAdminServiceAccessTokens(); + verify(servletResponse, times(1)) + .sendError(HttpServletResponse.SC_UNAUTHORIZED, "Unauthorized"); + verify(filterChain, never()).doFilter(servletRequest, servletResponse); + } + + @Test + public void testWithAccessControlEnabledWithTokenSpecifiedWithNoTokenPassed() throws Exception { + String someValidToken = "someValidToken"; + + when(bizConfig.isAdminServiceAccessControlEnabled()).thenReturn(true); + when(bizConfig.getAdminServiceAccessTokens()).thenReturn(someValidToken); + when(servletRequest.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(null); + + authenticationFilter.doFilter(servletRequest, servletResponse, filterChain); + + verify(bizConfig, times(1)).isAdminServiceAccessControlEnabled(); + verify(bizConfig, times(1)).getAdminServiceAccessTokens(); + verify(servletResponse, times(1)) + .sendError(HttpServletResponse.SC_UNAUTHORIZED, "Unauthorized"); + verify(filterChain, never()).doFilter(servletRequest, servletResponse); + } + + + @Test + public void testWithAccessControlEnabledWithMultipleTokenSpecifiedWithValidTokenPassed() + throws Exception { + String someToken = "someToken"; + String anotherToken = "anotherToken"; + + when(bizConfig.isAdminServiceAccessControlEnabled()).thenReturn(true); + when(bizConfig.getAdminServiceAccessTokens()) + .thenReturn(String.format("%s,%s", someToken, anotherToken)); + when(servletRequest.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(someToken); + + authenticationFilter.doFilter(servletRequest, servletResponse, filterChain); + + verify(bizConfig, times(1)).isAdminServiceAccessControlEnabled(); + verify(bizConfig, times(1)).getAdminServiceAccessTokens(); + verify(filterChain, times(1)).doFilter(servletRequest, servletResponse); + verify(servletResponse, never()).sendError(anyInt(), anyString()); + } + + @Test + public void testWithAccessControlEnabledWithNoTokenSpecifiedWithTokenPassed() throws Exception { + String someToken = "someToken"; + + when(bizConfig.isAdminServiceAccessControlEnabled()).thenReturn(true); + when(bizConfig.getAdminServiceAccessTokens()).thenReturn(null); + when(servletRequest.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(someToken); + + authenticationFilter.doFilter(servletRequest, servletResponse, filterChain); + + verify(bizConfig, times(1)).isAdminServiceAccessControlEnabled(); + verify(bizConfig, times(1)).getAdminServiceAccessTokens(); + verify(filterChain, times(1)).doFilter(servletRequest, servletResponse); + verify(servletResponse, never()).sendError(anyInt(), anyString()); + } + + @Test + public void testWithAccessControlEnabledWithNoTokenSpecifiedWithNoTokenPassed() throws Exception { + String someToken = "someToken"; + + when(bizConfig.isAdminServiceAccessControlEnabled()).thenReturn(true); + when(bizConfig.getAdminServiceAccessTokens()).thenReturn(null); + when(servletRequest.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(null); + + authenticationFilter.doFilter(servletRequest, servletResponse, filterChain); + + verify(bizConfig, times(1)).isAdminServiceAccessControlEnabled(); + verify(bizConfig, times(1)).getAdminServiceAccessTokens(); + verify(filterChain, times(1)).doFilter(servletRequest, servletResponse); + verify(servletResponse, never()).sendError(anyInt(), anyString()); + } + + @Test + public void testWithConfigChanged() throws Exception { + String someToken = "someToken"; + String anotherToken = "anotherToken"; + String yetAnotherToken = "yetAnotherToken"; + + // case 1: init state + when(bizConfig.isAdminServiceAccessControlEnabled()).thenReturn(true); + when(bizConfig.getAdminServiceAccessTokens()).thenReturn(someToken); + + when(servletRequest.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(someToken); + + authenticationFilter.doFilter(servletRequest, servletResponse, filterChain); + + verify(filterChain, times(1)).doFilter(servletRequest, servletResponse); + verify(servletResponse, never()).sendError(anyInt(), anyString()); + + // case 2: change access tokens specified + initVariables(); + when(bizConfig.getAdminServiceAccessTokens()) + .thenReturn(String.format("%s,%s", anotherToken, yetAnotherToken)); + when(servletRequest.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(someToken); + + authenticationFilter.doFilter(servletRequest, servletResponse, filterChain); + + verify(servletResponse, times(1)) + .sendError(HttpServletResponse.SC_UNAUTHORIZED, "Unauthorized"); + verify(filterChain, never()).doFilter(servletRequest, servletResponse); + + initVariables(); + when(servletRequest.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(anotherToken); + + authenticationFilter.doFilter(servletRequest, servletResponse, filterChain); + + verify(filterChain, times(1)).doFilter(servletRequest, servletResponse); + verify(servletResponse, never()).sendError(anyInt(), anyString()); + + // case 3: change access control flag + initVariables(); + when(bizConfig.isAdminServiceAccessControlEnabled()).thenReturn(false); + + authenticationFilter.doFilter(servletRequest, servletResponse, filterChain); + + verify(filterChain, times(1)).doFilter(servletRequest, servletResponse); + verify(servletResponse, never()).sendError(anyInt(), anyString()); + verify(servletRequest, never()).getHeader(HttpHeaders.AUTHORIZATION); + } +} \ No newline at end of file diff --git a/apollo-adminservice/src/test/java/com/ctrip/framework/apollo/adminservice/filter/AdminServiceAuthenticationIntegrationTest.java b/apollo-adminservice/src/test/java/com/ctrip/framework/apollo/adminservice/filter/AdminServiceAuthenticationIntegrationTest.java new file mode 100644 index 00000000000..23d424559f5 --- /dev/null +++ b/apollo-adminservice/src/test/java/com/ctrip/framework/apollo/adminservice/filter/AdminServiceAuthenticationIntegrationTest.java @@ -0,0 +1,128 @@ +package com.ctrip.framework.apollo.adminservice.filter; + +import com.ctrip.framework.apollo.adminservice.controller.AbstractControllerTest; +import com.ctrip.framework.apollo.common.config.RefreshablePropertySource; +import com.ctrip.framework.apollo.common.dto.AppDTO; +import java.util.List; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.jdbc.Sql; +import org.springframework.test.context.jdbc.Sql.ExecutionPhase; +import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.web.client.HttpClientErrorException; + +@DirtiesContext +public class AdminServiceAuthenticationIntegrationTest extends AbstractControllerTest { + + @Autowired + private List propertySources; + + @Before + public void setUp() throws Exception { + doRefresh(propertySources); + } + + @Test + @Sql(scripts = "/controller/test-release.sql", executionPhase = ExecutionPhase.BEFORE_TEST_METHOD) + @Sql(scripts = "/filter/test-access-control-disabled.sql", executionPhase = ExecutionPhase.BEFORE_TEST_METHOD) + @Sql(scripts = "/controller/cleanup.sql", executionPhase = ExecutionPhase.AFTER_TEST_METHOD) + public void testWithAccessControlDisabledExplicitly() { + String appId = "someAppId"; + AppDTO app = restTemplate + .getForObject("http://localhost:" + port + "/apps/" + appId, AppDTO.class); + + Assert.assertEquals("someAppId", app.getAppId()); + } + + @Test + @Sql(scripts = "/controller/test-release.sql", executionPhase = ExecutionPhase.BEFORE_TEST_METHOD) + @Sql(scripts = "/filter/test-access-control-disabled.sql", executionPhase = ExecutionPhase.BEFORE_TEST_METHOD) + @Sql(scripts = "/controller/cleanup.sql", executionPhase = ExecutionPhase.AFTER_TEST_METHOD) + public void testWithAccessControlDisabledExplicitlyWithAccessToken() { + String appId = "someAppId"; + String someToken = "someToken"; + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.AUTHORIZATION, someToken); + HttpEntity entity = new HttpEntity<>(headers); + + AppDTO app = restTemplate + .exchange("http://localhost:" + port + "/apps/" + appId, HttpMethod.GET, entity, + AppDTO.class).getBody(); + + Assert.assertEquals("someAppId", app.getAppId()); + } + + @Test + @Sql(scripts = "/controller/test-release.sql", executionPhase = ExecutionPhase.BEFORE_TEST_METHOD) + @Sql(scripts = "/filter/test-access-control-enabled.sql", executionPhase = ExecutionPhase.BEFORE_TEST_METHOD) + @Sql(scripts = "/controller/cleanup.sql", executionPhase = ExecutionPhase.AFTER_TEST_METHOD) + public void testWithAccessControlEnabledWithValidAccessToken() { + String appId = "someAppId"; + String someValidToken = "someToken"; + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.AUTHORIZATION, someValidToken); + HttpEntity entity = new HttpEntity<>(headers); + + AppDTO app = restTemplate + .exchange("http://localhost:" + port + "/apps/" + appId, HttpMethod.GET, entity, + AppDTO.class).getBody(); + + Assert.assertEquals("someAppId", app.getAppId()); + } + + @Test(expected = HttpClientErrorException.class) + @Sql(scripts = "/controller/test-release.sql", executionPhase = ExecutionPhase.BEFORE_TEST_METHOD) + @Sql(scripts = "/filter/test-access-control-enabled.sql", executionPhase = ExecutionPhase.BEFORE_TEST_METHOD) + @Sql(scripts = "/controller/cleanup.sql", executionPhase = ExecutionPhase.AFTER_TEST_METHOD) + public void testWithAccessControlEnabledWithNoAccessToken() { + String appId = "someAppId"; + AppDTO app = restTemplate + .getForObject("http://localhost:" + port + "/apps/" + appId, AppDTO.class); + } + + @Test(expected = HttpClientErrorException.class) + @Sql(scripts = "/controller/test-release.sql", executionPhase = ExecutionPhase.BEFORE_TEST_METHOD) + @Sql(scripts = "/filter/test-access-control-enabled.sql", executionPhase = ExecutionPhase.BEFORE_TEST_METHOD) + @Sql(scripts = "/controller/cleanup.sql", executionPhase = ExecutionPhase.AFTER_TEST_METHOD) + public void testWithAccessControlEnabledWithInValidAccessToken() { + String appId = "someAppId"; + String someValidToken = "someInvalidToken"; + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.AUTHORIZATION, someValidToken); + HttpEntity entity = new HttpEntity<>(headers); + + AppDTO app = restTemplate + .exchange("http://localhost:" + port + "/apps/" + appId, HttpMethod.GET, entity, + AppDTO.class).getBody(); + } + + @Test + @Sql(scripts = "/controller/test-release.sql", executionPhase = ExecutionPhase.BEFORE_TEST_METHOD) + @Sql(scripts = "/filter/test-access-control-enabled-no-token.sql", executionPhase = ExecutionPhase.BEFORE_TEST_METHOD) + @Sql(scripts = "/controller/cleanup.sql", executionPhase = ExecutionPhase.AFTER_TEST_METHOD) + public void testWithAccessControlEnabledWithNoTokenSpecified() { + String appId = "someAppId"; + String someToken = "someToken"; + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.AUTHORIZATION, someToken); + HttpEntity entity = new HttpEntity<>(headers); + + AppDTO app = restTemplate + .exchange("http://localhost:" + port + "/apps/" + appId, HttpMethod.GET, entity, + AppDTO.class).getBody(); + + Assert.assertEquals("someAppId", app.getAppId()); + } + + + private void doRefresh(List propertySources) { + propertySources.forEach(refreshablePropertySource -> ReflectionTestUtils + .invokeMethod(refreshablePropertySource, "refresh")); + } +} diff --git a/apollo-adminservice/src/test/resources/filter/test-access-control-disabled.sql b/apollo-adminservice/src/test/resources/filter/test-access-control-disabled.sql new file mode 100644 index 00000000000..da2672754bd --- /dev/null +++ b/apollo-adminservice/src/test/resources/filter/test-access-control-disabled.sql @@ -0,0 +1,4 @@ +INSERT INTO `ServerConfig` (`Key`, `Cluster`, `Value`) +VALUES + ('admin-service.access.tokens', 'default', 'someToken,anotherToken'), + ('admin-service.access.control.enabled', 'default', 'false'); diff --git a/apollo-adminservice/src/test/resources/filter/test-access-control-enabled-no-token.sql b/apollo-adminservice/src/test/resources/filter/test-access-control-enabled-no-token.sql new file mode 100644 index 00000000000..5e302ccf32d --- /dev/null +++ b/apollo-adminservice/src/test/resources/filter/test-access-control-enabled-no-token.sql @@ -0,0 +1,3 @@ +INSERT INTO `ServerConfig` (`Key`, `Cluster`, `Value`) +VALUES + ('admin-service.access.control.enabled', 'default', 'true'); diff --git a/apollo-adminservice/src/test/resources/filter/test-access-control-enabled.sql b/apollo-adminservice/src/test/resources/filter/test-access-control-enabled.sql new file mode 100644 index 00000000000..05d7ba38595 --- /dev/null +++ b/apollo-adminservice/src/test/resources/filter/test-access-control-enabled.sql @@ -0,0 +1,4 @@ +INSERT INTO `ServerConfig` (`Key`, `Cluster`, `Value`) +VALUES + ('admin-service.access.tokens', 'default', 'someToken,anotherToken'), + ('admin-service.access.control.enabled', 'default', 'true'); diff --git a/apollo-biz/src/main/java/com/ctrip/framework/apollo/biz/config/BizConfig.java b/apollo-biz/src/main/java/com/ctrip/framework/apollo/biz/config/BizConfig.java index 0ec4b8bc70f..1cece1df8c9 100644 --- a/apollo-biz/src/main/java/com/ctrip/framework/apollo/biz/config/BizConfig.java +++ b/apollo-biz/src/main/java/com/ctrip/framework/apollo/biz/config/BizConfig.java @@ -173,4 +173,11 @@ int checkInt(int value, int min, int max, int defaultValue) { return defaultValue; } + public boolean isAdminServiceAccessControlEnabled() { + return getBooleanProperty("admin-service.access.control.enabled", false); + } + + public String getAdminServiceAccessTokens() { + return getValue("admin-service.access.tokens"); + } } diff --git a/apollo-client/src/test/java/com/ctrip/framework/apollo/internals/RemoteConfigLongPollServiceTest.java b/apollo-client/src/test/java/com/ctrip/framework/apollo/internals/RemoteConfigLongPollServiceTest.java index d4af113c192..b8668d2d041 100644 --- a/apollo-client/src/test/java/com/ctrip/framework/apollo/internals/RemoteConfigLongPollServiceTest.java +++ b/apollo-client/src/test/java/com/ctrip/framework/apollo/internals/RemoteConfigLongPollServiceTest.java @@ -23,6 +23,7 @@ import com.ctrip.framework.apollo.util.http.HttpUtil; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; +import com.google.common.net.HttpHeaders; import com.google.common.util.concurrent.SettableFuture; import java.lang.reflect.Type; import java.util.List; @@ -213,7 +214,7 @@ public HttpResponse> answer(InvocationOnMock invo Map headers = request.getHeaders(); assertNotNull(headers); assertTrue(headers.containsKey(Signature.HTTP_HEADER_TIMESTAMP)); - assertTrue(headers.containsKey(Signature.HTTP_HEADER_AUTHORIZATION)); + assertTrue(headers.containsKey(HttpHeaders.AUTHORIZATION)); return pollResponse; } diff --git a/apollo-client/src/test/java/com/ctrip/framework/apollo/internals/RemoteConfigRepositoryTest.java b/apollo-client/src/test/java/com/ctrip/framework/apollo/internals/RemoteConfigRepositoryTest.java index daf0ccaca59..49c658da958 100644 --- a/apollo-client/src/test/java/com/ctrip/framework/apollo/internals/RemoteConfigRepositoryTest.java +++ b/apollo-client/src/test/java/com/ctrip/framework/apollo/internals/RemoteConfigRepositoryTest.java @@ -32,6 +32,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.net.HttpHeaders; import com.google.common.net.UrlEscapers; import com.google.common.util.concurrent.SettableFuture; import com.google.gson.Gson; @@ -186,7 +187,7 @@ public HttpResponse answer(InvocationOnMock invocation) throws Thr Map headers = request.getHeaders(); assertNotNull(headers); assertTrue(headers.containsKey(Signature.HTTP_HEADER_TIMESTAMP)); - assertTrue(headers.containsKey(Signature.HTTP_HEADER_AUTHORIZATION)); + assertTrue(headers.containsKey(HttpHeaders.AUTHORIZATION)); return someResponse; } diff --git a/apollo-configservice/src/main/java/com/ctrip/framework/apollo/configservice/filter/ClientAuthenticationFilter.java b/apollo-configservice/src/main/java/com/ctrip/framework/apollo/configservice/filter/ClientAuthenticationFilter.java index 1cb43dfa94b..dad62880138 100644 --- a/apollo-configservice/src/main/java/com/ctrip/framework/apollo/configservice/filter/ClientAuthenticationFilter.java +++ b/apollo-configservice/src/main/java/com/ctrip/framework/apollo/configservice/filter/ClientAuthenticationFilter.java @@ -3,6 +3,7 @@ import com.ctrip.framework.apollo.configservice.util.AccessKeyUtil; import com.ctrip.framework.apollo.core.signature.Signature; import com.ctrip.framework.apollo.core.utils.StringUtils; +import com.google.common.net.HttpHeaders; import java.io.IOException; import java.util.List; import java.util.Objects; @@ -53,7 +54,7 @@ public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain List availableSecrets = accessKeyUtil.findAvailableSecret(appId); if (!CollectionUtils.isEmpty(availableSecrets)) { String timestamp = request.getHeader(Signature.HTTP_HEADER_TIMESTAMP); - String authorization = request.getHeader(Signature.HTTP_HEADER_AUTHORIZATION); + String authorization = request.getHeader(HttpHeaders.AUTHORIZATION); // check timestamp, valid within 1 minute if (!checkTimestamp(timestamp)) { diff --git a/apollo-configservice/src/test/java/com/ctrip/framework/apollo/configservice/filter/ClientAuthenticationFilterTest.java b/apollo-configservice/src/test/java/com/ctrip/framework/apollo/configservice/filter/ClientAuthenticationFilterTest.java index a74ca2654cf..ffe31a74af9 100644 --- a/apollo-configservice/src/test/java/com/ctrip/framework/apollo/configservice/filter/ClientAuthenticationFilterTest.java +++ b/apollo-configservice/src/test/java/com/ctrip/framework/apollo/configservice/filter/ClientAuthenticationFilterTest.java @@ -18,6 +18,7 @@ import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.http.HttpHeaders; /** * @author nisiyong @@ -95,7 +96,7 @@ public void testUnauthorized() throws Exception { when(accessKeyUtil.findAvailableSecret(appId)).thenReturn(secrets); when(accessKeyUtil.buildSignature(any(), any(), any(), any())).thenReturn(availableSignature); when(request.getHeader(Signature.HTTP_HEADER_TIMESTAMP)).thenReturn(oneMinAgoTimestamp); - when(request.getHeader(Signature.HTTP_HEADER_AUTHORIZATION)).thenReturn(errorAuthorization); + when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(errorAuthorization); clientAuthenticationFilter.doFilter(request, response, filterChain); @@ -115,7 +116,7 @@ public void testAuthorizedSuccessfully() throws Exception { when(accessKeyUtil.findAvailableSecret(appId)).thenReturn(secrets); when(accessKeyUtil.buildSignature(any(), any(), any(), any())).thenReturn(availableSignature); when(request.getHeader(Signature.HTTP_HEADER_TIMESTAMP)).thenReturn(oneMinAgoTimestamp); - when(request.getHeader(Signature.HTTP_HEADER_AUTHORIZATION)).thenReturn(correctAuthorization); + when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(correctAuthorization); clientAuthenticationFilter.doFilter(request, response, filterChain); diff --git a/apollo-core/src/main/java/com/ctrip/framework/apollo/core/signature/Signature.java b/apollo-core/src/main/java/com/ctrip/framework/apollo/core/signature/Signature.java index ba52fc3e5dc..f3bded2f928 100644 --- a/apollo-core/src/main/java/com/ctrip/framework/apollo/core/signature/Signature.java +++ b/apollo-core/src/main/java/com/ctrip/framework/apollo/core/signature/Signature.java @@ -1,6 +1,7 @@ package com.ctrip.framework.apollo.core.signature; import com.google.common.collect.Maps; +import com.google.common.net.HttpHeaders; import java.net.MalformedURLException; import java.net.URL; import java.util.Map; @@ -16,7 +17,6 @@ public class Signature { private static final String AUTHORIZATION_FORMAT = "Apollo %s:%s"; private static final String DELIMITER = "\n"; - public static final String HTTP_HEADER_AUTHORIZATION = "Authorization"; public static final String HTTP_HEADER_TIMESTAMP = "Timestamp"; public static String signature(String timestamp, String pathWithQuery, String secret) { @@ -32,7 +32,7 @@ public static Map buildHttpHeaders(String url, String appId, Str String signature = signature(timestamp, pathWithQuery, secret); Map headers = Maps.newHashMap(); - headers.put(HTTP_HEADER_AUTHORIZATION, String.format(AUTHORIZATION_FORMAT, appId, signature)); + headers.put(HttpHeaders.AUTHORIZATION, String.format(AUTHORIZATION_FORMAT, appId, signature)); headers.put(HTTP_HEADER_TIMESTAMP, timestamp); return headers; } diff --git a/apollo-core/src/test/java/com/ctrip/framework/apollo/core/signature/SignatureTest.java b/apollo-core/src/test/java/com/ctrip/framework/apollo/core/signature/SignatureTest.java index 8933fc076cf..ea805b12a7a 100644 --- a/apollo-core/src/test/java/com/ctrip/framework/apollo/core/signature/SignatureTest.java +++ b/apollo-core/src/test/java/com/ctrip/framework/apollo/core/signature/SignatureTest.java @@ -3,6 +3,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import com.google.common.net.HttpHeaders; import java.util.Map; import org.junit.Test; @@ -31,7 +32,7 @@ public void testBuildHttpHeaders() { Map actualHttpHeaders = Signature.buildHttpHeaders(url, appId, secret); - assertTrue(actualHttpHeaders.containsKey(Signature.HTTP_HEADER_AUTHORIZATION)); + assertTrue(actualHttpHeaders.containsKey(HttpHeaders.AUTHORIZATION)); assertTrue(actualHttpHeaders.containsKey(Signature.HTTP_HEADER_TIMESTAMP)); } } \ No newline at end of file diff --git a/apollo-portal/src/main/java/com/ctrip/framework/apollo/openapi/filter/ConsumerAuthenticationFilter.java b/apollo-portal/src/main/java/com/ctrip/framework/apollo/openapi/filter/ConsumerAuthenticationFilter.java index adec13dcb47..b212e4a1637 100644 --- a/apollo-portal/src/main/java/com/ctrip/framework/apollo/openapi/filter/ConsumerAuthenticationFilter.java +++ b/apollo-portal/src/main/java/com/ctrip/framework/apollo/openapi/filter/ConsumerAuthenticationFilter.java @@ -13,6 +13,7 @@ import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.http.HttpHeaders; /** * @author Jason Song(song_s@ctrip.com) @@ -37,7 +38,7 @@ public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain HttpServletRequest request = (HttpServletRequest) req; HttpServletResponse response = (HttpServletResponse) resp; - String token = request.getHeader("Authorization"); + String token = request.getHeader(HttpHeaders.AUTHORIZATION); Long consumerId = consumerAuthUtil.getConsumerId(token); diff --git a/apollo-portal/src/main/java/com/ctrip/framework/apollo/portal/component/RetryableRestTemplate.java b/apollo-portal/src/main/java/com/ctrip/framework/apollo/portal/component/RetryableRestTemplate.java index 97bcc27012b..e07b22897e0 100644 --- a/apollo-portal/src/main/java/com/ctrip/framework/apollo/portal/component/RetryableRestTemplate.java +++ b/apollo-portal/src/main/java/com/ctrip/framework/apollo/portal/component/RetryableRestTemplate.java @@ -1,18 +1,30 @@ package com.ctrip.framework.apollo.portal.component; import com.ctrip.framework.apollo.common.exception.ServiceException; -import com.ctrip.framework.apollo.portal.environment.PortalMetaDomainService; import com.ctrip.framework.apollo.core.dto.ServiceDTO; -import com.ctrip.framework.apollo.portal.environment.Env; +import com.ctrip.framework.apollo.portal.component.config.PortalConfig; import com.ctrip.framework.apollo.portal.constant.TracerEventType; +import com.ctrip.framework.apollo.portal.environment.Env; +import com.ctrip.framework.apollo.portal.environment.PortalMetaDomainService; import com.ctrip.framework.apollo.tracer.Tracer; import com.ctrip.framework.apollo.tracer.spi.Transaction; +import com.google.common.base.Strings; +import com.google.common.collect.Maps; +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import java.lang.reflect.Type; +import java.net.SocketTimeoutException; +import java.util.List; +import java.util.Map; +import javax.annotation.PostConstruct; import org.apache.http.conn.ConnectTimeoutException; import org.apache.http.conn.HttpHostConnectException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.context.annotation.Lazy; import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Component; @@ -22,10 +34,6 @@ import org.springframework.web.util.DefaultUriBuilderFactory; import org.springframework.web.util.UriTemplateHandler; -import javax.annotation.PostConstruct; -import java.net.SocketTimeoutException; -import java.util.List; - /** * 封装RestTemplate. admin server集群在某些机器宕机或者超时的情况下轮询重试 */ @@ -36,20 +44,31 @@ public class RetryableRestTemplate { private UriTemplateHandler uriTemplateHandler = new DefaultUriBuilderFactory(); + private Gson gson = new Gson(); + /** + * Admin service access tokens in "PortalDB.ServerConfig" + */ + private static final Type ACCESS_TOKENS = new TypeToken>(){}.getType(); + private RestTemplate restTemplate; private final RestTemplateFactory restTemplateFactory; private final AdminServiceAddressLocator adminServiceAddressLocator; private final PortalMetaDomainService portalMetaDomainService; + private final PortalConfig portalConfig; + private volatile String lastAdminServiceAccessTokens; + private volatile Map adminServiceAccessTokenMap; public RetryableRestTemplate( final @Lazy RestTemplateFactory restTemplateFactory, final @Lazy AdminServiceAddressLocator adminServiceAddressLocator, - final PortalMetaDomainService portalMetaDomainService + final PortalMetaDomainService portalMetaDomainService, + final PortalConfig portalConfig ) { this.restTemplateFactory = restTemplateFactory; this.adminServiceAddressLocator = adminServiceAddressLocator; this.portalMetaDomainService = portalMetaDomainService; + this.portalConfig = portalConfig; } @@ -95,11 +114,12 @@ private T execute(HttpMethod method, Env env, String path, Object request, C ct.addData("Env", env); List services = getAdminServices(env, ct); + HttpHeaders extraHeaders = assembleExtraHeaders(env); for (ServiceDTO serviceDTO : services) { try { - T result = doExecute(method, serviceDTO, path, request, responseType, uriVariables); + T result = doExecute(method, extraHeaders, serviceDTO, path, request, responseType, uriVariables); ct.setStatus(Transaction.SUCCESS); ct.complete(); @@ -137,12 +157,13 @@ private ResponseEntity exchangeGet(Env env, String path, ParameterizedTyp ct.addData("Env", env); List services = getAdminServices(env, ct); + HttpEntity entity = new HttpEntity<>(assembleExtraHeaders(env)); for (ServiceDTO serviceDTO : services) { try { ResponseEntity result = - restTemplate.exchange(parseHost(serviceDTO) + path, HttpMethod.GET, null, reference, uriVariables); + restTemplate.exchange(parseHost(serviceDTO) + path, HttpMethod.GET, entity, reference, uriVariables); ct.setStatus(Transaction.SUCCESS); ct.complete(); @@ -171,6 +192,18 @@ private ResponseEntity exchangeGet(Env env, String path, ParameterizedTyp } + private HttpHeaders assembleExtraHeaders(Env env) { + String adminServiceAccessToken = getAdminServiceAccessToken(env); + + if (!Strings.isNullOrEmpty(adminServiceAccessToken)) { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.AUTHORIZATION, adminServiceAccessToken); + return headers; + } + + return null; + } + private List getAdminServices(Env env, Transaction ct) { List services = adminServiceAddressLocator.getServiceList(env); @@ -188,23 +221,61 @@ private List getAdminServices(Env env, Transaction ct) { return services; } - private T doExecute(HttpMethod method, ServiceDTO service, String path, Object request, - Class responseType, - Object... uriVariables) { + private String getAdminServiceAccessToken(Env env) { + String accessTokens = portalConfig.getAdminServiceAccessTokens(); + + if (Strings.isNullOrEmpty(accessTokens)) { + return null; + } + + if (!accessTokens.equals(lastAdminServiceAccessTokens)) { + synchronized (this) { + adminServiceAccessTokenMap = parseAdminServiceAccessTokens(accessTokens); + lastAdminServiceAccessTokens = accessTokens; + } + } + + return adminServiceAccessTokenMap.get(env); + } + + private Map parseAdminServiceAccessTokens(String accessTokens) { + Map tokenMap = Maps.newHashMap(); + try { + // try to parse + Map map = gson.fromJson(accessTokens, ACCESS_TOKENS); + map.forEach((env, token) -> { + if (Env.exists(env)) { + tokenMap.put(Env.valueOf(env), token); + } + }); + } catch (Exception e) { + logger.error("Wrong format of admin service access tokens: {}", accessTokens, e); + } + return tokenMap; + } + private T doExecute(HttpMethod method, HttpHeaders extraHeaders, ServiceDTO service, String path, Object request, + Class responseType, Object... uriVariables) { T result = null; switch (method) { case GET: - result = restTemplate.getForObject(parseHost(service) + path, responseType, uriVariables); - break; case POST: - result = - restTemplate.postForEntity(parseHost(service) + path, request, responseType, uriVariables).getBody(); - break; case PUT: - restTemplate.put(parseHost(service) + path, request, uriVariables); - break; case DELETE: - restTemplate.delete(parseHost(service) + path, uriVariables); + HttpEntity entity; + if (request instanceof HttpEntity) { + entity = (HttpEntity) request; + if (!CollectionUtils.isEmpty(extraHeaders)) { + HttpHeaders headers = new HttpHeaders(); + headers.addAll(entity.getHeaders()); + headers.addAll(extraHeaders); + entity = new HttpEntity<>(entity.getBody(), headers); + } + } else { + entity = new HttpEntity<>(request, extraHeaders); + } + result = restTemplate + .exchange(parseHost(service) + path, method, entity, responseType, uriVariables) + .getBody(); break; default: throw new UnsupportedOperationException(String.format("unsupported http method(method=%s)", method)); diff --git a/apollo-portal/src/main/java/com/ctrip/framework/apollo/portal/component/config/PortalConfig.java b/apollo-portal/src/main/java/com/ctrip/framework/apollo/portal/component/config/PortalConfig.java index fac0bbcd73f..805bb00493a 100644 --- a/apollo-portal/src/main/java/com/ctrip/framework/apollo/portal/component/config/PortalConfig.java +++ b/apollo-portal/src/main/java/com/ctrip/framework/apollo/portal/component/config/PortalConfig.java @@ -208,6 +208,10 @@ public boolean isManageAppMasterPermissionEnabled() { return getBooleanProperty(SystemRoleManagerService.MANAGE_APP_MASTER_LIMIT_SWITCH_KEY, false); } + public String getAdminServiceAccessTokens() { + return getValue("admin-service.access.tokens"); + } + /*** * The following configurations are used in ctrip profile **/ diff --git a/apollo-portal/src/main/resources/static/scripts/AppUtils.js b/apollo-portal/src/main/resources/static/scripts/AppUtils.js index 76165b392a6..507bbd35bf8 100644 --- a/apollo-portal/src/main/resources/static/scripts/AppUtils.js +++ b/apollo-portal/src/main/resources/static/scripts/AppUtils.js @@ -56,7 +56,7 @@ appUtil.service('AppUtil', ['toastr', '$window', '$q', '$translate', 'prefixLoca if (!query) { //如果不传这个参数或者false则返回到首页(参数出错) if (!notJumpToHomePage) { - $window.location.href = '/index.html'; + $window.location.href = prefixLocation + '/index.html'; } else { return {}; } diff --git a/apollo-portal/src/test/java/com/ctrip/framework/apollo/openapi/filter/ConsumerAuthenticationFilterTest.java b/apollo-portal/src/test/java/com/ctrip/framework/apollo/openapi/filter/ConsumerAuthenticationFilterTest.java index e5ac79d4c72..326eb5f842b 100644 --- a/apollo-portal/src/test/java/com/ctrip/framework/apollo/openapi/filter/ConsumerAuthenticationFilterTest.java +++ b/apollo-portal/src/test/java/com/ctrip/framework/apollo/openapi/filter/ConsumerAuthenticationFilterTest.java @@ -12,6 +12,7 @@ import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.http.HttpHeaders; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; @@ -48,7 +49,7 @@ public void testAuthSuccessfully() throws Exception { String someToken = "someToken"; Long someConsumerId = 1L; - when(request.getHeader("Authorization")).thenReturn(someToken); + when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(someToken); when(consumerAuthUtil.getConsumerId(someToken)).thenReturn(someConsumerId); authenticationFilter.doFilter(request, response, filterChain); @@ -62,7 +63,7 @@ public void testAuthSuccessfully() throws Exception { public void testAuthFailed() throws Exception { String someInvalidToken = "someInvalidToken"; - when(request.getHeader("Authorization")).thenReturn(someInvalidToken); + when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn(someInvalidToken); when(consumerAuthUtil.getConsumerId(someInvalidToken)).thenReturn(null); authenticationFilter.doFilter(request, response, filterChain); diff --git a/apollo-portal/src/test/java/com/ctrip/framework/apollo/portal/RetryableRestTemplateTest.java b/apollo-portal/src/test/java/com/ctrip/framework/apollo/portal/RetryableRestTemplateTest.java index 46497f6eba4..eda4fd92f34 100644 --- a/apollo-portal/src/test/java/com/ctrip/framework/apollo/portal/RetryableRestTemplateTest.java +++ b/apollo-portal/src/test/java/com/ctrip/framework/apollo/portal/RetryableRestTemplateTest.java @@ -1,6 +1,14 @@ package com.ctrip.framework.apollo.portal; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -9,20 +17,28 @@ import com.ctrip.framework.apollo.core.dto.ServiceDTO; import com.ctrip.framework.apollo.portal.component.AdminServiceAddressLocator; import com.ctrip.framework.apollo.portal.component.RetryableRestTemplate; +import com.ctrip.framework.apollo.portal.component.config.PortalConfig; import com.ctrip.framework.apollo.portal.environment.Env; import com.ctrip.framework.apollo.portal.environment.PortalMetaDomainService; +import com.google.common.collect.Maps; +import com.google.gson.Gson; import java.net.SocketTimeoutException; import java.util.Arrays; import java.util.Collections; +import java.util.List; +import java.util.Map; import org.apache.http.HttpHost; import org.apache.http.conn.ConnectTimeoutException; import org.apache.http.conn.HttpHostConnectException; -import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.mockito.InjectMocks; import org.mockito.Mock; -import org.springframework.http.HttpStatus; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; import org.springframework.http.ResponseEntity; import org.springframework.web.client.ResourceAccessException; import org.springframework.web.client.RestTemplate; @@ -35,9 +51,13 @@ public class RetryableRestTemplateTest extends AbstractUnitTest { private RestTemplate restTemplate; @Mock private PortalMetaDomainService portalMetaDomainService; + @Mock + private PortalConfig portalConfig; @InjectMocks private RetryableRestTemplate retryableRestTemplate; + private Gson gson = new Gson(); + private String path = "app"; private String serviceOne = "http://10.0.0.1"; private String serviceTwo = "http://10.0.0.2"; @@ -46,15 +66,16 @@ public class RetryableRestTemplateTest extends AbstractUnitTest { private ResourceAccessException httpHostConnectException = new ResourceAccessException(""); private ResourceAccessException connectTimeoutException = new ResourceAccessException(""); private Object request = new Object(); - private ResponseEntity entity = new ResponseEntity<>(HttpStatus.OK); - + private Object result = new Object(); + private Class requestType = request.getClass(); @Before public void init() { socketTimeoutException.initCause(new SocketTimeoutException()); httpHostConnectException - .initCause(new HttpHostConnectException(new ConnectTimeoutException(), new HttpHost(serviceOne, 80))); + .initCause(new HttpHostConnectException(new ConnectTimeoutException(), + new HttpHost(serviceOne, 80))); connectTimeoutException.initCause(new ConnectTimeoutException()); } @@ -70,72 +91,397 @@ public void testNoAdminServer() { public void testAllServerDown() { when(serviceAddressLocator.getServiceList(any())) - .thenReturn(Arrays.asList(mockService(serviceOne), mockService(serviceTwo), mockService(serviceThree))); - when(restTemplate.getForObject(serviceOne + "/" + path, Object.class)).thenThrow(socketTimeoutException); - when(restTemplate.getForObject(serviceTwo + "/" + path, Object.class)).thenThrow(httpHostConnectException); - when(restTemplate.getForObject(serviceThree + "/" + path, Object.class)).thenThrow(connectTimeoutException); + .thenReturn(Arrays + .asList(mockService(serviceOne), mockService(serviceTwo), mockService(serviceThree))); + when(restTemplate + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(Object.class))).thenThrow(socketTimeoutException); + when(restTemplate + .exchange(eq(serviceTwo + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(Object.class))).thenThrow(httpHostConnectException); + when(restTemplate + .exchange(eq(serviceThree + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(Object.class))).thenThrow(connectTimeoutException); retryableRestTemplate.get(Env.DEV, path, Object.class); - verify(restTemplate).getForObject(serviceOne + "/" + path, Object.class); - verify(restTemplate).getForObject(serviceTwo + "/" + path, Object.class); - verify(restTemplate).getForObject(serviceThree + "/" + path, Object.class); + verify(restTemplate, times(1)) + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(Object.class)); + verify(restTemplate, times(1)) + .exchange(eq(serviceTwo + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(Object.class)); + verify(restTemplate, times(1)) + .exchange(eq(serviceThree + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(Object.class)); + } + @Test + public void testOneServerDown() { + ResponseEntity someEntity = mock(ResponseEntity.class); + when(someEntity.getBody()).thenReturn(result); + when(serviceAddressLocator.getServiceList(any())) + .thenReturn(Arrays + .asList(mockService(serviceOne), mockService(serviceTwo), mockService(serviceThree))); + when(restTemplate + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(Object.class))).thenThrow(socketTimeoutException); + when(restTemplate + .exchange(eq(serviceTwo + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(Object.class))).thenReturn(someEntity); + when(restTemplate + .exchange(eq(serviceThree + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(Object.class))).thenThrow(connectTimeoutException); + + Object actualResult = retryableRestTemplate.get(Env.DEV, path, Object.class); + + verify(restTemplate, times(1)) + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(Object.class)); + verify(restTemplate, times(1)) + .exchange(eq(serviceTwo + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(Object.class)); + verify(restTemplate, never()) + .exchange(eq(serviceThree + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(Object.class)); + assertEquals(result, actualResult); } @Test - public void testOneServerDown() { + public void testPostSocketTimeoutNotRetry() { + ResponseEntity someEntity = mock(ResponseEntity.class); + when(someEntity.getBody()).thenReturn(result); + + when(serviceAddressLocator.getServiceList(any())) + .thenReturn(Arrays + .asList(mockService(serviceOne), mockService(serviceTwo), mockService(serviceThree))); + when(restTemplate + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.POST), any(HttpEntity.class), + eq(Object.class))).thenThrow(socketTimeoutException); + when(restTemplate + .exchange(eq(serviceTwo + "/" + path), eq(HttpMethod.POST), any(HttpEntity.class), + eq(Object.class))).thenReturn(someEntity); + + Throwable exception = null; + Object actualResult = null; + try { + actualResult = retryableRestTemplate.post(Env.DEV, path, request, Object.class); + } catch (Throwable ex) { + exception = ex; + } + + assertNull(actualResult); + assertSame(socketTimeoutException, exception); + verify(restTemplate, times(1)) + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.POST), any(HttpEntity.class), + eq(Object.class)); + verify(restTemplate, never()) + .exchange(eq(serviceTwo + "/" + path), eq(HttpMethod.POST), any(HttpEntity.class), + eq(Object.class)); + } + + @Test + public void testDelete() { + ResponseEntity someEntity = mock(ResponseEntity.class); - Object result = new Object(); when(serviceAddressLocator.getServiceList(any())) - .thenReturn(Arrays.asList(mockService(serviceOne), mockService(serviceTwo), mockService(serviceThree))); - when(restTemplate.getForObject(serviceOne + "/" + path, Object.class)).thenThrow(socketTimeoutException); - when(restTemplate.getForObject(serviceTwo + "/" + path, Object.class)).thenReturn(result); - when(restTemplate.getForObject(serviceThree + "/" + path, Object.class)).thenThrow(connectTimeoutException); + .thenReturn(Arrays + .asList(mockService(serviceOne), mockService(serviceTwo), mockService(serviceThree))); + when(restTemplate + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.DELETE), any(HttpEntity.class), + (Class) isNull())).thenReturn(someEntity); - Object o = retryableRestTemplate.get(Env.DEV, path, Object.class); + retryableRestTemplate.delete(Env.DEV, path); - verify(restTemplate).getForObject(serviceOne + "/" + path, Object.class); - verify(restTemplate).getForObject(serviceTwo + "/" + path, Object.class); - verify(restTemplate, times(0)).getForObject(serviceThree + "/" + path, Object.class); - Assert.assertEquals(result, o); + verify(restTemplate) + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.DELETE), any(HttpEntity.class), + (Class) isNull()); } - @Test(expected = ResourceAccessException.class) - public void testPostSocketTimeoutNotRetry(){ + @Test + public void testPut() { + ResponseEntity someEntity = mock(ResponseEntity.class); + when(serviceAddressLocator.getServiceList(any())) - .thenReturn(Arrays.asList(mockService(serviceOne), mockService(serviceTwo), mockService(serviceThree))); + .thenReturn(Arrays + .asList(mockService(serviceOne), mockService(serviceTwo), mockService(serviceThree))); + when(restTemplate + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.PUT), any(HttpEntity.class), + (Class) isNull())).thenReturn(someEntity); + + retryableRestTemplate.put(Env.DEV, path, request); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(HttpEntity.class); + verify(restTemplate) + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.PUT), argumentCaptor.capture(), + (Class) isNull()); + + assertEquals(request, argumentCaptor.getValue().getBody()); + } + + @Test + public void testPostObjectWithNoAccessToken() { + Env someEnv = Env.DEV; + ResponseEntity someEntity = mock(ResponseEntity.class); + + when(serviceAddressLocator.getServiceList(someEnv)) + .thenReturn(Collections.singletonList(mockService(serviceOne))); + when(restTemplate + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.POST), any(HttpEntity.class), + eq(requestType))).thenReturn(someEntity); + when(someEntity.getBody()).thenReturn(result); - when(restTemplate.postForEntity(serviceOne + "/" + path, request, Object.class)).thenThrow(socketTimeoutException); - when(restTemplate.postForEntity(serviceTwo + "/" + path, request, Object.class)).thenReturn(entity); + Object actualResult = retryableRestTemplate.post(someEnv, path, request, requestType); - retryableRestTemplate.post(Env.DEV, path, request, Object.class); + assertEquals(result, actualResult); - verify(restTemplate).postForEntity(serviceOne + "/" + path, request, Object.class); - verify(restTemplate, times(0)).postForEntity(serviceTwo + "/" + path, request, Object.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(HttpEntity.class); + verify(restTemplate, times(1)) + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.POST), argumentCaptor.capture(), + eq(requestType)); + + HttpEntity entity = argumentCaptor.getValue(); + HttpHeaders headers = entity.getHeaders(); + + assertSame(request, entity.getBody()); + assertTrue(headers.isEmpty()); } + @Test + public void testPostObjectWithAccessToken() { + Env someEnv = Env.DEV; + String someToken = "someToken"; + ResponseEntity someEntity = mock(ResponseEntity.class); + + when(portalConfig.getAdminServiceAccessTokens()) + .thenReturn(mockAdminServiceTokens(someEnv, someToken)); + when(serviceAddressLocator.getServiceList(someEnv)) + .thenReturn(Collections.singletonList(mockService(serviceOne))); + when(restTemplate + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.POST), any(HttpEntity.class), + eq(requestType))).thenReturn(someEntity); + when(someEntity.getBody()).thenReturn(result); + + Object actualResult = retryableRestTemplate.post(someEnv, path, request, requestType); + + assertEquals(result, actualResult); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(HttpEntity.class); + verify(restTemplate, times(1)) + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.POST), argumentCaptor.capture(), + eq(requestType)); + + HttpEntity entity = argumentCaptor.getValue(); + HttpHeaders headers = entity.getHeaders(); + List headerValue = headers.get(HttpHeaders.AUTHORIZATION); + + assertSame(request, entity.getBody()); + assertEquals(1, headers.size()); + assertEquals(1, headerValue.size()); + assertEquals(someToken, headerValue.get(0)); + } @Test - public void testDelete(){ - when(serviceAddressLocator.getServiceList(any())) - .thenReturn(Arrays.asList(mockService(serviceOne), mockService(serviceTwo), mockService(serviceThree))); + public void testPostObjectWithNoAccessTokenForEnv() { + Env someEnv = Env.DEV; + Env anotherEnv = Env.PRO; + String someToken = "someToken"; + ResponseEntity someEntity = mock(ResponseEntity.class); + + when(portalConfig.getAdminServiceAccessTokens()) + .thenReturn(mockAdminServiceTokens(someEnv, someToken)); + when(serviceAddressLocator.getServiceList(someEnv)) + .thenReturn(Collections.singletonList(mockService(serviceOne))); + when(serviceAddressLocator.getServiceList(anotherEnv)) + .thenReturn(Collections.singletonList(mockService(serviceTwo))); + when(restTemplate + .exchange(eq(serviceTwo + "/" + path), eq(HttpMethod.POST), any(HttpEntity.class), + eq(requestType))).thenReturn(someEntity); + when(someEntity.getBody()).thenReturn(result); + + Object actualResult = retryableRestTemplate.post(anotherEnv, path, request, requestType); + + assertEquals(result, actualResult); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(HttpEntity.class); + verify(restTemplate, times(1)) + .exchange(eq(serviceTwo + "/" + path), eq(HttpMethod.POST), argumentCaptor.capture(), + eq(requestType)); + + HttpEntity entity = argumentCaptor.getValue(); + HttpHeaders headers = entity.getHeaders(); + + assertSame(request, entity.getBody()); + assertTrue(headers.isEmpty()); + } - retryableRestTemplate.delete(Env.DEV, path); + @Test + public void testPostEntityWithNoAccessToken() { + Env someEnv = Env.DEV; + String originalHeader = "someHeader"; + String originalValue = "someValue"; + HttpHeaders originalHeaders = new HttpHeaders(); + originalHeaders.add(originalHeader, originalValue); + HttpEntity requestEntity = new HttpEntity<>(request, originalHeaders); + ResponseEntity someEntity = mock(ResponseEntity.class); + + when(serviceAddressLocator.getServiceList(someEnv)) + .thenReturn(Collections.singletonList(mockService(serviceOne))); + when(restTemplate + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.POST), any(HttpEntity.class), + eq(requestType))).thenReturn(someEntity); + when(someEntity.getBody()).thenReturn(result); + + Object actualResult = retryableRestTemplate.post(someEnv, path, requestEntity, requestType); + + assertEquals(result, actualResult); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(HttpEntity.class); + verify(restTemplate, times(1)) + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.POST), argumentCaptor.capture(), + eq(requestType)); + + HttpEntity entity = argumentCaptor.getValue(); + + assertSame(requestEntity, entity); + assertSame(request, entity.getBody()); + assertEquals(originalHeaders, entity.getHeaders()); + } + + @Test + public void testPostEntityWithAccessToken() { + Env someEnv = Env.DEV; + String someToken = "someToken"; + String originalHeader = "someHeader"; + String originalValue = "someValue"; + HttpHeaders originalHeaders = new HttpHeaders(); + originalHeaders.add(originalHeader, originalValue); + HttpEntity requestEntity = new HttpEntity<>(request, originalHeaders); + ResponseEntity someEntity = mock(ResponseEntity.class); + + when(portalConfig.getAdminServiceAccessTokens()) + .thenReturn(mockAdminServiceTokens(someEnv, someToken)); + when(serviceAddressLocator.getServiceList(someEnv)) + .thenReturn(Collections.singletonList(mockService(serviceOne))); + when(restTemplate + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.POST), any(HttpEntity.class), + eq(requestType))).thenReturn(someEntity); + when(someEntity.getBody()).thenReturn(result); + + Object actualResult = retryableRestTemplate.post(someEnv, path, requestEntity, requestType); + + assertEquals(result, actualResult); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(HttpEntity.class); + verify(restTemplate, times(1)) + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.POST), argumentCaptor.capture(), + eq(requestType)); + + HttpEntity entity = argumentCaptor.getValue(); + HttpHeaders headers = entity.getHeaders(); + + assertSame(request, entity.getBody()); + assertEquals(2, headers.size()); + assertEquals(originalValue, headers.get(originalHeader).get(0)); + assertEquals(someToken, headers.get(HttpHeaders.AUTHORIZATION).get(0)); + } - verify(restTemplate).delete(serviceOne + "/" + path); + @Test + public void testGetEntityWithNoAccessToken() { + Env someEnv = Env.DEV; + ParameterizedTypeReference requestType = mock(ParameterizedTypeReference.class); + ResponseEntity someEntity = mock(ResponseEntity.class); + + when(serviceAddressLocator.getServiceList(someEnv)) + .thenReturn(Collections.singletonList(mockService(serviceOne))); + when(restTemplate + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(requestType))).thenReturn(someEntity); + + ResponseEntity actualResult = retryableRestTemplate.get(someEnv, path, requestType); + + assertEquals(someEntity, actualResult); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(HttpEntity.class); + verify(restTemplate, times(1)) + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.GET), argumentCaptor.capture(), + eq(requestType)); + + HttpHeaders headers = argumentCaptor.getValue().getHeaders(); + assertTrue(headers.isEmpty()); } @Test - public void testPut(){ - when(serviceAddressLocator.getServiceList(any())) - .thenReturn(Arrays.asList(mockService(serviceOne), mockService(serviceTwo), mockService(serviceThree))); + public void testGetEntityWithAccessToken() { + Env someEnv = Env.DEV; + String someToken = "someToken"; + ParameterizedTypeReference requestType = mock(ParameterizedTypeReference.class); + ResponseEntity someEntity = mock(ResponseEntity.class); + + when(portalConfig.getAdminServiceAccessTokens()) + .thenReturn(mockAdminServiceTokens(someEnv, someToken)); + when(serviceAddressLocator.getServiceList(someEnv)) + .thenReturn(Collections.singletonList(mockService(serviceOne))); + when(restTemplate + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(requestType))).thenReturn(someEntity); + + ResponseEntity actualResult = retryableRestTemplate.get(someEnv, path, requestType); + + assertEquals(someEntity, actualResult); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(HttpEntity.class); + verify(restTemplate, times(1)) + .exchange(eq(serviceOne + "/" + path), eq(HttpMethod.GET), argumentCaptor.capture(), + eq(requestType)); + + HttpHeaders headers = argumentCaptor.getValue().getHeaders(); + List headerValue = headers.get(HttpHeaders.AUTHORIZATION); + + assertEquals(1, headers.size()); + assertEquals(1, headerValue.size()); + assertEquals(someToken, headerValue.get(0)); + } - retryableRestTemplate.put(Env.DEV, path, request); + @Test + public void testGetEntityWithNoAccessTokenForEnv() { + Env someEnv = Env.DEV; + Env anotherEnv = Env.PRO; + String someToken = "someToken"; + ParameterizedTypeReference requestType = mock(ParameterizedTypeReference.class); + ResponseEntity someEntity = mock(ResponseEntity.class); + + when(portalConfig.getAdminServiceAccessTokens()) + .thenReturn(mockAdminServiceTokens(someEnv, someToken)); + when(serviceAddressLocator.getServiceList(someEnv)) + .thenReturn(Collections.singletonList(mockService(serviceOne))); + when(serviceAddressLocator.getServiceList(anotherEnv)) + .thenReturn(Collections.singletonList(mockService(serviceTwo))); + when(restTemplate + .exchange(eq(serviceTwo + "/" + path), eq(HttpMethod.GET), any(HttpEntity.class), + eq(requestType))).thenReturn(someEntity); + + ResponseEntity actualResult = retryableRestTemplate.get(anotherEnv, path, requestType); + + assertEquals(someEntity, actualResult); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(HttpEntity.class); + verify(restTemplate, times(1)) + .exchange(eq(serviceTwo + "/" + path), eq(HttpMethod.GET), argumentCaptor.capture(), + eq(requestType)); + + HttpHeaders headers = argumentCaptor.getValue().getHeaders(); + + assertTrue(headers.isEmpty()); + } + + private String mockAdminServiceTokens(Env env, String token) { + Map tokenMap = Maps.newHashMap(); + tokenMap.put(env.getName(), token); - verify(restTemplate).put(serviceOne + "/" + path, request); + return gson.toJson(tokenMap); } private ServiceDTO mockService(String homeUrl) {