@@ -49,30 +49,11 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
4949}
5050
5151type msgWriter struct {
52- mw * msgWriterState
53- closed bool
54- }
55-
56- func (mw * msgWriter ) Write (p []byte ) (int , error ) {
57- if mw .closed {
58- return 0 , errors .New ("cannot use closed writer" )
59- }
60- return mw .mw .Write (p )
61- }
62-
63- func (mw * msgWriter ) Close () error {
64- if mw .closed {
65- return errors .New ("cannot use closed writer" )
66- }
67- mw .closed = true
68- return mw .mw .Close ()
69- }
70-
71- type msgWriterState struct {
7252 c * Conn
7353
7454 mu * mu
7555 writeMu * mu
56+ closed bool
7657
7758 ctx context.Context
7859 opcode opcode
@@ -82,16 +63,16 @@ type msgWriterState struct {
8263 flateWriter * flate.Writer
8364}
8465
85- func newMsgWriterState (c * Conn ) * msgWriterState {
86- mw := & msgWriterState {
66+ func newMsgWriter (c * Conn ) * msgWriter {
67+ mw := & msgWriter {
8768 c : c ,
8869 mu : newMu (c ),
8970 writeMu : newMu (c ),
9071 }
9172 return mw
9273}
9374
94- func (mw * msgWriterState ) ensureFlate () {
75+ func (mw * msgWriter ) ensureFlate () {
9576 if mw .trimWriter == nil {
9677 mw .trimWriter = & trimLastFourBytesWriter {
9778 w : util .WriterFunc (mw .write ),
@@ -104,22 +85,19 @@ func (mw *msgWriterState) ensureFlate() {
10485 mw .flate = true
10586}
10687
107- func (mw * msgWriterState ) flateContextTakeover () bool {
88+ func (mw * msgWriter ) flateContextTakeover () bool {
10889 if mw .c .client {
10990 return ! mw .c .copts .clientNoContextTakeover
11091 }
11192 return ! mw .c .copts .serverNoContextTakeover
11293}
11394
11495func (c * Conn ) writer (ctx context.Context , typ MessageType ) (io.WriteCloser , error ) {
115- err := c .msgWriterState .reset (ctx , typ )
96+ err := c .msgWriter .reset (ctx , typ )
11697 if err != nil {
11798 return nil , err
11899 }
119- return & msgWriter {
120- mw : c .msgWriterState ,
121- closed : false ,
122- }, nil
100+ return c .msgWriter , nil
123101}
124102
125103func (c * Conn ) write (ctx context.Context , typ MessageType , p []byte ) (int , error ) {
@@ -129,8 +107,8 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
129107 }
130108
131109 if ! c .flate () {
132- defer c .msgWriterState .mu .unlock ()
133- return c .writeFrame (ctx , true , false , c .msgWriterState .opcode , p )
110+ defer c .msgWriter .mu .unlock ()
111+ return c .writeFrame (ctx , true , false , c .msgWriter .opcode , p )
134112 }
135113
136114 n , err := mw .Write (p )
@@ -142,7 +120,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
142120 return n , err
143121}
144122
145- func (mw * msgWriterState ) reset (ctx context.Context , typ MessageType ) error {
123+ func (mw * msgWriter ) reset (ctx context.Context , typ MessageType ) error {
146124 err := mw .mu .lock (ctx )
147125 if err != nil {
148126 return err
@@ -151,21 +129,26 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
151129 mw .ctx = ctx
152130 mw .opcode = opcode (typ )
153131 mw .flate = false
132+ mw .closed = false
154133
155134 mw .trimWriter .reset ()
156135
157136 return nil
158137}
159138
160- func (mw * msgWriterState ) putFlateWriter () {
139+ func (mw * msgWriter ) putFlateWriter () {
161140 if mw .flateWriter != nil {
162141 putFlateWriter (mw .flateWriter )
163142 mw .flateWriter = nil
164143 }
165144}
166145
167146// Write writes the given bytes to the WebSocket connection.
168- func (mw * msgWriterState ) Write (p []byte ) (_ int , err error ) {
147+ func (mw * msgWriter ) Write (p []byte ) (_ int , err error ) {
148+ if mw .closed {
149+ return 0 , errors .New ("cannot use closed writer" )
150+ }
151+
169152 err = mw .writeMu .lock (mw .ctx )
170153 if err != nil {
171154 return 0 , fmt .Errorf ("failed to write: %w" , err )
@@ -194,7 +177,7 @@ func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
194177 return mw .write (p )
195178}
196179
197- func (mw * msgWriterState ) write (p []byte ) (int , error ) {
180+ func (mw * msgWriter ) write (p []byte ) (int , error ) {
198181 n , err := mw .c .writeFrame (mw .ctx , false , mw .flate , mw .opcode , p )
199182 if err != nil {
200183 return n , fmt .Errorf ("failed to write data frame: %w" , err )
@@ -204,9 +187,14 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
204187}
205188
206189// Close flushes the frame to the connection.
207- func (mw * msgWriterState ) Close () (err error ) {
190+ func (mw * msgWriter ) Close () (err error ) {
208191 defer errd .Wrap (& err , "failed to close writer" )
209192
193+ if mw .closed {
194+ return errors .New ("writer already closed" )
195+ }
196+ mw .closed = true
197+
210198 err = mw .writeMu .lock (mw .ctx )
211199 if err != nil {
212200 return err
@@ -232,7 +220,7 @@ func (mw *msgWriterState) Close() (err error) {
232220 return nil
233221}
234222
235- func (mw * msgWriterState ) close () {
223+ func (mw * msgWriter ) close () {
236224 if mw .c .client {
237225 mw .c .writeFrameMu .forceLock ()
238226 putBufioWriter (mw .c .bw )
0 commit comments