diff --git a/README.md b/README.md index 8ba7dff9..8ff71570 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ bmclib v2 is a library to abstract interacting with baseboard management control - [IPMItool](https://github.com/bmc-toolbox/bmclib/tree/main/providers/ipmitool) - [Intel AMT](https://github.com/bmc-toolbox/bmclib/tree/main/providers/intelamt) - [Asrockrack](https://github.com/bmc-toolbox/bmclib/tree/main/providers/asrockrack) + - [RPC](providers/rpc/) ## Installation diff --git a/client.go b/client.go index c5ecadde..b552cce5 100644 --- a/client.go +++ b/client.go @@ -4,11 +4,13 @@ package bmclib import ( "context" + "fmt" "io" "net/http" "sync" "time" + "dario.cat/mergo" "github.com/bmc-toolbox/bmclib/v2/bmc" "github.com/bmc-toolbox/bmclib/v2/internal/httpclient" "github.com/bmc-toolbox/bmclib/v2/providers/asrockrack" @@ -16,6 +18,7 @@ import ( "github.com/bmc-toolbox/bmclib/v2/providers/intelamt" "github.com/bmc-toolbox/bmclib/v2/providers/ipmitool" "github.com/bmc-toolbox/bmclib/v2/providers/redfish" + "github.com/bmc-toolbox/bmclib/v2/providers/rpc" "github.com/bmc-toolbox/bmclib/v2/providers/supermicro" "github.com/bmc-toolbox/common" "github.com/go-logr/logr" @@ -58,6 +61,7 @@ type providerConfig struct { intelamt intelamt.Config dell dell.Config supermicro supermicro.Config + rpc rpc.Provider } // NewClient returns a new Client struct @@ -90,6 +94,7 @@ func NewClient(host, user, pass string, opts ...Option) *Client { supermicro: supermicro.Config{ Port: "443", }, + rpc: rpc.Provider{}, }, } @@ -129,7 +134,29 @@ func (c *Client) defaultTimeout(ctx context.Context) time.Duration { return time.Until(deadline) / time.Duration(l) } +func (c *Client) registerRPCProvider() error { + driverRPC := rpc.New(c.providerConfig.rpc.ConsumerURL, c.Auth.Host, c.providerConfig.rpc.Opts.HMAC.Secrets) + c.providerConfig.rpc.Logger = c.Logger + if err := mergo.Merge(driverRPC, c.providerConfig.rpc, mergo.WithOverride, mergo.WithTransformers(&rpc.Provider{})); err != nil { + return fmt.Errorf("failed to merge user specified rpc config with the config defaults, rpc provider not available: %w", err) + } + c.Registry.Register(rpc.ProviderName, rpc.ProviderProtocol, rpc.Features, nil, driverRPC) + + return nil +} + func (c *Client) registerProviders() { + // register the rpc provider + // without the consumer URL there is no way to send RPC requests. + if c.providerConfig.rpc.ConsumerURL != "" { + // when the rpc provider is to be used, we won't register any other providers. + err := c.registerRPCProvider() + if err == nil { + c.Logger.Info("note: with the rpc provider registered, no other providers will be registered and available") + return + } + c.Logger.Info("failed to register rpc provider, falling back to registering all other providers", "error", err.Error()) + } // register ipmitool provider ipmiOpts := []ipmitool.Option{ ipmitool.WithLogger(c.Logger), diff --git a/client_test.go b/client_test.go index 324ed826..321a7e25 100644 --- a/client_test.go +++ b/client_test.go @@ -21,8 +21,7 @@ func TestBMC(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() cl := NewClient(host, user, pass, WithLogger(log), WithPerProviderTimeout(5*time.Second)) - err := cl.Open(ctx) - if err != nil { + if err := cl.Open(ctx); err != nil { t.Logf("%+v", cl.GetMetadata()) t.Fatal(err) } diff --git a/examples/rpc/main.go b/examples/rpc/main.go new file mode 100644 index 00000000..097ccfbc --- /dev/null +++ b/examples/rpc/main.go @@ -0,0 +1,94 @@ +package main + +import ( + "context" + "encoding/json" + "net/http" + "time" + + "github.com/bmc-toolbox/bmclib/v2" + "github.com/bmc-toolbox/bmclib/v2/logging" + "github.com/bmc-toolbox/bmclib/v2/providers/rpc" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + // Start the test consumer + go testConsumer(ctx) + time.Sleep(100 * time.Millisecond) + + log := logging.ZeroLogger("info") + opts := []bmclib.Option{ + bmclib.WithLogger(log), + bmclib.WithPerProviderTimeout(5 * time.Second), + bmclib.WithRPCOpt(rpc.Provider{ + ConsumerURL: "http://localhost:8800", + // Opts are not required. + Opts: rpc.Opts{ + HMAC: rpc.HMACOpts{ + Secrets: rpc.Secrets{rpc.SHA256: {"superSecret1"}}, + }, + Signature: rpc.SignatureOpts{ + HeaderName: "X-Bespoke-Signature", + IncludedPayloadHeaders: []string{"X-Bespoke-Timestamp"}, + }, + Request: rpc.RequestOpts{ + TimestampHeader: "X-Bespoke-Timestamp", + }, + }, + }), + } + host := "127.0.1.1" + user := "admin" + pass := "admin" + c := bmclib.NewClient(host, user, pass, opts...) + if err := c.Open(ctx); err != nil { + panic(err) + } + defer c.Close(ctx) + + state, err := c.GetPowerState(ctx) + if err != nil { + panic(err) + } + log.Info("power state", "state", state) + log.Info("metadata for GetPowerState", "metadata", c.GetMetadata()) + + ok, err := c.SetPowerState(ctx, "on") + if err != nil { + panic(err) + } + log.Info("set power state", "ok", ok) + log.Info("metadata for SetPowerState", "metadata", c.GetMetadata()) + + <-ctx.Done() +} + +func testConsumer(ctx context.Context) error { + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + req := rpc.RequestPayload{} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + rp := rpc.ResponsePayload{ + ID: req.ID, + Host: req.Host, + } + switch req.Method { + case rpc.PowerGetMethod: + rp.Result = "on" + case rpc.PowerSetMethod: + + case rpc.BootDeviceMethod: + + default: + w.WriteHeader(http.StatusNotFound) + } + b, _ := json.Marshal(rp) + w.Write(b) + }) + + return http.ListenAndServe(":8800", nil) +} diff --git a/go.mod b/go.mod index ff0b55c8..d73adff9 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,19 @@ module github.com/bmc-toolbox/bmclib/v2 go 1.18 require ( + dario.cat/mergo v1.0.0 + github.com/Jeffail/gabs/v2 v2.7.0 github.com/bmc-toolbox/common v0.0.0-20230220061748-93ff001f4a1d github.com/bombsimon/logrusr/v2 v2.0.1 + github.com/ghodss/yaml v1.0.0 github.com/go-logr/logr v1.2.4 + github.com/go-logr/zerologr v1.2.3 github.com/google/go-cmp v0.5.9 github.com/hashicorp/go-multierror v1.1.1 github.com/jacobweinstock/iamt v0.0.0-20230502042727-d7cdbe67d9ef github.com/jacobweinstock/registrar v0.4.7 github.com/pkg/errors v0.9.1 + github.com/rs/zerolog v1.30.0 github.com/sirupsen/logrus v1.8.1 github.com/stmcginnis/gofish v0.14.0 github.com/stretchr/testify v1.8.0 @@ -26,8 +31,11 @@ require ( github.com/VictorLowther/soap v0.0.0-20150314151524-8e36fca84b22 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/mattn/go-colorable v0.1.12 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/satori/go.uuid v1.2.0 // indirect golang.org/x/sys v0.1.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ba4b45b7..cdb371e3 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= +dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= +github.com/Jeffail/gabs/v2 v2.7.0 h1:Y2edYaTcE8ZpRsR2AtmPu5xQdFDIthFG0jYhu5PY8kg= +github.com/Jeffail/gabs/v2 v2.7.0/go.mod h1:dp5ocw1FvBBQYssgHsG7I1WYsiLRtkUaB1FEtSwvNUw= github.com/VictorLowther/simplexml v0.0.0-20180716164440-0bff93621230 h1:t95Grn2mOPfb3+kPDWsNnj4dlNcxnvuR72IjY8eYjfQ= github.com/VictorLowther/simplexml v0.0.0-20180716164440-0bff93621230/go.mod h1:t2EzW1qybnPDQ3LR/GgeF0GOzHUXT5IVMLP2gkW1cmc= github.com/VictorLowther/soap v0.0.0-20150314151524-8e36fca84b22 h1:a0MBqYm44o0NcthLKCljZHe1mxlN6oahCQHHThnSwB4= @@ -6,13 +10,19 @@ github.com/bmc-toolbox/common v0.0.0-20230220061748-93ff001f4a1d h1:cQ30Wa8mhLzK github.com/bmc-toolbox/common v0.0.0-20230220061748-93ff001f4a1d/go.mod h1:SY//n1PJjZfbFbmAsB6GvEKbc7UXz3d30s3kWxfJQ/c= github.com/bombsimon/logrusr/v2 v2.0.1 h1:1VgxVNQMCvjirZIYaT9JYn6sAVGVEcNtRE0y4mvaOAM= github.com/bombsimon/logrusr/v2 v2.0.1/go.mod h1:ByVAX+vHdLGAfdroiMg6q0zgq2FODY2lc5YJvzmOJio= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-logr/logr v1.0.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/zerologr v1.2.3 h1:up5N9vcH9Xck3jJkXzgyOxozT14R47IyDODz8LM1KSs= +github.com/go-logr/zerologr v1.2.3/go.mod h1:BxwGo7y5zgSHYR1BjbnHPyF/5ZjVKfKxAZANVu6E8Ho= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -31,10 +41,17 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.30.0 h1:SymVODrcRsaRaSInD9yQtKbtWqwsfoPcRff/oRXLj4c= +github.com/rs/zerolog v1.30.0/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= @@ -59,6 +76,8 @@ golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210608053332-aa57babbf139/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw= @@ -67,6 +86,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM= gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lint.mk b/lint.mk index 1bac7a79..d492cc70 100644 --- a/lint.mk +++ b/lint.mk @@ -20,7 +20,7 @@ LINTERS := FIXERS := GOLANGCI_LINT_CONFIG := $(LINT_ROOT)/.golangci.yml -GOLANGCI_LINT_VERSION ?= v1.51.2 +GOLANGCI_LINT_VERSION ?= v1.53.3 GOLANGCI_LINT_BIN := $(LINT_ROOT)/out/linters/golangci-lint-$(GOLANGCI_LINT_VERSION)-$(LINT_ARCH) $(GOLANGCI_LINT_BIN): mkdir -p $(LINT_ROOT)/out/linters diff --git a/logging/logging.go b/logging/logging.go index 47289d7f..9edb4728 100644 --- a/logging/logging.go +++ b/logging/logging.go @@ -5,6 +5,8 @@ import ( "github.com/bombsimon/logrusr/v2" "github.com/go-logr/logr" + "github.com/go-logr/zerologr" + "github.com/rs/zerolog" "github.com/sirupsen/logrus" ) @@ -26,3 +28,20 @@ func DefaultLogger() logr.Logger { return logrusr.New(logrusLog) } + +// ZeroLogger is a logr.Logger implementation that uses zerolog. +// This logger handles nested structs better than the logrus implementation. +func ZeroLogger(level string) logr.Logger { + zl := zerolog.New(os.Stdout) + zl = zl.With().Caller().Timestamp().Logger() + var l zerolog.Level + switch level { + case "debug": + l = zerolog.DebugLevel + default: + l = zerolog.InfoLevel + } + zl = zl.Level(l) + + return zerologr.New(&zl) +} diff --git a/option.go b/option.go index abdec1fb..f79cab5f 100644 --- a/option.go +++ b/option.go @@ -7,6 +7,7 @@ import ( "time" "github.com/bmc-toolbox/bmclib/v2/internal/httpclient" + "github.com/bmc-toolbox/bmclib/v2/providers/rpc" "github.com/go-logr/logr" "github.com/jacobweinstock/registrar" ) @@ -137,3 +138,9 @@ func WithDellRedfishUseBasicAuth(useBasicAuth bool) Option { args.providerConfig.dell.UseBasicAuth = useBasicAuth } } + +func WithRPCOpt(opt rpc.Provider) Option { + return func(args *Client) { + args.providerConfig.rpc = opt + } +} diff --git a/providers/rpc/doc.go b/providers/rpc/doc.go new file mode 100644 index 00000000..52e56886 --- /dev/null +++ b/providers/rpc/doc.go @@ -0,0 +1,12 @@ +/* +Package rpc is a provider that defines an HTTP request/response contract for handling BMC interactions. +It allows users a simple way to interoperate with an existing/bespoke out-of-band management solution. + +The rpc provider request/response payloads are modeled after JSON-RPC 2.0, but are not JSON-RPC 2.0 +compliant so as to allow for more flexibility and interoperability with existing systems. + +The rpc provider has options that can be set to include an HMAC signature in the request header. +It follows the features found at https://webhooks.fyi/security/hmac, this includes hash algorithms sha256 +and sha512, replay prevention, versioning, and key rotation. +*/ +package rpc diff --git a/providers/rpc/experimental.go b/providers/rpc/experimental.go new file mode 100644 index 00000000..0464a281 --- /dev/null +++ b/providers/rpc/experimental.go @@ -0,0 +1,26 @@ +package rpc + +import ( + "github.com/Jeffail/gabs/v2" + "github.com/ghodss/yaml" +) + +// embedPayload will embed the RequestPayload into the given JSON object at the dot path notation location ("object.data"). +func (p *RequestPayload) embedPayload(rawJSON []byte, dotPath string) ([]byte, error) { + if rawJSON == nil { + return rawJSON, nil + } + jdata2, err := yaml.YAMLToJSON(rawJSON) + if err != nil { + return nil, err + } + g, err := gabs.ParseJSON(jdata2) + if err != nil { + return nil, err + } + if _, err := g.SetP(p, dotPath); err != nil { + return nil, err + } + + return g.Bytes(), nil +} diff --git a/providers/rpc/http.go b/providers/rpc/http.go new file mode 100644 index 00000000..bde373ab --- /dev/null +++ b/providers/rpc/http.go @@ -0,0 +1,69 @@ +package rpc + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" +) + +// createRequest +func (p *Provider) createRequest(ctx context.Context, rp RequestPayload) (*http.Request, error) { + var data []byte + if rj := p.Opts.Experimental.CustomRequestPayload; rj != nil && p.Opts.Experimental.DotPath != "" { + d, err := rp.embedPayload(rj, p.Opts.Experimental.DotPath) + if err != nil { + return nil, err + } + data = d + } else { + d, err := json.Marshal(rp) + if err != nil { + return nil, err + } + data = d + } + + req, err := http.NewRequestWithContext(ctx, p.Opts.Request.HTTPMethod, p.listenerURL.String(), bytes.NewReader(data)) + if err != nil { + return nil, err + } + for k, v := range p.Opts.Request.StaticHeaders { + req.Header.Add(k, strings.Join(v, ",")) + } + if p.Opts.Request.HTTPContentType != "" { + req.Header.Set("Content-Type", p.Opts.Request.HTTPContentType) + } + if p.Opts.Request.TimestampHeader != "" { + req.Header.Add(p.Opts.Request.TimestampHeader, time.Now().Format(p.Opts.Request.TimestampFormat)) + } + + return req, nil +} + +func (p *Provider) handleResponse(statusCode int, headers http.Header, body *bytes.Buffer, reqKeysAndValues []any) (ResponsePayload, error) { + kvs := reqKeysAndValues + defer func() { + if !p.LogNotificationsDisabled { + kvs = append(kvs, responseKVS(statusCode, headers, body)...) + p.Logger.Info("rpc notification details", kvs...) + } + }() + + res := ResponsePayload{} + if err := json.Unmarshal(body.Bytes(), &res); err != nil { + if statusCode != http.StatusOK { + return ResponsePayload{}, fmt.Errorf("unexpected status code: %d, response error(optional): %v", statusCode, res.Error) + } + example, _ := json.Marshal(ResponsePayload{ID: 123, Host: p.Host, Error: &ResponseError{Code: 1, Message: "error message"}}) + return ResponsePayload{}, fmt.Errorf("failed to parse response: got: %q, error: %w, expected response json spec: %v", body.String(), err, string(example)) + } + if statusCode != http.StatusOK { + return ResponsePayload{}, fmt.Errorf("unexpected status code: %d, response error(optional): %v", statusCode, res.Error) + } + + return res, nil +} diff --git a/providers/rpc/http_test.go b/providers/rpc/http_test.go new file mode 100644 index 00000000..c211ff62 --- /dev/null +++ b/providers/rpc/http_test.go @@ -0,0 +1,168 @@ +package rpc + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func testRequest(method, reqURL string, body RequestPayload, headers http.Header) *http.Request { + buf := new(bytes.Buffer) + _ = json.NewEncoder(buf).Encode(body) + req, _ := http.NewRequestWithContext(context.Background(), method, reqURL, buf) + req.Header = headers + return req +} + +func TestRequestKVS(t *testing.T) { + tests := map[string]struct { + req *http.Request + expected []interface{} + }{ + "success": { + req: testRequest( + http.MethodPost, "http://example.com", + RequestPayload{ID: 1, Host: "127.0.0.1", Method: "POST", Params: nil}, + http.Header{"Content-Type": []string{"application/json"}}, + ), + expected: []interface{}{"request", requestDetails{ + Body: RequestPayload{ + ID: 1, + Host: "127.0.0.1", + Method: "POST", + Params: nil, + }, + Headers: http.Header{"Content-Type": {"application/json"}}, + URL: "http://example.com", + Method: "POST", + }}, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + kvs := requestKVS(tc.req) + if diff := cmp.Diff(kvs, tc.expected); diff != "" { + t.Fatalf("requestKVS() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestRequestKVSOneOffs(t *testing.T) { + t.Run("nil body", func(t *testing.T) { + req, _ := http.NewRequestWithContext(context.Background(), http.MethodPost, "http://example.com", nil) + got := requestKVS(req) + if diff := cmp.Diff(got, []interface{}{"request", requestDetails{}}); diff != "" { + t.Logf("got: %+v", got) + t.Fatalf("requestKVS(req) mismatch (-want +got):\n%s", diff) + } + }) + t.Run("nil request", func(t *testing.T) { + if diff := cmp.Diff(requestKVS(nil), []interface{}{"request", requestDetails{}}); diff != "" { + t.Fatalf("requestKVS(nil) mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("failed to unmarshal body", func(t *testing.T) { + req, _ := http.NewRequestWithContext(context.Background(), http.MethodPost, "http://example.com", bytes.NewBufferString("invalid")) + got := requestKVS(req) + if diff := cmp.Diff(got, []interface{}{"request", requestDetails{URL: "http://example.com", Method: http.MethodPost, Headers: http.Header{}}}); diff != "" { + t.Logf("got: %+v", got) + t.Fatalf("requestKVS(req) mismatch (-want +got):\n%s", diff) + } + }) +} + +func TestResponseKVS(t *testing.T) { + tests := map[string]struct { + resp *http.Response + expected []interface{} + }{ + "one": { + resp: &http.Response{StatusCode: 200, Header: http.Header{"Content-Type": []string{"application/json"}}, Body: io.NopCloser(strings.NewReader(`{"id":1,"host":"127.0.0.1"}`))}, + expected: []interface{}{"response", responseDetails{ + StatusCode: 200, + Headers: http.Header{"Content-Type": {"application/json"}}, + Body: ResponsePayload{ID: 1, Host: "127.0.0.1"}, + }}, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + buf := new(bytes.Buffer) + _, _ = io.Copy(buf, tc.resp.Body) + kvs := responseKVS(tc.resp.StatusCode, tc.resp.Header, buf) + if diff := cmp.Diff(kvs, tc.expected); diff != "" { + t.Fatalf("requestKVS() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestCreateRequest(t *testing.T) { + tests := map[string]struct { + cfg Provider + body RequestPayload + expected *http.Request + }{ + "success": { + cfg: Provider{ + Opts: Opts{ + Request: RequestOpts{ + HTTPMethod: http.MethodPost, + HTTPContentType: "application/json", + StaticHeaders: http.Header{"X-Test": []string{"test"}}, + }, + }, + listenerURL: &url.URL{Scheme: "http", Host: "example.com"}, + }, + body: RequestPayload{ID: 1, Host: "127.0.0.1", Method: PowerSetMethod}, + expected: &http.Request{ + ContentLength: 52, + Host: "example.com", + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Method: http.MethodPost, + URL: &url.URL{Scheme: "http", Host: "example.com"}, + Header: http.Header{"X-Test": {"test"}, "Content-Type": {"application/json"}}, + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + data, _ := json.Marshal(tc.body) + body := bytes.NewReader(data) + tc.expected.Body = io.NopCloser(body) + req, err := tc.cfg.createRequest(context.Background(), tc.body) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(req, tc.expected, cmpopts.IgnoreUnexported(http.Request{}, bytes.Reader{}), cmpopts.IgnoreFields(http.Request{}, "GetBody")); diff != "" { + t.Fatalf("createRequest() mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestContentSize(t *testing.T) { + prov := New("http://127.0.0.1/rpc", "127.0.2.1", Secrets{SHA256: {"superSecret1"}}) + _ = prov.Open(context.Background()) + reqPayload := RequestPayload{ID: 1, Host: "127.0.0.1", Method: PowerGetMethod} + req, err := prov.createRequest(context.Background(), reqPayload) + if err != nil { + t.Fatal(err) + } + if req.ContentLength > maxContentLenAllowed { + t.Fatalf("unexpected content length: got: %d, want: %v", req.ContentLength, maxContentLenAllowed) + } +} diff --git a/providers/rpc/logging.go b/providers/rpc/logging.go new file mode 100644 index 00000000..5e689c68 --- /dev/null +++ b/providers/rpc/logging.go @@ -0,0 +1,58 @@ +package rpc + +import ( + "bytes" + "encoding/json" + "io" + "net/http" +) + +type requestDetails struct { + Body RequestPayload `json:"body"` + Headers http.Header `json:"headers"` + URL string `json:"url"` + Method string `json:"method"` +} + +type responseDetails struct { + StatusCode int `json:"statusCode"` + Body ResponsePayload `json:"body"` + Headers http.Header `json:"headers"` +} + +// requestKVS returns a slice of key, value sets. Used for logging. +func requestKVS(req *http.Request) []interface{} { + var r requestDetails + if req != nil && req.Body != nil { + var p RequestPayload + reqBody, err := io.ReadAll(req.Body) + if err == nil { + req.Body = io.NopCloser(bytes.NewBuffer(reqBody)) + _ = json.Unmarshal(reqBody, &p) + } + r = requestDetails{ + Body: p, + Headers: req.Header, + URL: req.URL.String(), + Method: req.Method, + } + } + + return []interface{}{"request", r} +} + +// responseKVS returns a slice of key, value sets. Used for logging. +func responseKVS(statusCode int, headers http.Header, body *bytes.Buffer) []interface{} { + var r responseDetails + if body.Len() > 0 { + var p ResponsePayload + _ = json.Unmarshal(body.Bytes(), &p) + r = responseDetails{ + StatusCode: statusCode, + Body: p, + Headers: headers, + } + } + + return []interface{}{"response", r} +} diff --git a/providers/rpc/payload.go b/providers/rpc/payload.go new file mode 100644 index 00000000..560dc2fe --- /dev/null +++ b/providers/rpc/payload.go @@ -0,0 +1,69 @@ +package rpc + +import "fmt" + +type Method string + +const ( + BootDeviceMethod Method = "setBootDevice" + PowerSetMethod Method = "setPowerState" + PowerGetMethod Method = "getPowerState" + VirtualMediaMethod Method = "setVirtualMedia" +) + +// RequestPayload is the payload sent to the ConsumerURL. +type RequestPayload struct { + ID int64 `json:"id"` + Host string `json:"host"` + Method Method `json:"method"` + Params any `json:"params,omitempty"` +} + +// BootDeviceParams are the parameters options used when setting a boot device. +type BootDeviceParams struct { + Device string `json:"device"` + Persistent bool `json:"persistent"` + EFIBoot bool `json:"efiBoot"` +} + +// PowerSetParams are the parameters options used when setting the power state. +type PowerSetParams struct { + State string `json:"state"` +} + +// PowerGetParams are the parameters options used when getting the power state. +type VirtualMediaParams struct { + MediaURL string `json:"mediaUrl"` + Kind string `json:"kind"` +} + +// ResponsePayload is the payload received from the ConsumerURL. +// The Result field is an interface{} so that different methods +// can define the contract according to their needs. +type ResponsePayload struct { + // ID is the ID of the response. It should match the ID of the request but is not enforced. + ID int64 `json:"id"` + Host string `json:"host"` + Result any `json:"result,omitempty"` + Error *ResponseError `json:"error,omitempty"` +} + +type ResponseError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type PowerGetResult string + +const ( + PoweredOn PowerGetResult = "on" + PoweredOff PowerGetResult = "off" +) + +func (p PowerGetResult) String() string { + return string(p) +} + +func (r *ResponseError) String() string { + return fmt.Sprintf("code: %v, message: %v", r.Code, r.Message) +} diff --git a/providers/rpc/rpc.go b/providers/rpc/rpc.go new file mode 100644 index 00000000..4651caaa --- /dev/null +++ b/providers/rpc/rpc.go @@ -0,0 +1,369 @@ +package rpc + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "hash" + "io" + "net/http" + "net/url" + "reflect" + "strings" + "time" + + "github.com/bmc-toolbox/bmclib/v2/providers" + "github.com/go-logr/logr" + "github.com/jacobweinstock/registrar" +) + +const ( + // ProviderName for the RPC implementation. + ProviderName = "rpc" + // ProviderProtocol for the rpc implementation. + ProviderProtocol = "http" + + // defaults + timestampHeader = "X-BMCLIB-Timestamp" + signatureHeader = "X-BMCLIB-Signature" + contentType = "application/json" + maxContentLenAllowed = 512 << (10 * 1) // 512KB + + // SHA256 is the SHA256 algorithm. + SHA256 Algorithm = "sha256" + // SHA256Short is the short version of the SHA256 algorithm. + SHA256Short Algorithm = "256" + // SHA512 is the SHA512 algorithm. + SHA512 Algorithm = "sha512" + // SHA512Short is the short version of the SHA512 algorithm. + SHA512Short Algorithm = "512" +) + +// Features implemented by the RPC provider. +var Features = registrar.Features{ + providers.FeaturePowerSet, + providers.FeaturePowerState, + providers.FeatureBootDeviceSet, +} + +// Algorithm is the type for HMAC algorithms. +type Algorithm string + +// Secrets hold per algorithm slice secrets. +// These secrets will be used to create HMAC signatures. +type Secrets map[Algorithm][]string + +// Signatures hold per algorithm slice of signatures. +type Signatures map[Algorithm][]string + +// Provider defines the configuration for sending rpc notifications. +type Provider struct { + // ConsumerURL is the URL where an rpc consumer/listener is running + // and to which we will send and receive all notifications. + ConsumerURL string + // Host is the BMC ip address or hostname or identifier. + Host string + // Client is the http client used for all HTTP calls. + Client *http.Client + // Logger is the logger to use for logging. + Logger logr.Logger + // LogNotificationsDisabled determines whether responses from rpc consumer/listeners will be logged or not. + LogNotificationsDisabled bool + // Opts are the options for the rpc provider. + Opts Opts + + // listenerURL is the URL of the rpc consumer/listener. + listenerURL *url.URL +} + +type Opts struct { + // Request is the options used to create the rpc HTTP request. + Request RequestOpts + // Signature is the options used for adding an HMAC signature to an HTTP request. + Signature SignatureOpts + // HMAC is the options used to create a HMAC signature. + HMAC HMACOpts + // Experimental options. + Experimental Experimental +} + +type RequestOpts struct { + // HTTPContentType is the content type to use for the rpc request notification. + HTTPContentType string + // HTTPMethod is the HTTP method to use for the rpc request notification. + HTTPMethod string + // StaticHeaders are predefined headers that will be added to every request. + StaticHeaders http.Header + // TimestampFormat is the time format for the timestamp header. + TimestampFormat string + // TimestampHeader is the header name that should contain the timestamp. Example: X-BMCLIB-Timestamp + TimestampHeader string +} + +type SignatureOpts struct { + // HeaderName is the header name that should contain the signature(s). Example: X-BMCLIB-Signature + HeaderName string + // AppendAlgoToHeaderDisabled decides whether to append the algorithm to the signature header or not. + // Example: X-BMCLIB-Signature becomes X-BMCLIB-Signature-256 + // When set to true, a header will be added for each algorithm. Example: X-BMCLIB-Signature-256 and X-BMCLIB-Signature-512 + AppendAlgoToHeaderDisabled bool + // IncludedPayloadHeaders are headers whose values will be included in the signature payload. Example: given these headers in a request: + // X-My-Header=123,X-Another=456, and IncludedPayloadHeaders := []string{"X-Another"}, the value of "X-Another" will be included in the signature payload. + // All headers will be deduplicated. + IncludedPayloadHeaders []string +} + +type HMACOpts struct { + // Hashes is a map of algorithms to a slice of hash.Hash (these are the hashed secrets). The slice is used to support multiple secrets. + Hashes map[Algorithm][]hash.Hash + // PrefixSigDisabled determines whether the algorithm will be prefixed to the signature. Example: sha256=abc123 + PrefixSigDisabled bool + // Secrets are a map of algorithms to secrets used for signing. + Secrets Secrets +} + +type Experimental struct { + // CustomRequestPayload must be in json. + CustomRequestPayload []byte + // DotPath is the path to where the bmclib RequestPayload{} will be embedded. For example: object.data.body + DotPath string +} + +// New returns a new Config containing all the defaults for the rpc provider. +func New(consumerURL string, host string, secrets Secrets) *Provider { + // defaults + c := &Provider{ + Host: host, + ConsumerURL: consumerURL, + Client: http.DefaultClient, + Logger: logr.Discard(), + Opts: Opts{ + Request: RequestOpts{ + HTTPContentType: contentType, + HTTPMethod: http.MethodPost, + TimestampFormat: time.RFC3339, + TimestampHeader: timestampHeader, + }, + Signature: SignatureOpts{ + HeaderName: signatureHeader, + IncludedPayloadHeaders: []string{}, + }, + HMAC: HMACOpts{ + Hashes: map[Algorithm][]hash.Hash{}, + Secrets: secrets, + }, + Experimental: Experimental{}, + }, + } + + if len(secrets) > 0 { + c.Opts.HMAC.Hashes = CreateHashes(secrets) + } + + return c +} + +// Name returns the name of this rpc provider. +// Implements bmc.Provider interface +func (p *Provider) Name() string { + return ProviderName +} + +// Open a connection to the rpc consumer. +// For the rpc provider, Open means validating the Config and +// that communication with the rpc consumer can be established. +func (p *Provider) Open(ctx context.Context) error { + // 1. validate consumerURL is a properly formatted URL. + // 2. validate that we can communicate with the rpc consumer. + u, err := url.Parse(p.ConsumerURL) + if err != nil { + return err + } + p.listenerURL = u + buf := new(bytes.Buffer) + _ = json.NewEncoder(buf).Encode(RequestPayload{}) + testReq, err := http.NewRequestWithContext(ctx, p.Opts.Request.HTTPMethod, p.listenerURL.String(), buf) + if err != nil { + return err + } + // test that we can communicate with the rpc consumer. + resp, err := p.Client.Do(testReq) + if err != nil { + return err + } + if resp.StatusCode >= http.StatusInternalServerError { + return fmt.Errorf("issue on the rpc consumer side, status code: %d", resp.StatusCode) + } + + // test that the consumer responses with the expected contract (ResponsePayload{}). + if err := json.NewDecoder(resp.Body).Decode(&ResponsePayload{}); err != nil { + return fmt.Errorf("issue with the rpc consumer response: %v", err) + } + + return resp.Body.Close() +} + +// Close a connection to the rpc consumer. +func (p *Provider) Close(_ context.Context) (err error) { + return nil +} + +// BootDeviceSet sends a next boot device rpc notification. +func (p *Provider) BootDeviceSet(ctx context.Context, bootDevice string, setPersistent, efiBoot bool) (ok bool, err error) { + rp := RequestPayload{ + ID: time.Now().UnixNano(), + Host: p.Host, + Method: BootDeviceMethod, + Params: BootDeviceParams{ + Device: bootDevice, + Persistent: setPersistent, + EFIBoot: efiBoot, + }, + } + resp, err := p.process(ctx, rp) + if err != nil { + return false, err + } + if resp.Error != nil && resp.Error.Code != 0 { + return false, fmt.Errorf("error from rpc consumer: %v", resp.Error) + } + + return true, nil +} + +// PowerSet sets the power state of a BMC machine. +func (p *Provider) PowerSet(ctx context.Context, state string) (ok bool, err error) { + switch strings.ToLower(state) { + case "on", "off", "cycle": + rp := RequestPayload{ + ID: time.Now().UnixNano(), + Host: p.Host, + Method: PowerSetMethod, + Params: PowerSetParams{ + State: strings.ToLower(state), + }, + } + resp, err := p.process(ctx, rp) + if err != nil { + return ok, err + } + if resp.Error != nil && resp.Error.Code != 0 { + return ok, fmt.Errorf("error from rpc consumer: %v", resp.Error) + } + + return true, nil + } + + return false, errors.New("requested power state is not supported") +} + +// PowerStateGet gets the power state of a BMC machine. +func (p *Provider) PowerStateGet(ctx context.Context) (state string, err error) { + rp := RequestPayload{ + ID: time.Now().UnixNano(), + Host: p.Host, + Method: PowerGetMethod, + } + resp, err := p.process(ctx, rp) + if err != nil { + return "", err + } + if resp.Error != nil && resp.Error.Code != 0 { + return "", fmt.Errorf("error from rpc consumer: %v", resp.Error) + } + + return resp.Result.(string), nil +} + +// process is the main function for the roundtrip of rpc calls to the ConsumerURL. +func (p *Provider) process(ctx context.Context, rp RequestPayload) (ResponsePayload, error) { + // 1. create the HTTP request. + // 2. create the signature payload. + // 3. sign the signature payload. + // 4. add signatures to the request as headers. + // 5. request/response round trip. + // 6. handle the response. + req, err := p.createRequest(ctx, rp) + if err != nil { + return ResponsePayload{}, err + } + + // create the signature payload + reqBuf := new(bytes.Buffer) + if _, err := io.Copy(reqBuf, req.Body); err != nil { + return ResponsePayload{}, fmt.Errorf("failed to read request body: %w", err) + } + headersForSig := http.Header{} + for _, h := range p.Opts.Signature.IncludedPayloadHeaders { + if val := req.Header.Get(h); val != "" { + headersForSig.Add(h, val) + } + } + sigPay := createSignaturePayload(reqBuf.Bytes(), headersForSig) + + // sign the signature payload + sigs, err := sign(sigPay, p.Opts.HMAC.Hashes, p.Opts.HMAC.PrefixSigDisabled) + if err != nil { + return ResponsePayload{}, err + } + + // add signatures to the request as headers. + for algo, keys := range sigs { + if len(sigs) > 0 { + h := p.Opts.Signature.HeaderName + if !p.Opts.Signature.AppendAlgoToHeaderDisabled { + h = fmt.Sprintf("%s-%s", h, algo.ToShort()) + } + req.Header.Add(h, strings.Join(keys, ",")) + } + } + + // request/response round trip. + kvs := requestKVS(req) + kvs = append(kvs, []interface{}{"host", p.Host, "method", rp.Method, "consumerURL", p.ConsumerURL}...) + if rp.Params != nil { + kvs = append(kvs, []interface{}{"params", rp.Params}...) + } + + resp, err := p.Client.Do(req) + if err != nil { + p.Logger.Error(err, "failed to send rpc notification", kvs...) + return ResponsePayload{}, err + } + defer resp.Body.Close() + + // handle the response + if resp.ContentLength > maxContentLenAllowed || resp.ContentLength < 0 { + return ResponsePayload{}, fmt.Errorf("response body is too large: %d bytes, max allowed: %d bytes", resp.ContentLength, maxContentLenAllowed) + } + respBuf := new(bytes.Buffer) + if _, err := io.CopyN(respBuf, resp.Body, resp.ContentLength); err != nil { + return ResponsePayload{}, fmt.Errorf("failed to read response body: %w", err) + } + respPayload, err := p.handleResponse(resp.StatusCode, resp.Header, respBuf, kvs) + if err != nil { + return ResponsePayload{}, err + } + + return respPayload, nil +} + +// Transformer implements the mergo interfaces for merging custom types. +func (p *Provider) Transformer(typ reflect.Type) func(dst, src reflect.Value) error { + switch typ { + case reflect.TypeOf(logr.Logger{}): + return func(dst, src reflect.Value) error { + if dst.CanSet() { + isZero := dst.MethodByName("GetSink") + result := isZero.Call(nil) + if result[0].IsNil() { + dst.Set(src) + } + } + return nil + } + } + return nil +} diff --git a/providers/rpc/rpc_test.go b/providers/rpc/rpc_test.go new file mode 100644 index 00000000..5aecf053 --- /dev/null +++ b/providers/rpc/rpc_test.go @@ -0,0 +1,204 @@ +package rpc + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestOpen(t *testing.T) { + tests := map[string]struct { + url string + shouldErr bool + }{ + "success": {}, + "bad url": {url: "%", shouldErr: true}, + "failed request": {url: "127.1.1.1", shouldErr: true}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + svr := testConsumer{rp: ResponsePayload{}}.testServer() + defer svr.Close() + + u := svr.URL + if tc.url != "" { + u = tc.url + } + c := New(u, "127.0.1.1", Secrets{SHA256: []string{"superSecret1"}}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := c.Open(ctx); err != nil && !tc.shouldErr { + t.Fatal(err) + } + c.Close(ctx) + }) + } +} + +func TestBootDeviceSet(t *testing.T) { + tests := map[string]struct { + url string + shouldErr bool + }{ + "success": {}, + "failure from consumer": {shouldErr: true}, + "failed request": {url: "127.1.1.1", shouldErr: true}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + rsp := testConsumer{ + rp: ResponsePayload{}, + } + if tc.shouldErr { + rsp.rp.Error = &ResponseError{Code: 500, Message: "failed"} + } + svr := rsp.testServer() + defer svr.Close() + + u := svr.URL + if tc.url != "" { + u = tc.url + } + c := New(u, "127.0.1.1", Secrets{SHA256: {"superSecret1"}}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _ = c.Open(ctx) + if _, err := c.BootDeviceSet(ctx, "pxe", false, false); err != nil && !tc.shouldErr { + t.Fatal(err) + } else if err == nil && tc.shouldErr { + t.Fatal("expected error, got none") + } + }) + } +} + +func TestPowerSet(t *testing.T) { + tests := map[string]struct { + url string + powerState string + shouldErr bool + }{ + "success": {powerState: "on"}, + "failed request": {powerState: "on", url: "127.1.1.1", shouldErr: true}, + "unknown state": {powerState: "unknown", shouldErr: true}, + "failure from consumer": {powerState: "on", shouldErr: true}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + rsp := testConsumer{ + rp: ResponsePayload{Result: tc.powerState}, + } + if tc.shouldErr { + rsp.rp.Error = &ResponseError{Code: 500, Message: "failed"} + } + svr := rsp.testServer() + defer svr.Close() + + u := svr.URL + if tc.url != "" { + u = tc.url + } + c := New(u, "127.0.1.1", Secrets{SHA256: {"superSecret1"}}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _ = c.Open(ctx) + + _, err := c.PowerSet(ctx, tc.powerState) + if err != nil && !tc.shouldErr { + t.Fatal(err) + } + }) + } +} + +func TestPowerStateGet(t *testing.T) { + tests := map[string]struct { + powerState string + shouldErr bool + url string + }{ + "success": {}, + "unknown state": {shouldErr: true}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + rsp := testConsumer{ + rp: ResponsePayload{Result: tc.powerState}, + } + if tc.shouldErr { + rsp.rp.Error = &ResponseError{Code: 500, Message: "failed"} + } + svr := rsp.testServer() + defer svr.Close() + + u := svr.URL + if tc.url != "" { + u = tc.url + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c := New(u, "127.0.1.1", Secrets{SHA256: {"superSecret1"}}) + _ = c.Open(ctx) + gotState, err := c.PowerStateGet(ctx) + if err != nil && !tc.shouldErr { + t.Fatal(err) + } + if diff := cmp.Diff(gotState, tc.powerState); diff != "" { + t.Fatal(diff) + } + }) + } +} + +func TestServerErrors(t *testing.T) { + tests := map[string]struct { + statusCode int + shouldErr bool + }{ + "bad request": {statusCode: http.StatusBadRequest, shouldErr: true}, + "not found": {statusCode: http.StatusNotFound, shouldErr: true}, + "internal": {statusCode: http.StatusInternalServerError, shouldErr: true}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + rsp := testConsumer{ + rp: ResponsePayload{Result: "on"}, + statusCode: tc.statusCode, + } + svr := rsp.testServer() + defer svr.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c := New(svr.URL, "127.0.0.1", Secrets{SHA256: {"superSecret1"}}) + if err := c.Open(ctx); err == nil { + t.Fatal("expected error, got none") + } + }) + } +} + +type testConsumer struct { + rp ResponsePayload + statusCode int +} + +func (t testConsumer) testServer() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if t.statusCode != 0 { + w.WriteHeader(t.statusCode) + return + } + b, _ := json.Marshal(t.rp) + _, _ = w.Write(b) + })) +} diff --git a/providers/rpc/signature.go b/providers/rpc/signature.go new file mode 100644 index 00000000..724187ee --- /dev/null +++ b/providers/rpc/signature.go @@ -0,0 +1,102 @@ +package rpc + +import ( + "crypto/hmac" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" + "fmt" + "hash" + "net/http" + "strings" +) + +type Hashes map[Algorithm][]hash.Hash + +// createSignaturePayload a signature payload is created by appending header values to the request body. +// there is no delimiter between the body and the header values and all header values. +func createSignaturePayload(body []byte, h http.Header) []byte { + // add headers to signature payload, no space between values. + for _, val := range h { + body = append(body, []byte(strings.Join(val, ""))...) + + } + + return body +} + +// sign signs the data with all the given hashes and returns the signatures. +func sign(data []byte, h Hashes, prefixSigDisabled bool) (Signatures, error) { + sigs := map[Algorithm][]string{} + for algo, hshs := range h { + for _, hsh := range hshs { + if _, err := hsh.Write(data); err != nil { + return nil, err + } + sig := hex.EncodeToString(hsh.Sum(nil)) + if !prefixSigDisabled { + sig = fmt.Sprintf("%s=%s", algo, sig) + } + sigs[algo] = append(sigs[algo], sig) + // reset so Sign can be called multiple times. Otherwise, the next call will append to the previous one. + hsh.Reset() + } + } + + return sigs, nil +} + +// ToShort returns the short version of an algorithm. +func (a Algorithm) ToShort() Algorithm { + switch a { + case SHA256: + return SHA256Short + case SHA512: + return SHA512Short + default: + return a + } +} + +// NewSHA256 returns a map of SHA256 HMACs from the given secrets. +func NewSHA256(secret ...string) Hashes { + var hsh []hash.Hash + for _, s := range secret { + hsh = append(hsh, hmac.New(sha256.New, []byte(s))) + } + return Hashes{SHA256: hsh} +} + +// NewSHA512 returns a map of SHA512 HMACs from the given secrets. +func NewSHA512(secret ...string) Hashes { + var hsh []hash.Hash + for _, s := range secret { + hsh = append(hsh, hmac.New(sha512.New, []byte(s))) + } + return Hashes{SHA512: hsh} +} + +func mergeHashes(hs ...Hashes) Hashes { + m := Hashes{} + for _, h := range hs { + for k, v := range h { + m[k] = append(m[k], v...) + } + } + return m +} + +// CreateHashes creates a new hash for all secrets provided. +func CreateHashes(s Secrets) map[Algorithm][]hash.Hash { + h := map[Algorithm][]hash.Hash{} + for algo, secrets := range s { + switch algo { + case SHA256, SHA256Short: + h = mergeHashes(h, NewSHA256(secrets...)) + case SHA512, SHA512Short: + h = mergeHashes(h, NewSHA512(secrets...)) + } + } + + return h +}