diff --git a/core/client.go b/core/client.go index edd2651..ffa4292 100644 --- a/core/client.go +++ b/core/client.go @@ -1,28 +1,26 @@ package core import ( + "errors" "io" "net" - - "errors" "net/url" + "time" - "crypto/tls" - + "github.com/gorilla/websocket" "github.com/juju/loggo" - "golang.org/x/net/websocket" ) var logger = loggo.GetLogger("core") type Client struct { - LogLevel loggo.Level - ListenAddr *net.TCPAddr - URL *url.URL - Origin string - ServerName string - InsecureCert bool - WSConfig websocket.Config + LogLevel loggo.Level + ListenAddr *net.TCPAddr + URL *url.URL + Origin string + + Dialer *websocket.Dialer + CreatedAt time.Time } func (client *Client) Listen() (err error) { @@ -39,19 +37,6 @@ func (client *Client) Listen() (err error) { logger.Debugf(client.Origin) - config, err := websocket.NewConfig(client.URL.String(), client.Origin) - if err != nil { - return - } - - config.TlsConfig = &tls.Config{ - InsecureSkipVerify: client.InsecureCert, - } - if client.ServerName != "" { - config.TlsConfig.ServerName = client.ServerName - } - client.WSConfig = *config - listener, err := net.ListenTCP("tcp", client.ListenAddr) if err != nil { return err @@ -102,18 +87,11 @@ func (client *Client) handleConn(conn *net.TCPConn) (err error) { return } - config := client.WSConfig - config.Header = map[string][]string{ - "WebSocks-Host": {host}, - } - - ws, err := websocket.DialConfig(&config) + ws, err := client.dial(host) if err != nil { return } - defer ws.Close() - go func() { _, err = io.Copy(ws, conn) if err != nil { @@ -130,3 +108,18 @@ func (client *Client) handleConn(conn *net.TCPConn) (err error) { return } + +func (client *Client) dial(host string) (ws *WebSocket, err error) { + conn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ + "WebSocks-Host": {host}, + }) + + if err != nil { + return + } + + ws = &WebSocket{ + conn: conn, + } + return +} diff --git a/core/server.go b/core/server.go index 702de47..d98d228 100644 --- a/core/server.go +++ b/core/server.go @@ -1,21 +1,17 @@ package core import ( + "fmt" "io" "net" "net/http" - - "time" - "net/http/httputil" "net/url" - "sync/atomic" + "time" - "fmt" - + "github.com/gorilla/websocket" "github.com/juju/loggo" - "golang.org/x/net/websocket" ) type Server struct { @@ -27,6 +23,8 @@ type Server struct { KeyPath string Proxy string + Upgrader *websocket.Upgrader + CreatedAt time.Time Opened uint64 @@ -35,13 +33,21 @@ type Server struct { Downloaded uint64 } -func (server *Server) HandleWebSocket(ws *websocket.Conn) { - defer ws.Close() +func (server *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) { + c, err := server.Upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Debugf(err.Error()) + return + } + + ws := &WebSocket{ + conn: c, + } atomic.AddUint64(&server.Opened, 1) defer atomic.AddUint64(&server.Closed, 1) - host := ws.Request().Header.Get("WebSocks-Host") + host := r.Header.Get("WebSocks-Host") logger.Debugf("Dial %s", host) conn, err := net.Dial("tcp", host) @@ -57,9 +63,7 @@ func (server *Server) HandleWebSocket(ws *websocket.Conn) { downloaded, err := io.Copy(conn, ws) atomic.AddUint64(&server.Downloaded, uint64(downloaded)) if err != nil { - if err != nil { - logger.Debugf(err.Error()) - } + logger.Debugf(err.Error()) return } }() @@ -67,11 +71,10 @@ func (server *Server) HandleWebSocket(ws *websocket.Conn) { uploaded, err := io.Copy(ws, conn) atomic.AddUint64(&server.Uploaded, uint64(uploaded)) if err != nil { - if err != nil { - logger.Debugf(err.Error()) - } + logger.Debugf(err.Error()) return } + return } func (server *Server) Status(w http.ResponseWriter, r *http.Request) { @@ -89,7 +92,7 @@ func (server *Server) Listen() (err error) { }() mux := http.NewServeMux() - mux.Handle(server.Pattern, websocket.Handler(server.HandleWebSocket)) + mux.HandleFunc(server.Pattern, server.HandleWebSocket) mux.HandleFunc("/status", server.Status) if server.Proxy != "" { mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { diff --git a/core/socks5.go b/core/socks5.go index a9e2ab2..4c86d8e 100644 --- a/core/socks5.go +++ b/core/socks5.go @@ -69,7 +69,7 @@ func getRequest(conn net.Conn) (rawaddr []byte, host string, err error) { idVer = 0 idCmd = 1 idType = 3 // address type index - idIP0 = 4 // ip addres start index + idIP0 = 4 // ip address start index idDmLen = 4 // domain address length index idDm0 = 5 // domain address start index diff --git a/core/websocket.go b/core/websocket.go new file mode 100644 index 0000000..3f1b69e --- /dev/null +++ b/core/websocket.go @@ -0,0 +1,32 @@ +package core + +import ( + "github.com/gorilla/websocket" +) + +type WebSocket struct { + conn *websocket.Conn + buf []byte +} + +func (ws *WebSocket) Read(p []byte) (n int, err error) { + if len(ws.buf) == 0 { + _, ws.buf, err = ws.conn.ReadMessage() + if err != nil { + return + } + } + + n = copy(p, ws.buf) + ws.buf = ws.buf[n:] + return +} + +func (ws *WebSocket) Write(p []byte) (n int, err error) { + err = ws.conn.WriteMessage(websocket.BinaryMessage, p) + if err != nil { + return + } + + return len(p), nil +} diff --git a/websocks.go b/websocks.go index a2405c0..03fd884 100644 --- a/websocks.go +++ b/websocks.go @@ -11,6 +11,9 @@ import ( "time" + "crypto/tls" + + "github.com/gorilla/websocket" "github.com/juju/loggo" "github.com/lzjluzijie/websocks/core" "github.com/urfave/cli" @@ -87,12 +90,25 @@ func main() { return } + tlsConfig := &tls.Config{ + InsecureSkipVerify: insecureCert, + } + + if serverName != "" { + tlsConfig.ServerName = serverName + } + local := core.Client{ - LogLevel: logger.LogLevel(), - ListenAddr: lAddr, - URL: u, - ServerName: serverName, - InsecureCert: insecureCert, + LogLevel: logger.LogLevel(), + ListenAddr: lAddr, + URL: u, + Dialer: &websocket.Dialer{ + ReadBufferSize: 4 * 1024, + WriteBufferSize: 4 * 1024, + HandshakeTimeout: 10 * time.Second, + TLSClientConfig: tlsConfig, + }, + CreatedAt: time.Now(), } err = local.Listen() @@ -110,7 +126,7 @@ func main() { Flags: []cli.Flag{ cli.StringFlag{ Name: "l", - Value: "127.0.0.1:23333", + Value: "0.0.0.0:23333", Usage: "local listening port", }, cli.StringFlag{ @@ -161,7 +177,12 @@ func main() { CertPath: certPath, KeyPath: keyPath, Proxy: proxy, - CreatedAt: time.Now(), + Upgrader: &websocket.Upgrader{ + ReadBufferSize: 4 * 1024, + WriteBufferSize: 4 * 1024, + HandshakeTimeout: 10 * time.Second, + }, + CreatedAt: time.Now(), } logger.Infof("Listening at %s", listenAddr)