From 808eb87cfd70492c5188349d53268f09e85fc8c1 Mon Sep 17 00:00:00 2001 From: Mathis Engelbart Date: Mon, 27 Jan 2025 15:49:14 +0100 Subject: [PATCH] Implement pacing interceptor --- go.mod | 3 +- go.sum | 2 + pkg/pacing/interceptor.go | 240 +++++++++++++++++++++++++++++++++ pkg/pacing/interceptor_test.go | 143 ++++++++++++++++++++ pkg/pacing/rate_limit_pacer.go | 33 +++++ 5 files changed, 420 insertions(+), 1 deletion(-) create mode 100644 pkg/pacing/interceptor.go create mode 100644 pkg/pacing/interceptor_test.go create mode 100644 pkg/pacing/rate_limit_pacer.go diff --git a/go.mod b/go.mod index 68c47f16..11752622 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/pion/interceptor -go 1.21 +go 1.21.0 require ( github.com/pion/logging v0.2.4 @@ -8,6 +8,7 @@ require ( github.com/pion/rtp v1.8.24 github.com/pion/transport/v3 v3.0.8 github.com/stretchr/testify v1.11.1 + golang.org/x/time v0.10.0 ) require ( diff --git a/go.sum b/go.sum index eb36f669..c0faa306 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4= +golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/pacing/interceptor.go b/pkg/pacing/interceptor.go new file mode 100644 index 00000000..9b443dac --- /dev/null +++ b/pkg/pacing/interceptor.go @@ -0,0 +1,240 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +// Package pacing implements a pacing interceptor. +package pacing + +import ( + "errors" + "log/slog" + "maps" + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/logging" + "github.com/pion/rtp" +) + +var ( + errPacerClosed = errors.New("pacer closed") + errPacerOverflow = errors.New("pacer queue overflow") +) + +type pacerFactory func(initialRate, burst int) pacer + +type pacer interface { + SetRate(rate, burst int) + Budget(time.Time) float64 + AllowN(time.Time, int) bool +} + +// Option is a configuration option for pacing interceptors. +type Option func(*Interceptor) error + +// InitialRate configures the initial pacing rate for interceptors created by +// the interceptor factory. +func InitialRate(rate int) Option { + return func(i *Interceptor) error { + i.initialRate = rate + + return nil + } +} + +// Interval configures the pacing interval for interceptors created by the +// interceptor factory. +func Interval(interval time.Duration) Option { + return func(i *Interceptor) error { + i.interval = interval + + return nil + } +} + +func setPacerFactory(f pacerFactory) Option { + return func(i *Interceptor) error { + i.pacerFactory = f + + return nil + } +} + +// InterceptorFactory is a factory for pacing interceptors. It also keeps a map +// of interceptors created in the past by ID. +type InterceptorFactory struct { + lock sync.Mutex + opts []Option + interceptors map[string]*Interceptor +} + +// NewInterceptor returns a new InterceptorFactory. +func NewInterceptor(opts ...Option) *InterceptorFactory { + return &InterceptorFactory{ + lock: sync.Mutex{}, + opts: opts, + interceptors: map[string]*Interceptor{}, + } +} + +// SetRate updates the pacing rate of the pacing interceptor with the given ID. +func (f *InterceptorFactory) SetRate(id string, r int) { + f.lock.Lock() + defer f.lock.Unlock() + + i, ok := f.interceptors[id] + if !ok { + return + } + i.setRate(r) +} + +// NewInterceptor creates a new pacing interceptor. +func (f *InterceptorFactory) NewInterceptor(id string) (interceptor.Interceptor, error) { + f.lock.Lock() + defer f.lock.Unlock() + + interceptor := &Interceptor{ + NoOp: interceptor.NoOp{}, + log: logging.NewDefaultLoggerFactory().NewLogger("pacer_interceptor"), + initialRate: 1_000_000, + interval: 5 * time.Millisecond, + queueSize: 1_000_000, + pacerFactory: func(initialRate, burst int) pacer { + return newRateLimitPacer(initialRate, burst) + }, + limit: nil, + queue: nil, + closed: make(chan struct{}), + wg: sync.WaitGroup{}, + } + for _, opt := range f.opts { + if err := opt(interceptor); err != nil { + return nil, err + } + } + interceptor.limit = interceptor.pacerFactory( + interceptor.initialRate, + burst(interceptor.initialRate, interceptor.interval), + ) + interceptor.queue = make(chan packet, interceptor.queueSize) + + f.interceptors[id] = interceptor + + interceptor.wg.Add(1) + go func() { + defer interceptor.wg.Done() + interceptor.loop() + }() + + return interceptor, nil +} + +// Interceptor implements packet pacing using a token bucket filter and sends +// packets at a fixed interval. +type Interceptor struct { + interceptor.NoOp + log logging.LeveledLogger + + // config + initialRate int + interval time.Duration + queueSize int + pacerFactory pacerFactory + + // limiter and queue + limit pacer + queue chan packet + + // shutdown + closed chan struct{} + wg sync.WaitGroup +} + +// burst calculates the minimal burst size required to reach the given rate and +// pacing interval. +func burst(rate int, interval time.Duration) int { + if interval == 0 { + interval = time.Millisecond + } + f := float64(time.Second.Milliseconds() / interval.Milliseconds()) + + return 8 * int(float64(rate)/f) +} + +// setRate updates the pacing rate and burst of the rate limiter. +func (i *Interceptor) setRate(r int) { + i.limit.SetRate(r, burst(r, i.interval)) +} + +// BindLocalStream implements interceptor.Interceptor. +func (i *Interceptor) BindLocalStream( + info *interceptor.StreamInfo, + writer interceptor.RTPWriter, +) interceptor.RTPWriter { + return interceptor.RTPWriterFunc(func( + header *rtp.Header, + payload []byte, + attributes interceptor.Attributes, + ) (int, error) { + hdr := header.Clone() + pay := make([]byte, len(payload)) + copy(pay, payload) + attr := maps.Clone(attributes) + select { + case i.queue <- packet{ + writer: writer, + header: &hdr, + payload: pay, + attributes: attr, + }: + case <-i.closed: + return 0, errPacerClosed + default: + return 0, errPacerOverflow + } + + return header.MarshalSize() + len(payload), nil + }) +} + +// Close implements interceptor.Interceptor. +func (i *Interceptor) Close() error { + defer i.wg.Wait() + close(i.closed) + + return nil +} + +func (i *Interceptor) loop() { + ticker := time.NewTicker(i.interval) + queue := make([]packet, 0) + for { + select { + case now := <-ticker.C: + for len(queue) > 0 && i.limit.Budget(now) > 8*float64(queue[0].len()) { + i.limit.AllowN(now, 8*queue[0].len()) + var next packet + next, queue = queue[0], queue[1:] + if _, err := next.writer.Write(next.header, next.payload, next.attributes); err != nil { + slog.Warn("error on writing RTP packet", "error", err) + } + } + case pkt := <-i.queue: + queue = append(queue, pkt) + case <-i.closed: + return + } + } +} + +type packet struct { + writer interceptor.RTPWriter + header *rtp.Header + payload []byte + attributes interceptor.Attributes +} + +func (p *packet) len() int { + return p.header.MarshalSize() + len(p.payload) +} diff --git a/pkg/pacing/interceptor_test.go b/pkg/pacing/interceptor_test.go new file mode 100644 index 00000000..c48894a8 --- /dev/null +++ b/pkg/pacing/interceptor_test.go @@ -0,0 +1,143 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package pacing + +import ( + "sync" + "testing" + "time" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/test" + "github.com/pion/rtp" + "github.com/stretchr/testify/assert" +) + +type mockPacer struct { + lock sync.Mutex + + rate int + burst int + + allow bool + allowCalled bool + budget float64 + budgetCalled bool +} + +// AllowN implements pacer. +func (m *mockPacer) AllowN(time.Time, int) bool { + m.lock.Lock() + defer m.lock.Unlock() + m.allowCalled = true + + return m.allow +} + +// Budget implements pacer. +func (m *mockPacer) Budget(time.Time) float64 { + m.lock.Lock() + defer m.lock.Unlock() + + m.budgetCalled = true + + return m.budget +} + +// SetRate implements pacer. +func (m *mockPacer) SetRate(rate int, burst int) { + m.lock.Lock() + defer m.lock.Unlock() + + m.rate = rate + m.burst = burst +} + +func TestInterceptor(t *testing.T) { + t.Run("calls_set_rate", func(t *testing.T) { + mp := &mockPacer{} + i := NewInterceptor( + setPacerFactory(func(initialRate, burst int) pacer { + return mp + }), + ) + + _, err := i.NewInterceptor("") + assert.NoError(t, err) + + i.SetRate("", 1_000_000) + assert.Equal(t, 1_000_000, mp.rate) + assert.Equal(t, 40_000, mp.burst) + }) + + t.Run("paces_packets", func(t *testing.T) { + mp := &mockPacer{ + rate: 0, + burst: 0, + allow: false, + allowCalled: false, + budget: 0, + budgetCalled: false, + } + i := NewInterceptor( + setPacerFactory(func(initialRate, burst int) pacer { + return mp + }), + Interval(time.Millisecond), + ) + + pacer, err := i.NewInterceptor("") + assert.NoError(t, err) + + stream := test.NewMockStream(&interceptor.StreamInfo{}, pacer) + defer func() { + assert.NoError(t, stream.Close()) + }() + + mp.lock.Lock() + mp.allow = true + mp.budget = 8 * 1500 + mp.lock.Unlock() + + hdr := rtp.Header{} + err = stream.WriteRTP(&rtp.Packet{ + Header: hdr, + Payload: make([]byte, 1200-hdr.MarshalSize()), + }) + assert.NoError(t, err) + + select { + case <-stream.WrittenRTP(): + case <-time.After(time.Second): + assert.Fail(t, "no RTP packet written") + } + mp.lock.Lock() + assert.True(t, mp.allowCalled) + assert.True(t, mp.budgetCalled) + mp.lock.Unlock() + + mp.lock.Lock() + mp.allow = false + mp.budget = 0 + mp.lock.Unlock() + + hdr = rtp.Header{} + err = stream.WriteRTP(&rtp.Packet{ + Header: hdr, + Payload: make([]byte, 1200-hdr.MarshalSize()), + }) + assert.NoError(t, err) + + mp.lock.Lock() + assert.True(t, mp.allowCalled) + assert.True(t, mp.budgetCalled) + mp.lock.Unlock() + + select { + case <-stream.WrittenRTP(): + assert.Fail(t, "RTP packet written without pacing budget") + case <-time.After(10 * time.Millisecond): + } + }) +} diff --git a/pkg/pacing/rate_limit_pacer.go b/pkg/pacing/rate_limit_pacer.go new file mode 100644 index 00000000..9b50d5d6 --- /dev/null +++ b/pkg/pacing/rate_limit_pacer.go @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package pacing + +import ( + "time" + + "golang.org/x/time/rate" +) + +type rateLimitPacer struct { + limiter *rate.Limiter +} + +func newRateLimitPacer(initialRate, burst int) *rateLimitPacer { + return &rateLimitPacer{ + limiter: rate.NewLimiter(rate.Limit(initialRate), burst), + } +} + +func (p *rateLimitPacer) SetRate(r, burst int) { + p.limiter.SetLimit(rate.Limit(r)) + p.limiter.SetBurst(burst) +} + +func (p *rateLimitPacer) Budget(t time.Time) float64 { + return p.limiter.TokensAt(t) +} + +func (p *rateLimitPacer) AllowN(t time.Time, n int) bool { + return p.limiter.AllowN(t, n) +}