Skip to content

Commit

Permalink
refactor: unify http client creation (#448)
Browse files Browse the repository at this point in the history
Signed-off-by: Suraj Shirvankar <[email protected]>
  • Loading branch information
h0lyalg0rithm authored Oct 7, 2024
1 parent 65e62d1 commit 28e1a09
Show file tree
Hide file tree
Showing 13 changed files with 60 additions and 47 deletions.
3 changes: 2 additions & 1 deletion sztp-agent/cmd/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ func Daemon() *cobra.Command {
return fmt.Errorf("must not be folder: %q", filePath)
}
}
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandDaemon()
},
}
Expand Down
3 changes: 2 additions & 1 deletion sztp-agent/cmd/disable.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ func Disable() *cobra.Command {
Use: "disable",
Short: "Run the disable command",
RunE: func(_ *cobra.Command, _ []string) error {
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandDisable()
},
}
Expand Down
3 changes: 2 additions & 1 deletion sztp-agent/cmd/enable.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ func Enable() *cobra.Command {
Use: "enable",
Short: "Run the enable command",
RunE: func(_ *cobra.Command, _ []string) error {
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandEnable()
},
}
Expand Down
3 changes: 2 additions & 1 deletion sztp-agent/cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ func Run() *cobra.Command {
return fmt.Errorf("must not be folder: %q", filePath)
}
}
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommand()
},
}
Expand Down
3 changes: 2 additions & 1 deletion sztp-agent/cmd/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ func Status() *cobra.Command {
Use: "status",
Short: "Run the status command",
RunE: func(_ *cobra.Command, _ []string) error {
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert)
client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey)
a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client)
return a.RunCommandStatus()
},
}
Expand Down
14 changes: 12 additions & 2 deletions sztp-agent/pkg/secureagent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ Copyright (C) 2022 Red Hat.
// Package secureagent implements the secure agent
package secureagent

import (
"net/http"
)

const (
CONTENT_TYPE_YANG = "application/yang-data+json"
OS_RELEASE_FILE = "/etc/os-release"
Expand Down Expand Up @@ -68,6 +72,11 @@ type BootstrapServerErrorOutput struct {
} `json:"ietf-restconf:errors"`
}

type HttpClient interface {
Get(uri string) (*http.Response, error)
Do(req *http.Request) (*http.Response, error)
}

// Agent is the basic structure to define an agent instance
type Agent struct {
InputBootstrapURL string // Bootstrap complete URL given by USER
Expand All @@ -83,10 +92,10 @@ type Agent struct {
ProgressJSON ProgressJSON // ProgressJson structure
BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo // BootstrapServerOnboardingInfo structure
BootstrapServerRedirectInfo BootstrapServerRedirectInfo // BootstrapServerRedirectInfo structure

HttpClient HttpClient
}

func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string) *Agent {
func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string, httpClient HttpClient) *Agent {
return &Agent{
InputBootstrapURL: bootstrapURL,
BootstrapURL: "",
Expand All @@ -101,6 +110,7 @@ func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, deviceP
ProgressJSON: ProgressJSON{},
BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{},
BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{},
HttpClient: httpClient,
}
}

Expand Down
5 changes: 4 additions & 1 deletion sztp-agent/pkg/secureagent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Copyright (C) 2022 Red Hat.
package secureagent

import (
"net/http"
"reflect"
"testing"
)
Expand Down Expand Up @@ -829,6 +830,7 @@ func TestNewAgent(t *testing.T) {
deviceEndEntityCert string
bootstrapTrustAnchorCert string
}
client := http.Client{}
tests := []struct {
name string
args args
Expand Down Expand Up @@ -856,12 +858,13 @@ func TestNewAgent(t *testing.T) {
ContentTypeReq: "application/yang-data+json",
InputJSONContent: generateInputJSONContent(),
DhcpLeaseFile: "TestDhcpLeaseFile",
HttpClient: &client,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert); !reflect.DeepEqual(got, tt.want) {
if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert, &client); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewAgent() = %v, want %v", got, tt.want)
}
})
Expand Down
3 changes: 3 additions & 0 deletions sztp-agent/pkg/secureagent/configuration_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package secureagent

import (
"net/http"
"testing"
)

