Skip to content

Commit ef4e98a

Browse files
committed
Implement schedules conditions
1 parent d233ec8 commit ef4e98a

File tree

7 files changed

+450
-14
lines changed

7 files changed

+450
-14
lines changed

lib/accessmonitoring/request_mapping.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ type AccessRequestExpressionEnv struct {
4242
// UserTraits includes arbitrary user traits dynamically provided by the
4343
// access monitoring rule handler.
4444
UserTraits map[string][]string
45+
46+
// Schedules contains the dictonary of schedules configured within the
47+
// access monitoring rule.
48+
Schedules ScheduleDict
4549
}
4650

4751
type accessRequestExpression typical.Expression[AccessRequestExpressionEnv, any]
@@ -127,11 +131,19 @@ func newRequestConditionParser() (*typical.Parser[AccessRequestExpressionEnv, an
127131
"user.traits": typical.DynamicMap(func(env AccessRequestExpressionEnv) (expression.Dict, error) {
128132
return expression.DictFromStringSliceMap(env.UserTraits), nil
129133
}),
134+
135+
"spec.schedules": typical.DynamicMap(func(env AccessRequestExpressionEnv) (ScheduleDict, error) {
136+
return env.Schedules, nil
137+
}),
130138
}
131139

132140
defParserSpec := expression.DefaultParserSpec[AccessRequestExpressionEnv]()
133141
defParserSpec.Variables = typicalEnvVar
134142

143+
inScheduleFn := typical.BinaryFunction[AccessRequestExpressionEnv](inSchedule)
144+
defParserSpec.Functions["in_schedule"] = inScheduleFn
145+
defParserSpec.Methods["in_schedule"] = inScheduleFn
146+
135147
requestConditionParser, err := typical.NewParser[AccessRequestExpressionEnv, any](defParserSpec)
136148
if err != nil {
137149
return nil, trace.Wrap(err)

lib/accessmonitoring/request_mapping_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ package accessmonitoring
2020

2121
import (
2222
"testing"
23+
"time"
2324

2425
"github.com/stretchr/testify/require"
2526

27+
accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1"
2628
"github.com/gravitational/teleport/api/types"
2729
)
2830

@@ -238,6 +240,69 @@ func TestEvaluateCondition(t *testing.T) {
238240
},
239241
match: false,
240242
},
243+
{
244+
description: "creation time is in schedule",
245+
condition: `access_request.spec.creation_time.in_schedule(spec.schedules["test"])`,
246+
env: AccessRequestExpressionEnv{
247+
CreationTime: time.Date(2025, time.August, 11, 14, 30, 0, 0, time.UTC), // Monday 14:30
248+
Schedules: ScheduleDict{
249+
"test": &accessmonitoringrulesv1.Schedule{
250+
Time: &accessmonitoringrulesv1.TimeSchedule{
251+
Shifts: []*accessmonitoringrulesv1.TimeSchedule_Shift{
252+
{
253+
Weekday: "Monday",
254+
Start: "14:00",
255+
End: "15:00",
256+
},
257+
},
258+
},
259+
},
260+
},
261+
},
262+
match: true,
263+
},
264+
{
265+
description: "shift interval is inclusive",
266+
condition: `access_request.spec.creation_time.in_schedule(spec.schedules["test"])`,
267+
env: AccessRequestExpressionEnv{
268+
CreationTime: time.Date(2025, time.August, 11, 14, 30, 0, 0, time.UTC), // Monday 14:30
269+
Schedules: ScheduleDict{
270+
"test": &accessmonitoringrulesv1.Schedule{
271+
Time: &accessmonitoringrulesv1.TimeSchedule{
272+
Shifts: []*accessmonitoringrulesv1.TimeSchedule_Shift{
273+
{
274+
Weekday: "Monday",
275+
Start: "14:30",
276+
End: "15:00",
277+
},
278+
},
279+
},
280+
},
281+
},
282+
},
283+
match: true,
284+
},
285+
{
286+
description: "schedule name not found",
287+
condition: `access_request.spec.creation_time.in_schedule(spec.schedules["not-found"])`,
288+
env: AccessRequestExpressionEnv{
289+
CreationTime: time.Date(2025, time.August, 11, 14, 30, 0, 0, time.UTC), // Monday 14:30
290+
Schedules: ScheduleDict{
291+
"test": &accessmonitoringrulesv1.Schedule{
292+
Time: &accessmonitoringrulesv1.TimeSchedule{
293+
Shifts: []*accessmonitoringrulesv1.TimeSchedule_Shift{
294+
{
295+
Weekday: "Monday",
296+
Start: "14:00",
297+
End: "15:00",
298+
},
299+
},
300+
},
301+
},
302+
},
303+
},
304+
match: false,
305+
},
241306
}
242307

