diff --git a/core/auth.go b/core/auth.go
new file mode 100644
index 0000000..0e54f36
--- /dev/null
+++ b/core/auth.go
@@ -0,0 +1,74 @@
+// Copyright (C) 2020-2021, IrineSistiana
+//
+// This file is part of simple-tls.
+//
+// simple-tls is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// simple-tls is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+package core
+
+import (
+ "context"
+ "crypto/md5"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "time"
+)
+
+type AuthTransport struct {
+ nextTransport Transport
+ auth [md5.Size]byte
+}
+
+func (t *AuthTransport) Dial(ctx context.Context) (net.Conn, error) {
+ conn, err := t.nextTransport.Dial(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write(t.auth[:]); err != nil {
+ conn.Close()
+ return nil, fmt.Errorf("failed to write auth: %w", err)
+ }
+ return conn, nil
+}
+
+func NewAuthTransport(nextTransport Transport, auth string) *AuthTransport {
+ return &AuthTransport{nextTransport: nextTransport, auth: md5.Sum([]byte(auth))}
+}
+
+type AuthTransportHandler struct {
+ nextHandler TransportHandler
+ auth [md5.Size]byte
+}
+
+func NewAuthTransportHandler(nextHandler TransportHandler, auth string) *AuthTransportHandler {
+ return &AuthTransportHandler{nextHandler: nextHandler, auth: md5.Sum([]byte(auth))}
+}
+
+var errAuthFailed = errors.New("auth failed")
+
+func (h *AuthTransportHandler) Handle(conn net.Conn) error {
+ var auth [md5.Size]byte
+ if _, err := io.ReadFull(conn, auth[:]); err != nil {
+ return fmt.Errorf("failed to read auth header: %w", err)
+ }
+
+ if auth != h.auth {
+ discardRead(conn, time.Second*15)
+ return errAuthFailed
+ }
+
+ return h.nextHandler.Handle(conn)
+}
diff --git a/core/client.go b/core/client.go
index d1d1142..754fa31 100644
--- a/core/client.go
+++ b/core/client.go
@@ -18,124 +18,146 @@
package core
import (
- "crypto/md5"
+ "bytes"
+ "context"
+ "crypto/sha256"
"crypto/tls"
"crypto/x509"
+ "encoding/hex"
+ "errors"
"fmt"
"github.com/IrineSistiana/ctunnel"
- "log"
+ "io/ioutil"
"net"
+ "strings"
"time"
)
type Client struct {
- Listener net.Listener
- ServerAddr string
- NoTLS bool
- Auth string
+ BindAddr string
+ DstAddr string
+ Websocket bool
+ WebsocketPath string
+ Mux int
+ Auth string
+
ServerName string
- CertPool *x509.CertPool
+ CA string
+ CertHash string
InsecureSkipVerify bool
- Timeout time.Duration
- AndroidVPNMode bool
- TFO bool
- Mux int
-
- dialer net.Dialer
- auth [16]byte
- tlsConfig *tls.Config
- muxPool *muxPool // not nil if Mux > 0
+
+ IdleTimeout time.Duration
+ AndroidVPNMode bool
+ TFO bool
+
+ testListener net.Listener
}
+var errEmptyCAFile = errors.New("no valid certificate was found in the ca file")
+
func (c *Client) ActiveAndServe() error {
- c.dialer = net.Dialer{
- Timeout: time.Second * 5,
- Control: GetControlFunc(&TcpConfig{AndroidVPN: c.AndroidVPNMode, EnableTFO: c.TFO}),
+
+ var l net.Listener
+ if c.testListener != nil {
+ l = c.testListener
+ } else {
+ var err error
+ lc := net.ListenConfig{}
+ l, err = lc.Listen(context.Background(), "tcp", c.BindAddr)
+ if err != nil {
+ return err
+ }
}
- if !c.NoTLS {
- c.tlsConfig = new(tls.Config)
- c.tlsConfig.NextProtos = []string{"http/1.1", "h2"}
- c.tlsConfig.ServerName = c.ServerName
- c.tlsConfig.RootCAs = c.CertPool
- c.tlsConfig.InsecureSkipVerify = c.InsecureSkipVerify
+ if len(c.ServerName) == 0 {
+ c.ServerName = strings.SplitN(c.DstAddr, ":", 2)[0]
}
- if len(c.Auth) > 0 {
- c.auth = md5.Sum([]byte(c.Auth))
+ var rootCAs *x509.CertPool
+ if len(c.CA) != 0 {
+ rootCAs = x509.NewCertPool()
+ certPEMBlock, err := ioutil.ReadFile(c.CA)
+ if err != nil {
+ return fmt.Errorf("cannot read ca file: %w", err)
+ }
+ if ok := rootCAs.AppendCertsFromPEM(certPEMBlock); !ok {
+ return errEmptyCAFile
+ }
}
- if c.Mux > 0 {
- c.muxPool = newMuxPool(c.dialServerConn, c.Mux)
+ dialer := &net.Dialer{
+ Timeout: time.Second * 5,
+ Control: GetControlFunc(&TcpConfig{AndroidVPN: c.AndroidVPNMode, EnableTFO: c.TFO}),
}
- for {
- localConn, err := c.Listener.Accept()
+ var chb []byte
+ if len(c.CertHash) != 0 {
+ b, err := hex.DecodeString(c.CertHash)
if err != nil {
- return fmt.Errorf("l.Accept(): %w", err)
+ return fmt.Errorf("invalid cert hash: %w", err)
}
- reduceTCPLoopbackSocketBuf(localConn)
+ chb = b
+ }
- go func() {
- defer localConn.Close()
-
- var serverConn net.Conn
- if c.Mux > 0 {
- stream, _, err := c.muxPool.GetStream()
- if err != nil {
- log.Printf("ERROR: muxPool.GetStream: %v", err)
- return
+ tlsConfig := &tls.Config{
+ NextProtos: []string{"h2", "http/1.1"},
+ ServerName: c.ServerName,
+ RootCAs: rootCAs,
+ InsecureSkipVerify: c.InsecureSkipVerify,
+ VerifyConnection: func(state tls.ConnectionState) error {
+ if len(chb) != 0 {
+ cert := state.PeerCertificates[0]
+ h := sha256.Sum256(cert.RawTBSCertificate)
+ if bytes.Equal(h[:len(chb)], chb) {
+ return nil
}
- serverConn = stream
- } else {
- conn, err := c.dialServerConn()
- if err != nil {
- log.Printf("ERROR: dialServerConn: %v", err)
- return
- }
- serverConn = conn
+ return fmt.Errorf("cert hash mismatch, recieved cert hash is [%s]", hex.EncodeToString(h[:]))
}
- defer serverConn.Close()
- if err := ctunnel.OpenTunnel(localConn, serverConn, c.Timeout); err != nil {
- log.Printf("ERROR: ActiveAndServe: openTunnel: %v", err)
+ if state.Version != tls.VersionTLS13 {
+ return fmt.Errorf("unsafe tls version %d", state.Version)
}
- }()
+ return nil
+ },
}
-}
-func (c *Client) dialServerConn() (net.Conn, error) {
- serverConn, err := c.dialer.Dial("tcp", c.ServerAddr)
- if err != nil {
- return nil, err
+ var transport Transport
+ if c.Websocket {
+ transport = NewWebsocketTransport(c.DstAddr, c.ServerName, c.WebsocketPath, tlsConfig, dialer)
+ } else {
+ transport = NewRawConnTransport(c.DstAddr, dialer)
+ transport = NewTLSTransport(transport, tlsConfig)
}
- if !c.NoTLS {
- serverTLSConn := tls.Client(serverConn, c.tlsConfig)
- if err := tls13HandshakeWithTimeout(serverTLSConn, time.Second*5); err != nil {
- serverTLSConn.Close()
- return nil, err
- }
- serverConn = serverTLSConn
+ if len(c.Auth) > 0 {
+ transport = NewAuthTransport(transport, c.Auth)
}
- // write auth
- if len(c.Auth) > 0 {
- if _, err := serverConn.Write(c.auth[:]); err != nil {
- serverConn.Close()
- return nil, fmt.Errorf("failed to write auth: %w", err)
+ transport = NewMuxTransport(transport, c.Mux)
+
+ for {
+ clientConn, err := l.Accept()
+ if err != nil {
+ return err
}
- }
- // write mode
- mode := modePlain
- if c.Mux > 0 {
- mode = modeMux
- }
- if _, err := serverConn.Write([]byte{mode}); err != nil {
- serverConn.Close()
- return nil, fmt.Errorf("failed to write mode: %w", err)
+ go func() {
+ defer clientConn.Close()
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+ defer cancel()
+ serverConn, err := transport.Dial(ctx)
+ if err != nil {
+ errLogger.Printf("failed to dial server connection: %v", err)
+ return
+ }
+ defer serverConn.Close()
+
+ err = ctunnel.OpenTunnel(clientConn, serverConn, c.IdleTimeout)
+ if err != nil {
+ logConnErr(clientConn, fmt.Errorf("tunnel closed: %w", err))
+ }
+ }()
}
- return serverConn, nil
}
diff --git a/core/core_test.go b/core/core_test.go
index c654067..405d1a6 100644
--- a/core/core_test.go
+++ b/core/core_test.go
@@ -19,21 +19,23 @@ package core
import (
"bytes"
+ "crypto/sha256"
"crypto/tls"
- "crypto/x509"
+ "encoding/hex"
+ "fmt"
"io"
- "log"
"math/rand"
"net"
+ "sync"
"testing"
"time"
)
func Test_main(t *testing.T) {
dataSize := 512 * 1024
+ b := make([]byte, dataSize)
+ rand.Read(b)
randData := func() []byte {
- b := make([]byte, dataSize)
- rand.Read(b)
return b
}
@@ -68,26 +70,35 @@ func Test_main(t *testing.T) {
}()
// test1
- test := func(t *testing.T, mux int, noTLS bool) {
- // start server
- _, keyPEM, certPEM, err := GenerateCertificate("example.com")
- cert, err := tls.X509KeyPair(certPEM, keyPEM)
+ test := func(t *testing.T, mux int, ws bool, wsPath string, auth string) {
+ serverListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
+ defer serverListener.Close()
- serverListener, err := net.Listen("tcp", "127.0.0.1:0")
+ _, x509cert, keyPEM, certPEM, err := GenerateCertificate("")
+ if err != nil {
+ t.Fatal(err)
+ }
+ h := sha256.Sum256(x509cert.RawTBSCertificate)
+ certHash := hex.EncodeToString(h[:])
+
+ cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatal(err)
}
- defer serverListener.Close()
server := Server{
- Listener: serverListener,
- Dst: echoListener.Addr().String(),
- Certificates: []tls.Certificate{cert},
- Timeout: timeout,
+ DstAddr: echoListener.Addr().String(),
+ Websocket: ws,
+ WebsocketPath: wsPath,
+ Auth: auth,
+ IdleTimeout: timeout,
+ testListener: serverListener,
+ testCert: &cert,
}
+
go server.ActiveAndServe()
// start client
@@ -97,62 +108,61 @@ func Test_main(t *testing.T) {
}
defer clientListener.Close()
- caPool := x509.NewCertPool()
- ok := caPool.AppendCertsFromPEM(certPEM)
- if !ok {
- t.Fatal("appendCertsFromPEM failed")
- }
-
client := Client{
- Listener: clientListener,
- ServerAddr: serverListener.Addr().String(),
- ServerName: "example.com",
- CertPool: caPool,
- InsecureSkipVerify: false,
+ DstAddr: serverListener.Addr().String(),
+ Websocket: ws,
+ WebsocketPath: wsPath,
Mux: mux,
- Timeout: timeout,
- AndroidVPNMode: false,
- TFO: false,
+ Auth: auth,
+ CertHash: certHash,
+ InsecureSkipVerify: true,
+ IdleTimeout: timeout,
+ testListener: clientListener,
}
- go client.ActiveAndServe()
- log.Printf("echo: %v, server: %v client: %v", echoListener.Addr(), serverListener.Addr(), clientListener.Addr())
+ go client.ActiveAndServe()
+ wg := new(sync.WaitGroup)
for i := 0; i < 10; i++ {
- conn, err := net.Dial("tcp", clientListener.Addr().String())
- if err != nil {
- t.Fatal(err)
- }
- data := randData()
- buf := make([]byte, dataSize)
- _, err = conn.Write(data)
- if err != nil {
- t.Fatal(err)
- }
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ conn, err := net.Dial("tcp", clientListener.Addr().String())
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ data := randData()
+ buf := make([]byte, dataSize)
+ _, err = conn.Write(data)
+ if err != nil {
+ t.Error(err)
+ return
+ }
- _, err = io.ReadFull(conn, buf)
- if err != nil {
- t.Fatal(err)
- }
- if bytes.Equal(data, buf) == false {
- t.Fatal("corrupted data")
- }
+ _, err = io.ReadFull(conn, buf)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if bytes.Equal(data, buf) == false {
+ t.Error("corrupted data")
+ return
+ }
+ }()
}
+ wg.Wait()
}
- tests := []struct {
- name string
- mux int
- noTLS bool
- }{
- {"plain", 0, false},
- {"mux", 5, false},
- {"no tls", 5, true},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- test(t, tt.mux, tt.noTLS)
- })
+ for _, mux := range [...]int{0, 5} {
+ for _, ws := range [...]bool{false, true} {
+ for _, wsPath := range [...]string{"", "/123456"} {
+ for _, auth := range [...]string{"", "123456"} {
+ t.Run(fmt.Sprintf("mux_%v_ws_%v_wsPath_%v_auth_%v", mux, ws, wsPath, auth), func(t *testing.T) {
+ test(t, mux, ws, wsPath, auth)
+ })
+ }
+ }
+ }
}
}
diff --git a/core/mode.go b/core/logger.go
similarity index 68%
rename from core/mode.go
rename to core/logger.go
index 85b562d..855a826 100644
--- a/core/mode.go
+++ b/core/logger.go
@@ -17,7 +17,19 @@
package core
-const (
- modePlain byte = iota
- modeMux
+import (
+ "log"
+ "net"
+ "net/http"
+ "os"
)
+
+var errLogger = log.New(os.Stderr, "err", log.LstdFlags)
+
+func logConnErr(conn net.Conn, err error) {
+ errLogger.Printf("connection %s <-> %s: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
+}
+
+func logRequestErr(r *http.Request, err error) {
+ errLogger.Printf("request from %s %s: %v", r.RemoteAddr, r.RequestURI, err)
+}
diff --git a/core/server.go b/core/server.go
index b9a1a1f..6d25dd0 100644
--- a/core/server.go
+++ b/core/server.go
@@ -18,146 +18,100 @@
package core
import (
- "bytes"
- "crypto/md5"
+ "context"
"crypto/tls"
+ "errors"
"fmt"
- "github.com/IrineSistiana/ctunnel"
- "github.com/xtaci/smux"
- "io"
"log"
"net"
"time"
)
type Server struct {
- Listener net.Listener
- Dst string
- NoTLS bool
- Auth string
+ BindAddr string
+ DstAddr string
- Certificates []tls.Certificate
- Timeout time.Duration
+ Websocket bool
+ WebsocketPath string
- auth [16]byte
- tlsConfig *tls.Config
+ Cert, Key, ServerName string
+ Auth string
+ TFO bool
+ IdleTimeout time.Duration
+ NoTLS bool
+
+ testListener net.Listener
+ testCert *tls.Certificate
}
+var errMissingCertOrKey = errors.New("one of cert or key argument is missing")
+
func (s *Server) ActiveAndServe() error {
- if !s.NoTLS {
- s.tlsConfig = new(tls.Config)
- s.tlsConfig.NextProtos = []string{"h2"}
- s.tlsConfig.Certificates = s.Certificates
+ var l net.Listener
+ if s.testListener != nil {
+ l = s.testListener
+ } else {
+ var err error
+ lc := net.ListenConfig{Control: GetControlFunc(&TcpConfig{EnableTFO: s.TFO})}
+ l, err = lc.Listen(context.Background(), "tcp", s.BindAddr)
+ if err != nil {
+ return err
+ }
}
+ var transportHandler TransportHandler
+ transportHandler = NewBaseTransportHandler(s.DstAddr, s.IdleTimeout)
+ transportHandler = NewMuxTransportHandler(transportHandler)
if len(s.Auth) > 0 {
- s.auth = md5.Sum([]byte(s.Auth))
+ transportHandler = NewAuthTransportHandler(transportHandler, s.Auth)
}
- for {
- clientConn, err := s.Listener.Accept()
- if err != nil {
- return fmt.Errorf("l.Accept(): %w", err)
- }
-
- go func() {
- defer clientConn.Close()
-
- if !s.NoTLS {
- clientTLSConn := tls.Server(clientConn, s.tlsConfig)
- // handshake
- if err := tls13HandshakeWithTimeout(clientTLSConn, time.Second*5); err != nil {
- log.Printf("ERROR: %s, tls13HandshakeWithTimeout: %v", clientConn.RemoteAddr(), err)
- return
- }
- clientConn = clientTLSConn
- }
-
- // check auth
- if len(s.Auth) > 0 {
- auth := make([]byte, 16)
- if _, err := io.ReadFull(clientConn, auth); err != nil {
- log.Printf("ERROR: %s, read client auth header: %v", clientConn.RemoteAddr(), err)
- return
+ if !s.NoTLS {
+ var certificate tls.Certificate
+ if s.testCert != nil {
+ certificate = *s.testCert
+ } else {
+ switch {
+ case len(s.Cert) == 0 && len(s.Key) == 0: // no cert and key
+ dnsName, _, keyPEM, certPEM, err := GenerateCertificate(s.ServerName)
+ if err != nil {
+ return fmt.Errorf("failed to generate temp cert: %w", err)
}
- if !bytes.Equal(s.auth[:], auth) {
- log.Printf("ERROR: %s, auth failed", clientConn.RemoteAddr())
- discardRead(clientConn, time.Second*15)
- return
+ log.Printf("warnning: you are using a tmp certificate with dns name: %s", dnsName)
+ cer, err := tls.X509KeyPair(certPEM, keyPEM)
+ if err != nil {
+ return fmt.Errorf("cannot load x509 key pair from memory: %w", err)
}
- }
-
- // mode
- header := make([]byte, 1)
- if _, err := io.ReadFull(clientConn, header); err != nil {
- log.Printf("ERROR: %s, read client mode header: %v", clientConn.RemoteAddr(), err)
- return
- }
- switch header[0] {
- case modePlain:
- if err := s.handleClientConn(clientConn); err != nil {
- log.Printf("ERROR: %s, handleClientConn: %v", clientConn.RemoteAddr(), err)
- return
- }
- case modeMux:
- err := s.handleClientMux(clientConn)
+ certificate = cer
+ case len(s.Cert) != 0 && len(s.Key) != 0: // has a cert and a key
+ cer, err := tls.LoadX509KeyPair(s.Cert, s.Key) //load cert
if err != nil {
- log.Printf("ERROR: %s, handleClientMux: %v", clientConn.RemoteAddr(), err)
- return
+ return fmt.Errorf("cannot load x509 key pair from disk: %w", err)
}
+ certificate = cer
default:
- log.Printf("ERROR: %s, invalid header %d", clientConn.RemoteAddr(), header[0])
- return
+ return errMissingCertOrKey
}
- }()
- }
-}
-
-func discardRead(c net.Conn, t time.Duration) {
- c.SetDeadline(time.Now().Add(t))
- buf := make([]byte, 512)
- for {
- _, err := c.Read(buf)
- if err != nil {
- return
}
- }
-}
-func (s *Server) handleClientConn(cc net.Conn) (err error) {
- dstConn, err := net.Dial("tcp", s.Dst)
- if err != nil {
- return fmt.Errorf("net.Dial: %v", err)
- }
- reduceTCPLoopbackSocketBuf(dstConn)
- defer dstConn.Close()
-
- if err := ctunnel.OpenTunnel(dstConn, cc, s.Timeout); err != nil {
- return fmt.Errorf("openTunnel: %v", err)
- }
- return nil
-}
+ tlsConfig := &tls.Config{
+ NextProtos: []string{"h2", "http/1.1"},
+ Certificates: []tls.Certificate{certificate},
+ VerifyConnection: func(state tls.ConnectionState) error {
+ if state.Version != tls.VersionTLS13 {
+ return fmt.Errorf("unsafe tls version %d", state.Version)
+ }
+ return nil
+ },
+ }
-func (s *Server) handleClientMux(cc net.Conn) (err error) {
- sess, err := smux.Server(cc, muxConfig)
- if err != nil {
- return err
+ l = tls.NewListener(l, tlsConfig)
}
- defer sess.Close()
- for {
- stream, err := sess.AcceptStream()
- if err != nil {
- return nil // suppress smux err
- }
- go func() {
- defer stream.Close()
- if err := s.handleClientConn(stream); err != nil {
- log.Printf("ERROR: handleClientMux: %s, handleClientConn: %v", stream.RemoteAddr(), err)
- return
- }
- }()
+ if s.Websocket {
+ return ListenWebsocket(l, s.WebsocketPath, transportHandler)
}
+ return ListenRawConn(l, transportHandler)
}
diff --git a/core/sip003.go b/core/sip003.go
index 769db25..7e30123 100644
--- a/core/sip003.go
+++ b/core/sip003.go
@@ -24,7 +24,7 @@ import (
"strings"
)
-var ErrBrokenSIP003Args = errors.New("invalid SIP003 args")
+var errBrokenSIP003Args = errors.New("invalid SIP003 args")
//SIP003Args contains sip003 args
type SIP003Args struct {
@@ -55,7 +55,7 @@ func GetSIP003Args() (*SIP003Args, error) {
if srhOk || srpOk || slhOk || slpOk || spoOk { // has at least one arg
if !(srhOk && srpOk && slhOk && slpOk) { // but not has all 4 args
- return nil, ErrBrokenSIP003Args
+ return nil, errBrokenSIP003Args
}
} else {
return nil, nil // can't find any sip003 arg
diff --git a/core/smux.go b/core/smux.go
index 436a5df..1544b99 100644
--- a/core/smux.go
+++ b/core/smux.go
@@ -18,13 +18,20 @@
package core
import (
+ "context"
"fmt"
"github.com/xtaci/smux"
+ "io"
"log"
"net"
"sync"
)
+const (
+ modePlain byte = iota
+ modeMux
+)
+
var muxConfig = &smux.Config{
Version: 1,
KeepAliveDisabled: true,
@@ -33,8 +40,8 @@ var muxConfig = &smux.Config{
MaxStreamBuffer: 32 * 1024,
}
-type muxPool struct {
- dialFunc func() (c net.Conn, err error)
+type MuxTransport struct {
+ nextTransport Transport
maxConcurrent int
sm sync.Mutex
@@ -45,11 +52,8 @@ type muxPool struct {
dialWaiting int
}
-func newMuxPool(dialFunc func() (c net.Conn, err error), maxConcurrent int) *muxPool {
- if maxConcurrent < 1 {
- panic(fmt.Sprintf("invalid maxConcurrent: %d", maxConcurrent))
- }
- return &muxPool{dialFunc: dialFunc, maxConcurrent: maxConcurrent, sess: map[*smux.Session]struct{}{}}
+func NewMuxTransport(subTransport Transport, maxConcurrent int) *MuxTransport {
+ return &MuxTransport{nextTransport: subTransport, maxConcurrent: maxConcurrent, sess: map[*smux.Session]struct{}{}}
}
type dialCall struct {
@@ -58,21 +62,32 @@ type dialCall struct {
err error
}
-func (m *muxPool) GetStream() (stream *smux.Stream, sess *smux.Session, err error) {
- if stream, sess, ok := m.tryGetStream(); ok {
- return stream, sess, nil
+func (m *MuxTransport) Dial(ctx context.Context) (net.Conn, error) {
+ if m.maxConcurrent <= 1 {
+ conn, err := m.nextTransport.Dial(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if _, err := conn.Write([]byte{modePlain}); err != nil {
+ conn.Close()
+ return nil, fmt.Errorf("failed to write mux header: %w", err)
+ }
+ }
+
+ if stream := m.tryGetStream(); stream != nil {
+ return stream, nil
}
- return m.tryGetStreamFlash()
+ return m.tryGetStreamFlash(ctx)
}
-func (m *muxPool) MarkDead(sess *smux.Session) {
+func (m *MuxTransport) MarkDead(sess *smux.Session) {
m.sm.Lock()
defer m.sm.Unlock()
delete(m.sess, sess)
sess.Close()
}
-func (m *muxPool) tryGetStream() (stream *smux.Stream, sess *smux.Session, ok bool) {
+func (m *MuxTransport) tryGetStream() (stream *smux.Stream) {
m.sm.Lock()
defer m.sm.Unlock()
for sess := range m.sess {
@@ -84,18 +99,18 @@ func (m *muxPool) tryGetStream() (stream *smux.Stream, sess *smux.Session, ok bo
delete(m.sess, sess)
continue
}
- return s, sess, true
+ return s
}
}
- return nil, nil, false
+ return nil
}
-func (m *muxPool) tryGetStreamFlash() (stream *smux.Stream, sess *smux.Session, err error) {
+func (m *MuxTransport) tryGetStreamFlash(ctx context.Context) (*smux.Stream, error) {
var call *dialCall
m.dm.Lock()
if m.dialing == nil || (m.dialing != nil && m.dialWaiting >= m.maxConcurrent) {
m.dialWaiting = 0
- m.dialing = m.dialSessLocked() // needs a new dial
+ m.dialing = m.dialSessLocked(ctx) // needs a new dial
} else {
m.dialWaiting++
}
@@ -103,27 +118,33 @@ func (m *muxPool) tryGetStreamFlash() (stream *smux.Stream, sess *smux.Session,
defer m.dm.Unlock()
<-call.done
- sess = call.s
- err = call.err
+ sess := call.s
+ err := call.err
if err != nil {
- return nil, nil, err
+ return nil, err
}
- stream, err = sess.OpenStream()
- return stream, sess, err
+ return sess.OpenStream()
}
-func (m *muxPool) dialSessLocked() (call *dialCall) {
+func (m *MuxTransport) dialSessLocked(ctx context.Context) (call *dialCall) {
call = &dialCall{
done: make(chan struct{}),
}
go func() {
- c, err := m.dialFunc()
+ c, err := m.nextTransport.Dial(ctx)
if err != nil {
call.err = err
close(call.done)
return
}
+ if _, err := c.Write([]byte{modeMux}); err != nil {
+ c.Close()
+ call.err = fmt.Errorf("failed to write mux header: %w", err)
+ close(call.done)
+ return
+ }
+
sess, err := smux.Client(c, muxConfig)
call.s = sess
call.err = err
@@ -139,3 +160,45 @@ func (m *muxPool) dialSessLocked() (call *dialCall) {
}()
return call
}
+
+type MuxTransportHandler struct {
+ nextHandler TransportHandler
+}
+
+func NewMuxTransportHandler(nextHandler TransportHandler) *MuxTransportHandler {
+ return &MuxTransportHandler{nextHandler: nextHandler}
+}
+
+func (h *MuxTransportHandler) Handle(conn net.Conn) error {
+ header := make([]byte, 1)
+ if _, err := io.ReadFull(conn, header); err != nil {
+ return fmt.Errorf("failed to read mux header: %w", err)
+ }
+
+ switch header[0] {
+ case modePlain:
+ return h.nextHandler.Handle(conn)
+ case modeMux:
+ sess, err := smux.Server(conn, muxConfig)
+ if err != nil {
+ return err
+ }
+ defer sess.Close()
+
+ for {
+ stream, err := sess.AcceptStream()
+ if err != nil {
+ return nil // suppress smux err
+ }
+ go func() {
+ defer stream.Close()
+ if err := h.nextHandler.Handle(stream); err != nil {
+ logConnErr(stream, err)
+ return
+ }
+ }()
+ }
+ default:
+ return fmt.Errorf("invalid mux header %d", header[0])
+ }
+}
diff --git a/core/transport.go b/core/transport.go
new file mode 100644
index 0000000..8fe4321
--- /dev/null
+++ b/core/transport.go
@@ -0,0 +1,111 @@
+// Copyright (C) 2020-2021, IrineSistiana
+//
+// This file is part of simple-tls.
+//
+// simple-tls is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// simple-tls is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+package core
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "github.com/IrineSistiana/ctunnel"
+ "net"
+ "time"
+)
+
+type Transport interface {
+ Dial(ctx context.Context) (net.Conn, error)
+}
+
+type TransportHandler interface {
+ Handle(conn net.Conn) error
+}
+
+type RawConnTransport struct {
+ addr string
+ dialer *net.Dialer
+}
+
+func (t *RawConnTransport) Dial(ctx context.Context) (net.Conn, error) {
+ return t.dialer.DialContext(ctx, "tcp", t.addr)
+}
+
+func NewRawConnTransport(addr string, dialer *net.Dialer) *RawConnTransport {
+ return &RawConnTransport{addr: addr, dialer: dialer}
+}
+
+type TLSTransport struct {
+ nextTransport Transport
+ conf *tls.Config
+}
+
+func (t *TLSTransport) Dial(ctx context.Context) (net.Conn, error) {
+ conn, err := t.nextTransport.Dial(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ tlsConn := tls.Client(conn, t.conf)
+ if err := tlsConn.HandshakeContext(ctx); err != nil {
+ tlsConn.Close()
+ return nil, err
+ }
+ return tlsConn, nil
+}
+
+func NewTLSTransport(nextTransport Transport, conf *tls.Config) *TLSTransport {
+ return &TLSTransport{nextTransport: nextTransport, conf: conf}
+}
+
+type BaseTransportHandler struct {
+ dst string
+ idleTimeout time.Duration
+}
+
+func (h *BaseTransportHandler) Handle(conn net.Conn) error {
+ dstConn, err := net.Dial("tcp", h.dst)
+ if err != nil {
+ return fmt.Errorf("cannot connect to the dst: %w", err)
+ }
+ reduceTCPLoopbackSocketBuf(dstConn)
+ defer dstConn.Close()
+
+ if err := ctunnel.OpenTunnel(dstConn, conn, h.idleTimeout); err != nil {
+ return fmt.Errorf("tunnel closed: %w", err)
+ }
+ return nil
+}
+
+func NewBaseTransportHandler(dst string, idleTimeout time.Duration) *BaseTransportHandler {
+ return &BaseTransportHandler{dst: dst, idleTimeout: idleTimeout}
+}
+
+func ListenRawConn(l net.Listener, nextHandler TransportHandler) error {
+ for {
+ conn, err := l.Accept()
+ if err != nil {
+ return err
+ }
+
+ go func() {
+ defer conn.Close()
+ err := nextHandler.Handle(conn)
+ if err != nil {
+ logConnErr(conn, err)
+ }
+ }()
+ }
+}
diff --git a/core/utils.go b/core/utils.go
index cd47183..50cd371 100644
--- a/core/utils.go
+++ b/core/utils.go
@@ -21,29 +21,17 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
- "crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
mathRand "math/rand"
+ "net"
"time"
)
-func tls13HandshakeWithTimeout(c *tls.Conn, timeout time.Duration) error {
- c.SetDeadline(time.Now().Add(timeout))
- if err := c.Handshake(); err != nil {
- return err
- }
- c.SetDeadline(time.Time{})
- if cVar := c.ConnectionState().Version; cVar != tls.VersionTLS13 {
- return fmt.Errorf("unexpected tls version: %x", cVar)
- }
- return nil
-}
-
-func GenerateCertificate(serverName string) (dnsName string, keyPEM, certPEM []byte, err error) {
+func GenerateCertificate(serverName string) (dnsName string, cert *x509.Certificate, keyPEM, certPEM []byte, err error) {
//priv key
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
@@ -89,7 +77,12 @@ func GenerateCertificate(serverName string) (dnsName string, keyPEM, certPEM []b
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
- return dnsName, keyPEM, certPEM, nil
+ cert, err = x509.ParseCertificate(certDER)
+ if err != nil {
+ return
+ }
+
+ return
}
func randServerName() string {
@@ -106,3 +99,14 @@ func randStr(length int) string {
}
return string(b)
}
+
+func discardRead(c net.Conn, t time.Duration) {
+ c.SetDeadline(time.Now().Add(t))
+ buf := make([]byte, 512)
+ for {
+ _, err := c.Read(buf)
+ if err != nil {
+ return
+ }
+ }
+}
diff --git a/core/websocket.go b/core/websocket.go
new file mode 100644
index 0000000..1bafa5e
--- /dev/null
+++ b/core/websocket.go
@@ -0,0 +1,130 @@
+// Copyright (C) 2020-2021, IrineSistiana
+//
+// This file is part of simple-tls.
+//
+// simple-tls is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// simple-tls is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+package core
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "golang.org/x/net/http2"
+ "net"
+ "net/http"
+ "net/url"
+ "nhooyr.io/websocket"
+ "time"
+)
+
+type WebsocketTransport struct {
+ u string
+ op *websocket.DialOptions
+}
+
+func NewWebsocketTransport(serverAddr, serverName, urlPath string, tlsConfig *tls.Config, dialer *net.Dialer) *WebsocketTransport {
+ u := url.URL{
+ Scheme: "https",
+ Host: serverName,
+ Path: urlPath,
+ }
+ t := &http.Transport{
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return dialer.DialContext(ctx, network, serverAddr)
+ },
+ TLSClientConfig: tlsConfig,
+ TLSHandshakeTimeout: time.Second * 5,
+ DisableCompression: true,
+ ResponseHeaderTimeout: time.Second * 5,
+ ExpectContinueTimeout: 0,
+ WriteBufferSize: 32 * 1024,
+ ReadBufferSize: 32 * 1024,
+ ForceAttemptHTTP2: true,
+ }
+
+ t2, _ := http2.ConfigureTransports(t)
+ t2.ReadIdleTimeout = time.Second * 30
+ t2.PingTimeout = time.Second * 10
+
+ return &WebsocketTransport{
+ u: u.String(),
+ op: &websocket.DialOptions{
+ HTTPClient: &http.Client{
+ Transport: t,
+ Timeout: time.Second * 10,
+ },
+ CompressionMode: websocket.CompressionDisabled,
+ },
+ }
+}
+
+func (p *WebsocketTransport) Dial(ctx context.Context) (net.Conn, error) {
+ wsConn, _, err := websocket.Dial(ctx, p.u, p.op)
+ if err != nil {
+ return nil, err
+ }
+ return websocket.NetConn(context.Background(), wsConn, websocket.MessageBinary), nil
+}
+
+type wsHttpHandler struct {
+ nextHandler TransportHandler
+ path string
+}
+
+var errInvalidPath = errors.New("invalid request path")
+
+func (h *wsHttpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ if len(h.path) != 0 && r.URL.Path != h.path {
+ w.WriteHeader(http.StatusNotFound)
+ logRequestErr(r, errInvalidPath)
+ return
+ }
+
+ wsConn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
+ CompressionMode: websocket.CompressionDisabled,
+ })
+ if err != nil {
+ logRequestErr(r, fmt.Errorf("cannot accept websocket connection: %w", err))
+ return
+ }
+
+ clientConn := websocket.NetConn(context.Background(), wsConn, websocket.MessageBinary)
+ defer clientConn.Close()
+
+ if err := h.nextHandler.Handle(clientConn); err != nil {
+ logRequestErr(r, err)
+ return
+ }
+ return
+}
+
+func ListenWebsocket(l net.Listener, path string, nextHandler TransportHandler) error {
+ httpServer := &http.Server{
+ Handler: &wsHttpHandler{
+ nextHandler: nextHandler,
+ path: path,
+ },
+ ReadTimeout: time.Second * 10,
+ ReadHeaderTimeout: time.Second * 10,
+ WriteTimeout: time.Second * 10,
+ }
+
+ http2.ConfigureServer(httpServer, &http2.Server{
+ IdleTimeout: time.Second * 45,
+ })
+
+ return httpServer.Serve(l)
+}
diff --git a/go.mod b/go.mod
index 3031f22..7a972a0 100644
--- a/go.mod
+++ b/go.mod
@@ -5,5 +5,10 @@ go 1.14
require (
github.com/IrineSistiana/ctunnel v0.0.0-20210409113947-9756ebc29fdb
github.com/xtaci/smux v1.5.16
+ go.uber.org/atomic v1.9.0 // indirect
+ go.uber.org/multierr v1.7.0 // indirect
+ go.uber.org/zap v1.19.1
+ golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2
golang.org/x/sys v0.0.0-20211003122950-b1ebd4e1001c
+ nhooyr.io/websocket v1.8.7
)
diff --git a/go.sum b/go.sum
index f10d9a1..27a731a 100644
--- a/go.sum
+++ b/go.sum
@@ -1,12 +1,99 @@
github.com/IrineSistiana/ctunnel v0.0.0-20210409113947-9756ebc29fdb h1:b7t7X5hUjO3ZTDfe2TZOKHGe7Aq8yfNzwwcS4Qg7QTA=
github.com/IrineSistiana/ctunnel v0.0.0-20210409113947-9756ebc29fdb/go.mod h1:xcxj4YojLW9ri1tSYmHFtCSaphfxXEGgp+9BnbVpDYk=
+github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
+github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
+github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
+github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
+github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
+github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
+github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
+github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
+github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
+github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
+github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk=
+github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
+github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
+github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
+github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8=
+github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
+github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
+github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
+github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
+github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
+github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
+github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
+github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
+github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
github.com/xtaci/smux v1.5.15 h1:6hMiXswcleXj5oNfcJc+DXS8Vj36XX2LaX98udog6Kc=
github.com/xtaci/smux v1.5.15/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY=
github.com/xtaci/smux v1.5.16 h1:FBPYOkW8ZTjLKUM4LI4xnnuuDC8CQ/dB04HD519WoEk=
github.com/xtaci/smux v1.5.16/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY=
+github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
+go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
+go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
+go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
+go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
+go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
+go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec=
+go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak=
+go.uber.org/zap v1.19.1 h1:ue41HOKd1vGURxrmeKIgELGb3jPW9DMUDGtsinblHwI=
+go.uber.org/zap v1.19.1/go.mod h1:j3DNczoxDZroyBnOT1L/Q79cfUMGZxlv/9dzN7SM1rI=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
+golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
+golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
+golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 h1:CIJ76btIcR3eFI5EgSo6k1qKw9KJexJuRLI9G7Hp5wE=
+golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200523222454-059865788121 h1:rITEj+UZHYC927n8GT97eC3zrpzXdb/voyeOuVKS46o=
golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57 h1:F5Gozwx4I1xtr/sr/8CFbb57iKi3297KFs0QDbGN60A=
golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211003122950-b1ebd4e1001c h1:EyJTLQbOxvk8V6oDdD8ILR1BOs3nEJXThD6aqsiPNkM=
golang.org/x/sys v0.0.0-20211003122950-b1ebd4e1001c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
+golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
+golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
+golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
+golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g=
+nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0=
diff --git a/main.go b/main.go
index ec82798..64b1f76 100644
--- a/main.go
+++ b/main.go
@@ -18,20 +18,17 @@
package main
import (
- "context"
- "crypto/tls"
+ "crypto/sha256"
"crypto/x509"
- "encoding/base64"
+ "encoding/pem"
"flag"
"fmt"
"io/ioutil"
"log"
- "net"
"os"
"os/signal"
"runtime"
"strconv"
- "strings"
"syscall"
"time"
@@ -46,12 +43,12 @@ func main() {
osSignals := make(chan os.Signal, 1)
signal.Notify(osSignals, os.Interrupt, os.Kill, syscall.SIGTERM)
s := <-osSignals
- log.Printf("main: exiting: signal: %v", s)
+ log.Printf("exiting: signal: %v", s)
os.Exit(0)
}()
- var bindAddr, dstAddr, auth, serverName, cca, ca, cert, key string
- var noTLS, insecureSkipVerify, isServer, tfo, vpn, genCert, showVersion bool
+ var bindAddr, dstAddr, auth, serverName, wsPath, ca, cert, key, hashCert, certHash string
+ var ws, insecureSkipVerify, isServer, noTLS, tfo, vpn, genCert, showVersion bool
var cpu, mux int
var timeout time.Duration
var timeoutFlag int
@@ -61,20 +58,24 @@ func main() {
commandLine.StringVar(&bindAddr, "b", "", "[Host:Port] bind address")
commandLine.StringVar(&dstAddr, "d", "", "[Host:Port] destination address")
commandLine.StringVar(&auth, "auth", "", "server password")
- commandLine.BoolVar(&noTLS, "no-tls", false, "disable TLS (debug only)")
+
+ commandLine.BoolVar(&ws, "ws", false, "websocket mode")
+ commandLine.StringVar(&wsPath, "ws-path", "", "websocket path")
// client only
commandLine.IntVar(&mux, "mux", 0, "enable mux")
commandLine.StringVar(&serverName, "n", "", "server name")
commandLine.StringVar(&ca, "ca", "", "PEM CA file path")
- commandLine.StringVar(&cca, "cca", "", "base64 encoded PEM CA")
+ commandLine.StringVar(&certHash, "cert-hash", "", "server certificate hash")
+
commandLine.BoolVar(&insecureSkipVerify, "no-verify", false, "client won't verify the server's certificate chain and host name")
commandLine.BoolVar(&vpn, "V", false, "DO NOT USE, this is for android vpn mode")
// server only
- commandLine.BoolVar(&isServer, "s", false, "is server")
- commandLine.StringVar(&cert, "cert", "", "[Path] PEM cert file")
- commandLine.StringVar(&key, "key", "", "[Path] PEM key file")
+ commandLine.BoolVar(&isServer, "s", false, "run as a server (without this simple-tls runs as a client)")
+ commandLine.StringVar(&cert, "cert", "", "PEM cert file")
+ commandLine.StringVar(&key, "key", "", "PEM key file")
+ commandLine.BoolVar(&noTLS, "no-tls", false, "disable server tls")
// etc
commandLine.IntVar(&timeoutFlag, "t", 300, "timeout in sec")
@@ -84,73 +85,93 @@ func main() {
// helper commands
commandLine.BoolVar(&genCert, "gen-cert", false, "[This is a helper function]: generate a certificate with dns name [-n], store it's key to [-key] and cert to [-cert], print cert in base64 format without padding characters")
commandLine.BoolVar(&showVersion, "v", false, "output version info and exit")
+ commandLine.StringVar(&hashCert, "hash-cert", "", "print the hashes for the certificate")
err := commandLine.Parse(os.Args[1:])
if err != nil {
- log.Fatalf("main: invalid arg: %v", err)
+ log.Fatalf("invalid arg: %v", err)
}
// display version
if showVersion {
println(version)
- os.Exit(0)
+ return
}
// gen cert
if genCert {
- log.Print("main: WARNING: generating PEM encoded key and cert")
+ log.Print("generating PEM encoded key and cert")
- dnsName, keyPEM, certPEM, err := core.GenerateCertificate(serverName)
+ dnsName, certOut, keyPEM, certPEM, err := core.GenerateCertificate(serverName)
if err != nil {
- log.Fatalf("main: generateCertificate: %v", err)
+ log.Fatalf("generateCertificate: %v", err)
}
// key
if len(key) == 0 {
key = dnsName + ".key"
}
- log.Printf("main: generating PEM encoded key to %s", key)
+ log.Printf("generating PEM encoded key to %s", key)
keyFile, err := os.Create(key)
if err != nil {
- log.Fatalf("main: creating key file [%s]: %v", key, err)
+ log.Fatalf("creating key file [%s]: %v", key, err)
}
defer keyFile.Close()
_, err = keyFile.Write(keyPEM)
if err != nil {
- log.Fatalf("main: writing key file [%s]: %v", key, err)
+ log.Fatalf("writing key file [%s]: %v", key, err)
}
// cert
if len(cert) == 0 {
cert = dnsName + ".cert"
}
- log.Printf("main: generating PEM encoded cert to %s", cert)
+ log.Printf("generating PEM encoded cert to %s", cert)
certFile, err := os.Create(cert)
if err != nil {
- log.Fatalf("main: creating cert file [%s]: %v", cert, err)
+ log.Fatalf("creating cert file [%s]: %v", cert, err)
}
defer certFile.Close()
_, err = certFile.Write(certPEM)
if err != nil {
- log.Fatalf("main: writing cert file [%s]: %v", cert, err)
+ log.Fatalf("writing cert file [%s]: %v", cert, err)
}
- certBase64 := base64.RawStdEncoding.EncodeToString(certPEM)
fmt.Printf("Your new cert dns name is: %s\n", dnsName)
- fmt.Print("Your new cert base64 string is:\n")
- fmt.Printf("%s\n", certBase64)
- fmt.Println("Copy this string and import it to client using -cca option")
+ fmt.Print("Your new cert hash is:\n")
+ fmt.Printf("%x\n", sha256.Sum256(certOut.RawTBSCertificate))
+ return
+ }
+
+ if len(hashCert) != 0 {
+ rawCert, err := ioutil.ReadFile(hashCert)
+ if err != nil {
+ log.Fatalf("failed to read cert file: %v", err)
+ }
+ b, _ := pem.Decode(rawCert)
+ if b.Type != "CERTIFICATE" {
+ log.Fatalf("invaild pem type [%s]", b.Type)
+ }
+
+ certs, err := x509.ParseCertificates(b.Bytes)
+ if err != nil {
+ log.Fatalf("failed to parse cert file: %v", err)
+ }
+ for _, cert := range certs {
+ h := sha256.Sum256(cert.RawTBSCertificate)
+ fmt.Printf("[%v]: %x\n", cert.Subject, h)
+ }
return
}
// overwrite args from env
sip003Args, err := core.GetSIP003Args()
if err != nil {
- log.Fatalf("main: sip003 error: %v", err)
+ log.Fatalf("sip003 error: %v", err)
}
if sip003Args != nil {
- log.Print("main: simple-tls is running as a sip003 plugin")
+ log.Print("simple-tls is running as a sip003 plugin")
var ok bool
var s string
@@ -168,20 +189,24 @@ func main() {
setStrIfNotEmpty(&dstAddr, s)
s, _ = sip003Args.SS_PLUGIN_OPTIONS["auth"]
setStrIfNotEmpty(&auth, s)
- _, ok = sip003Args.SS_PLUGIN_OPTIONS["no-tls"]
- noTLS = noTLS || ok
+
+ _, ok = sip003Args.SS_PLUGIN_OPTIONS["ws"]
+ ws = ws || ok
+ s, _ = sip003Args.SS_PLUGIN_OPTIONS["ws-path"]
+ setStrIfNotEmpty(&wsPath, s)
// client
s, _ = sip003Args.SS_PLUGIN_OPTIONS["n"]
setStrIfNotEmpty(&serverName, s)
s, _ = sip003Args.SS_PLUGIN_OPTIONS["mux"]
if err := setIntIfNotZero(&mux, s); err != nil {
- log.Fatalf("main: invalid mux value, %v", err)
+ log.Fatalf("invalid mux value, %v", err)
}
s, _ = sip003Args.SS_PLUGIN_OPTIONS["ca"]
setStrIfNotEmpty(&ca, s)
- s, _ = sip003Args.SS_PLUGIN_OPTIONS["cca"]
- setStrIfNotEmpty(&cca, s)
+ s, _ = sip003Args.SS_PLUGIN_OPTIONS["cert-hash"]
+ setStrIfNotEmpty(&certHash, s)
+
_, ok = sip003Args.SS_PLUGIN_OPTIONS["no-verify"]
insecureSkipVerify = insecureSkipVerify || ok
@@ -192,15 +217,17 @@ func main() {
setStrIfNotEmpty(&cert, s)
s, _ = sip003Args.SS_PLUGIN_OPTIONS["key"]
setStrIfNotEmpty(&key, s)
+ _, ok = sip003Args.SS_PLUGIN_OPTIONS["no-tls"]
+ noTLS = noTLS || ok
// etc
s, _ = sip003Args.SS_PLUGIN_OPTIONS["t"]
if err := setIntIfNotZero(&timeoutFlag, s); err != nil {
- log.Fatalf("main: invalid timeout value, %v", err)
+ log.Fatalf("invalid timeout value, %v", err)
}
s, _ = sip003Args.SS_PLUGIN_OPTIONS["cpu"]
if err := setIntIfNotZero(&cpu, s); err != nil {
- log.Fatalf("main: invalid cpu number, %v", err)
+ log.Fatalf("invalid cpu number, %v", err)
}
_, ok = sip003Args.SS_PLUGIN_OPTIONS["fast-open"]
tfo = tfo || ok
@@ -218,117 +245,56 @@ func main() {
runtime.GOMAXPROCS(cpu)
if len(bindAddr) == 0 {
- log.Fatal("main: bind addr is required")
+ log.Fatal("bind addr is required")
}
if len(dstAddr) == 0 {
- log.Fatal("main: destination addr is required")
+ log.Fatal("destination addr is required")
}
- log.Printf("main: simple-tls %s (go version: %s, os: %s, arch: %s)", version, runtime.Version(), runtime.GOOS, runtime.GOARCH)
+ log.Printf("simple-tls %s (go version: %s, os: %s, arch: %s)", version, runtime.Version(), runtime.GOOS, runtime.GOARCH)
if isServer {
- var certificates []tls.Certificate
- if !noTLS {
- switch {
- case len(cert) == 0 && len(key) == 0: // no cert and key
- log.Printf("main: warnning: neither -key nor -cert is specified")
-
- dnsName, keyPEM, certPEM, err := core.GenerateCertificate(serverName)
- if err != nil {
- log.Fatalf("main: generateCertificate: %v", err)
- }
- log.Printf("main: warnning: using tmp certificate %s", dnsName)
- cer, err := tls.X509KeyPair(certPEM, keyPEM)
- if err != nil {
- log.Fatalf("main: X509KeyPair: %v", err)
- }
- certificates = []tls.Certificate{cer}
- case len(cert) != 0 && len(key) != 0: // has cert and key
- cer, err := tls.LoadX509KeyPair(cert, key) //load cert
- if err != nil {
- log.Fatalf("main: LoadX509KeyPair: %v", err)
- }
- certificates = []tls.Certificate{cer}
- default:
- log.Fatal("main: server must have a X509 key pair, aka. -cert and -key")
- }
- }
-
- lc := net.ListenConfig{Control: core.GetControlFunc(&core.TcpConfig{EnableTFO: tfo})}
- l, err := lc.Listen(context.Background(), "tcp", bindAddr)
- if err != nil {
- log.Fatalf("main: net.Listen: %v", err)
- }
-
server := core.Server{
- Listener: l,
- Dst: dstAddr,
- NoTLS: noTLS,
- Auth: auth,
- Certificates: certificates,
- Timeout: timeout,
+ BindAddr: bindAddr,
+ DstAddr: dstAddr,
+ Websocket: ws,
+ WebsocketPath: wsPath,
+ Cert: cert,
+ Key: key,
+ ServerName: serverName,
+ Auth: auth,
+ TFO: tfo,
+ IdleTimeout: timeout,
}
-
- err = server.ActiveAndServe()
- if err != nil {
- log.Fatalf("main: doServer: %v", err)
+ if err := server.ActiveAndServe(); err != nil {
+ log.Fatalf("server exited: %v", err)
}
+ log.Print("server exited")
+ return
} else { // do client
- var rootCAs *x509.CertPool
- if !noTLS {
- if len(serverName) == 0 {
- serverName = strings.SplitN(dstAddr, ":", 2)[0]
- }
-
- switch {
- case len(cca) != 0:
- cca = strings.TrimRight(cca, "=")
- pem, err := base64.RawStdEncoding.DecodeString(cca)
- if err != nil {
- log.Fatalf("main: base64.RawStdEncoding.DecodeString: %v", err)
- }
-
- rootCAs = x509.NewCertPool()
- if ok := rootCAs.AppendCertsFromPEM(pem); !ok {
- log.Fatal("main: AppendCertsFromPEM failed, cca is invalid")
- }
- case len(ca) != 0:
- rootCAs = x509.NewCertPool()
- certPEMBlock, err := ioutil.ReadFile(ca)
- if err != nil {
- log.Fatalf("main: ReadFile ca [%s], %v", ca, err)
- }
- if ok := rootCAs.AppendCertsFromPEM(certPEMBlock); !ok {
- log.Fatal("main: AppendCertsFromPEM failed, ca is invalid")
- }
- }
- }
-
- lc := net.ListenConfig{}
- l, err := lc.Listen(context.Background(), "tcp", bindAddr)
- if err != nil {
- log.Fatalf("main: net.Listen: %v", err)
- }
-
client := core.Client{
- Listener: l,
- ServerAddr: dstAddr,
- NoTLS: noTLS,
+ BindAddr: bindAddr,
+ DstAddr: dstAddr,
+ Websocket: ws,
+ WebsocketPath: wsPath,
+ Mux: mux,
Auth: auth,
ServerName: serverName,
- CertPool: rootCAs,
+ CA: ca,
+ CertHash: certHash,
InsecureSkipVerify: insecureSkipVerify,
- Timeout: timeout,
+ IdleTimeout: timeout,
AndroidVPNMode: vpn,
TFO: tfo,
- Mux: mux,
}
err = client.ActiveAndServe()
if err != nil {
- log.Fatalf("main: doServer: %v", err)
+ log.Fatalf("client exited: %v", err)
}
+ log.Print("client exited")
+ return
}
}