diff --git a/sztp-agent/cmd/daemon.go b/sztp-agent/cmd/daemon.go index d309d9a6..92d49efd 100644 --- a/sztp-agent/cmd/daemon.go +++ b/sztp-agent/cmd/daemon.go @@ -59,7 +59,8 @@ func Daemon() *cobra.Command { return fmt.Errorf("must not be folder: %q", filePath) } } - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommandDaemon() }, } diff --git a/sztp-agent/cmd/disable.go b/sztp-agent/cmd/disable.go index 3ce70a4d..157776d9 100644 --- a/sztp-agent/cmd/disable.go +++ b/sztp-agent/cmd/disable.go @@ -34,7 +34,8 @@ func Disable() *cobra.Command { Use: "disable", Short: "Run the disable command", RunE: func(_ *cobra.Command, _ []string) error { - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommandDisable() }, } diff --git a/sztp-agent/cmd/enable.go b/sztp-agent/cmd/enable.go index 745bd795..49f53b4a 100644 --- a/sztp-agent/cmd/enable.go +++ b/sztp-agent/cmd/enable.go @@ -34,7 +34,8 @@ func Enable() *cobra.Command { Use: "enable", Short: "Run the enable command", RunE: func(_ *cobra.Command, _ []string) error { - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommandEnable() }, } diff --git a/sztp-agent/cmd/run.go b/sztp-agent/cmd/run.go index f3b02c1f..99a9b1f9 100644 --- a/sztp-agent/cmd/run.go +++ b/sztp-agent/cmd/run.go @@ -59,7 +59,8 @@ func Run() *cobra.Command { return fmt.Errorf("must not be folder: %q", filePath) } } - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommand() }, } diff --git a/sztp-agent/cmd/status.go b/sztp-agent/cmd/status.go index cf5043a7..dbc80dfa 100644 --- a/sztp-agent/cmd/status.go +++ b/sztp-agent/cmd/status.go @@ -34,7 +34,8 @@ func Status() *cobra.Command { Use: "status", Short: "Run the status command", RunE: func(_ *cobra.Command, _ []string) error { - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommandStatus() }, } diff --git a/sztp-agent/pkg/secureagent/agent.go b/sztp-agent/pkg/secureagent/agent.go index e248dc42..67970468 100644 --- a/sztp-agent/pkg/secureagent/agent.go +++ b/sztp-agent/pkg/secureagent/agent.go @@ -9,6 +9,10 @@ Copyright (C) 2022 Red Hat. // Package secureagent implements the secure agent package secureagent +import ( + "net/http" +) + const ( CONTENT_TYPE_YANG = "application/yang-data+json" OS_RELEASE_FILE = "/etc/os-release" @@ -68,6 +72,11 @@ type BootstrapServerErrorOutput struct { } `json:"ietf-restconf:errors"` } +type HttpClient interface { + Get(uri string) (*http.Response, error) + Do(req *http.Request) (*http.Response, error) +} + // Agent is the basic structure to define an agent instance type Agent struct { InputBootstrapURL string // Bootstrap complete URL given by USER @@ -83,10 +92,10 @@ type Agent struct { ProgressJSON ProgressJSON // ProgressJson structure BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo // BootstrapServerOnboardingInfo structure BootstrapServerRedirectInfo BootstrapServerRedirectInfo // BootstrapServerRedirectInfo structure - + HttpClient HttpClient } -func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string) *Agent { +func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string, httpClient HttpClient) *Agent { return &Agent{ InputBootstrapURL: bootstrapURL, BootstrapURL: "", @@ -101,6 +110,7 @@ func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, deviceP ProgressJSON: ProgressJSON{}, BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{}, BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{}, + HttpClient: httpClient, } } diff --git a/sztp-agent/pkg/secureagent/agent_test.go b/sztp-agent/pkg/secureagent/agent_test.go index ad2ebb09..e8234a8a 100644 --- a/sztp-agent/pkg/secureagent/agent_test.go +++ b/sztp-agent/pkg/secureagent/agent_test.go @@ -9,6 +9,7 @@ Copyright (C) 2022 Red Hat. package secureagent import ( + "net/http" "reflect" "testing" ) @@ -829,6 +830,7 @@ func TestNewAgent(t *testing.T) { deviceEndEntityCert string bootstrapTrustAnchorCert string } + client := http.Client{} tests := []struct { name string args args @@ -856,12 +858,13 @@ func TestNewAgent(t *testing.T) { ContentTypeReq: "application/yang-data+json", InputJSONContent: generateInputJSONContent(), DhcpLeaseFile: "TestDhcpLeaseFile", + HttpClient: &client, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert); !reflect.DeepEqual(got, tt.want) { + if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert, &client); !reflect.DeepEqual(got, tt.want) { t.Errorf("NewAgent() = %v, want %v", got, tt.want) } }) diff --git a/sztp-agent/pkg/secureagent/configuration_test.go b/sztp-agent/pkg/secureagent/configuration_test.go index 8c8be267..3aa25815 100644 --- a/sztp-agent/pkg/secureagent/configuration_test.go +++ b/sztp-agent/pkg/secureagent/configuration_test.go @@ -1,6 +1,7 @@ package secureagent import ( + "net/http" "testing" ) @@ -151,6 +152,7 @@ func TestAgent_copyConfigurationFile(t *testing.T) { ProgressJSON: tt.fields.ProgressJSON, BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo, BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo, + HttpClient: &http.Client{}, } if err := a.copyConfigurationFile(); (err != nil) != tt.wantErr { t.Errorf("copyConfigurationFile() error = %v, wantErr %v", err, tt.wantErr) @@ -368,6 +370,7 @@ func TestAgent_launchScriptsConfiguration(t *testing.T) { ProgressJSON: tt.fields.ProgressJSON, BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo, BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo, + HttpClient: &http.Client{}, } if err := a.launchScriptsConfiguration(tt.args.typeOf); (err != nil) != tt.wantErr { t.Errorf("launchScriptsConfiguration() error = %v, wantErr %v", err, tt.wantErr) diff --git a/sztp-agent/pkg/secureagent/daemon_test.go b/sztp-agent/pkg/secureagent/daemon_test.go index 4e9ba1e4..c9341c04 100644 --- a/sztp-agent/pkg/secureagent/daemon_test.go +++ b/sztp-agent/pkg/secureagent/daemon_test.go @@ -418,6 +418,7 @@ func TestAgent_doReqBootstrap(t *testing.T) { ContentTypeReq: tt.fields.ContentTypeReq, InputJSONContent: tt.fields.InputJSONContent, DhcpLeaseFile: tt.fields.DhcpLeaseFile, + HttpClient: &http.Client{}, } if err := a.doRequestBootstrapServerOnboardingInfo(); (err != nil) != tt.wantErr { t.Errorf("doRequestBootstrapServer() error = %v, wantErr %v", err, tt.wantErr) diff --git a/sztp-agent/pkg/secureagent/image.go b/sztp-agent/pkg/secureagent/image.go index daff80d3..9af486c7 100644 --- a/sztp-agent/pkg/secureagent/image.go +++ b/sztp-agent/pkg/secureagent/image.go @@ -8,13 +8,10 @@ Copyright (C) 2022 Red Hat. package secureagent import ( - "crypto/tls" - "crypto/x509" "errors" "fmt" "io" "log" - "net/http" "os" "path/filepath" "strconv" @@ -37,27 +34,7 @@ func (a *Agent) downloadAndValidateImage() error { return err } - caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert()) - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey()) - - check := http.Client{ - CheckRedirect: func(r *http.Request, _ []*http.Request) error { - r.URL.Opaque = r.URL.Path - return nil - }, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - //nolint:gosec - InsecureSkipVerify: true, // TODO: remove skip verify - RootCAs: caCertPool, - Certificates: []tls.Certificate{cert}, - }, - }, - } - - response, err := check.Get(item) + response, err := a.HttpClient.Get(item) if err != nil { return err } diff --git a/sztp-agent/pkg/secureagent/image_test.go b/sztp-agent/pkg/secureagent/image_test.go index 24cd2854..bc73f5f4 100644 --- a/sztp-agent/pkg/secureagent/image_test.go +++ b/sztp-agent/pkg/secureagent/image_test.go @@ -9,7 +9,7 @@ import ( //nolint:funlen func TestAgent_downloadAndValidateImage(t *testing.T) { svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/imageOK" { + if r.URL.Path == "/imageOK" || r.URL.Path == "/report-progress" { w.WriteHeader(200) } else { w.WriteHeader(400) @@ -309,6 +309,7 @@ func TestAgent_downloadAndValidateImage(t *testing.T) { ProgressJSON: tt.fields.ProgressJSON, BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo, BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo, + HttpClient: &http.Client{}, } if err := a.downloadAndValidateImage(); (err != nil) != tt.wantErr { t.Errorf("downloadAndValidateImage() error = %v, wantErr %v", err, tt.wantErr) diff --git a/sztp-agent/pkg/secureagent/progress_test.go b/sztp-agent/pkg/secureagent/progress_test.go index 7a55c0cb..a821dcf9 100644 --- a/sztp-agent/pkg/secureagent/progress_test.go +++ b/sztp-agent/pkg/secureagent/progress_test.go @@ -158,6 +158,7 @@ func TestAgent_doReportProgress(t *testing.T) { InputJSONContent: tt.fields.InputJSONContent, DhcpLeaseFile: tt.fields.DhcpLeaseFile, ProgressJSON: tt.fields.ProgressJSON, + HttpClient: &http.Client{}, } if err := a.doReportProgress(ProgressTypeBootstrapInitiated, "Bootstrap Initiated"); (err != nil) != tt.wantErr { t.Errorf("doReportProgress() error = %v, wantErr %v", err, tt.wantErr) diff --git a/sztp-agent/pkg/secureagent/tls.go b/sztp-agent/pkg/secureagent/tls.go index 44bdaf68..b17911ac 100644 --- a/sztp-agent/pkg/secureagent/tls.go +++ b/sztp-agent/pkg/secureagent/tls.go @@ -18,10 +18,35 @@ import ( "log" "net/http" "os" + "path/filepath" "strconv" "strings" ) +// NewHTTPClient instantiate a new HTTP Client +func NewHTTPClient(bootstrapTrustAnchorCert string, deviceEndEntityCert string, devicePrivateKey string) http.Client { + certPath := filepath.Clean(bootstrapTrustAnchorCert) + caCert, _ := os.ReadFile(certPath) + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + cert, _ := tls.LoadX509KeyPair(deviceEndEntityCert, devicePrivateKey) + client := http.Client{ + CheckRedirect: func(r *http.Request, _ []*http.Request) error { + r.URL.Opaque = r.URL.Path + return nil + }, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + //nolint:gosec + InsecureSkipVerify: true, // TODO: remove skip verify + RootCAs: caCertPool, + Certificates: []tls.Certificate{cert}, + }, + }, + } + return client +} + func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapServerPostOutput, error) { var postResponse BootstrapServerPostOutput var errorResponse BootstrapServerErrorOutput @@ -38,20 +63,7 @@ func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapSe r.SetBasicAuth(a.GetSerialNumber(), a.GetDevicePassword()) r.Header.Add("Content-Type", a.GetContentTypeReq()) - caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert()) - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey()) - - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ //nolint:gosec - RootCAs: caCertPool, - Certificates: []tls.Certificate{cert}, - }, - }, - } - res, err := client.Do(r) + res, err := a.HttpClient.Do(r) if err != nil { log.Println("Error doing the request", err.Error()) return nil, err