Skip to content

Commit

Permalink
fix: cache id and addr in discv5
Browse files Browse the repository at this point in the history
revert
  • Loading branch information
thinkAfCod authored and GrapeBaBa committed Oct 23, 2024
1 parent d34a772 commit cfbffc5
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 83 deletions.
60 changes: 8 additions & 52 deletions p2p/discover/portal_protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,7 @@ func DefaultPortalProtocolConfig() *PortalProtocolConfig {
}

type PortalProtocol struct {
table *Table
cachedIdsLock sync.Mutex
cachedNodes map[string]*enode.Node
cachedIds map[string]enode.ID
table *Table

protocolId string
protocolName string
Expand Down Expand Up @@ -214,8 +211,6 @@ func NewPortalProtocol(config *PortalProtocolConfig, protocolId portalwire.Proto
closeCtx, cancelCloseCtx := context.WithCancel(context.Background())

protocol := &PortalProtocol{
cachedNodes: make(map[string]*enode.Node),
cachedIds: make(map[string]enode.ID),
protocolId: string(protocolId),
protocolName: protocolId.Name(),
ListenAddr: config.ListenAddr,
Expand Down Expand Up @@ -330,19 +325,11 @@ func (p *PortalProtocol) setupUDPListening() error {
func(buf []byte, addr *net.UDPAddr) (int, error) {
p.Log.Info("will send to target data", "ip", addr.IP.To4().String(), "port", addr.Port, "bufLength", len(buf))

p.cachedIdsLock.Lock()
defer p.cachedIdsLock.Unlock()
if n, ok := p.cachedNodes[addr.String()]; ok {
if n, ok := p.DiscV5.cachedAddrNode[addr.String()]; ok {
//_, err := p.DiscV5.TalkRequestToID(id, addr, string(portalwire.UTPNetwork), buf)
req := &v5wire.TalkRequest{Protocol: string(portalwire.Utp), Message: buf}
p.DiscV5.sendFromAnotherThreadWithNode(n, netip.AddrPortFrom(netutil.IPToAddr(addr.IP), uint16(addr.Port)), req)

return len(buf), err
} else if id, ok := p.cachedIds[addr.String()]; ok {
//_, err := p.DiscV5.TalkRequestToID(id, addr, string(portalwire.UTPNetwork), buf)
req := &v5wire.TalkRequest{Protocol: string(portalwire.Utp), Message: buf}
p.DiscV5.sendFromAnotherThread(id, netip.AddrPortFrom(netutil.IPToAddr(addr.IP), uint16(addr.Port)), req)

return len(buf), err
} else {
p.Log.Warn("not found target node info", "ip", addr.IP.To4().String(), "port", addr.Port, "bufLength", len(buf))
Expand Down Expand Up @@ -396,35 +383,6 @@ func (p *PortalProtocol) setupDiscV5AndTable() error {
return nil
}

func (p *PortalProtocol) cacheNode(node *enode.Node) {
p.cachedIdsLock.Lock()
defer p.cachedIdsLock.Unlock()
addr := &net.UDPAddr{IP: node.IP(), Port: node.UDP()}
if _, ok := p.cachedNodes[addr.String()]; !ok && node != nil {
p.cachedNodes[addr.String()] = node
}
}

func (p *PortalProtocol) cacheNodeId(id enode.ID, addr *net.UDPAddr) {
p.cachedIdsLock.Lock()
defer p.cachedIdsLock.Unlock()
if (id != enode.ID{}) {
p.cachedIds[addr.String()] = id
}
}

func (p *PortalProtocol) cacheNodeById(id enode.ID, addr *net.UDPAddr) {
go func() {
p.cacheNodeId(id, addr)
if _, ok := p.cachedNodes[addr.String()]; !ok {
n := p.ResolveNodeId(id)
if n != nil {
p.cacheNode(n)
}
}
}()
}

func (p *PortalProtocol) ping(node *enode.Node) (uint64, error) {
pong, err := p.pingInner(node)
if err != nil {
Expand Down Expand Up @@ -586,7 +544,6 @@ func (p *PortalProtocol) processOffer(target *enode.Node, resp []byte, request *
}

p.Log.Info("will process Offer", "id", target.ID(), "ip", target.IP().To4().String(), "port", target.UDP())
p.cacheNode(target)

accept := &portalwire.Accept{}
err = accept.UnmarshalSSZ(resp[1:])
Expand Down Expand Up @@ -724,7 +681,6 @@ func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte,
}

p.Log.Info("will process content", "id", target.ID(), "ip", target.IP().To4().String(), "port", target.UDP())
p.cacheNode(target)

switch resp[1] {
case portalwire.ContentRawSelector:
Expand Down Expand Up @@ -930,14 +886,18 @@ func (p *PortalProtocol) processPong(target *enode.Node, resp []byte) (*portalwi
}

func (p *PortalProtocol) handleUtpTalkRequest(id enode.ID, addr *net.UDPAddr, msg []byte) []byte {
p.cacheNodeById(id, addr)
if n := p.DiscV5.getNode(id); n != nil {
p.table.addInboundNode(n)
}
p.Log.Trace("receive utp data", "addr", addr, "msg-length", len(msg))
p.packetRouter.ReceiveMessage(msg, addr)
return []byte("")
}

func (p *PortalProtocol) handleTalkRequest(id enode.ID, addr *net.UDPAddr, msg []byte) []byte {
p.cacheNodeById(id, addr)
if n := p.DiscV5.getNode(id); n != nil {
p.table.addInboundNode(n)
}

msgCode := msg[0]

Expand Down Expand Up @@ -1120,8 +1080,6 @@ func (p *PortalProtocol) handleFindContent(id enode.ID, addr *net.UDPAddr, reque
return nil, err
}

p.cacheNodeById(id, addr)

if errors.Is(err, ContentNotFound) {
closestNodes := p.findNodesCloseToContent(contentId, portalFindnodesResultLimit)
for i, n := range closestNodes {
Expand Down Expand Up @@ -1316,8 +1274,6 @@ func (p *PortalProtocol) handleOffer(id enode.ID, addr *net.UDPAddr, request *po
}
}

p.cacheNodeById(id, addr)

idBuffer := make([]byte, 2)
if contentKeyBitlist.Count() != 0 {
connId := p.connIdGen.GenCid(id, false)
Expand Down
9 changes: 0 additions & 9 deletions p2p/discover/portal_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,6 @@ func TestPortalWireProtocolUdp(t *testing.T) {
assert.NoError(t, err)
time.Sleep(12 * time.Second)

node1.cacheNode(node2.localNode.Node())
node1.cacheNode(node3.localNode.Node())

node2.cacheNode(node1.localNode.Node())
node2.cacheNode(node3.localNode.Node())

node3.cacheNode(node1.localNode.Node())
node3.cacheNode(node2.localNode.Node())

udpAddrStr1 := fmt.Sprintf("%s:%d", node1.localNode.Node().IP(), node1.localNode.Node().UDP())
udpAddrStr2 := fmt.Sprintf("%s:%d", node2.localNode.Node().IP(), node2.localNode.Node().UDP())

Expand Down
44 changes: 27 additions & 17 deletions p2p/discover/v5_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,17 @@ type codecV5 interface {
// UDPv5 is the implementation of protocol version 5.
type UDPv5 struct {
// static fields
conn UDPConn
tab *Table
netrestrict *netutil.Netlist
priv *ecdsa.PrivateKey
localNode *enode.LocalNode
db *enode.DB
log log.Logger
clock mclock.Clock
validSchemes enr.IdentityScheme
conn UDPConn
tab *Table
cachedIds map[enode.ID]*enode.Node
cachedAddrNode map[string]*enode.Node
netrestrict *netutil.Netlist
priv *ecdsa.PrivateKey
localNode *enode.LocalNode
db *enode.DB
log log.Logger
clock mclock.Clock
validSchemes enr.IdentityScheme

// misc buffers used during message handling
logcontext []interface{}
Expand Down Expand Up @@ -151,14 +153,16 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) {
cfg = cfg.withDefaults()
t := &UDPv5{
// static fields
conn: newMeteredConn(conn),
localNode: ln,
db: ln.Database(),
netrestrict: cfg.NetRestrict,
priv: cfg.PrivateKey,
log: cfg.Log,
validSchemes: cfg.ValidSchemes,
clock: cfg.Clock,
conn: newMeteredConn(conn),
cachedAddrNode: make(map[string]*enode.Node),
cachedIds: make(map[enode.ID]*enode.Node),
localNode: ln,
db: ln.Database(),
netrestrict: cfg.NetRestrict,
priv: cfg.PrivateKey,
log: cfg.Log,
validSchemes: cfg.ValidSchemes,
clock: cfg.Clock,
// channels into dispatch
packetInCh: make(chan ReadPacket, 1),
readNextCh: make(chan struct{}, 1),
Expand Down Expand Up @@ -724,6 +728,10 @@ func (t *UDPv5) send(toID enode.ID, toAddr netip.AddrPort, packet v5wire.Packet,
t.log.Warn(">> "+packet.Name(), t.logcontext...)
return nonce, err
}
if c != nil && c.Node != nil {
t.cachedIds[toID] = c.Node
t.cachedAddrNode[toAddr.String()] = c.Node
}

_, err = t.conn.WriteToUDPAddrPort(enc, toAddr)
t.log.Trace(">> "+packet.Name(), t.logcontext...)
Expand Down Expand Up @@ -785,6 +793,8 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr netip.AddrPort) error {
if fromNode != nil {
// Handshake succeeded, add to table.
t.tab.addInboundNode(fromNode)
t.cachedIds[fromID] = fromNode
t.cachedAddrNode[fromAddr.String()] = fromNode
}
if packet.Kind() != v5wire.WhoareyouPacket {
// WHOAREYOU logged separately to report errors.
Expand Down
5 changes: 0 additions & 5 deletions portalnetwork/history/history_network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,6 @@ func TestGetContentByKey(t *testing.T) {
contentId := historyNetwork1.portalProtocol.ToContentId(headerEntry.key)
err = historyNetwork1.portalProtocol.Put(headerEntry.key, contentId, headerEntry.value)
require.NoError(t, err)

header, err = historyNetwork1.GetBlockHeader(headerEntry.key[1:])
require.NoError(t, err)
require.NotNil(t, header)

// get content from historyNetwork1
header, err = historyNetwork2.GetBlockHeader(headerEntry.key[1:])
require.NoError(t, err)
Expand Down

0 comments on commit cfbffc5

Please sign in to comment.