diff --git a/cmd/shisui/main.go b/cmd/shisui/main.go index 16243775bd1c..e4d3065dacd3 100644 --- a/cmd/shisui/main.go +++ b/cmd/shisui/main.go @@ -6,9 +6,11 @@ import ( "fmt" "net" "net/http" + "os/signal" "path" "slices" "strings" + "syscall" "os" @@ -45,6 +47,14 @@ type Config struct { Networks []string } +type Client struct { + DiscV5API *discover.DiscV5API + HistoryNetwork *history.HistoryNetwork + BeaconNetwork *beacon.BeaconNetwork + StateNetwork *state.StateNetwork + Server *http.Server +} + var app = flags.NewApp("the go-portal-network command line interface") var ( @@ -85,6 +95,9 @@ func shisui(ctx *cli.Context) error { setDefaultLogger(*config) + clientChan := make(chan *Client, 1) + go handlerInterrupt(clientChan) + addr, err := net.ResolveUDPAddr("udp", config.Protocol.ListenAddr) if err != nil { return err @@ -94,7 +107,7 @@ func shisui(ctx *cli.Context) error { return err } - return startPortalRpcServer(*config, conn, config.RpcAddr) + return startPortalRpcServer(*config, conn, config.RpcAddr, clientChan) } func setDefaultLogger(config Config) { @@ -105,7 +118,51 @@ func setDefaultLogger(config Config) { log.SetDefault(defaultLogger) } -func startPortalRpcServer(config Config, conn discover.UDPConn, addr string) error { +func handlerInterrupt(clientChan <-chan *Client) { + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(interrupt) + + <-interrupt + log.Warn("Closing Shisui gracefully (type CTRL-C again to force quit)") + + go func() { + if len(clientChan) == 0 { + log.Warn("Waiting for the client to start...") + } + c := <-clientChan + c.closePortalRpcServer() + }() + + <-interrupt + os.Exit(1) +} + +func (cli *Client) closePortalRpcServer() { + if cli.HistoryNetwork != nil { + log.Info("Closing history network...") + cli.HistoryNetwork.Stop() + } + if cli.BeaconNetwork != nil { + log.Info("Closing beacon network...") + cli.BeaconNetwork.Stop() + } + if cli.StateNetwork != nil { + log.Info("Closing state network...") + cli.StateNetwork.Stop() + } + log.Info("Closing Database...") + cli.DiscV5API.DiscV5.LocalNode().Database().Close() + log.Info("Closing UDPv5 protocol...") + cli.DiscV5API.DiscV5.Close() + log.Info("Closing servers...") + cli.Server.Close() + os.Exit(1) +} + +func startPortalRpcServer(config Config, conn discover.UDPConn, addr string, clientChan chan<- *Client) error { + client := &Client{} + discV5, localNode, err := initDiscV5(config, conn) if err != nil { return err @@ -117,6 +174,7 @@ func startPortalRpcServer(config Config, conn discover.UDPConn, addr string) err if err != nil { return err } + client.DiscV5API = discV5API api := &web3.API{} err = server.RegisterName("web3", api) @@ -130,20 +188,25 @@ func startPortalRpcServer(config Config, conn discover.UDPConn, addr string) err if err != nil { return err } + client.HistoryNetwork = historyNetwork } + var beaconNetwork *beacon.BeaconNetwork if slices.Contains(config.Networks, portalwire.Beacon.Name()) { - err = initBeacon(config, server, conn, localNode, discV5) + beaconNetwork, err = initBeacon(config, server, conn, localNode, discV5) if err != nil { return err } + client.BeaconNetwork = beaconNetwork } + var stateNetwork *state.StateNetwork if slices.Contains(config.Networks, portalwire.State.Name()) { - err = initState(config, server, conn, localNode, discV5) + stateNetwork, err = initState(config, server, conn, localNode, discV5) if err != nil { return err } + client.StateNetwork = stateNetwork } ethapi := ðapi.API{ @@ -160,6 +223,9 @@ func startPortalRpcServer(config Config, conn discover.UDPConn, addr string) err Addr: addr, Handler: server, } + client.Server = httpServer + + clientChan <- client return httpServer.ListenAndServe() } @@ -223,15 +289,15 @@ func initHistory(config Config, server *rpc.Server, conn discover.UDPConn, local return historyNetwork, historyNetwork.Start() } -func initBeacon(config Config, server *rpc.Server, conn discover.UDPConn, localNode *enode.LocalNode, discV5 *discover.UDPv5) error { +func initBeacon(config Config, server *rpc.Server, conn discover.UDPConn, localNode *enode.LocalNode, discV5 *discover.UDPv5) (*beacon.BeaconNetwork, error) { dbPath := path.Join(config.DataDir, "beacon") err := os.MkdirAll(dbPath, 0755) if err != nil { - return err + return nil, err } sqlDb, err := sql.Open("sqlite3", path.Join(dbPath, "beacon.sqlite")) if err != nil { - return err + return nil, err } contentStorage, err := beacon.NewBeaconStorage(storage.PortalStorageConfig{ @@ -241,32 +307,32 @@ func initBeacon(config Config, server *rpc.Server, conn discover.UDPConn, localN Spec: configs.Mainnet, }) if err != nil { - return err + return nil, err } contentQueue := make(chan *discover.ContentElement, 50) protocol, err := discover.NewPortalProtocol(config.Protocol, portalwire.Beacon, config.PrivateKey, conn, localNode, discV5, contentStorage, contentQueue) if err != nil { - return err + return nil, err } portalApi := discover.NewPortalAPI(protocol) beaconAPI := beacon.NewBeaconNetworkAPI(portalApi) err = server.RegisterName("portal", beaconAPI) if err != nil { - return err + return nil, err } beaconNetwork := beacon.NewBeaconNetwork(protocol) - return beaconNetwork.Start() + return beaconNetwork, beaconNetwork.Start() } -func initState(config Config, server *rpc.Server, conn discover.UDPConn, localNode *enode.LocalNode, discV5 *discover.UDPv5) error { +func initState(config Config, server *rpc.Server, conn discover.UDPConn, localNode *enode.LocalNode, discV5 *discover.UDPv5) (*state.StateNetwork, error) { networkName := portalwire.State.Name() db, err := history.NewDB(config.DataDir, networkName) if err != nil { - return err + return nil, err } contentStorage, err := history.NewHistoryStorage(storage.PortalStorageConfig{ StorageCapacityMB: config.DataCapacity, @@ -275,24 +341,24 @@ func initState(config Config, server *rpc.Server, conn discover.UDPConn, localNo NetworkName: networkName, }) if err != nil { - return err + return nil, err } contentQueue := make(chan *discover.ContentElement, 50) protocol, err := discover.NewPortalProtocol(config.Protocol, portalwire.State, config.PrivateKey, conn, localNode, discV5, contentStorage, contentQueue) if err != nil { - return err + return nil, err } api := discover.NewPortalAPI(protocol) stateNetworkAPI := state.NewStateNetworkAPI(api) err = server.RegisterName("portal", stateNetworkAPI) if err != nil { - return err + return nil, err } client := rpc.DialInProc(server) historyNetwork := state.NewStateNetwork(protocol, client) - return historyNetwork.Start() + return historyNetwork, historyNetwork.Start() } func getPortalConfig(ctx *cli.Context) (*Config, error) {