diff --git a/go.mod b/go.mod index 4b05999ab..9eb4c9ff6 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/skycoin/noise v0.0.0-20180327030543-2492fe189ae6 github.com/skycoin/skycoin v0.28.1-0.20250823221707-c533551dfabd //DO NOT MODIFY OR UPDATE v0.28.1-0.20241105130348-39b49a2d0a7f - github.com/skycoin/skywire v1.3.31-0.20250724153549-ec7ca3554d42 + github.com/skycoin/skywire v1.3.31-0.20250810155428-30d83a379b39 github.com/spf13/cobra v1.9.1 github.com/stretchr/testify v1.10.0 golang.org/x/net v0.43.0 @@ -30,6 +30,8 @@ require ( golang.org/x/term v0.34.0 ) +require github.com/xtaci/smux v1.5.34 + require ( github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect github.com/bytedance/sonic v1.14.0 // indirect diff --git a/go.sum b/go.sum index fc6b09d52..eef8b660e 100644 --- a/go.sum +++ b/go.sum @@ -118,8 +118,8 @@ github.com/skycoin/noise v0.0.0-20180327030543-2492fe189ae6 h1:1Nc5EBY6pjfw1kwW0 github.com/skycoin/noise v0.0.0-20180327030543-2492fe189ae6/go.mod h1:UXghlricA7J3aRD/k7p/zBObQfmBawwCxIVPVjz2Q3o= github.com/skycoin/skycoin v0.28.1-0.20250823221707-c533551dfabd h1:yKo1t3+P78TcCZvWqEJDV7DAB162C3qVHDKLjB8b2hA= github.com/skycoin/skycoin v0.28.1-0.20250823221707-c533551dfabd/go.mod h1:9w5J+CJ7fWwkmpttrQ2SFksiSPc0t0DtwsCdXLdl4Qg= -github.com/skycoin/skywire v1.3.31-0.20250724153549-ec7ca3554d42 h1:9Hr/ht404g8fDo80Bw9YIPwu0IuDKrG3mRkZeH6y/Vc= -github.com/skycoin/skywire v1.3.31-0.20250724153549-ec7ca3554d42/go.mod h1:JnR5EJHpryaFFILpPpFJybtUT+0+2/aQxSxgjKvsPZs= +github.com/skycoin/skywire v1.3.31-0.20250810155428-30d83a379b39 h1:6+YIdW2rrU9ZDCvigTP9j/oUGcWgEzPPNFM0yo7Z2F0= +github.com/skycoin/skywire v1.3.31-0.20250810155428-30d83a379b39/go.mod h1:8fUvhqqo54SR0lMlUpGX/qnXSjKZEQw6TFYsXzwBAqg= github.com/spf13/cobra v1.4.0/go.mod h1:Wo4iy3BUC+X2Fybo0PDqwJIv3dNRiZLHQymsfxlB84g= github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= @@ -145,6 +145,8 @@ github.com/valyala/fastrand v1.1.0 h1:f+5HkLW4rsgzdNoleUOB69hyT9IlD2ZQh9GyDMfb5G github.com/valyala/fastrand v1.1.0/go.mod h1:HWqCzkrkg6QXT8V2EXWvXCoow7vLwOFN002oeRzjapQ= github.com/valyala/histogram v1.2.0 h1:wyYGAZZt3CpwUiIb9AU/Zbllg1llXyrtApRS815OLoQ= github.com/valyala/histogram v1.2.0/go.mod h1:Hb4kBwb4UxsaNbbbh+RRz8ZR6pdodR57tzWUS3BUzXY= +github.com/xtaci/smux v1.5.34 h1:OUA9JaDFHJDT8ZT3ebwLWPAgEfE6sWo2LaTy3anXqwg= +github.com/xtaci/smux v1.5.34/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY= golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= diff --git a/pkg/disc/entry.go b/pkg/disc/entry.go index 705b4b373..dbb286a35 100644 --- a/pkg/disc/entry.go +++ b/pkg/disc/entry.go @@ -123,6 +123,9 @@ type Entry struct { // Signature for proving authenticity of an Entry. Signature string `json:"signature,omitempty"` + + // Protocol is the lib that use for multiplexing. + Protocol string `json:"protocol,omitempty"` } func (e *Entry) String() string { diff --git a/pkg/dmsg/client.go b/pkg/dmsg/client.go index 0a9d2e2b7..6129a5249 100644 --- a/pkg/dmsg/client.go +++ b/pkg/dmsg/client.go @@ -10,9 +10,11 @@ import ( "sync" "time" + "github.com/hashicorp/yamux" "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher" "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/logging" "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/netutil" + "github.com/xtaci/smux" "golang.org/x/net/proxy" "github.com/skycoin/dmsg/pkg/disc" @@ -47,6 +49,7 @@ type Config struct { Callbacks *ClientCallbacks ClientType string ConnectedServersType string + Protocol string } // Ensure ensures all config values are set. @@ -156,6 +159,8 @@ func (ce *Client) Serve(ctx context.Context) { updateEntryLoopOnce := new(sync.Once) + needInitialPost := true + for { if isClosed(ce.done) { return @@ -208,6 +213,18 @@ func (ce *Client) Serve(ctx context.Context) { rand.Shuffle(len(entries), func(i, j int) { entries[i], entries[j] = entries[j], entries[i] }) + + if needInitialPost { + // use this for put protocol type of client to disc, for dicision part of dmsg-server + err = ce.initilizeClientEntry(cancellabelCtx, ce.conf.ClientType, ce.conf.Protocol) + if err != nil { + ce.log.WithError(err).Warn("Initial post entry failed") + } else { + ce.log.WithError(err).Info("Initial post entry successed") + } + needInitialPost = false + } + for n, entry := range entries { if isClosed(ce.done) { return @@ -490,7 +507,7 @@ func (ce *Client) EnsureSession(ctx context.Context, entry *disc.Entry) error { ce.log.WithField("remote_pk", entry.Static).Debug("Session already exists...") return nil } - + entry.Protocol = ce.conf.Protocol // Dial session. _, err := ce.dialSession(ctx, entry) return err @@ -537,6 +554,19 @@ func (ce *Client) dialSession(ctx context.Context, entry *disc.Entry) (cs Client if err != nil { return ClientSession{}, err } + if entry.Protocol == "smux" { + dSes.sm.smux, err = smux.Client(conn, smux.DefaultConfig()) + if err != nil { + return ClientSession{}, err + } + ce.log.Infof("smux stream session initial for %s", dSes.RemotePK().String()) + } else { + dSes.sm.yamux, err = yamux.Client(conn, yamux.DefaultConfig()) + if err != nil { + return ClientSession{}, err + } + ce.log.Infof("yamux stream session initial for %s", dSes.RemotePK().String()) + } if !ce.setSession(ctx, dSes.SessionCommon) { _ = dSes.Close() //nolint:errcheck diff --git a/pkg/dmsg/entity_common.go b/pkg/dmsg/entity_common.go index eeb83556e..2e6ce82c4 100644 --- a/pkg/dmsg/entity_common.go +++ b/pkg/dmsg/entity_common.go @@ -225,6 +225,29 @@ func (c *EntityCommon) updateServerEntryLoop(ctx context.Context, addr string, m } } +func (c *EntityCommon) initilizeClientEntry(ctx context.Context, clientType string, protocol string) (err error) { + // Record last update on success. + defer func() { + if err == nil { + c.recordUpdate() + } + }() + + srvPKs := make([]cipher.PubKey, 0, len(c.sessions)) + + _, err = c.dc.Entry(ctx, c.pk) + if err != nil { + entry := disc.NewClientEntry(c.pk, 0, srvPKs) + entry.ClientType = clientType + entry.Protocol = protocol + if err := entry.Sign(c.sk); err != nil { + return err + } + return c.dc.PostEntry(ctx, entry) + } + return nil +} + func (c *EntityCommon) updateClientEntry(ctx context.Context, done chan struct{}, clientType string) (err error) { if isClosed(done) { return nil @@ -295,6 +318,17 @@ func (c *EntityCommon) updateClientEntryLoop(ctx context.Context, done chan stru } } +func (c *EntityCommon) entryProtocol(ctx context.Context, pk cipher.PubKey) string { + entry, err := c.dc.Entry(ctx, pk) + if err != nil { + c.log.WithField("entry", entry).WithError(err).Warn("Entry not found, so return empty as protocol.\n") + return "" + } + + c.log.WithField("entry", entry).Debug("Entry's protocol fetch.\n") + return entry.Protocol +} + func (c *EntityCommon) delEntry(ctx context.Context) (err error) { entry, err := c.dc.Entry(ctx, c.pk) diff --git a/pkg/dmsg/server.go b/pkg/dmsg/server.go index 77105b83c..ece1bc399 100644 --- a/pkg/dmsg/server.go +++ b/pkg/dmsg/server.go @@ -7,9 +7,11 @@ import ( "sync" "time" + "github.com/hashicorp/yamux" "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher" "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/logging" "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/netutil" + "github.com/xtaci/smux" "github.com/skycoin/dmsg/internal/servermetrics" "github.com/skycoin/dmsg/pkg/disc" @@ -214,7 +216,6 @@ func (s *Server) handleSession(conn net.Conn) { } return } - log = log.WithField("remote_pk", dSes.RemotePK()) log.Info("Started session.") @@ -223,6 +224,27 @@ func (s *Server) handleSession(conn net.Conn) { awaitDone(ctx, s.done) log.WithError(dSes.Close()).Info("Stopped session.") }() + // detect visor protocol for dmsg + protocol := s.entryProtocol(ctx, dSes.RemotePK()) + + // based on protocol, create smux or yamux stream session + if protocol == "smux" { + dSes.sm.smux, err = smux.Server(conn, smux.DefaultConfig()) + if err != nil { + cancel() + return + } + dSes.sm.addr = dSes.sm.smux.RemoteAddr() + log.Infof("smux stream session initial for %s", dSes.RemotePK().String()) + } else { + dSes.sm.yamux, err = yamux.Server(conn, yamux.DefaultConfig()) + if err != nil { + cancel() + return + } + dSes.sm.addr = dSes.sm.yamux.RemoteAddr() + log.Infof("yamux stream session initial for %s", dSes.RemotePK().String()) + } if s.setSession(ctx, dSes.SessionCommon) { dSes.Serve() diff --git a/pkg/dmsg/server_session.go b/pkg/dmsg/server_session.go index e92aa28b5..fd15e52db 100644 --- a/pkg/dmsg/server_session.go +++ b/pkg/dmsg/server_session.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/yamux" "github.com/sirupsen/logrus" "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/netutil" + "github.com/xtaci/smux" "github.com/skycoin/dmsg/internal/servermetrics" "github.com/skycoin/dmsg/pkg/noise" @@ -44,30 +45,54 @@ func (ss *ServerSession) Close() error { func (ss *ServerSession) Serve() { ss.m.RecordSession(servermetrics.DeltaConnect) // record successful connection defer ss.m.RecordSession(servermetrics.DeltaDisconnect) // record disconnection - - for { - yStr, err := ss.ys.AcceptStream() - if err != nil { - switch err { - case yamux.ErrSessionShutdown, io.EOF: - ss.log.WithError(err).Info("Stopping session...") - default: - ss.log.WithError(err).Warn("Failed to accept stream, stopping session...") + if ss.sm.smux != nil { + for { + sStr, err := ss.sm.smux.AcceptStream() + if err != nil { + switch err { + case io.EOF: + ss.log.WithError(err).Info("Stopping session...") + default: + ss.log.WithError(err).Warn("Failed to accept stream, stopping session...") + } + return } - return + + log := ss.log.WithField("smux_id", sStr.ID()) + log.Info("Initiating stream.") + + go func(sStr *smux.Stream) { + err := ss.serveStream(log, sStr, ss.sm.addr) + log.WithError(err).Info("Stopped stream.") + }(sStr) } + } else { + for { + yStr, err := ss.sm.yamux.AcceptStream() + if err != nil { + switch err { + case yamux.ErrSessionShutdown, io.EOF: + ss.log.WithError(err).Info("Stopping session...") + default: + ss.log.WithError(err).Warn("Failed to accept stream, stopping session...") + } + return + } - log := ss.log.WithField("yamux_id", yStr.StreamID()) - log.Info("Initiating stream.") + log := ss.log.WithField("yamux_id", yStr.StreamID()) + log.Info("Initiating stream.") - go func(yStr *yamux.Stream) { - err := ss.serveStream(log, yStr) - log.WithError(err).Info("Stopped stream.") - }(yStr) + go func(yStr *yamux.Stream) { + err := ss.serveStream(log, yStr, ss.sm.addr) + log.WithError(err).Info("Stopped stream.") + }(yStr) + } } } -func (ss *ServerSession) serveStream(log logrus.FieldLogger, yStr *yamux.Stream) error { +// struct + +func (ss *ServerSession) serveStream(log logrus.FieldLogger, yStr io.ReadWriteCloser, addr net.Addr) error { readRequest := func() (StreamRequest, error) { obj, err := ss.readObject(yStr) if err != nil { @@ -102,7 +127,7 @@ func (ss *ServerSession) serveStream(log logrus.FieldLogger, yStr *yamux.Stream) if req.IPinfo && req.DstAddr.PK == ss.entity.LocalPK() { log.Debug("Received IP stream request.") - ip, err := addrToIP(yStr.RemoteAddr()) + ip, err := addrToIP(addr) if err != nil { ss.m.RecordStream(servermetrics.DeltaFailed) // record failed stream return err @@ -164,22 +189,27 @@ func addrToIP(addr net.Addr) (net.IP, error) { } } -func (ss *ServerSession) forwardRequest(req StreamRequest) (yStr *yamux.Stream, respObj SignedObject, err error) { +func (ss *ServerSession) forwardRequest(req StreamRequest) (mStr io.ReadWriteCloser, respObj SignedObject, err error) { defer func() { - if err != nil && yStr != nil { + if err != nil && mStr != nil { ss.log. - WithError(yStr.Close()). + WithError(mStr.Close()). Debugf("After forwardRequest failed, the yamux stream is closed.") } }() - - if yStr, err = ss.ys.OpenStream(); err != nil { - return nil, nil, err + if ss.sm.smux != nil { + if mStr, err = ss.sm.smux.OpenStream(); err != nil { + return nil, nil, err + } + } else { + if mStr, err = ss.sm.yamux.OpenStream(); err != nil { + return nil, nil, err + } } - if err = ss.writeObject(yStr, req.raw); err != nil { + if err = ss.writeObject(mStr, req.raw); err != nil { return nil, nil, err } - if respObj, err = ss.readObject(yStr); err != nil { + if respObj, err = ss.readObject(mStr); err != nil { return nil, nil, err } var resp StreamResponse @@ -189,5 +219,5 @@ func (ss *ServerSession) forwardRequest(req StreamRequest) (yStr *yamux.Stream, if err = resp.Verify(req); err != nil { return nil, nil, err } - return yStr, respObj, nil + return mStr, respObj, nil } diff --git a/pkg/dmsg/session_common.go b/pkg/dmsg/session_common.go index 69903a106..e9ac8c4a0 100644 --- a/pkg/dmsg/session_common.go +++ b/pkg/dmsg/session_common.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/yamux" "github.com/sirupsen/logrus" "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher" + "github.com/xtaci/smux" "github.com/skycoin/dmsg/pkg/noise" ) @@ -24,15 +25,24 @@ type SessionCommon struct { rPK cipher.PubKey // remote pk netConn net.Conn // underlying net.Conn (TCP connection to the dmsg server) - ys *yamux.Session - ns *noise.Noise - nMap noise.NonceMap - rMx sync.Mutex - wMx sync.Mutex + // ys *yamux.Session + // ss *smux.Session + sm SessionManager + ns *noise.Noise + nMap noise.NonceMap + rMx sync.Mutex + wMx sync.Mutex log logrus.FieldLogger } +// SessionManager blablabla +type SessionManager struct { + yamux *yamux.Session + smux *smux.Session + addr net.Addr +} + // GetConn returns underlying TCP `net.Conn`. func (sc *SessionCommon) GetConn() net.Conn { return sc.netConn @@ -70,16 +80,9 @@ func (sc *SessionCommon) initClient(entity *EntityCommon, conn net.Conn, rPK cip if rw.Buffered() > 0 { return ErrSessionHandshakeExtraBytes } - - ySes, err := yamux.Client(conn, yamux.DefaultConfig()) - if err != nil { - return err - } - sc.entity = entity sc.rPK = rPK sc.netConn = conn - sc.ys = ySes sc.ns = ns sc.nMap = make(noise.NonceMap) sc.log = entity.log.WithField("session", ns.RemoteStatic()) @@ -104,15 +107,9 @@ func (sc *SessionCommon) initServer(entity *EntityCommon, conn net.Conn) error { return ErrSessionHandshakeExtraBytes } - ySes, err := yamux.Server(conn, yamux.DefaultConfig()) - if err != nil { - return err - } - sc.entity = entity sc.rPK = ns.RemoteStatic() sc.netConn = conn - sc.ys = ySes sc.ns = ns sc.nMap = make(noise.NonceMap) sc.log = entity.log.WithField("session", ns.RemoteStatic()) @@ -170,14 +167,25 @@ func (sc *SessionCommon) LocalTCPAddr() net.Addr { return sc.netConn.LocalAddr() func (sc *SessionCommon) RemoteTCPAddr() net.Addr { return sc.netConn.RemoteAddr() } // Ping obtains the round trip latency of the session. -func (sc *SessionCommon) Ping() (time.Duration, error) { return sc.ys.Ping() } +func (sc *SessionCommon) Ping() (time.Duration, error) { + if sc.sm.yamux != nil { + return sc.sm.yamux.Ping() + } + return 0, fmt.Errorf("Ping not available on SMUX protocol") +} // Close closes the session. func (sc *SessionCommon) Close() error { if sc == nil { return nil } - err := sc.ys.Close() + var err error + if sc.sm.smux != nil { + err = sc.sm.smux.Close() + } + if sc.sm.yamux != nil { + err = sc.sm.yamux.Close() + } sc.rMx.Lock() sc.nMap = nil sc.rMx.Unlock() diff --git a/pkg/dmsg/stream.go b/pkg/dmsg/stream.go index 69f1dfed9..2c255ab77 100644 --- a/pkg/dmsg/stream.go +++ b/pkg/dmsg/stream.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/yamux" "github.com/sirupsen/logrus" "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher" + "github.com/xtaci/smux" "github.com/skycoin/dmsg/pkg/noise" ) @@ -17,7 +18,7 @@ import ( type Stream struct { ses *ClientSession // back reference yStr *yamux.Stream - + sStr *smux.Stream // The following fields are to be filled after handshake. lAddr Addr rAddr Addr @@ -28,15 +29,30 @@ type Stream struct { } func newInitiatingStream(cSes *ClientSession) (*Stream, error) { - yStr, err := cSes.ys.OpenStream() + if cSes.sm.smux != nil { + sStr, err := cSes.sm.smux.OpenStream() + if err != nil { + return nil, err + } + return &Stream{ses: cSes, sStr: sStr}, nil + } + yStr, err := cSes.sm.yamux.OpenStream() if err != nil { return nil, err } return &Stream{ses: cSes, yStr: yStr}, nil + } func newRespondingStream(cSes *ClientSession) (*Stream, error) { - yStr, err := cSes.ys.AcceptStream() + if cSes.sm.smux != nil { + sStr, err := cSes.sm.smux.AcceptStream() + if err != nil { + return nil, err + } + return &Stream{ses: cSes, sStr: sStr}, nil + } + yStr, err := cSes.sm.yamux.AcceptStream() if err != nil { return nil, err } @@ -51,6 +67,9 @@ func (s *Stream) Close() error { if s.close != nil { s.close() } + if s.sStr != nil { + return s.sStr.Close() + } return s.yStr.Close() } @@ -83,6 +102,10 @@ func (s *Stream) writeRequest(rAddr Addr) (req StreamRequest, err error) { obj := MakeSignedStreamRequest(&req, s.ses.localSK()) // Write request. + if s.sStr != nil { + err = s.ses.writeObject(s.sStr, obj) + return + } err = s.ses.writeObject(s.yStr, obj) return } @@ -106,15 +129,26 @@ func (s *Stream) writeIPRequest(rAddr Addr) (req StreamRequest, err error) { obj := MakeSignedStreamRequest(&req, s.ses.localSK()) // Write request. + if s.sStr != nil { + err = s.ses.writeObject(s.sStr, obj) + return + } err = s.ses.writeObject(s.yStr, obj) return } func (s *Stream) readRequest() (req StreamRequest, err error) { var obj SignedObject - if obj, err = s.ses.readObject(s.yStr); err != nil { - return + if s.sStr != nil { + if obj, err = s.ses.readObject(s.sStr); err != nil { + return + } + } else { + if obj, err = s.ses.readObject(s.yStr); err != nil { + return + } } + if req, err = obj.ObtainStreamRequest(); err != nil { return } @@ -158,8 +192,14 @@ func (s *Stream) writeResponse(reqHash cipher.SHA256) error { } obj := MakeSignedStreamResponse(&resp, s.ses.localSK()) - if err := s.ses.writeObject(s.yStr, obj); err != nil { - return err + if s.sStr != nil { + if err := s.ses.writeObject(s.sStr, obj); err != nil { + return err + } + } else { + if err := s.ses.writeObject(s.yStr, obj); err != nil { + return err + } } // Push stream to listener. @@ -167,9 +207,18 @@ func (s *Stream) writeResponse(reqHash cipher.SHA256) error { } func (s *Stream) readResponse(req StreamRequest) error { - obj, err := s.ses.readObject(s.yStr) - if err != nil { - return err + var obj SignedObject + var err error + if s.sStr != nil { + obj, err = s.ses.readObject(s.sStr) + if err != nil { + return err + } + } else { + obj, err = s.ses.readObject(s.yStr) + if err != nil { + return err + } } resp, err := obj.ObtainStreamResponse() if err != nil { @@ -182,9 +231,18 @@ func (s *Stream) readResponse(req StreamRequest) error { } func (s *Stream) readIPResponse(req StreamRequest) (net.IP, error) { - obj, err := s.ses.readObject(s.yStr) - if err != nil { - return nil, err + var obj SignedObject + var err error + if s.sStr != nil { + obj, err = s.ses.readObject(s.sStr) + if err != nil { + return nil, err + } + } else { + obj, err = s.ses.readObject(s.yStr) + if err != nil { + return nil, err + } } resp, err := obj.ObtainStreamResponse() if err != nil { @@ -210,7 +268,11 @@ func (s *Stream) prepareFields(init bool, lAddr, rAddr Addr) { s.lAddr = lAddr s.rAddr = rAddr s.ns = ns - s.nsConn = noise.NewReadWriter(s.yStr, s.ns) + if s.sStr != nil { + s.nsConn = noise.NewReadWriter(s.sStr, s.ns) + } else { + s.nsConn = noise.NewReadWriter(s.yStr, s.ns) + } s.log = s.ses.log.WithField("stream", s.lAddr.ShortString()+"->"+s.rAddr.ShortString()) } @@ -241,6 +303,9 @@ func (s *Stream) ServerPK() cipher.PubKey { // StreamID returns the stream ID. func (s *Stream) StreamID() uint32 { + if s.sStr != nil { + return s.sStr.ID() + } return s.yStr.StreamID() } @@ -256,15 +321,24 @@ func (s *Stream) Write(b []byte) (int, error) { // SetDeadline implements net.Conn func (s *Stream) SetDeadline(t time.Time) error { + if s.sStr != nil { + return s.sStr.SetDeadline(t) + } return s.yStr.SetDeadline(t) } // SetReadDeadline implements net.Conn func (s *Stream) SetReadDeadline(t time.Time) error { + if s.sStr != nil { + return s.sStr.SetReadDeadline(t) + } return s.yStr.SetReadDeadline(t) } // SetWriteDeadline implements net.Conn func (s *Stream) SetWriteDeadline(t time.Time) error { + if s.sStr != nil { + return s.sStr.SetWriteDeadline(t) + } return s.yStr.SetWriteDeadline(t) } diff --git a/vendor/github.com/xtaci/smux/.gitignore b/vendor/github.com/xtaci/smux/.gitignore new file mode 100644 index 000000000..daf913b1b --- /dev/null +++ b/vendor/github.com/xtaci/smux/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof diff --git a/vendor/github.com/xtaci/smux/.travis.yml b/vendor/github.com/xtaci/smux/.travis.yml new file mode 100644 index 000000000..0d37d1168 --- /dev/null +++ b/vendor/github.com/xtaci/smux/.travis.yml @@ -0,0 +1,20 @@ +arch: + - amd64 + - ppc64le +language: go +go: + - 1.9.x + - 1.10.x + - 1.11.x + +before_install: + - go get -t -v ./... + +install: + - go get github.com/xtaci/smux + +script: + - go test -coverprofile=coverage.txt -covermode=atomic -bench . + +after_success: + - bash <(curl -s https://codecov.io/bash) diff --git a/vendor/github.com/xtaci/smux/LICENSE b/vendor/github.com/xtaci/smux/LICENSE new file mode 100644 index 000000000..d072f023a --- /dev/null +++ b/vendor/github.com/xtaci/smux/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2016-2017 xtaci + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/xtaci/smux/README.md b/vendor/github.com/xtaci/smux/README.md new file mode 100644 index 000000000..c3306515e --- /dev/null +++ b/vendor/github.com/xtaci/smux/README.md @@ -0,0 +1,136 @@ +smux + +[![GoDoc][1]][2] [![MIT licensed][3]][4] [![Build Status][5]][6] [![Go Report Card][7]][8] [![Coverage Statusd][9]][10] [![Sourcegraph][11]][12] + +smux + +[1]: https://godoc.org/github.com/xtaci/smux?status.svg +[2]: https://godoc.org/github.com/xtaci/smux +[3]: https://img.shields.io/badge/license-MIT-blue.svg +[4]: LICENSE +[5]: https://img.shields.io/github/created-at/xtaci/smux +[6]: https://img.shields.io/github/created-at/xtaci/smux +[7]: https://goreportcard.com/badge/github.com/xtaci/smux +[8]: https://goreportcard.com/report/github.com/xtaci/smux +[9]: https://codecov.io/gh/xtaci/smux/branch/master/graph/badge.svg +[10]: https://codecov.io/gh/xtaci/smux +[11]: https://sourcegraph.com/github.com/xtaci/smux/-/badge.svg +[12]: https://sourcegraph.com/github.com/xtaci/smux?badge + +## Introduction + +Smux ( **S**imple **MU**ltiple**X**ing) is a multiplexing library for Golang. It relies on an underlying connection to provide reliability and ordering, such as TCP or [KCP](https://github.com/xtaci/kcp-go), and provides stream-oriented multiplexing. The original intention of this library is to power the connection management for [kcp-go](https://github.com/xtaci/kcp-go). + +## Features + +1. ***Token bucket*** controlled receiving, which provides smoother bandwidth graph(see picture below). +2. Session-wide receive buffer, shared among streams, **fully controlled** overall memory usage. +3. Minimized header(8Bytes), maximized payload. +4. Well-tested on millions of devices in [kcptun](https://github.com/xtaci/kcptun). +5. Builtin fair queue traffic shaping. +6. Per-stream sliding window to control congestion.(protocol version 2+). + +![smooth bandwidth curve](assets/curve.jpg) + +## Documentation + +For complete documentation, see the associated [Godoc](https://godoc.org/github.com/xtaci/smux). + +## Benchmark +``` +$ go test -v -run=^$ -bench . +goos: darwin +goarch: amd64 +pkg: github.com/xtaci/smux +BenchmarkMSB-4 30000000 51.8 ns/op +BenchmarkAcceptClose-4 50000 36783 ns/op +BenchmarkConnSmux-4 30000 58335 ns/op 2246.88 MB/s 1208 B/op 19 allocs/op +BenchmarkConnTCP-4 50000 25579 ns/op 5124.04 MB/s 0 B/op 0 allocs/op +PASS +ok github.com/xtaci/smux 7.811s +``` + +## Specification + +``` +VERSION(1B) | CMD(1B) | LENGTH(2B) | STREAMID(4B) | DATA(LENGTH) + +VALUES FOR LATEST VERSION: +VERSION: + 1/2 + +CMD: + cmdSYN(0) + cmdFIN(1) + cmdPSH(2) + cmdNOP(3) + cmdUPD(4) // only supported on version 2 + +STREAMID: + client use odd numbers starts from 1 + server use even numbers starts from 0 + +cmdUPD: + | CONSUMED(4B) | WINDOW(4B) | +``` + +## Usage + +```go + +func client() { + // Get a TCP connection + conn, err := net.Dial(...) + if err != nil { + panic(err) + } + + // Setup client side of smux + session, err := smux.Client(conn, nil) + if err != nil { + panic(err) + } + + // Open a new stream + stream, err := session.OpenStream() + if err != nil { + panic(err) + } + + // Stream implements io.ReadWriteCloser + stream.Write([]byte("ping")) + stream.Close() + session.Close() +} + +func server() { + // Accept a TCP connection + conn, err := listener.Accept() + if err != nil { + panic(err) + } + + // Setup server side of smux + session, err := smux.Server(conn, nil) + if err != nil { + panic(err) + } + + // Accept a stream + stream, err := session.AcceptStream() + if err != nil { + panic(err) + } + + // Listen for a message + buf := make([]byte, 4) + stream.Read(buf) + stream.Close() + session.Close() +} + +``` + +## Status + +Stable diff --git a/vendor/github.com/xtaci/smux/alloc.go b/vendor/github.com/xtaci/smux/alloc.go new file mode 100644 index 000000000..b7bcf2b29 --- /dev/null +++ b/vendor/github.com/xtaci/smux/alloc.go @@ -0,0 +1,101 @@ +// MIT License +// +// Copyright (c) 2016-2017 xtaci +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package smux + +import ( + "errors" + "sync" +) + +var ( + defaultAllocator *Allocator + debruijinPos = [...]byte{0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30, 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31} +) + +func init() { + defaultAllocator = NewAllocator() +} + +// Allocator for incoming frames, optimized to prevent overwriting after zeroing +type Allocator struct { + buffers []sync.Pool +} + +// NewAllocator initiates a []byte allocator for frames less than 65536 bytes, +// the waste(memory fragmentation) of space allocation is guaranteed to be +// no more than 50%. +func NewAllocator() *Allocator { + alloc := new(Allocator) + alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K + for k := range alloc.buffers { + i := k + alloc.buffers[k].New = func() interface{} { + b := make([]byte, 1< 65536 { + return nil + } + + bits := msb(size) + if size == 1< 65536 || cap(*p) != 1<> 1 + v |= v >> 2 + v |= v >> 4 + v |= v >> 8 + v |= v >> 16 + return debruijinPos[(v*0x07C4ACDD)>>27] +} diff --git a/vendor/github.com/xtaci/smux/frame.go b/vendor/github.com/xtaci/smux/frame.go new file mode 100644 index 000000000..902b655da --- /dev/null +++ b/vendor/github.com/xtaci/smux/frame.go @@ -0,0 +1,106 @@ +// MIT License +// +// Copyright (c) 2016-2017 xtaci +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package smux + +import ( + "encoding/binary" + "fmt" +) + +const ( // cmds + // protocol version 1: + cmdSYN byte = iota // stream open + cmdFIN // stream close, a.k.a EOF mark + cmdPSH // data push + cmdNOP // no operation + + // protocol version 2 extra commands + // notify bytes consumed by remote peer-end + cmdUPD +) + +const ( + // data size of cmdUPD, format: + // |4B data consumed(ACK)| 4B window size(WINDOW) | + szCmdUPD = 8 +) + +const ( + // initial peer window guess, a slow-start + initialPeerWindow = 262144 +) + +const ( + sizeOfVer = 1 + sizeOfCmd = 1 + sizeOfLength = 2 + sizeOfSid = 4 + headerSize = sizeOfVer + sizeOfCmd + sizeOfSid + sizeOfLength +) + +// Frame defines a packet from or to be multiplexed into a single connection +type Frame struct { + ver byte // version + cmd byte // command + sid uint32 // stream id + data []byte // payload +} + +// newFrame creates a new frame with given version, command and stream id +func newFrame(version byte, cmd byte, sid uint32) Frame { + return Frame{ver: version, cmd: cmd, sid: sid} +} + +// rawHeader is a byte array representation of Frame header +type rawHeader [headerSize]byte + +func (h rawHeader) Version() byte { + return h[0] +} + +func (h rawHeader) Cmd() byte { + return h[1] +} + +func (h rawHeader) Length() uint16 { + return binary.LittleEndian.Uint16(h[2:]) +} + +func (h rawHeader) StreamID() uint32 { + return binary.LittleEndian.Uint32(h[4:]) +} + +func (h rawHeader) String() string { + return fmt.Sprintf("Version:%d Cmd:%d StreamID:%d Length:%d", + h.Version(), h.Cmd(), h.StreamID(), h.Length()) +} + +// updHeader is a byte array representation of cmdUPD +type updHeader [szCmdUPD]byte + +func (h updHeader) Consumed() uint32 { + return binary.LittleEndian.Uint32(h[:]) +} +func (h updHeader) Window() uint32 { + return binary.LittleEndian.Uint32(h[4:]) +} diff --git a/vendor/github.com/xtaci/smux/mux.go b/vendor/github.com/xtaci/smux/mux.go new file mode 100644 index 000000000..39815c711 --- /dev/null +++ b/vendor/github.com/xtaci/smux/mux.go @@ -0,0 +1,128 @@ +// MIT License +// +// Copyright (c) 2016-2017 xtaci +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package smux + +import ( + "errors" + "fmt" + "io" + "math" + "time" +) + +// Config is used to tune the Smux session +type Config struct { + // SMUX Protocol version, support 1,2 + Version int + + // Disabled keepalive + KeepAliveDisabled bool + + // KeepAliveInterval is how often to send a NOP command to the remote + KeepAliveInterval time.Duration + + // KeepAliveTimeout is how long the session + // will be closed if no data has arrived + KeepAliveTimeout time.Duration + + // MaxFrameSize is used to control the maximum + // frame size to sent to the remote + MaxFrameSize int + + // MaxReceiveBuffer is used to control the maximum + // number of data in the buffer pool + MaxReceiveBuffer int + + // MaxStreamBuffer is used to control the maximum + // number of data per stream + MaxStreamBuffer int +} + +// DefaultConfig is used to return a default configuration +func DefaultConfig() *Config { + return &Config{ + Version: 1, + KeepAliveInterval: 10 * time.Second, + KeepAliveTimeout: 30 * time.Second, + MaxFrameSize: 32768, + MaxReceiveBuffer: 4194304, + MaxStreamBuffer: 65536, + } +} + +// VerifyConfig is used to verify the sanity of configuration +func VerifyConfig(config *Config) error { + if !(config.Version == 1 || config.Version == 2) { + return errors.New("unsupported protocol version") + } + if !config.KeepAliveDisabled { + if config.KeepAliveInterval == 0 { + return errors.New("keep-alive interval must be positive") + } + if config.KeepAliveTimeout < config.KeepAliveInterval { + return fmt.Errorf("keep-alive timeout must be larger than keep-alive interval") + } + } + if config.MaxFrameSize <= 0 { + return errors.New("max frame size must be positive") + } + if config.MaxFrameSize > 65535 { + return errors.New("max frame size must not be larger than 65535") + } + if config.MaxReceiveBuffer <= 0 { + return errors.New("max receive buffer must be positive") + } + if config.MaxStreamBuffer <= 0 { + return errors.New("max stream buffer must be positive") + } + if config.MaxStreamBuffer > config.MaxReceiveBuffer { + return errors.New("max stream buffer must not be larger than max receive buffer") + } + if config.MaxStreamBuffer > math.MaxInt32 { + return errors.New("max stream buffer cannot be larger than 2147483647") + } + return nil +} + +// Server is used to initialize a new server-side connection. +func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) { + if config == nil { + config = DefaultConfig() + } + if err := VerifyConfig(config); err != nil { + return nil, err + } + return newSession(config, conn, false), nil +} + +// Client is used to initialize a new client-side connection. +func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) { + if config == nil { + config = DefaultConfig() + } + + if err := VerifyConfig(config); err != nil { + return nil, err + } + return newSession(config, conn, true), nil +} diff --git a/vendor/github.com/xtaci/smux/pkg.go b/vendor/github.com/xtaci/smux/pkg.go new file mode 100644 index 000000000..e5f110a77 --- /dev/null +++ b/vendor/github.com/xtaci/smux/pkg.go @@ -0,0 +1,6 @@ +// Package smux is a multiplexing library for Golang. +// +// It relies on an underlying connection to provide reliability and ordering, such as TCP or KCP, +// and provides stream-oriented multiplexing over a single channel. + +package smux diff --git a/vendor/github.com/xtaci/smux/session.go b/vendor/github.com/xtaci/smux/session.go new file mode 100644 index 000000000..ee07d4e29 --- /dev/null +++ b/vendor/github.com/xtaci/smux/session.go @@ -0,0 +1,619 @@ +// MIT License +// +// Copyright (c) 2016-2017 xtaci +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package smux + +import ( + "container/heap" + "encoding/binary" + "errors" + "io" + "net" + "runtime" + "sync" + "sync/atomic" + "time" +) + +const ( + defaultAcceptBacklog = 1024 + maxShaperSize = 1024 + openCloseTimeout = 30 * time.Second // Timeout for opening/closing streams +) + +// CLASSID represents the class of a frame +type CLASSID int + +const ( + CLSCTRL CLASSID = iota // prioritized control signal + CLSDATA +) + +// timeoutError representing timeouts for operations such as accept, read and write +// +// To better cooperate with the standard library, timeoutError should implement the standard library's `net.Error`. +// +// For example, using smux to implement net.Listener and work with http.Server, the keep-alive connection (*smux.Stream) will be unexpectedly closed. +// For more details, see https://github.com/xtaci/smux/pull/99. +type timeoutError struct{} + +func (timeoutError) Error() string { return "timeout" } +func (timeoutError) Temporary() bool { return true } +func (timeoutError) Timeout() bool { return true } + +var ( + ErrInvalidProtocol = errors.New("invalid protocol") + ErrConsumed = errors.New("peer consumed more than sent") + ErrGoAway = errors.New("stream id overflows, should start a new connection") + ErrTimeout net.Error = &timeoutError{} + ErrWouldBlock = errors.New("operation would block on IO") +) + +// writeRequest represents a request to write a frame +type writeRequest struct { + class CLASSID + frame Frame + seq uint32 + result chan writeResult +} + +// writeResult represents the result of a write request +type writeResult struct { + n int + err error +} + +// Session defines a multiplexed connection for streams +type Session struct { + conn io.ReadWriteCloser + + config *Config + nextStreamID uint32 // next stream identifier + nextStreamIDLock sync.Mutex + + bucket int32 // token bucket + bucketNotify chan struct{} // used for waiting for tokens + + streams map[uint32]*stream // all streams in this session + streamLock sync.Mutex // locks streams + + die chan struct{} // flag session has died + dieOnce sync.Once + + // socket error handling + socketReadError atomic.Value + socketWriteError atomic.Value + chSocketReadError chan struct{} + chSocketWriteError chan struct{} + socketReadErrorOnce sync.Once + socketWriteErrorOnce sync.Once + + // smux protocol errors + protoError atomic.Value + chProtoError chan struct{} + protoErrorOnce sync.Once + + chAccepts chan *stream + + dataReady int32 // flag data has arrived + + goAway int32 // flag id exhausted + + deadline atomic.Value + + requestID uint32 // Monotonic increasing write request ID + shaper chan writeRequest // a shaper for writing + writes chan writeRequest +} + +func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { + s := new(Session) + s.die = make(chan struct{}) + s.conn = conn + s.config = config + s.streams = make(map[uint32]*stream) + s.chAccepts = make(chan *stream, defaultAcceptBacklog) + s.bucket = int32(config.MaxReceiveBuffer) + s.bucketNotify = make(chan struct{}, 1) + s.shaper = make(chan writeRequest) + s.writes = make(chan writeRequest) + s.chSocketReadError = make(chan struct{}) + s.chSocketWriteError = make(chan struct{}) + s.chProtoError = make(chan struct{}) + + if client { + s.nextStreamID = 1 + } else { + s.nextStreamID = 0 + } + + go s.shaperLoop() + go s.recvLoop() + go s.sendLoop() + if !config.KeepAliveDisabled { + go s.keepalive() + } + return s +} + +// OpenStream is used to create a new stream +func (s *Session) OpenStream() (*Stream, error) { + if s.IsClosed() { + return nil, io.ErrClosedPipe + } + + // generate stream id + s.nextStreamIDLock.Lock() + if s.goAway > 0 { + s.nextStreamIDLock.Unlock() + return nil, ErrGoAway + } + + s.nextStreamID += 2 + sid := s.nextStreamID + if sid == sid%2 { // stream-id overflows + s.goAway = 1 + s.nextStreamIDLock.Unlock() + return nil, ErrGoAway + } + s.nextStreamIDLock.Unlock() + + stream := newStream(sid, s.config.MaxFrameSize, s) + + if _, err := s.writeControlFrame(newFrame(byte(s.config.Version), cmdSYN, sid)); err != nil { + return nil, err + } + + s.streamLock.Lock() + defer s.streamLock.Unlock() + select { + case <-s.chSocketReadError: + return nil, s.socketReadError.Load().(error) + case <-s.chSocketWriteError: + return nil, s.socketWriteError.Load().(error) + case <-s.die: + return nil, io.ErrClosedPipe + default: + s.streams[sid] = stream + wrapper := &Stream{stream: stream} + // NOTE(x): disabled finalizer for issue #997 + /* + runtime.SetFinalizer(wrapper, func(s *Stream) { + s.Close() + }) + */ + return wrapper, nil + } +} + +// Open returns a generic ReadWriteCloser +func (s *Session) Open() (io.ReadWriteCloser, error) { + return s.OpenStream() +} + +// AcceptStream is used to block until the next available stream +// is ready to be accepted. +func (s *Session) AcceptStream() (*Stream, error) { + var deadline <-chan time.Time + if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() { + timer := time.NewTimer(time.Until(d)) + defer timer.Stop() + deadline = timer.C + } + + select { + case stream := <-s.chAccepts: + wrapper := &Stream{stream: stream} + runtime.SetFinalizer(wrapper, func(s *Stream) { + s.Close() + }) + return wrapper, nil + case <-deadline: + return nil, ErrTimeout + case <-s.chSocketReadError: + return nil, s.socketReadError.Load().(error) + case <-s.chProtoError: + return nil, s.protoError.Load().(error) + case <-s.die: + return nil, io.ErrClosedPipe + } +} + +// Accept Returns a generic ReadWriteCloser instead of smux.Stream +func (s *Session) Accept() (io.ReadWriteCloser, error) { + return s.AcceptStream() +} + +// Close is used to close the session and all streams. +func (s *Session) Close() error { + var once bool + s.dieOnce.Do(func() { + close(s.die) + once = true + }) + + if once { + s.streamLock.Lock() + for k := range s.streams { + s.streams[k].sessionClose() + } + s.streamLock.Unlock() + return s.conn.Close() + } else { + return io.ErrClosedPipe + } +} + +// CloseChan can be used by someone who wants to be notified immediately when this +// session is closed +func (s *Session) CloseChan() <-chan struct{} { + return s.die +} + +// notifyBucket notifies recvLoop that bucket is available +func (s *Session) notifyBucket() { + select { + case s.bucketNotify <- struct{}{}: + default: + } +} + +func (s *Session) notifyReadError(err error) { + s.socketReadErrorOnce.Do(func() { + s.socketReadError.Store(err) + close(s.chSocketReadError) + }) +} + +func (s *Session) notifyWriteError(err error) { + s.socketWriteErrorOnce.Do(func() { + s.socketWriteError.Store(err) + close(s.chSocketWriteError) + }) +} + +func (s *Session) notifyProtoError(err error) { + s.protoErrorOnce.Do(func() { + s.protoError.Store(err) + close(s.chProtoError) + }) +} + +// IsClosed does a safe check to see if we have shutdown +func (s *Session) IsClosed() bool { + select { + case <-s.die: + return true + default: + return false + } +} + +// NumStreams returns the number of currently open streams +func (s *Session) NumStreams() int { + if s.IsClosed() { + return 0 + } + s.streamLock.Lock() + defer s.streamLock.Unlock() + return len(s.streams) +} + +// SetDeadline sets a deadline used by Accept* calls. +// A zero time value disables the deadline. +func (s *Session) SetDeadline(t time.Time) error { + s.deadline.Store(t) + return nil +} + +// LocalAddr satisfies net.Conn interface +func (s *Session) LocalAddr() net.Addr { + if ts, ok := s.conn.(interface { + LocalAddr() net.Addr + }); ok { + return ts.LocalAddr() + } + return nil +} + +// RemoteAddr satisfies net.Conn interface +func (s *Session) RemoteAddr() net.Addr { + if ts, ok := s.conn.(interface { + RemoteAddr() net.Addr + }); ok { + return ts.RemoteAddr() + } + return nil +} + +// notify the session that a stream has closed +func (s *Session) streamClosed(sid uint32) { + s.streamLock.Lock() + if stream, ok := s.streams[sid]; ok { + n := stream.recycleTokens() + if n > 0 { // return remaining tokens to the bucket + if atomic.AddInt32(&s.bucket, int32(n)) > 0 { + s.notifyBucket() + } + } + delete(s.streams, sid) + } + s.streamLock.Unlock() +} + +// returnTokens is called by stream to return token after read +func (s *Session) returnTokens(n int) { + if atomic.AddInt32(&s.bucket, int32(n)) > 0 { + s.notifyBucket() + } +} + +// recvLoop keeps on reading from underlying connection if tokens are available +func (s *Session) recvLoop() { + var hdr rawHeader + var updHdr updHeader + + for { + for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() { + select { + case <-s.bucketNotify: + case <-s.die: + return + } + } + + // read header first + if _, err := io.ReadFull(s.conn, hdr[:]); err == nil { + atomic.StoreInt32(&s.dataReady, 1) + if hdr.Version() != byte(s.config.Version) { + s.notifyProtoError(ErrInvalidProtocol) + return + } + sid := hdr.StreamID() + switch hdr.Cmd() { + case cmdNOP: + case cmdSYN: // stream opening + s.streamLock.Lock() + if _, ok := s.streams[sid]; !ok { + stream := newStream(sid, s.config.MaxFrameSize, s) + s.streams[sid] = stream + select { + case s.chAccepts <- stream: + case <-s.die: + } + } + s.streamLock.Unlock() + case cmdFIN: // stream closing + s.streamLock.Lock() + if stream, ok := s.streams[sid]; ok { + stream.fin() + stream.notifyReadEvent() + } + s.streamLock.Unlock() + case cmdPSH: // data frame + if hdr.Length() > 0 { + pNewbuf := defaultAllocator.Get(int(hdr.Length())) + if written, err := io.ReadFull(s.conn, *pNewbuf); err == nil { + s.streamLock.Lock() + if stream, ok := s.streams[sid]; ok { + stream.pushBytes(pNewbuf) + // a stream used some token + atomic.AddInt32(&s.bucket, -int32(written)) + stream.notifyReadEvent() + } else { + // data directed to a missing/closed stream, recycle the buffer immediately. + defaultAllocator.Put(pNewbuf) + } + s.streamLock.Unlock() + } else { + s.notifyReadError(err) + return + } + } + case cmdUPD: // a window update signal + if _, err := io.ReadFull(s.conn, updHdr[:]); err == nil { + s.streamLock.Lock() + if stream, ok := s.streams[sid]; ok { + stream.update(updHdr.Consumed(), updHdr.Window()) + } + s.streamLock.Unlock() + } else { + s.notifyReadError(err) + return + } + default: + s.notifyProtoError(ErrInvalidProtocol) + return + } + } else { + s.notifyReadError(err) + return + } + } +} + +// keepalive sends NOP frame to peer to keep the connection alive, and detect dead peers +func (s *Session) keepalive() { + tickerPing := time.NewTicker(s.config.KeepAliveInterval) + tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout) + defer tickerPing.Stop() + defer tickerTimeout.Stop() + for { + select { + case <-tickerPing.C: + s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, CLSCTRL) + s.notifyBucket() // force a signal to the recvLoop + case <-tickerTimeout.C: + if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { + // recvLoop may block while bucket is 0, in this case, + // session should not be closed. + if atomic.LoadInt32(&s.bucket) > 0 { + s.Close() + return + } + } + case <-s.die: + return + } + } +} + +// shaperLoop implements a priority queue for write requests, +// some control messages are prioritized over data messages +func (s *Session) shaperLoop() { + var reqs shaperHeap + var next writeRequest + var chWrite chan writeRequest + var chShaper chan writeRequest + + for { + // chWrite is not available until it has packet to send + if len(reqs) > 0 { + chWrite = s.writes + next = heap.Pop(&reqs).(writeRequest) + } else { + chWrite = nil + } + + // control heap size, chShaper is not available until packets are less than maximum allowed + if len(reqs) >= maxShaperSize { + chShaper = nil + } else { + chShaper = s.shaper + } + + // assertion on non nil + if chShaper == nil && chWrite == nil { + panic("both channel are nil") + } + + select { + case <-s.die: + return + case r := <-chShaper: + if chWrite != nil { // next is valid, reshape + heap.Push(&reqs, next) + } + heap.Push(&reqs, r) + case chWrite <- next: + } + } +} + +// sendLoop sends frames to the underlying connection +func (s *Session) sendLoop() { + var buf []byte + var n int + var err error + var vec [][]byte // vector for writeBuffers + + bw, ok := s.conn.(interface { + WriteBuffers(v [][]byte) (n int, err error) + }) + + if ok { + buf = make([]byte, headerSize) + vec = make([][]byte, 2) + } else { + buf = make([]byte, (1<<16)+headerSize) + } + + for { + select { + case <-s.die: + return + case request := <-s.writes: + buf[0] = request.frame.ver + buf[1] = request.frame.cmd + binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data))) + binary.LittleEndian.PutUint32(buf[4:], request.frame.sid) + + // support for scatter-gather I/O + if len(vec) > 0 { + vec[0] = buf[:headerSize] + vec[1] = request.frame.data + n, err = bw.WriteBuffers(vec) + } else { + copy(buf[headerSize:], request.frame.data) + n, err = s.conn.Write(buf[:headerSize+len(request.frame.data)]) + } + + n -= headerSize + if n < 0 { + n = 0 + } + + result := writeResult{ + n: n, + err: err, + } + + request.result <- result + close(request.result) + + // store conn error + if err != nil { + s.notifyWriteError(err) + return + } + } + } +} + +// writeControlFrame writes the control frame to the underlying connection +// and returns the number of bytes written if successful +func (s *Session) writeControlFrame(f Frame) (n int, err error) { + timer := time.NewTimer(openCloseTimeout) + defer timer.Stop() + + return s.writeFrameInternal(f, timer.C, CLSCTRL) +} + +// internal writeFrame version to support deadline used in keepalive +func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, class CLASSID) (int, error) { + req := writeRequest{ + class: class, + frame: f, + seq: atomic.AddUint32(&s.requestID, 1), + result: make(chan writeResult, 1), + } + select { + case s.shaper <- req: + case <-s.die: + return 0, io.ErrClosedPipe + case <-s.chSocketWriteError: + return 0, s.socketWriteError.Load().(error) + case <-deadline: + return 0, ErrTimeout + } + + select { + case result := <-req.result: + return result.n, result.err + case <-s.die: + return 0, io.ErrClosedPipe + case <-s.chSocketWriteError: + return 0, s.socketWriteError.Load().(error) + case <-deadline: + return 0, ErrTimeout + } +} diff --git a/vendor/github.com/xtaci/smux/shaper.go b/vendor/github.com/xtaci/smux/shaper.go new file mode 100644 index 000000000..27ea4e49c --- /dev/null +++ b/vendor/github.com/xtaci/smux/shaper.go @@ -0,0 +1,56 @@ +// MIT License +// +// Copyright (c) 2016-2017 xtaci +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package smux + +// _itimediff returns the time difference between two uint32 values. +// The result is a signed 32-bit integer representing the difference between 'later' and 'earlier'. +func _itimediff(later, earlier uint32) int32 { + return (int32)(later - earlier) +} + +// shaperHeap is a min-heap of writeRequest. +// It orders writeRequests by class first, then by sequence number within the same class. +type shaperHeap []writeRequest + +func (h shaperHeap) Len() int { return len(h) } + +// Less determines the ordering of elements in the heap. +// Requests are ordered by their class first. If two requests have the same class, +// they are ordered by their sequence numbers. +func (h shaperHeap) Less(i, j int) bool { + if h[i].class != h[j].class { + return h[i].class < h[j].class + } + return _itimediff(h[j].seq, h[i].seq) > 0 +} + +func (h shaperHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *shaperHeap) Push(x interface{}) { *h = append(*h, x.(writeRequest)) } + +func (h *shaperHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} diff --git a/vendor/github.com/xtaci/smux/stream.go b/vendor/github.com/xtaci/smux/stream.go new file mode 100644 index 000000000..6ce6ccbcb --- /dev/null +++ b/vendor/github.com/xtaci/smux/stream.go @@ -0,0 +1,619 @@ +// MIT License +// +// Copyright (c) 2016-2017 xtaci +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package smux + +import ( + "encoding/binary" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +// wrapper for GC +type Stream struct { + *stream +} + +// Stream implements net.Conn +type stream struct { + id uint32 // Stream identifier + sess *Session + + buffers []*[]byte // the sequential buffers of stream + heads []*[]byte // slice heads of the buffers above, kept for recycle + + bufferLock sync.Mutex // Mutex to protect access to buffers + frameSize int // Maximum frame size for the stream + + // notify a read event + chReadEvent chan struct{} + + // flag the stream has closed + die chan struct{} + dieOnce sync.Once // Ensures die channel is closed only once + + // FIN command + chFinEvent chan struct{} + finEventOnce sync.Once // Ensures chFinEvent is closed only once + + // deadlines + readDeadline atomic.Value + writeDeadline atomic.Value + + // per stream sliding window control + numRead uint32 // count num of bytes read + numWritten uint32 // count num of bytes written + incr uint32 // bytes sent since last window update + + // UPD command + peerConsumed uint32 // num of bytes the peer has consumed + peerWindow uint32 // peer window, initialized to 256KB, updated by peer + chUpdate chan struct{} // notify of remote data consuming and window update +} + +// newStream initializes and returns a new Stream. +func newStream(id uint32, frameSize int, sess *Session) *stream { + s := new(stream) + s.id = id + s.chReadEvent = make(chan struct{}, 1) + s.chUpdate = make(chan struct{}, 1) + s.frameSize = frameSize + s.sess = sess + s.die = make(chan struct{}) + s.chFinEvent = make(chan struct{}) + s.peerWindow = initialPeerWindow // set to initial window size + + return s +} + +// ID returns the stream's unique identifier. +func (s *stream) ID() uint32 { + return s.id +} + +// Read reads data from the stream into the provided buffer. +func (s *stream) Read(b []byte) (n int, err error) { + for { + n, err = s.tryRead(b) + if err == ErrWouldBlock { + if ew := s.waitRead(); ew != nil { + return 0, ew + } + } else { + return n, err + } + } +} + +// tryRead attempts to read data from the stream without blocking. +func (s *stream) tryRead(b []byte) (n int, err error) { + if s.sess.config.Version == 2 { + return s.tryReadv2(b) + } + + if len(b) == 0 { + return 0, nil + } + + // A critical section to copy data from buffers to + s.bufferLock.Lock() + if len(s.buffers) > 0 { + n = copy(b, *s.buffers[0]) + s.buffers[0] = s.buffers[0] + *s.buffers[0] = (*s.buffers[0])[n:] + if len(*s.buffers[0]) == 0 { + s.buffers[0] = nil + s.buffers = s.buffers[1:] + // full recycle + defaultAllocator.Put(s.heads[0]) + s.heads = s.heads[1:] + } + } + s.bufferLock.Unlock() + + if n > 0 { + s.sess.returnTokens(n) + return n, nil + } + + select { + case <-s.die: + return 0, io.EOF + default: + return 0, ErrWouldBlock + } +} + +// tryReadv2 is the non-blocking version of Read for version 2 streams. +func (s *stream) tryReadv2(b []byte) (n int, err error) { + if len(b) == 0 { + return 0, nil + } + + var notifyConsumed uint32 + s.bufferLock.Lock() + if len(s.buffers) > 0 { + n = copy(b, *s.buffers[0]) + s.buffers[0] = s.buffers[0] + *s.buffers[0] = (*s.buffers[0])[n:] + if len(*s.buffers[0]) == 0 { + s.buffers[0] = nil + s.buffers = s.buffers[1:] + // full recycle + defaultAllocator.Put(s.heads[0]) + s.heads = s.heads[1:] + } + } + + // in an ideal environment: + // if more than half of buffer has consumed, send read ack to peer + // based on round-trip time of ACK, continous flowing data + // won't slow down due to waiting for ACK, as long as the + // consumer keeps on reading data. + // + // s.numRead == n implies that it's the initial reading + s.numRead += uint32(n) + s.incr += uint32(n) + + // for initial reading, send window update + if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(n) { + notifyConsumed = s.numRead + s.incr = 0 // reset couting for next window update + } + s.bufferLock.Unlock() + + if n > 0 { + s.sess.returnTokens(n) + + // send window update if necessary + if notifyConsumed > 0 { + err := s.sendWindowUpdate(notifyConsumed) + return n, err + } else { + return n, nil + } + } + + select { + case <-s.die: + return 0, io.EOF + default: + return 0, ErrWouldBlock + } +} + +// WriteTo implements io.WriteTo +// WriteTo writes data to w until there's no more data to write or when an error occurs. +// The return value n is the number of bytes written. Any error encountered during the write is also returned. +// WriteTo calls Write in a loop until there is no more data to write or when an error occurs. +// If the underlying stream is a v2 stream, it will send window update to peer when necessary. +// If the underlying stream is a v1 stream, it will not send window update to peer. +func (s *stream) WriteTo(w io.Writer) (n int64, err error) { + if s.sess.config.Version == 2 { + return s.writeTov2(w) + } + + for { + var pbuf *[]byte + s.bufferLock.Lock() + if len(s.buffers) > 0 { + pbuf = s.buffers[0] + s.buffers = s.buffers[1:] + s.heads = s.heads[1:] + } + s.bufferLock.Unlock() + + if pbuf != nil { + nw, ew := w.Write(*pbuf) + // NOTE: WriteTo is a reader, so we need to return tokens here + s.sess.returnTokens(len(*pbuf)) + defaultAllocator.Put(pbuf) + if nw > 0 { + n += int64(nw) + } + + if ew != nil { + return n, ew + } + } else if ew := s.waitRead(); ew != nil { + return n, ew + } + } +} + +// check comments in WriteTo +func (s *stream) writeTov2(w io.Writer) (n int64, err error) { + for { + var notifyConsumed uint32 + var pbuf *[]byte + s.bufferLock.Lock() + if len(s.buffers) > 0 { + pbuf = s.buffers[0] + s.buffers = s.buffers[1:] + s.heads = s.heads[1:] + } + var bufLen uint32 + if pbuf != nil { + bufLen = uint32(len(*pbuf)) + } + s.numRead += bufLen + s.incr += bufLen + if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == bufLen { + notifyConsumed = s.numRead + s.incr = 0 + } + s.bufferLock.Unlock() + + if pbuf != nil { + nw, ew := w.Write(*pbuf) + // NOTE: WriteTo is a reader, so we need to return tokens here + s.sess.returnTokens(len(*pbuf)) + defaultAllocator.Put(pbuf) + if nw > 0 { + n += int64(nw) + } + + if ew != nil { + return n, ew + } + + if notifyConsumed > 0 { + if err := s.sendWindowUpdate(notifyConsumed); err != nil { + return n, err + } + } + } else if ew := s.waitRead(); ew != nil { + return n, ew + } + } +} + +// sendWindowUpdate sends a window update frame to the peer. +func (s *stream) sendWindowUpdate(consumed uint32) error { + var timer *time.Timer + var deadline <-chan time.Time + if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() { + timer = time.NewTimer(time.Until(d)) + defer timer.Stop() + deadline = timer.C + } + + frame := newFrame(byte(s.sess.config.Version), cmdUPD, s.id) + var hdr updHeader + binary.LittleEndian.PutUint32(hdr[:], consumed) + binary.LittleEndian.PutUint32(hdr[4:], uint32(s.sess.config.MaxStreamBuffer)) + frame.data = hdr[:] + _, err := s.sess.writeFrameInternal(frame, deadline, CLSCTRL) + return err +} + +// waitRead blocks until a read event occurs or a deadline is reached. +func (s *stream) waitRead() error { + var timer *time.Timer + var deadline <-chan time.Time + if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() { + timer = time.NewTimer(time.Until(d)) + defer timer.Stop() + deadline = timer.C + } + + select { + case <-s.chReadEvent: // notify some data has arrived, or closed + return nil + case <-s.chFinEvent: + // BUGFIX(xtaci): Fix for https://github.com/xtaci/smux/issues/82 + s.bufferLock.Lock() + defer s.bufferLock.Unlock() + if len(s.buffers) > 0 { + return nil + } + return io.EOF + case <-s.sess.chSocketReadError: + return s.sess.socketReadError.Load().(error) + case <-s.sess.chProtoError: + return s.sess.protoError.Load().(error) + case <-deadline: + return ErrTimeout + case <-s.die: + return io.ErrClosedPipe + } + +} + +// Write implements net.Conn +// +// Note that the behavior when multiple goroutines write concurrently is not deterministic, +// frames may interleave in random way. +func (s *stream) Write(b []byte) (n int, err error) { + if s.sess.config.Version == 2 { + return s.writeV2(b) + } + + var deadline <-chan time.Time + if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() { + timer := time.NewTimer(time.Until(d)) + defer timer.Stop() + deadline = timer.C + } + + // check if stream has closed + select { + case <-s.chFinEvent: // passive closing + return 0, io.EOF + case <-s.die: + return 0, io.ErrClosedPipe + default: + } + + // frame split and transmit + sent := 0 + frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id) + bts := b + for len(bts) > 0 { + sz := len(bts) + if sz > s.frameSize { + sz = s.frameSize + } + frame.data = bts[:sz] + bts = bts[sz:] + n, err := s.sess.writeFrameInternal(frame, deadline, CLSDATA) + s.numWritten++ + sent += n + if err != nil { + return sent, err + } + } + + return sent, nil +} + +// writeV2 writes data to the stream for version 2 streams. +func (s *stream) writeV2(b []byte) (n int, err error) { + // check empty input + if len(b) == 0 { + return 0, nil + } + + // check if stream has closed + select { + case <-s.chFinEvent: + return 0, io.EOF + case <-s.die: + return 0, io.ErrClosedPipe + default: + } + + // create write deadline timer + var deadline <-chan time.Time + if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() { + timer := time.NewTimer(time.Until(d)) + defer timer.Stop() + deadline = timer.C + } + + // frame split and transmit process + sent := 0 + frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id) + + for { + // per stream sliding window control + // [.... [consumed... numWritten] ... win... ] + // [.... [consumed...................+rmtwnd]] + var bts []byte + // note: + // even if uint32 overflow, this math still works: + // eg1: uint32(0) - uint32(math.MaxUint32) = 1 + // eg2: int32(uint32(0) - uint32(1)) = -1 + // + // basicially, you can take it as a MODULAR ARITHMETIC + inflight := int32(atomic.LoadUint32(&s.numWritten) - atomic.LoadUint32(&s.peerConsumed)) + if inflight < 0 { // security check for malformed data + return 0, ErrConsumed + } + + // make sure you understand 'win' is calculated in modular arithmetic(2^32(4GB)) + win := int32(atomic.LoadUint32(&s.peerWindow)) - inflight + + if win > 0 { + // determine how many bytes to send + if win > int32(len(b)) { + bts = b + b = nil + } else { + bts = b[:win] + b = b[win:] + } + + // frame split and transmit + for len(bts) > 0 { + // splitting frame + sz := len(bts) + if sz > s.frameSize { + sz = s.frameSize + } + frame.data = bts[:sz] + bts = bts[sz:] + + // transmit of frame + n, err := s.sess.writeFrameInternal(frame, deadline, CLSDATA) + atomic.AddUint32(&s.numWritten, uint32(sz)) + sent += n + if err != nil { + return sent, err + } + } + } + + // if there is any data left to be sent, + // wait until stream closes, window changes or deadline reached + // this blocking behavior will back propagate flow control to upper layer. + if len(b) > 0 { + select { + case <-s.chFinEvent: + return 0, io.EOF + case <-s.die: + return sent, io.ErrClosedPipe + case <-deadline: + return sent, ErrTimeout + case <-s.sess.chSocketWriteError: + return sent, s.sess.socketWriteError.Load().(error) + case <-s.chUpdate: // notify of remote data consuming and window update + continue + } + } else { + return sent, nil + } + } +} + +// Close implements net.Conn +func (s *stream) Close() error { + var once bool + var err error + s.dieOnce.Do(func() { + close(s.die) + once = true + }) + + if once { + // send FIN in order + f := newFrame(byte(s.sess.config.Version), cmdFIN, s.id) + + timer := time.NewTimer(openCloseTimeout) + defer timer.Stop() + + _, err = s.sess.writeFrameInternal(f, timer.C, CLSDATA) + s.sess.streamClosed(s.id) + return err + } else { + return io.ErrClosedPipe + } +} + +// GetDieCh returns a readonly chan which can be readable +// when the stream is to be closed. +func (s *stream) GetDieCh() <-chan struct{} { + return s.die +} + +// SetReadDeadline sets the read deadline as defined by +// net.Conn.SetReadDeadline. +// A zero time value disables the deadline. +func (s *stream) SetReadDeadline(t time.Time) error { + s.readDeadline.Store(t) + s.notifyReadEvent() + return nil +} + +// SetWriteDeadline sets the write deadline as defined by +// net.Conn.SetWriteDeadline. +// A zero time value disables the deadline. +func (s *stream) SetWriteDeadline(t time.Time) error { + s.writeDeadline.Store(t) + return nil +} + +// SetDeadline sets both read and write deadlines as defined by +// net.Conn.SetDeadline. +// A zero time value disables the deadlines. +func (s *stream) SetDeadline(t time.Time) error { + if err := s.SetReadDeadline(t); err != nil { + return err + } + if err := s.SetWriteDeadline(t); err != nil { + return err + } + return nil +} + +// session closes +func (s *stream) sessionClose() { s.dieOnce.Do(func() { close(s.die) }) } + +// LocalAddr satisfies net.Conn interface +func (s *stream) LocalAddr() net.Addr { + if ts, ok := s.sess.conn.(interface { + LocalAddr() net.Addr + }); ok { + return ts.LocalAddr() + } + return nil +} + +// RemoteAddr satisfies net.Conn interface +func (s *stream) RemoteAddr() net.Addr { + if ts, ok := s.sess.conn.(interface { + RemoteAddr() net.Addr + }); ok { + return ts.RemoteAddr() + } + return nil +} + +// pushBytes append buf to buffers +func (s *stream) pushBytes(pbuf *[]byte) (written int, err error) { + s.bufferLock.Lock() + s.buffers = append(s.buffers, pbuf) + s.heads = append(s.heads, pbuf) + s.bufferLock.Unlock() + return +} + +// recycleTokens transform remaining bytes to tokens(will truncate buffer) +func (s *stream) recycleTokens() (n int) { + s.bufferLock.Lock() + for k := range s.buffers { + n += len(*s.buffers[k]) + defaultAllocator.Put(s.heads[k]) + } + s.buffers = nil + s.heads = nil + s.bufferLock.Unlock() + return +} + +// notify read event +func (s *stream) notifyReadEvent() { + select { + case s.chReadEvent <- struct{}{}: + default: + } +} + +// update command +func (s *stream) update(consumed uint32, window uint32) { + atomic.StoreUint32(&s.peerConsumed, consumed) + atomic.StoreUint32(&s.peerWindow, window) + select { + case s.chUpdate <- struct{}{}: + default: + } +} + +// mark this stream has been closed in protocol +func (s *stream) fin() { + s.finEventOnce.Do(func() { + close(s.chFinEvent) + }) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 811c0f3ca..5b13a851f 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -206,7 +206,7 @@ github.com/skycoin/skycoin/src/cipher/ripemd160 github.com/skycoin/skycoin/src/cipher/secp256k1-go github.com/skycoin/skycoin/src/cipher/secp256k1-go/secp256k1-go2 github.com/skycoin/skycoin/src/util/logging -# github.com/skycoin/skywire v1.3.31-0.20250724153549-ec7ca3554d42 +# github.com/skycoin/skywire v1.3.31-0.20250810155428-30d83a379b39 ## explicit; go 1.24 github.com/skycoin/skywire/deployment github.com/skycoin/skywire/pkg/skywire-utilities/pkg/buildinfo @@ -258,6 +258,9 @@ github.com/valyala/fastrand # github.com/valyala/histogram v1.2.0 ## explicit; go 1.12 github.com/valyala/histogram +# github.com/xtaci/smux v1.5.34 +## explicit; go 1.13 +github.com/xtaci/smux # golang.org/x/arch v0.20.0 ## explicit; go 1.23.0 golang.org/x/arch/x86/x86asm