@@ -51,9 +51,8 @@ type Conn struct {
5151	br              * bufio.Reader 
5252	bw              * bufio.Writer 
5353
54- 	readTimeout      chan  context.Context 
55- 	writeTimeout     chan  context.Context 
56- 	timeoutLoopDone  chan  struct {}
54+ 	readTimeoutStop   atomic.Pointer [func () bool ]
55+ 	writeTimeoutStop  atomic.Pointer [func () bool ]
5756
5857	// Read state. 
5958	readMu          * mu 
@@ -113,10 +112,6 @@ func newConn(cfg connConfig) *Conn {
113112		br : cfg .br ,
114113		bw : cfg .bw ,
115114
116- 		readTimeout :     make (chan  context.Context ),
117- 		writeTimeout :    make (chan  context.Context ),
118- 		timeoutLoopDone : make (chan  struct {}),
119- 
120115		closed :         make (chan  struct {}),
121116		activePings :    make (map [string ]chan <-  struct {}),
122117		onPingReceived : cfg .onPingReceived ,
@@ -144,8 +139,6 @@ func newConn(cfg connConfig) *Conn {
144139		c .close ()
145140	})
146141
147- 	go  c .timeoutLoop ()
148- 
149142	return  c 
150143}
151144
@@ -175,27 +168,34 @@ func (c *Conn) close() error {
175168	return  err 
176169}
177170
178- func  (c  * Conn ) timeoutLoop () {
179- 	defer  close (c .timeoutLoopDone )
171+ func  (c  * Conn ) setupWriteTimeout (ctx  context.Context ) {
172+ 	stop  :=  context .AfterFunc (ctx , func () {
173+ 		c .clearWriteTimeout ()
174+ 		c .close ()
175+ 	})
176+ 	swapTimeoutStop (& c .writeTimeoutStop , & stop )
177+ }
180178
181- 	readCtx  :=  context .Background ()
182- 	writeCtx  :=  context .Background ()
179+ func  (c  * Conn ) clearWriteTimeout () {
180+ 	swapTimeoutStop (& c .writeTimeoutStop , nil )
181+ }
183182
184- 	for  {
185- 		select  {
186- 		case  <- c .closed :
187- 			return 
188- 
189- 		case  writeCtx  =  <- c .writeTimeout :
190- 		case  readCtx  =  <- c .readTimeout :
191- 
192- 		case  <- readCtx .Done ():
193- 			c .close ()
194- 			return 
195- 		case  <- writeCtx .Done ():
196- 			c .close ()
197- 			return 
198- 		}
183+ func  (c  * Conn ) setupReadTimeout (ctx  context.Context ) {
184+ 	stop  :=  context .AfterFunc (ctx , func () {
185+ 		c .clearReadTimeout ()
186+ 		c .close ()
187+ 	})
188+ 	swapTimeoutStop (& c .readTimeoutStop , & stop )
189+ }
190+ 
191+ func  (c  * Conn ) clearReadTimeout () {
192+ 	swapTimeoutStop (& c .readTimeoutStop , nil )
193+ }
194+ 
195+ func  swapTimeoutStop (p  * atomic.Pointer [func () bool ], newStop  * func () bool ) {
196+ 	oldStop  :=  p .Swap (newStop )
197+ 	if  oldStop  !=  nil  {
198+ 		(* oldStop )()
199199	}
200200}
201201
0 commit comments