From 3fcb46e431b98cf7018d0b96f8ada210197aab8c Mon Sep 17 00:00:00 2001 From: Aadhar Sachdeva Date: Sun, 14 Feb 2021 13:35:09 +0530 Subject: [PATCH 1/6] Sane state, refactor pending --- cmd/dkvctl/main.go | 32 ++++- cmd/dkvsrv/main.go | 216 ++++++++++++++++++++++++++--- go.mod | 4 +- go.sum | 7 + internal/master/ds_service_test.go | 4 +- internal/master/ss_service_test.go | 2 +- internal/slave/service_test.go | 2 +- internal/utils.go | 110 +++++++++++++++ pkg/ctl/client.go | 10 +- 9 files changed, 359 insertions(+), 28 deletions(-) create mode 100644 internal/utils.go diff --git a/cmd/dkvctl/main.go b/cmd/dkvctl/main.go index e537e9c0..02d0189c 100644 --- a/cmd/dkvctl/main.go +++ b/cmd/dkvctl/main.go @@ -3,6 +3,7 @@ package main import ( "flag" "fmt" + utils "github.com/flipkart-incubator/dkv/internal" "os" "strconv" "strings" @@ -186,10 +187,20 @@ func (c *cmd) replica(client *ctl.DKVClient, args ...string) { } } -var dkvAddr string +var ( + dkvAddr string + dkvMode string + certPath string + keyPath string + caCertPath string +) func init() { flag.StringVar(&dkvAddr, "dkvAddr", "127.0.0.1:8080", ": - DKV server address") + flag.StringVar(&dkvMode, "dkvMode", "insecure", "Security mode for the DKV server node - insecure|autoTLS|serverTLS|mutualTLS") + flag.StringVar(&certPath, "certPath", "", "Path for certificate file of this node") + flag.StringVar(&keyPath, "keyPath", "", "Path for key file of this node") + flag.StringVar(&caCertPath, "caCertPath", "", "Path for root certificate of the chain, i.e. CA certificate") for _, c := range cmds { flag.StringVar(&c.value, c.name, c.value, c.cmdDesc) } @@ -198,13 +209,21 @@ func init() { func usage() { fmt.Printf("Usage of %s:\n", os.Args[0]) - dkvAddrFlag := flag.Lookup("dkvAddr") - fmt.Printf(" -dkvAddr %s (default: %s)\n", dkvAddrFlag.Usage, dkvAddrFlag.DefValue) + printUsage([]string {"dkvAddr", "dkvMode", "certPath", "keyPath", "caCertPath"}) for _, cmd := range cmds { cmd.usage() } } +func printUsage(flags []string) { + for _, flagName := range flags { + dkvFlag := flag.Lookup(flagName) + if dkvFlag != nil { + fmt.Printf(" -%s %s (default: %s)\n", dkvFlag.Name, dkvFlag.Usage, dkvFlag.DefValue) + } + } +} + func trimLower(str string) string { return strings.ToLower(strings.TrimSpace(str)) } @@ -216,8 +235,13 @@ func main() { } flag.Parse() + utils.ValidateDKVConfig(utils.DKVConfig{ServerMode: dkvMode, + SrvrAddr: dkvAddr, KeyPath: keyPath, CertPath: certPath, + CaCertPath: caCertPath}) fmt.Printf("Connecting to DKV service at %s...", dkvAddr) - client, err := ctl.NewInSecureDKVClient(dkvAddr) + client, err := utils.NewDKVClient(utils.DKVConfig{ServerMode: dkvMode, + SrvrAddr: dkvAddr, KeyPath: keyPath, CertPath: certPath, + CaCertPath: caCertPath}) if err != nil { fmt.Printf("\nUnable to create DKV client. Error: %v\n", err) return diff --git a/cmd/dkvsrv/main.go b/cmd/dkvsrv/main.go index f14654b7..92b2bde1 100644 --- a/cmd/dkvsrv/main.go +++ b/cmd/dkvsrv/main.go @@ -1,17 +1,29 @@ package main import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "flag" "fmt" + "google.golang.org/grpc/credentials" + "io/ioutil" "log" + "math/big" "net" "os" "os/signal" "path" + "path/filepath" "strings" "syscall" "time" + utils "github.com/flipkart-incubator/dkv/internal" "github.com/flipkart-incubator/dkv/internal/master" "github.com/flipkart-incubator/dkv/internal/slave" "github.com/flipkart-incubator/dkv/internal/stats" @@ -19,7 +31,6 @@ import ( "github.com/flipkart-incubator/dkv/internal/storage/badger" "github.com/flipkart-incubator/dkv/internal/storage/rocksdb" "github.com/flipkart-incubator/dkv/internal/sync" - "github.com/flipkart-incubator/dkv/pkg/ctl" "github.com/flipkart-incubator/dkv/pkg/serverpb" nexus_api "github.com/flipkart-incubator/nexus/pkg/api" nexus "github.com/flipkart-incubator/nexus/pkg/raft" @@ -31,14 +42,21 @@ import ( ) var ( - disklessMode bool - dbEngine string - dbFolder string - dbListenAddr string - dbRole string - statsdAddr string - replMasterAddr string - replPollInterval time.Duration + disklessMode bool + dbEngine string + dbFolder string + dbListenAddr string + dbRole string + statsdAddr string + replMasterAddr string + masterMode string + replPollInterval time.Duration + mode string + certPath string + keyPath string + caCertPath string + generatedCertValidity int + generatedCertDir string // Logging vars dbAccessLog string @@ -59,9 +77,16 @@ func init() { flag.StringVar(&dbRole, "dbRole", "none", "DB role of this node - none|master|slave") flag.StringVar(&statsdAddr, "statsdAddr", "", "StatsD service address in host:port format") flag.StringVar(&replMasterAddr, "replMasterAddr", "", "Service address of DKV master node for replication") + flag.StringVar(&masterMode, "masterMode", "insecure", "Service address of DKV master node for replication") flag.DurationVar(&replPollInterval, "replPollInterval", 5*time.Second, "Interval used for polling changes from master. Eg., 10s, 5ms, 2h, etc.") flag.StringVar(&dbAccessLog, "dbAccessLog", "", "File for logging DKV accesses eg., stdout, stderr, /tmp/access.log") + flag.StringVar(&mode, "mode", "insecure", "Security mode for this node - insecure|autoTLS|serverTLS|mutualTLS") + flag.StringVar(&certPath, "certPath", "", "Path for certificate file of this node") + flag.StringVar(&keyPath, "keyPath", "", "Path for key file of this node") + flag.StringVar(&caCertPath, "caCertPath", "", "Path for root certificate of the chain, i.e. CA certificate") flag.BoolVar(&verboseLogging, "verbose", false, fmt.Sprintf("Enable verbose logging.\nBy default, only warnings and errors are logged. (default %v)", verboseLogging)) + flag.IntVar(&generatedCertValidity, "generatedCertValidity", 365, "Validity(in days) for the generated certificate") + flag.StringVar(&generatedCertDir, "generatedCertDir", "/tmp/dkv-certs", "Directory to store the generate key and certificates") setDKVDefaultsForNexusDirs() } @@ -109,7 +134,11 @@ func main() { serverpb.RegisterDKVServer(grpcSrvr, dkvSvc) serverpb.RegisterDKVReplicationServer(grpcSrvr, dkvSvc) case slaveRole: - if replCli, err := ctl.NewInSecureDKVClient(replMasterAddr); err != nil { + if replCli, err := utils.NewDKVClient(utils.DKVConfig{ServerMode: masterMode, + SrvrAddr: replMasterAddr, KeyPath: keyPath, CertPath: certPath, + CaCertPath: caCertPath}); + + err != nil { panic(err) } else { defer replCli.Close() @@ -138,11 +167,38 @@ func validateFlags() { if disklessMode && (strings.ToLower(dbEngine) == "rocksdb" || strings.ToLower(dbRole) == masterRole) { log.Panicf("diskless is available only on Badger storage and for standalone and slave roles") } - if strings.ToLower(dbRole) == slaveRole && replMasterAddr == "" { + if strings.ToLower(dbRole) == slaveRole && (replMasterAddr == "" || masterMode == "") { log.Panicf("replMasterAddr must be given in slave mode") } + + if strings.ToLower(dbRole) == slaveRole { + utils.ValidateDKVConfig(utils.DKVConfig{ServerMode: masterMode, + SrvrAddr: replMasterAddr, KeyPath: keyPath, CertPath: certPath, + CaCertPath: caCertPath}) + } + + validateServerModeCompliantFlags() +} + +func validateServerModeCompliantFlags() { + utils.ValidateDKVConfig(utils.DKVConfig{ServerMode: mode, + SrvrAddr: replMasterAddr, KeyPath: keyPath, CertPath: certPath, + CaCertPath: caCertPath}) + srvrMode := utils.ToServerMode(mode) + switch srvrMode { + case utils.AutoTLS: + if generatedCertValidity <= 0 { + log.Panicf("Validaty period for the generated certificates(generatedCertValidity) should be > 0") + } + + if generatedCertDir == "" { + log.Panicf("Directory for certificate generation(generatedCertDir) must not be empty") + } + } } + + func setupAccessLogger() { accessLogger = zap.NewNop() if dbAccessLog != "" { @@ -215,14 +271,142 @@ func setupDKVLogger() { } func newGrpcServerListener() (*grpc.Server, net.Listener) { - grpcSrvr := grpc.NewServer( - grpc.StreamInterceptor(grpc_zap.StreamServerInterceptor(accessLogger)), - grpc.UnaryInterceptor(grpc_zap.UnaryServerInterceptor(accessLogger)), - ) + var grpcSrvr *grpc.Server + + if mode == utils.Insecure { + grpcSrvr = grpc.NewServer( + grpc.StreamInterceptor(grpc_zap.StreamServerInterceptor(accessLogger)), + grpc.UnaryInterceptor(grpc_zap.UnaryServerInterceptor(accessLogger)), + ) + } else { + srvrCred, err := loadTLSCredentials() + if err != nil { + dkvLogger.Sugar().Fatal("Unable to load tls credentials", err) + } + grpcSrvr = grpc.NewServer( + grpc.Creds(srvrCred), + grpc.StreamInterceptor(grpc_zap.StreamServerInterceptor(accessLogger)), + grpc.UnaryInterceptor(grpc_zap.UnaryServerInterceptor(accessLogger)), + ) + } reflection.Register(grpcSrvr) return grpcSrvr, newListener() } +func loadTLSCredentials() (credentials.TransportCredentials, error) { + var serverCert tls.Certificate + var err error + if mode == utils.AutoTLS { + err = generateSelfSignedCert() + if err != nil { + return nil, err + } + } + + serverCert, err = tls.LoadX509KeyPair(certPath, keyPath) + + if err != nil { + return nil, err + } + + var config *tls.Config + + if mode == utils.MutualTLS { + pemClientCA, err := ioutil.ReadFile(caCertPath) + if err != nil { + return nil, err + } + + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(pemClientCA) { + return nil, fmt.Errorf("failed to add client CA's certificate") + } + + config = &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: certPool, + } + } else { + config = &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.NoClientCert, + } + } + return credentials.NewTLS(config), nil +} + +func generateSelfSignedCert() error { + var err error + if _, errCreate := os.Stat(generatedCertDir); os.IsNotExist(errCreate) { + err = os.MkdirAll(generatedCertDir, os.ModePerm) + } + + if err != nil { + return err + } + certPath = filepath.Join(generatedCertDir, "cert.pem") + keyPath = filepath.Join(generatedCertDir, "key.pem") + + _, errcert := os.Stat(certPath) + _, errkey := os.Stat(keyPath) + if errcert == nil && errkey == nil { + return nil + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return err + } + + tmpl := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{Organization: []string{"DKV"}}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Duration(generatedCertValidity) * (24 * time.Hour)), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + + log.Println("automatically generate certificates", + zap.Time("certificate-validity-bound-not-after", tmpl.NotAfter)) + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return err + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv) + if err != nil { + return err + } + + certOut, err := os.Create(certPath) + if err != nil { + return err + } + pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + certOut.Close() + log.Println("created cert file", zap.String("path", certPath)) + + b, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return err + } + keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + + pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}) + keyOut.Close() + return nil +} + func newListener() (lis net.Listener) { var err error if lis, err = net.Listen("tcp", dbListenAddr); err != nil { @@ -261,7 +445,7 @@ func printFlagsWithPrefix(prefixes ...string) { } func toDKVSrvrRole(role string) dkvSrvrRole { - return dkvSrvrRole(strings.TrimSpace(strings.ToLower(dbRole))) + return dkvSrvrRole(strings.TrimSpace(strings.ToLower(role))) } func (role dkvSrvrRole) printFlags() { diff --git a/go.mod b/go.mod index 52548bf5..5f59858e 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/gogo/protobuf v1.3.1 github.com/golang/protobuf v1.3.5 github.com/golang/snappy v0.0.2 // indirect + github.com/google/go-cmp v0.5.2 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.2.0 github.com/pkg/errors v0.9.1 // indirect github.com/prometheus/client_golang v1.5.1 // indirect @@ -28,7 +29,8 @@ require ( golang.org/x/net v0.0.0-20201002202402-0a1ea396d57c // indirect golang.org/x/sys v0.0.0-20201005065044-765f4ea38db3 // indirect golang.org/x/text v0.3.2 // indirect - golang.org/x/tools v0.0.0-20200318150045-ba25ddc85566 // indirect + golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/grpc v1.28.0 honnef.co/go/tools v0.0.1-2020.1.3 // indirect ) diff --git a/go.sum b/go.sum index 3ac27e1e..6469952d 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,7 @@ github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuy github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4 h1:Hs82Z41s6SdL1CELW+XaDYmOH4hkBN4/N9og/AsOv7E= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= @@ -88,6 +89,7 @@ github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7a github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/golang/crypto v0.0.0-20191011191535-87dc89f01550 h1:WmVrTfi89sdCb9gSzM2EZSq0fhOb/2rNQJmC8lL+86I= github.com/golang/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= github.com/golang/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:QiPdlFVtSpwGb41kD2htioToP9EZqdz/6YL/xZsmuzY= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= @@ -111,6 +113,7 @@ github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.2 h1:aeE13tS0IiQgFjYdoL8qN3K1N2bXXtI6Vi51/y7BpMw= github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/sync v0.0.0-20190911185100-cd5d95a43a6e h1:Wi0y4ZAnk+WvzvZ26frs641N6DpG0e5Se3whDjLs9zU= github.com/golang/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= github.com/golang/sys v0.0.0-20200317113312-5766fd39f98d h1:8zzUweJJxbKvLGxvjbyJsIimIUysmU3Xa27wwAG04s0= github.com/golang/sys v0.0.0-20200317113312-5766fd39f98d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -128,6 +131,8 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/googleapis/google-cloud-go v0.26.0/go.mod h1:yJoOdPPE9UpqbamBhJvp7Ur6OUPPV4rUY3RnssPGNBA= @@ -201,6 +206,7 @@ github.com/prometheus/procfs v0.0.10/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+G github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/smira/go-statsd v1.3.1 h1:JalGiHNdK7GqVAPpg7j0Kwp2jZrz/fCg/B4ZuNuBY2w= github.com/smira/go-statsd v1.3.1/go.mod h1:1srXJ9/pbnN04G8f4F1jUzsGOnwkPKXciyqpewGlkC4= @@ -246,6 +252,7 @@ google.golang.org/grpc v1.25.1 h1:wdKvqQk7IttEw92GoRyKG2IDrUIpgpj6H6m81yfeMW0= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.28.0 h1:bO/TA4OxCOummhSf10siHuG7vJOiwh7SpRpFZDkOgl4= google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKal+60= +gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= diff --git a/internal/master/ds_service_test.go b/internal/master/ds_service_test.go index 5730d039..976c122c 100644 --- a/internal/master/ds_service_test.go +++ b/internal/master/ds_service_test.go @@ -125,7 +125,7 @@ func testNewDKVNodeJoiningAndLeaving(t *testing.T) { <-time.After(3 * time.Second) svcAddr := fmt.Sprintf("%s:%d", dkvSvcHost, dkvPorts[newNodeID]) // Create the client for the new DKV node - if dkvCli, err := ctl.NewInSecureDKVClient(svcAddr); err != nil { + if dkvCli, err := ctl.NewDKVClient(svcAddr, grpc.WithInsecure()); err != nil { t.Fatal(err) } else { defer dkvCli.Close() @@ -149,7 +149,7 @@ func testNewDKVNodeJoiningAndLeaving(t *testing.T) { func initDKVClients(ids ...int) { for id := 1; id <= clusterSize; id++ { svcAddr := fmt.Sprintf("%s:%d", dkvSvcHost, dkvPorts[id]) - if client, err := ctl.NewInSecureDKVClient(svcAddr); err != nil { + if client, err := ctl.NewDKVClient(svcAddr, grpc.WithInsecure()); err != nil { panic(err) } else { dkvClis[id] = client diff --git a/internal/master/ss_service_test.go b/internal/master/ss_service_test.go index 5d7c1be3..476e009e 100644 --- a/internal/master/ss_service_test.go +++ b/internal/master/ss_service_test.go @@ -38,7 +38,7 @@ func TestStandaloneService(t *testing.T) { go serveStandaloneDKV() sleepInSecs(3) dkvSvcAddr := fmt.Sprintf("%s:%d", dkvSvcHost, dkvSvcPort) - if client, err := ctl.NewInSecureDKVClient(dkvSvcAddr); err != nil { + if client, err := ctl.NewDKVClient(dkvSvcAddr, grpc.WithInsecure()); err != nil { t.Fatalf("Unable to connect to DKV service at %s. Error: %v", dkvSvcAddr, err) } else { dkvCli = client diff --git a/internal/slave/service_test.go b/internal/slave/service_test.go index 7aafc7f8..793248c6 100644 --- a/internal/slave/service_test.go +++ b/internal/slave/service_test.go @@ -124,7 +124,7 @@ func getKeys(t *testing.T, dkvCli *ctl.DKVClient, numKeys int, keyPrefix, valPre func newDKVClient(port int) *ctl.DKVClient { dkvSvcAddr := fmt.Sprintf("%s:%d", dkvSvcHost, port) - if client, err := ctl.NewInSecureDKVClient(dkvSvcAddr); err != nil { + if client, err := ctl.NewDKVClient(dkvSvcAddr, grpc.WithInsecure()); err != nil { panic(err) } else { return client diff --git a/internal/utils.go b/internal/utils.go new file mode 100644 index 00000000..c72893f1 --- /dev/null +++ b/internal/utils.go @@ -0,0 +1,110 @@ +package internal + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "github.com/flipkart-incubator/dkv/pkg/ctl" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "io/ioutil" + "log" + "strings" +) + +type ServerMode string + +const ( + ServerTLS ServerMode = "serverTLS" + AutoTLS = "autoTLS" + MutualTLS = "mutualTLS" + Insecure = "insecure" +) + +type DKVConfig struct { + ServerMode string + SrvrAddr string + KeyPath string + CertPath string + CaCertPath string +} + +func ValidateCertAndKeyPath(keyPath string, certPath string) { + if keyPath == "" { + log.Panicf("Key path(keyPath) must not be empty for non-auto TLS mode") + } + + if certPath == "" { + log.Panicf("Certificate path(certPath) must not be empty for non-auto TLS mode") + } +} + +func ValidateDKVConfig(config DKVConfig) { + serverMode := ToServerMode(config.ServerMode) + switch serverMode { + case AutoTLS: + case MutualTLS: + if config.CaCertPath == "" { + log.Panicf("CA certifacte path(caCertPath) must not be empty for mutualTLS mode") + } + ValidateCertAndKeyPath(config.KeyPath, config.CertPath) + case ServerTLS: + ValidateCertAndKeyPath(config.KeyPath, config.CertPath) + case Insecure: + default: + log.Panicf("Invalid server mode. Allowed values are insecure|autoTLS|serverTLS|mutualTLS") + } +} + + +func ToServerMode(mode string) ServerMode { + return ServerMode(strings.TrimSpace(mode)) +} + +func NewDKVClient(clientConfig DKVConfig) (*ctl.DKVClient, error) { + srvrMode := ToServerMode(clientConfig.ServerMode) + var opt grpc.DialOption + var err error = nil + var config *tls.Config + switch srvrMode { + case AutoTLS: + config = &tls.Config{ + InsecureSkipVerify: true, + } + opt = grpc.WithTransportCredentials(credentials.NewTLS(config)) + case MutualTLS: + config, err = getTLSConfigWithCertPool(clientConfig.CaCertPath) + if err == nil { + var cert tls.Certificate + cert, err = tls.LoadX509KeyPair(clientConfig.CertPath, clientConfig.KeyPath) + if err == nil { + config.Certificates = []tls.Certificate{cert} + opt = grpc.WithTransportCredentials(credentials.NewTLS(config)) + } + } + case ServerTLS: + config, err = getTLSConfigWithCertPool(clientConfig.CaCertPath) + if err == nil { + opt = grpc.WithTransportCredentials(credentials.NewTLS(config)) + } + case Insecure: + opt = grpc.WithInsecure() + } + if err != nil { + return nil, err + } + return ctl.NewDKVClient(clientConfig.SrvrAddr, opt) +} + +func getTLSConfigWithCertPool(caCertPath string) (*tls.Config, error) { + certPool := x509.NewCertPool() + bs, err := ioutil.ReadFile(caCertPath) + if err != nil { + return nil, err + } + ok := certPool.AppendCertsFromPEM(bs) + if !ok { + return nil, errors.New("failed to append certs") + } + return &tls.Config{InsecureSkipVerify: false, RootCAs: certPool}, err +} \ No newline at end of file diff --git a/pkg/ctl/client.go b/pkg/ctl/client.go index 60332fb5..9dc80a95 100644 --- a/pkg/ctl/client.go +++ b/pkg/ctl/client.go @@ -33,13 +33,17 @@ const ( ConnectTimeout = 10 * time.Second ) -// NewInSecureDKVClient creates an insecure GRPC client against the +// NewDKVClient creates an insecure GRPC client against the // given DKV service address. -func NewInSecureDKVClient(svcAddr string) (*DKVClient, error) { +func NewDKVClient(svcAddr string, opts ...grpc.DialOption) (*DKVClient, error) { var dkvClnt *DKVClient ctx, cancel := context.WithTimeout(context.Background(), ConnectTimeout) + optsCopy := opts + optsCopy = append(optsCopy, grpc.WithBlock()) + optsCopy = append(optsCopy, grpc.WithReadBufferSize(ReadBufSize)) + optsCopy = append(optsCopy, grpc.WithWriteBufferSize(WriteBufSize)) defer cancel() - conn, err := grpc.DialContext(ctx, svcAddr, grpc.WithInsecure(), grpc.WithBlock(), grpc.WithReadBufferSize(ReadBufSize), grpc.WithWriteBufferSize(WriteBufSize)) + conn, err := grpc.DialContext(ctx, svcAddr, opts...) if err == nil { dkvCli := serverpb.NewDKVClient(conn) dkvReplCli := serverpb.NewDKVReplicationClient(conn) From 84d66f243968f22d7ba752edfd48845c57b0e6fa Mon Sep 17 00:00:00 2001 From: Aadhar Sachdeva Date: Tue, 23 Feb 2021 17:44:41 +0530 Subject: [PATCH 2/6] Flag cleanup, util refactoring --- cmd/dkvctl/main.go | 22 ++-- cmd/dkvsrv/main.go | 253 ++++++++++----------------------------------- internal/utils.go | 213 +++++++++++++++++++++++++++++--------- 3 files changed, 228 insertions(+), 260 deletions(-) diff --git a/cmd/dkvctl/main.go b/cmd/dkvctl/main.go index 02d0189c..8245832a 100644 --- a/cmd/dkvctl/main.go +++ b/cmd/dkvctl/main.go @@ -189,17 +189,11 @@ func (c *cmd) replica(client *ctl.DKVClient, args ...string) { var ( dkvAddr string - dkvMode string - certPath string - keyPath string caCertPath string ) func init() { flag.StringVar(&dkvAddr, "dkvAddr", "127.0.0.1:8080", ": - DKV server address") - flag.StringVar(&dkvMode, "dkvMode", "insecure", "Security mode for the DKV server node - insecure|autoTLS|serverTLS|mutualTLS") - flag.StringVar(&certPath, "certPath", "", "Path for certificate file of this node") - flag.StringVar(&keyPath, "keyPath", "", "Path for key file of this node") flag.StringVar(&caCertPath, "caCertPath", "", "Path for root certificate of the chain, i.e. CA certificate") for _, c := range cmds { flag.StringVar(&c.value, c.name, c.value, c.cmdDesc) @@ -235,13 +229,9 @@ func main() { } flag.Parse() - utils.ValidateDKVConfig(utils.DKVConfig{ServerMode: dkvMode, - SrvrAddr: dkvAddr, KeyPath: keyPath, CertPath: certPath, - CaCertPath: caCertPath}) fmt.Printf("Connecting to DKV service at %s...", dkvAddr) - client, err := utils.NewDKVClient(utils.DKVConfig{ServerMode: dkvMode, - SrvrAddr: dkvAddr, KeyPath: keyPath, CertPath: certPath, - CaCertPath: caCertPath}) + client, err := utils.NewDKVClient(utils.DKVConfig{ConnectionMode: modeFromFlags(), + SrvrAddr: dkvAddr, CaCertPath: caCertPath}) if err != nil { fmt.Printf("\nUnable to create DKV client. Error: %v\n", err) return @@ -263,3 +253,11 @@ func main() { usage() } } + +func modeFromFlags() utils.ConnectionMode { + if caCertPath != "" { + return utils.ServerTLS + } else { + return utils.Insecure + } +} diff --git a/cmd/dkvsrv/main.go b/cmd/dkvsrv/main.go index 92b2bde1..9040fef0 100644 --- a/cmd/dkvsrv/main.go +++ b/cmd/dkvsrv/main.go @@ -1,24 +1,12 @@ package main import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" "flag" "fmt" - "google.golang.org/grpc/credentials" - "io/ioutil" "log" - "math/big" - "net" "os" "os/signal" "path" - "path/filepath" "strings" "syscall" "time" @@ -34,11 +22,8 @@ import ( "github.com/flipkart-incubator/dkv/pkg/serverpb" nexus_api "github.com/flipkart-incubator/nexus/pkg/api" nexus "github.com/flipkart-incubator/nexus/pkg/raft" - grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "google.golang.org/grpc" - "google.golang.org/grpc/reflection" ) var ( @@ -46,17 +31,14 @@ var ( dbEngine string dbFolder string dbListenAddr string + peerListenAddr string dbRole string statsdAddr string replMasterAddr string - masterMode string replPollInterval time.Duration - mode string certPath string keyPath string caCertPath string - generatedCertValidity int - generatedCertDir string // Logging vars dbAccessLog string @@ -73,20 +55,17 @@ func init() { flag.BoolVar(&disklessMode, "dbDiskless", false, fmt.Sprintf("Enables diskless mode where data is stored entirely in memory.\nAvailable on Badger for standalone and slave roles. (default %v)", disklessMode)) flag.StringVar(&dbFolder, "dbFolder", "/tmp/dkvsrv", "DB folder path for storing data files") flag.StringVar(&dbListenAddr, "dbListenAddr", "127.0.0.1:8080", "Address on which the DKV service binds") + flag.StringVar(&peerListenAddr, "peerListenAddr", "127.0.0.1:8083", "Address on which the DKV replication service binds") flag.StringVar(&dbEngine, "dbEngine", "rocksdb", "Underlying DB engine for storing data - badger|rocksdb") flag.StringVar(&dbRole, "dbRole", "none", "DB role of this node - none|master|slave") flag.StringVar(&statsdAddr, "statsdAddr", "", "StatsD service address in host:port format") flag.StringVar(&replMasterAddr, "replMasterAddr", "", "Service address of DKV master node for replication") - flag.StringVar(&masterMode, "masterMode", "insecure", "Service address of DKV master node for replication") flag.DurationVar(&replPollInterval, "replPollInterval", 5*time.Second, "Interval used for polling changes from master. Eg., 10s, 5ms, 2h, etc.") flag.StringVar(&dbAccessLog, "dbAccessLog", "", "File for logging DKV accesses eg., stdout, stderr, /tmp/access.log") - flag.StringVar(&mode, "mode", "insecure", "Security mode for this node - insecure|autoTLS|serverTLS|mutualTLS") flag.StringVar(&certPath, "certPath", "", "Path for certificate file of this node") flag.StringVar(&keyPath, "keyPath", "", "Path for key file of this node") flag.StringVar(&caCertPath, "caCertPath", "", "Path for root certificate of the chain, i.e. CA certificate") flag.BoolVar(&verboseLogging, "verbose", false, fmt.Sprintf("Enable verbose logging.\nBy default, only warnings and errors are logged. (default %v)", verboseLogging)) - flag.IntVar(&generatedCertValidity, "generatedCertValidity", 365, "Validity(in days) for the generated certificate") - flag.StringVar(&generatedCertDir, "generatedCertDir", "/tmp/dkv-certs", "Directory to store the generate key and certificates") setDKVDefaultsForNexusDirs() } @@ -106,8 +85,22 @@ func main() { setFlagsForNexusDirs() setupStats() + var secure = false + if caCertPath != "" && keyPath != "" && certPath != "" { + secure = true + } + + var srvrMode utils.ConnectionMode + if secure { + srvrMode = utils.ServerTLS + } else { + srvrMode = utils.Insecure + } + kvs, cp, ca, br := newKVStore() - grpcSrvr, lstnr := newGrpcServerListener() + grpcSrvr, lstnr := utils.NewGrpcServerListener(utils.DKVConfig{ConnectionMode: srvrMode, + SrvrAddr: dbListenAddr, KeyPath: keyPath, CertPath: certPath, + CaCertPath: caCertPath}, accessLogger) defer grpcSrvr.GracefulStop() srvrRole := toDKVSrvrRole(dbRole) srvrRole.printFlags() @@ -132,9 +125,26 @@ func main() { } defer dkvSvc.Close() serverpb.RegisterDKVServer(grpcSrvr, dkvSvc) - serverpb.RegisterDKVReplicationServer(grpcSrvr, dkvSvc) + var replSrvrMode utils.ConnectionMode + if secure { + replSrvrMode = utils.MutualTLS + } else { + replSrvrMode = utils.Insecure + } + replGrpcSrvr, replLstnr := utils.NewGrpcServerListener(utils.DKVConfig{ConnectionMode: replSrvrMode, + SrvrAddr: peerListenAddr, KeyPath: keyPath, CertPath: certPath, + CaCertPath: caCertPath}, accessLogger) + defer replGrpcSrvr.GracefulStop() + serverpb.RegisterDKVReplicationServer(replGrpcSrvr, dkvSvc) + go replGrpcSrvr.Serve(replLstnr) case slaveRole: - if replCli, err := utils.NewDKVClient(utils.DKVConfig{ServerMode: masterMode, + var replClientMode utils.ConnectionMode + if secure { + replClientMode = utils.MutualTLS + } else { + replClientMode = utils.Insecure + } + if replCli, err := utils.NewDKVClient(utils.DKVConfig{ConnectionMode: replClientMode, SrvrAddr: replMasterAddr, KeyPath: keyPath, CertPath: certPath, CaCertPath: caCertPath}); @@ -158,6 +168,9 @@ func validateFlags() { if dbListenAddr != "" && strings.IndexRune(dbListenAddr, ':') < 0 { log.Panicf("given listen address: %s is invalid, must be in host:port format", dbListenAddr) } + if peerListenAddr != "" && strings.IndexRune(peerListenAddr, ':') < 0 { + log.Panicf("given listen address: %s is invalid, must be in host:port format", dbListenAddr) + } if replMasterAddr != "" && strings.IndexRune(replMasterAddr, ':') < 0 { log.Panicf("given master address: %s for replication is invalid, must be in host:port format", replMasterAddr) } @@ -167,38 +180,15 @@ func validateFlags() { if disklessMode && (strings.ToLower(dbEngine) == "rocksdb" || strings.ToLower(dbRole) == masterRole) { log.Panicf("diskless is available only on Badger storage and for standalone and slave roles") } - if strings.ToLower(dbRole) == slaveRole && (replMasterAddr == "" || masterMode == "") { + if strings.ToLower(dbRole) == slaveRole && replMasterAddr == "" { log.Panicf("replMasterAddr must be given in slave mode") } - - if strings.ToLower(dbRole) == slaveRole { - utils.ValidateDKVConfig(utils.DKVConfig{ServerMode: masterMode, - SrvrAddr: replMasterAddr, KeyPath: keyPath, CertPath: certPath, - CaCertPath: caCertPath}) - } - - validateServerModeCompliantFlags() -} - -func validateServerModeCompliantFlags() { - utils.ValidateDKVConfig(utils.DKVConfig{ServerMode: mode, - SrvrAddr: replMasterAddr, KeyPath: keyPath, CertPath: certPath, - CaCertPath: caCertPath}) - srvrMode := utils.ToServerMode(mode) - switch srvrMode { - case utils.AutoTLS: - if generatedCertValidity <= 0 { - log.Panicf("Validaty period for the generated certificates(generatedCertValidity) should be > 0") - } - - if generatedCertDir == "" { - log.Panicf("Directory for certificate generation(generatedCertDir) must not be empty") - } + nonNullAuthFlags := btou(certPath != "") + btou(keyPath != "") + btou(caCertPath != "") + if nonNullAuthFlags > 0 && nonNullAuthFlags < 3 { + log.Panicf("Missing TLS attributes, set all flags (caCertPath, keyPath, certPath) to run DKV in secure mode") } } - - func setupAccessLogger() { accessLogger = zap.NewNop() if dbAccessLog != "" { @@ -270,152 +260,6 @@ func setupDKVLogger() { } } -func newGrpcServerListener() (*grpc.Server, net.Listener) { - var grpcSrvr *grpc.Server - - if mode == utils.Insecure { - grpcSrvr = grpc.NewServer( - grpc.StreamInterceptor(grpc_zap.StreamServerInterceptor(accessLogger)), - grpc.UnaryInterceptor(grpc_zap.UnaryServerInterceptor(accessLogger)), - ) - } else { - srvrCred, err := loadTLSCredentials() - if err != nil { - dkvLogger.Sugar().Fatal("Unable to load tls credentials", err) - } - grpcSrvr = grpc.NewServer( - grpc.Creds(srvrCred), - grpc.StreamInterceptor(grpc_zap.StreamServerInterceptor(accessLogger)), - grpc.UnaryInterceptor(grpc_zap.UnaryServerInterceptor(accessLogger)), - ) - } - reflection.Register(grpcSrvr) - return grpcSrvr, newListener() -} - -func loadTLSCredentials() (credentials.TransportCredentials, error) { - var serverCert tls.Certificate - var err error - if mode == utils.AutoTLS { - err = generateSelfSignedCert() - if err != nil { - return nil, err - } - } - - serverCert, err = tls.LoadX509KeyPair(certPath, keyPath) - - if err != nil { - return nil, err - } - - var config *tls.Config - - if mode == utils.MutualTLS { - pemClientCA, err := ioutil.ReadFile(caCertPath) - if err != nil { - return nil, err - } - - certPool := x509.NewCertPool() - if !certPool.AppendCertsFromPEM(pemClientCA) { - return nil, fmt.Errorf("failed to add client CA's certificate") - } - - config = &tls.Config{ - Certificates: []tls.Certificate{serverCert}, - ClientAuth: tls.RequireAndVerifyClientCert, - ClientCAs: certPool, - } - } else { - config = &tls.Config{ - Certificates: []tls.Certificate{serverCert}, - ClientAuth: tls.NoClientCert, - } - } - return credentials.NewTLS(config), nil -} - -func generateSelfSignedCert() error { - var err error - if _, errCreate := os.Stat(generatedCertDir); os.IsNotExist(errCreate) { - err = os.MkdirAll(generatedCertDir, os.ModePerm) - } - - if err != nil { - return err - } - certPath = filepath.Join(generatedCertDir, "cert.pem") - keyPath = filepath.Join(generatedCertDir, "key.pem") - - _, errcert := os.Stat(certPath) - _, errkey := os.Stat(keyPath) - if errcert == nil && errkey == nil { - return nil - } - - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - return err - } - - tmpl := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{Organization: []string{"DKV"}}, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Duration(generatedCertValidity) * (24 * time.Hour)), - - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - - - log.Println("automatically generate certificates", - zap.Time("certificate-validity-bound-not-after", tmpl.NotAfter)) - - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return err - } - - derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv) - if err != nil { - return err - } - - certOut, err := os.Create(certPath) - if err != nil { - return err - } - pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - certOut.Close() - log.Println("created cert file", zap.String("path", certPath)) - - b, err := x509.MarshalECPrivateKey(priv) - if err != nil { - return err - } - keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return err - } - - pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}) - keyOut.Close() - return nil -} - -func newListener() (lis net.Listener) { - var err error - if lis, err = net.Listen("tcp", dbListenAddr); err != nil { - log.Panicf("failed to listen: %v", err) - return - } - return -} - func setupSignalHandler() <-chan os.Signal { signals := []os.Signal{syscall.SIGINT, syscall.SIGQUIT, syscall.SIGSTOP, syscall.SIGTERM} stopChan := make(chan os.Signal, len(signals)) @@ -454,9 +298,9 @@ func (role dkvSrvrRole) printFlags() { printFlagsWithPrefix("db") case masterRole: if haveFlagsWithPrefix("nexus") { - printFlagsWithPrefix("db", "nexus") + printFlagsWithPrefix("db", "nexus", "peer") } else { - printFlagsWithPrefix("db") + printFlagsWithPrefix("db", "peer") } case slaveRole: printFlagsWithPrefix("db", "repl") @@ -554,3 +398,10 @@ func newDKVReplicator(kvs storage.KVStore) nexus_api.RaftReplicator { return nexusRepl } } + +func btou(b bool) uint8 { + if b { + return 1 + } + return 0 +} \ No newline at end of file diff --git a/internal/utils.go b/internal/utils.go index c72893f1..570c423f 100644 --- a/internal/utils.go +++ b/internal/utils.go @@ -1,77 +1,56 @@ package internal import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" + "fmt" "github.com/flipkart-incubator/dkv/pkg/ctl" + grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" + "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/reflection" "io/ioutil" "log" + "math/big" + "net" + "os" + "path/filepath" "strings" + "time" ) -type ServerMode string +type ConnectionMode string const ( - ServerTLS ServerMode = "serverTLS" - AutoTLS = "autoTLS" - MutualTLS = "mutualTLS" - Insecure = "insecure" + ServerTLS ConnectionMode = "serverTLS" + MutualTLS = "mutualTLS" + Insecure = "insecure" ) type DKVConfig struct { - ServerMode string - SrvrAddr string - KeyPath string - CertPath string - CaCertPath string + ConnectionMode ConnectionMode + SrvrAddr string + KeyPath string + CertPath string + CaCertPath string } -func ValidateCertAndKeyPath(keyPath string, certPath string) { - if keyPath == "" { - log.Panicf("Key path(keyPath) must not be empty for non-auto TLS mode") - } - - if certPath == "" { - log.Panicf("Certificate path(certPath) must not be empty for non-auto TLS mode") - } -} - -func ValidateDKVConfig(config DKVConfig) { - serverMode := ToServerMode(config.ServerMode) - switch serverMode { - case AutoTLS: - case MutualTLS: - if config.CaCertPath == "" { - log.Panicf("CA certifacte path(caCertPath) must not be empty for mutualTLS mode") - } - ValidateCertAndKeyPath(config.KeyPath, config.CertPath) - case ServerTLS: - ValidateCertAndKeyPath(config.KeyPath, config.CertPath) - case Insecure: - default: - log.Panicf("Invalid server mode. Allowed values are insecure|autoTLS|serverTLS|mutualTLS") - } -} - - -func ToServerMode(mode string) ServerMode { - return ServerMode(strings.TrimSpace(mode)) +func ToServerMode(mode string) ConnectionMode { + return ConnectionMode(strings.TrimSpace(mode)) } func NewDKVClient(clientConfig DKVConfig) (*ctl.DKVClient, error) { - srvrMode := ToServerMode(clientConfig.ServerMode) var opt grpc.DialOption var err error = nil var config *tls.Config - switch srvrMode { - case AutoTLS: - config = &tls.Config{ - InsecureSkipVerify: true, - } - opt = grpc.WithTransportCredentials(credentials.NewTLS(config)) + switch clientConfig.ConnectionMode { case MutualTLS: config, err = getTLSConfigWithCertPool(clientConfig.CaCertPath) if err == nil { @@ -107,4 +86,144 @@ func getTLSConfigWithCertPool(caCertPath string) (*tls.Config, error) { return nil, errors.New("failed to append certs") } return &tls.Config{InsecureSkipVerify: false, RootCAs: certPool}, err +} + +func NewGrpcServerListener(config DKVConfig, loogger *zap.Logger) (*grpc.Server, net.Listener) { + var grpcSrvr *grpc.Server + + if config.ConnectionMode == Insecure { + grpcSrvr = grpc.NewServer( + grpc.StreamInterceptor(grpc_zap.StreamServerInterceptor(loogger)), + grpc.UnaryInterceptor(grpc_zap.UnaryServerInterceptor(loogger)), + ) + } else { + srvrCred, err := loadTLSCredentials(config) + if err != nil { + log.Fatal("Unable to load tls credentials", err) + } + grpcSrvr = grpc.NewServer( + grpc.Creds(srvrCred), + grpc.StreamInterceptor(grpc_zap.StreamServerInterceptor(loogger)), + grpc.UnaryInterceptor(grpc_zap.UnaryServerInterceptor(loogger)), + ) + } + reflection.Register(grpcSrvr) + return grpcSrvr, NewListener(config.SrvrAddr) +} + +func loadTLSCredentials(clientConfig DKVConfig) (credentials.TransportCredentials, error) { + var serverCert tls.Certificate + var err error + + serverCert, err = tls.LoadX509KeyPair(clientConfig.CertPath, clientConfig.KeyPath) + + if err != nil { + return nil, err + } + + var config *tls.Config + + if clientConfig.ConnectionMode == MutualTLS { + pemClientCA, err := ioutil.ReadFile(clientConfig.CaCertPath) + if err != nil { + return nil, err + } + + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(pemClientCA) { + return nil, fmt.Errorf("failed to add client CA's certificate") + } + + config = &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: certPool, + } + } else { + config = &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.NoClientCert, + } + } + return credentials.NewTLS(config), nil +} + +func generateSelfSignedCert(generatedCertDir string, generatedCertValidity int) (string, string, error) { + var err error + if _, errCreate := os.Stat(generatedCertDir); os.IsNotExist(errCreate) { + err = os.MkdirAll(generatedCertDir, os.ModePerm) + } + + if err != nil { + return "", "", err + } + certPath := filepath.Join(generatedCertDir, "cert.pem") + keyPath := filepath.Join(generatedCertDir, "key.pem") + + _, errcert := os.Stat(certPath) + _, errkey := os.Stat(keyPath) + if errcert == nil && errkey == nil { + return "", "", nil + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return "", "", err + } + + tmpl := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{Organization: []string{"DKV"}}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Duration(generatedCertValidity) * (24 * time.Hour)), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + + log.Println("automatically generate certificates", + zap.Time("certificate-validity-bound-not-after", tmpl.NotAfter)) + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return "", "", err + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv) + if err != nil { + return "", "", err + } + + certOut, err := os.Create(certPath) + if err != nil { + return "", "", err + } + pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + certOut.Close() + log.Println("created cert file", zap.String("path", certPath)) + + b, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return "", "", err + } + keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return "", "", err + } + + pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}) + keyOut.Close() + return certPath, keyPath, nil +} + +func NewListener(listenAddr string) (lis net.Listener) { + var err error + if lis, err = net.Listen("tcp", listenAddr); err != nil { + log.Panicf("failed to listen: %v", err) + return + } + return } \ No newline at end of file From 996a6d4a92d6543672478fed7166a094d8a90f79 Mon Sep 17 00:00:00 2001 From: Aadhar Sachdeva Date: Tue, 23 Feb 2021 17:47:08 +0530 Subject: [PATCH 3/6] formatting fixes --- internal/utils.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/utils.go b/internal/utils.go index 570c423f..bc3264ac 100644 --- a/internal/utils.go +++ b/internal/utils.go @@ -183,7 +183,6 @@ func generateSelfSignedCert(generatedCertDir string, generatedCertValidity int) BasicConstraintsValid: true, } - log.Println("automatically generate certificates", zap.Time("certificate-validity-bound-not-after", tmpl.NotAfter)) @@ -226,4 +225,4 @@ func NewListener(listenAddr string) (lis net.Listener) { return } return -} \ No newline at end of file +} From 55765bc4fe886e8df6ab74d24e7f0096bd469779 Mon Sep 17 00:00:00 2001 From: Aadhar Sachdeva Date: Tue, 23 Feb 2021 18:29:40 +0530 Subject: [PATCH 4/6] formatting, cleanup --- cmd/dkvctl/main.go | 6 +++--- cmd/dkvsrv/main.go | 30 +++++++++++++++--------------- internal/utils.go | 7 +------ 3 files changed, 19 insertions(+), 24 deletions(-) diff --git a/cmd/dkvctl/main.go b/cmd/dkvctl/main.go index 8245832a..a0c45162 100644 --- a/cmd/dkvctl/main.go +++ b/cmd/dkvctl/main.go @@ -188,8 +188,8 @@ func (c *cmd) replica(client *ctl.DKVClient, args ...string) { } var ( - dkvAddr string - caCertPath string + dkvAddr string + caCertPath string ) func init() { @@ -203,7 +203,7 @@ func init() { func usage() { fmt.Printf("Usage of %s:\n", os.Args[0]) - printUsage([]string {"dkvAddr", "dkvMode", "certPath", "keyPath", "caCertPath"}) + printUsage([]string{"dkvAddr", "caCertPath"}) for _, cmd := range cmds { cmd.usage() } diff --git a/cmd/dkvsrv/main.go b/cmd/dkvsrv/main.go index 9040fef0..f14b258a 100644 --- a/cmd/dkvsrv/main.go +++ b/cmd/dkvsrv/main.go @@ -27,18 +27,18 @@ import ( ) var ( - disklessMode bool - dbEngine string - dbFolder string - dbListenAddr string - peerListenAddr string - dbRole string - statsdAddr string - replMasterAddr string - replPollInterval time.Duration - certPath string - keyPath string - caCertPath string + disklessMode bool + dbEngine string + dbFolder string + dbListenAddr string + peerListenAddr string + dbRole string + statsdAddr string + replMasterAddr string + replPollInterval time.Duration + certPath string + keyPath string + caCertPath string // Logging vars dbAccessLog string @@ -147,8 +147,8 @@ func main() { if replCli, err := utils.NewDKVClient(utils.DKVConfig{ConnectionMode: replClientMode, SrvrAddr: replMasterAddr, KeyPath: keyPath, CertPath: certPath, CaCertPath: caCertPath}); - - err != nil { + + err != nil { panic(err) } else { defer replCli.Close() @@ -404,4 +404,4 @@ func btou(b bool) uint8 { return 1 } return 0 -} \ No newline at end of file +} diff --git a/internal/utils.go b/internal/utils.go index bc3264ac..c6d8977b 100644 --- a/internal/utils.go +++ b/internal/utils.go @@ -22,7 +22,6 @@ import ( "net" "os" "path/filepath" - "strings" "time" ) @@ -42,10 +41,6 @@ type DKVConfig struct { CaCertPath string } -func ToServerMode(mode string) ConnectionMode { - return ConnectionMode(strings.TrimSpace(mode)) -} - func NewDKVClient(clientConfig DKVConfig) (*ctl.DKVClient, error) { var opt grpc.DialOption var err error = nil @@ -148,7 +143,7 @@ func loadTLSCredentials(clientConfig DKVConfig) (credentials.TransportCredential return credentials.NewTLS(config), nil } -func generateSelfSignedCert(generatedCertDir string, generatedCertValidity int) (string, string, error) { +func GenerateSelfSignedCert(generatedCertDir string, generatedCertValidity int) (string, string, error) { var err error if _, errCreate := os.Stat(generatedCertDir); os.IsNotExist(errCreate) { err = os.MkdirAll(generatedCertDir, os.ModePerm) From ecfbcb104d666e3271a06064309dd820ea4b5bbf Mon Sep 17 00:00:00 2001 From: Aadhar Sachdeva Date: Wed, 24 Feb 2021 12:24:12 +0530 Subject: [PATCH 5/6] Minor refactoring --- internal/utils.go | 35 ++++++++++++----------------------- 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/internal/utils.go b/internal/utils.go index c6d8977b..61e185ac 100644 --- a/internal/utils.go +++ b/internal/utils.go @@ -84,24 +84,17 @@ func getTLSConfigWithCertPool(caCertPath string) (*tls.Config, error) { } func NewGrpcServerListener(config DKVConfig, loogger *zap.Logger) (*grpc.Server, net.Listener) { - var grpcSrvr *grpc.Server - - if config.ConnectionMode == Insecure { - grpcSrvr = grpc.NewServer( - grpc.StreamInterceptor(grpc_zap.StreamServerInterceptor(loogger)), - grpc.UnaryInterceptor(grpc_zap.UnaryServerInterceptor(loogger)), - ) - } else { + opts := []grpc.ServerOption{ + grpc.StreamInterceptor(grpc_zap.StreamServerInterceptor(loogger)), + grpc.UnaryInterceptor(grpc_zap.UnaryServerInterceptor(loogger))} + if config.ConnectionMode != Insecure { srvrCred, err := loadTLSCredentials(config) if err != nil { log.Fatal("Unable to load tls credentials", err) } - grpcSrvr = grpc.NewServer( - grpc.Creds(srvrCred), - grpc.StreamInterceptor(grpc_zap.StreamServerInterceptor(loogger)), - grpc.UnaryInterceptor(grpc_zap.UnaryServerInterceptor(loogger)), - ) + opts = append(opts, grpc.Creds(srvrCred)) } + grpcSrvr := grpc.NewServer(opts...) reflection.Register(grpcSrvr) return grpcSrvr, NewListener(config.SrvrAddr) } @@ -116,7 +109,9 @@ func loadTLSCredentials(clientConfig DKVConfig) (credentials.TransportCredential return nil, err } - var config *tls.Config + var config = &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + } if clientConfig.ConnectionMode == MutualTLS { pemClientCA, err := ioutil.ReadFile(clientConfig.CaCertPath) @@ -129,16 +124,10 @@ func loadTLSCredentials(clientConfig DKVConfig) (credentials.TransportCredential return nil, fmt.Errorf("failed to add client CA's certificate") } - config = &tls.Config{ - Certificates: []tls.Certificate{serverCert}, - ClientAuth: tls.RequireAndVerifyClientCert, - ClientCAs: certPool, - } + config.ClientAuth = tls.RequireAndVerifyClientCert + config.ClientCAs = certPool } else { - config = &tls.Config{ - Certificates: []tls.Certificate{serverCert}, - ClientAuth: tls.NoClientCert, - } + config.ClientAuth = tls.NoClientCert } return credentials.NewTLS(config), nil } From 2cc348162846461e25791b4c9cd30c1a7f4a77f4 Mon Sep 17 00:00:00 2001 From: Aadhar Sachdeva Date: Wed, 24 Feb 2021 14:58:48 +0530 Subject: [PATCH 6/6] Reformat --- cmd/dkvctl/main.go | 16 ++++++++-------- pkg/ctl/client.go | 8 +++----- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/cmd/dkvctl/main.go b/cmd/dkvctl/main.go index a0278869..ce0da7c2 100644 --- a/cmd/dkvctl/main.go +++ b/cmd/dkvctl/main.go @@ -159,8 +159,8 @@ func (c *cmd) listNodes(client *ctl.DKVClient, args ...string) { } var ( - dkvAddr string - caCertPath string + dkvAddr string + caCertPath string dkvAuthority string ) @@ -208,13 +208,13 @@ func main() { flag.Parse() fmt.Printf("Connecting to DKV service at %s...", dkvAddr) - fmt.Printf("Connecting to DKV service at %s", dkvAddr) - if dkvAuthority = strings.TrimSpace(dkvAuthority); dkvAuthority != "" { - fmt.Printf(" (:authority = %s)", dkvAuthority) - } - fmt.Printf("...") + fmt.Printf("Connecting to DKV service at %s", dkvAddr) + if dkvAuthority = strings.TrimSpace(dkvAuthority); dkvAuthority != "" { + fmt.Printf(" (:authority = %s)", dkvAuthority) + } + fmt.Printf("...") client, err := utils.NewDKVClient(utils.DKVConfig{ConnectionMode: modeFromFlags(), - SrvrAddr: dkvAddr, CaCertPath: caCertPath}, "") + SrvrAddr: dkvAddr, CaCertPath: caCertPath}, dkvAuthority) if err != nil { fmt.Printf("\nUnable to create DKV client. Error: %v\n", err) return diff --git a/pkg/ctl/client.go b/pkg/ctl/client.go index e5dc6513..7d27f593 100644 --- a/pkg/ctl/client.go +++ b/pkg/ctl/client.go @@ -39,11 +39,9 @@ func NewDKVClient(svcAddr string, authority string, opts ...grpc.DialOption) (*D var dkvClnt *DKVClient ctx, cancel := context.WithTimeout(context.Background(), ConnectTimeout) optsCopy := opts - optsCopy = append(optsCopy, grpc.WithBlock()) - optsCopy = append(optsCopy, grpc.WithReadBufferSize(ReadBufSize)) - optsCopy = append(optsCopy, grpc.WithWriteBufferSize(WriteBufSize)) - optsCopy = append(optsCopy, grpc.WithAuthority(authority)) - optsCopy = append(optsCopy, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxMsgSize))) + optsCopy = append(optsCopy, grpc.WithBlock(), grpc.WithReadBufferSize(ReadBufSize), + grpc.WithWriteBufferSize(WriteBufSize), grpc.WithAuthority(authority), + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxMsgSize))) defer cancel() conn, err := grpc.DialContext(ctx, svcAddr, opts...) if err == nil {