diff --git a/conf/router.yaml b/conf/router.yaml index 64e0559..63c1320 100644 --- a/conf/router.yaml +++ b/conf/router.yaml @@ -50,7 +50,7 @@ hosts: dns: remote - name: www.baidu.com match: equal - target: local + target: auto - name: google match: contain target: deny diff --git a/proto/tcpcopy.go b/proto/tcpcopy.go index 4d9324e..be66278 100644 --- a/proto/tcpcopy.go +++ b/proto/tcpcopy.go @@ -35,13 +35,12 @@ func (that *tcpCopy) response() error { that.req.DstIP = conf.RouterConfig.TcpCopy.IP that.req.DstPort = conf.RouterConfig.TcpCopy.Port - network, connAddr := tunnel.buildAddress("", that.req.DstIP, that.req.DstPort) + network, connAddr := tunnel.buildAddress("", that.req.DstIP, that.req.DstPort, true) if connAddr == "" { err = errors.New("target address is empty") return err } - tunnel.registerCounter("", that.req.DstIP, that.req.DstPort) - err = tunnel.dail(network, connAddr) + err = tunnel.dail(network, connAddr, 0) if err != nil { log.Println(trace.ID(that.req.ID), "dail err", err.Error()) return err diff --git a/proto/tunnel.go b/proto/tunnel.go index cfcda96..08b3d6b 100644 --- a/proto/tunnel.go +++ b/proto/tunnel.go @@ -58,7 +58,6 @@ func init() { type tunnel struct { req *Request conn *net.TCPConn // 后端服务 - connAddr string // 后端地址 curState int inboundIP string // 来源IP @@ -227,19 +226,19 @@ func (s *tunnel) logCopyErr(name string, err error) { } // dail tcp连接 -func (s *tunnel) dail(network, connAddr string) error { - if s.connAddr == connAddr && s.conn != nil { - return nil - } +func (s *tunnel) dail(network, connAddr string, second int64) error { if config.DebugLevel >= config.LevelLong { log.Printf("%s create new connection to server %s\n", trace.ID(s.req.ID), connAddr) } + connTimeout := time.Duration(5) * time.Second + if second > 0 { + connTimeout = time.Duration(second) * time.Second + } conn, err := net.DialTimeout(network, connAddr, connTimeout) if err != nil { return err } - s.connAddr = connAddr s.conn = conn.(*net.TCPConn) return nil } @@ -264,7 +263,7 @@ func (s *tunnel) registerCounter(dstName, dstIP string, dstPort uint16) { } // 连接地址优先使用IP -func (s *tunnel) buildAddress(dstName, dstIP string, dstPort uint16) (network string, connAddr string) { +func (s *tunnel) buildAddress(dstName, dstIP string, dstPort uint16, addCounter bool) (network string, connAddr string) { network = "tcp" if dstIP != "" { if strings.Contains(dstIP, ":") { @@ -276,6 +275,10 @@ func (s *tunnel) buildAddress(dstName, dstIP string, dstPort uint16) (network st } else if dstName != "" { connAddr = fmt.Sprintf("%s:%d", dstName, dstPort) } + + if addCounter && connAddr != "" { + s.registerCounter(dstName, dstIP, dstPort) + } return } @@ -388,11 +391,11 @@ func (s *tunnel) handshake(proto string, dstName, dstIP string, dstPort uint16) if state != cache.StateFail { //local dial成功则返回,走本地网络 //auto 只能优化ip ping 不通的情况,能dail通访问不了的需要手动remote - //如果最终连的地址是相同的,也会复用 - network, connAddr := s.buildAddress(dstName, dstIP, dstPort) + network, connAddr := s.buildAddress(dstName, dstIP, dstPort, true) if connAddr != "" { - err = s.dail(network, connAddr) + err = s.dail(network, connAddr, 1) if err == nil { + log.Println(trace.ID(s.req.ID), fmt.Sprintf("auto to %s", connAddr)) s.curState = stateNew return } @@ -413,17 +416,16 @@ func (s *tunnel) handshake(proto string, dstName, dstIP string, dstPort uint16) if dstName == "" { dstName = dstIP } - targetNet, targetAddr = s.buildAddress(dstName, "", dstPort) + targetNet, targetAddr = s.buildAddress(dstName, "", dstPort, false) } else { - targetNet, targetAddr = s.buildAddress("", dstIP, dstPort) + targetNet, targetAddr = s.buildAddress("", dstIP, dstPort, false) } if targetAddr == "" || targetAddr[0] == ':' { err = errors.New("target host is empty") return } - network, connAddr := s.buildAddress(proxyServer, "", proxyPort) - s.registerCounter(proxyServer, "", proxyPort) + network, connAddr := s.buildAddress(proxyServer, "", proxyPort, true) switch proxyScheme { case "socks5": log.Println(trace.ID(s.req.ID), fmt.Sprintf("PROXY %s for %s", connAddr, targetAddr)) @@ -434,7 +436,7 @@ func (s *tunnel) handshake(proto string, dstName, dstIP string, dstPort uint16) case "http": if proto == protoHTTP { //可避免转发到charles显示2次域名,且部分电脑请求出错 log.Println(trace.ID(s.req.ID), fmt.Sprintf("PROXY %s", connAddr)) - err = s.dail(network, connAddr) + err = s.dail(network, connAddr, 0) } else { log.Println(trace.ID(s.req.ID), fmt.Sprintf("PROXY %s for %s", connAddr, targetAddr)) err = s.httpConnect(network, connAddr, targetAddr, false) @@ -444,15 +446,14 @@ func (s *tunnel) handshake(proto string, dstName, dstIP string, dstPort uint16) return } } else { - network, connAddr := s.buildAddress(dstName, dstIP, dstPort) + network, connAddr := s.buildAddress(dstName, dstIP, dstPort, true) if connAddr != "" { - s.registerCounter(dstName, dstIP, dstPort) if dstName == "" { log.Println(trace.ID(s.req.ID), fmt.Sprintf("direct to %s", connAddr)) } else { log.Println(trace.ID(s.req.ID), fmt.Sprintf("direct to %s for %s", connAddr, dstName)) } - err = s.dail(network, connAddr) + err = s.dail(network, connAddr, 0) } else { err = errors.New("dstName && dstIP is empty") } @@ -520,7 +521,7 @@ func (s *tunnel) socks5(network, connAddr string, targetNet, targetAddr string) // http代理 func (s *tunnel) httpConnect(network, connAddr string, target string, encrypt bool) (err error) { - err = s.dail(network, connAddr) + err = s.dail(network, connAddr, 0) if err != nil { log.Println(trace.ID(s.req.ID), "dail err", err.Error()) return