243308
for _, test := range tests {

lib/accessmonitoring/review/review.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ func (handler *Handler) getMatchingRule(
241241
var reviewRule *accessmonitoringrulesv1.AccessMonitoringRule
242242

243243
for _, rule := range handler.rules.Get() {
244+
env.Schedules = accessmonitoring.ScheduleDict(rule.GetSpec().GetSchedules())
245+
244246
conditionMatch, err := accessmonitoring.EvaluateCondition(rule.GetSpec().GetCondition(), env)
245247
if err != nil {
246248
handler.Logger.WarnContext(ctx, "Failed to evaluate access monitoring rule",

lib/accessmonitoring/review/review_test.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,113 @@ func TestConflictingRules(t *testing.T) {
198198
require.NoError(t, handler.HandleAccessRequest(ctx, event))
199199
}
200200

201+
func TestScheduleRequest(t *testing.T) {
202+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
203+
t.Cleanup(cancel)
204+
205+
testReqID := uuid.New().String()
206+
testRuleName := "test-rule"
207+
withSecretsFalse := false
208+
requesterUserName := "requester"
209+
210+
requester, err := types.NewUser(requesterUserName)
211+
require.NoError(t, err)
212+
213+
testRule := newApprovedRule(
214+
testRuleName,
215+
`access_request.spec.creation_time.in_schedule(spec.schedules["test-schedule"])`)
216+
217+
testRule.Spec.Schedules = map[string]*accessmonitoringrulesv1.Schedule{
218+
"test-schedule": {
219+
Time: &accessmonitoringrulesv1.TimeSchedule{
220+
Shifts: []*accessmonitoringrulesv1.TimeSchedule_Shift{
221+
{
222+
Weekday: time.Monday.String(),
223+
Start: "14:00",
224+
End: "15:00",
225+
},
226+
},
227+
},
228+
},
229+
}
230+
231+
cache := accessmonitoring.NewCache()
232+
cache.Put([]*accessmonitoringrulesv1.AccessMonitoringRule{testRule})
233+
234+
tests := []struct {
235+
description string
236+
setupMock func(m *mockClient)
237+
creationTime time.Time
238+
assertErr require.ErrorAssertionFunc
239+
}{
240+
{
241+
description: "test within schedule",
242+
setupMock: func(m *mockClient) {
243+
m.On("GetUser", mock.Anything, requesterUserName, withSecretsFalse).
244+
Return(requester, nil)
245+
246+
review, err := newAccessReview(
247+
requesterUserName,
248+
testRuleName,
249+
types.RequestState_APPROVED.String(),
250+
time.Time{},
251+
)
252+
require.NoError(t, err)
253+
254+
m.On("SubmitAccessReview", mock.Anything, types.AccessReviewSubmission{
255+
RequestID: testReqID,
256+
Review: review,
257+
}).Return(mock.Anything, nil)
258+
},
259+
creationTime: time.Date(2025, time.August, 11, 14, 30, 0, 0, time.UTC),
260+
assertErr: require.NoError,
261+
},
262+
{
263+
description: "test outside schedule",
264+
setupMock: func(m *mockClient) {
265+
m.On("GetUser", mock.Anything, requesterUserName, withSecretsFalse).
266+
Return(requester, nil)
267+
268+
m.AssertNotCalled(t, "SubmitAccessReview")
269+
},
270+
creationTime: time.Date(2025, time.August, 11, 15, 30, 0, 0, time.UTC),
271+
assertErr: require.NoError,
272+
},
273+
}
274+
275+
for _, test := range tests {
276+
t.Run(test.description, func(t *testing.T) {
277+
t.Parallel()
278+
279+
client := &mockClient{}
280+
if test.setupMock != nil {
281+
test.setupMock(client)
282+
}
283+
284+
handler, err := NewHandler(Config{
285+
HandlerName: handlerName,
286+
Client: client,
287+
Cache: cache,
288+
})
289+
require.NoError(t, err)
290+
291+
req, err := types.NewAccessRequest(
292+
testReqID,
293+
requesterUserName,
294+
"role",
295+
)
296+
require.NoError(t, err)
297+
req.SetCreationTime(test.creationTime)
298+
299+
test.assertErr(t, handler.HandleAccessRequest(ctx, types.Event{
300+
Type: types.OpPut,
301+
Resource: req,
302+
}))
303+
client.AssertExpectations(t)
304+
})
305+
}
306+
}
307+
201308
func TestResourceRequest(t *testing.T) {
202309
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
203310
t.Cleanup(cancel)

lib/accessmonitoring/schedule.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
Copyright 2025 Gravitational, Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package accessmonitoring
18+
19+
import (
20+
"time"
21+
22+
accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1"
23+
24+
"github.com/gravitational/trace"
25+
)
26+
27+
// ScheduleDict specifies a dictionary of schedules.
28+
//
29+
// Implements [typical.Getter]
30+
type ScheduleDict map[string]*accessmonitoringrulesv1.Schedule
31+
32+
// Get returns the schedule with the specified key name.
33+
func (d ScheduleDict) Get(key string) (*accessmonitoringrulesv1.Schedule, error) {
34+
return d[key], nil
35+
}
36+
37+
// ClockTime returns a new time value overriding the hour and minute.
38+
func ClockTime(timestamp time.Time, hourMinute string) (time.Time, error) {
39+
const hourMinuteFormat = "15:04" // 24-hour HH:MM format
40+
41+
parsed, err := time.ParseInLocation(hourMinuteFormat, hourMinute, timestamp.Location())
42+
if err != nil {
43+
return time.Time{}, trace.Wrap(err)
44+
}
45+
46+
return time.Date(timestamp.Year(), timestamp.Month(), timestamp.Day(),
47+
parsed.Hour(), parsed.Minute(), 0, 0, timestamp.Location()), nil
48+
}
49+
50+
// inSchedule returns true if the timestamp is within the schedule.
51+
func inSchedule(timestamp time.Time, schedule *accessmonitoringrulesv1.Schedule) (bool, error) {
52+
if schedule.GetTime() == nil {
53+
return false, nil
54+
}
55+
56+
if len(schedule.GetTime().GetShifts()) == 0 {
57+
return false, nil
58+
}
59+
60+
weekday := timestamp.Weekday().String()
61+
for _, shift := range schedule.GetTime().GetShifts() {
62+
if weekday != shift.Weekday {
63+
continue
64+
}
65+
66+
startTime, err := ClockTime(timestamp, shift.Start)
67+
if err != nil {
68+
return false, trace.Wrap(err, "invalid start time: %q", shift.Start)
69+
}
70+
71+
endTime, err := ClockTime(timestamp, shift.End)
72+
if err != nil {
73+
return false, trace.Wrap(err, "invalid end time: %q", shift.End)
74+
}
75+
76+
if !timestamp.Before(startTime) && !timestamp.After(endTime) {
77+
return true, nil
78+
}
79+
}
80+
return false, nil
81+
}

0 commit comments

Comments
 (0)