@@ -45,20 +45,47 @@ func TestAccept(t *testing.T) {
4545 t .Run ("badCompression" , func (t * testing.T ) {
4646 t .Parallel ()
4747
48- w := mockHijacker {
49- ResponseWriter : httptest .NewRecorder (),
48+ newRequest := func (extensions string ) * http.Request {
49+ r := httptest .NewRequest ("GET" , "/" , nil )
50+ r .Header .Set ("Connection" , "Upgrade" )
51+ r .Header .Set ("Upgrade" , "websocket" )
52+ r .Header .Set ("Sec-WebSocket-Version" , "13" )
53+ r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
54+ r .Header .Set ("Sec-WebSocket-Extensions" , extensions )
55+ return r
56+ }
57+ newResponseWriter := func () http.ResponseWriter {
58+ return mockHijacker {
59+ ResponseWriter : httptest .NewRecorder (),
60+ hijack : func () (net.Conn , * bufio.ReadWriter , error ) {
61+ return nil , nil , errors .New ("hijack error" )
62+ },
63+ }
5064 }
51- r := httptest .NewRequest ("GET" , "/" , nil )
52- r .Header .Set ("Connection" , "Upgrade" )
53- r .Header .Set ("Upgrade" , "websocket" )
54- r .Header .Set ("Sec-WebSocket-Version" , "13" )
55- r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
56- r .Header .Set ("Sec-WebSocket-Extensions" , "permessage-deflate; harharhar" )
5765
58- _ , err := Accept (w , r , & AcceptOptions {
59- CompressionMode : CompressionContextTakeover ,
66+ t .Run ("withoutFallback" , func (t * testing.T ) {
67+ t .Parallel ()
68+
69+ w := newResponseWriter ()
70+ r := newRequest ("permessage-deflate; harharhar" )
71+ _ , _ = Accept (w , r , & AcceptOptions {
72+ CompressionMode : CompressionNoContextTakeover ,
73+ })
74+ assert .Equal (t , "extension header" , w .Header ().Get ("Sec-WebSocket-Extensions" ), "" )
75+ })
76+ t .Run ("withFallback" , func (t * testing.T ) {
77+ t .Parallel ()
78+
79+ w := newResponseWriter ()
80+ r := newRequest ("permessage-deflate; harharhar, permessage-deflate" )
81+ _ , _ = Accept (w , r , & AcceptOptions {
82+ CompressionMode : CompressionNoContextTakeover ,
83+ })
84+ assert .Equal (t , "extension header" ,
85+ w .Header ().Get ("Sec-WebSocket-Extensions" ),
86+ CompressionNoContextTakeover .opts ().String (),
87+ )
6088 })
61- assert .Contains (t , err , `unsupported permessage-deflate parameter` )
6289 })
6390
6491 t .Run ("requireHttpHijacker" , func (t * testing.T ) {
@@ -321,79 +348,66 @@ func Test_authenticateOrigin(t *testing.T) {
321348 }
322349}
323350
324- func Test_acceptCompression (t * testing.T ) {
351+ func Test_selectDeflate (t * testing.T ) {
325352 t .Parallel ()
326353
327354 testCases := []struct {
328- name string
329- mode CompressionMode
330- reqSecWebSocketExtensions string
331- respSecWebSocketExtensions string
332- expCopts * compressionOptions
333- error bool
355+ name string
356+ mode CompressionMode
357+ header string
358+ expCopts * compressionOptions
359+ expOK bool
334360 }{
335361 {
336362 name : "disabled" ,
337363 mode : CompressionDisabled ,
338364 expCopts : nil ,
365+ expOK : false ,
339366 },
340367 {
341368 name : "noClientSupport" ,
342369 mode : CompressionNoContextTakeover ,
343370 expCopts : nil ,
371+ expOK : false ,
344372 },
345373 {
346- name : "permessage-deflate" ,
347- mode : CompressionNoContextTakeover ,
348- reqSecWebSocketExtensions : "permessage-deflate; client_max_window_bits" ,
349- respSecWebSocketExtensions : "permessage-deflate; client_no_context_takeover; server_no_context_takeover" ,
374+ name : "permessage-deflate" ,
375+ mode : CompressionNoContextTakeover ,
376+ header : "permessage-deflate; client_max_window_bits" ,
350377 expCopts : & compressionOptions {
351378 clientNoContextTakeover : true ,
352379 serverNoContextTakeover : true ,
353380 },
381+ expOK : true ,
382+ },
383+ {
384+ name : "permessage-deflate/unknown-parameter" ,
385+ mode : CompressionNoContextTakeover ,
386+ header : "permessage-deflate; meow" ,
387+ expOK : false ,
354388 },
355389 {
356- name : "permessage-deflate/error" ,
357- mode : CompressionNoContextTakeover ,
358- reqSecWebSocketExtensions : "permessage-deflate; meow" ,
359- error : true ,
390+ name : "permessage-deflate/unknown-parameter" ,
391+ mode : CompressionNoContextTakeover ,
392+ header : "permessage-deflate; meow, permessage-deflate; client_max_window_bits" ,
393+ expCopts : & compressionOptions {
394+ clientNoContextTakeover : true ,
395+ serverNoContextTakeover : true ,
396+ },
397+ expOK : true ,
360398 },
361- // {
362- // name: "x-webkit-deflate-frame",
363- // mode: CompressionNoContextTakeover,
364- // reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
365- // respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
366- // expCopts: &compressionOptions{
367- // clientNoContextTakeover: true,
368- // serverNoContextTakeover: true,
369- // },
370- // },
371- // {
372- // name: "x-webkit-deflate/error",
373- // mode: CompressionNoContextTakeover,
374- // reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
375- // error: true,
376- // },
377399 }
378400
379401 for _ , tc := range testCases {
380402 tc := tc
381403 t .Run (tc .name , func (t * testing.T ) {
382404 t .Parallel ()
383405
384- r := httptest .NewRequest (http .MethodGet , "/" , nil )
385- r .Header .Set ("Sec-WebSocket-Extensions" , tc .reqSecWebSocketExtensions )
386-
387- w := httptest .NewRecorder ()
388- copts , err := acceptCompression (r , w , tc .mode )
389- if tc .error {
390- assert .Error (t , err )
391- return
392- }
393-
394- assert .Success (t , err )
406+ h := http.Header {}
407+ h .Set ("Sec-WebSocket-Extensions" , tc .header )
408+ copts , ok := selectDeflate (websocketExtensions (h ), tc .mode )
409+ assert .Equal (t , "selected options" , tc .expOK , ok )
395410 assert .Equal (t , "compression options" , tc .expCopts , copts )
396- assert .Equal (t , "Sec-WebSocket-Extensions" , tc .respSecWebSocketExtensions , w .Header ().Get ("Sec-WebSocket-Extensions" ))
397411 })
398412 }
399413}
0 commit comments