66 "io"
77 "math"
88 "net"
9- "sync"
9+ "sync/atomic "
1010 "time"
1111)
1212
@@ -28,9 +28,10 @@ import (
2828//
2929// Close will close the *websocket.Conn with StatusNormalClosure.
3030//
31- // When a deadline is hit, the connection will be closed. This is
32- // different from most net.Conn implementations where only the
33- // reading/writing goroutines are interrupted but the connection is kept alive.
31+ // When a deadline is hit and there is an active read or write goroutine, the
32+ // connection will be closed. This is different from most net.Conn implementations
33+ // where only the reading/writing goroutines are interrupted but the connection
34+ // is kept alive.
3435//
3536// The Addr methods will return a mock net.Addr that returns "websocket" for Network
3637// and "websocket/unknown-addr" for String.
@@ -41,17 +42,43 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
4142 nc := & netConn {
4243 c : c ,
4344 msgType : msgType ,
45+ readMu : newMu (c ),
46+ writeMu : newMu (c ),
4447 }
4548
46- var cancel context.CancelFunc
47- nc .writeContext , cancel = context .WithCancel (ctx )
48- nc .writeTimer = time .AfterFunc (math .MaxInt64 , cancel )
49+ var writeCancel context.CancelFunc
50+ nc .writeCtx , writeCancel = context .WithCancel (ctx )
51+ var readCancel context.CancelFunc
52+ nc .readCtx , readCancel = context .WithCancel (ctx )
53+
54+ nc .writeTimer = time .AfterFunc (math .MaxInt64 , func () {
55+ if ! nc .writeMu .tryLock () {
56+ // If the lock cannot be acquired, then there is an
57+ // active write goroutine and so we should cancel the context.
58+ writeCancel ()
59+ return
60+ }
61+ defer nc .writeMu .unlock ()
62+
63+ // Prevents future writes from writing until the deadline is reset.
64+ atomic .StoreInt64 (& nc .writeExpired , 1 )
65+ })
4966 if ! nc .writeTimer .Stop () {
5067 <- nc .writeTimer .C
5168 }
5269
53- nc .readContext , cancel = context .WithCancel (ctx )
54- nc .readTimer = time .AfterFunc (math .MaxInt64 , cancel )
70+ nc .readTimer = time .AfterFunc (math .MaxInt64 , func () {
71+ if ! nc .readMu .tryLock () {
72+ // If the lock cannot be acquired, then there is an
73+ // active read goroutine and so we should cancel the context.
74+ readCancel ()
75+ return
76+ }
77+ defer nc .readMu .unlock ()
78+
79+ // Prevents future reads from reading until the deadline is reset.
80+ atomic .StoreInt64 (& nc .readExpired , 1 )
81+ })
5582 if ! nc .readTimer .Stop () {
5683 <- nc .readTimer .C
5784 }
@@ -64,59 +91,72 @@ type netConn struct {
6491 msgType MessageType
6592
6693 writeTimer * time.Timer
67- writeContext context.Context
94+ writeMu * mu
95+ writeExpired int64
96+ writeCtx context.Context
6897
6998 readTimer * time.Timer
70- readContext context. Context
71-
72- readMu sync. Mutex
73- eofed bool
74- reader io.Reader
99+ readMu * mu
100+ readExpired int64
101+ readCtx context. Context
102+ readEOFed bool
103+ reader io.Reader
75104}
76105
77106var _ net.Conn = & netConn {}
78107
79- func (c * netConn ) Close () error {
80- return c .c .Close (StatusNormalClosure , "" )
108+ func (nc * netConn ) Close () error {
109+ return nc .c .Close (StatusNormalClosure , "" )
81110}
82111
83- func (c * netConn ) Write (p []byte ) (int , error ) {
84- err := c .c .Write (c .writeContext , c .msgType , p )
112+ func (nc * netConn ) Write (p []byte ) (int , error ) {
113+ nc .writeMu .forceLock ()
114+ defer nc .writeMu .unlock ()
115+
116+ if atomic .LoadInt64 (& nc .writeExpired ) == 1 {
117+ return 0 , fmt .Errorf ("failed to write: %w" , context .DeadlineExceeded )
118+ }
119+
120+ err := nc .c .Write (nc .writeCtx , nc .msgType , p )
85121 if err != nil {
86122 return 0 , err
87123 }
88124 return len (p ), nil
89125}
90126
91- func (c * netConn ) Read (p []byte ) (int , error ) {
92- c .readMu .Lock ()
93- defer c .readMu .Unlock ()
127+ func (nc * netConn ) Read (p []byte ) (int , error ) {
128+ nc .readMu .forceLock ()
129+ defer nc .readMu .unlock ()
130+
131+ if atomic .LoadInt64 (& nc .readExpired ) == 1 {
132+ return 0 , fmt .Errorf ("failed to read: %w" , context .DeadlineExceeded )
133+ }
94134
95- if c . eofed {
135+ if nc . readEOFed {
96136 return 0 , io .EOF
97137 }
98138
99- if c .reader == nil {
100- typ , r , err := c .c .Reader (c . readContext )
139+ if nc .reader == nil {
140+ typ , r , err := nc .c .Reader (nc . readCtx )
101141 if err != nil {
102142 switch CloseStatus (err ) {
103143 case StatusNormalClosure , StatusGoingAway :
104- c . eofed = true
144+ nc . readEOFed = true
105145 return 0 , io .EOF
106146 }
107147 return 0 , err
108148 }
109- if typ != c .msgType {
110- err := fmt .Errorf ("unexpected frame type read (expected %v): %v" , c .msgType , typ )
111- c .c .Close (StatusUnsupportedData , err .Error ())
149+ if typ != nc .msgType {
150+ err := fmt .Errorf ("unexpected frame type read (expected %v): %v" , nc .msgType , typ )
151+ nc .c .Close (StatusUnsupportedData , err .Error ())
112152 return 0 , err
113153 }
114- c .reader = r
154+ nc .reader = r
115155 }
116156
117- n , err := c .reader .Read (p )
157+ n , err := nc .reader .Read (p )
118158 if err == io .EOF {
119- c .reader = nil
159+ nc .reader = nil
120160 err = nil
121161 }
122162 return n , err
@@ -133,34 +173,36 @@ func (a websocketAddr) String() string {
133173 return "websocket/unknown-addr"
134174}
135175
136- func (c * netConn ) RemoteAddr () net.Addr {
176+ func (nc * netConn ) RemoteAddr () net.Addr {
137177 return websocketAddr {}
138178}
139179
140- func (c * netConn ) LocalAddr () net.Addr {
180+ func (nc * netConn ) LocalAddr () net.Addr {
141181 return websocketAddr {}
142182}
143183
144- func (c * netConn ) SetDeadline (t time.Time ) error {
145- c .SetWriteDeadline (t )
146- c .SetReadDeadline (t )
184+ func (nc * netConn ) SetDeadline (t time.Time ) error {
185+ nc .SetWriteDeadline (t )
186+ nc .SetReadDeadline (t )
147187 return nil
148188}
149189
150- func (c * netConn ) SetWriteDeadline (t time.Time ) error {
190+ func (nc * netConn ) SetWriteDeadline (t time.Time ) error {
191+ atomic .StoreInt64 (& nc .writeExpired , 0 )
151192 if t .IsZero () {
152- c .writeTimer .Stop ()
193+ nc .writeTimer .Stop ()
153194 } else {
154- c .writeTimer .Reset (t .Sub (time .Now ()))
195+ nc .writeTimer .Reset (t .Sub (time .Now ()))
155196 }
156197 return nil
157198}
158199
159- func (c * netConn ) SetReadDeadline (t time.Time ) error {
200+ func (nc * netConn ) SetReadDeadline (t time.Time ) error {
201+ atomic .StoreInt64 (& nc .readExpired , 0 )
160202 if t .IsZero () {
161- c .readTimer .Stop ()
203+ nc .readTimer .Stop ()
162204 } else {
163- c .readTimer .Reset (t .Sub (time .Now ()))
205+ nc .readTimer .Reset (t .Sub (time .Now ()))
164206 }
165207 return nil
166208}
0 commit comments