Skip to content

Commit

Permalink
Feat[speedo]: flesh out speedo and add TCP transfer test
Browse files Browse the repository at this point in the history
  • Loading branch information
yunginnanet committed Oct 19, 2023
1 parent ec44773 commit d88ed90
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 51 deletions.
110 changes: 69 additions & 41 deletions internal/util/speedometer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package util

import (
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
Expand All @@ -10,7 +11,7 @@ import (

var ErrLimitReached = errors.New("limit reached")

// Speedometer is a wrapper around an io.Writer that will limit the rate at which data is written to the underlying writer.
// Speedometer is an io.Writer wrapper that will limit the rate at which data is written to the underlying target.
//
// It is safe for concurrent use, but writers will block when slowed down.
//
Expand All @@ -20,23 +21,38 @@ var ErrLimitReached = errors.New("limit reached")
//
// - a speed limit, causing slow downs of data written to the underlying writer if the speed limit is exceeded.
type Speedometer struct {
cap int64
ceiling int64
speedLimit *SpeedLimit
internal atomics
hardLock *sync.RWMutex
w io.Writer
}

type atomics struct {
count *atomic.Int64
closed *atomic.Bool
count *int64
start *sync.Once
stop *sync.Once
birth *atomic.Pointer[time.Time]
duration *atomic.Pointer[time.Duration]
slow *atomic.Bool
}

func newAtomics() atomics {
manhattan := atomics{
count: new(atomic.Int64),
closed: new(atomic.Bool),
start: new(sync.Once),
stop: new(sync.Once),
birth: new(atomic.Pointer[time.Time]),
duration: new(atomic.Pointer[time.Duration]),
slow: new(atomic.Bool),
}
manhattan.birth.Store(&time.Time{})
manhattan.closed.Store(false)
manhattan.count.Store(0)
return manhattan
}

// SpeedLimit is used to limit the rate at which data is written to the underlying writer.
type SpeedLimit struct {
// Burst is the number of bytes that can be written to the underlying writer per Frame.
Expand All @@ -49,6 +65,8 @@ type SpeedLimit struct {
Delay time.Duration
}

const fallbackDelay = 100

func regulateSpeedLimit(speedLimit *SpeedLimit) (*SpeedLimit, error) {
if speedLimit.Burst <= 0 || speedLimit.Frame <= 0 {
return nil, errors.New("invalid speed limit")
Expand All @@ -57,12 +75,12 @@ func regulateSpeedLimit(speedLimit *SpeedLimit) (*SpeedLimit, error) {
speedLimit.CheckEveryBytes = speedLimit.Burst
}
if speedLimit.Delay <= 0 {
speedLimit.Delay = 100 * time.Millisecond
speedLimit.Delay = fallbackDelay * time.Millisecond
}
return speedLimit, nil
}

func newSpeedometer(w io.Writer, speedLimit *SpeedLimit, cap int64) (*Speedometer, error) {
func newSpeedometer(w io.Writer, speedLimit *SpeedLimit, ceiling int64) (*Speedometer, error) {
if w == nil {
return nil, errors.New("writer cannot be nil")
}
Expand All @@ -72,25 +90,13 @@ func newSpeedometer(w io.Writer, speedLimit *SpeedLimit, cap int64) (*Speedomete
return nil, err
}
}
z := int64(0)
speedo := &Speedometer{

return &Speedometer{
w: w,
cap: cap,
ceiling: ceiling,
speedLimit: speedLimit,
hardLock: &sync.RWMutex{},
internal: atomics{
count: &z,
birth: new(atomic.Pointer[time.Time]),
duration: new(atomic.Pointer[time.Duration]),
closed: new(atomic.Bool),
stop: new(sync.Once),
start: new(sync.Once),
slow: new(atomic.Bool),
},
}
speedo.internal.birth.Store(&time.Time{})
speedo.internal.closed.Store(false)
return speedo, nil
internal: newAtomics(),
}, nil
}

// NewSpeedometer creates a new Speedometer that wraps the given io.Writer.
Expand All @@ -108,21 +114,27 @@ func NewLimitedSpeedometer(w io.Writer, speedLimit *SpeedLimit) (*Speedometer, e

// NewCappedSpeedometer creates a new Speedometer that wraps the given io.Writer.
// If len(written) bytes exceeds cap, writes to the underlying writer will be ceased permanently for the Speedometer.
func NewCappedSpeedometer(w io.Writer, cap int64) (*Speedometer, error) {
return newSpeedometer(w, nil, cap)
func NewCappedSpeedometer(w io.Writer, capacity int64) (*Speedometer, error) {
return newSpeedometer(w, nil, capacity)
}

// NewCappedLimitedSpeedometer creates a new Speedometer that wraps the given io.Writer.
// It is a combination of NewLimitedSpeedometer and NewCappedSpeedometer.
func NewCappedLimitedSpeedometer(w io.Writer, speedLimit *SpeedLimit, capacity int64) (*Speedometer, error) {
return newSpeedometer(w, speedLimit, capacity)
}

func (s *Speedometer) increment(inc int64) (int, error) {
if s.internal.closed.Load() || !s.hardLock.TryRLock() {
if s.internal.closed.Load() {
return 0, io.ErrClosedPipe
}
var err error
if s.cap > 0 && s.Total()+inc > s.cap {
if s.ceiling > 0 && s.Total()+inc > s.ceiling {
_ = s.Close()
err = ErrLimitReached
inc = s.cap - s.Total()
inc = s.ceiling - s.Total()
}
atomic.AddInt64(s.internal.count, inc)
s.internal.count.Add(inc)
return int(inc), err
}

Expand All @@ -133,12 +145,11 @@ func (s *Speedometer) Running() bool {

// Total returns the total number of bytes written to the underlying writer.
func (s *Speedometer) Total() int64 {
return atomic.LoadInt64(s.internal.count)
return s.internal.count.Load()
}

// Close stops the Speedometer. No additional writes will be accepted.
func (s *Speedometer) Close() error {
s.hardLock.TryLock()
if s.internal.closed.Load() {
return io.ErrClosedPipe
}
Expand Down Expand Up @@ -187,23 +198,40 @@ func (s *Speedometer) slowDown() error {

// Write writes p to the underlying writer, following all defined speed limits.
func (s *Speedometer) Write(p []byte) (n int, err error) {
if !s.hardLock.TryRLock() {
if s.internal.closed.Load() {
return 0, io.ErrClosedPipe
}
s.internal.start.Do(func() {
now := time.Now()
s.internal.birth.Store(&now)
})
accepted, err := s.increment(int64(len(p)))
if err != nil {
wn, innerErr := s.w.Write(p[:accepted])
if innerErr != nil {
err = innerErr

// if no speed limit, just write and record
if s.speedLimit == nil {
n, err = s.w.Write(p)
if err != nil {
return n, fmt.Errorf("error writing to underlying writer: %w", err)
}
return wn, err
return s.increment(int64(len(p)))
}

var (
wErr error
accepted int
)
accepted, wErr = s.increment(int64(len(p)))

if wErr != nil {
return 0, fmt.Errorf("error incrementing: %w", wErr)
}
if err = s.slowDown(); err != nil {
return 0, err

if sErr := s.slowDown(); sErr != nil {
return 0, fmt.Errorf("error slowing down: %w", sErr)
}

var iErr error
if n, iErr = s.w.Write(p[:accepted]); iErr != nil {
return n, fmt.Errorf("error writing to underlying writer: %w", iErr)
}
return s.w.Write(p)
return
}
Loading

0 comments on commit d88ed90

Please sign in to comment.