@@ -10,6 +10,7 @@ import (
10
10
"net/http"
11
11
"net/http/httptest"
12
12
"strings"
13
+ "sync"
13
14
"testing"
14
15
15
16
"nhooyr.io/websocket/internal/test/assert"
@@ -142,6 +143,42 @@ func TestAccept(t *testing.T) {
142
143
_ , err := Accept (w , r , nil )
143
144
assert .Contains (t , err , `failed to hijack connection` )
144
145
})
146
+ t .Run ("closeRace" , func (t * testing.T ) {
147
+ t .Parallel ()
148
+
149
+ server , _ := net .Pipe ()
150
+
151
+ rw := bufio .NewReadWriter (bufio .NewReader (server ), bufio .NewWriter (server ))
152
+ newResponseWriter := func () http.ResponseWriter {
153
+ return mockHijacker {
154
+ ResponseWriter : httptest .NewRecorder (),
155
+ hijack : func () (net.Conn , * bufio.ReadWriter , error ) {
156
+ return server , rw , nil
157
+ },
158
+ }
159
+ }
160
+ w := newResponseWriter ()
161
+
162
+ r := httptest .NewRequest ("GET" , "/" , nil )
163
+ r .Header .Set ("Connection" , "Upgrade" )
164
+ r .Header .Set ("Upgrade" , "websocket" )
165
+ r .Header .Set ("Sec-WebSocket-Version" , "13" )
166
+ r .Header .Set ("Sec-WebSocket-Key" , xrand .Base64 (16 ))
167
+
168
+ c , err := Accept (w , r , nil )
169
+ wg := & sync.WaitGroup {}
170
+ wg .Add (2 )
171
+ go func () {
172
+ c .Close (StatusInternalError , "the sky is falling" )
173
+ wg .Done ()
174
+ }()
175
+ go func () {
176
+ c .CloseNow ()
177
+ wg .Done ()
178
+ }()
179
+ wg .Wait ()
180
+ assert .Success (t , err )
181
+ })
145
182
}
146
183
147
184
func Test_verifyClientHandshake (t * testing.T ) {
0 commit comments