@@ -14,6 +14,8 @@ import (
1414 "strings"
1515 "time"
1616
17+ "golang.org/x/sync/errgroup"
18+
1719 "github.com/google/uuid"
1820 "golang.org/x/xerrors"
1921 "nhooyr.io/websocket"
@@ -317,142 +319,28 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
317319 q := coordinateURL .Query ()
318320 q .Add ("version" , proto .CurrentVersion .String ())
319321 coordinateURL .RawQuery = q .Encode ()
320- closedCoordinator := make (chan struct {})
321- // Must only ever be used once, send error OR close to avoid
322- // reassignment race. Buffered so we don't hang in goroutine.
323- firstCoordinator := make (chan error , 1 )
324- go func () {
325- defer close (closedCoordinator )
326- isFirst := true
327- for retrier := retry .New (50 * time .Millisecond , 10 * time .Second ); retrier .Wait (ctx ); {
328- options .Logger .Debug (ctx , "connecting" )
329- // nolint:bodyclose
330- ws , res , err := websocket .Dial (ctx , coordinateURL .String (), & websocket.DialOptions {
331- HTTPClient : c .HTTPClient ,
332- HTTPHeader : headers ,
333- // Need to disable compression to avoid a data-race.
334- CompressionMode : websocket .CompressionDisabled ,
335- })
336- if isFirst {
337- if res != nil && res .StatusCode == http .StatusConflict {
338- firstCoordinator <- ReadBodyAsError (res )
339- return
340- }
341- isFirst = false
342- close (firstCoordinator )
343- }
344- if err != nil {
345- if errors .Is (err , context .Canceled ) {
346- return
347- }
348- options .Logger .Debug (ctx , "failed to dial" , slog .Error (err ))
349- continue
350- }
351- client , err := tailnet .NewDRPCClient (websocket .NetConn (ctx , ws , websocket .MessageBinary ))
352- if err != nil {
353- options .Logger .Debug (ctx , "failed to create DRPCClient" , slog .Error (err ))
354- _ = ws .Close (websocket .StatusInternalError , "" )
355- continue
356- }
357- coordinate , err := client .Coordinate (ctx )
358- if err != nil {
359- options .Logger .Debug (ctx , "failed to reach the Coordinate endpoint" , slog .Error (err ))
360- _ = ws .Close (websocket .StatusInternalError , "" )
361- continue
362- }
363-
364- coordination := tailnet .NewRemoteCoordination (options .Logger , coordinate , conn , agentID )
365- options .Logger .Debug (ctx , "serving coordinator" )
366- err = <- coordination .Error ()
367- if errors .Is (err , context .Canceled ) {
368- _ = ws .Close (websocket .StatusGoingAway , "" )
369- return
370- }
371- if err != nil {
372- options .Logger .Debug (ctx , "error serving coordinator" , slog .Error (err ))
373- _ = ws .Close (websocket .StatusGoingAway , "" )
374- continue
375- }
376- _ = ws .Close (websocket .StatusGoingAway , "" )
377- }
378- }()
379-
380- derpMapURL , err := c .URL .Parse ("/api/v2/derp-map" )
381- if err != nil {
382- return nil , xerrors .Errorf ("parse url: %w" , err )
383- }
384- closedDerpMap := make (chan struct {})
385- // Must only ever be used once, send error OR close to avoid
386- // reassignment race. Buffered so we don't hang in goroutine.
387- firstDerpMap := make (chan error , 1 )
388- go func () {
389- defer close (closedDerpMap )
390- isFirst := true
391- for retrier := retry .New (50 * time .Millisecond , 10 * time .Second ); retrier .Wait (ctx ); {
392- options .Logger .Debug (ctx , "connecting to server for derp map updates" )
393- // nolint:bodyclose
394- ws , res , err := websocket .Dial (ctx , derpMapURL .String (), & websocket.DialOptions {
395- HTTPClient : c .HTTPClient ,
396- HTTPHeader : headers ,
397- // Need to disable compression to avoid a data-race.
398- CompressionMode : websocket .CompressionDisabled ,
399- })
400- if isFirst {
401- if res != nil && res .StatusCode == http .StatusConflict {
402- firstDerpMap <- ReadBodyAsError (res )
403- return
404- }
405- isFirst = false
406- close (firstDerpMap )
407- }
408- if err != nil {
409- if errors .Is (err , context .Canceled ) {
410- return
411- }
412- options .Logger .Debug (ctx , "failed to dial" , slog .Error (err ))
413- continue
414- }
415-
416- var (
417- nconn = websocket .NetConn (ctx , ws , websocket .MessageBinary )
418- dec = json .NewDecoder (nconn )
419- )
420- for {
421- var derpMap tailcfg.DERPMap
422- err := dec .Decode (& derpMap )
423- if xerrors .Is (err , context .Canceled ) {
424- _ = ws .Close (websocket .StatusGoingAway , "" )
425- return
426- }
427- if err != nil {
428- options .Logger .Debug (ctx , "failed to decode derp map" , slog .Error (err ))
429- _ = ws .Close (websocket .StatusGoingAway , "" )
430- return
431- }
432-
433- if ! tailnet .CompareDERPMaps (conn .DERPMap (), & derpMap ) {
434- options .Logger .Debug (ctx , "updating derp map due to detected changes" )
435- conn .SetDERPMap (& derpMap )
436- }
437- }
438- }
439- }()
440322
441- for firstCoordinator != nil || firstDerpMap != nil {
442- select {
443- case <- dialCtx .Done ():
444- return nil , xerrors .Errorf ("timed out waiting for coordinator and derp map: %w" , dialCtx .Err ())
445- case err = <- firstCoordinator :
446- if err != nil {
447- return nil , xerrors .Errorf ("start coordinator: %w" , err )
448- }
449- firstCoordinator = nil
450- case err = <- firstDerpMap :
451- if err != nil {
452- return nil , xerrors .Errorf ("receive derp map: %w" , err )
453- }
454- firstDerpMap = nil
323+ connector := runTailnetAPIConnector (ctx , options .Logger ,
324+ agentID , coordinateURL .String (),
325+ & websocket.DialOptions {
326+ HTTPClient : c .HTTPClient ,
327+ HTTPHeader : headers ,
328+ // Need to disable compression to avoid a data-race.
329+ CompressionMode : websocket .CompressionDisabled ,
330+ },
331+ conn ,
332+ )
333+ options .Logger .Debug (ctx , "running tailnet API v2+ connector" )
334+
335+ select {
336+ case <- dialCtx .Done ():
337+ return nil , xerrors .Errorf ("timed out waiting for coordinator and derp map: %w" , dialCtx .Err ())
338+ case err = <- connector .connected :
339+ if err != nil {
340+ options .Logger .Error (ctx , "failed to connect to tailnet v2+ API" , slog .Error (err ))
341+ return nil , xerrors .Errorf ("start connector: %w" , err )
455342 }
343+ options .Logger .Debug (ctx , "connected to tailnet v2+ API" )
456344 }
457345
458346 agentConn = NewWorkspaceAgentConn (conn , WorkspaceAgentConnOptions {
@@ -464,8 +352,7 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
464352 AgentIP : WorkspaceAgentIP ,
465353 CloseFunc : func () error {
466354 cancel ()
467- <- closedCoordinator
468- <- closedDerpMap
355+ <- connector .closed
469356 return conn .Close ()
470357 },
471358 })
@@ -478,6 +365,169 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
478365 return agentConn , nil
479366}
480367
368+ // tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to
369+ //
370+ // 1) run the Coordinate API and pass node information back and forth
371+ // 2) stream DERPMap updates and program the Conn
372+ //
373+ // These functions share the same websocket, and so are combined here so that if we hit a problem
374+ // we tear the whole thing down and start over with a new websocket.
375+ type tailnetAPIConnector struct {
376+ ctx context.Context
377+ logger slog.Logger
378+
379+ agentID uuid.UUID
380+ coordinateURL string
381+ dialOptions * websocket.DialOptions
382+ conn * tailnet.Conn
383+
384+ connected chan error
385+ isFirst bool
386+ closed chan struct {}
387+ }
388+
389+ // runTailnetAPIConnector creates and runs a tailnetAPIConnector
390+ func runTailnetAPIConnector (
391+ ctx context.Context , logger slog.Logger ,
392+ agentID uuid.UUID , coordinateURL string , dialOptions * websocket.DialOptions ,
393+ conn * tailnet.Conn ,
394+ ) * tailnetAPIConnector {
395+ tac := & tailnetAPIConnector {
396+ ctx : ctx ,
397+ logger : logger ,
398+ agentID : agentID ,
399+ coordinateURL : coordinateURL ,
400+ dialOptions : dialOptions ,
401+ conn : conn ,
402+ connected : make (chan error , 1 ),
403+ closed : make (chan struct {}),
404+ }
405+ go tac .run ()
406+ return tac
407+ }
408+
409+ func (tac * tailnetAPIConnector ) run () {
410+ tac .isFirst = true
411+ defer close (tac .closed )
412+ for retrier := retry .New (50 * time .Millisecond , 10 * time .Second ); retrier .Wait (tac .ctx ); {
413+ tailnetClient , err := tac .dial ()
414+ if err != nil {
415+ continue
416+ }
417+ tac .logger .Debug (tac .ctx , "obtained tailnet API v2+ client" )
418+ tac .coordinateAndDERPMap (tailnetClient )
419+ tac .logger .Debug (tac .ctx , "tailnet API v2+ connection lost" )
420+ }
421+ }
422+
423+ func (tac * tailnetAPIConnector ) dial () (proto.DRPCTailnetClient , error ) {
424+ tac .logger .Debug (tac .ctx , "dialing Coder tailnet v2+ API" )
425+ // nolint:bodyclose
426+ ws , res , err := websocket .Dial (tac .ctx , tac .coordinateURL , tac .dialOptions )
427+ if tac .isFirst {
428+ if res != nil && res .StatusCode == http .StatusConflict {
429+ err = ReadBodyAsError (res )
430+ tac .connected <- err
431+ return nil , err
432+ }
433+ tac .isFirst = false
434+ close (tac .connected )
435+ }
436+ if err != nil {
437+ if ! errors .Is (err , context .Canceled ) {
438+ tac .logger .Error (tac .ctx , "failed to dial tailnet v2+ API" , slog .Error (err ))
439+ }
440+ return nil , err
441+ }
442+ client , err := tailnet .NewDRPCClient (websocket .NetConn (tac .ctx , ws , websocket .MessageBinary ))
443+ if err != nil {
444+ tac .logger .Debug (tac .ctx , "failed to create DRPCClient" , slog .Error (err ))
445+ _ = ws .Close (websocket .StatusInternalError , "" )
446+ return nil , err
447+ }
448+ return client , err
449+ }
450+
451+ // coordinateAndDERPMap uses the provided client to coordinate and stream DERP Maps. It is combined
452+ // into one function so that a problem with one tears down the other and triggers a retry (if
453+ // appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same
454+ // fate.
455+ func (tac * tailnetAPIConnector ) coordinateAndDERPMap (client proto.DRPCTailnetClient ) {
456+ defer func () {
457+ conn := client .DRPCConn ()
458+ closeErr := conn .Close ()
459+ if closeErr != nil &&
460+ ! xerrors .Is (closeErr , io .EOF ) &&
461+ ! xerrors .Is (closeErr , context .Canceled ) &&
462+ ! xerrors .Is (closeErr , context .DeadlineExceeded ) {
463+ tac .logger .Error (tac .ctx , "error closing DRPC connection" , slog .Error (closeErr ))
464+ <- conn .Closed ()
465+ }
466+ }()
467+ eg , egCtx := errgroup .WithContext (tac .ctx )
468+ eg .Go (func () error {
469+ return tac .coordinate (egCtx , client )
470+ })
471+ eg .Go (func () error {
472+ return tac .derpMap (egCtx , client )
473+ })
474+ err := eg .Wait ()
475+ if err != nil &&
476+ ! xerrors .Is (err , io .EOF ) &&
477+ ! xerrors .Is (err , context .Canceled ) &&
478+ ! xerrors .Is (err , context .DeadlineExceeded ) {
479+ tac .logger .Error (tac .ctx , "error while connected to tailnet v2+ API" )
480+ }
481+ }
482+
483+ func (tac * tailnetAPIConnector ) coordinate (ctx context.Context , client proto.DRPCTailnetClient ) error {
484+ coord , err := client .Coordinate (ctx )
485+ if err != nil {
486+ return xerrors .Errorf ("failed to connect to Coordinate RPC: %w" , err )
487+ }
488+ defer func () {
489+ cErr := coord .Close ()
490+ if cErr != nil {
491+ tac .logger .Debug (ctx , "error closing Coordinate RPC" , slog .Error (cErr ))
492+ }
493+ }()
494+ coordination := tailnet .NewRemoteCoordination (tac .logger , coord , tac .conn , tac .agentID )
495+ tac .logger .Debug (ctx , "serving coordinator" )
496+ err = <- coordination .Error ()
497+ if err != nil &&
498+ ! xerrors .Is (err , io .EOF ) &&
499+ ! xerrors .Is (err , context .Canceled ) &&
500+ ! xerrors .Is (err , context .DeadlineExceeded ) {
501+ return xerrors .Errorf ("remote coordination error: %w" , err )
502+ }
503+ return nil
504+ }
505+
506+ func (tac * tailnetAPIConnector ) derpMap (ctx context.Context , client proto.DRPCTailnetClient ) error {
507+ s , err := client .StreamDERPMaps (ctx , & proto.StreamDERPMapsRequest {})
508+ if err != nil {
509+ return xerrors .Errorf ("failed to connect to StreamDERPMaps RPC: %w" , err )
510+ }
511+ defer func () {
512+ cErr := s .Close ()
513+ if cErr != nil {
514+ tac .logger .Debug (ctx , "error closing StreamDERPMaps RPC" , slog .Error (cErr ))
515+ }
516+ }()
517+ for {
518+ dmp , err := s .Recv ()
519+ if err != nil {
520+ if xerrors .Is (err , io .EOF ) || xerrors .Is (err , context .Canceled ) || xerrors .Is (err , context .DeadlineExceeded ) {
521+ return nil
522+ }
523+ return xerrors .Errorf ("error receiving DERP Map: %w" , err )
524+ }
525+ tac .logger .Debug (ctx , "got new DERP Map" , slog .F ("derp_map" , dmp ))
526+ dm := tailnet .DERPMapFromProto (dmp )
527+ tac .conn .SetDERPMap (dm )
528+ }
529+ }
530+
481531// WatchWorkspaceAgentMetadata watches the metadata of a workspace agent.
482532// The returned channel will be closed when the context is canceled. Exactly
483533// one error will be sent on the error channel. The metadata channel is never closed.
0 commit comments