Skip to content

Commit 8121532

Browse files
Addign cors headers unit tests
1 parent 2e8601f commit 8121532

File tree

5 files changed

+361
-37
lines changed

5 files changed

+361
-37
lines changed

server/server-core/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ dependencies {
1616
implementation(libs.smithy.model)
1717
implementation(project(":io"))
1818
implementation(project(":logging"))
19+
implementation(libs.netty.all)
1920
}

server/server-core/src/main/java/software/amazon/smithy/java/server/core/CorsHeaders.java

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,61 +7,59 @@
77

88
import static software.amazon.smithy.java.core.schema.TraitKey.CORS_TRAIT;
99

10-
import java.util.*;
11-
import software.amazon.smithy.java.logging.InternalLogger;
12-
import software.amazon.smithy.model.traits.CorsTrait;
10+
import io.netty.handler.codec.http.HttpHeaders;
11+
import java.util.Arrays;
12+
import java.util.List;
13+
import java.util.Map;
1314

1415
public final class CorsHeaders {
15-
private static final InternalLogger LOGGER = InternalLogger.getLogger(CorsHeaders.class);
1616

17-
private CorsHeaders() {} // Prevent instantiation
17+
private CorsHeaders() {}
1818

19-
public static Map<String, List<String>> of(HttpJob job) {
19+
private static final Map<String, Iterable<?>> BASE_CORS_HEADERS = Map.of(
20+
"Access-Control-Allow-Methods",
21+
List.of("GET, POST, PUT, DELETE, OPTIONS"),
22+
"Access-Control-Allow-Headers",
23+
List.of("*,Access-Control-Allow-Headers,Access-Control-Allow-Methods,Access-Control-Allow-Origin,Amz-Sdk-Invocation-Id,Amz-Sdk-Request,Authorization,Content-Length,Content-Type,X-Amz-User-Agent,X-Amzn-Trace-Id"),
24+
"Access-Control-Max-Age",
25+
List.of("600"));
26+
27+
public static void of(HttpJob job, HttpHeaders headers) {
2028
if (!shouldAddCorsHeaders(job)) {
21-
return Map.of();
29+
return;
2230
}
2331

2432
String requestOrigin = job.request().headers().firstValue("origin");
25-
Optional<String> configuredOrigin = getConfiguredOrigin(job);
33+
String configuredOrigin = getConfiguredOrigin(job);
2634

27-
if (!isOriginAllowed(configuredOrigin.orElse(null), requestOrigin)) {
28-
return Map.of();
35+
if (!isOriginAllowed(configuredOrigin, requestOrigin)) {
36+
return;
2937
}
3038

31-
Map<String, List<String>> headers = new HashMap<>();
32-
headers.put("Access-Control-Allow-Origin", List.of(requestOrigin));
33-
headers.put("Access-Control-Allow-Methods", List.of("GET, POST, PUT, DELETE, OPTIONS"));
34-
headers.put("Access-Control-Allow-Headers",
35-
List.of("*,Access-Control-Allow-Headers,Access-Control-Allow-Methods,Access-Control-Allow-Origin,Amz-Sdk-Invocation-Id,Amz-Sdk-Request,Authorization,Content-Length,Content-Type,X-Amz-User-Agent,X-Amzn-Trace-Id"));
36-
headers.put("Access-Control-Max-Age", List.of("600"));
39+
for (Map.Entry<String, Iterable<?>> entrySet : BASE_CORS_HEADERS.entrySet()) {
40+
headers.set(entrySet.getKey(), entrySet.getValue());
41+
}
3742

38-
return headers;
43+
headers.set("Access-Control-Allow-Origin", List.of(requestOrigin));
3944
}
4045

4146
private static boolean shouldAddCorsHeaders(HttpJob job) {
42-
if (job == null || job.operation() == null
43-
|| job.operation().getApiOperation() == null
44-
||
45-
job.operation().getApiOperation().service() == null
46-
||
47+
if (job.operation().getApiOperation().service() == null ||
4748
job.operation().getApiOperation().service().schema() == null
4849
||
4950
!job.operation().getApiOperation().service().schema().hasTrait(CORS_TRAIT)) {
5051
return false;
5152
}
52-
53-
return job.request() != null
54-
&& job.request().headers() != null
55-
&& job.request().headers().hasHeader("origin");
53+
return job.request().headers().hasHeader("origin");
5654
}
5755

58-
private static Optional<String> getConfiguredOrigin(HttpJob job) {
59-
return Optional.ofNullable(job.operation()
56+
private static String getConfiguredOrigin(HttpJob job) {
57+
return job.operation()
6058
.getApiOperation()
6159
.service()
6260
.schema()
63-
.getTrait(CORS_TRAIT))
64-
.map(CorsTrait::getOrigin);
61+
.getTrait(CORS_TRAIT)
62+
.getOrigin();
6563
}
6664

6765
private static boolean isOriginAllowed(String configuredOrigin, String requestOrigin) {
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package software.amazon.smithy.java.server.core;
7+
8+
import static org.junit.jupiter.api.Assertions.assertFalse;
9+
import static org.junit.jupiter.api.Assertions.assertTrue;
10+
11+
import io.netty.handler.codec.http.DefaultHttpHeaders;
12+
import java.net.URI;
13+
import java.net.URISyntaxException;
14+
import java.util.List;
15+
import java.util.Map;
16+
import java.util.stream.Stream;
17+
import org.junit.jupiter.params.ParameterizedTest;
18+
import org.junit.jupiter.params.provider.MethodSource;
19+
import software.amazon.smithy.java.core.schema.ApiService;
20+
import software.amazon.smithy.java.core.schema.Schema;
21+
import software.amazon.smithy.java.http.api.HttpHeaders;
22+
import software.amazon.smithy.java.server.Operation;
23+
import software.amazon.smithy.model.traits.CorsTrait;
24+
25+
public class CorsHeadersTest {
26+
27+
record TestCase(
28+
String name,
29+
String configuredOrigin,
30+
String requestOrigin,
31+
boolean shouldHaveHeaders,
32+
String expectedAllowOrigin) {}
33+
34+
private static Stream<TestCase> corsTestCases() {
35+
return Stream.of(
36+
new TestCase(
37+
"No request origin header",
38+
"http://allowed-origin.com",
39+
null,
40+
false,
41+
null),
42+
new TestCase(
43+
"Not allowed origin",
44+
"http://allowed-origin.com",
45+
"http://not-allowed-origin.com",
46+
false,
47+
null),
48+
new TestCase(
49+
"Multiple allowed origins",
50+
"http://allowed-origin.com,http://other-allowed-origin.com",
51+
"http://allowed-origin.com",
52+
true,
53+
"http://allowed-origin.com"),
54+
new TestCase(
55+
"Wildcard origin",
56+
"*",
57+
"http://any-origin.com",
58+
true,
59+
"http://any-origin.com"),
60+
new TestCase(
61+
"Exact match origin",
62+
"http://allowed-origin.com",
63+
"http://allowed-origin.com",
64+
true,
65+
"http://allowed-origin.com"));
66+
}
67+
68+
@ParameterizedTest(name = "{0}")
69+
@MethodSource("corsTestCases")
70+
void testCorsHeaders(TestCase testCase) throws URISyntaxException {
71+
// Create operation with configured CORS origin
72+
Operation testOperation = Operation.of(
73+
"TestOperation",
74+
(input, context) -> new TestStructs.TestOutput(),
75+
new TestStructs.TestApiOperation(new ApiService() {
76+
@Override
77+
public Schema schema() {
78+
return Schema.structureBuilder(
79+
CorsTrait.ID,
80+
CorsTrait.builder().origin(testCase.configuredOrigin).build()).build();
81+
}
82+
}),
83+
new TestStructs.TestService());
84+
85+
// Create request headers
86+
Map<String, List<String>> headerMap =
87+
testCase.requestOrigin != null ? Map.of("Origin", List.of(testCase.requestOrigin)) : Map.of();
88+
HttpHeaders requestHeaders = HttpHeaders.of(headerMap);
89+
90+
// Create request
91+
HttpRequest request = new HttpRequest(
92+
requestHeaders,
93+
new URI("http://test.com"),
94+
"GET");
95+
96+
// Create response and job
97+
HttpResponse response = new HttpResponse(new TestStructs.TestModifiableHttpHeaders());
98+
ServerProtocol protocol = new TestStructs.TestServerProtocol(List.of());
99+
HttpJob job = new HttpJob(testOperation, protocol, request, response);
100+
101+
// Apply CORS headers
102+
var responseHeaders = new DefaultHttpHeaders();
103+
CorsHeaders.of(job, responseHeaders);
104+
105+
// Verify headers
106+
if (testCase.shouldHaveHeaders) {
107+
assertContainsHeader(responseHeaders, testCase.expectedAllowOrigin);
108+
} else {
109+
assertDoNotContainsHeader(responseHeaders);
110+
}
111+
}
112+
113+
private void assertContainsHeader(io.netty.handler.codec.http.HttpHeaders responseHeaders, String allowedOrigin) {
114+
assertTrue(responseHeaders.contains("Access-Control-Allow-Methods"));
115+
assertTrue(responseHeaders.contains("Access-Control-Allow-Headers"));
116+
assertTrue(responseHeaders.contains("Access-Control-Max-Age"));
117+
assertTrue(responseHeaders.contains("Access-Control-Allow-Origin"));
118+
assertTrue(responseHeaders.get("Access-Control-Allow-Origin").contains(allowedOrigin));
119+
}
120+
121+
private void assertDoNotContainsHeader(io.netty.handler.codec.http.HttpHeaders responseHeaders) {
122+
assertFalse(responseHeaders.contains("Access-Control-Allow-Methods"));
123+
assertFalse(responseHeaders.contains("Access-Control-Allow-Headers"));
124+
assertFalse(responseHeaders.contains("Access-Control-Max-Age"));
125+
assertFalse(responseHeaders.contains("Access-Control-Allow-Origin"));
126+
}
127+
}

0 commit comments

Comments
 (0)