@@ -62,20 +62,47 @@ func TestAccept(t *testing.T) {
6262 t .Run ("badCompression" , func (t * testing.T ) {
6363 t .Parallel ()
6464
65- w := mockHijacker {
66- ResponseWriter : httptest .NewRecorder (),
65+ newRequest := func (extensions string ) * http.Request {
66+ r := httptest .NewRequest ("GET" , "/" , nil )
67+ r .Header .Set ("Connection" , "Upgrade" )
68+ r .Header .Set ("Upgrade" , "websocket" )
69+ r .Header .Set ("Sec-WebSocket-Version" , "13" )
70+ r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
71+ r .Header .Set ("Sec-WebSocket-Extensions" , extensions )
72+ return r
73+ }
74+ newResponseWriter := func () http.ResponseWriter {
75+ return mockHijacker {
76+ ResponseWriter : httptest .NewRecorder (),
77+ hijack : func () (net.Conn , * bufio.ReadWriter , error ) {
78+ return nil , nil , errors .New ("hijack error" )
79+ },
80+ }
6781 }
68- r := httptest .NewRequest ("GET" , "/" , nil )
69- r .Header .Set ("Connection" , "Upgrade" )
70- r .Header .Set ("Upgrade" , "websocket" )
71- r .Header .Set ("Sec-WebSocket-Version" , "13" )
72- r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
73- r .Header .Set ("Sec-WebSocket-Extensions" , "permessage-deflate; harharhar" )
7482
75- _ , err := Accept (w , r , & AcceptOptions {
76- CompressionMode : CompressionContextTakeover ,
83+ t .Run ("withoutFallback" , func (t * testing.T ) {
84+ t .Parallel ()
85+
86+ w := newResponseWriter ()
87+ r := newRequest ("permessage-deflate; harharhar" )
88+ _ , _ = Accept (w , r , & AcceptOptions {
89+ CompressionMode : CompressionNoContextTakeover ,
90+ })
91+ assert .Equal (t , "extension header" , w .Header ().Get ("Sec-WebSocket-Extensions" ), "" )
92+ })
93+ t .Run ("withFallback" , func (t * testing.T ) {
94+ t .Parallel ()
95+
96+ w := newResponseWriter ()
97+ r := newRequest ("permessage-deflate; harharhar, permessage-deflate" )
98+ _ , _ = Accept (w , r , & AcceptOptions {
99+ CompressionMode : CompressionNoContextTakeover ,
100+ })
101+ assert .Equal (t , "extension header" ,
102+ w .Header ().Get ("Sec-WebSocket-Extensions" ),
103+ CompressionNoContextTakeover .opts ().String (),
104+ )
77105 })
78- assert .Contains (t , err , `unsupported permessage-deflate parameter` )
79106 })
80107
81108 t .Run ("requireHttpHijacker" , func (t * testing.T ) {
@@ -344,42 +371,53 @@ func Test_authenticateOrigin(t *testing.T) {
344371 }
345372}
346373
347- func Test_acceptCompression (t * testing.T ) {
374+ func Test_selectDeflate (t * testing.T ) {
348375 t .Parallel ()
349376
350377 testCases := []struct {
351- name string
352- mode CompressionMode
353- reqSecWebSocketExtensions string
354- respSecWebSocketExtensions string
355- expCopts * compressionOptions
356- error bool
378+ name string
379+ mode CompressionMode
380+ header string
381+ expCopts * compressionOptions
382+ expOK bool
357383 }{
358384 {
359385 name : "disabled" ,
360386 mode : CompressionDisabled ,
361387 expCopts : nil ,
388+ expOK : false ,
362389 },
363390 {
364391 name : "noClientSupport" ,
365392 mode : CompressionNoContextTakeover ,
366393 expCopts : nil ,
394+ expOK : false ,
367395 },
368396 {
369- name : "permessage-deflate" ,
370- mode : CompressionNoContextTakeover ,
371- reqSecWebSocketExtensions : "permessage-deflate; client_max_window_bits" ,
372- respSecWebSocketExtensions : "permessage-deflate; client_no_context_takeover; server_no_context_takeover" ,
397+ name : "permessage-deflate" ,
398+ mode : CompressionNoContextTakeover ,
399+ header : "permessage-deflate; client_max_window_bits" ,
373400 expCopts : & compressionOptions {
374401 clientNoContextTakeover : true ,
375402 serverNoContextTakeover : true ,
376403 },
404+ expOK : true ,
405+ },
406+ {
407+ name : "permessage-deflate/unknown-parameter" ,
408+ mode : CompressionNoContextTakeover ,
409+ header : "permessage-deflate; meow" ,
410+ expOK : false ,
377411 },
378412 {
379- name : "permessage-deflate/error" ,
380- mode : CompressionNoContextTakeover ,
381- reqSecWebSocketExtensions : "permessage-deflate; meow" ,
382- error : true ,
413+ name : "permessage-deflate/unknown-parameter" ,
414+ mode : CompressionNoContextTakeover ,
415+ header : "permessage-deflate; meow, permessage-deflate; client_max_window_bits" ,
416+ expCopts : & compressionOptions {
417+ clientNoContextTakeover : true ,
418+ serverNoContextTakeover : true ,
419+ },
420+ expOK : true ,
383421 },
384422 // {
385423 // name: "x-webkit-deflate-frame",
@@ -404,19 +442,11 @@ func Test_acceptCompression(t *testing.T) {
404442 t .Run (tc .name , func (t * testing.T ) {
405443 t .Parallel ()
406444
407- r := httptest .NewRequest (http .MethodGet , "/" , nil )
408- r .Header .Set ("Sec-WebSocket-Extensions" , tc .reqSecWebSocketExtensions )
409-
410- w := httptest .NewRecorder ()
411- copts , err := acceptCompression (r , w , tc .mode )
412- if tc .error {
413- assert .Error (t , err )
414- return
415- }
416-
417- assert .Success (t , err )
445+ h := http.Header {}
446+ h .Set ("Sec-WebSocket-Extensions" , tc .header )
447+ copts , ok := selectDeflate (websocketExtensions (h ), tc .mode )
448+ assert .Equal (t , "selected options" , tc .expOK , ok )
418449 assert .Equal (t , "compression options" , tc .expCopts , copts )
419- assert .Equal (t , "Sec-WebSocket-Extensions" , tc .respSecWebSocketExtensions , w .Header ().Get ("Sec-WebSocket-Extensions" ))
420450 })
421451 }
422452}
0 commit comments