Expand Down Expand Up @@ -151,6 +152,7 @@ func TestAgent_copyConfigurationFile(t *testing.T) {
ProgressJSON: tt.fields.ProgressJSON,
BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo,
BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo,
HttpClient: &http.Client{},
}
if err := a.copyConfigurationFile(); (err != nil) != tt.wantErr {
t.Errorf("copyConfigurationFile() error = %v, wantErr %v", err, tt.wantErr)
Expand Down Expand Up @@ -368,6 +370,7 @@ func TestAgent_launchScriptsConfiguration(t *testing.T) {
ProgressJSON: tt.fields.ProgressJSON,
BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo,
BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo,
HttpClient: &http.Client{},
}
if err := a.launchScriptsConfiguration(tt.args.typeOf); (err != nil) != tt.wantErr {
t.Errorf("launchScriptsConfiguration() error = %v, wantErr %v", err, tt.wantErr)
Expand Down
1 change: 1 addition & 0 deletions sztp-agent/pkg/secureagent/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ func TestAgent_doReqBootstrap(t *testing.T) {
ContentTypeReq: tt.fields.ContentTypeReq,
InputJSONContent: tt.fields.InputJSONContent,
DhcpLeaseFile: tt.fields.DhcpLeaseFile,
HttpClient: &http.Client{},
}
if err := a.doRequestBootstrapServerOnboardingInfo(); (err != nil) != tt.wantErr {
t.Errorf("doRequestBootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
Expand Down
25 changes: 1 addition & 24 deletions sztp-agent/pkg/secureagent/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@ Copyright (C) 2022 Red Hat.
package secureagent

import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strconv"
Expand All @@ -37,27 +34,7 @@ func (a *Agent) downloadAndValidateImage() error {
return err
}

caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert())
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey())

check := http.Client{
CheckRedirect: func(r *http.Request, _ []*http.Request) error {
r.URL.Opaque = r.URL.Path
return nil
},
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true, // TODO: remove skip verify
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}

response, err := check.Get(item)
response, err := a.HttpClient.Get(item)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion sztp-agent/pkg/secureagent/image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
//nolint:funlen
func TestAgent_downloadAndValidateImage(t *testing.T) {
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/imageOK" {
if r.URL.Path == "/imageOK" || r.URL.Path == "/report-progress" {
w.WriteHeader(200)
} else {
w.WriteHeader(400)
Expand Down Expand Up @@ -309,6 +309,7 @@ func TestAgent_downloadAndValidateImage(t *testing.T) {
ProgressJSON: tt.fields.ProgressJSON,
BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo,
BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo,
HttpClient: &http.Client{},
}
if err := a.downloadAndValidateImage(); (err != nil) != tt.wantErr {
t.Errorf("downloadAndValidateImage() error = %v, wantErr %v", err, tt.wantErr)
Expand Down
1 change: 1 addition & 0 deletions sztp-agent/pkg/secureagent/progress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ func TestAgent_doReportProgress(t *testing.T) {
InputJSONContent: tt.fields.InputJSONContent,
DhcpLeaseFile: tt.fields.DhcpLeaseFile,
ProgressJSON: tt.fields.ProgressJSON,
HttpClient: &http.Client{},
}
if err := a.doReportProgress(ProgressTypeBootstrapInitiated, "Bootstrap Initiated"); (err != nil) != tt.wantErr {
t.Errorf("doReportProgress() error = %v, wantErr %v", err, tt.wantErr)
Expand Down
40 changes: 26 additions & 14 deletions sztp-agent/pkg/secureagent/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,35 @@ import (
"log"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
)

// NewHTTPClient instantiate a new HTTP Client
func NewHTTPClient(bootstrapTrustAnchorCert string, deviceEndEntityCert string, devicePrivateKey string) http.Client {
certPath := filepath.Clean(bootstrapTrustAnchorCert)
caCert, _ := os.ReadFile(certPath)
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(deviceEndEntityCert, devicePrivateKey)
client := http.Client{
CheckRedirect: func(r *http.Request, _ []*http.Request) error {
r.URL.Opaque = r.URL.Path
return nil
},
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true, // TODO: remove skip verify
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}
return client
}

func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapServerPostOutput, error) {
var postResponse BootstrapServerPostOutput
var errorResponse BootstrapServerErrorOutput
Expand All @@ -38,20 +63,7 @@ func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapSe
r.SetBasicAuth(a.GetSerialNumber(), a.GetDevicePassword())
r.Header.Add("Content-Type", a.GetContentTypeReq())

caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert())
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey())

client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{ //nolint:gosec
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}
res, err := client.Do(r)
res, err := a.HttpClient.Do(r)
if err != nil {
log.Println("Error doing the request", err.Error())
return nil, err
Expand Down

0 comments on commit 28e1a09

Please sign in to comment.