diff --git a/messaging/messaging_batch.go b/messaging/messaging_batch.go index 27680b6f..365190a4 100644 --- a/messaging/messaging_batch.go +++ b/messaging/messaging_batch.go @@ -27,7 +27,6 @@ import ( "mime/multipart" "net/http" "net/textproto" - "sync" "firebase.google.com/go/v4/internal" ) @@ -165,53 +164,85 @@ func (c *fcmClient) sendEachInBatch(ctx context.Context, messages []*Message, dr return nil, fmt.Errorf("messages must not contain more than %d elements", maxMessages) } - var responses []*SendResponse = make([]*SendResponse, len(messages)) - var wg sync.WaitGroup - for idx, m := range messages { if err := validateMessage(m); err != nil { return nil, fmt.Errorf("invalid message at index %d: %v", idx, err) } - wg.Add(1) - go func(idx int, m *Message, dryRun bool, responses []*SendResponse) { - defer wg.Done() - var resp string - var err error - if dryRun { - resp, err = c.SendDryRun(ctx, m) - } else { - resp, err = c.Send(ctx, m) - } - if err == nil { - responses[idx] = &SendResponse{ - Success: true, - MessageID: resp, - } - } else { - responses[idx] = &SendResponse{ - Success: false, - Error: err, - } - } - }(idx, m, dryRun, responses) } - // Wait for all SendDryRun/Send calls to finish - wg.Wait() + + const numWorkers = 50 + jobs := make(chan job, len(messages)) + results := make(chan result, len(messages)) + + responses := make([]*SendResponse, len(messages)) + + for w := 0; w < numWorkers; w++ { + go worker(ctx, c, dryRun, jobs, results) + } + + for idx, m := range messages { + jobs <- job{message: m, index: idx} + } + close(jobs) + + for i := 0; i < len(messages); i++ { + res := <-results + responses[res.index] = res.response + } successCount := 0 + failureCount := 0 for _, r := range responses { if r.Success { successCount++ + } else { + failureCount++ } } return &BatchResponse{ Responses: responses, SuccessCount: successCount, - FailureCount: len(responses) - successCount, + FailureCount: failureCount, }, nil } +type job struct { + message *Message + index int +} + +type result struct { + response *SendResponse + index int +} + +func worker(ctx context.Context, c *fcmClient, dryRun bool, jobs <-chan job, results chan<- result) { + for j := range jobs { + var respMsg string + var err error + if dryRun { + respMsg, err = c.SendDryRun(ctx, j.message) + } else { + respMsg, err = c.Send(ctx, j.message) + } + + var sr *SendResponse + if err == nil { + sr = &SendResponse{ + Success: true, + MessageID: respMsg, + } + } else { + sr = &SendResponse{ + Success: false, + Error: err, + } + } + results <- result{response: sr, index: j.index} + } +} + // SendAll sends the messages in the given array via Firebase Cloud Messaging. // // The messages array may contain up to 500 messages. SendAll employs batching to send the entire diff --git a/messaging/messaging_batch_test.go b/messaging/messaging_batch_test.go index a13ca54b..e8603eae 100644 --- a/messaging/messaging_batch_test.go +++ b/messaging/messaging_batch_test.go @@ -27,25 +27,19 @@ import ( "net/http/httptest" "net/textproto" "strings" + "sync" "testing" "google.golang.org/api/option" ) -var testMessages = []*Message{ - {Topic: "topic1"}, - {Topic: "topic2"}, -} +var testMessages = []*Message{{Topic: "topic1"}, {Topic: "topic2"}} var testMulticastMessage = &MulticastMessage{ Tokens: []string{"token1", "token2"}, } var testSuccessResponse = []fcmResponse{ - { - Name: "projects/test-project/messages/1", - }, - { - Name: "projects/test-project/messages/2", - }, + {Name: "projects/test-project/messages/1"}, + {Name: "projects/test-project/messages/2"}, } const wantMime = "multipart/mixed; boundary=__END_OF_PART__" @@ -53,13 +47,11 @@ const wantSendURL = "/v1/projects/test-project/messages:send" func TestMultipartEntitySingle(t *testing.T) { entity := &multipartEntity{ - parts: []*part{ - { - method: "POST", - url: "http://example.com", - body: map[string]interface{}{"key": "value"}, - }, - }, + parts: []*part{{ + method: "POST", + url: "http://example.com", + body: map[string]interface{}{"key": "value"}, + }}, } mime := entity.Mime() @@ -90,6 +82,265 @@ func TestMultipartEntitySingle(t *testing.T) { } } +func TestSendEachWorkerPoolScenarios(t *testing.T) { + scenarios := []struct { + name string + numMessages int + allSuccessful bool + testNameSuffix string // To make test names more descriptive if needed + }{ + {numMessages: 5, allSuccessful: true, testNameSuffix: " (5msg < 50workers)"}, + {numMessages: 50, allSuccessful: true, testNameSuffix: " (50msg == 50workers)"}, + {numMessages: 75, allSuccessful: true, testNameSuffix: " (75msg > 50workers)"}, + {numMessages: 75, allSuccessful: false, testNameSuffix: " (75msg > 50workers, with Failures)"}, + } + + for _, s := range scenarios { + scenarioName := fmt.Sprintf("NumMessages_%d_AllSuccess_%v%s", s.numMessages, s.allSuccessful, s.testNameSuffix) + t.Run(scenarioName, func(t *testing.T) { + ctx := context.Background() + client, err := NewClient(ctx, testMessagingConfig) + if err != nil { + t.Fatal(err) + } + + messages := make([]*Message, s.numMessages) + expectedSuccessCount := s.numMessages + expectedFailureCount := 0 + + serverHitCount := 0 + mu := &sync.Mutex{} // To protect serverHitCount + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + serverHitCount++ + mu.Unlock() + + var reqBody fcmRequest + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + var originalIndex int + if !s.allSuccessful { // Only parse index if we might fail based on it + topicParts := strings.Split(reqBody.Message.Topic, "topic") + if len(topicParts) == 2 { + fmt.Sscanf(topicParts[1], "%d", &originalIndex) + } else { + t.Logf("Unexpected topic format: %s", reqBody.Message.Topic) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{ + "name": fmt.Sprintf("projects/test-project/messages/%s-unexpected", reqBody.Message.Topic), + }) + return + } + } + + if !s.allSuccessful && (originalIndex+1)%3 == 0 { + w.WriteHeader(http.StatusInternalServerError) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": map[string]interface{}{ + "message": fmt.Sprintf("Simulated server error for original index %d", originalIndex), + "status": "INTERNAL", + }, + }) + } else { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "name": fmt.Sprintf("projects/test-project/messages/%s-idx%d", reqBody.Message.Topic, originalIndex), + }) + } + })) + defer ts.Close() + client.fcmEndpoint = ts.URL + + for i := 0; i < s.numMessages; i++ { + messages[i] = &Message{Topic: fmt.Sprintf("topic%d", i)} + } + + if !s.allSuccessful { + expectedSuccessCount = 0 + expectedFailureCount = 0 + for i := 0; i < s.numMessages; i++ { + if (i+1)%3 == 0 { + expectedFailureCount++ + } else { + expectedSuccessCount++ + } + } + } + + br, err := client.SendEach(ctx, messages) + if err != nil { + t.Fatalf("SendEach() unexpected error: %v", err) + } + + if br.SuccessCount != expectedSuccessCount { + t.Errorf("SuccessCount = %d; want = %d", br.SuccessCount, expectedSuccessCount) + } + if br.FailureCount != expectedFailureCount { + t.Errorf("FailureCount = %d; want = %d", br.FailureCount, expectedFailureCount) + } + if len(br.Responses) != s.numMessages { + t.Errorf("len(Responses) = %d; want = %d", len(br.Responses), s.numMessages) + } + mu.Lock() // Protect serverHitCount read + if serverHitCount != s.numMessages { + t.Errorf("Server hit count = %d; want = %d", serverHitCount, s.numMessages) + } + mu.Unlock() + + for i, resp := range br.Responses { + isExpectedToSucceed := s.allSuccessful || (i+1)%3 != 0 + if resp.Success != isExpectedToSucceed { + t.Errorf("Responses[%d].Success = %v; want = %v", i, resp.Success, isExpectedToSucceed) + } + if isExpectedToSucceed && resp.MessageID == "" { + t.Errorf("Responses[%d].MessageID is empty for a successful message", i) + } + if !isExpectedToSucceed && resp.Error == nil { + t.Errorf("Responses[%d].Error is nil for a failed message", i) + } + } + }) + } +} + +func TestSendEachResponseOrderWithConcurrency(t *testing.T) { + ctx := context.Background() + client, err := NewClient(ctx, testMessagingConfig) + if err != nil { + t.Fatal(err) + } + + numMessages := 75 // Ensure this is > new worker count of 50 + messages := make([]*Message, numMessages) + for i := 0; i < numMessages; i++ { + messages[i] = &Message{Token: fmt.Sprintf("token%d", i)} // Using Token for unique identification + } + + serverHitCount := 0 + messageIDLog := make(map[string]int) + var mu sync.Mutex + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + serverHitCount++ + hitOrder := serverHitCount + mu.Unlock() + + var reqBody fcmRequest + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + messageIdentifier := reqBody.Message.Token + + mu.Lock() + messageIDLog[messageIdentifier] = hitOrder + mu.Unlock() + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "name": fmt.Sprintf("projects/test-project/messages/msg_for_%s", messageIdentifier), + }) + })) + defer ts.Close() + client.fcmEndpoint = ts.URL + + br, err := client.SendEach(ctx, messages) + if err != nil { + t.Fatalf("SendEach() unexpected error: %v", err) + } + + if br.SuccessCount != numMessages { + t.Errorf("SuccessCount = %d; want = %d", br.SuccessCount, numMessages) + } + if len(br.Responses) != numMessages { + t.Errorf("len(Responses) = %d; want = %d", len(br.Responses), numMessages) + } + + if serverHitCount != numMessages { + t.Errorf("Server hit count = %d; want = %d", serverHitCount, numMessages) + } + + for i, resp := range br.Responses { + if !resp.Success { + t.Errorf("Responses[%d] was not successful: %v", i, resp.Error) + continue + } + expectedMessageIDPart := fmt.Sprintf("msg_for_token%d", i) + if !strings.Contains(resp.MessageID, expectedMessageIDPart) { + t.Errorf("Responses[%d].MessageID = %q; want to contain %q", i, resp.MessageID, expectedMessageIDPart) + } + } +} + +func TestSendEachEarlyValidationSkipsSend(t *testing.T) { + ctx := context.Background() + client, err := NewClient(ctx, testMessagingConfig) + if err != nil { + t.Fatal(err) + } + + messagesWithInvalid := []*Message{{Topic: "topic1"}, nil, {Topic: "topic2"}} + + serverHitCount := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + serverHitCount++ + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ "name":"projects/test-project/messages/1" }`)) + })) + defer ts.Close() + client.fcmEndpoint = ts.URL + + br, err := client.SendEach(ctx, messagesWithInvalid) + if err == nil { + t.Errorf("SendEach() expected error for invalid message, got nil") + } + if br != nil { + t.Errorf("SendEach() expected nil BatchResponse for invalid message, got %v", br) + } + + if serverHitCount != 0 { + t.Errorf("Server hit count = %d; want = 0 due to early validation failure", serverHitCount) + } + + messagesWithInvalidFirst := []*Message{ + {Topic: "invalid", Condition: "invalid"}, // Invalid: both Topic and Condition + {Topic: "topic1"}, + } + serverHitCount = 0 + br, err = client.SendEach(ctx, messagesWithInvalidFirst) + if err == nil { + t.Errorf("SendEach() expected error for invalid first message, got nil") + } + if br != nil { + t.Errorf("SendEach() expected nil BatchResponse for invalid first message, got %v", br) + } + if serverHitCount != 0 { + t.Errorf("Server hit count = %d; want = 0 for invalid first message", serverHitCount) + } + + messagesWithInvalidLast := []*Message{ + {Topic: "topic1"}, // Valid first message + {Topic: "topic_last", Token: "token_last"}, // Invalid: cannot have both Topic and Token + } + serverHitCount = 0 + br, err = client.SendEach(ctx, messagesWithInvalidLast) + if err == nil { + t.Errorf("SendEach() expected error for invalid last message, got nil") + } + if br != nil { + t.Errorf("SendEach() expected nil BatchResponse for invalid last message, got %v", br) + } + if serverHitCount != 0 { + t.Errorf("Server hit count = %d; want = 0 for invalid last message", serverHitCount) + } +} + func TestMultipartEntity(t *testing.T) { entity := &multipartEntity{ parts: []*part{ @@ -97,8 +348,7 @@ func TestMultipartEntity(t *testing.T) { method: "POST", url: "http://example1.com", body: map[string]interface{}{"key1": "value"}, - }, - { + }, { method: "POST", url: "http://example2.com", body: map[string]interface{}{"key2": "value"}, @@ -150,13 +400,11 @@ func TestMultipartEntity(t *testing.T) { func TestMultipartEntityError(t *testing.T) { entity := &multipartEntity{ - parts: []*part{ - { - method: "POST", - url: "http://example.com", - body: func() {}, - }, - }, + parts: []*part{{ + method: "POST", + url: "http://example.com", + body: func() {}, + }}, } b, err := entity.Bytes() @@ -275,46 +523,54 @@ func TestSendEachDryRun(t *testing.T) { func TestSendEachPartialFailure(t *testing.T) { success := []fcmResponse{ - { - Name: "projects/test-project/messages/1", - }, + {Name: "projects/test-project/messages/1"}, } var failures []string - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - req, _ := ioutil.ReadAll(r.Body) - - for idx, testMessage := range testMessages { - // Write success for topic1 and error for topic2 - if strings.Contains(string(req), testMessage.Topic) { - if idx%2 == 0 { - w.Header().Set("Content-Type", wantMime) - w.Write([]byte("{ \"name\":\"" + success[0].Name + "\" }")) - } else { - w.WriteHeader(http.StatusInternalServerError) - w.Header().Set("Content-Type", wantMime) - w.Write([]byte(failures[0])) - } - } - } - })) - defer ts.Close() - ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } - client.fcmEndpoint = ts.URL for idx, tc := range httpErrors { failures = []string{tc.resp} + serverHitCount := 0 + var mu sync.Mutex + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + serverHitCount++ + mu.Unlock() + reqBody, _ := ioutil.ReadAll(r.Body) + var msgIn fcmRequest + json.Unmarshal(reqBody, &msgIn) + + if msgIn.Message.Topic == testMessages[0].Topic { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ "name":"` + success[0].Name + `" }`)) + } else if msgIn.Message.Topic == testMessages[1].Topic { + w.WriteHeader(http.StatusInternalServerError) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(failures[0])) + } else { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"unknown topic"}`)) + } + })) + defer ts.Close() + client.fcmEndpoint = ts.URL br, err := client.SendEach(ctx, testMessages) if err != nil { - t.Fatal(err) + t.Fatalf("[%d] SendEach() unexpected error: %v", idx, err) } + mu.Lock() + if serverHitCount != len(testMessages) { + t.Errorf("[%d] Server hit count = %d; want = %d", idx, serverHitCount, len(testMessages)) + } + mu.Unlock() + if err := checkPartialErrorBatchResponse(br, tc); err != nil { t.Errorf("[%d] SendEach() = %v", idx, err) } @@ -322,28 +578,37 @@ func TestSendEachPartialFailure(t *testing.T) { } func TestSendEachTotalFailure(t *testing.T) { - var resp string - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(resp)) - })) - defer ts.Close() - ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } - client.fcmEndpoint = ts.URL client.fcmClient.httpClient.RetryConfig = nil for idx, tc := range httpErrors { - resp = tc.resp + serverHitCount := 0 + var mu sync.Mutex + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + serverHitCount++ + mu.Unlock() + w.WriteHeader(http.StatusInternalServerError) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(tc.resp)) + })) + defer ts.Close() + client.fcmEndpoint = ts.URL + br, err := client.SendEach(ctx, testMessages) if err != nil { - t.Fatal(err) + t.Fatalf("[%d] SendEach() unexpected error: %v", idx, err) + } + + mu.Lock() + if serverHitCount != len(testMessages) { + t.Errorf("[%d] Server hit count = %d; want = %d", idx, serverHitCount, len(testMessages)) } + mu.Unlock() if err := checkTotalErrorBatchResponse(br, tc); err != nil { t.Errorf("[%d] SendEach() = %v", idx, err) @@ -423,10 +688,8 @@ func TestSendEachForMulticastInvalidMessage(t *testing.T) { want := "invalid message at index 0: priority must be 'normal' or 'high'" mm := &MulticastMessage{ - Tokens: []string{"token1"}, - Android: &AndroidConfig{ - Priority: "invalid", - }, + Tokens: []string{"token1"}, + Android: &AndroidConfig{Priority: "invalid"}, } br, err := client.SendEachForMulticast(ctx, mm) if err == nil || err.Error() != want { @@ -525,9 +788,7 @@ func TestSendEachForMulticastDryRun(t *testing.T) { func TestSendEachForMulticastPartialFailure(t *testing.T) { success := []fcmResponse{ - { - Name: "projects/test-project/messages/1", - }, + {Name: "projects/test-project/messages/1"}, } var failures []string @@ -536,7 +797,6 @@ func TestSendEachForMulticastPartialFailure(t *testing.T) { for idx, token := range testMulticastMessage.Tokens { if strings.Contains(string(req), token) { - // Write success for token1 and error for token2 if idx%2 == 0 { w.Header().Set("Content-Type", wantMime) w.Write([]byte("{ \"name\":\"" + success[0].Name + "\" }")) @@ -687,9 +947,7 @@ func TestSendAllDryRun(t *testing.T) { func TestSendAllPartialFailure(t *testing.T) { success := []fcmResponse{ - { - Name: "projects/test-project/messages/1", - }, + {Name: "projects/test-project/messages/1"}, } var req, resp []byte @@ -887,10 +1145,8 @@ func TestSendMulticastInvalidMessage(t *testing.T) { want := "invalid message at index 0: priority must be 'normal' or 'high'" mm := &MulticastMessage{ - Tokens: []string{"token1"}, - Android: &AndroidConfig{ - Priority: "invalid", - }, + Tokens: []string{"token1"}, + Android: &AndroidConfig{Priority: "invalid"}, } br, err := client.SendMulticast(ctx, mm) if err == nil || err.Error() != want { @@ -1001,9 +1257,7 @@ func TestSendMulticastDryRun(t *testing.T) { } func TestSendMulticastPartialFailure(t *testing.T) { - success := []fcmResponse{ - testSuccessResponse[0], - } + success := []fcmResponse{testSuccessResponse[0]} var resp []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -1156,7 +1410,7 @@ func checkSuccessfulSendResponse(r *SendResponse, wantID string) error { } func checkMultipartRequest(b []byte, dryRun bool) error { - reader := multipart.NewReader(bytes.NewBuffer((b)), multipartBoundary) + reader := multipart.NewReader(bytes.NewBuffer(b), multipartBoundary) count := 0 for { part, err := reader.NextPart()