@@ -24,11 +24,13 @@ import (
2424 "testing"
2525 "time"
2626
27+ "github.com/google/fleetspeak/fleetspeak/src/client/comms"
2728 "github.com/google/fleetspeak/fleetspeak/src/client/stats"
29+ "golang.org/x/exp/slices"
2830)
2931
3032type testStatsCollector struct {
31- stats.CommunicatorCollector
33+ stats.Collector
3234 fetches atomic.Int64
3335}
3436
@@ -38,12 +40,33 @@ func (c *testStatsCollector) AfterGetFileRequest(_, _, _ string, didFetch bool,
3840 }
3941}
4042
41- func (* testStatsCollector ) OutboundContactData (string , int , error ) {}
43+ type testCommsContext struct {
44+ comms.Context
45+ stats stats.Collector
46+ clientLabels []string
47+ }
4248
43- func (* testStatsCollector ) InboundContactData (string , int , error ) {}
49+ func (c * testCommsContext ) Stats () stats.Collector {
50+ return c .stats
51+ }
4452
45- func createFakeServer (lastModified time.Time ) (* httptest.Server , []string ) {
53+ func (c * testCommsContext ) CurrentIdentity () (comms.ClientIdentity , error ) {
54+ return comms.ClientIdentity {Labels : c .clientLabels }, nil
55+ }
56+
57+ func (c * testCommsContext ) ServerInfo () (comms.ServerInfo , error ) {
58+ return comms.ServerInfo {}, nil
59+ }
60+
61+ func createFakeServer (lastModified time.Time , blockedLabels ... string ) (* httptest.Server , []string ) {
4662 fakeServer := httptest .NewTLSServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
63+ for _ , label := range r .Header .Values ("X-Fleetspeak-Labels" ) {
64+ if slices .Contains (blockedLabels , label ) {
65+ http .Error (w , "unauthorized" , http .StatusUnauthorized )
66+ return
67+ }
68+ }
69+
4770 content := strings .NewReader ("test" )
4871 http .ServeContent (w , r , "test.txt" , lastModified , content )
4972 }))
@@ -52,12 +75,12 @@ func createFakeServer(lastModified time.Time) (*httptest.Server, []string) {
5275 return fakeServer , hosts
5376}
5477
55- func doRequest (t * testing.T , hosts []string , client * http.Client , lastModifiedOnClient time.Time , stats stats. CommunicatorCollector ) (string , time.Time ) {
78+ func doRequest (t * testing.T , cctx comms. Context , hosts []string , client * http.Client , lastModifiedOnClient time.Time ) (string , time.Time ) {
5679 t .Helper ()
57- ctx , cancel := context .WithTimeout (context . Background (), 2 * time .Second )
80+ ctx , cancel := context .WithTimeout (t . Context (), 2 * time .Second )
5881 defer cancel ()
5982
60- reader , modTime , err := getFileIfModified (ctx , hosts , client , "TestService" , "test.txt" , lastModifiedOnClient , stats )
83+ reader , modTime , err := getFileIfModified (ctx , cctx , nil , hosts , client , "TestService" , "test.txt" , lastModifiedOnClient )
6184 if err != nil {
6285 t .Fatalf ("getFileIfModified() failed: %v" , err )
6386 }
@@ -74,12 +97,13 @@ func doRequest(t *testing.T, hosts []string, client *http.Client, lastModifiedOn
7497
7598func TestGetFileIfModified (t * testing.T ) {
7699 stats := & testStatsCollector {}
100+ cctx := & testCommsContext {stats : stats }
77101 lastModifiedOnServer := time .Date (2024 , 1 , 1 , 0 , 0 , 0 , 0 , time .UTC )
78102 lastModifiedOnClient := lastModifiedOnServer .Add (- time .Hour )
79103 fakeServer , hosts := createFakeServer (lastModifiedOnServer )
80104 defer fakeServer .Close ()
81105
82- body , modTime := doRequest (t , hosts , fakeServer .Client (), lastModifiedOnClient , stats )
106+ body , modTime := doRequest (t , cctx , hosts , fakeServer .Client (), lastModifiedOnClient )
83107 if ! modTime .Equal (lastModifiedOnServer ) {
84108 t .Errorf ("Unexpected modTime, got: %v, want: %v" , modTime , lastModifiedOnServer )
85109 }
@@ -94,12 +118,13 @@ func TestGetFileIfModified(t *testing.T) {
94118
95119func TestGetFileIfNotModified (t * testing.T ) {
96120 stats := & testStatsCollector {}
121+ cctx := & testCommsContext {stats : stats }
97122 lastModifiedOnServer := time .Date (2024 , 1 , 1 , 0 , 0 , 0 , 0 , time .UTC )
98123 lastModifiedOnClient := lastModifiedOnServer
99124 fakeServer , hosts := createFakeServer (lastModifiedOnServer )
100125 defer fakeServer .Close ()
101126
102- body , _ := doRequest (t , hosts , fakeServer .Client (), lastModifiedOnClient , stats )
127+ body , _ := doRequest (t , cctx , hosts , fakeServer .Client (), lastModifiedOnClient )
103128 if want := "" ; body != want {
104129 t .Errorf ("Unexpected response body, got: %q, want: %q" , body , want )
105130 }
@@ -110,6 +135,7 @@ func TestGetFileIfNotModified(t *testing.T) {
110135}
111136
112137func TestGetFileUnreachableHost (t * testing.T ) {
138+ cctx := & testCommsContext {stats : & stats.NoopCollector {}}
113139 lastModifiedOnServer := time .Date (2024 , 1 , 1 , 0 , 0 , 0 , 0 , time .UTC )
114140 lastModifiedOnClient := lastModifiedOnServer .Add (- time .Hour )
115141 fakeServer , hosts := createFakeServer (lastModifiedOnServer )
@@ -119,11 +145,28 @@ func TestGetFileUnreachableHost(t *testing.T) {
119145 // should still succeed by trying the next one in the list.
120146 hosts = append ([]string {"unreachable_host" }, hosts ... )
121147
122- body , modTime := doRequest (t , hosts , fakeServer .Client (), lastModifiedOnClient , stats. NoopCollector {} )
148+ body , modTime := doRequest (t , cctx , hosts , fakeServer .Client (), lastModifiedOnClient )
123149 if ! modTime .Equal (lastModifiedOnServer ) {
124150 t .Errorf ("Unexpected modTime, got: %v, want: %v" , modTime , lastModifiedOnServer )
125151 }
126152 if want := "test" ; body != want {
127153 t .Errorf ("Unexpected body, got: %q, want: %q" , body , want )
128154 }
129155}
156+
157+ func TestGetFileUnauthorizedClient (t * testing.T ) {
158+ stats := & testStatsCollector {}
159+ cctx := & testCommsContext {stats : stats , clientLabels : []string {"label1" }}
160+ lastModifiedOnServer := time .Date (2024 , 1 , 1 , 0 , 0 , 0 , 0 , time .UTC )
161+ lastModifiedOnClient := lastModifiedOnServer .Add (- time .Hour )
162+ fakeServer , hosts := createFakeServer (lastModifiedOnServer , "label1" )
163+ defer fakeServer .Close ()
164+
165+ ctx , cancel := context .WithTimeout (t .Context (), 2 * time .Second )
166+ defer cancel ()
167+ _ , _ , err := getFileIfModified (ctx , cctx , nil , hosts , fakeServer .Client (), "TestService" , "test.txt" , lastModifiedOnClient )
168+
169+ if err == nil {
170+ t .Errorf ("getFileIfModified() succeeded, want error" )
171+ }
172+ }
0 commit comments