diff --git a/integrations/access/accessmonitoring/access_monitoring_rules.go b/integrations/access/accessmonitoring/access_monitoring_rules.go index 35ff0d4c8874a..00bd0b578f2f8 100644 --- a/integrations/access/accessmonitoring/access_monitoring_rules.go +++ b/integrations/access/accessmonitoring/access_monitoring_rules.go @@ -154,6 +154,22 @@ func (amrh *RuleHandler) RecipientsFromAccessMonitoringRules(ctx context.Context } for _, rule := range amrh.getAccessMonitoringRules() { + // Check if creation time is within rule schedules. + isInSchedules, err := accessmonitoring.InSchedules(rule.GetSpec().GetSchedules(), env.CreationTime) + if err != nil { + log.WarnContext(ctx, "Failed to evaluate access monitoring rule", + "error", err, + "rule", rule.GetMetadata().GetName(), + ) + continue + } + if len(rule.GetSpec().GetSchedules()) != 0 && !isInSchedules { + log.DebugContext(ctx, "Access request does not satisfy schedule condition", + "rule", rule.GetMetadata().GetName()) + continue + } + + // Check if environment matches rule conditions. match, err := accessmonitoring.EvaluateCondition(rule.Spec.Condition, env) if err != nil { log.WarnContext(ctx, "Failed to parse access monitoring notification rule", @@ -188,6 +204,22 @@ func (amrh *RuleHandler) RawRecipientsFromAccessMonitoringRules(ctx context.Cont } for _, rule := range amrh.getAccessMonitoringRules() { + // Check if creation time is within rule schedules. + isInSchedules, err := accessmonitoring.InSchedules(rule.GetSpec().GetSchedules(), env.CreationTime) + if err != nil { + log.WarnContext(ctx, "Failed to evaluate access monitoring rule", + "error", err, + "rule", rule.GetMetadata().GetName(), + ) + continue + } + if len(rule.GetSpec().GetSchedules()) != 0 && !isInSchedules { + log.DebugContext(ctx, "Access request does not satisfy schedule condition", + "rule", rule.GetMetadata().GetName()) + continue + } + + // Check if environment matches rule conditions. match, err := accessmonitoring.EvaluateCondition(rule.Spec.Condition, env) if err != nil { log.WarnContext(ctx, "Failed to parse access monitoring notification rule", diff --git a/integrations/access/accessmonitoring/access_monitoring_rules_test.go b/integrations/access/accessmonitoring/access_monitoring_rules_test.go index fffbc062baac6..9a39c9b0b01d9 100644 --- a/integrations/access/accessmonitoring/access_monitoring_rules_test.go +++ b/integrations/access/accessmonitoring/access_monitoring_rules_test.go @@ -21,6 +21,7 @@ package accessmonitoring import ( "context" "testing" + "time" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -52,7 +53,7 @@ func mockFetchRecipient(ctx context.Context, recipient string) (*common.Recipien return nil, nil } -func TestRecipeints(t *testing.T) { +func TestRecipients(t *testing.T) { const ( pluginName = "fakePluginName" pluginType = "fakePluginType" @@ -133,7 +134,7 @@ func TestRecipeints(t *testing.T) { require.ElementsMatch(t, []string{}, rawRecipients) } -func TestRecipeintsWithResources(t *testing.T) { +func TestRecipientsWithResources(t *testing.T) { const ( pluginName = "fakePluginName" pluginType = "fakePluginType" @@ -205,6 +206,85 @@ func TestRecipeintsWithResources(t *testing.T) { require.ElementsMatch(t, []string{recipient}, rawRecipients) } +func TestRecipientsWithSchedules(t *testing.T) { + const ( + pluginName = "fakePluginName" + pluginType = "fakePluginType" + recipient = "recipient@goteleport.com" + ) + + teleportClient := &mockTeleportClient{} + teleportClient. + On("GetUser", mock.Anything, mock.Anything, mock.Anything). + Return(&types.UserV2{}, nil) + + amrh := NewRuleHandler(RuleHandlerConfig{ + Client: teleportClient, + PluginType: pluginType, + PluginName: pluginName, + FetchRecipientCallback: func(ctx context.Context, recipient string) (*common.Recipient, error) { + return emailRecipient(recipient), nil + }, + }) + + rule1, err := services.NewAccessMonitoringRuleWithLabels("rule1", nil, &pb.AccessMonitoringRuleSpec{ + Subjects: []string{types.KindAccessRequest}, + Schedules: map[string]*pb.Schedule{ + "default": { + Time: &pb.TimeSchedule{ + Shifts: []*pb.TimeSchedule_Shift{ + { + Weekday: time.Monday.String(), + Start: "14:00", + End: "15:00", + }, + }, + }, + }, + }, + Condition: `true`, + Notification: &pb.Notification{ + Name: pluginName, + Recipients: []string{recipient}, + }, + }) + require.NoError(t, err) + err = amrh.HandleAccessMonitoringRule(context.Background(), types.Event{ + Type: types.OpPut, + Resource: types.Resource153ToLegacy(rule1), + }) + require.NoError(t, err) + require.Len(t, amrh.getAccessMonitoringRules(), 1) + + ctx := context.Background() + + // Expect recipient from matching rule. + req := &types.AccessRequestV3{ + Spec: types.AccessRequestSpecV3{ + Created: time.Date(2025, time.August, 11, 14, 30, 0, 0, time.UTC), + }, + } + + recipients := amrh.RecipientsFromAccessMonitoringRules(ctx, req) + require.ElementsMatch(t, []common.Recipient{*emailRecipient(recipient)}, recipients.ToSlice()) + + rawRecipients := amrh.RawRecipientsFromAccessMonitoringRules(ctx, req) + require.ElementsMatch(t, []string{recipient}, rawRecipients) + + // Expect no recipient when not in schedule. + req = &types.AccessRequestV3{ + Spec: types.AccessRequestSpecV3{ + Created: time.Date(2025, time.August, 11, 15, 30, 0, 0, time.UTC), + }, + } + + recipients = amrh.RecipientsFromAccessMonitoringRules(ctx, req) + require.ElementsMatch(t, []common.Recipient{}, recipients.ToSlice()) + + rawRecipients = amrh.RawRecipientsFromAccessMonitoringRules(ctx, req) + require.ElementsMatch(t, []string{}, rawRecipients) +} + func emailRecipient(recipient string) *common.Recipient { return &common.Recipient{ Name: recipient, diff --git a/lib/accessmonitoring/review/review.go b/lib/accessmonitoring/review/review.go index 37a409a22cd7d..4b4d2d06d175f 100644 --- a/lib/accessmonitoring/review/review.go +++ b/lib/accessmonitoring/review/review.go @@ -241,7 +241,9 @@ func (handler *Handler) getMatchingRule( var reviewRule *accessmonitoringrulesv1.AccessMonitoringRule for _, rule := range handler.rules.Get() { - conditionMatch, err := accessmonitoring.EvaluateCondition(rule.GetSpec().GetCondition(), env) + + // Check if creation time is within rule schedules. + isInSchedules, err := accessmonitoring.InSchedules(rule.GetSpec().GetSchedules(), env.CreationTime) if err != nil { handler.Logger.WarnContext(ctx, "Failed to evaluate access monitoring rule", "error", err, @@ -249,7 +251,21 @@ func (handler *Handler) getMatchingRule( ) continue } + if len(rule.GetSpec().GetSchedules()) != 0 && !isInSchedules { + handler.Logger.DebugContext(ctx, "Access request does not satisfy schedule condition", + "rule", rule.GetMetadata().GetName()) + continue + } + // Check if environment matches rule conditions. + conditionMatch, err := accessmonitoring.EvaluateCondition(rule.GetSpec().GetCondition(), env) + if err != nil { + handler.Logger.WarnContext(ctx, "Failed to evaluate access monitoring rule", + "error", err, + "rule", rule.GetMetadata().GetName(), + ) + continue + } if !conditionMatch { continue } diff --git a/lib/accessmonitoring/review/review_test.go b/lib/accessmonitoring/review/review_test.go index 2e70fa5315664..b3315e13d1812 100644 --- a/lib/accessmonitoring/review/review_test.go +++ b/lib/accessmonitoring/review/review_test.go @@ -198,6 +198,113 @@ func TestConflictingRules(t *testing.T) { require.NoError(t, handler.HandleAccessRequest(ctx, event)) } +func TestScheduleRequest(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + t.Cleanup(cancel) + + testReqID := uuid.New().String() + testRuleName := "test-rule" + withSecretsFalse := false + requesterUserName := "requester" + + requester, err := types.NewUser(requesterUserName) + require.NoError(t, err) + + testRule := newApprovedRule( + testRuleName, + `true`) + + testRule.Spec.Schedules = map[string]*accessmonitoringrulesv1.Schedule{ + "test-schedule": { + Time: &accessmonitoringrulesv1.TimeSchedule{ + Shifts: []*accessmonitoringrulesv1.TimeSchedule_Shift{ + { + Weekday: time.Monday.String(), + Start: "14:00", + End: "15:00", + }, + }, + }, + }, + } + + cache := accessmonitoring.NewCache() + cache.Put([]*accessmonitoringrulesv1.AccessMonitoringRule{testRule}) + + tests := []struct { + description string + setupMock func(m *mockClient) + creationTime time.Time + assertErr require.ErrorAssertionFunc + }{ + { + description: "test within schedule", + setupMock: func(m *mockClient) { + m.On("GetUser", mock.Anything, requesterUserName, withSecretsFalse). + Return(requester, nil) + + review, err := newAccessReview( + requesterUserName, + testRuleName, + types.RequestState_APPROVED.String(), + time.Time{}, + ) + require.NoError(t, err) + + m.On("SubmitAccessReview", mock.Anything, types.AccessReviewSubmission{ + RequestID: testReqID, + Review: review, + }).Return(mock.Anything, nil) + }, + creationTime: time.Date(2025, time.August, 11, 14, 30, 0, 0, time.UTC), + assertErr: require.NoError, + }, + { + description: "test outside schedule", + setupMock: func(m *mockClient) { + m.On("GetUser", mock.Anything, requesterUserName, withSecretsFalse). + Return(requester, nil) + + m.AssertNotCalled(t, "SubmitAccessReview") + }, + creationTime: time.Date(2025, time.August, 11, 15, 30, 0, 0, time.UTC), + assertErr: require.NoError, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + t.Parallel() + + client := &mockClient{} + if test.setupMock != nil { + test.setupMock(client) + } + + handler, err := NewHandler(Config{ + HandlerName: handlerName, + Client: client, + Cache: cache, + }) + require.NoError(t, err) + + req, err := types.NewAccessRequest( + testReqID, + requesterUserName, + "role", + ) + require.NoError(t, err) + req.SetCreationTime(test.creationTime) + + test.assertErr(t, handler.HandleAccessRequest(ctx, types.Event{ + Type: types.OpPut, + Resource: req, + })) + client.AssertExpectations(t) + }) + } +} + func TestResourceRequest(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) t.Cleanup(cancel) diff --git a/lib/accessmonitoring/schedule.go b/lib/accessmonitoring/schedule.go new file mode 100644 index 0000000000000..69510a721fab5 --- /dev/null +++ b/lib/accessmonitoring/schedule.go @@ -0,0 +1,93 @@ +/* +Copyright 2025 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package accessmonitoring + +import ( + "time" + + "github.com/gravitational/trace" + + accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" +) + +// ClockTime returns a new time value overriding the hour and minute. +func ClockTime(timestamp time.Time, hourMinute string) (time.Time, error) { + const hourMinuteFormat = "15:04" // 24-hour HH:MM format + + parsed, err := time.ParseInLocation(hourMinuteFormat, hourMinute, timestamp.Location()) + if err != nil { + return time.Time{}, trace.Wrap(err) + } + + return time.Date(timestamp.Year(), timestamp.Month(), timestamp.Day(), + parsed.Hour(), parsed.Minute(), 0, 0, timestamp.Location()), nil +} + +// inSchedule returns true if the timestamp is within the schedule. +func inSchedule(schedule *accessmonitoringrulesv1.Schedule, timestamp time.Time) (bool, error) { + if schedule.GetTime() == nil { + return false, nil + } + + if len(schedule.GetTime().GetShifts()) == 0 { + return false, nil + } + + loc, err := time.LoadLocation(schedule.GetTime().GetTimezone()) + if err != nil { + return false, trace.Wrap(err) + } + + timestamp = timestamp.In(loc) + weekday := timestamp.Weekday().String() + + for _, shift := range schedule.GetTime().GetShifts() { + if weekday != shift.Weekday { + continue + } + + startTime, err := ClockTime(timestamp, shift.Start) + if err != nil { + return false, trace.Wrap(err, "invalid start time: %q", shift.Start) + } + + endTime, err := ClockTime(timestamp, shift.End) + if err != nil { + return false, trace.Wrap(err, "invalid end time: %q", shift.End) + } + + if !timestamp.Before(startTime) && !timestamp.After(endTime) { + return true, nil + } + } + return false, nil +} + +// InSchedules returns true if the provided timestamp is within an of the provided +// schedules. Returns false if schedules is empty. +func InSchedules(schedules map[string]*accessmonitoringrulesv1.Schedule, timestamp time.Time) (bool, error) { + for _, schedule := range schedules { + isInSchedule, err := inSchedule(schedule, timestamp) + if err != nil { + return false, trace.Wrap(err) + } + if isInSchedule { + return true, nil + } + } + return false, nil +} diff --git a/lib/accessmonitoring/schedule_test.go b/lib/accessmonitoring/schedule_test.go new file mode 100644 index 0000000000000..08316b84fcff5 --- /dev/null +++ b/lib/accessmonitoring/schedule_test.go @@ -0,0 +1,260 @@ +/* +Copyright 2025 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package accessmonitoring + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" +) + +func TestClockTime(t *testing.T) { + timestamp := time.Date(2025, time.August, 11, 14, 30, 0, 0, time.UTC) + tests := []struct { + description string + clockTime string + assertErr require.ErrorAssertionFunc + assertTime require.ValueAssertionFunc + }{ + { + description: "min clock time", + clockTime: "00:00", + assertErr: require.NoError, + assertTime: func(t require.TestingT, ts any, _ ...any) { + require.Equal(t, time.Date(2025, time.August, 11, 0, 0, 0, 0, time.UTC), ts) + }, + }, + { + description: "max clock time", + clockTime: "23:59", + assertErr: require.NoError, + assertTime: func(t require.TestingT, ts any, _ ...any) { + require.Equal(t, time.Date(2025, time.August, 11, 23, 59, 0, 0, time.UTC), ts) + }, + }, + { + description: "24 hour out of range", + clockTime: "24:00", + assertErr: func(t require.TestingT, err error, _ ...any) { + require.ErrorContains(t, err, "hour out of range") + }, + }, + { + description: "60 minute out of range", + clockTime: "00:60", + assertErr: func(t require.TestingT, err error, _ ...any) { + require.ErrorContains(t, err, "minute out of range") + }, + }, + { + description: "seconds specified", + clockTime: "12:34:56", + assertErr: func(t require.TestingT, err error, _ ...any) { + require.ErrorContains(t, err, "extra text") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + t.Parallel() + ts, err := ClockTime(timestamp, tt.clockTime) + tt.assertErr(t, err) + if tt.assertTime != nil { + tt.assertTime(t, ts) + } + }) + } +} + +func TestInSchedule(t *testing.T) { + timestamp := time.Date(2025, time.August, 11, 14, 30, 0, 0, time.UTC) // Monday 14:30 + tests := []struct { + description string + schedules map[string]*accessmonitoringrulesv1.Schedule + assertErr require.ErrorAssertionFunc + assertInSchedule require.BoolAssertionFunc + }{ + { + description: "in schedule", + schedules: map[string]*accessmonitoringrulesv1.Schedule{ + "default": { + Time: &accessmonitoringrulesv1.TimeSchedule{ + Shifts: []*accessmonitoringrulesv1.TimeSchedule_Shift{ + { + Weekday: time.Monday.String(), + Start: "14:00", + End: "15:00", + }, + }, + }, + }, + }, + assertErr: require.NoError, + assertInSchedule: require.True, + }, + { + description: "schedule does not contain any shifts", + schedules: map[string]*accessmonitoringrulesv1.Schedule{ + "default": { + Time: &accessmonitoringrulesv1.TimeSchedule{ + Shifts: []*accessmonitoringrulesv1.TimeSchedule_Shift{}, + }, + }, + }, + assertErr: require.NoError, + assertInSchedule: require.False, + }, + { + description: "invalid timezone", + schedules: map[string]*accessmonitoringrulesv1.Schedule{ + "default": { + Time: &accessmonitoringrulesv1.TimeSchedule{ + Timezone: "invalid", + Shifts: []*accessmonitoringrulesv1.TimeSchedule_Shift{ + { + Weekday: time.Monday.String(), + Start: "14:00", + End: "15:00", + }, + }, + }, + }, + }, + assertErr: func(t require.TestingT, err error, _ ...any) { + require.ErrorContains(t, err, "unknown time zone") + }, + assertInSchedule: require.False, + }, + { + description: "different timezone", + schedules: map[string]*accessmonitoringrulesv1.Schedule{ + "default": { + Time: &accessmonitoringrulesv1.TimeSchedule{ + Timezone: "America/Los_Angeles", + Shifts: []*accessmonitoringrulesv1.TimeSchedule_Shift{ + { + Weekday: time.Monday.String(), + Start: "14:00", + End: "15:00", + }, + }, + }, + }, + }, + assertErr: require.NoError, + assertInSchedule: require.False, + }, + { + description: "different weekday", + schedules: map[string]*accessmonitoringrulesv1.Schedule{ + "default": { + Time: &accessmonitoringrulesv1.TimeSchedule{ + Shifts: []*accessmonitoringrulesv1.TimeSchedule_Shift{ + { + Weekday: time.Tuesday.String(), + Start: "14:00", + End: "15:00", + }, + }, + }, + }, + }, + assertErr: require.NoError, + assertInSchedule: require.False, + }, + { + description: "before schedule", + schedules: map[string]*accessmonitoringrulesv1.Schedule{ + "default": { + Time: &accessmonitoringrulesv1.TimeSchedule{ + Shifts: []*accessmonitoringrulesv1.TimeSchedule_Shift{ + { + Weekday: time.Monday.String(), + Start: "14:31", + End: "15:00", + }, + }, + }, + }, + }, + assertErr: require.NoError, + assertInSchedule: require.False, + }, + { + description: "exact start time", + schedules: map[string]*accessmonitoringrulesv1.Schedule{ + "default": { + Time: &accessmonitoringrulesv1.TimeSchedule{ + Shifts: []*accessmonitoringrulesv1.TimeSchedule_Shift{ + { + Weekday: time.Monday.String(), + Start: "14:30", + End: "15:00", + }, + }, + }, + }, + }, + assertErr: require.NoError, + assertInSchedule: require.True, + }, + { + description: "multiple schedules", + schedules: map[string]*accessmonitoringrulesv1.Schedule{ + "schedule1": { + Time: &accessmonitoringrulesv1.TimeSchedule{ + Shifts: []*accessmonitoringrulesv1.TimeSchedule_Shift{ + { + Weekday: time.Monday.String(), + Start: "14:30", + End: "15:00", + }, + }, + }, + }, + "schedule2": { + Time: &accessmonitoringrulesv1.TimeSchedule{ + Shifts: []*accessmonitoringrulesv1.TimeSchedule_Shift{ + { + Weekday: time.Tuesday.String(), + Start: "14:30", + End: "15:00", + }, + }, + }, + }, + }, + assertErr: require.NoError, + assertInSchedule: require.True, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + t.Parallel() + ts, err := InSchedules(tt.schedules, timestamp) + tt.assertErr(t, err) + if tt.assertInSchedule != nil { + tt.assertInSchedule(t, ts) + } + }) + } +} diff --git a/lib/services/access_monitoring_rules.go b/lib/services/access_monitoring_rules.go index 6ff8c60b9e18d..baf4f1704b4da 100644 --- a/lib/services/access_monitoring_rules.go +++ b/lib/services/access_monitoring_rules.go @@ -178,12 +178,12 @@ func validateShift(shift *accessmonitoringrulesv1.TimeSchedule_Shift) error { return trace.BadParameter("failed to parse weekday: %v", shift.GetWeekday()) } - start, err := clockTime(time.Time{}, shift.GetStart()) + start, err := accessmonitoring.ClockTime(time.Time{}, shift.GetStart()) if err != nil { return trace.Wrap(err, "invalid start time") } - end, err := clockTime(time.Time{}, shift.GetEnd()) + end, err := accessmonitoring.ClockTime(time.Time{}, shift.GetEnd()) if err != nil { return trace.Wrap(err, "invalid end time") } @@ -194,18 +194,6 @@ func validateShift(shift *accessmonitoringrulesv1.TimeSchedule_Shift) error { return nil } -func clockTime(timestamp time.Time, hourMinute string) (time.Time, error) { - const hourMinuteFormat = "15:04" // 24-hour HH:MM format - - parsed, err := time.ParseInLocation(hourMinuteFormat, hourMinute, timestamp.Location()) - if err != nil { - return time.Time{}, trace.Wrap(err) - } - - return time.Date(timestamp.Year(), timestamp.Month(), timestamp.Day(), - parsed.Hour(), parsed.Minute(), 0, 0, timestamp.Location()), nil -} - // MarshalAccessMonitoringRule marshals AccessMonitoringRule resource to JSON. func MarshalAccessMonitoringRule(accessMonitoringRule *accessmonitoringrulesv1.AccessMonitoringRule, opts ...MarshalOption) ([]byte, error) { return FastMarshalProtoResourceDeprecated(accessMonitoringRule, opts...)