From 1ee01621b6631b1b73e6aa93c2021761c40d04eb Mon Sep 17 00:00:00 2001 From: halulu Date: Thu, 24 May 2018 20:00:14 +0800 Subject: [PATCH] use github.com/gorilla/websocke instead of golang.org/x/net/websocket --- core/client.go | 57 +++++++++++++++++++++-------------------------- core/server.go | 11 +++++++-- core/websocket.go | 23 +++---------------- websocks.go | 33 ++++++++++++++++++++++----- 4 files changed, 65 insertions(+), 59 deletions(-) diff --git a/core/client.go b/core/client.go index f1a8c60..ffa4292 100644 --- a/core/client.go +++ b/core/client.go @@ -1,26 +1,26 @@ package core import ( + "errors" "io" "net" "net/url" - "errors" - "crypto/tls" + "time" + "github.com/gorilla/websocket" "github.com/juju/loggo" - "golang.org/x/net/websocket" ) var logger = loggo.GetLogger("core") type Client struct { - LogLevel loggo.Level - ListenAddr *net.TCPAddr - URL *url.URL - Origin string - ServerName string - InsecureCert bool - WSConfig websocket.Config + LogLevel loggo.Level + ListenAddr *net.TCPAddr + URL *url.URL + Origin string + + Dialer *websocket.Dialer + CreatedAt time.Time } func (client *Client) Listen() (err error) { @@ -37,19 +37,6 @@ func (client *Client) Listen() (err error) { logger.Debugf(client.Origin) - config, err := websocket.NewConfig(client.URL.String(), client.Origin) - if err != nil { - return - } - - config.TlsConfig = &tls.Config{ - InsecureSkipVerify: client.InsecureCert, - } - if client.ServerName != "" { - config.TlsConfig.ServerName = client.ServerName - } - client.WSConfig = *config - listener, err := net.ListenTCP("tcp", client.ListenAddr) if err != nil { return err @@ -100,18 +87,11 @@ func (client *Client) handleConn(conn *net.TCPConn) (err error) { return } - config := client.WSConfig - config.Header = map[string][]string{ - "WebSocks-Host": {host}, - } - - ws, err := websocket.DialConfig(&config) + ws, err := client.dial(host) if err != nil { return } - defer ws.Close() - go func() { _, err = io.Copy(ws, conn) if err != nil { @@ -128,3 +108,18 @@ func (client *Client) handleConn(conn *net.TCPConn) (err error) { return } + +func (client *Client) dial(host string) (ws *WebSocket, err error) { + conn, _, err := client.Dialer.Dial(client.URL.String(), map[string][]string{ + "WebSocks-Host": {host}, + }) + + if err != nil { + return + } + + ws = &WebSocket{ + conn: conn, + } + return +} diff --git a/core/server.go b/core/server.go index e4feb5b..d98d228 100644 --- a/core/server.go +++ b/core/server.go @@ -1,6 +1,7 @@ package core import ( + "fmt" "io" "net" "net/http" @@ -8,8 +9,8 @@ import ( "net/url" "sync/atomic" "time" - "fmt" + "github.com/gorilla/websocket" "github.com/juju/loggo" ) @@ -22,6 +23,8 @@ type Server struct { KeyPath string Proxy string + Upgrader *websocket.Upgrader + CreatedAt time.Time Opened uint64 @@ -31,12 +34,16 @@ type Server struct { } func (server *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) { - ws, err := NewWebSocket(w, r) + c, err := server.Upgrader.Upgrade(w, r, nil) if err != nil { logger.Debugf(err.Error()) return } + ws := &WebSocket{ + conn: c, + } + atomic.AddUint64(&server.Opened, 1) defer atomic.AddUint64(&server.Closed, 1) diff --git a/core/websocket.go b/core/websocket.go index 1e72a0e..3f1b69e 100644 --- a/core/websocket.go +++ b/core/websocket.go @@ -1,35 +1,18 @@ package core import ( - "net/http" - "time" - "github.com/gorilla/websocket" ) -var upgrader = &websocket.Upgrader{ - ReadBufferSize: 4 * 1024, - WriteBufferSize: 4 * 1024, - HandshakeTimeout: 10 * time.Second, -} - type WebSocket struct { - conn *websocket.Conn - buf []byte -} - -func NewWebSocket(w http.ResponseWriter, r *http.Request) (ws *WebSocket, err error) { - c, err := upgrader.Upgrade(w, r, nil) - ws = &WebSocket{ - conn: c, - } - return + conn *websocket.Conn + buf []byte } func (ws *WebSocket) Read(p []byte) (n int, err error) { if len(ws.buf) == 0 { _, ws.buf, err = ws.conn.ReadMessage() - if err != nil{ + if err != nil { return } } diff --git a/websocks.go b/websocks.go index a2f2c53..03fd884 100644 --- a/websocks.go +++ b/websocks.go @@ -11,6 +11,9 @@ import ( "time" + "crypto/tls" + + "github.com/gorilla/websocket" "github.com/juju/loggo" "github.com/lzjluzijie/websocks/core" "github.com/urfave/cli" @@ -87,12 +90,25 @@ func main() { return } + tlsConfig := &tls.Config{ + InsecureSkipVerify: insecureCert, + } + + if serverName != "" { + tlsConfig.ServerName = serverName + } + local := core.Client{ - LogLevel: logger.LogLevel(), - ListenAddr: lAddr, - URL: u, - ServerName: serverName, - InsecureCert: insecureCert, + LogLevel: logger.LogLevel(), + ListenAddr: lAddr, + URL: u, + Dialer: &websocket.Dialer{ + ReadBufferSize: 4 * 1024, + WriteBufferSize: 4 * 1024, + HandshakeTimeout: 10 * time.Second, + TLSClientConfig: tlsConfig, + }, + CreatedAt: time.Now(), } err = local.Listen() @@ -161,7 +177,12 @@ func main() { CertPath: certPath, KeyPath: keyPath, Proxy: proxy, - CreatedAt: time.Now(), + Upgrader: &websocket.Upgrader{ + ReadBufferSize: 4 * 1024, + WriteBufferSize: 4 * 1024, + HandshakeTimeout: 10 * time.Second, + }, + CreatedAt: time.Now(), } logger.Infof("Listening at %s", listenAddr)