Skip to content

Commit

Permalink
Merge pull request #178 from r4f4ss/save_private_key
Browse files Browse the repository at this point in the history
Save/load private key to file "clientKey"
  • Loading branch information
r4f4ss authored Oct 1, 2024
2 parents 80d0a8b + fffa3c7 commit ee38628
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 13 deletions.
36 changes: 34 additions & 2 deletions cmd/shisui/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package main

import (
"flag"
"os"
"path/filepath"
"testing"

"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"
)
Expand All @@ -13,7 +16,8 @@ func TestGenConfig(t *testing.T) {
flagSet := flag.NewFlagSet("test", 0)
flagSet.String("rpc.addr", "127.0.0.11", "test")
flagSet.String("rpc.port", "8888", "test")
flagSet.String("data.dir", "./test", "test")
tmpDir := t.TempDir()
flagSet.String("data.dir", tmpDir, "test")
flagSet.Uint64("data.capacity", size, "test")
// flagSet.String("udp.addr", "172.23.50.11", "test")
flagSet.Int("udp.port", 9999, "test")
Expand All @@ -30,9 +34,37 @@ func TestGenConfig(t *testing.T) {
require.NoError(t, err)

require.Equal(t, config.DataCapacity, size)
require.Equal(t, config.DataDir, "./test")
require.Equal(t, config.DataDir, tmpDir)
require.Equal(t, config.LogLevel, 3)
// require.Equal(t, config.RpcAddr, "127.0.0.11:8888")
require.Equal(t, config.Protocol.ListenAddr, ":9999")
require.Equal(t, config.Networks, []string{"history"})
}

func TestKeyConfig(t *testing.T) {
flagSet := flag.NewFlagSet("test", 0)
tmpDir := t.TempDir()
flagSet.String("data.dir", tmpDir, "test")
pk := "a19d7a264e68004832327fca0ac46636332e0ec4b2a20a7ac942020754fcb666"
flagSet.String("private.key", "0x"+pk, "test")

command := &cli.Command{Name: "mycommand"}

ctx := cli.NewContext(nil, flagSet, nil)
ctx.Command = command

config, err := getPortalConfig(ctx)
require.NoError(t, err)

require.Equal(t, config.DataDir, tmpDir)

keyPk, err := crypto.HexToECDSA(pk)
require.Nil(t, err)
require.Equal(t, config.PrivateKey, keyPk)

fullPath := filepath.Join(config.DataDir, privateKeyFileName)
keyStored, err := os.ReadFile(fullPath)
require.Nil(t, err)
keyEnc := string(keyStored)
require.Equal(t, keyEnc, pk)
}
78 changes: 67 additions & 11 deletions cmd/shisui/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package main
import (
"crypto/ecdsa"
"database/sql"
"encoding/hex"
"fmt"
"net"
"net/http"
"os/signal"
"path"
"path/filepath"
"slices"
"strings"
"syscall"
Expand Down Expand Up @@ -37,6 +39,10 @@ import (
"github.com/urfave/cli/v2"
)

const (
privateKeyFileName = "clientKey"
)

type Config struct {
Protocol *discover.PortalProtocolConfig
PrivateKey *ecdsa.PrivateKey
Expand Down Expand Up @@ -88,13 +94,13 @@ func main() {
}

func shisui(ctx *cli.Context) error {
setDefaultLogger(ctx.Int(utils.PortalLogLevelFlag.Name))

config, err := getPortalConfig(ctx)
if err != nil {
return nil
}

setDefaultLogger(*config)

clientChan := make(chan *Client, 1)
go handlerInterrupt(clientChan)

Expand All @@ -110,9 +116,9 @@ func shisui(ctx *cli.Context) error {
return startPortalRpcServer(*config, conn, config.RpcAddr, clientChan)
}

func setDefaultLogger(config Config) {
func setDefaultLogger(logLevel int) {
glogger := log.NewGlogHandler(log.NewTerminalHandler(os.Stderr, true))
slogVerbosity := log.FromLegacyLevel(config.LogLevel)
slogVerbosity := log.FromLegacyLevel(logLevel)
glogger.Verbosity(slogVerbosity)
defaultLogger := log.NewLogger(glogger)
log.SetDefault(defaultLogger)
Expand Down Expand Up @@ -366,10 +372,6 @@ func getPortalConfig(ctx *cli.Context) (*Config, error) {
config := &Config{
Protocol: discover.DefaultPortalProtocolConfig(),
}
err := setPrivateKey(ctx, config)
if err != nil {
return config, err
}

httpAddr := ctx.String(utils.PortalRPCListenAddrFlag.Name)
httpPort := ctx.String(utils.PortalRPCPortFlag.Name)
Expand All @@ -384,6 +386,11 @@ func getPortalConfig(ctx *cli.Context) (*Config, error) {
config.Protocol.ListenAddr = port
}

err := setPrivateKey(ctx, config)
if err != nil {
return config, err
}

natString := ctx.String(utils.PortalNATFlag.Name)
if natString != "" {
natInterface, err := nat.Parse(natString)
Expand Down Expand Up @@ -412,15 +419,64 @@ func setPrivateKey(ctx *cli.Context, config *Config) error {
return err
}
} else {
privateKey, err = crypto.GenerateKey()
if err != nil {
return err
if _, err := os.Stat(filepath.Join(config.DataDir, privateKeyFileName)); err == nil {
log.Info("Loading private key from file", "datadir", config.DataDir, "file", privateKeyFileName)
privateKey, err = readPrivateKey(config, privateKeyFileName)
if err != nil {
return err
}
} else {
log.Info("Creating new private key")
privateKey, err = crypto.GenerateKey()
if err != nil {
return err
}
}
}

config.PrivateKey = privateKey
err = writePrivateKey(privateKey, config, privateKeyFileName)
if err != nil {
return err
}
return nil
}

func writePrivateKey(privateKey *ecdsa.PrivateKey, config *Config, fileName string) error {
keyEnc := hex.EncodeToString(crypto.FromECDSA(privateKey))

fullPath := filepath.Join(config.DataDir, fileName)
file, err := os.OpenFile(fullPath, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return err
}
defer file.Close()

_, err = file.WriteString(keyEnc)
if err != nil {
return err
}

return nil
}

func readPrivateKey(config *Config, fileName string) (*ecdsa.PrivateKey, error) {
fullPath := filepath.Join(config.DataDir, fileName)

keyBytes, err := os.ReadFile(fullPath)
if err != nil {
return nil, err
}

keyEnc := string(keyBytes)
key, err := crypto.HexToECDSA(keyEnc)
if err != nil {
return nil, err
}

return key, nil
}

// setPortalBootstrapNodes creates a list of bootstrap nodes from the command line
// flags, reverting to pre-configured ones if none have been specified.
func setPortalBootstrapNodes(ctx *cli.Context, config *Config) {
Expand Down

0 comments on commit ee38628

Please sign in to comment.