Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

protocol: limit use of buffered reader #120

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 22 additions & 32 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package proxyproto

import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -51,7 +52,6 @@ type Conn struct {
once sync.Once
readErr error
conn net.Conn
bufReader *bufio.Reader
reader io.Reader
header *Header
ProxyHeaderPolicy Policy
Expand Down Expand Up @@ -151,16 +151,8 @@ func (p *Listener) Addr() net.Addr {
// NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn
func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
// For v1 the header length is at most 108 bytes.
// For v2 the header length is at most 52 bytes plus the length of the TLVs.
// We use 256 bytes to be safe.
const bufSize = 256
br := bufio.NewReaderSize(conn, bufSize)

pConn := &Conn{
bufReader: br,
reader: io.MultiReader(br, conn),
conn: conn,
conn: conn,
}

for _, opt := range opts {
Expand Down Expand Up @@ -297,7 +289,23 @@ func (p *Conn) readHeader() error {
}
}

header, err := Read(p.bufReader)
// For v1 the header length is at most 108 bytes.
// For v2 the header length is at most 52 bytes plus the length of the TLVs.
// We use 256 bytes to be safe.
const bufSize = 256
br := bufio.NewReaderSize(p.conn, bufSize)

header, err := Read(br)

if br.Buffered() != 0 {
buf := make([]byte, br.Buffered())
if _, err := br.Read(buf); err != nil {
return err // this should never as we read buffered data
}
Comment on lines +301 to +304

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct as Read may read less than Buffered() so the proper way is to use io.ReadFull but then I am not sure it won't trigger another read from p.conn (i.e. I am not sure br.Buffered() will be 0 after that).

That is why I used TeeReader in #119 to copy data read from connection into the bytes.NewBuffer

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW err could be EOF

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct as Read may read less than Buffered() so the proper way is to use io.ReadFull but then I am not sure it won't trigger another read from p.conn (i.e. I am not sure br.Buffered() will be 0 after that).

I think we can use Peek here to avoid allocation like:

	if br.Buffered() > 0 {
		b, _ := br.Peek(br.Buffered())
		p.reader = io.MultiReader(bytes.NewReader(b), p.conn)
	} else {
		p.reader = p.conn
	}

Its not super documented but I think Peek's design is to avoid underlying Read if n <= Buffered()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its not super documented but I think Peek's design is to avoid underlying Read if n <= Buffered()

Its not guaranteed unfortunately, see golang/go#63548 (comment)

Copy link
Contributor Author

@mmatczuk mmatczuk Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexanderYastrebov bufio.Reader.Read will never read here.

Call to read will jump directly here, this is because:

  • n != 0 - len(p)
  • b.r != b.w - there is something buffered

Also, note definition of Buffered().

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The question is about correctness.

There is simply no documented way to obtain buffered data from bufio.Reader and the current implementation can change in the future and this will break unexpectedly. See discussion golang/go#63548 for more details.

From https://pkg.go.dev/bufio#Reader.Read

To read exactly len(p) bytes, use io.ReadFull(b, p).

but that does not guarantee that reader does not buffer next chunk.

Therefore I think the right way is to capture everything that was read from the connection like I did in #119

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The probability they would change that Read behavior is close to 0.
But we sure can add a test that would guarantee that the Read does not touch the underlying reader.

At this point both branches are very similar, except for the fact that this one is easier to read and uses less resources.

I think you would agree this code is not very simple.

	bb := bytes.NewBuffer(make([]byte, 0, bufSize))
	br := bufio.NewReaderSize(io.TeeReader(p.conn, bb), bufSize)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you, please, add such test?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @mmatczuk 🙏🏻

p.reader = io.MultiReader(bytes.NewReader(buf), p.conn)
} else {
p.reader = p.conn
}

// If the connection's readHeaderTimeout is more than 0, undo the change to the
// deadline that we made above. Because we retain the readDeadline as part of our
Expand Down Expand Up @@ -364,26 +372,8 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) {
return 0, p.readErr
}

b := make([]byte, p.bufReader.Buffered())
if _, err := p.bufReader.Read(b); err != nil {
return 0, err // this should never as we read buffered data
}

var n int64
{
nn, err := w.Write(b)
n += int64(nn)
if err != nil {
return n, err
}
if wt, ok := p.reader.(io.WriterTo); ok {
return wt.WriteTo(w)
}
{
nn, err := io.Copy(w, p.conn)
n += nn
if err != nil {
return n, err
}
}

return n, nil
return io.Copy(w, p.reader)
}
Loading