diff --git a/internal/util/speedometer.go b/internal/util/speedometer.go deleted file mode 100644 index e3cbbae..0000000 --- a/internal/util/speedometer.go +++ /dev/null @@ -1,237 +0,0 @@ -package util - -import ( - "errors" - "fmt" - "io" - "sync" - "sync/atomic" - "time" -) - -var ErrLimitReached = errors.New("limit reached") - -// Speedometer is an io.Writer wrapper that will limit the rate at which data is written to the underlying target. -// -// It is safe for concurrent use, but writers will block when slowed down. -// -// Optionally, it can be given; -// -// - a capacity, which will cause it to return an error if the capacity is exceeded. -// -// - a speed limit, causing slow downs of data written to the underlying writer if the speed limit is exceeded. -type Speedometer struct { - ceiling int64 - speedLimit *SpeedLimit - internal atomics - w io.Writer -} - -type atomics struct { - count *atomic.Int64 - closed *atomic.Bool - start *sync.Once - stop *sync.Once - birth *atomic.Pointer[time.Time] - duration *atomic.Pointer[time.Duration] - slow *atomic.Bool -} - -func newAtomics() atomics { - manhattan := atomics{ - count: new(atomic.Int64), - closed: new(atomic.Bool), - start: new(sync.Once), - stop: new(sync.Once), - birth: new(atomic.Pointer[time.Time]), - duration: new(atomic.Pointer[time.Duration]), - slow: new(atomic.Bool), - } - manhattan.birth.Store(&time.Time{}) - manhattan.closed.Store(false) - manhattan.count.Store(0) - return manhattan -} - -// SpeedLimit is used to limit the rate at which data is written to the underlying writer. -type SpeedLimit struct { - // Burst is the number of bytes that can be written to the underlying writer per Frame. - Burst int - // Frame is the duration of the frame in which Burst can be written to the underlying writer. - Frame time.Duration - // CheckEveryBytes is the number of bytes written before checking if the speed limit has been exceeded. - CheckEveryBytes int - // Delay is the duration to delay writing if the speed limit has been exceeded during a Write call. (blocking) - Delay time.Duration -} - -const fallbackDelay = 100 - -func regulateSpeedLimit(speedLimit *SpeedLimit) (*SpeedLimit, error) { - if speedLimit.Burst <= 0 || speedLimit.Frame <= 0 { - return nil, errors.New("invalid speed limit") - } - if speedLimit.CheckEveryBytes <= 0 { - speedLimit.CheckEveryBytes = speedLimit.Burst - } - if speedLimit.Delay <= 0 { - speedLimit.Delay = fallbackDelay * time.Millisecond - } - return speedLimit, nil -} - -func newSpeedometer(w io.Writer, speedLimit *SpeedLimit, ceiling int64) (*Speedometer, error) { - if w == nil { - return nil, errors.New("writer cannot be nil") - } - var err error - if speedLimit != nil { - if speedLimit, err = regulateSpeedLimit(speedLimit); err != nil { - return nil, err - } - } - - return &Speedometer{ - w: w, - ceiling: ceiling, - speedLimit: speedLimit, - internal: newAtomics(), - }, nil -} - -// NewSpeedometer creates a new Speedometer that wraps the given io.Writer. -// It will not limit the rate at which data is written to the underlying writer, it only measures it. -func NewSpeedometer(w io.Writer) (*Speedometer, error) { - return newSpeedometer(w, nil, -1) -} - -// NewLimitedSpeedometer creates a new Speedometer that wraps the given io.Writer. -// If the speed limit is exceeded, writes to the underlying writer will be limited. -// See SpeedLimit for more information. -func NewLimitedSpeedometer(w io.Writer, speedLimit *SpeedLimit) (*Speedometer, error) { - return newSpeedometer(w, speedLimit, -1) -} - -// NewCappedSpeedometer creates a new Speedometer that wraps the given io.Writer. -// If len(written) bytes exceeds cap, writes to the underlying writer will be ceased permanently for the Speedometer. -func NewCappedSpeedometer(w io.Writer, capacity int64) (*Speedometer, error) { - return newSpeedometer(w, nil, capacity) -} - -// NewCappedLimitedSpeedometer creates a new Speedometer that wraps the given io.Writer. -// It is a combination of NewLimitedSpeedometer and NewCappedSpeedometer. -func NewCappedLimitedSpeedometer(w io.Writer, speedLimit *SpeedLimit, capacity int64) (*Speedometer, error) { - return newSpeedometer(w, speedLimit, capacity) -} - -func (s *Speedometer) increment(inc int64) (int, error) { - if s.internal.closed.Load() { - return 0, io.ErrClosedPipe - } - var err error - if s.ceiling > 0 && s.Total()+inc > s.ceiling { - _ = s.Close() - err = ErrLimitReached - inc = s.ceiling - s.Total() - } - s.internal.count.Add(inc) - return int(inc), err -} - -// Running returns true if the Speedometer is still running. -func (s *Speedometer) Running() bool { - return !s.internal.closed.Load() -} - -// Total returns the total number of bytes written to the underlying writer. -func (s *Speedometer) Total() int64 { - return s.internal.count.Load() -} - -// Close stops the Speedometer. No additional writes will be accepted. -func (s *Speedometer) Close() error { - if s.internal.closed.Load() { - return io.ErrClosedPipe - } - s.internal.stop.Do(func() { - s.internal.closed.Store(true) - stopped := time.Now() - birth := s.internal.birth.Load() - duration := stopped.Sub(*birth) - s.internal.duration.Store(&duration) - }) - return nil -} - -/*func (s *Speedometer) IsSlow() bool { - return s.internal.slow.Load() -}*/ - -// Rate returns the rate at which data is being written to the underlying writer per second. -func (s *Speedometer) Rate() float64 { - if s.internal.closed.Load() { - return float64(s.Total()) / s.internal.duration.Load().Seconds() - } - return float64(s.Total()) / time.Since(*s.internal.birth.Load()).Seconds() -} - -func (s *Speedometer) slowDown() error { - switch { - case s.speedLimit == nil: - return nil - case s.speedLimit.Burst <= 0 || s.speedLimit.Frame <= 0, - s.speedLimit.CheckEveryBytes <= 0, s.speedLimit.Delay <= 0: - return errors.New("invalid speed limit") - default: - // - } - if s.Total()%int64(s.speedLimit.CheckEveryBytes) != 0 { - return nil - } - s.internal.slow.Store(true) - for s.Rate() > float64(s.speedLimit.Burst)/s.speedLimit.Frame.Seconds() { - time.Sleep(s.speedLimit.Delay) - } - s.internal.slow.Store(false) - return nil -} - -// Write writes p to the underlying writer, following all defined speed limits. -func (s *Speedometer) Write(p []byte) (n int, err error) { - if s.internal.closed.Load() { - return 0, io.ErrClosedPipe - } - s.internal.start.Do(func() { - now := time.Now() - s.internal.birth.Store(&now) - }) - - // if no speed limit, just write and record - if s.speedLimit == nil { - n, err = s.w.Write(p) - if err != nil { - return n, fmt.Errorf("error writing to underlying writer: %w", err) - } - return s.increment(int64(len(p))) - } - - var ( - wErr error - accepted int - ) - accepted, wErr = s.increment(int64(len(p))) - - if wErr != nil { - return 0, fmt.Errorf("error incrementing: %w", wErr) - } - - if sErr := s.slowDown(); sErr != nil { - return 0, fmt.Errorf("error slowing down: %w", sErr) - } - - var iErr error - if n, iErr = s.w.Write(p[:accepted]); iErr != nil { - return n, fmt.Errorf("error writing to underlying writer: %w", iErr) - } - return -} diff --git a/internal/util/speedometer_test.go b/internal/util/speedometer_test.go deleted file mode 100644 index 20afaba..0000000 --- a/internal/util/speedometer_test.go +++ /dev/null @@ -1,393 +0,0 @@ -package util - -import ( - "bytes" - "errors" - "fmt" - "io" - "net" - "sync" - "sync/atomic" - "testing" - "time" -) - -type testWriter struct { - t *testing.T - total int64 -} - -func (w *testWriter) Write(p []byte) (n int, err error) { - atomic.AddInt64(&w.total, int64(len(p))) - return len(p), nil -} - -func writeStuff(t *testing.T, target io.Writer, count int) error { - t.Helper() - write := func() error { - _, err := target.Write([]byte("a")) - if err != nil { - return fmt.Errorf("error writing: %w", err) - } - return nil - } - - if count < 0 { - var err error - for err = write(); err == nil; err = write() { - time.Sleep(5 * time.Millisecond) - } - return err - } - for i := 0; i < count; i++ { - if err := write(); err != nil { - return err - } - } - return nil -} - -//nolint:funlen -func Test_Speedometer(t *testing.T) { - type results struct { - total int64 - written int - rate float64 - err error - } - - isIt := func(want, have results) { - t.Helper() - if have.total != want.total { - t.Errorf("total: want %d, have %d", want.total, have.total) - } - if have.written != want.written { - t.Errorf("written: want %d, have %d", want.written, have.written) - } - if have.rate != want.rate { - t.Errorf("rate: want %f, have %f", want.rate, have.rate) - } - if !errors.Is(have.err, want.err) { - t.Errorf("wantErr: want %v, have %v", want.err, have.err) - } - } - - var ( - errChan = make(chan error, 10) - ) - - t.Run("EarlyClose", func(t *testing.T) { - var ( - err error - cnt int - ) - t.Parallel() - sp, nerr := NewSpeedometer(&testWriter{t: t}) - if nerr != nil { - t.Errorf("unexpected error: %v", nerr) - } - go func() { - errChan <- writeStuff(t, sp, -1) - }() - time.Sleep(1 * time.Second) - if closeErr := sp.Close(); closeErr != nil { - t.Errorf("wantErr: want %v, have %v", nil, closeErr) - } - err = <-errChan - if !errors.Is(err, io.ErrClosedPipe) { - t.Errorf("wantErr: want %v, have %v", io.ErrClosedPipe, err) - } - cnt, err = sp.Write([]byte("a")) - isIt(results{err: io.ErrClosedPipe, written: 0}, results{err: err, written: cnt}) - }) - - t.Run("Basic", func(t *testing.T) { - var ( - err error - cnt int - ) - t.Parallel() - sp, nerr := NewSpeedometer(&testWriter{t: t}) - if nerr != nil { - t.Errorf("unexpected error: %v", nerr) - } - cnt, err = sp.Write([]byte("a")) - isIt(results{err: nil, written: 1, total: 1}, results{err: err, written: cnt, total: sp.Total()}) - cnt, err = sp.Write([]byte("aa")) - isIt(results{err: nil, written: 2, total: 3}, results{err: err, written: cnt, total: sp.Total()}) - cnt, err = sp.Write([]byte("a")) - isIt(results{err: nil, written: 1, total: 4}, results{err: err, written: cnt, total: sp.Total()}) - cnt, err = sp.Write([]byte("a")) - isIt(results{err: nil, written: 1, total: 5}, results{err: err, written: cnt, total: sp.Total()}) - }) - - t.Run("ConcurrentWrites", func(t *testing.T) { - var ( - err error - ) - - count := int64(0) - sp, nerr := NewSpeedometer(&testWriter{t: t}) - if nerr != nil { - t.Errorf("unexpected error: %v", nerr) - } - wg := &sync.WaitGroup{} - wg.Add(100) - for i := 0; i < 100; i++ { - go func() { - var counted int - var gerr error - counted, gerr = sp.Write([]byte("a")) - if gerr != nil { - t.Errorf("unexpected error: %v", err) - } - atomic.AddInt64(&count, int64(counted)) - wg.Done() - }() - } - wg.Wait() - isIt(results{err: nil, written: 100, total: 100}, - results{err: err, written: int(atomic.LoadInt64(&count)), total: sp.Total()}) - }) - - t.Run("GottaGoFast", func(t *testing.T) { - t.Parallel() - var ( - err error - ) - sp, nerr := NewSpeedometer(&testWriter{t: t}) - if nerr != nil { - t.Errorf("unexpected error: %v", nerr) - } - go func() { - errChan <- writeStuff(t, sp, -1) - }() - var count = 0 - for sp.Running() { - select { - case err = <-errChan: - if !errors.Is(err, io.ErrClosedPipe) { - t.Errorf("unexpected error: %v", err) - } else { - if count < 5 { - t.Errorf("too few iterations: %d", count) - } - t.Logf("final rate: %v per second", sp.Rate()) - } - default: - if count > 5 { - _ = sp.Close() - } - time.Sleep(100 * time.Millisecond) - t.Logf("rate: %v per second", sp.Rate()) - count++ - } - } - }) - - // test limiter with speedlimit - t.Run("CantGoFast", func(t *testing.T) { - t.Parallel() - t.Run("10BytesASecond", func(t *testing.T) { - t.Parallel() - var ( - err error - ) - sp, nerr := NewLimitedSpeedometer(&testWriter{t: t}, &SpeedLimit{ - Burst: 10, - Frame: time.Second, - CheckEveryBytes: 1, - Delay: 100 * time.Millisecond, - }) - if nerr != nil { - t.Errorf("unexpected error: %v", nerr) - } - for i := 0; i < 15; i++ { - if _, err = sp.Write([]byte("a")); err != nil { - t.Errorf("unexpected error: %v", err) - } - /*if sp.IsSlow() { - t.Errorf("unexpected slow state") - }*/ - t.Logf("rate: %v per second", sp.Rate()) - if sp.Rate() > 10 { - t.Errorf("speeding in a school zone (expected under %d): %v", sp.speedLimit.Burst, sp.Rate()) - } - } - }) - - t.Run("1000BytesPer5SecondsMeasuredEvery5000Bytes", func(t *testing.T) { - t.Parallel() - var ( - err error - ) - sp, nerr := NewLimitedSpeedometer(&testWriter{t: t}, &SpeedLimit{ - Burst: 1000, - Frame: 2 * time.Second, - CheckEveryBytes: 5000, - Delay: 500 * time.Millisecond, - }) - - if nerr != nil { - t.Errorf("unexpected error: %v", nerr) - } - - for i := 0; i < 4999; i++ { - if _, err = sp.Write([]byte("a")); err != nil { - t.Errorf("unexpected error: %v", err) - } - if i%1000 == 0 { - t.Logf("rate: %v per second", sp.Rate()) - } - if sp.Rate() < 1000 { - t.Errorf("shouldn't have slowed down yet (expected over %d): %v", sp.speedLimit.Burst, sp.Rate()) - } - } - if _, err = sp.Write([]byte("a")); err != nil { - t.Errorf("unexpected error: %v", err) - } - for i := 0; i < 10; i++ { - if _, err = sp.Write([]byte("a")); err != nil { - t.Errorf("unexpected error: %v", err) - } - t.Logf("rate: %v per second", sp.Rate()) - if sp.Rate() > 1000 { - t.Errorf("speeding in a school zone (expected under %d): %v", sp.speedLimit.Burst, sp.Rate()) - } - } - }) - }) - - // test capped speedometer - t.Run("OnlyALittle", func(t *testing.T) { - t.Parallel() - var ( - err error - ) - sp, nerr := NewCappedSpeedometer(&testWriter{t: t}, 1024) - if nerr != nil { - t.Errorf("unexpected error: %v", nerr) - } - for i := 0; i < 1024; i++ { - if _, err = sp.Write([]byte("a")); err != nil { - t.Errorf("unexpected error: %v", err) - } - if sp.Total() > 1024 { - t.Errorf("shouldn't have written more than 1024 bytes") - } - } - if _, err = sp.Write([]byte("a")); err == nil { - t.Errorf("expected error when writing over capacity") - } - }) - - t.Run("SynSynAckAck", func(t *testing.T) { - t.Parallel() - var ( - server net.Listener - err error - ) - //goland:noinspection GoCommentLeadingSpace - if server, err = net.Listen("tcp", ":8080"); err != nil { // #nosec:G102 - this is a unit test. - t.Fatalf("Failed to start server: %v", err) - } - defer func(server net.Listener) { - if cErr := server.Close(); cErr != nil { - t.Errorf("Failed to close server: %v", err) - } - }(server) - - go func() { - var ( - conn net.Conn - aErr error - ) - if conn, aErr = server.Accept(); aErr != nil { - t.Errorf("Failed to accept connection: %v", err) - } - - t.Logf("Accepted connection from %s", conn.RemoteAddr().String()) - - defer func(conn net.Conn) { - if cErr := conn.Close(); cErr != nil { - t.Errorf("Failed to close connection: %v", err) - } - }(conn) - - speedLimit := &SpeedLimit{ - Burst: 512, - Frame: time.Second, - CheckEveryBytes: 1, - Delay: 10 * time.Millisecond, - } - - var ( - speedometer *Speedometer - sErr error - ) - if speedometer, sErr = NewCappedLimitedSpeedometer(conn, speedLimit, 4096); sErr != nil { - t.Errorf("Failed to create speedometer: %v", sErr) - } - - buf := make([]byte, 1024) - for i := range buf { - targ := byte('E') - if i%2 == 0 { - targ = byte('e') - } - buf[i] = targ - } - for { - n, wErr := speedometer.Write(buf) - switch { - case errors.Is(wErr, io.EOF), errors.Is(wErr, ErrLimitReached): - return - case wErr != nil: - t.Errorf("Failed to write: %v", wErr) - case n != len(buf): - t.Errorf("Failed to write all bytes: %d", n) - default: - t.Logf("Wrote %d bytes", n) - } - } - }() - - var ( - client net.Conn - aErr error - ) - - if client, aErr = net.Dial("tcp", "localhost:8080"); aErr != nil { - t.Fatalf("Failed to connect to server: %v", err) - } - - defer func(client net.Conn) { - if clErr := client.Close(); clErr != nil { - t.Errorf("Failed to close client: %v", err) - } - }(client) - - buf := &bytes.Buffer{} - startTime := time.Now() - n, cpErr := io.Copy(buf, client) - if cpErr != nil { - t.Errorf("Failed to copy: %v", cpErr) - } - - duration := time.Since(startTime) - if buf.Len() == 0 || n == 0 { - t.Fatalf("No data received") - } - - rate := measureRate(t, n, duration) - - if rate > 512.0 { - t.Fatalf("Rate exceeded: got %f, expected <= 100.0", rate) - } - }) -} - -func measureRate(t *testing.T, received int64, duration time.Duration) float64 { - t.Helper() - return float64(received) / duration.Seconds() -}