From 3632c35798fc3aa59285e52083b81590e293a299 Mon Sep 17 00:00:00 2001 From: halulu Date: Sat, 2 Jun 2018 19:25:03 +0800 Subject: [PATCH 1/8] simple mux --- core/client.go | 47 ++++++----- core/mux.go | 195 ++++++++++++++++++++++++++++++++++++++++++++++ core/server.go | 7 ++ core/websocket.go | 1 + websocks.go | 10 ++- 5 files changed, 241 insertions(+), 19 deletions(-) create mode 100644 core/mux.go diff --git a/core/client.go b/core/client.go index d69c488..225fa6d 100644 --- a/core/client.go +++ b/core/client.go @@ -17,6 +17,10 @@ type Client struct { ListenAddr *net.TCPAddr URL *url.URL + Mux bool + MuxS []*Mux + MuxTCPConn map[uint64]*net.TCPConn + Dialer *websocket.Dialer CreatedAt time.Time @@ -25,6 +29,16 @@ type Client struct { func (client *Client) Listen() (err error) { logger.SetLogLevel(client.LogLevel) + if client.Mux { + mux, err := client.DialMux() + if err != nil { + return err + } + + client.MuxS = []*Mux{mux} + logger.Debugf("mux ok") + } + listener, err := net.ListenTCP("tcp", client.ListenAddr) if err != nil { return err @@ -68,18 +82,30 @@ func (client *Client) handleConn(conn *net.TCPConn) (err error) { return } - logger.Debugf("Host: %s", host) - _, err = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x08, 0x43}) if err != nil { return } - ws, err := client.dial(host) + if client.Mux { + client.ClientHandleMux(conn, host) + return + } + + wsConn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ + "WebSocks-Host": {host}, + }) + if err != nil { return } + logger.Debugf("host: %s", host) + + ws := &WebSocket{ + conn: wsConn, + } + go func() { _, err = io.Copy(ws, conn) if err != nil { @@ -96,18 +122,3 @@ func (client *Client) handleConn(conn *net.TCPConn) (err error) { return } - -func (client *Client) dial(host string) (ws *WebSocket, err error) { - conn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ - "WebSocks-Host": {host}, - }) - - if err != nil { - return - } - - ws = &WebSocket{ - conn: conn, - } - return -} diff --git a/core/mux.go b/core/mux.go new file mode 100644 index 0000000..4c1673d --- /dev/null +++ b/core/mux.go @@ -0,0 +1,195 @@ +package core + +import ( + "encoding/gob" + "math/rand" + "net" + "time" +) + +type Mux struct { + WS *WebSocket + Decoder *gob.Decoder + Encoder *gob.Encoder +} + +type MuxRequest struct { + ID uint64 + Method string + Data []byte +} + +type MuxResponse struct { + ID uint64 + Data []byte +} + +//DialMux dial new mux conn, listen and write to local conn +func (client *Client) DialMux() (mux *Mux, err error) { + conn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ + "WebSocks-Mux": {"mux"}, + }) + + if err != nil { + return + } + + ws := &WebSocket{ + conn: conn, + } + + mux = &Mux{ + WS: ws, + Decoder: gob.NewDecoder(ws), + Encoder: gob.NewEncoder(ws), + } + + go func() { + for { + resp := &MuxResponse{} + err := mux.Decoder.Decode(resp) + if err != nil { + logger.Errorf(err.Error()) + return + } + + conn := client.MuxTCPConn[resp.ID] + _, err = conn.Write(resp.Data) + if err != nil { + logger.Errorf(err.Error()) + continue + } + } + }() + + return +} + +func (server *Server) ServerHandleMux(ws *WebSocket) { + dec := gob.NewDecoder(ws) + enc := gob.NewEncoder(ws) + + for { + req := &MuxRequest{} + err := dec.Decode(req) + if err != nil { + logger.Errorf(err.Error()) + continue + } + + go server.ServerHandleMuxRequest(req, enc) + } + return +} + +func (server *Server) ServerHandleMuxRequest(req *MuxRequest, enc *gob.Encoder) { + id := req.ID + method := req.Method + var err error + + //If method is "new", dial new and listen + if method == "new" { + host := string(req.Data) + conn, err := net.Dial("tcp", host) + if err != nil { + logger.Errorf(err.Error()) + return + } + + server.MuxConn[id] = conn + logger.Debugf("dialed %s, id %d", host, id) + + go func() { + data := make([]byte, 32*1024) + for { + n, err := conn.Read(data) + if err != nil { + logger.Errorf(err.Error()) + return + } + + resp := &MuxResponse{ + ID: id, + Data: data[:n], + } + + err = enc.Encode(resp) + if err != nil { + logger.Errorf(err.Error()) + continue + } + } + }() + + return + } + + //If method is not "new", write data + conn := server.MuxConn[id] + if conn == nil { + time.Sleep(3 * time.Second) + conn = server.MuxConn[id] + if conn == nil { + logger.Errorf("conn %d does not exist", id) + } + return + } + _, err = conn.Write(req.Data) + if err != nil { + logger.Errorf(err.Error()) + conn.Close() + delete(server.MuxConn, id) + logger.Debugf("conn %d closed", id) + return + } +} + +//ClientHandleMux read from local conn and write to remote server +func (client *Client) ClientHandleMux(conn *net.TCPConn, host string) { + mux := client.MuxS[0] + id, err := mux.NewConn(host) + if err != nil { + logger.Errorf(err.Error()) + return + } + + client.MuxTCPConn[id] = conn + + data := make([]byte, 32*1024) + + for { + n, err := conn.Read(data) + if err != nil { + logger.Errorf(err.Error()) + return + } + + req := &MuxRequest{ + ID: id, + Data: data[:n], + } + + err = mux.Encoder.Encode(req) + if err != nil { + logger.Errorf(err.Error()) + continue + } + } +} + +func (mux *Mux) NewConn(host string) (id uint64, err error) { + id = rand.Uint64() + req := &MuxRequest{ + ID: id, + Method: "new", + Data: []byte(host), + } + + err = mux.Encoder.Encode(req) + if err != nil { + return + } + + logger.Debugf("mux dialed %s", host) + return +} diff --git a/core/server.go b/core/server.go index 149fd8a..93349f2 100644 --- a/core/server.go +++ b/core/server.go @@ -24,6 +24,8 @@ type Server struct { KeyPath string Proxy string + MuxConn map[uint64]net.Conn + Upgrader *websocket.Upgrader CreatedAt time.Time @@ -48,6 +50,11 @@ func (server *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) { atomic.AddUint64(&server.Opened, 1) defer atomic.AddUint64(&server.Closed, 1) + if r.Header.Get("WebSocks-Mux") == "mux" { + server.ServerHandleMux(ws) + return + } + host := r.Header.Get("WebSocks-Host") logger.Debugf("Dial %s", host) diff --git a/core/websocket.go b/core/websocket.go index 3f1b69e..979b299 100644 --- a/core/websocket.go +++ b/core/websocket.go @@ -19,6 +19,7 @@ func (ws *WebSocket) Read(p []byte) (n int, err error) { n = copy(p, ws.buf) ws.buf = ws.buf[n:] + return } diff --git a/websocks.go b/websocks.go index 34173ba..172f5f7 100644 --- a/websocks.go +++ b/websocks.go @@ -54,6 +54,10 @@ func main() { Value: "ws://localhost:23333/websocks", Usage: "server url", }, + cli.BoolFlag{ + Name: "mux", + Usage: "mux mode", + }, cli.StringFlag{ Name: "n", Value: "", @@ -68,6 +72,7 @@ func main() { debug := c.GlobalBool("debug") listenAddr := c.String("l") serverURL := c.String("s") + mux := c.Bool("mux") serverName := c.String("n") insecureCert := false if c.Bool("insecure") { @@ -108,7 +113,9 @@ func main() { HandshakeTimeout: 10 * time.Second, TLSClientConfig: tlsConfig, }, - CreatedAt: time.Now(), + Mux: mux, + MuxTCPConn: make(map[uint64]*net.TCPConn), + CreatedAt: time.Now(), } err = local.Listen() @@ -177,6 +184,7 @@ func main() { CertPath: certPath, KeyPath: keyPath, Proxy: proxy, + MuxConn: make(map[uint64]net.Conn), Upgrader: &websocket.Upgrader{ ReadBufferSize: 4 * 1024, WriteBufferSize: 4 * 1024, From 8e9d0089e5590b31ec2b755b679fba1da6490aed Mon Sep 17 00:00:00 2001 From: halulu Date: Sat, 2 Jun 2018 20:20:27 +0800 Subject: [PATCH 2/8] use github.com/xtaci/smux --- core/client.go | 88 +++++++++++++++------ core/mux.go | 192 +--------------------------------------------- core/server.go | 56 +++++++++++++- core/websocket.go | 5 ++ websocks.go | 6 +- 5 files changed, 125 insertions(+), 222 deletions(-) diff --git a/core/client.go b/core/client.go index 225fa6d..667d5a0 100644 --- a/core/client.go +++ b/core/client.go @@ -6,8 +6,11 @@ import ( "net/url" "time" + "encoding/json" + "github.com/gorilla/websocket" "github.com/juju/loggo" + "github.com/xtaci/smux" ) var logger = loggo.GetLogger("core") @@ -17,9 +20,7 @@ type Client struct { ListenAddr *net.TCPAddr URL *url.URL - Mux bool - MuxS []*Mux - MuxTCPConn map[uint64]*net.TCPConn + Mux bool Dialer *websocket.Dialer @@ -29,16 +30,6 @@ type Client struct { func (client *Client) Listen() (err error) { logger.SetLogLevel(client.LogLevel) - if client.Mux { - mux, err := client.DialMux() - if err != nil { - return err - } - - client.MuxS = []*Mux{mux} - logger.Debugf("mux ok") - } - listener, err := net.ListenTCP("tcp", client.ListenAddr) if err != nil { return err @@ -61,34 +52,85 @@ func (client *Client) Listen() (err error) { return nil } -func (client *Client) handleConn(conn *net.TCPConn) (err error) { - defer func() { - if err != nil { - logger.Debugf("Handle connection error: %s", err.Error()) - } - }() - +func (client *Client) handleConn(conn *net.TCPConn) { defer conn.Close() conn.SetLinger(0) - err = handShake(conn) + err := handShake(conn) if err != nil { + logger.Errorf(err.Error()) return } _, host, err := getRequest(conn) if err != nil { + logger.Errorf(err.Error()) return } _, err = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x08, 0x43}) if err != nil { + logger.Errorf(err.Error()) return } + logger.Debugf("host: %s", host) + if client.Mux { - client.ClientHandleMux(conn, host) + wsConn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ + "WebSocks-Mux": {"mux"}, + }) + if err != nil { + logger.Errorf(err.Error()) + return + } + + ws := &WebSocket{ + conn: wsConn, + } + + session, err := smux.Client(ws, nil) + if err != nil { + logger.Errorf(err.Error()) + return + } + + stream, err := session.OpenStream() + if err != nil { + logger.Errorf(err.Error()) + return + } + + req := MuxRequest{ + Host: host, + } + + enc := json.NewEncoder(stream) + err = enc.Encode(req) + + if err != nil { + logger.Errorf(err.Error()) + return + } + + go func() { + _, err = io.Copy(stream, conn) + if err != nil { + logger.Debugf(err.Error()) + stream.Close() + return + } + return + }() + + _, err = io.Copy(conn, stream) + if err != nil { + logger.Errorf(err.Error()) + stream.Close() + return + } + return } @@ -100,8 +142,6 @@ func (client *Client) handleConn(conn *net.TCPConn) (err error) { return } - logger.Debugf("host: %s", host) - ws := &WebSocket{ conn: wsConn, } diff --git a/core/mux.go b/core/mux.go index 4c1673d..4f2e324 100644 --- a/core/mux.go +++ b/core/mux.go @@ -1,195 +1,5 @@ package core -import ( - "encoding/gob" - "math/rand" - "net" - "time" -) - -type Mux struct { - WS *WebSocket - Decoder *gob.Decoder - Encoder *gob.Encoder -} - type MuxRequest struct { - ID uint64 - Method string - Data []byte -} - -type MuxResponse struct { - ID uint64 - Data []byte -} - -//DialMux dial new mux conn, listen and write to local conn -func (client *Client) DialMux() (mux *Mux, err error) { - conn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ - "WebSocks-Mux": {"mux"}, - }) - - if err != nil { - return - } - - ws := &WebSocket{ - conn: conn, - } - - mux = &Mux{ - WS: ws, - Decoder: gob.NewDecoder(ws), - Encoder: gob.NewEncoder(ws), - } - - go func() { - for { - resp := &MuxResponse{} - err := mux.Decoder.Decode(resp) - if err != nil { - logger.Errorf(err.Error()) - return - } - - conn := client.MuxTCPConn[resp.ID] - _, err = conn.Write(resp.Data) - if err != nil { - logger.Errorf(err.Error()) - continue - } - } - }() - - return -} - -func (server *Server) ServerHandleMux(ws *WebSocket) { - dec := gob.NewDecoder(ws) - enc := gob.NewEncoder(ws) - - for { - req := &MuxRequest{} - err := dec.Decode(req) - if err != nil { - logger.Errorf(err.Error()) - continue - } - - go server.ServerHandleMuxRequest(req, enc) - } - return -} - -func (server *Server) ServerHandleMuxRequest(req *MuxRequest, enc *gob.Encoder) { - id := req.ID - method := req.Method - var err error - - //If method is "new", dial new and listen - if method == "new" { - host := string(req.Data) - conn, err := net.Dial("tcp", host) - if err != nil { - logger.Errorf(err.Error()) - return - } - - server.MuxConn[id] = conn - logger.Debugf("dialed %s, id %d", host, id) - - go func() { - data := make([]byte, 32*1024) - for { - n, err := conn.Read(data) - if err != nil { - logger.Errorf(err.Error()) - return - } - - resp := &MuxResponse{ - ID: id, - Data: data[:n], - } - - err = enc.Encode(resp) - if err != nil { - logger.Errorf(err.Error()) - continue - } - } - }() - - return - } - - //If method is not "new", write data - conn := server.MuxConn[id] - if conn == nil { - time.Sleep(3 * time.Second) - conn = server.MuxConn[id] - if conn == nil { - logger.Errorf("conn %d does not exist", id) - } - return - } - _, err = conn.Write(req.Data) - if err != nil { - logger.Errorf(err.Error()) - conn.Close() - delete(server.MuxConn, id) - logger.Debugf("conn %d closed", id) - return - } -} - -//ClientHandleMux read from local conn and write to remote server -func (client *Client) ClientHandleMux(conn *net.TCPConn, host string) { - mux := client.MuxS[0] - id, err := mux.NewConn(host) - if err != nil { - logger.Errorf(err.Error()) - return - } - - client.MuxTCPConn[id] = conn - - data := make([]byte, 32*1024) - - for { - n, err := conn.Read(data) - if err != nil { - logger.Errorf(err.Error()) - return - } - - req := &MuxRequest{ - ID: id, - Data: data[:n], - } - - err = mux.Encoder.Encode(req) - if err != nil { - logger.Errorf(err.Error()) - continue - } - } -} - -func (mux *Mux) NewConn(host string) (id uint64, err error) { - id = rand.Uint64() - req := &MuxRequest{ - ID: id, - Method: "new", - Data: []byte(host), - } - - err = mux.Encoder.Encode(req) - if err != nil { - return - } - - logger.Debugf("mux dialed %s", host) - return + Host string } diff --git a/core/server.go b/core/server.go index 93349f2..6c7c7e8 100644 --- a/core/server.go +++ b/core/server.go @@ -11,8 +11,11 @@ import ( "crypto/tls" + "encoding/json" + "github.com/gorilla/websocket" "github.com/juju/loggo" + "github.com/xtaci/smux" ) type Server struct { @@ -24,8 +27,6 @@ type Server struct { KeyPath string Proxy string - MuxConn map[uint64]net.Conn - Upgrader *websocket.Upgrader CreatedAt time.Time @@ -51,7 +52,56 @@ func (server *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) { defer atomic.AddUint64(&server.Closed, 1) if r.Header.Get("WebSocks-Mux") == "mux" { - server.ServerHandleMux(ws) + session, err := smux.Server(ws, nil) + if err != nil { + logger.Errorf(err.Error()) + return + } + + for { + stream, err := session.AcceptStream() + if err != nil { + logger.Errorf(err.Error()) + return + } + + dec := json.NewDecoder(stream) + req := &MuxRequest{} + err = dec.Decode(req) + if err != nil { + logger.Errorf(err.Error()) + return + } + + host := req.Host + + conn, err := net.Dial("tcp", host) + if err != nil { + if err != nil { + logger.Debugf(err.Error()) + } + return + } + + go func() { + downloaded, err := io.Copy(conn, stream) + atomic.AddUint64(&server.Downloaded, uint64(downloaded)) + if err != nil { + logger.Debugf(err.Error()) + stream.Close() + return + } + }() + + uploaded, err := io.Copy(stream, conn) + atomic.AddUint64(&server.Uploaded, uint64(uploaded)) + if err != nil { + logger.Debugf(err.Error()) + stream.Close() + return + } + + } return } diff --git a/core/websocket.go b/core/websocket.go index 979b299..1749f7b 100644 --- a/core/websocket.go +++ b/core/websocket.go @@ -31,3 +31,8 @@ func (ws *WebSocket) Write(p []byte) (n int, err error) { return len(p), nil } + +func (ws *WebSocket) Close() (err error) { + ws.conn.Close() + return +} diff --git a/websocks.go b/websocks.go index 172f5f7..fc73d17 100644 --- a/websocks.go +++ b/websocks.go @@ -113,9 +113,8 @@ func main() { HandshakeTimeout: 10 * time.Second, TLSClientConfig: tlsConfig, }, - Mux: mux, - MuxTCPConn: make(map[uint64]*net.TCPConn), - CreatedAt: time.Now(), + Mux: mux, + CreatedAt: time.Now(), } err = local.Listen() @@ -184,7 +183,6 @@ func main() { CertPath: certPath, KeyPath: keyPath, Proxy: proxy, - MuxConn: make(map[uint64]net.Conn), Upgrader: &websocket.Upgrader{ ReadBufferSize: 4 * 1024, WriteBufferSize: 4 * 1024, From c74b6b1b73bb587970661a2bd370c2a82d212a57 Mon Sep 17 00:00:00 2001 From: halulu Date: Sat, 2 Jun 2018 22:25:54 +0800 Subject: [PATCH 3/8] smux chan --- core/client.go | 47 +++++++++++++-------------------------- core/mux.go | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++ core/server.go | 4 +--- websocks.go | 8 ++++--- 4 files changed, 81 insertions(+), 38 deletions(-) diff --git a/core/client.go b/core/client.go index 667d5a0..c9c1c4a 100644 --- a/core/client.go +++ b/core/client.go @@ -6,8 +6,6 @@ import ( "net/url" "time" - "encoding/json" - "github.com/gorilla/websocket" "github.com/juju/loggo" "github.com/xtaci/smux" @@ -20,7 +18,9 @@ type Client struct { ListenAddr *net.TCPAddr URL *url.URL - Mux bool + Mux bool + Opened int + StreamChan chan *smux.Stream Dialer *websocket.Dialer @@ -78,37 +78,20 @@ func (client *Client) handleConn(conn *net.TCPConn) { logger.Debugf("host: %s", host) if client.Mux { - wsConn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ - "WebSocks-Mux": {"mux"}, - }) - if err != nil { - logger.Errorf(err.Error()) - return - } - - ws := &WebSocket{ - conn: wsConn, + l := len(client.StreamChan) + c := cap(client.StreamChan) + logger.Debugf("%d %d", l, c) + if l != c { + go func() { + err := client.OpenSession() + if err != nil { + logger.Errorf(err.Error()) + return + } + }() } - session, err := smux.Client(ws, nil) - if err != nil { - logger.Errorf(err.Error()) - return - } - - stream, err := session.OpenStream() - if err != nil { - logger.Errorf(err.Error()) - return - } - - req := MuxRequest{ - Host: host, - } - - enc := json.NewEncoder(stream) - err = enc.Encode(req) - + stream, err := client.GetStream(host) if err != nil { logger.Errorf(err.Error()) return diff --git a/core/mux.go b/core/mux.go index 4f2e324..f47cbbb 100644 --- a/core/mux.go +++ b/core/mux.go @@ -1,5 +1,65 @@ package core +import ( + "encoding/json" + "time" + + "github.com/xtaci/smux" +) + type MuxRequest struct { Host string } + +func (client *Client) OpenSession() (err error) { + wsConn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ + "WebSocks-Mux": {"mux"}, + }) + + if err != nil { + logger.Errorf(err.Error()) + return + } + + ws := &WebSocket{ + conn: wsConn, + } + + session, err := smux.Client(ws, nil) + if err != nil { + return + } + + go func() { + for { + if session.NumStreams() > 2 { + time.Sleep(time.Second) + continue + } + + stream, err := session.OpenStream() + if err != nil { + session.Close() + logger.Errorf(err.Error()) + return + } + + client.StreamChan <- stream + } + return + }() + + return +} + +func (client *Client) GetStream(host string) (stream *smux.Stream, err error) { + stream = <-client.StreamChan + + req := MuxRequest{ + Host: host, + } + + enc := json.NewEncoder(stream) + err = enc.Encode(req) + return +} diff --git a/core/server.go b/core/server.go index 6c7c7e8..6e10b00 100644 --- a/core/server.go +++ b/core/server.go @@ -1,6 +1,7 @@ package core import ( + "encoding/json" "io" "net" "net/http" @@ -11,8 +12,6 @@ import ( "crypto/tls" - "encoding/json" - "github.com/gorilla/websocket" "github.com/juju/loggo" "github.com/xtaci/smux" @@ -100,7 +99,6 @@ func (server *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) { stream.Close() return } - } return } diff --git a/websocks.go b/websocks.go index fc73d17..20dacab 100644 --- a/websocks.go +++ b/websocks.go @@ -17,6 +17,7 @@ import ( "github.com/juju/loggo" "github.com/lzjluzijie/websocks/core" "github.com/urfave/cli" + "github.com/xtaci/smux" ) func main() { @@ -46,7 +47,7 @@ func main() { Flags: []cli.Flag{ cli.StringFlag{ Name: "l", - Value: ":10801", + Value: "127.0.0.1:10801", Usage: "local listening port", }, cli.StringFlag{ @@ -113,8 +114,9 @@ func main() { HandshakeTimeout: 10 * time.Second, TLSClientConfig: tlsConfig, }, - Mux: mux, - CreatedAt: time.Now(), + Mux: mux, + StreamChan: make(chan *smux.Stream, 8), + CreatedAt: time.Now(), } err = local.Listen() From b6235c3921b5b28961e98235ea0708a03ca4c827 Mon Sep 17 00:00:00 2001 From: halulu Date: Sun, 3 Jun 2018 17:59:39 +0800 Subject: [PATCH 4/8] mux client and server (can not use) --- core/client.go | 71 ++++++++----------- core/mux.go | 70 ++++--------------- core/muxclient.go | 169 ++++++++++++++++++++++++++++++++++++++++++++++ core/muxserver.go | 114 +++++++++++++++++++++++++++++++ core/server.go | 58 ++-------------- core/websocket.go | 9 ++- websocks.go | 9 ++- 7 files changed, 340 insertions(+), 160 deletions(-) create mode 100644 core/muxclient.go create mode 100644 core/muxserver.go diff --git a/core/client.go b/core/client.go index c9c1c4a..64457c5 100644 --- a/core/client.go +++ b/core/client.go @@ -8,7 +8,6 @@ import ( "github.com/gorilla/websocket" "github.com/juju/loggo" - "github.com/xtaci/smux" ) var logger = loggo.GetLogger("core") @@ -18,9 +17,7 @@ type Client struct { ListenAddr *net.TCPAddr URL *url.URL - Mux bool - Opened int - StreamChan chan *smux.Stream + Mux bool Dialer *websocket.Dialer @@ -39,6 +36,32 @@ func (client *Client) Listen() (err error) { defer listener.Close() + if client.Mux { + muxClient := &MuxClient{ + Client: client, + MessageChan: make(chan *Message), + } + + for i := 0; i < 4; i++ { + err = muxClient.Open() + if err != nil { + logger.Debugf(err.Error()) + return + } + } + + for { + conn, err := listener.AcceptTCP() + if err != nil { + logger.Debugf(err.Error()) + continue + } + + go muxClient.handleConn(conn) + } + return + } + for { conn, err := listener.AcceptTCP() if err != nil { @@ -77,46 +100,6 @@ func (client *Client) handleConn(conn *net.TCPConn) { logger.Debugf("host: %s", host) - if client.Mux { - l := len(client.StreamChan) - c := cap(client.StreamChan) - logger.Debugf("%d %d", l, c) - if l != c { - go func() { - err := client.OpenSession() - if err != nil { - logger.Errorf(err.Error()) - return - } - }() - } - - stream, err := client.GetStream(host) - if err != nil { - logger.Errorf(err.Error()) - return - } - - go func() { - _, err = io.Copy(stream, conn) - if err != nil { - logger.Debugf(err.Error()) - stream.Close() - return - } - return - }() - - _, err = io.Copy(conn, stream) - if err != nil { - logger.Errorf(err.Error()) - stream.Close() - return - } - - return - } - wsConn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ "WebSocks-Host": {host}, }) diff --git a/core/mux.go b/core/mux.go index f47cbbb..2178343 100644 --- a/core/mux.go +++ b/core/mux.go @@ -1,65 +1,19 @@ package core -import ( - "encoding/json" - "time" - - "github.com/xtaci/smux" +const ( + MessageMethodData = iota + MessageMethodDial ) -type MuxRequest struct { - Host string -} - -func (client *Client) OpenSession() (err error) { - wsConn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ - "WebSocks-Mux": {"mux"}, - }) - - if err != nil { - logger.Errorf(err.Error()) - return - } - - ws := &WebSocket{ - conn: wsConn, - } - - session, err := smux.Client(ws, nil) - if err != nil { - return - } - - go func() { - for { - if session.NumStreams() > 2 { - time.Sleep(time.Second) - continue - } - - stream, err := session.OpenStream() - if err != nil { - session.Close() - logger.Errorf(err.Error()) - return - } - - client.StreamChan <- stream - } - return - }() - - return +type MuxConn struct { + ID int + DataID int + DataChan chan []byte } -func (client *Client) GetStream(host string) (stream *smux.Stream, err error) { - stream = <-client.StreamChan - - req := MuxRequest{ - Host: host, - } - - enc := json.NewEncoder(stream) - err = enc.Encode(req) - return +type Message struct { + Method byte + ConnID int + MessageID int + Data []byte } diff --git a/core/muxclient.go b/core/muxclient.go new file mode 100644 index 0000000..a48f4d6 --- /dev/null +++ b/core/muxclient.go @@ -0,0 +1,169 @@ +package core + +import ( + "encoding/gob" + "errors" + "math/rand" + "net" + "sync" +) + +type MuxClient struct { + *Client + MessageChan chan *Message + muxConnMap sync.Map + Mutex sync.Mutex +} + +func (client *MuxClient) Open() (err error) { + wsConn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ + "WebSocks-Mux": {"mux"}, + }) + + if err != nil { + logger.Errorf(err.Error()) + return + } + + ws := &WebSocket{ + conn: wsConn, + } + + dec := gob.NewDecoder(ws) + enc := gob.NewEncoder(ws) + + go func() { + for { + m := &Message{} + err = dec.Decode(m) + if err != nil { + logger.Debugf(err.Error()) + return + } + + err = client.HandleMessage(m) + if err != nil { + logger.Debugf(err.Error()) + continue + } + } + }() + + go func() { + for { + m := <-client.MessageChan + err = enc.Encode(m) + if err != nil { + logger.Debugf(err.Error()) + return + } + } + }() + return +} + +func (client *MuxClient) Dial(conn *net.TCPConn, host string) { + dataChan := make(chan []byte) + id := rand.Int() + muxConn := &MuxConn{ + DataChan: dataChan, + ID: id, + } + + client.muxConnMap.Store(id, muxConn) + + m := &Message{ + Method: MessageMethodDial, + ConnID: id, + Data: []byte(host), + } + client.MessageChan <- m + + //listen local conn and send message + go func() { + messageID := 1 + buf := make([]byte, 32*1024) + for { + n, err := conn.Read(buf) + if err != nil { + logger.Errorf(err.Error()) + return + } + + println(n) + + dataMessage := &Message{ + Method: MessageMethodData, + ConnID: id, + MessageID: messageID, + Data: buf[:n], + } + + messageID++ + client.MessageChan <- dataMessage + } + }() + + go func() { + for { + _, err := conn.Write(<-dataChan) + if err != nil { + logger.Debugf(err.Error()) + conn.Close() + } + } + }() + + return +} + +func (client *MuxClient) HandleMessage(m *Message) (err error) { + if m.Method != MessageMethodData { + return errors.New("unknown method") + } + + connID := m.ConnID + c, ok := client.muxConnMap.Load(connID) + if !ok { + return errors.New("can not load conn") + } + conn := c.(*MuxConn) + + go func() { + for { + if conn.DataID == m.MessageID { + conn.DataChan <- m.Data + return + } + } + }() + + return +} + +func (client *MuxClient) handleConn(conn *net.TCPConn) { + defer conn.Close() + + conn.SetLinger(0) + + err := handShake(conn) + if err != nil { + logger.Errorf(err.Error()) + return + } + + _, host, err := getRequest(conn) + if err != nil { + logger.Errorf(err.Error()) + return + } + + _, err = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x08, 0x43}) + if err != nil { + logger.Errorf(err.Error()) + return + } + + client.Dial(conn, host) + return +} diff --git a/core/muxserver.go b/core/muxserver.go new file mode 100644 index 0000000..9b44889 --- /dev/null +++ b/core/muxserver.go @@ -0,0 +1,114 @@ +package core + +import ( + "encoding/gob" + "errors" + "net" + "time" +) + +func (server *Server) HandleWS(ws *WebSocket) (err error) { + dec := gob.NewDecoder(ws) + enc := gob.NewEncoder(ws) + + //receive messages + go func() { + for { + m := &Message{} + err = dec.Decode(m) + if err != nil { + logger.Debugf(err.Error()) + return + } + + err = server.HandleMessage(m) + if err != nil { + logger.Debugf(err.Error()) + continue + } + } + }() + + //send messages + go func() { + for { + m := <-server.MessageChan + err = enc.Encode(m) + if err != nil { + logger.Debugf(err.Error()) + return + } + } + }() + + time.Sleep(time.Minute) + return +} + +func (server *Server) HandleMessage(m *Message) (err error) { + if m.Method == MessageMethodDial { + id := m.ConnID + dataChan := make(chan []byte) + conn := &MuxConn{ + ID: id, + DataChan: dataChan, + } + + server.muxConnMap.Store(id, conn) + server.DialRemote(conn, string(m.Data)) + return + } + + if m.Method != MessageMethodData { + return errors.New("unknown method") + } + + connID := m.ConnID + c, ok := server.muxConnMap.Load(connID) + if !ok { + return errors.New("can not load conn") + } + + conn := c.(*MuxConn) + go func() { + for { + if conn.DataID == m.MessageID { + conn.DataChan <- m.Data + return + } + } + }() + + return +} + +func (server *Server) DialRemote(muxConn *MuxConn, host string) { + conn, err := net.Dial("tcp", host) + if err != nil { + logger.Debugf(err.Error()) + return + } + + go func() { + for { + buf := make([]byte, 32*1024) + n, err := conn.Read(buf) + if err != nil { + logger.Debugf(err.Error()) + return + } + + m := &Message{ + Method: MessageMethodData, + ConnID: muxConn.ID, + MessageID: muxConn.DataID, + Data: buf[:n], + } + muxConn.DataID++ + + server.MessageChan <- m + } + }() + + return +} diff --git a/core/server.go b/core/server.go index 6e10b00..b358d5f 100644 --- a/core/server.go +++ b/core/server.go @@ -1,7 +1,6 @@ package core import ( - "encoding/json" "io" "net" "net/http" @@ -12,9 +11,10 @@ import ( "crypto/tls" + "sync" + "github.com/gorilla/websocket" "github.com/juju/loggo" - "github.com/xtaci/smux" ) type Server struct { @@ -28,6 +28,10 @@ type Server struct { Upgrader *websocket.Upgrader + MessageChan chan *Message + muxConnMap sync.Map + Mutex sync.Mutex + CreatedAt time.Time Opened uint64 @@ -51,55 +55,7 @@ func (server *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) { defer atomic.AddUint64(&server.Closed, 1) if r.Header.Get("WebSocks-Mux") == "mux" { - session, err := smux.Server(ws, nil) - if err != nil { - logger.Errorf(err.Error()) - return - } - - for { - stream, err := session.AcceptStream() - if err != nil { - logger.Errorf(err.Error()) - return - } - - dec := json.NewDecoder(stream) - req := &MuxRequest{} - err = dec.Decode(req) - if err != nil { - logger.Errorf(err.Error()) - return - } - - host := req.Host - - conn, err := net.Dial("tcp", host) - if err != nil { - if err != nil { - logger.Debugf(err.Error()) - } - return - } - - go func() { - downloaded, err := io.Copy(conn, stream) - atomic.AddUint64(&server.Downloaded, uint64(downloaded)) - if err != nil { - logger.Debugf(err.Error()) - stream.Close() - return - } - }() - - uploaded, err := io.Copy(stream, conn) - atomic.AddUint64(&server.Uploaded, uint64(uploaded)) - if err != nil { - logger.Debugf(err.Error()) - stream.Close() - return - } - } + server.HandleWS(ws) return } diff --git a/core/websocket.go b/core/websocket.go index 1749f7b..02e8b41 100644 --- a/core/websocket.go +++ b/core/websocket.go @@ -1,12 +1,15 @@ package core import ( + "sync" + "github.com/gorilla/websocket" ) type WebSocket struct { - conn *websocket.Conn - buf []byte + conn *websocket.Conn + buf []byte + mutex sync.Mutex } func (ws *WebSocket) Read(p []byte) (n int, err error) { @@ -24,11 +27,13 @@ func (ws *WebSocket) Read(p []byte) (n int, err error) { } func (ws *WebSocket) Write(p []byte) (n int, err error) { + ws.mutex.Lock() err = ws.conn.WriteMessage(websocket.BinaryMessage, p) if err != nil { return } + ws.mutex.Unlock() return len(p), nil } diff --git a/websocks.go b/websocks.go index 20dacab..8a10b3d 100644 --- a/websocks.go +++ b/websocks.go @@ -17,7 +17,6 @@ import ( "github.com/juju/loggo" "github.com/lzjluzijie/websocks/core" "github.com/urfave/cli" - "github.com/xtaci/smux" ) func main() { @@ -114,9 +113,8 @@ func main() { HandshakeTimeout: 10 * time.Second, TLSClientConfig: tlsConfig, }, - Mux: mux, - StreamChan: make(chan *smux.Stream, 8), - CreatedAt: time.Now(), + Mux: mux, + CreatedAt: time.Now(), } err = local.Listen() @@ -190,7 +188,8 @@ func main() { WriteBufferSize: 4 * 1024, HandshakeTimeout: 10 * time.Second, }, - CreatedAt: time.Now(), + MessageChan: make(chan *core.Message), + CreatedAt: time.Now(), } logger.Infof("Listening at %s", listenAddr) From 3d8b80a787bf78474dd21eef0bba98061673f7e6 Mon Sep 17 00:00:00 2001 From: halulu Date: Mon, 4 Jun 2018 20:38:04 +0800 Subject: [PATCH 5/8] rewrite mux (still can not use) --- core/client.go | 35 +++---- core/mux.go | 117 ++++++++++++++++++++-- core/muxclient.go | 231 ++++++++++++++++++++----------------------- core/muxserver.go | 217 ++++++++++++++++++++-------------------- core/muxwebsocket.go | 124 +++++++++++++++++++++++ core/server.go | 6 +- core/websocket.go | 9 +- 7 files changed, 463 insertions(+), 276 deletions(-) create mode 100644 core/muxwebsocket.go diff --git a/core/client.go b/core/client.go index 64457c5..ad7fb79 100644 --- a/core/client.go +++ b/core/client.go @@ -17,7 +17,8 @@ type Client struct { ListenAddr *net.TCPAddr URL *url.URL - Mux bool + Mux bool + MuxWS *MuxWebSocket Dialer *websocket.Dialer @@ -37,29 +38,13 @@ func (client *Client) Listen() (err error) { defer listener.Close() if client.Mux { - muxClient := &MuxClient{ - Client: client, - MessageChan: make(chan *Message), - } - - for i := 0; i < 4; i++ { - err = muxClient.Open() - if err != nil { - logger.Debugf(err.Error()) - return - } + muxWS, err := client.OpenMux() + if err != nil { + logger.Debugf(err.Error()) + return err } - for { - conn, err := listener.AcceptTCP() - if err != nil { - logger.Debugf(err.Error()) - continue - } - - go muxClient.handleConn(conn) - } - return + client.MuxWS = muxWS } for { @@ -69,7 +54,11 @@ func (client *Client) Listen() (err error) { continue } - go client.handleConn(conn) + if client.Mux { + go client.handleMuxConn(conn) + } else { + go client.handleConn(conn) + } } return nil diff --git a/core/mux.go b/core/mux.go index 2178343..b37e2a3 100644 --- a/core/mux.go +++ b/core/mux.go @@ -1,19 +1,120 @@ package core +import ( + "io" + "math/rand" + "net" + "sync" + "sync/atomic" +) + const ( MessageMethodData = iota MessageMethodDial ) -type MuxConn struct { - ID int - DataID int - DataChan chan []byte -} - type Message struct { Method byte - ConnID int - MessageID int + ConnID uint64 + MessageID uint64 Data []byte } + +type MuxConn struct { + ID uint64 + muxWS *MuxWebSocket + + messages []*Message + mutex sync.Mutex + buf []byte + + receiveMessageID uint64 + sendMessageID *uint64 +} + +//client use +func NewMuxConn(muxWS *MuxWebSocket) (conn *MuxConn) { + conn = new(MuxConn) + conn.muxWS = muxWS + conn.ID = rand.Uint64() + return +} + +func (conn *MuxConn) Write(p []byte) (n int, err error) { + m := &Message{ + Method: MessageMethodData, + ConnID: conn.ID, + MessageID: conn.SendMessageID(), + Data: p, + } + + err = conn.muxWS.SendMessage(m) + if err != nil { + return 0, err + } + return len(p), nil +} + +func (conn *MuxConn) Read(p []byte) (n int, err error) { + for { + if len(conn.buf) != 0 { + break + } + } + println("readed") + + conn.mutex.Lock() + n = copy(p, conn.buf) + conn.buf = conn.buf[n:] + conn.mutex.Unlock() + return +} + +func (conn *MuxConn) ReceiveMessage(m *Message) (err error) { + for { + if conn.receiveMessageID == m.MessageID { + conn.mutex.Lock() + conn.buf = append(conn.buf, m.Data...) + conn.receiveMessageID++ + conn.mutex.Unlock() + return + } + } + return +} + +//client dial remote +func (conn *MuxConn) DialMessage(host string) (err error) { + m := &Message{ + Method: MessageMethodDial, + MessageID: 18446744073709551615, + ConnID: conn.ID, + Data: []byte(host), + } + + err = conn.muxWS.SendMessage(m) + return +} + +func (conn *MuxConn) SendMessageID() (id uint64) { + id = atomic.LoadUint64(conn.sendMessageID) + atomic.AddUint64(conn.sendMessageID, 1) + return +} + +func (conn *MuxConn) Run(c *net.TCPConn) { + go func() { + _, err := io.Copy(conn, c) + if err != nil { + logger.Debugf(err.Error()) + } + }() + + go func() { + _, err := io.Copy(c, conn) + if err != nil { + logger.Debugf(err.Error()) + } + }() + return +} diff --git a/core/muxclient.go b/core/muxclient.go index a48f4d6..d788a5b 100644 --- a/core/muxclient.go +++ b/core/muxclient.go @@ -1,27 +1,15 @@ package core import ( - "encoding/gob" - "errors" - "math/rand" "net" - "sync" ) -type MuxClient struct { - *Client - MessageChan chan *Message - muxConnMap sync.Map - Mutex sync.Mutex -} - -func (client *MuxClient) Open() (err error) { +func (client *Client) OpenMux() (muxWS *MuxWebSocket, err error) { wsConn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ "WebSocks-Mux": {"mux"}, }) if err != nil { - logger.Errorf(err.Error()) return } @@ -29,119 +17,10 @@ func (client *MuxClient) Open() (err error) { conn: wsConn, } - dec := gob.NewDecoder(ws) - enc := gob.NewEncoder(ws) - - go func() { - for { - m := &Message{} - err = dec.Decode(m) - if err != nil { - logger.Debugf(err.Error()) - return - } - - err = client.HandleMessage(m) - if err != nil { - logger.Debugf(err.Error()) - continue - } - } - }() - - go func() { - for { - m := <-client.MessageChan - err = enc.Encode(m) - if err != nil { - logger.Debugf(err.Error()) - return - } - } - }() - return -} - -func (client *MuxClient) Dial(conn *net.TCPConn, host string) { - dataChan := make(chan []byte) - id := rand.Int() - muxConn := &MuxConn{ - DataChan: dataChan, - ID: id, - } - - client.muxConnMap.Store(id, muxConn) - - m := &Message{ - Method: MessageMethodDial, - ConnID: id, - Data: []byte(host), - } - client.MessageChan <- m - - //listen local conn and send message - go func() { - messageID := 1 - buf := make([]byte, 32*1024) - for { - n, err := conn.Read(buf) - if err != nil { - logger.Errorf(err.Error()) - return - } - - println(n) - - dataMessage := &Message{ - Method: MessageMethodData, - ConnID: id, - MessageID: messageID, - Data: buf[:n], - } - - messageID++ - client.MessageChan <- dataMessage - } - }() - - go func() { - for { - _, err := conn.Write(<-dataChan) - if err != nil { - logger.Debugf(err.Error()) - conn.Close() - } - } - }() - - return -} - -func (client *MuxClient) HandleMessage(m *Message) (err error) { - if m.Method != MessageMethodData { - return errors.New("unknown method") - } - - connID := m.ConnID - c, ok := client.muxConnMap.Load(connID) - if !ok { - return errors.New("can not load conn") - } - conn := c.(*MuxConn) - - go func() { - for { - if conn.DataID == m.MessageID { - conn.DataChan <- m.Data - return - } - } - }() - + muxWS = NewMuxWebSocket(ws) return } - -func (client *MuxClient) handleConn(conn *net.TCPConn) { +func (client *Client) handleMuxConn(conn *net.TCPConn) { defer conn.Close() conn.SetLinger(0) @@ -164,6 +43,108 @@ func (client *MuxClient) handleConn(conn *net.TCPConn) { return } - client.Dial(conn, host) + logger.Debugf("host: %s", host) + + muxConn := NewMuxConn(client.MuxWS) + + err = muxConn.DialMessage(host) + if err != nil { + logger.Errorf(err.Error()) + return + } + + muxConn.Run(conn) return } + +//func (client *Client) Dial(conn *net.TCPConn, host string) { +// muxConn := NewMuxConn(client.MuxWS) +// +// //listen local conn and send message +// go func() { +// messageID := 1 +// buf := make([]byte, 32*1024) +// for { +// n, err := conn.Read(buf) +// if err != nil { +// logger.Errorf(err.Error()) +// return +// } +// +// println(n) +// +// dataMessage := &Message{ +// Method: MessageMethodData, +// ConnID: id, +// MessageID: messageID, +// Data: buf[:n], +// } +// +// messageID++ +// client.MessageChan <- dataMessage +// } +// }() +// +// go func() { +// for { +// _, err := conn.Write(<-dataChan) +// if err != nil { +// logger.Debugf(err.Error()) +// conn.Close() +// } +// } +// }() +// +// return +//} +// +//func (client *MuxClient) HandleMessage(m *Message) (err error) { +// if m.Method != MessageMethodData { +// return errors.New("unknown method") +// } +// +// connID := m.ConnID +// c, ok := client.muxConnMap.Load(connID) +// if !ok { +// return errors.New("can not load conn") +// } +// conn := c.(*MuxConn) +// +// go func() { +// for { +// if conn.DataID == m.MessageID { +// conn.DataChan <- m.Data +// return +// } +// } +// }() +// +// return +//} +// +//func (client *Client) handleConn(conn *net.TCPConn) { +// defer conn.Close() +// +// conn.SetLinger(0) +// +// err := handShake(conn) +// if err != nil { +// logger.Errorf(err.Error()) +// return +// } +// +// _, host, err := getRequest(conn) +// if err != nil { +// logger.Errorf(err.Error()) +// return +// } +// +// _, err = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x08, 0x43}) +// if err != nil { +// logger.Errorf(err.Error()) +// return +// } +// +// client.Dial(conn, host) +// return +//} diff --git a/core/muxserver.go b/core/muxserver.go index 9b44889..9ab97e4 100644 --- a/core/muxserver.go +++ b/core/muxserver.go @@ -1,114 +1,107 @@ package core -import ( - "encoding/gob" - "errors" - "net" - "time" -) - -func (server *Server) HandleWS(ws *WebSocket) (err error) { - dec := gob.NewDecoder(ws) - enc := gob.NewEncoder(ws) - - //receive messages - go func() { - for { - m := &Message{} - err = dec.Decode(m) - if err != nil { - logger.Debugf(err.Error()) - return - } - - err = server.HandleMessage(m) - if err != nil { - logger.Debugf(err.Error()) - continue - } - } - }() - - //send messages - go func() { - for { - m := <-server.MessageChan - err = enc.Encode(m) - if err != nil { - logger.Debugf(err.Error()) - return - } - } - }() - - time.Sleep(time.Minute) - return -} - -func (server *Server) HandleMessage(m *Message) (err error) { - if m.Method == MessageMethodDial { - id := m.ConnID - dataChan := make(chan []byte) - conn := &MuxConn{ - ID: id, - DataChan: dataChan, - } - - server.muxConnMap.Store(id, conn) - server.DialRemote(conn, string(m.Data)) - return - } - - if m.Method != MessageMethodData { - return errors.New("unknown method") - } - - connID := m.ConnID - c, ok := server.muxConnMap.Load(connID) - if !ok { - return errors.New("can not load conn") - } - - conn := c.(*MuxConn) - go func() { - for { - if conn.DataID == m.MessageID { - conn.DataChan <- m.Data - return - } - } - }() - - return -} - -func (server *Server) DialRemote(muxConn *MuxConn, host string) { - conn, err := net.Dial("tcp", host) - if err != nil { - logger.Debugf(err.Error()) - return - } - - go func() { - for { - buf := make([]byte, 32*1024) - n, err := conn.Read(buf) - if err != nil { - logger.Debugf(err.Error()) - return - } - - m := &Message{ - Method: MessageMethodData, - ConnID: muxConn.ID, - MessageID: muxConn.DataID, - Data: buf[:n], - } - muxConn.DataID++ - - server.MessageChan <- m - } - }() - - return -} +//func (server *Server) HandleMuxWS(ws *WebSocket) (muxWS *MuxWebSocket,err error) { +// dec := gob.NewDecoder(ws) +// enc := gob.NewEncoder(ws) +// +// //receive messages +// go func() { +// for { +// m := &Message{} +// err = dec.Decode(m) +// if err != nil { +// logger.Debugf(err.Error()) +// return +// } +// +// err = server.HandleMessage(m) +// if err != nil { +// logger.Debugf(err.Error()) +// continue +// } +// } +// }() +// +// //send messages +// go func() { +// for { +// m := <-server.MessageChan +// err = enc.Encode(m) +// if err != nil { +// logger.Debugf(err.Error()) +// return +// } +// } +// }() +// +// time.Sleep(time.Minute) +// return +//} +// +//func (server *Server) HandleMessage(m *Message) (err error) { +// if m.Method == MessageMethodDial { +// id := m.ConnID +// dataChan := make(chan []byte) +// conn := &MuxConn{ +// ID: id, +// DataChan: dataChan, +// } +// +// server.muxConnMap.Store(id, conn) +// server.DialRemote(conn, string(m.Data)) +// return +// } +// +// if m.Method != MessageMethodData { +// return errors.New("unknown method") +// } +// +// connID := m.ConnID +// c, ok := server.muxConnMap.Load(connID) +// if !ok { +// return errors.New("can not load conn") +// } +// +// conn := c.(*MuxConn) +// go func() { +// for { +// if conn.DataID == m.MessageID { +// conn.DataChan <- m.Data +// return +// } +// } +// }() +// +// return +//} +// +//func (server *Server) DialRemote(muxConn *MuxConn, host string) { +// conn, err := net.Dial("tcp", host) +// if err != nil { +// logger.Debugf(err.Error()) +// return +// } +// +// go func() { +// for { +// buf := make([]byte, 32*1024) +// n, err := conn.Read(buf) +// if err != nil { +// logger.Debugf(err.Error()) +// return +// } +// +// m := &Message{ +// Method: MessageMethodData, +// ConnID: muxConn.ID, +// MessageID: muxConn.DataID, +// Data: buf[:n], +// } +// muxConn.DataID++ +// +// server.MessageChan <- m +// } +// }() +// +// return +//} diff --git a/core/muxwebsocket.go b/core/muxwebsocket.go new file mode 100644 index 0000000..a8965c2 --- /dev/null +++ b/core/muxwebsocket.go @@ -0,0 +1,124 @@ +package core + +import ( + "encoding/gob" + "fmt" + "net" + "sync" + + "github.com/pkg/errors" +) + +type MuxWebSocket struct { + *WebSocket + Decoder *gob.Decoder + Encoder *gob.Encoder + + connMap sync.Map + + mutex sync.RWMutex +} + +func NewMuxWebSocket(ws *WebSocket) (muxWS *MuxWebSocket) { + dec := gob.NewDecoder(ws) + enc := gob.NewEncoder(ws) + + muxWS = &MuxWebSocket{ + WebSocket: ws, + Decoder: dec, + Encoder: enc, + } + return +} + +func (muxWS *MuxWebSocket) AcceptMuxConn(m *Message) (conn *MuxConn, host string, err error) { + if m.Method != MessageMethodDial { + err = errors.New(fmt.Sprintf("wrong message method %d", m.Method)) + return + } + + host = string(m.Data) + + conn = &MuxConn{ + ID: m.ConnID, + muxWS: muxWS, + } + muxWS.PutMuxConn(conn) + return +} + +func (muxWS *MuxWebSocket) SendMessage(m *Message) (err error) { + muxWS.mutex.Lock() + err = muxWS.Encoder.Encode(m) + muxWS.mutex.Unlock() + return +} + +func (muxWS *MuxWebSocket) ReceiveMessage() (m *Message, err error) { + m = &Message{} + muxWS.mutex.RLock() + err = muxWS.Decoder.Decode(m) + muxWS.mutex.RUnlock() + return +} + +func (muxWS *MuxWebSocket) Listen() (err error) { + //block and listen + for { + m, err := muxWS.ReceiveMessage() + if err != nil { + return err + } + + fmt.Println(m) + + //accept new conn + if m.Method == MessageMethodDial { + conn, host, err := muxWS.AcceptMuxConn(m) + if err != nil { + logger.Debugf(err.Error()) + continue + } + + logger.Debugf("Accepted mux conn %s", host) + + tcpAddr, err := net.ResolveTCPAddr("tcp", host) + if err != nil { + logger.Debugf(err.Error()) + continue + } + + tcpConn, err := net.DialTCP("tcp", nil, tcpAddr) + if err != nil { + logger.Debugf(err.Error()) + continue + } + + conn.Run(tcpConn) + continue + } + + //get conn and send message + conn := muxWS.GetMuxConn(m.ConnID) + err = conn.ReceiveMessage(m) + if err != nil { + logger.Debugf(err.Error()) + continue + } + } + return +} + +func (muxWS *MuxWebSocket) PutMuxConn(conn *MuxConn) { + muxWS.connMap.Store(conn.ID, conn) + return +} + +func (muxWS *MuxWebSocket) GetMuxConn(id uint64) (conn *MuxConn) { + c, ok := muxWS.connMap.Load(id) + if !ok { + panic("not ok!") + } + + return c.(*MuxConn) +} diff --git a/core/server.go b/core/server.go index b358d5f..57247ba 100644 --- a/core/server.go +++ b/core/server.go @@ -55,7 +55,11 @@ func (server *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) { defer atomic.AddUint64(&server.Closed, 1) if r.Header.Get("WebSocks-Mux") == "mux" { - server.HandleWS(ws) + muxWS := NewMuxWebSocket(ws) + err = muxWS.Listen() + if err != nil { + logger.Debugf(err.Error()) + } return } diff --git a/core/websocket.go b/core/websocket.go index 02e8b41..1749f7b 100644 --- a/core/websocket.go +++ b/core/websocket.go @@ -1,15 +1,12 @@ package core import ( - "sync" - "github.com/gorilla/websocket" ) type WebSocket struct { - conn *websocket.Conn - buf []byte - mutex sync.Mutex + conn *websocket.Conn + buf []byte } func (ws *WebSocket) Read(p []byte) (n int, err error) { @@ -27,13 +24,11 @@ func (ws *WebSocket) Read(p []byte) (n int, err error) { } func (ws *WebSocket) Write(p []byte) (n int, err error) { - ws.mutex.Lock() err = ws.conn.WriteMessage(websocket.BinaryMessage, p) if err != nil { return } - ws.mutex.Unlock() return len(p), nil } From e08582a9608ec9338893eb216440390f4c916686 Mon Sep 17 00:00:00 2001 From: halulu Date: Mon, 4 Jun 2018 22:23:18 +0800 Subject: [PATCH 6/8] mux server receive data --- core/client.go | 1 + core/mux.go | 46 ++++++++++++++++----------- core/muxclient.go | 22 +++++++++++-- core/muxserver.go | 74 +++++++++++++++++++++++++++++++++++++++++++ core/muxwebsocket.go | 75 ++------------------------------------------ core/server.go | 5 +-- 6 files changed, 126 insertions(+), 97 deletions(-) diff --git a/core/client.go b/core/client.go index ad7fb79..b504454 100644 --- a/core/client.go +++ b/core/client.go @@ -45,6 +45,7 @@ func (client *Client) Listen() (err error) { } client.MuxWS = muxWS + go client.MuxWS.ClientListen() } for { diff --git a/core/mux.go b/core/mux.go index b37e2a3..d1b3b74 100644 --- a/core/mux.go +++ b/core/mux.go @@ -24,9 +24,9 @@ type MuxConn struct { ID uint64 muxWS *MuxWebSocket - messages []*Message - mutex sync.Mutex - buf []byte + mutex sync.Mutex + buf []byte + wait chan int receiveMessageID uint64 sendMessageID *uint64 @@ -34,10 +34,12 @@ type MuxConn struct { //client use func NewMuxConn(muxWS *MuxWebSocket) (conn *MuxConn) { - conn = new(MuxConn) - conn.muxWS = muxWS - conn.ID = rand.Uint64() - return + return &MuxConn{ + ID: rand.Uint64(), + muxWS: muxWS, + wait: make(chan int), + sendMessageID: new(uint64), + } } func (conn *MuxConn) Write(p []byte) (n int, err error) { @@ -56,12 +58,9 @@ func (conn *MuxConn) Write(p []byte) (n int, err error) { } func (conn *MuxConn) Read(p []byte) (n int, err error) { - for { - if len(conn.buf) != 0 { - break - } + if len(conn.buf) == 0 { + <-conn.wait } - println("readed") conn.mutex.Lock() n = copy(p, conn.buf) @@ -70,15 +69,18 @@ func (conn *MuxConn) Read(p []byte) (n int, err error) { return } -func (conn *MuxConn) ReceiveMessage(m *Message) (err error) { +func (conn *MuxConn) HandleMessage(m *Message) (err error) { for { if conn.receiveMessageID == m.MessageID { conn.mutex.Lock() conn.buf = append(conn.buf, m.Data...) conn.receiveMessageID++ + close(conn.wait) + conn.wait = make(chan int) conn.mutex.Unlock() return } + <-conn.wait } return } @@ -92,7 +94,15 @@ func (conn *MuxConn) DialMessage(host string) (err error) { Data: []byte(host), } + logger.Debugf("dial for %s", host) + err = conn.muxWS.SendMessage(m) + if err != nil { + return + } + + conn.muxWS.PutMuxConn(conn) + logger.Debugf("%d %s", conn.ID, host) return } @@ -103,12 +113,10 @@ func (conn *MuxConn) SendMessageID() (id uint64) { } func (conn *MuxConn) Run(c *net.TCPConn) { - go func() { - _, err := io.Copy(conn, c) - if err != nil { - logger.Debugf(err.Error()) - } - }() + _, err := io.Copy(conn, c) + if err != nil { + logger.Debugf(err.Error()) + } go func() { _, err := io.Copy(c, conn) diff --git a/core/muxclient.go b/core/muxclient.go index d788a5b..ff226e1 100644 --- a/core/muxclient.go +++ b/core/muxclient.go @@ -43,8 +43,6 @@ func (client *Client) handleMuxConn(conn *net.TCPConn) { return } - logger.Debugf("host: %s", host) - muxConn := NewMuxConn(client.MuxWS) err = muxConn.DialMessage(host) @@ -53,10 +51,30 @@ func (client *Client) handleMuxConn(conn *net.TCPConn) { return } + logger.Debugf("dialed for %s", host) + muxConn.Run(conn) return } +func (muxWS *MuxWebSocket) ClientListen() { + for { + m, err := muxWS.ReceiveMessage() + if err != nil { + logger.Debugf(err.Error()) + return + } + + //get conn and send message + conn := muxWS.GetMuxConn(m.ConnID) + err = conn.HandleMessage(m) + if err != nil { + logger.Debugf(err.Error()) + continue + } + } +} + //func (client *Client) Dial(conn *net.TCPConn, host string) { // muxConn := NewMuxConn(client.MuxWS) // diff --git a/core/muxserver.go b/core/muxserver.go index 9ab97e4..1c1fc53 100644 --- a/core/muxserver.go +++ b/core/muxserver.go @@ -1,5 +1,79 @@ package core +import ( + "errors" + "fmt" + "net" +) + +func (muxWS *MuxWebSocket) ServerListen() { + //block and listen + for { + m, err := muxWS.ReceiveMessage() + if err != nil { + logger.Debugf(err.Error()) + return + } + + go muxWS.ServerHandleMessage(m) + } + return +} + +func (muxWS *MuxWebSocket) AcceptMuxConn(m *Message) (conn *MuxConn, host string, err error) { + if m.Method != MessageMethodDial { + err = errors.New(fmt.Sprintf("wrong message method %d", m.Method)) + return + } + + host = string(m.Data) + + conn = &MuxConn{ + ID: m.ConnID, + muxWS: muxWS, + wait: make(chan int), + sendMessageID: new(uint64), + } + muxWS.PutMuxConn(conn) + return +} + +func (muxWS *MuxWebSocket) ServerHandleMessage(m *Message) { + //accept new conn + if m.Method == MessageMethodDial { + conn, host, err := muxWS.AcceptMuxConn(m) + if err != nil { + logger.Debugf(err.Error()) + return + } + + tcpAddr, err := net.ResolveTCPAddr("tcp", host) + if err != nil { + logger.Debugf(err.Error()) + return + } + + tcpConn, err := net.DialTCP("tcp", nil, tcpAddr) + if err != nil { + logger.Debugf(err.Error()) + return + } + + logger.Debugf("Accepted mux conn %s", host) + + conn.Run(tcpConn) + return + } + + //get conn and send message + conn := muxWS.GetMuxConn(m.ConnID) + err := conn.HandleMessage(m) + if err != nil { + logger.Debugf(err.Error()) + return + } +} + //func (server *Server) HandleMuxWS(ws *WebSocket) (muxWS *MuxWebSocket,err error) { // dec := gob.NewDecoder(ws) // enc := gob.NewEncoder(ws) diff --git a/core/muxwebsocket.go b/core/muxwebsocket.go index a8965c2..eb803e9 100644 --- a/core/muxwebsocket.go +++ b/core/muxwebsocket.go @@ -2,11 +2,7 @@ package core import ( "encoding/gob" - "fmt" - "net" "sync" - - "github.com/pkg/errors" ) type MuxWebSocket struct { @@ -16,7 +12,7 @@ type MuxWebSocket struct { connMap sync.Map - mutex sync.RWMutex + //mutex sync.RWMutex } func NewMuxWebSocket(ws *WebSocket) (muxWS *MuxWebSocket) { @@ -31,81 +27,16 @@ func NewMuxWebSocket(ws *WebSocket) (muxWS *MuxWebSocket) { return } -func (muxWS *MuxWebSocket) AcceptMuxConn(m *Message) (conn *MuxConn, host string, err error) { - if m.Method != MessageMethodDial { - err = errors.New(fmt.Sprintf("wrong message method %d", m.Method)) - return - } - - host = string(m.Data) - - conn = &MuxConn{ - ID: m.ConnID, - muxWS: muxWS, - } - muxWS.PutMuxConn(conn) - return -} - func (muxWS *MuxWebSocket) SendMessage(m *Message) (err error) { - muxWS.mutex.Lock() err = muxWS.Encoder.Encode(m) - muxWS.mutex.Unlock() + logger.Debugf("sent %#v", m) return } func (muxWS *MuxWebSocket) ReceiveMessage() (m *Message, err error) { m = &Message{} - muxWS.mutex.RLock() err = muxWS.Decoder.Decode(m) - muxWS.mutex.RUnlock() - return -} - -func (muxWS *MuxWebSocket) Listen() (err error) { - //block and listen - for { - m, err := muxWS.ReceiveMessage() - if err != nil { - return err - } - - fmt.Println(m) - - //accept new conn - if m.Method == MessageMethodDial { - conn, host, err := muxWS.AcceptMuxConn(m) - if err != nil { - logger.Debugf(err.Error()) - continue - } - - logger.Debugf("Accepted mux conn %s", host) - - tcpAddr, err := net.ResolveTCPAddr("tcp", host) - if err != nil { - logger.Debugf(err.Error()) - continue - } - - tcpConn, err := net.DialTCP("tcp", nil, tcpAddr) - if err != nil { - logger.Debugf(err.Error()) - continue - } - - conn.Run(tcpConn) - continue - } - - //get conn and send message - conn := muxWS.GetMuxConn(m.ConnID) - err = conn.ReceiveMessage(m) - if err != nil { - logger.Debugf(err.Error()) - continue - } - } + logger.Debugf("received %#v", m) return } diff --git a/core/server.go b/core/server.go index 57247ba..7caecb0 100644 --- a/core/server.go +++ b/core/server.go @@ -56,10 +56,7 @@ func (server *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) { if r.Header.Get("WebSocks-Mux") == "mux" { muxWS := NewMuxWebSocket(ws) - err = muxWS.Listen() - if err != nil { - logger.Debugf(err.Error()) - } + muxWS.ServerListen() return } From 02c221af97f3641a8a2ef44d3137cdcfe21d10d8 Mon Sep 17 00:00:00 2001 From: halulu Date: Tue, 5 Jun 2018 16:41:47 +0800 Subject: [PATCH 7/8] single mux --- core/mux.go | 15 ++++++++++----- core/muxwebsocket.go | 28 ++++++++++++++++++---------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/core/mux.go b/core/mux.go index d1b3b74..6b5133d 100644 --- a/core/mux.go +++ b/core/mux.go @@ -59,10 +59,12 @@ func (conn *MuxConn) Write(p []byte) (n int, err error) { func (conn *MuxConn) Read(p []byte) (n int, err error) { if len(conn.buf) == 0 { + logger.Debugf("%d buf is 0, waiting", conn.ID) <-conn.wait } conn.mutex.Lock() + logger.Debugf("%d buf: %v", conn.buf) n = copy(p, conn.buf) conn.buf = conn.buf[n:] conn.mutex.Unlock() @@ -70,6 +72,7 @@ func (conn *MuxConn) Read(p []byte) (n int, err error) { } func (conn *MuxConn) HandleMessage(m *Message) (err error) { + logger.Debugf("handle message %d %d", m.ConnID, m.MessageID) for { if conn.receiveMessageID == m.MessageID { conn.mutex.Lock() @@ -78,6 +81,7 @@ func (conn *MuxConn) HandleMessage(m *Message) (err error) { close(conn.wait) conn.wait = make(chan int) conn.mutex.Unlock() + logger.Debugf("handled message %d %d", m.ConnID, m.MessageID) return } <-conn.wait @@ -113,16 +117,17 @@ func (conn *MuxConn) SendMessageID() (id uint64) { } func (conn *MuxConn) Run(c *net.TCPConn) { - _, err := io.Copy(conn, c) - if err != nil { - logger.Debugf(err.Error()) - } - go func() { _, err := io.Copy(c, conn) if err != nil { logger.Debugf(err.Error()) } }() + + _, err := io.Copy(conn, c) + if err != nil { + logger.Debugf(err.Error()) + } + return } diff --git a/core/muxwebsocket.go b/core/muxwebsocket.go index eb803e9..f80a386 100644 --- a/core/muxwebsocket.go +++ b/core/muxwebsocket.go @@ -10,9 +10,9 @@ type MuxWebSocket struct { Decoder *gob.Decoder Encoder *gob.Encoder - connMap sync.Map - - //mutex sync.RWMutex + muxConns []*MuxConn + muxConnID []uint64 + mutex sync.Mutex } func NewMuxWebSocket(ws *WebSocket) (muxWS *MuxWebSocket) { @@ -30,6 +30,7 @@ func NewMuxWebSocket(ws *WebSocket) (muxWS *MuxWebSocket) { func (muxWS *MuxWebSocket) SendMessage(m *Message) (err error) { err = muxWS.Encoder.Encode(m) logger.Debugf("sent %#v", m) + //logger.Debugf("sent message %d %d %s", m.ConnID, m.MessageID, string(m.Data)) return } @@ -37,19 +38,26 @@ func (muxWS *MuxWebSocket) ReceiveMessage() (m *Message, err error) { m = &Message{} err = muxWS.Decoder.Decode(m) logger.Debugf("received %#v", m) + //logger.Debugf("received message %d %d %s", m.ConnID, m.MessageID, string(m.Data)) return } func (muxWS *MuxWebSocket) PutMuxConn(conn *MuxConn) { - muxWS.connMap.Store(conn.ID, conn) + muxWS.mutex.Lock() + muxWS.muxConns = append(muxWS.muxConns, conn) + muxWS.muxConnID = append(muxWS.muxConnID, conn.ID) + muxWS.mutex.Unlock() return } -func (muxWS *MuxWebSocket) GetMuxConn(id uint64) (conn *MuxConn) { - c, ok := muxWS.connMap.Load(id) - if !ok { - panic("not ok!") +func (muxWS *MuxWebSocket) GetMuxConn(connID uint64) (conn *MuxConn) { + muxWS.mutex.Lock() + for n, id := range muxWS.muxConnID { + if id == connID { + conn = muxWS.muxConns[n] + break + } } - - return c.(*MuxConn) + muxWS.mutex.Unlock() + return } From 74ebdaf881b606a8b68aa99f44c79e6d8a55ed92 Mon Sep 17 00:00:00 2001 From: halulu Date: Tue, 5 Jun 2018 17:21:44 +0800 Subject: [PATCH 8/8] improve code --- core/client.go | 23 +++--- core/mux.go | 21 ----- core/muxclient.go | 177 +++++++++++-------------------------------- core/muxserver.go | 150 +++++++----------------------------- core/muxwebsocket.go | 2 - websocks.go | 2 +- 6 files changed, 89 insertions(+), 286 deletions(-) diff --git a/core/client.go b/core/client.go index b504454..0206b4a 100644 --- a/core/client.go +++ b/core/client.go @@ -38,13 +38,12 @@ func (client *Client) Listen() (err error) { defer listener.Close() if client.Mux { - muxWS, err := client.OpenMux() + err := client.OpenMux() if err != nil { logger.Debugf(err.Error()) return err } - client.MuxWS = muxWS go client.MuxWS.ClientListen() } @@ -55,11 +54,7 @@ func (client *Client) Listen() (err error) { continue } - if client.Mux { - go client.handleMuxConn(conn) - } else { - go client.handleConn(conn) - } + go client.handleConn(conn) } return nil @@ -88,8 +83,16 @@ func (client *Client) handleConn(conn *net.TCPConn) { return } - logger.Debugf("host: %s", host) + if client.Mux { + client.DialMuxConn(host, conn) + } else { + client.DialWSConn(host, conn) + } + + return +} +func (client *Client) DialWSConn(host string, conn *net.TCPConn) { wsConn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ "WebSocks-Host": {host}, }) @@ -98,6 +101,8 @@ func (client *Client) handleConn(conn *net.TCPConn) { return } + logger.Debugf("dialed ws for %s", host) + ws := &WebSocket{ conn: wsConn, } @@ -113,8 +118,8 @@ func (client *Client) handleConn(conn *net.TCPConn) { _, err = io.Copy(conn, ws) if err != nil { + logger.Debugf(err.Error()) return } - return } diff --git a/core/mux.go b/core/mux.go index 6b5133d..64dc5d0 100644 --- a/core/mux.go +++ b/core/mux.go @@ -89,27 +89,6 @@ func (conn *MuxConn) HandleMessage(m *Message) (err error) { return } -//client dial remote -func (conn *MuxConn) DialMessage(host string) (err error) { - m := &Message{ - Method: MessageMethodDial, - MessageID: 18446744073709551615, - ConnID: conn.ID, - Data: []byte(host), - } - - logger.Debugf("dial for %s", host) - - err = conn.muxWS.SendMessage(m) - if err != nil { - return - } - - conn.muxWS.PutMuxConn(conn) - logger.Debugf("%d %s", conn.ID, host) - return -} - func (conn *MuxConn) SendMessageID() (id uint64) { id = atomic.LoadUint64(conn.sendMessageID) atomic.AddUint64(conn.sendMessageID, 1) diff --git a/core/muxclient.go b/core/muxclient.go index ff226e1..9c6cb2d 100644 --- a/core/muxclient.go +++ b/core/muxclient.go @@ -4,7 +4,25 @@ import ( "net" ) -func (client *Client) OpenMux() (muxWS *MuxWebSocket, err error) { +func (muxWS *MuxWebSocket) ClientListen() { + for { + m, err := muxWS.ReceiveMessage() + if err != nil { + logger.Debugf(err.Error()) + return + } + + //get conn and send message + conn := muxWS.GetMuxConn(m.ConnID) + err = conn.HandleMessage(m) + if err != nil { + logger.Debugf(err.Error()) + continue + } + } +} + +func (client *Client) OpenMux() (err error) { wsConn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ "WebSocks-Mux": {"mux"}, }) @@ -17,152 +35,47 @@ func (client *Client) OpenMux() (muxWS *MuxWebSocket, err error) { conn: wsConn, } - muxWS = NewMuxWebSocket(ws) + muxWS := NewMuxWebSocket(ws) + client.MuxWS = muxWS return } -func (client *Client) handleMuxConn(conn *net.TCPConn) { - defer conn.Close() - - conn.SetLinger(0) +func (client *Client) DialMuxConn(host string, conn *net.TCPConn) { + muxConn := NewMuxConn(client.MuxWS) - err := handShake(conn) + err := muxConn.DialMessage(host) if err != nil { logger.Errorf(err.Error()) + err = client.OpenMux() + if err != nil { + logger.Errorf(err.Error()) + } return } - _, host, err := getRequest(conn) - if err != nil { - logger.Errorf(err.Error()) - return - } + muxConn.muxWS.PutMuxConn(muxConn) - _, err = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x08, 0x43}) - if err != nil { - logger.Errorf(err.Error()) - return + logger.Debugf("dialed mux for %s", host) + + muxConn.Run(conn) + return +} + +//client dial remote +func (conn *MuxConn) DialMessage(host string) (err error) { + m := &Message{ + Method: MessageMethodDial, + MessageID: 18446744073709551615, + ConnID: conn.ID, + Data: []byte(host), } - muxConn := NewMuxConn(client.MuxWS) + logger.Debugf("dial for %s", host) - err = muxConn.DialMessage(host) + err = conn.muxWS.SendMessage(m) if err != nil { - logger.Errorf(err.Error()) return } - logger.Debugf("dialed for %s", host) - - muxConn.Run(conn) + logger.Debugf("%d %s", conn.ID, host) return } - -func (muxWS *MuxWebSocket) ClientListen() { - for { - m, err := muxWS.ReceiveMessage() - if err != nil { - logger.Debugf(err.Error()) - return - } - - //get conn and send message - conn := muxWS.GetMuxConn(m.ConnID) - err = conn.HandleMessage(m) - if err != nil { - logger.Debugf(err.Error()) - continue - } - } -} - -//func (client *Client) Dial(conn *net.TCPConn, host string) { -// muxConn := NewMuxConn(client.MuxWS) -// -// //listen local conn and send message -// go func() { -// messageID := 1 -// buf := make([]byte, 32*1024) -// for { -// n, err := conn.Read(buf) -// if err != nil { -// logger.Errorf(err.Error()) -// return -// } -// -// println(n) -// -// dataMessage := &Message{ -// Method: MessageMethodData, -// ConnID: id, -// MessageID: messageID, -// Data: buf[:n], -// } -// -// messageID++ -// client.MessageChan <- dataMessage -// } -// }() -// -// go func() { -// for { -// _, err := conn.Write(<-dataChan) -// if err != nil { -// logger.Debugf(err.Error()) -// conn.Close() -// } -// } -// }() -// -// return -//} -// -//func (client *MuxClient) HandleMessage(m *Message) (err error) { -// if m.Method != MessageMethodData { -// return errors.New("unknown method") -// } -// -// connID := m.ConnID -// c, ok := client.muxConnMap.Load(connID) -// if !ok { -// return errors.New("can not load conn") -// } -// conn := c.(*MuxConn) -// -// go func() { -// for { -// if conn.DataID == m.MessageID { -// conn.DataChan <- m.Data -// return -// } -// } -// }() -// -// return -//} -// -//func (client *Client) handleConn(conn *net.TCPConn) { -// defer conn.Close() -// -// conn.SetLinger(0) -// -// err := handShake(conn) -// if err != nil { -// logger.Errorf(err.Error()) -// return -// } -// -// _, host, err := getRequest(conn) -// if err != nil { -// logger.Errorf(err.Error()) -// return -// } -// -// _, err = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x08, 0x43}) -// if err != nil { -// logger.Errorf(err.Error()) -// return -// } -// -// client.Dial(conn, host) -// return -//} diff --git a/core/muxserver.go b/core/muxserver.go index 1c1fc53..7631ef7 100644 --- a/core/muxserver.go +++ b/core/muxserver.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net" + "time" ) func (muxWS *MuxWebSocket) ServerListen() { @@ -20,25 +21,12 @@ func (muxWS *MuxWebSocket) ServerListen() { return } -func (muxWS *MuxWebSocket) AcceptMuxConn(m *Message) (conn *MuxConn, host string, err error) { - if m.Method != MessageMethodDial { - err = errors.New(fmt.Sprintf("wrong message method %d", m.Method)) +func (muxWS *MuxWebSocket) ServerHandleMessage(m *Message) { + //check message + if m.Data == nil { return } - host = string(m.Data) - - conn = &MuxConn{ - ID: m.ConnID, - muxWS: muxWS, - wait: make(chan int), - sendMessageID: new(uint64), - } - muxWS.PutMuxConn(conn) - return -} - -func (muxWS *MuxWebSocket) ServerHandleMessage(m *Message) { //accept new conn if m.Method == MessageMethodDial { conn, host, err := muxWS.AcceptMuxConn(m) @@ -67,6 +55,14 @@ func (muxWS *MuxWebSocket) ServerHandleMessage(m *Message) { //get conn and send message conn := muxWS.GetMuxConn(m.ConnID) + if conn == nil { + time.Sleep(time.Second) + conn = muxWS.GetMuxConn(m.ConnID) + if conn == nil { + logger.Debugf("conn %d do not exist", m.ConnID) + return + } + } err := conn.HandleMessage(m) if err != nil { logger.Debugf(err.Error()) @@ -74,108 +70,20 @@ func (muxWS *MuxWebSocket) ServerHandleMessage(m *Message) { } } -//func (server *Server) HandleMuxWS(ws *WebSocket) (muxWS *MuxWebSocket,err error) { -// dec := gob.NewDecoder(ws) -// enc := gob.NewEncoder(ws) -// -// //receive messages -// go func() { -// for { -// m := &Message{} -// err = dec.Decode(m) -// if err != nil { -// logger.Debugf(err.Error()) -// return -// } -// -// err = server.HandleMessage(m) -// if err != nil { -// logger.Debugf(err.Error()) -// continue -// } -// } -// }() -// -// //send messages -// go func() { -// for { -// m := <-server.MessageChan -// err = enc.Encode(m) -// if err != nil { -// logger.Debugf(err.Error()) -// return -// } -// } -// }() -// -// time.Sleep(time.Minute) -// return -//} -// -//func (server *Server) HandleMessage(m *Message) (err error) { -// if m.Method == MessageMethodDial { -// id := m.ConnID -// dataChan := make(chan []byte) -// conn := &MuxConn{ -// ID: id, -// DataChan: dataChan, -// } -// -// server.muxConnMap.Store(id, conn) -// server.DialRemote(conn, string(m.Data)) -// return -// } -// -// if m.Method != MessageMethodData { -// return errors.New("unknown method") -// } -// -// connID := m.ConnID -// c, ok := server.muxConnMap.Load(connID) -// if !ok { -// return errors.New("can not load conn") -// } -// -// conn := c.(*MuxConn) -// go func() { -// for { -// if conn.DataID == m.MessageID { -// conn.DataChan <- m.Data -// return -// } -// } -// }() -// -// return -//} -// -//func (server *Server) DialRemote(muxConn *MuxConn, host string) { -// conn, err := net.Dial("tcp", host) -// if err != nil { -// logger.Debugf(err.Error()) -// return -// } -// -// go func() { -// for { -// buf := make([]byte, 32*1024) -// n, err := conn.Read(buf) -// if err != nil { -// logger.Debugf(err.Error()) -// return -// } -// -// m := &Message{ -// Method: MessageMethodData, -// ConnID: muxConn.ID, -// MessageID: muxConn.DataID, -// Data: buf[:n], -// } -// muxConn.DataID++ -// -// server.MessageChan <- m -// } -// }() -// -// return -//} +func (muxWS *MuxWebSocket) AcceptMuxConn(m *Message) (conn *MuxConn, host string, err error) { + if m.Method != MessageMethodDial { + err = errors.New(fmt.Sprintf("wrong message method %d", m.Method)) + return + } + + host = string(m.Data) + + conn = &MuxConn{ + ID: m.ConnID, + muxWS: muxWS, + wait: make(chan int), + sendMessageID: new(uint64), + } + muxWS.PutMuxConn(conn) + return +} diff --git a/core/muxwebsocket.go b/core/muxwebsocket.go index f80a386..099fc88 100644 --- a/core/muxwebsocket.go +++ b/core/muxwebsocket.go @@ -30,7 +30,6 @@ func NewMuxWebSocket(ws *WebSocket) (muxWS *MuxWebSocket) { func (muxWS *MuxWebSocket) SendMessage(m *Message) (err error) { err = muxWS.Encoder.Encode(m) logger.Debugf("sent %#v", m) - //logger.Debugf("sent message %d %d %s", m.ConnID, m.MessageID, string(m.Data)) return } @@ -38,7 +37,6 @@ func (muxWS *MuxWebSocket) ReceiveMessage() (m *Message, err error) { m = &Message{} err = muxWS.Decoder.Decode(m) logger.Debugf("received %#v", m) - //logger.Debugf("received message %d %d %s", m.ConnID, m.MessageID, string(m.Data)) return } diff --git a/websocks.go b/websocks.go index 8a10b3d..2b3baf4 100644 --- a/websocks.go +++ b/websocks.go @@ -25,7 +25,7 @@ func main() { app := cli.NewApp() app.Name = "WebSocks" - app.Version = "0.7.0" + app.Version = "0.8.0" app.Usage = "A secure proxy based on WebSocket." app.Description = "See https://github.com/lzjluzijie/websocks" app.Author = "Halulu"