diff --git a/conn_test.go b/conn_test.go index 58ac394c..c3ccc886 100644 --- a/conn_test.go +++ b/conn_test.go @@ -421,6 +421,25 @@ func TestConn(t *testing.T) { err = c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) + + t.Run("ReadLimitExceededReturnsErrMessageTooBig", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + + c1.SetReadLimit(1024) + _ = c2.CloseRead(tt.ctx) + + writeDone := xsync.Go(func() error { + payload := strings.Repeat("x", 4096) + return c2.Write(tt.ctx, websocket.MessageText, []byte(payload)) + }) + + _, _, err := c1.Read(tt.ctx) + assert.ErrorIs(t, websocket.ErrMessageTooBig, err) + assert.Contains(t, err, "read limited at 1025 bytes") + + _ = c2.CloseNow() + <-writeDone + }) } func TestWasm(t *testing.T) { diff --git a/errors.go b/errors.go new file mode 100644 index 00000000..bf4fc2b0 --- /dev/null +++ b/errors.go @@ -0,0 +1,8 @@ +package websocket + +import ( + "errors" +) + +// ErrMessageTooBig is returned when a message exceeds the read limit. +var ErrMessageTooBig = errors.New("websocket: message too big") diff --git a/read.go b/read.go index aab9e141..520acd50 100644 --- a/read.go +++ b/read.go @@ -90,7 +90,8 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { // // By default, the connection has a message read limit of 32768 bytes. // -// When the limit is hit, the connection will be closed with StatusMessageTooBig. +// When the limit is hit, reads return an error wrapping ErrMessageTooBig and +// the connection is closed with StatusMessageTooBig. // // Set to -1 to disable. func (c *Conn) SetReadLimit(n int64) { @@ -520,9 +521,9 @@ func (lr *limitReader) Read(p []byte) (int, error) { } if lr.n == 0 { - err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) - lr.c.writeError(StatusMessageTooBig, err) - return 0, err + reason := fmt.Errorf("read limited at %d bytes", lr.limit.Load()) + lr.c.writeError(StatusMessageTooBig, reason) + return 0, fmt.Errorf("%w: %v", ErrMessageTooBig, reason) } if int64(len(p)) > lr.n { diff --git a/ws_js.go b/ws_js.go index 8d52aeab..026b75fc 100644 --- a/ws_js.go +++ b/ws_js.go @@ -144,9 +144,9 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { } readLimit := c.msgReadLimit.Load() if readLimit >= 0 && int64(len(p)) > readLimit { - err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load()) - c.Close(StatusMessageTooBig, err.Error()) - return 0, nil, err + reason := fmt.Errorf("read limited at %d bytes", c.msgReadLimit.Load()) + c.Close(StatusMessageTooBig, reason.Error()) + return 0, nil, fmt.Errorf("%w: %v", ErrMessageTooBig, reason) } return typ, p, nil }