diff --git a/core/auth.go b/core/auth.go new file mode 100644 index 0000000..0e54f36 --- /dev/null +++ b/core/auth.go @@ -0,0 +1,74 @@ +// Copyright (C) 2020-2021, IrineSistiana +// +// This file is part of simple-tls. +// +// simple-tls is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// simple-tls is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package core + +import ( + "context" + "crypto/md5" + "errors" + "fmt" + "io" + "net" + "time" +) + +type AuthTransport struct { + nextTransport Transport + auth [md5.Size]byte +} + +func (t *AuthTransport) Dial(ctx context.Context) (net.Conn, error) { + conn, err := t.nextTransport.Dial(ctx) + if err != nil { + return nil, err + } + if _, err := conn.Write(t.auth[:]); err != nil { + conn.Close() + return nil, fmt.Errorf("failed to write auth: %w", err) + } + return conn, nil +} + +func NewAuthTransport(nextTransport Transport, auth string) *AuthTransport { + return &AuthTransport{nextTransport: nextTransport, auth: md5.Sum([]byte(auth))} +} + +type AuthTransportHandler struct { + nextHandler TransportHandler + auth [md5.Size]byte +} + +func NewAuthTransportHandler(nextHandler TransportHandler, auth string) *AuthTransportHandler { + return &AuthTransportHandler{nextHandler: nextHandler, auth: md5.Sum([]byte(auth))} +} + +var errAuthFailed = errors.New("auth failed") + +func (h *AuthTransportHandler) Handle(conn net.Conn) error { + var auth [md5.Size]byte + if _, err := io.ReadFull(conn, auth[:]); err != nil { + return fmt.Errorf("failed to read auth header: %w", err) + } + + if auth != h.auth { + discardRead(conn, time.Second*15) + return errAuthFailed + } + + return h.nextHandler.Handle(conn) +} diff --git a/core/client.go b/core/client.go index d1d1142..754fa31 100644 --- a/core/client.go +++ b/core/client.go @@ -18,124 +18,146 @@ package core import ( - "crypto/md5" + "bytes" + "context" + "crypto/sha256" "crypto/tls" "crypto/x509" + "encoding/hex" + "errors" "fmt" "github.com/IrineSistiana/ctunnel" - "log" + "io/ioutil" "net" + "strings" "time" ) type Client struct { - Listener net.Listener - ServerAddr string - NoTLS bool - Auth string + BindAddr string + DstAddr string + Websocket bool + WebsocketPath string + Mux int + Auth string + ServerName string - CertPool *x509.CertPool + CA string + CertHash string InsecureSkipVerify bool - Timeout time.Duration - AndroidVPNMode bool - TFO bool - Mux int - - dialer net.Dialer - auth [16]byte - tlsConfig *tls.Config - muxPool *muxPool // not nil if Mux > 0 + + IdleTimeout time.Duration + AndroidVPNMode bool + TFO bool + + testListener net.Listener } +var errEmptyCAFile = errors.New("no valid certificate was found in the ca file") + func (c *Client) ActiveAndServe() error { - c.dialer = net.Dialer{ - Timeout: time.Second * 5, - Control: GetControlFunc(&TcpConfig{AndroidVPN: c.AndroidVPNMode, EnableTFO: c.TFO}), + + var l net.Listener + if c.testListener != nil { + l = c.testListener + } else { + var err error + lc := net.ListenConfig{} + l, err = lc.Listen(context.Background(), "tcp", c.BindAddr) + if err != nil { + return err + } } - if !c.NoTLS { - c.tlsConfig = new(tls.Config) - c.tlsConfig.NextProtos = []string{"http/1.1", "h2"} - c.tlsConfig.ServerName = c.ServerName - c.tlsConfig.RootCAs = c.CertPool - c.tlsConfig.InsecureSkipVerify = c.InsecureSkipVerify + if len(c.ServerName) == 0 { + c.ServerName = strings.SplitN(c.DstAddr, ":", 2)[0] } - if len(c.Auth) > 0 { - c.auth = md5.Sum([]byte(c.Auth)) + var rootCAs *x509.CertPool + if len(c.CA) != 0 { + rootCAs = x509.NewCertPool() + certPEMBlock, err := ioutil.ReadFile(c.CA) + if err != nil { + return fmt.Errorf("cannot read ca file: %w", err) + } + if ok := rootCAs.AppendCertsFromPEM(certPEMBlock); !ok { + return errEmptyCAFile + } } - if c.Mux > 0 { - c.muxPool = newMuxPool(c.dialServerConn, c.Mux) + dialer := &net.Dialer{ + Timeout: time.Second * 5, + Control: GetControlFunc(&TcpConfig{AndroidVPN: c.AndroidVPNMode, EnableTFO: c.TFO}), } - for { - localConn, err := c.Listener.Accept() + var chb []byte + if len(c.CertHash) != 0 { + b, err := hex.DecodeString(c.CertHash) if err != nil { - return fmt.Errorf("l.Accept(): %w", err) + return fmt.Errorf("invalid cert hash: %w", err) } - reduceTCPLoopbackSocketBuf(localConn) + chb = b + } - go func() { - defer localConn.Close() - - var serverConn net.Conn - if c.Mux > 0 { - stream, _, err := c.muxPool.GetStream() - if err != nil { - log.Printf("ERROR: muxPool.GetStream: %v", err) - return + tlsConfig := &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + ServerName: c.ServerName, + RootCAs: rootCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + VerifyConnection: func(state tls.ConnectionState) error { + if len(chb) != 0 { + cert := state.PeerCertificates[0] + h := sha256.Sum256(cert.RawTBSCertificate) + if bytes.Equal(h[:len(chb)], chb) { + return nil } - serverConn = stream - } else { - conn, err := c.dialServerConn() - if err != nil { - log.Printf("ERROR: dialServerConn: %v", err) - return - } - serverConn = conn + return fmt.Errorf("cert hash mismatch, recieved cert hash is [%s]", hex.EncodeToString(h[:])) } - defer serverConn.Close() - if err := ctunnel.OpenTunnel(localConn, serverConn, c.Timeout); err != nil { - log.Printf("ERROR: ActiveAndServe: openTunnel: %v", err) + if state.Version != tls.VersionTLS13 { + return fmt.Errorf("unsafe tls version %d", state.Version) } - }() + return nil + }, } -} -func (c *Client) dialServerConn() (net.Conn, error) { - serverConn, err := c.dialer.Dial("tcp", c.ServerAddr) - if err != nil { - return nil, err + var transport Transport + if c.Websocket { + transport = NewWebsocketTransport(c.DstAddr, c.ServerName, c.WebsocketPath, tlsConfig, dialer) + } else { + transport = NewRawConnTransport(c.DstAddr, dialer) + transport = NewTLSTransport(transport, tlsConfig) } - if !c.NoTLS { - serverTLSConn := tls.Client(serverConn, c.tlsConfig) - if err := tls13HandshakeWithTimeout(serverTLSConn, time.Second*5); err != nil { - serverTLSConn.Close() - return nil, err - } - serverConn = serverTLSConn + if len(c.Auth) > 0 { + transport = NewAuthTransport(transport, c.Auth) } - // write auth - if len(c.Auth) > 0 { - if _, err := serverConn.Write(c.auth[:]); err != nil { - serverConn.Close() - return nil, fmt.Errorf("failed to write auth: %w", err) + transport = NewMuxTransport(transport, c.Mux) + + for { + clientConn, err := l.Accept() + if err != nil { + return err } - } - // write mode - mode := modePlain - if c.Mux > 0 { - mode = modeMux - } - if _, err := serverConn.Write([]byte{mode}); err != nil { - serverConn.Close() - return nil, fmt.Errorf("failed to write mode: %w", err) + go func() { + defer clientConn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + serverConn, err := transport.Dial(ctx) + if err != nil { + errLogger.Printf("failed to dial server connection: %v", err) + return + } + defer serverConn.Close() + + err = ctunnel.OpenTunnel(clientConn, serverConn, c.IdleTimeout) + if err != nil { + logConnErr(clientConn, fmt.Errorf("tunnel closed: %w", err)) + } + }() } - return serverConn, nil } diff --git a/core/core_test.go b/core/core_test.go index c654067..405d1a6 100644 --- a/core/core_test.go +++ b/core/core_test.go @@ -19,21 +19,23 @@ package core import ( "bytes" + "crypto/sha256" "crypto/tls" - "crypto/x509" + "encoding/hex" + "fmt" "io" - "log" "math/rand" "net" + "sync" "testing" "time" ) func Test_main(t *testing.T) { dataSize := 512 * 1024 + b := make([]byte, dataSize) + rand.Read(b) randData := func() []byte { - b := make([]byte, dataSize) - rand.Read(b) return b } @@ -68,26 +70,35 @@ func Test_main(t *testing.T) { }() // test1 - test := func(t *testing.T, mux int, noTLS bool) { - // start server - _, keyPEM, certPEM, err := GenerateCertificate("example.com") - cert, err := tls.X509KeyPair(certPEM, keyPEM) + test := func(t *testing.T, mux int, ws bool, wsPath string, auth string) { + serverListener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } + defer serverListener.Close() - serverListener, err := net.Listen("tcp", "127.0.0.1:0") + _, x509cert, keyPEM, certPEM, err := GenerateCertificate("") + if err != nil { + t.Fatal(err) + } + h := sha256.Sum256(x509cert.RawTBSCertificate) + certHash := hex.EncodeToString(h[:]) + + cert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { t.Fatal(err) } - defer serverListener.Close() server := Server{ - Listener: serverListener, - Dst: echoListener.Addr().String(), - Certificates: []tls.Certificate{cert}, - Timeout: timeout, + DstAddr: echoListener.Addr().String(), + Websocket: ws, + WebsocketPath: wsPath, + Auth: auth, + IdleTimeout: timeout, + testListener: serverListener, + testCert: &cert, } + go server.ActiveAndServe() // start client @@ -97,62 +108,61 @@ func Test_main(t *testing.T) { } defer clientListener.Close() - caPool := x509.NewCertPool() - ok := caPool.AppendCertsFromPEM(certPEM) - if !ok { - t.Fatal("appendCertsFromPEM failed") - } - client := Client{ - Listener: clientListener, - ServerAddr: serverListener.Addr().String(), - ServerName: "example.com", - CertPool: caPool, - InsecureSkipVerify: false, + DstAddr: serverListener.Addr().String(), + Websocket: ws, + WebsocketPath: wsPath, Mux: mux, - Timeout: timeout, - AndroidVPNMode: false, - TFO: false, + Auth: auth, + CertHash: certHash, + InsecureSkipVerify: true, + IdleTimeout: timeout, + testListener: clientListener, } - go client.ActiveAndServe() - log.Printf("echo: %v, server: %v client: %v", echoListener.Addr(), serverListener.Addr(), clientListener.Addr()) + go client.ActiveAndServe() + wg := new(sync.WaitGroup) for i := 0; i < 10; i++ { - conn, err := net.Dial("tcp", clientListener.Addr().String()) - if err != nil { - t.Fatal(err) - } - data := randData() - buf := make([]byte, dataSize) - _, err = conn.Write(data) - if err != nil { - t.Fatal(err) - } + wg.Add(1) + go func() { + defer wg.Done() + conn, err := net.Dial("tcp", clientListener.Addr().String()) + if err != nil { + t.Error(err) + return + } + data := randData() + buf := make([]byte, dataSize) + _, err = conn.Write(data) + if err != nil { + t.Error(err) + return + } - _, err = io.ReadFull(conn, buf) - if err != nil { - t.Fatal(err) - } - if bytes.Equal(data, buf) == false { - t.Fatal("corrupted data") - } + _, err = io.ReadFull(conn, buf) + if err != nil { + t.Error(err) + return + } + if bytes.Equal(data, buf) == false { + t.Error("corrupted data") + return + } + }() } + wg.Wait() } - tests := []struct { - name string - mux int - noTLS bool - }{ - {"plain", 0, false}, - {"mux", 5, false}, - {"no tls", 5, true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - test(t, tt.mux, tt.noTLS) - }) + for _, mux := range [...]int{0, 5} { + for _, ws := range [...]bool{false, true} { + for _, wsPath := range [...]string{"", "/123456"} { + for _, auth := range [...]string{"", "123456"} { + t.Run(fmt.Sprintf("mux_%v_ws_%v_wsPath_%v_auth_%v", mux, ws, wsPath, auth), func(t *testing.T) { + test(t, mux, ws, wsPath, auth) + }) + } + } + } } } diff --git a/core/mode.go b/core/logger.go similarity index 68% rename from core/mode.go rename to core/logger.go index 85b562d..855a826 100644 --- a/core/mode.go +++ b/core/logger.go @@ -17,7 +17,19 @@ package core -const ( - modePlain byte = iota - modeMux +import ( + "log" + "net" + "net/http" + "os" ) + +var errLogger = log.New(os.Stderr, "err", log.LstdFlags) + +func logConnErr(conn net.Conn, err error) { + errLogger.Printf("connection %s <-> %s: %v", conn.RemoteAddr(), conn.LocalAddr(), err) +} + +func logRequestErr(r *http.Request, err error) { + errLogger.Printf("request from %s %s: %v", r.RemoteAddr, r.RequestURI, err) +} diff --git a/core/server.go b/core/server.go index b9a1a1f..6d25dd0 100644 --- a/core/server.go +++ b/core/server.go @@ -18,146 +18,100 @@ package core import ( - "bytes" - "crypto/md5" + "context" "crypto/tls" + "errors" "fmt" - "github.com/IrineSistiana/ctunnel" - "github.com/xtaci/smux" - "io" "log" "net" "time" ) type Server struct { - Listener net.Listener - Dst string - NoTLS bool - Auth string + BindAddr string + DstAddr string - Certificates []tls.Certificate - Timeout time.Duration + Websocket bool + WebsocketPath string - auth [16]byte - tlsConfig *tls.Config + Cert, Key, ServerName string + Auth string + TFO bool + IdleTimeout time.Duration + NoTLS bool + + testListener net.Listener + testCert *tls.Certificate } +var errMissingCertOrKey = errors.New("one of cert or key argument is missing") + func (s *Server) ActiveAndServe() error { - if !s.NoTLS { - s.tlsConfig = new(tls.Config) - s.tlsConfig.NextProtos = []string{"h2"} - s.tlsConfig.Certificates = s.Certificates + var l net.Listener + if s.testListener != nil { + l = s.testListener + } else { + var err error + lc := net.ListenConfig{Control: GetControlFunc(&TcpConfig{EnableTFO: s.TFO})} + l, err = lc.Listen(context.Background(), "tcp", s.BindAddr) + if err != nil { + return err + } } + var transportHandler TransportHandler + transportHandler = NewBaseTransportHandler(s.DstAddr, s.IdleTimeout) + transportHandler = NewMuxTransportHandler(transportHandler) if len(s.Auth) > 0 { - s.auth = md5.Sum([]byte(s.Auth)) + transportHandler = NewAuthTransportHandler(transportHandler, s.Auth) } - for { - clientConn, err := s.Listener.Accept() - if err != nil { - return fmt.Errorf("l.Accept(): %w", err) - } - - go func() { - defer clientConn.Close() - - if !s.NoTLS { - clientTLSConn := tls.Server(clientConn, s.tlsConfig) - // handshake - if err := tls13HandshakeWithTimeout(clientTLSConn, time.Second*5); err != nil { - log.Printf("ERROR: %s, tls13HandshakeWithTimeout: %v", clientConn.RemoteAddr(), err) - return - } - clientConn = clientTLSConn - } - - // check auth - if len(s.Auth) > 0 { - auth := make([]byte, 16) - if _, err := io.ReadFull(clientConn, auth); err != nil { - log.Printf("ERROR: %s, read client auth header: %v", clientConn.RemoteAddr(), err) - return + if !s.NoTLS { + var certificate tls.Certificate + if s.testCert != nil { + certificate = *s.testCert + } else { + switch { + case len(s.Cert) == 0 && len(s.Key) == 0: // no cert and key + dnsName, _, keyPEM, certPEM, err := GenerateCertificate(s.ServerName) + if err != nil { + return fmt.Errorf("failed to generate temp cert: %w", err) } - if !bytes.Equal(s.auth[:], auth) { - log.Printf("ERROR: %s, auth failed", clientConn.RemoteAddr()) - discardRead(clientConn, time.Second*15) - return + log.Printf("warnning: you are using a tmp certificate with dns name: %s", dnsName) + cer, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return fmt.Errorf("cannot load x509 key pair from memory: %w", err) } - } - - // mode - header := make([]byte, 1) - if _, err := io.ReadFull(clientConn, header); err != nil { - log.Printf("ERROR: %s, read client mode header: %v", clientConn.RemoteAddr(), err) - return - } - switch header[0] { - case modePlain: - if err := s.handleClientConn(clientConn); err != nil { - log.Printf("ERROR: %s, handleClientConn: %v", clientConn.RemoteAddr(), err) - return - } - case modeMux: - err := s.handleClientMux(clientConn) + certificate = cer + case len(s.Cert) != 0 && len(s.Key) != 0: // has a cert and a key + cer, err := tls.LoadX509KeyPair(s.Cert, s.Key) //load cert if err != nil { - log.Printf("ERROR: %s, handleClientMux: %v", clientConn.RemoteAddr(), err) - return + return fmt.Errorf("cannot load x509 key pair from disk: %w", err) } + certificate = cer default: - log.Printf("ERROR: %s, invalid header %d", clientConn.RemoteAddr(), header[0]) - return + return errMissingCertOrKey } - }() - } -} - -func discardRead(c net.Conn, t time.Duration) { - c.SetDeadline(time.Now().Add(t)) - buf := make([]byte, 512) - for { - _, err := c.Read(buf) - if err != nil { - return } - } -} -func (s *Server) handleClientConn(cc net.Conn) (err error) { - dstConn, err := net.Dial("tcp", s.Dst) - if err != nil { - return fmt.Errorf("net.Dial: %v", err) - } - reduceTCPLoopbackSocketBuf(dstConn) - defer dstConn.Close() - - if err := ctunnel.OpenTunnel(dstConn, cc, s.Timeout); err != nil { - return fmt.Errorf("openTunnel: %v", err) - } - return nil -} + tlsConfig := &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + Certificates: []tls.Certificate{certificate}, + VerifyConnection: func(state tls.ConnectionState) error { + if state.Version != tls.VersionTLS13 { + return fmt.Errorf("unsafe tls version %d", state.Version) + } + return nil + }, + } -func (s *Server) handleClientMux(cc net.Conn) (err error) { - sess, err := smux.Server(cc, muxConfig) - if err != nil { - return err + l = tls.NewListener(l, tlsConfig) } - defer sess.Close() - for { - stream, err := sess.AcceptStream() - if err != nil { - return nil // suppress smux err - } - go func() { - defer stream.Close() - if err := s.handleClientConn(stream); err != nil { - log.Printf("ERROR: handleClientMux: %s, handleClientConn: %v", stream.RemoteAddr(), err) - return - } - }() + if s.Websocket { + return ListenWebsocket(l, s.WebsocketPath, transportHandler) } + return ListenRawConn(l, transportHandler) } diff --git a/core/sip003.go b/core/sip003.go index 769db25..7e30123 100644 --- a/core/sip003.go +++ b/core/sip003.go @@ -24,7 +24,7 @@ import ( "strings" ) -var ErrBrokenSIP003Args = errors.New("invalid SIP003 args") +var errBrokenSIP003Args = errors.New("invalid SIP003 args") //SIP003Args contains sip003 args type SIP003Args struct { @@ -55,7 +55,7 @@ func GetSIP003Args() (*SIP003Args, error) { if srhOk || srpOk || slhOk || slpOk || spoOk { // has at least one arg if !(srhOk && srpOk && slhOk && slpOk) { // but not has all 4 args - return nil, ErrBrokenSIP003Args + return nil, errBrokenSIP003Args } } else { return nil, nil // can't find any sip003 arg diff --git a/core/smux.go b/core/smux.go index 436a5df..1544b99 100644 --- a/core/smux.go +++ b/core/smux.go @@ -18,13 +18,20 @@ package core import ( + "context" "fmt" "github.com/xtaci/smux" + "io" "log" "net" "sync" ) +const ( + modePlain byte = iota + modeMux +) + var muxConfig = &smux.Config{ Version: 1, KeepAliveDisabled: true, @@ -33,8 +40,8 @@ var muxConfig = &smux.Config{ MaxStreamBuffer: 32 * 1024, } -type muxPool struct { - dialFunc func() (c net.Conn, err error) +type MuxTransport struct { + nextTransport Transport maxConcurrent int sm sync.Mutex @@ -45,11 +52,8 @@ type muxPool struct { dialWaiting int } -func newMuxPool(dialFunc func() (c net.Conn, err error), maxConcurrent int) *muxPool { - if maxConcurrent < 1 { - panic(fmt.Sprintf("invalid maxConcurrent: %d", maxConcurrent)) - } - return &muxPool{dialFunc: dialFunc, maxConcurrent: maxConcurrent, sess: map[*smux.Session]struct{}{}} +func NewMuxTransport(subTransport Transport, maxConcurrent int) *MuxTransport { + return &MuxTransport{nextTransport: subTransport, maxConcurrent: maxConcurrent, sess: map[*smux.Session]struct{}{}} } type dialCall struct { @@ -58,21 +62,32 @@ type dialCall struct { err error } -func (m *muxPool) GetStream() (stream *smux.Stream, sess *smux.Session, err error) { - if stream, sess, ok := m.tryGetStream(); ok { - return stream, sess, nil +func (m *MuxTransport) Dial(ctx context.Context) (net.Conn, error) { + if m.maxConcurrent <= 1 { + conn, err := m.nextTransport.Dial(ctx) + if err != nil { + return nil, err + } + if _, err := conn.Write([]byte{modePlain}); err != nil { + conn.Close() + return nil, fmt.Errorf("failed to write mux header: %w", err) + } + } + + if stream := m.tryGetStream(); stream != nil { + return stream, nil } - return m.tryGetStreamFlash() + return m.tryGetStreamFlash(ctx) } -func (m *muxPool) MarkDead(sess *smux.Session) { +func (m *MuxTransport) MarkDead(sess *smux.Session) { m.sm.Lock() defer m.sm.Unlock() delete(m.sess, sess) sess.Close() } -func (m *muxPool) tryGetStream() (stream *smux.Stream, sess *smux.Session, ok bool) { +func (m *MuxTransport) tryGetStream() (stream *smux.Stream) { m.sm.Lock() defer m.sm.Unlock() for sess := range m.sess { @@ -84,18 +99,18 @@ func (m *muxPool) tryGetStream() (stream *smux.Stream, sess *smux.Session, ok bo delete(m.sess, sess) continue } - return s, sess, true + return s } } - return nil, nil, false + return nil } -func (m *muxPool) tryGetStreamFlash() (stream *smux.Stream, sess *smux.Session, err error) { +func (m *MuxTransport) tryGetStreamFlash(ctx context.Context) (*smux.Stream, error) { var call *dialCall m.dm.Lock() if m.dialing == nil || (m.dialing != nil && m.dialWaiting >= m.maxConcurrent) { m.dialWaiting = 0 - m.dialing = m.dialSessLocked() // needs a new dial + m.dialing = m.dialSessLocked(ctx) // needs a new dial } else { m.dialWaiting++ } @@ -103,27 +118,33 @@ func (m *muxPool) tryGetStreamFlash() (stream *smux.Stream, sess *smux.Session, defer m.dm.Unlock() <-call.done - sess = call.s - err = call.err + sess := call.s + err := call.err if err != nil { - return nil, nil, err + return nil, err } - stream, err = sess.OpenStream() - return stream, sess, err + return sess.OpenStream() } -func (m *muxPool) dialSessLocked() (call *dialCall) { +func (m *MuxTransport) dialSessLocked(ctx context.Context) (call *dialCall) { call = &dialCall{ done: make(chan struct{}), } go func() { - c, err := m.dialFunc() + c, err := m.nextTransport.Dial(ctx) if err != nil { call.err = err close(call.done) return } + if _, err := c.Write([]byte{modeMux}); err != nil { + c.Close() + call.err = fmt.Errorf("failed to write mux header: %w", err) + close(call.done) + return + } + sess, err := smux.Client(c, muxConfig) call.s = sess call.err = err @@ -139,3 +160,45 @@ func (m *muxPool) dialSessLocked() (call *dialCall) { }() return call } + +type MuxTransportHandler struct { + nextHandler TransportHandler +} + +func NewMuxTransportHandler(nextHandler TransportHandler) *MuxTransportHandler { + return &MuxTransportHandler{nextHandler: nextHandler} +} + +func (h *MuxTransportHandler) Handle(conn net.Conn) error { + header := make([]byte, 1) + if _, err := io.ReadFull(conn, header); err != nil { + return fmt.Errorf("failed to read mux header: %w", err) + } + + switch header[0] { + case modePlain: + return h.nextHandler.Handle(conn) + case modeMux: + sess, err := smux.Server(conn, muxConfig) + if err != nil { + return err + } + defer sess.Close() + + for { + stream, err := sess.AcceptStream() + if err != nil { + return nil // suppress smux err + } + go func() { + defer stream.Close() + if err := h.nextHandler.Handle(stream); err != nil { + logConnErr(stream, err) + return + } + }() + } + default: + return fmt.Errorf("invalid mux header %d", header[0]) + } +} diff --git a/core/transport.go b/core/transport.go new file mode 100644 index 0000000..8fe4321 --- /dev/null +++ b/core/transport.go @@ -0,0 +1,111 @@ +// Copyright (C) 2020-2021, IrineSistiana +// +// This file is part of simple-tls. +// +// simple-tls is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// simple-tls is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package core + +import ( + "context" + "crypto/tls" + "fmt" + "github.com/IrineSistiana/ctunnel" + "net" + "time" +) + +type Transport interface { + Dial(ctx context.Context) (net.Conn, error) +} + +type TransportHandler interface { + Handle(conn net.Conn) error +} + +type RawConnTransport struct { + addr string + dialer *net.Dialer +} + +func (t *RawConnTransport) Dial(ctx context.Context) (net.Conn, error) { + return t.dialer.DialContext(ctx, "tcp", t.addr) +} + +func NewRawConnTransport(addr string, dialer *net.Dialer) *RawConnTransport { + return &RawConnTransport{addr: addr, dialer: dialer} +} + +type TLSTransport struct { + nextTransport Transport + conf *tls.Config +} + +func (t *TLSTransport) Dial(ctx context.Context) (net.Conn, error) { + conn, err := t.nextTransport.Dial(ctx) + if err != nil { + return nil, err + } + + tlsConn := tls.Client(conn, t.conf) + if err := tlsConn.HandshakeContext(ctx); err != nil { + tlsConn.Close() + return nil, err + } + return tlsConn, nil +} + +func NewTLSTransport(nextTransport Transport, conf *tls.Config) *TLSTransport { + return &TLSTransport{nextTransport: nextTransport, conf: conf} +} + +type BaseTransportHandler struct { + dst string + idleTimeout time.Duration +} + +func (h *BaseTransportHandler) Handle(conn net.Conn) error { + dstConn, err := net.Dial("tcp", h.dst) + if err != nil { + return fmt.Errorf("cannot connect to the dst: %w", err) + } + reduceTCPLoopbackSocketBuf(dstConn) + defer dstConn.Close() + + if err := ctunnel.OpenTunnel(dstConn, conn, h.idleTimeout); err != nil { + return fmt.Errorf("tunnel closed: %w", err) + } + return nil +} + +func NewBaseTransportHandler(dst string, idleTimeout time.Duration) *BaseTransportHandler { + return &BaseTransportHandler{dst: dst, idleTimeout: idleTimeout} +} + +func ListenRawConn(l net.Listener, nextHandler TransportHandler) error { + for { + conn, err := l.Accept() + if err != nil { + return err + } + + go func() { + defer conn.Close() + err := nextHandler.Handle(conn) + if err != nil { + logConnErr(conn, err) + } + }() + } +} diff --git a/core/utils.go b/core/utils.go index cd47183..50cd371 100644 --- a/core/utils.go +++ b/core/utils.go @@ -21,29 +21,17 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" - "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "fmt" "math/big" mathRand "math/rand" + "net" "time" ) -func tls13HandshakeWithTimeout(c *tls.Conn, timeout time.Duration) error { - c.SetDeadline(time.Now().Add(timeout)) - if err := c.Handshake(); err != nil { - return err - } - c.SetDeadline(time.Time{}) - if cVar := c.ConnectionState().Version; cVar != tls.VersionTLS13 { - return fmt.Errorf("unexpected tls version: %x", cVar) - } - return nil -} - -func GenerateCertificate(serverName string) (dnsName string, keyPEM, certPEM []byte, err error) { +func GenerateCertificate(serverName string) (dnsName string, cert *x509.Certificate, keyPEM, certPEM []byte, err error) { //priv key key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { @@ -89,7 +77,12 @@ func GenerateCertificate(serverName string) (dnsName string, keyPEM, certPEM []b keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b}) certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - return dnsName, keyPEM, certPEM, nil + cert, err = x509.ParseCertificate(certDER) + if err != nil { + return + } + + return } func randServerName() string { @@ -106,3 +99,14 @@ func randStr(length int) string { } return string(b) } + +func discardRead(c net.Conn, t time.Duration) { + c.SetDeadline(time.Now().Add(t)) + buf := make([]byte, 512) + for { + _, err := c.Read(buf) + if err != nil { + return + } + } +} diff --git a/core/websocket.go b/core/websocket.go new file mode 100644 index 0000000..1bafa5e --- /dev/null +++ b/core/websocket.go @@ -0,0 +1,130 @@ +// Copyright (C) 2020-2021, IrineSistiana +// +// This file is part of simple-tls. +// +// simple-tls is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// simple-tls is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package core + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "golang.org/x/net/http2" + "net" + "net/http" + "net/url" + "nhooyr.io/websocket" + "time" +) + +type WebsocketTransport struct { + u string + op *websocket.DialOptions +} + +func NewWebsocketTransport(serverAddr, serverName, urlPath string, tlsConfig *tls.Config, dialer *net.Dialer) *WebsocketTransport { + u := url.URL{ + Scheme: "https", + Host: serverName, + Path: urlPath, + } + t := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, network, serverAddr) + }, + TLSClientConfig: tlsConfig, + TLSHandshakeTimeout: time.Second * 5, + DisableCompression: true, + ResponseHeaderTimeout: time.Second * 5, + ExpectContinueTimeout: 0, + WriteBufferSize: 32 * 1024, + ReadBufferSize: 32 * 1024, + ForceAttemptHTTP2: true, + } + + t2, _ := http2.ConfigureTransports(t) + t2.ReadIdleTimeout = time.Second * 30 + t2.PingTimeout = time.Second * 10 + + return &WebsocketTransport{ + u: u.String(), + op: &websocket.DialOptions{ + HTTPClient: &http.Client{ + Transport: t, + Timeout: time.Second * 10, + }, + CompressionMode: websocket.CompressionDisabled, + }, + } +} + +func (p *WebsocketTransport) Dial(ctx context.Context) (net.Conn, error) { + wsConn, _, err := websocket.Dial(ctx, p.u, p.op) + if err != nil { + return nil, err + } + return websocket.NetConn(context.Background(), wsConn, websocket.MessageBinary), nil +} + +type wsHttpHandler struct { + nextHandler TransportHandler + path string +} + +var errInvalidPath = errors.New("invalid request path") + +func (h *wsHttpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if len(h.path) != 0 && r.URL.Path != h.path { + w.WriteHeader(http.StatusNotFound) + logRequestErr(r, errInvalidPath) + return + } + + wsConn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + CompressionMode: websocket.CompressionDisabled, + }) + if err != nil { + logRequestErr(r, fmt.Errorf("cannot accept websocket connection: %w", err)) + return + } + + clientConn := websocket.NetConn(context.Background(), wsConn, websocket.MessageBinary) + defer clientConn.Close() + + if err := h.nextHandler.Handle(clientConn); err != nil { + logRequestErr(r, err) + return + } + return +} + +func ListenWebsocket(l net.Listener, path string, nextHandler TransportHandler) error { + httpServer := &http.Server{ + Handler: &wsHttpHandler{ + nextHandler: nextHandler, + path: path, + }, + ReadTimeout: time.Second * 10, + ReadHeaderTimeout: time.Second * 10, + WriteTimeout: time.Second * 10, + } + + http2.ConfigureServer(httpServer, &http2.Server{ + IdleTimeout: time.Second * 45, + }) + + return httpServer.Serve(l) +} diff --git a/go.mod b/go.mod index 3031f22..7a972a0 100644 --- a/go.mod +++ b/go.mod @@ -5,5 +5,10 @@ go 1.14 require ( github.com/IrineSistiana/ctunnel v0.0.0-20210409113947-9756ebc29fdb github.com/xtaci/smux v1.5.16 + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.7.0 // indirect + go.uber.org/zap v1.19.1 + golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 golang.org/x/sys v0.0.0-20211003122950-b1ebd4e1001c + nhooyr.io/websocket v1.8.7 ) diff --git a/go.sum b/go.sum index f10d9a1..27a731a 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,99 @@ github.com/IrineSistiana/ctunnel v0.0.0-20210409113947-9756ebc29fdb h1:b7t7X5hUjO3ZTDfe2TZOKHGe7Aq8yfNzwwcS4Qg7QTA= github.com/IrineSistiana/ctunnel v0.0.0-20210409113947-9756ebc29fdb/go.mod h1:xcxj4YojLW9ri1tSYmHFtCSaphfxXEGgp+9BnbVpDYk= +github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8= +github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/xtaci/smux v1.5.15 h1:6hMiXswcleXj5oNfcJc+DXS8Vj36XX2LaX98udog6Kc= github.com/xtaci/smux v1.5.15/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY= github.com/xtaci/smux v1.5.16 h1:FBPYOkW8ZTjLKUM4LI4xnnuuDC8CQ/dB04HD519WoEk= github.com/xtaci/smux v1.5.16/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec= +go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= +go.uber.org/zap v1.19.1 h1:ue41HOKd1vGURxrmeKIgELGb3jPW9DMUDGtsinblHwI= +go.uber.org/zap v1.19.1/go.mod h1:j3DNczoxDZroyBnOT1L/Q79cfUMGZxlv/9dzN7SM1rI= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 h1:CIJ76btIcR3eFI5EgSo6k1qKw9KJexJuRLI9G7Hp5wE= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200523222454-059865788121 h1:rITEj+UZHYC927n8GT97eC3zrpzXdb/voyeOuVKS46o= golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57 h1:F5Gozwx4I1xtr/sr/8CFbb57iKi3297KFs0QDbGN60A= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211003122950-b1ebd4e1001c h1:EyJTLQbOxvk8V6oDdD8ILR1BOs3nEJXThD6aqsiPNkM= golang.org/x/sys v0.0.0-20211003122950-b1ebd4e1001c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g= +nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= diff --git a/main.go b/main.go index ec82798..64b1f76 100644 --- a/main.go +++ b/main.go @@ -18,20 +18,17 @@ package main import ( - "context" - "crypto/tls" + "crypto/sha256" "crypto/x509" - "encoding/base64" + "encoding/pem" "flag" "fmt" "io/ioutil" "log" - "net" "os" "os/signal" "runtime" "strconv" - "strings" "syscall" "time" @@ -46,12 +43,12 @@ func main() { osSignals := make(chan os.Signal, 1) signal.Notify(osSignals, os.Interrupt, os.Kill, syscall.SIGTERM) s := <-osSignals - log.Printf("main: exiting: signal: %v", s) + log.Printf("exiting: signal: %v", s) os.Exit(0) }() - var bindAddr, dstAddr, auth, serverName, cca, ca, cert, key string - var noTLS, insecureSkipVerify, isServer, tfo, vpn, genCert, showVersion bool + var bindAddr, dstAddr, auth, serverName, wsPath, ca, cert, key, hashCert, certHash string + var ws, insecureSkipVerify, isServer, noTLS, tfo, vpn, genCert, showVersion bool var cpu, mux int var timeout time.Duration var timeoutFlag int @@ -61,20 +58,24 @@ func main() { commandLine.StringVar(&bindAddr, "b", "", "[Host:Port] bind address") commandLine.StringVar(&dstAddr, "d", "", "[Host:Port] destination address") commandLine.StringVar(&auth, "auth", "", "server password") - commandLine.BoolVar(&noTLS, "no-tls", false, "disable TLS (debug only)") + + commandLine.BoolVar(&ws, "ws", false, "websocket mode") + commandLine.StringVar(&wsPath, "ws-path", "", "websocket path") // client only commandLine.IntVar(&mux, "mux", 0, "enable mux") commandLine.StringVar(&serverName, "n", "", "server name") commandLine.StringVar(&ca, "ca", "", "PEM CA file path") - commandLine.StringVar(&cca, "cca", "", "base64 encoded PEM CA") + commandLine.StringVar(&certHash, "cert-hash", "", "server certificate hash") + commandLine.BoolVar(&insecureSkipVerify, "no-verify", false, "client won't verify the server's certificate chain and host name") commandLine.BoolVar(&vpn, "V", false, "DO NOT USE, this is for android vpn mode") // server only - commandLine.BoolVar(&isServer, "s", false, "is server") - commandLine.StringVar(&cert, "cert", "", "[Path] PEM cert file") - commandLine.StringVar(&key, "key", "", "[Path] PEM key file") + commandLine.BoolVar(&isServer, "s", false, "run as a server (without this simple-tls runs as a client)") + commandLine.StringVar(&cert, "cert", "", "PEM cert file") + commandLine.StringVar(&key, "key", "", "PEM key file") + commandLine.BoolVar(&noTLS, "no-tls", false, "disable server tls") // etc commandLine.IntVar(&timeoutFlag, "t", 300, "timeout in sec") @@ -84,73 +85,93 @@ func main() { // helper commands commandLine.BoolVar(&genCert, "gen-cert", false, "[This is a helper function]: generate a certificate with dns name [-n], store it's key to [-key] and cert to [-cert], print cert in base64 format without padding characters") commandLine.BoolVar(&showVersion, "v", false, "output version info and exit") + commandLine.StringVar(&hashCert, "hash-cert", "", "print the hashes for the certificate") err := commandLine.Parse(os.Args[1:]) if err != nil { - log.Fatalf("main: invalid arg: %v", err) + log.Fatalf("invalid arg: %v", err) } // display version if showVersion { println(version) - os.Exit(0) + return } // gen cert if genCert { - log.Print("main: WARNING: generating PEM encoded key and cert") + log.Print("generating PEM encoded key and cert") - dnsName, keyPEM, certPEM, err := core.GenerateCertificate(serverName) + dnsName, certOut, keyPEM, certPEM, err := core.GenerateCertificate(serverName) if err != nil { - log.Fatalf("main: generateCertificate: %v", err) + log.Fatalf("generateCertificate: %v", err) } // key if len(key) == 0 { key = dnsName + ".key" } - log.Printf("main: generating PEM encoded key to %s", key) + log.Printf("generating PEM encoded key to %s", key) keyFile, err := os.Create(key) if err != nil { - log.Fatalf("main: creating key file [%s]: %v", key, err) + log.Fatalf("creating key file [%s]: %v", key, err) } defer keyFile.Close() _, err = keyFile.Write(keyPEM) if err != nil { - log.Fatalf("main: writing key file [%s]: %v", key, err) + log.Fatalf("writing key file [%s]: %v", key, err) } // cert if len(cert) == 0 { cert = dnsName + ".cert" } - log.Printf("main: generating PEM encoded cert to %s", cert) + log.Printf("generating PEM encoded cert to %s", cert) certFile, err := os.Create(cert) if err != nil { - log.Fatalf("main: creating cert file [%s]: %v", cert, err) + log.Fatalf("creating cert file [%s]: %v", cert, err) } defer certFile.Close() _, err = certFile.Write(certPEM) if err != nil { - log.Fatalf("main: writing cert file [%s]: %v", cert, err) + log.Fatalf("writing cert file [%s]: %v", cert, err) } - certBase64 := base64.RawStdEncoding.EncodeToString(certPEM) fmt.Printf("Your new cert dns name is: %s\n", dnsName) - fmt.Print("Your new cert base64 string is:\n") - fmt.Printf("%s\n", certBase64) - fmt.Println("Copy this string and import it to client using -cca option") + fmt.Print("Your new cert hash is:\n") + fmt.Printf("%x\n", sha256.Sum256(certOut.RawTBSCertificate)) + return + } + + if len(hashCert) != 0 { + rawCert, err := ioutil.ReadFile(hashCert) + if err != nil { + log.Fatalf("failed to read cert file: %v", err) + } + b, _ := pem.Decode(rawCert) + if b.Type != "CERTIFICATE" { + log.Fatalf("invaild pem type [%s]", b.Type) + } + + certs, err := x509.ParseCertificates(b.Bytes) + if err != nil { + log.Fatalf("failed to parse cert file: %v", err) + } + for _, cert := range certs { + h := sha256.Sum256(cert.RawTBSCertificate) + fmt.Printf("[%v]: %x\n", cert.Subject, h) + } return } // overwrite args from env sip003Args, err := core.GetSIP003Args() if err != nil { - log.Fatalf("main: sip003 error: %v", err) + log.Fatalf("sip003 error: %v", err) } if sip003Args != nil { - log.Print("main: simple-tls is running as a sip003 plugin") + log.Print("simple-tls is running as a sip003 plugin") var ok bool var s string @@ -168,20 +189,24 @@ func main() { setStrIfNotEmpty(&dstAddr, s) s, _ = sip003Args.SS_PLUGIN_OPTIONS["auth"] setStrIfNotEmpty(&auth, s) - _, ok = sip003Args.SS_PLUGIN_OPTIONS["no-tls"] - noTLS = noTLS || ok + + _, ok = sip003Args.SS_PLUGIN_OPTIONS["ws"] + ws = ws || ok + s, _ = sip003Args.SS_PLUGIN_OPTIONS["ws-path"] + setStrIfNotEmpty(&wsPath, s) // client s, _ = sip003Args.SS_PLUGIN_OPTIONS["n"] setStrIfNotEmpty(&serverName, s) s, _ = sip003Args.SS_PLUGIN_OPTIONS["mux"] if err := setIntIfNotZero(&mux, s); err != nil { - log.Fatalf("main: invalid mux value, %v", err) + log.Fatalf("invalid mux value, %v", err) } s, _ = sip003Args.SS_PLUGIN_OPTIONS["ca"] setStrIfNotEmpty(&ca, s) - s, _ = sip003Args.SS_PLUGIN_OPTIONS["cca"] - setStrIfNotEmpty(&cca, s) + s, _ = sip003Args.SS_PLUGIN_OPTIONS["cert-hash"] + setStrIfNotEmpty(&certHash, s) + _, ok = sip003Args.SS_PLUGIN_OPTIONS["no-verify"] insecureSkipVerify = insecureSkipVerify || ok @@ -192,15 +217,17 @@ func main() { setStrIfNotEmpty(&cert, s) s, _ = sip003Args.SS_PLUGIN_OPTIONS["key"] setStrIfNotEmpty(&key, s) + _, ok = sip003Args.SS_PLUGIN_OPTIONS["no-tls"] + noTLS = noTLS || ok // etc s, _ = sip003Args.SS_PLUGIN_OPTIONS["t"] if err := setIntIfNotZero(&timeoutFlag, s); err != nil { - log.Fatalf("main: invalid timeout value, %v", err) + log.Fatalf("invalid timeout value, %v", err) } s, _ = sip003Args.SS_PLUGIN_OPTIONS["cpu"] if err := setIntIfNotZero(&cpu, s); err != nil { - log.Fatalf("main: invalid cpu number, %v", err) + log.Fatalf("invalid cpu number, %v", err) } _, ok = sip003Args.SS_PLUGIN_OPTIONS["fast-open"] tfo = tfo || ok @@ -218,117 +245,56 @@ func main() { runtime.GOMAXPROCS(cpu) if len(bindAddr) == 0 { - log.Fatal("main: bind addr is required") + log.Fatal("bind addr is required") } if len(dstAddr) == 0 { - log.Fatal("main: destination addr is required") + log.Fatal("destination addr is required") } - log.Printf("main: simple-tls %s (go version: %s, os: %s, arch: %s)", version, runtime.Version(), runtime.GOOS, runtime.GOARCH) + log.Printf("simple-tls %s (go version: %s, os: %s, arch: %s)", version, runtime.Version(), runtime.GOOS, runtime.GOARCH) if isServer { - var certificates []tls.Certificate - if !noTLS { - switch { - case len(cert) == 0 && len(key) == 0: // no cert and key - log.Printf("main: warnning: neither -key nor -cert is specified") - - dnsName, keyPEM, certPEM, err := core.GenerateCertificate(serverName) - if err != nil { - log.Fatalf("main: generateCertificate: %v", err) - } - log.Printf("main: warnning: using tmp certificate %s", dnsName) - cer, err := tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - log.Fatalf("main: X509KeyPair: %v", err) - } - certificates = []tls.Certificate{cer} - case len(cert) != 0 && len(key) != 0: // has cert and key - cer, err := tls.LoadX509KeyPair(cert, key) //load cert - if err != nil { - log.Fatalf("main: LoadX509KeyPair: %v", err) - } - certificates = []tls.Certificate{cer} - default: - log.Fatal("main: server must have a X509 key pair, aka. -cert and -key") - } - } - - lc := net.ListenConfig{Control: core.GetControlFunc(&core.TcpConfig{EnableTFO: tfo})} - l, err := lc.Listen(context.Background(), "tcp", bindAddr) - if err != nil { - log.Fatalf("main: net.Listen: %v", err) - } - server := core.Server{ - Listener: l, - Dst: dstAddr, - NoTLS: noTLS, - Auth: auth, - Certificates: certificates, - Timeout: timeout, + BindAddr: bindAddr, + DstAddr: dstAddr, + Websocket: ws, + WebsocketPath: wsPath, + Cert: cert, + Key: key, + ServerName: serverName, + Auth: auth, + TFO: tfo, + IdleTimeout: timeout, } - - err = server.ActiveAndServe() - if err != nil { - log.Fatalf("main: doServer: %v", err) + if err := server.ActiveAndServe(); err != nil { + log.Fatalf("server exited: %v", err) } + log.Print("server exited") + return } else { // do client - var rootCAs *x509.CertPool - if !noTLS { - if len(serverName) == 0 { - serverName = strings.SplitN(dstAddr, ":", 2)[0] - } - - switch { - case len(cca) != 0: - cca = strings.TrimRight(cca, "=") - pem, err := base64.RawStdEncoding.DecodeString(cca) - if err != nil { - log.Fatalf("main: base64.RawStdEncoding.DecodeString: %v", err) - } - - rootCAs = x509.NewCertPool() - if ok := rootCAs.AppendCertsFromPEM(pem); !ok { - log.Fatal("main: AppendCertsFromPEM failed, cca is invalid") - } - case len(ca) != 0: - rootCAs = x509.NewCertPool() - certPEMBlock, err := ioutil.ReadFile(ca) - if err != nil { - log.Fatalf("main: ReadFile ca [%s], %v", ca, err) - } - if ok := rootCAs.AppendCertsFromPEM(certPEMBlock); !ok { - log.Fatal("main: AppendCertsFromPEM failed, ca is invalid") - } - } - } - - lc := net.ListenConfig{} - l, err := lc.Listen(context.Background(), "tcp", bindAddr) - if err != nil { - log.Fatalf("main: net.Listen: %v", err) - } - client := core.Client{ - Listener: l, - ServerAddr: dstAddr, - NoTLS: noTLS, + BindAddr: bindAddr, + DstAddr: dstAddr, + Websocket: ws, + WebsocketPath: wsPath, + Mux: mux, Auth: auth, ServerName: serverName, - CertPool: rootCAs, + CA: ca, + CertHash: certHash, InsecureSkipVerify: insecureSkipVerify, - Timeout: timeout, + IdleTimeout: timeout, AndroidVPNMode: vpn, TFO: tfo, - Mux: mux, } err = client.ActiveAndServe() if err != nil { - log.Fatalf("main: doServer: %v", err) + log.Fatalf("client exited: %v", err) } + log.Print("client exited") + return } }