From e5b9cc3f5b14df56dbfb98fadbdfa3a756e96527 Mon Sep 17 00:00:00 2001 From: liminggui Date: Mon, 12 Dec 2022 09:48:04 +0800 Subject: [PATCH] fix stream bug --- proto/socks5.go | 298 ++++++++++++++++++++++--------------------- proto/stats/stats.go | 94 +++++++------- 2 files changed, 198 insertions(+), 194 deletions(-) diff --git a/proto/socks5.go b/proto/socks5.go index 597310c..4dc2858 100644 --- a/proto/socks5.go +++ b/proto/socks5.go @@ -1,147 +1,151 @@ -package proto - -import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - "log" - "net" - "strconv" - - "github.com/keminar/anyproxy/utils/trace" -) - -type socks5Stream struct { - req *Request -} - -func newSocks5Stream(req *Request) *socks5Stream { - c := &socks5Stream{ - req: req, - } - return c -} - -func (that *socks5Stream) validHead() bool { - if that.req.reader.Buffered() < 2 { - return false - } - - var buffer [1024]byte - n, err := that.req.reader.Read(buffer[:]) - if err != nil { - return false - } - tmpBuf := buffer[:n] - return len(tmpBuf) >= 2 && tmpBuf[0] == 0x05 -} - -func (that *socks5Stream) readRequest(from string) (canProxy bool, err error) { - if err = that.ParseHeader(); err != nil { - return false, err - } - return true, nil -} - -func (that *socks5Stream) response() error { - tunnel := newTunnel(that.req) - if ip, ok := tunnel.isAllowed(); !ok { - return errors.New(ip + " is not allowed") - } - - var err error - // 发送socks5应答 - _, err = that.req.conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) - if err != nil { - log.Println(trace.ID(that.req.ID), "write err", err.Error()) - return err - } - - that.showIP() - err = tunnel.handshake(protoTCP, that.req.DstName, that.req.DstIP, that.req.DstPort) - if err != nil { - log.Println(trace.ID(that.req.ID), "handshake err", err.Error()) - return err - } - - tunnel.transfer(-1) - return nil -} - -func (that *socks5Stream) showIP() { - if that.req.DstName != "" { - log.Println(trace.ID(that.req.ID), fmt.Sprintf("%s %s -> %s:%d", "Socks5", that.req.conn.RemoteAddr().String(), that.req.DstName, that.req.DstPort)) - } else { - log.Println(trace.ID(that.req.ID), fmt.Sprintf("%s %s -> %s:%d", "Socks5", that.req.conn.RemoteAddr().String(), that.req.DstIP, that.req.DstPort)) - } -} - -// parsing socks5 header, and return address and parsing error -func (that *socks5Stream) ParseHeader() error { - // response to socks5 client - // see rfc 1982 for more details (https://tools.ietf.org/html/rfc1928) - n, err := that.req.conn.Write([]byte{0x05, 0x00}) // version and no authentication required - if err != nil { - return err - } - - // step2: process client Requests and does Reply - /** - +----+-----+-------+------+----------+----------+ - |VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT | - +----+-----+-------+------+----------+----------+ - | 1 | 1 | X'00' | 1 | Variable | 2 | - +----+-----+-------+------+----------+----------+ - */ - var buffer [1024]byte - n, err = that.req.reader.Read(buffer[:]) - if err != nil { - return err - } - if n < 6 { - return errors.New("not a socks protocol") - } - - switch buffer[3] { - case 0x01: - // ipv4 address - ipv4 := make([]byte, 4) - if _, err := io.ReadAtLeast(bytes.NewReader(buffer[4:]), ipv4, len(ipv4)); err != nil { - return err - } - //fmt.Println(1) - that.req.DstIP = net.IP(ipv4).String() - case 0x04: - // ipv6 - ipv6 := make([]byte, 16) - if _, err := io.ReadAtLeast(bytes.NewReader(buffer[4:]), ipv6, len(ipv6)); err != nil { - return err - } - that.req.DstIP = net.IP(ipv6).String() - case 0x03: - // domain - addrLen := int(buffer[4]) - domain := make([]byte, addrLen) - if _, err := io.ReadAtLeast(bytes.NewReader(buffer[5:]), domain, addrLen); err != nil { - return err - } - //fmt.Println(2) - that.req.DstName = string(domain) - } - - port := make([]byte, 2) - err = binary.Read(bytes.NewReader(buffer[n-2:n]), binary.BigEndian, &port) - if err != nil { - return err - } - - portStr := strconv.Itoa((int(port[0]) << 8) | int(port[1])) - c, err := strconv.ParseUint(portStr, 0, 16) - if err != nil { - return err - } - that.req.DstPort = uint16(c) - return nil -} +package proto + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "log" + "net" + "strconv" + + "github.com/keminar/anyproxy/utils/trace" +) + +type socks5Stream struct { + req *Request +} + +func newSocks5Stream(req *Request) *socks5Stream { + c := &socks5Stream{ + req: req, + } + return c +} + +func (that *socks5Stream) validHead() bool { + if that.req.reader.Buffered() < 2 { + return false + } + + tmpBuf, err := that.req.reader.Peek(2) + if err != nil { + return false + } + + isSocks5 := len(tmpBuf) >= 2 && tmpBuf[0] == 0x05 + if isSocks5 { + // 如果是SOCKS5则把已读信息从缓存区释放掉 + that.req.reader.UnreadBuf(-1) + } + return isSocks5 +} + +func (that *socks5Stream) readRequest(from string) (canProxy bool, err error) { + if err = that.ParseHeader(); err != nil { + return false, err + } + return true, nil +} + +func (that *socks5Stream) response() error { + tunnel := newTunnel(that.req) + if ip, ok := tunnel.isAllowed(); !ok { + return errors.New(ip + " is not allowed") + } + + var err error + // 发送socks5应答 + _, err = that.req.conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + if err != nil { + log.Println(trace.ID(that.req.ID), "write err", err.Error()) + return err + } + + that.showIP() + err = tunnel.handshake(protoTCP, that.req.DstName, that.req.DstIP, that.req.DstPort) + if err != nil { + log.Println(trace.ID(that.req.ID), "handshake err", err.Error()) + return err + } + + tunnel.transfer(-1) + return nil +} + +func (that *socks5Stream) showIP() { + if that.req.DstName != "" { + log.Println(trace.ID(that.req.ID), fmt.Sprintf("%s %s -> %s:%d", "Socks5", that.req.conn.RemoteAddr().String(), that.req.DstName, that.req.DstPort)) + } else { + log.Println(trace.ID(that.req.ID), fmt.Sprintf("%s %s -> %s:%d", "Socks5", that.req.conn.RemoteAddr().String(), that.req.DstIP, that.req.DstPort)) + } +} + +// parsing socks5 header, and return address and parsing error +func (that *socks5Stream) ParseHeader() error { + // response to socks5 client + // see rfc 1982 for more details (https://tools.ietf.org/html/rfc1928) + n, err := that.req.conn.Write([]byte{0x05, 0x00}) // version and no authentication required + if err != nil { + return err + } + + // step2: process client Requests and does Reply + /** + +----+-----+-------+------+----------+----------+ + |VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT | + +----+-----+-------+------+----------+----------+ + | 1 | 1 | X'00' | 1 | Variable | 2 | + +----+-----+-------+------+----------+----------+ + */ + var buffer [1024]byte + n, err = that.req.reader.Read(buffer[:]) + if err != nil { + return err + } + if n < 6 { + return errors.New("not a socks protocol") + } + + switch buffer[3] { + case 0x01: + // ipv4 address + ipv4 := make([]byte, 4) + if _, err := io.ReadAtLeast(bytes.NewReader(buffer[4:]), ipv4, len(ipv4)); err != nil { + return err + } + //fmt.Println(1) + that.req.DstIP = net.IP(ipv4).String() + case 0x04: + // ipv6 + ipv6 := make([]byte, 16) + if _, err := io.ReadAtLeast(bytes.NewReader(buffer[4:]), ipv6, len(ipv6)); err != nil { + return err + } + that.req.DstIP = net.IP(ipv6).String() + case 0x03: + // domain + addrLen := int(buffer[4]) + domain := make([]byte, addrLen) + if _, err := io.ReadAtLeast(bytes.NewReader(buffer[5:]), domain, addrLen); err != nil { + return err + } + //fmt.Println(2) + that.req.DstName = string(domain) + } + + port := make([]byte, 2) + err = binary.Read(bytes.NewReader(buffer[n-2:n]), binary.BigEndian, &port) + if err != nil { + return err + } + + portStr := strconv.Itoa((int(port[0]) << 8) | int(port[1])) + c, err := strconv.ParseUint(portStr, 0, 16) + if err != nil { + return err + } + that.req.DstPort = uint16(c) + return nil +} diff --git a/proto/stats/stats.go b/proto/stats/stats.go index 5516581..eaefe36 100644 --- a/proto/stats/stats.go +++ b/proto/stats/stats.go @@ -1,47 +1,47 @@ -package stats - -import ( - "log" - "sync" - "time" -) - -type Manager struct { - access sync.RWMutex - counters map[string]*Counter -} - -func NewManager() *Manager { - m := &Manager{ - counters: make(map[string]*Counter), - } - return m -} - -func (m *Manager) RegisterCounter(name string) *Counter { - m.access.Lock() - defer m.access.Unlock() - - if _, found := m.counters[name]; found { - m.counters[name].active = time.Now().Unix() - return m.counters[name] - } - c := new(Counter) - c.name = name - m.counters[name] = c - return c -} - -func (m *Manager) UnregisterCounter() { - m.access.Lock() - defer m.access.Unlock() - - now := time.Now().Unix() - - for _, v := range m.counters { - if now-v.active > 300 { - delete(m.counters, v.name) - } - } - log.Println("stats links:", len(m.counters)) -} +package stats + +import ( + "log" + "sync" + "time" +) + +type Manager struct { + access sync.RWMutex + counters map[string]*Counter +} + +func NewManager() *Manager { + m := &Manager{ + counters: make(map[string]*Counter), + } + return m +} + +func (m *Manager) RegisterCounter(name string) *Counter { + m.access.Lock() + defer m.access.Unlock() + + if _, found := m.counters[name]; found { + m.counters[name].active = time.Now().Unix() + return m.counters[name] + } + c := new(Counter) + c.name = name + m.counters[name] = c + return c +} + +func (m *Manager) UnregisterCounter() { + m.access.Lock() + defer m.access.Unlock() + + now := time.Now().Unix() + + for _, v := range m.counters { + if now-v.active > 300 { + delete(m.counters, v.name) + } + } + log.Println("stats links:", len(m.counters)) +}