diff --git a/core/client.go b/core/client.go index d69c488..0206b4a 100644 --- a/core/client.go +++ b/core/client.go @@ -17,6 +17,9 @@ type Client struct { ListenAddr *net.TCPAddr URL *url.URL + Mux bool + MuxWS *MuxWebSocket + Dialer *websocket.Dialer CreatedAt time.Time @@ -34,6 +37,16 @@ func (client *Client) Listen() (err error) { defer listener.Close() + if client.Mux { + err := client.OpenMux() + if err != nil { + logger.Debugf(err.Error()) + return err + } + + go client.MuxWS.ClientListen() + } + for { conn, err := listener.AcceptTCP() if err != nil { @@ -47,39 +60,53 @@ func (client *Client) Listen() (err error) { return nil } -func (client *Client) handleConn(conn *net.TCPConn) (err error) { - defer func() { - if err != nil { - logger.Debugf("Handle connection error: %s", err.Error()) - } - }() - +func (client *Client) handleConn(conn *net.TCPConn) { defer conn.Close() conn.SetLinger(0) - err = handShake(conn) + err := handShake(conn) if err != nil { + logger.Errorf(err.Error()) return } _, host, err := getRequest(conn) if err != nil { + logger.Errorf(err.Error()) return } - logger.Debugf("Host: %s", host) - _, err = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x08, 0x43}) if err != nil { + logger.Errorf(err.Error()) return } - ws, err := client.dial(host) + if client.Mux { + client.DialMuxConn(host, conn) + } else { + client.DialWSConn(host, conn) + } + + return +} + +func (client *Client) DialWSConn(host string, conn *net.TCPConn) { + wsConn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ + "WebSocks-Host": {host}, + }) + if err != nil { return } + logger.Debugf("dialed ws for %s", host) + + ws := &WebSocket{ + conn: wsConn, + } + go func() { _, err = io.Copy(ws, conn) if err != nil { @@ -91,23 +118,8 @@ func (client *Client) handleConn(conn *net.TCPConn) (err error) { _, err = io.Copy(conn, ws) if err != nil { + logger.Debugf(err.Error()) return } - - return -} - -func (client *Client) dial(host string) (ws *WebSocket, err error) { - conn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ - "WebSocks-Host": {host}, - }) - - if err != nil { - return - } - - ws = &WebSocket{ - conn: conn, - } return } diff --git a/core/mux.go b/core/mux.go new file mode 100644 index 0000000..64dc5d0 --- /dev/null +++ b/core/mux.go @@ -0,0 +1,112 @@ +package core + +import ( + "io" + "math/rand" + "net" + "sync" + "sync/atomic" +) + +const ( + MessageMethodData = iota + MessageMethodDial +) + +type Message struct { + Method byte + ConnID uint64 + MessageID uint64 + Data []byte +} + +type MuxConn struct { + ID uint64 + muxWS *MuxWebSocket + + mutex sync.Mutex + buf []byte + wait chan int + + receiveMessageID uint64 + sendMessageID *uint64 +} + +//client use +func NewMuxConn(muxWS *MuxWebSocket) (conn *MuxConn) { + return &MuxConn{ + ID: rand.Uint64(), + muxWS: muxWS, + wait: make(chan int), + sendMessageID: new(uint64), + } +} + +func (conn *MuxConn) Write(p []byte) (n int, err error) { + m := &Message{ + Method: MessageMethodData, + ConnID: conn.ID, + MessageID: conn.SendMessageID(), + Data: p, + } + + err = conn.muxWS.SendMessage(m) + if err != nil { + return 0, err + } + return len(p), nil +} + +func (conn *MuxConn) Read(p []byte) (n int, err error) { + if len(conn.buf) == 0 { + logger.Debugf("%d buf is 0, waiting", conn.ID) + <-conn.wait + } + + conn.mutex.Lock() + logger.Debugf("%d buf: %v", conn.buf) + n = copy(p, conn.buf) + conn.buf = conn.buf[n:] + conn.mutex.Unlock() + return +} + +func (conn *MuxConn) HandleMessage(m *Message) (err error) { + logger.Debugf("handle message %d %d", m.ConnID, m.MessageID) + for { + if conn.receiveMessageID == m.MessageID { + conn.mutex.Lock() + conn.buf = append(conn.buf, m.Data...) + conn.receiveMessageID++ + close(conn.wait) + conn.wait = make(chan int) + conn.mutex.Unlock() + logger.Debugf("handled message %d %d", m.ConnID, m.MessageID) + return + } + <-conn.wait + } + return +} + +func (conn *MuxConn) SendMessageID() (id uint64) { + id = atomic.LoadUint64(conn.sendMessageID) + atomic.AddUint64(conn.sendMessageID, 1) + return +} + +func (conn *MuxConn) Run(c *net.TCPConn) { + go func() { + _, err := io.Copy(c, conn) + if err != nil { + logger.Debugf(err.Error()) + } + }() + + _, err := io.Copy(conn, c) + if err != nil { + logger.Debugf(err.Error()) + } + + return +} diff --git a/core/muxclient.go b/core/muxclient.go new file mode 100644 index 0000000..9c6cb2d --- /dev/null +++ b/core/muxclient.go @@ -0,0 +1,81 @@ +package core + +import ( + "net" +) + +func (muxWS *MuxWebSocket) ClientListen() { + for { + m, err := muxWS.ReceiveMessage() + if err != nil { + logger.Debugf(err.Error()) + return + } + + //get conn and send message + conn := muxWS.GetMuxConn(m.ConnID) + err = conn.HandleMessage(m) + if err != nil { + logger.Debugf(err.Error()) + continue + } + } +} + +func (client *Client) OpenMux() (err error) { + wsConn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ + "WebSocks-Mux": {"mux"}, + }) + + if err != nil { + return + } + + ws := &WebSocket{ + conn: wsConn, + } + + muxWS := NewMuxWebSocket(ws) + client.MuxWS = muxWS + return +} +func (client *Client) DialMuxConn(host string, conn *net.TCPConn) { + muxConn := NewMuxConn(client.MuxWS) + + err := muxConn.DialMessage(host) + if err != nil { + logger.Errorf(err.Error()) + err = client.OpenMux() + if err != nil { + logger.Errorf(err.Error()) + } + return + } + + muxConn.muxWS.PutMuxConn(muxConn) + + logger.Debugf("dialed mux for %s", host) + + muxConn.Run(conn) + return +} + +//client dial remote +func (conn *MuxConn) DialMessage(host string) (err error) { + m := &Message{ + Method: MessageMethodDial, + MessageID: 18446744073709551615, + ConnID: conn.ID, + Data: []byte(host), + } + + logger.Debugf("dial for %s", host) + + err = conn.muxWS.SendMessage(m) + if err != nil { + return + } + + logger.Debugf("%d %s", conn.ID, host) + return +} diff --git a/core/muxserver.go b/core/muxserver.go new file mode 100644 index 0000000..7631ef7 --- /dev/null +++ b/core/muxserver.go @@ -0,0 +1,89 @@ +package core + +import ( + "errors" + "fmt" + "net" + "time" +) + +func (muxWS *MuxWebSocket) ServerListen() { + //block and listen + for { + m, err := muxWS.ReceiveMessage() + if err != nil { + logger.Debugf(err.Error()) + return + } + + go muxWS.ServerHandleMessage(m) + } + return +} + +func (muxWS *MuxWebSocket) ServerHandleMessage(m *Message) { + //check message + if m.Data == nil { + return + } + + //accept new conn + if m.Method == MessageMethodDial { + conn, host, err := muxWS.AcceptMuxConn(m) + if err != nil { + logger.Debugf(err.Error()) + return + } + + tcpAddr, err := net.ResolveTCPAddr("tcp", host) + if err != nil { + logger.Debugf(err.Error()) + return + } + + tcpConn, err := net.DialTCP("tcp", nil, tcpAddr) + if err != nil { + logger.Debugf(err.Error()) + return + } + + logger.Debugf("Accepted mux conn %s", host) + + conn.Run(tcpConn) + return + } + + //get conn and send message + conn := muxWS.GetMuxConn(m.ConnID) + if conn == nil { + time.Sleep(time.Second) + conn = muxWS.GetMuxConn(m.ConnID) + if conn == nil { + logger.Debugf("conn %d do not exist", m.ConnID) + return + } + } + err := conn.HandleMessage(m) + if err != nil { + logger.Debugf(err.Error()) + return + } +} + +func (muxWS *MuxWebSocket) AcceptMuxConn(m *Message) (conn *MuxConn, host string, err error) { + if m.Method != MessageMethodDial { + err = errors.New(fmt.Sprintf("wrong message method %d", m.Method)) + return + } + + host = string(m.Data) + + conn = &MuxConn{ + ID: m.ConnID, + muxWS: muxWS, + wait: make(chan int), + sendMessageID: new(uint64), + } + muxWS.PutMuxConn(conn) + return +} diff --git a/core/muxwebsocket.go b/core/muxwebsocket.go new file mode 100644 index 0000000..099fc88 --- /dev/null +++ b/core/muxwebsocket.go @@ -0,0 +1,61 @@ +package core + +import ( + "encoding/gob" + "sync" +) + +type MuxWebSocket struct { + *WebSocket + Decoder *gob.Decoder + Encoder *gob.Encoder + + muxConns []*MuxConn + muxConnID []uint64 + mutex sync.Mutex +} + +func NewMuxWebSocket(ws *WebSocket) (muxWS *MuxWebSocket) { + dec := gob.NewDecoder(ws) + enc := gob.NewEncoder(ws) + + muxWS = &MuxWebSocket{ + WebSocket: ws, + Decoder: dec, + Encoder: enc, + } + return +} + +func (muxWS *MuxWebSocket) SendMessage(m *Message) (err error) { + err = muxWS.Encoder.Encode(m) + logger.Debugf("sent %#v", m) + return +} + +func (muxWS *MuxWebSocket) ReceiveMessage() (m *Message, err error) { + m = &Message{} + err = muxWS.Decoder.Decode(m) + logger.Debugf("received %#v", m) + return +} + +func (muxWS *MuxWebSocket) PutMuxConn(conn *MuxConn) { + muxWS.mutex.Lock() + muxWS.muxConns = append(muxWS.muxConns, conn) + muxWS.muxConnID = append(muxWS.muxConnID, conn.ID) + muxWS.mutex.Unlock() + return +} + +func (muxWS *MuxWebSocket) GetMuxConn(connID uint64) (conn *MuxConn) { + muxWS.mutex.Lock() + for n, id := range muxWS.muxConnID { + if id == connID { + conn = muxWS.muxConns[n] + break + } + } + muxWS.mutex.Unlock() + return +} diff --git a/core/server.go b/core/server.go index 149fd8a..7caecb0 100644 --- a/core/server.go +++ b/core/server.go @@ -11,6 +11,8 @@ import ( "crypto/tls" + "sync" + "github.com/gorilla/websocket" "github.com/juju/loggo" ) @@ -26,6 +28,10 @@ type Server struct { Upgrader *websocket.Upgrader + MessageChan chan *Message + muxConnMap sync.Map + Mutex sync.Mutex + CreatedAt time.Time Opened uint64 @@ -48,6 +54,12 @@ func (server *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) { atomic.AddUint64(&server.Opened, 1) defer atomic.AddUint64(&server.Closed, 1) + if r.Header.Get("WebSocks-Mux") == "mux" { + muxWS := NewMuxWebSocket(ws) + muxWS.ServerListen() + return + } + host := r.Header.Get("WebSocks-Host") logger.Debugf("Dial %s", host) diff --git a/core/websocket.go b/core/websocket.go index 3f1b69e..1749f7b 100644 --- a/core/websocket.go +++ b/core/websocket.go @@ -19,6 +19,7 @@ func (ws *WebSocket) Read(p []byte) (n int, err error) { n = copy(p, ws.buf) ws.buf = ws.buf[n:] + return } @@ -30,3 +31,8 @@ func (ws *WebSocket) Write(p []byte) (n int, err error) { return len(p), nil } + +func (ws *WebSocket) Close() (err error) { + ws.conn.Close() + return +} diff --git a/websocks.go b/websocks.go index 34173ba..2b3baf4 100644 --- a/websocks.go +++ b/websocks.go @@ -25,7 +25,7 @@ func main() { app := cli.NewApp() app.Name = "WebSocks" - app.Version = "0.7.0" + app.Version = "0.8.0" app.Usage = "A secure proxy based on WebSocket." app.Description = "See https://github.com/lzjluzijie/websocks" app.Author = "Halulu" @@ -46,7 +46,7 @@ func main() { Flags: []cli.Flag{ cli.StringFlag{ Name: "l", - Value: ":10801", + Value: "127.0.0.1:10801", Usage: "local listening port", }, cli.StringFlag{ @@ -54,6 +54,10 @@ func main() { Value: "ws://localhost:23333/websocks", Usage: "server url", }, + cli.BoolFlag{ + Name: "mux", + Usage: "mux mode", + }, cli.StringFlag{ Name: "n", Value: "", @@ -68,6 +72,7 @@ func main() { debug := c.GlobalBool("debug") listenAddr := c.String("l") serverURL := c.String("s") + mux := c.Bool("mux") serverName := c.String("n") insecureCert := false if c.Bool("insecure") { @@ -108,6 +113,7 @@ func main() { HandshakeTimeout: 10 * time.Second, TLSClientConfig: tlsConfig, }, + Mux: mux, CreatedAt: time.Now(), } @@ -182,7 +188,8 @@ func main() { WriteBufferSize: 4 * 1024, HandshakeTimeout: 10 * time.Second, }, - CreatedAt: time.Now(), + MessageChan: make(chan *core.Message), + CreatedAt: time.Now(), } logger.Infof("Listening at %s", listenAddr)