Skip to content

Commit

Permalink
fix tunnel bug
Browse files Browse the repository at this point in the history
  • Loading branch information
keminar committed Dec 11, 2022
1 parent fadb8a0 commit 13e4f4f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
2 changes: 1 addition & 1 deletion conf/router.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ hosts:
dns: remote
- name: www.baidu.com
match: equal
target: local
target: auto
- name: google
match: contain
target: deny
Expand Down
5 changes: 2 additions & 3 deletions proto/tcpcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@ func (that *tcpCopy) response() error {
that.req.DstIP = conf.RouterConfig.TcpCopy.IP
that.req.DstPort = conf.RouterConfig.TcpCopy.Port

network, connAddr := tunnel.buildAddress("", that.req.DstIP, that.req.DstPort)
network, connAddr := tunnel.buildAddress("", that.req.DstIP, that.req.DstPort, true)
if connAddr == "" {
err = errors.New("target address is empty")
return err
}
tunnel.registerCounter("", that.req.DstIP, that.req.DstPort)
err = tunnel.dail(network, connAddr)
err = tunnel.dail(network, connAddr, 0)
if err != nil {
log.Println(trace.ID(that.req.ID), "dail err", err.Error())
return err
Expand Down
39 changes: 20 additions & 19 deletions proto/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ func init() {
type tunnel struct {
req *Request
conn *net.TCPConn // 后端服务
connAddr string // 后端地址
curState int

inboundIP string // 来源IP
Expand Down Expand Up @@ -227,19 +226,19 @@ func (s *tunnel) logCopyErr(name string, err error) {
}

// dail tcp连接
func (s *tunnel) dail(network, connAddr string) error {
if s.connAddr == connAddr && s.conn != nil {
return nil
}
func (s *tunnel) dail(network, connAddr string, second int64) error {
if config.DebugLevel >= config.LevelLong {
log.Printf("%s create new connection to server %s\n", trace.ID(s.req.ID), connAddr)
}

connTimeout := time.Duration(5) * time.Second
if second > 0 {
connTimeout = time.Duration(second) * time.Second
}
conn, err := net.DialTimeout(network, connAddr, connTimeout)
if err != nil {
return err
}
s.connAddr = connAddr
s.conn = conn.(*net.TCPConn)
return nil
}
Expand All @@ -264,7 +263,7 @@ func (s *tunnel) registerCounter(dstName, dstIP string, dstPort uint16) {
}

// 连接地址优先使用IP
func (s *tunnel) buildAddress(dstName, dstIP string, dstPort uint16) (network string, connAddr string) {
func (s *tunnel) buildAddress(dstName, dstIP string, dstPort uint16, addCounter bool) (network string, connAddr string) {
network = "tcp"
if dstIP != "" {
if strings.Contains(dstIP, ":") {
Expand All @@ -276,6 +275,10 @@ func (s *tunnel) buildAddress(dstName, dstIP string, dstPort uint16) (network st
} else if dstName != "" {
connAddr = fmt.Sprintf("%s:%d", dstName, dstPort)
}

if addCounter && connAddr != "" {
s.registerCounter(dstName, dstIP, dstPort)
}
return
}

Expand Down Expand Up @@ -388,11 +391,11 @@ func (s *tunnel) handshake(proto string, dstName, dstIP string, dstPort uint16)
if state != cache.StateFail {
//local dial成功则返回,走本地网络
//auto 只能优化ip ping 不通的情况,能dail通访问不了的需要手动remote
//如果最终连的地址是相同的,也会复用
network, connAddr := s.buildAddress(dstName, dstIP, dstPort)
network, connAddr := s.buildAddress(dstName, dstIP, dstPort, true)
if connAddr != "" {
err = s.dail(network, connAddr)
err = s.dail(network, connAddr, 1)
if err == nil {
log.Println(trace.ID(s.req.ID), fmt.Sprintf("auto to %s", connAddr))
s.curState = stateNew
return
}
Expand All @@ -413,17 +416,16 @@ func (s *tunnel) handshake(proto string, dstName, dstIP string, dstPort uint16)
if dstName == "" {
dstName = dstIP
}
targetNet, targetAddr = s.buildAddress(dstName, "", dstPort)
targetNet, targetAddr = s.buildAddress(dstName, "", dstPort, false)
} else {
targetNet, targetAddr = s.buildAddress("", dstIP, dstPort)
targetNet, targetAddr = s.buildAddress("", dstIP, dstPort, false)
}
if targetAddr == "" || targetAddr[0] == ':' {
err = errors.New("target host is empty")
return
}

network, connAddr := s.buildAddress(proxyServer, "", proxyPort)
s.registerCounter(proxyServer, "", proxyPort)
network, connAddr := s.buildAddress(proxyServer, "", proxyPort, true)
switch proxyScheme {
case "socks5":
log.Println(trace.ID(s.req.ID), fmt.Sprintf("PROXY %s for %s", connAddr, targetAddr))
Expand All @@ -434,7 +436,7 @@ func (s *tunnel) handshake(proto string, dstName, dstIP string, dstPort uint16)
case "http":
if proto == protoHTTP { //可避免转发到charles显示2次域名,且部分电脑请求出错
log.Println(trace.ID(s.req.ID), fmt.Sprintf("PROXY %s", connAddr))
err = s.dail(network, connAddr)
err = s.dail(network, connAddr, 0)
} else {
log.Println(trace.ID(s.req.ID), fmt.Sprintf("PROXY %s for %s", connAddr, targetAddr))
err = s.httpConnect(network, connAddr, targetAddr, false)
Expand All @@ -444,15 +446,14 @@ func (s *tunnel) handshake(proto string, dstName, dstIP string, dstPort uint16)
return
}
} else {
network, connAddr := s.buildAddress(dstName, dstIP, dstPort)
network, connAddr := s.buildAddress(dstName, dstIP, dstPort, true)
if connAddr != "" {
s.registerCounter(dstName, dstIP, dstPort)
if dstName == "" {
log.Println(trace.ID(s.req.ID), fmt.Sprintf("direct to %s", connAddr))
} else {
log.Println(trace.ID(s.req.ID), fmt.Sprintf("direct to %s for %s", connAddr, dstName))
}
err = s.dail(network, connAddr)
err = s.dail(network, connAddr, 0)
} else {
err = errors.New("dstName && dstIP is empty")
}
Expand Down Expand Up @@ -520,7 +521,7 @@ func (s *tunnel) socks5(network, connAddr string, targetNet, targetAddr string)

// http代理
func (s *tunnel) httpConnect(network, connAddr string, target string, encrypt bool) (err error) {
err = s.dail(network, connAddr)
err = s.dail(network, connAddr, 0)
if err != nil {
log.Println(trace.ID(s.req.ID), "dail err", err.Error())
return
Expand Down

0 comments on commit 13e4f4f

Please sign in to comment.