From 9fd0b45640db6ad514fd5c6e824041667d65f889 Mon Sep 17 00:00:00 2001 From: Bhoopesh Date: Fri, 20 Sep 2024 01:13:52 +0530 Subject: [PATCH 1/4] feat: status command Signed-off-by: Bhoopesh --- docker-compose.dpu.yml | 28 ++- scripts/run_agent.sh | 3 + sztp-agent/cmd/daemon.go | 19 +- sztp-agent/cmd/disable.go | 9 +- sztp-agent/cmd/enable.go | 8 +- sztp-agent/cmd/run.go | 19 +- sztp-agent/cmd/status.go | 10 +- sztp-agent/pkg/secureagent/agent.go | 32 ++- sztp-agent/pkg/secureagent/agent_test.go | 11 +- sztp-agent/pkg/secureagent/configuration.go | 4 + sztp-agent/pkg/secureagent/daemon.go | 7 + sztp-agent/pkg/secureagent/image.go | 2 + sztp-agent/pkg/secureagent/progress_test.go | 2 +- sztp-agent/pkg/secureagent/run.go | 4 + sztp-agent/pkg/secureagent/status.go | 233 +++++++++++++++++++- sztp-agent/pkg/secureagent/status_test.go | 10 + sztp-agent/pkg/secureagent/utils.go | 18 ++ 17 files changed, 399 insertions(+), 20 deletions(-) diff --git a/docker-compose.dpu.yml b/docker-compose.dpu.yml index 28c9600f..5e281dda 100644 --- a/docker-compose.dpu.yml +++ b/docker-compose.dpu.yml @@ -43,6 +43,9 @@ services: - dhcp-leases-folder:/var/lib/dhclient/ - /etc/os-release:/etc/os-release - /etc/ssh:/etc/ssh + - /var/lib/sztp:/var/lib/sztp + - /run/sztp:/run/sztp + privileged: true networks: - opi command: ['/opi-sztp-agent', 'daemon', @@ -50,7 +53,10 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/third_my_cert.pem', '--device-private-key', '/certs/third_private_key.pem', - '--serial-number', 'third-serial-number'] + '--serial-number', 'third-serial-number', + '--status-file-path', '/var/lib/sztp/status.json', + '--result-file-path', '/var/lib/sztp/result.json', + '--sym-link-dir', '/run/sztp'] agent2: <<: *agent @@ -59,7 +65,10 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/second_my_cert.pem', '--device-private-key', '/certs/second_private_key.pem', - '--serial-number', 'second-serial-number'] + '--serial-number', 'second-serial-number', + '--status-file-path', '/var/lib/sztp/status.json', + '--result-file-path', '/var/lib/sztp/result.json', + '--sym-link-dir', '/run/sztp'] agent1: <<: *agent @@ -68,7 +77,10 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/first_my_cert.pem', '--device-private-key', '/certs/first_private_key.pem', - '--serial-number', 'first-serial-number'] + '--serial-number', 'first-serial-number', + '--status-file-path', '/var/lib/sztp/status.json', + '--result-file-path', '/var/lib/sztp/result.json', + '--sym-link-dir', '/run/sztp'] agent4: <<: *agent @@ -77,7 +89,10 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/first_my_cert.pem', '--device-private-key', '/certs/first_private_key.pem', - '--serial-number', 'first-serial-number'] + '--serial-number', 'first-serial-number', + '--status-file-path', '/var/lib/sztp/status.json', + '--result-file-path', '/var/lib/sztp/result.json', + '--sym-link-dir', '/run/sztp'] agent5: <<: *agent @@ -86,7 +101,10 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/first_my_cert.pem', '--device-private-key', '/certs/first_private_key.pem', - '--serial-number', 'first-serial-number'] + '--serial-number', 'first-serial-number', + '--status-file-path', '/var/lib/sztp/status.json', + '--result-file-path', '/var/lib/sztp/result.json', + '--sym-link-dir', '/run/sztp'] volumes: client-certs: diff --git a/scripts/run_agent.sh b/scripts/run_agent.sh index 7b31a631..2c35965c 100755 --- a/scripts/run_agent.sh +++ b/scripts/run_agent.sh @@ -21,6 +21,9 @@ docker run --rm -it --network=host \ --mount type=bind,source=/etc/ssh,target=/etc/ssh,readonly \ --mount type=bind,source=/etc/os-release,target=/etc/os-release,readonly \ --mount type=bind,source=/var/lib/NetworkManager,target=/var/lib/NetworkManager,readonly \ + --mount type=bind,source=/var/lib/sztp,target=/var/lib/sztp \ + --mount type=bind,source=/run/sztp,target=/run/sztp \ + --privileged \ ${DOCKER_SZTP_IMAGE} \ /opi-sztp-agent daemon \ --dhcp-lease-file /var/lib/NetworkManager/dhclient-eth0.lease \ diff --git a/sztp-agent/cmd/daemon.go b/sztp-agent/cmd/daemon.go index 92d49efd..0dc0a0b1 100644 --- a/sztp-agent/cmd/daemon.go +++ b/sztp-agent/cmd/daemon.go @@ -32,13 +32,16 @@ func Daemon() *cobra.Command { devicePrivateKey string deviceEndEntityCert string bootstrapTrustAnchorCert string + statusFilePath string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ Use: "daemon", Short: "Run the daemon command", RunE: func(_ *cobra.Command, _ []string) error { - arrayChecker := []string{devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert} + arrayChecker := []string{devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath} if bootstrapURL != "" && dhcpLeaseFile != "" { return fmt.Errorf("'--bootstrap-url' and '--dhcp-lease-file' are mutualy exclusive") } @@ -52,6 +55,15 @@ func Daemon() *cobra.Command { _, err := url.ParseRequestURI(bootstrapURL) cobra.CheckErr(err) } + if statusFilePath == "" { + return fmt.Errorf("'--status-file-path' is required") + } + if resultFilePath == "" { + return fmt.Errorf("'--result-file-path' is required") + } + if symLinkDir == "" { + return fmt.Errorf("'--symlink-dir' is required") + } for _, filePath := range arrayChecker { info, err := os.Stat(filePath) cobra.CheckErr(err) @@ -60,7 +72,7 @@ func Daemon() *cobra.Command { } } client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath, symLinkDir, &client) return a.RunCommandDaemon() }, } @@ -75,6 +87,9 @@ func Daemon() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "/certs/private_key.pem", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "/certs/my_cert.pem", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "/certs/opi.pem", "Bootstrap server trust anchor Cert") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Path to the status file") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Path to the result file") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Path to the symlink directory") return cmd } diff --git a/sztp-agent/cmd/disable.go b/sztp-agent/cmd/disable.go index 157776d9..b086aeb3 100644 --- a/sztp-agent/cmd/disable.go +++ b/sztp-agent/cmd/disable.go @@ -28,6 +28,9 @@ func Disable() *cobra.Command { devicePrivateKey string deviceEndEntityCert string bootstrapTrustAnchorCert string + statusFilePath string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ @@ -35,7 +38,7 @@ func Disable() *cobra.Command { Short: "Run the disable command", RunE: func(_ *cobra.Command, _ []string) error { client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath, symLinkDir, &client) return a.RunCommandDisable() }, } @@ -50,5 +53,9 @@ func Disable() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "", "Bootstrap server trust anchor Cert") + flags.StringVar(&statusFilePath, "status-file-path", "", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "", "Sym Link Directory") + return cmd } diff --git a/sztp-agent/cmd/enable.go b/sztp-agent/cmd/enable.go index 49f53b4a..2c6513d9 100644 --- a/sztp-agent/cmd/enable.go +++ b/sztp-agent/cmd/enable.go @@ -28,6 +28,9 @@ func Enable() *cobra.Command { devicePrivateKey string deviceEndEntityCert string bootstrapTrustAnchorCert string + statusFilePath string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ @@ -35,7 +38,7 @@ func Enable() *cobra.Command { Short: "Run the enable command", RunE: func(_ *cobra.Command, _ []string) error { client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath, symLinkDir, &client) return a.RunCommandEnable() }, } @@ -50,6 +53,9 @@ func Enable() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "", "Bootstrap server trust anchor Cert") + flags.StringVar(&statusFilePath, "status-file-path", "", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "", "Sym Link Directory") return cmd } diff --git a/sztp-agent/cmd/run.go b/sztp-agent/cmd/run.go index 99a9b1f9..a76cb791 100644 --- a/sztp-agent/cmd/run.go +++ b/sztp-agent/cmd/run.go @@ -32,13 +32,16 @@ func Run() *cobra.Command { devicePrivateKey string deviceEndEntityCert string bootstrapTrustAnchorCert string + statusFilePath string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ Use: "run", Short: "Exec the run command", RunE: func(_ *cobra.Command, _ []string) error { - arrayChecker := []string{devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert} + arrayChecker := []string{devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath} if bootstrapURL != "" && dhcpLeaseFile != "" { return fmt.Errorf("'--bootstrap-url' and '--dhcp-lease-file' are mutualy exclusive") } @@ -52,6 +55,15 @@ func Run() *cobra.Command { _, err := url.ParseRequestURI(bootstrapURL) cobra.CheckErr(err) } + if statusFilePath == "" { + return fmt.Errorf("'--status-file-path' is required") + } + if resultFilePath == "" { + return fmt.Errorf("'--result-file-path' is required") + } + if symLinkDir == "" { + return fmt.Errorf("'--symlink-dir' is required") + } for _, filePath := range arrayChecker { info, err := os.Stat(filePath) cobra.CheckErr(err) @@ -60,7 +72,7 @@ func Run() *cobra.Command { } } client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath, symLinkDir, &client) return a.RunCommand() }, } @@ -75,6 +87,9 @@ func Run() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "/certs/private_key.pem", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "/certs/my_cert.pem", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "/certs/opi.pem", "Bootstrap server trust anchor Cert") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "", "Sym Link Directory") return cmd } diff --git a/sztp-agent/cmd/status.go b/sztp-agent/cmd/status.go index dbc80dfa..debbaa42 100644 --- a/sztp-agent/cmd/status.go +++ b/sztp-agent/cmd/status.go @@ -28,6 +28,9 @@ func Status() *cobra.Command { devicePrivateKey string deviceEndEntityCert string bootstrapTrustAnchorCert string + statusFilePath string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ @@ -35,7 +38,7 @@ func Status() *cobra.Command { Short: "Run the status command", RunE: func(_ *cobra.Command, _ []string) error { client := secureagent.NewHTTPClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath, symLinkDir, &client) return a.RunCommandStatus() }, } @@ -46,10 +49,13 @@ func Status() *cobra.Command { flags.StringVar(&bootstrapURL, "bootstrap-url", "", "Bootstrap server URL") flags.StringVar(&serialNumber, "serial-number", "", "Device's serial number") flags.StringVar(&dhcpLeaseFile, "dhcp-lease-file", "/var/lib/dhclient/dhclient.leases", "Device's dhclient leases file") - flags.StringVar(&devicePassword, "device-password", "", "Device's password") + flags.StringVar(&devicePassword, "device-password", "", "Dehomevice's password") flags.StringVar(&devicePrivateKey, "device-private-key", "", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "", "Bootstrap server trust anchor Cert") + flags.StringVar(&statusFilePath, "status-file-path", "", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "", "Sym Link Directory") return cmd } diff --git a/sztp-agent/pkg/secureagent/agent.go b/sztp-agent/pkg/secureagent/agent.go index 67970468..eaf33831 100644 --- a/sztp-agent/pkg/secureagent/agent.go +++ b/sztp-agent/pkg/secureagent/agent.go @@ -93,9 +93,12 @@ type Agent struct { BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo // BootstrapServerOnboardingInfo structure BootstrapServerRedirectInfo BootstrapServerRedirectInfo // BootstrapServerRedirectInfo structure HttpClient HttpClient + StatusFilePath string // Path to the status file + ResultFilePath string // Path to the result file + SymLinkDir string // Path to the symlink directory for the status file } -func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string, httpClient HttpClient) *Agent { +func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath, symLinkDir string, httpClient HttpClient) *Agent { return &Agent{ InputBootstrapURL: bootstrapURL, BootstrapURL: "", @@ -111,6 +114,9 @@ func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, deviceP BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{}, BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{}, HttpClient: httpClient, + StatusFilePath: statusFilePath, + ResultFilePath: resultFilePath, + SymLinkDir: symLinkDir, } } @@ -150,6 +156,18 @@ func (a *Agent) GetProgressJSON() ProgressJSON { return a.ProgressJSON } +func (a *Agent) GetStatusFilePath() string { + return a.StatusFilePath +} + +func (a *Agent) GetResultFilePath() string { + return a.ResultFilePath +} + +func (a *Agent) GetSymLinkDir() string { + return a.SymLinkDir +} + func (a *Agent) SetBootstrapURL(url string) { a.BootstrapURL = url } @@ -181,3 +199,15 @@ func (a *Agent) SetContentTypeReq(ct string) { func (a *Agent) SetProgressJSON(p ProgressJSON) { a.ProgressJSON = p } + +func (a *Agent) SetStatusFilePath(path string) { + a.StatusFilePath = path +} + +func (a *Agent) SetResultFilePath(path string) { + a.ResultFilePath = path +} + +func (a *Agent) SetSymLinkDir(path string) { + a.SymLinkDir = path +} diff --git a/sztp-agent/pkg/secureagent/agent_test.go b/sztp-agent/pkg/secureagent/agent_test.go index e8234a8a..25d75f96 100644 --- a/sztp-agent/pkg/secureagent/agent_test.go +++ b/sztp-agent/pkg/secureagent/agent_test.go @@ -829,6 +829,9 @@ func TestNewAgent(t *testing.T) { devicePrivateKey string deviceEndEntityCert string bootstrapTrustAnchorCert string + statusFilePath string + resultFilePath string + symLinkDir string } client := http.Client{} tests := []struct { @@ -846,6 +849,9 @@ func TestNewAgent(t *testing.T) { devicePrivateKey: "TestDevicePrivateKey", deviceEndEntityCert: "TestDeviceEndEntityCert", bootstrapTrustAnchorCert: "TestBootstrapTrustCert", + statusFilePath: "TestStatusFilePath", + resultFilePath: "TestResultFilePath", + symLinkDir: "TestSymLinkDir", }, want: &Agent{ InputBootstrapURL: "TestBootstrap", @@ -858,13 +864,16 @@ func TestNewAgent(t *testing.T) { ContentTypeReq: "application/yang-data+json", InputJSONContent: generateInputJSONContent(), DhcpLeaseFile: "TestDhcpLeaseFile", + StatusFilePath: "TestStatusFilePath", + ResultFilePath: "TestResultFilePath", + SymLinkDir: "TestSymLinkDir", HttpClient: &client, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert, &client); !reflect.DeepEqual(got, tt.want) { + if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert, tt.args.statusFilePath, tt.args.resultFilePath, tt.args.symLinkDir, &client); !reflect.DeepEqual(got, tt.want) { t.Errorf("NewAgent() = %v, want %v", got, tt.want) } }) diff --git a/sztp-agent/pkg/secureagent/configuration.go b/sztp-agent/pkg/secureagent/configuration.go index 1a6f0133..00a569a6 100644 --- a/sztp-agent/pkg/secureagent/configuration.go +++ b/sztp-agent/pkg/secureagent/configuration.go @@ -10,6 +10,7 @@ import ( func (a *Agent) copyConfigurationFile() error { log.Println("[INFO] Starting the Copy Configuration.") _ = a.doReportProgress(ProgressTypeConfigInitiated, "Configuration Initiated") + _ = a.UpdateAndSaveStatus("config", true, "") // Copy the configuration file to the device file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + "-config") if err != nil { @@ -36,6 +37,7 @@ func (a *Agent) copyConfigurationFile() error { } log.Println("[INFO] Configuration file copied successfully") _ = a.doReportProgress(ProgressTypeConfigComplete, "Configuration Complete") + _ = a.UpdateAndSaveStatus("config", false, "") return nil } @@ -56,6 +58,7 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error { } log.Println("[INFO] Starting the " + scriptName + "-configuration.") _ = a.doReportProgress(reportStart, "Report starting") + _ = a.UpdateAndSaveStatus(scriptName+"-script", true, "") // nolint:gosec file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + scriptName + "configuration.sh") if err != nil { @@ -89,6 +92,7 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error { } log.Println(string(out)) // remove it _ = a.doReportProgress(reportEnd, "Report end") + _ = a.UpdateAndSaveStatus(scriptName+"-script", false, "") log.Println("[INFO] " + scriptName + "-Configuration script executed successfully") return nil } diff --git a/sztp-agent/pkg/secureagent/daemon.go b/sztp-agent/pkg/secureagent/daemon.go index 2064d90e..a829a324 100644 --- a/sztp-agent/pkg/secureagent/daemon.go +++ b/sztp-agent/pkg/secureagent/daemon.go @@ -33,6 +33,10 @@ const ( // RunCommandDaemon runs the command in the background func (a *Agent) RunCommandDaemon() error { + if err := a.PrepareStatus(); err != nil { + log.Println("failed to prepare status: ", err) + return err + } for { err := a.performBootstrapSequence() if err != nil { @@ -46,6 +50,7 @@ func (a *Agent) RunCommandDaemon() error { } func (a *Agent) performBootstrapSequence() error { + _ = a.UpdateAndSaveStatus("bootstrap", true, "") var err error err = a.discoverBootstrapURLs() if err != nil { @@ -76,6 +81,7 @@ func (a *Agent) performBootstrapSequence() error { return err } _ = a.doReportProgress(ProgressTypeBootstrapComplete, "Bootstrap Complete") + _ = a.UpdateAndSaveStatus("bootstrap", false, "") return nil } @@ -142,6 +148,7 @@ func (a *Agent) doRequestBootstrapServerOnboardingInfo() error { } log.Println("[INFO] Response retrieved successfully") _ = a.doReportProgress(ProgressTypeBootstrapInitiated, "Bootstrap Initiated") + _ = a.UpdateAndSaveStatus("bootstrap", true, "") crypto := res.IetfSztpBootstrapServerOutput.ConveyedInformation newVal, err := base64.StdEncoding.DecodeString(crypto) if err != nil { diff --git a/sztp-agent/pkg/secureagent/image.go b/sztp-agent/pkg/secureagent/image.go index 9af486c7..a918bd51 100644 --- a/sztp-agent/pkg/secureagent/image.go +++ b/sztp-agent/pkg/secureagent/image.go @@ -23,6 +23,7 @@ import ( func (a *Agent) downloadAndValidateImage() error { log.Printf("[INFO] Starting the Download Image: %v", a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI) _ = a.doReportProgress(ProgressTypeBootImageInitiated, "BootImage Initiated") + _ = a.UpdateAndSaveStatus("boot-image", true, "") // Download the image from DownloadURI and save it to a file a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference = fmt.Sprintf("%8d", time.Now().Unix()) for i, item := range a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI { @@ -78,6 +79,7 @@ func (a *Agent) downloadAndValidateImage() error { } log.Println("[INFO] Checksum verified successfully") _ = a.doReportProgress(ProgressTypeBootImageComplete, "BootImage Complete") + _ = a.UpdateAndSaveStatus("boot-image", false, "") return nil default: return errors.New("unsupported hash algorithm") diff --git a/sztp-agent/pkg/secureagent/progress_test.go b/sztp-agent/pkg/secureagent/progress_test.go index a821dcf9..78def06b 100644 --- a/sztp-agent/pkg/secureagent/progress_test.go +++ b/sztp-agent/pkg/secureagent/progress_test.go @@ -142,7 +142,7 @@ func TestAgent_doReportProgress(t *testing.T) { DhcpLeaseFile: "DHCPLEASEFILE", ProgressJSON: ProgressJSON{}, }, - wantErr: true, + wantErr: true, }, } for _, tt := range tests { diff --git a/sztp-agent/pkg/secureagent/run.go b/sztp-agent/pkg/secureagent/run.go index b3b2e4c9..978fc036 100644 --- a/sztp-agent/pkg/secureagent/run.go +++ b/sztp-agent/pkg/secureagent/run.go @@ -13,6 +13,10 @@ import "log" // RunCommand runs the command in the background func (a *Agent) RunCommand() error { log.Println("runCommand started") + if err := a.PrepareStatus(); err != nil { + log.Println("failed to prepare status: ", err) + return err + } err := a.performBootstrapSequence() if err != nil { log.Println("Error in performBootstrapSequence inside runCommand: ", err) diff --git a/sztp-agent/pkg/secureagent/status.go b/sztp-agent/pkg/secureagent/status.go index e5341fdf..9dcc15d5 100644 --- a/sztp-agent/pkg/secureagent/status.go +++ b/sztp-agent/pkg/secureagent/status.go @@ -8,19 +8,244 @@ Copyright (C) 2022 Red Hat. // Package secureagent implements the secure agent package secureagent -import "log" +import ( + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "time" +) + +// Status represents the status of the provisioning process. +type Status struct { + Init StageStatus `json:"init"` + DownloadingFile StageStatus `json:"downloading-file"` // not sure if this is needed + PendingReboot StageStatus `json:"pending-reboot"` + Parsing StageStatus `json:"parsing"` + BootImage StageStatus `json:"boot-image"` + PreScript StageStatus `json:"pre-script"` + Config StageStatus `json:"config"` + PostScript StageStatus `json:"post-script"` + Bootstrap StageStatus `json:"bootstrap"` + IsCompleted StageStatus `json:"is-completed"` + Informational string `json:"informational"` + DataSource string `json:"datasource"` + Stage string `json:"stage"` +} + +type Result struct { + DataSource string `json:"dat asource"` + Errors []string `json:"errors"` +} + +type StageStatus struct { + Errors []string `json:"errors"` + Start float64 `json:"start"` + End float64 `json:"end"` +} + +// LoadStatusFile loads the current status.json from the filesystem. +func (a *Agent) loadStatusFile() (*Status, error) { + file, err := os.ReadFile(a.GetStatusFilePath()) + if err != nil { + return nil, err + } + var status Status + err = json.Unmarshal(file, &status) + if err != nil { + return nil, err + } + return &status, nil +} + +func (a *Agent) UpdateAndSaveStatus(stage string, isStart bool, errMsg string) error { + status, err := a.loadStatusFile() + if err != nil { + fmt.Println("Creating a new status file.") + status = a.createNewStatus() + } + + if err := a.updateStageStatus(status, stage, isStart, errMsg); err != nil { + return err + } + + return a.saveStatus(status) +} + +// createNewStatus initializes a new Status object when status.json doesn't exist. +func (a *Agent) createNewStatus() *Status { + return &Status{ + DataSource: "ds", + Stage: "", + } +} + +// updateStageStatus updates the status object for a specific stage. +func (a *Agent) updateStageStatus(status *Status, stage string, isStart bool, errMsg string) error { + now := float64(time.Now().Unix()) + + switch stage { + case "init": + updateStage(&status.Init, isStart, now, errMsg) + case "downloading-file": + updateStage(&status.DownloadingFile, isStart, now, errMsg) + case "pending-reboot": + updateStage(&status.PendingReboot, isStart, now, errMsg) + case "is-completed": + updateStage(&status.IsCompleted, isStart, now, errMsg) + case "parsing": + updateStage(&status.Parsing, isStart, now, errMsg) + case "boot-image": + updateStage(&status.BootImage, isStart, now, errMsg) + case "pre-script": + updateStage(&status.PreScript, isStart, now, errMsg) + case "config": + updateStage(&status.Config, isStart, now, errMsg) + case "post-script": + updateStage(&status.PostScript, isStart, now, errMsg) + case "bootstrap": + updateStage(&status.Bootstrap, isStart, now, errMsg) + + default: + return fmt.Errorf("unknown stage: %s", stage) + } + + // Update the current stage + if isStart { + status.Stage = stage + } else { + status.Stage = "" + } + + return nil +} + +func updateStage(stageStatus *StageStatus, isStart bool, now float64, errMsg string) { + if isStart { + stageStatus.Start = now + stageStatus.End = 0 + } else { + stageStatus.End = now + if errMsg != "" { + stageStatus.Errors = append(stageStatus.Errors, errMsg) + } + } +} + +// SaveStatusToFile saves the Status object to the status.json file. +func (a *Agent) saveStatus(status *Status) error { + return saveToFile(status, a.GetStatusFilePath()) +} + +// SaveResultFile saves the Result object to the result.json file. +func (a *Agent) saveResult(result *Result) error { + return saveToFile(result, a.GetResultFilePath()) +} + +// EnsureDirExists checks if a directory exists, and creates it if it doesn't. +func EnsureDirExists(dir string) error { + if _, err := os.Stat(dir); os.IsNotExist(err) { + err := os.MkdirAll(dir, 0755) // Create the directory with appropriate permissions + if err != nil { + return fmt.Errorf("failed to create directory %s: %v", dir, err) + } + } + return nil +} + +// EnsureFile ensures that a file exists; creates it if it does not. +func EnsureFileExists(filePath string) error { + // Ensure the directory exists + dir := filepath.Dir(filePath) + if err := EnsureDirExists(dir); err != nil { + return err + } + + // Check if the file already exists + if _, err := os.Stat(filePath); os.IsNotExist(err) { + // File does not exist, create it + file, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("failed to create file %s: %v", filePath, err) + } + defer file.Close() + fmt.Printf("File %s created successfully.\n", filePath) + } else { + fmt.Printf("File %s already exists.\n", filePath) + } + return nil +} + +// CreateSymlink creates a symlink for a file from target to link location. +func CreateSymlink(targetFile, linkFile string) error { + // Ensure the directory for the symlink exists + linkDir := filepath.Dir(linkFile) + if err := EnsureDirExists(linkDir); err != nil { + return err + } + + // Remove any existing symlink + if _, err := os.Lstat(linkFile); err == nil { + os.Remove(linkFile) + } + + // Create a new symlink + return os.Symlink(targetFile, linkFile) +} // RunCommandStatus runs the command in the background func (a *Agent) RunCommandStatus() error { log.Println("RunCommandStatus") + // read the status file and print the status in command line + status, err := a.loadStatusFile() + if err != nil { + log.Println("failed to load status file: ", err) + return err + } + fmt.Printf("Current status: %+v\n", status) return nil } -/* -func (a *Agent) prepareEnvStatus() error { - log.Println("prepareEnvStatus") +func (a *Agent) PrepareStatus() error { + log.Println("prepareStatus") + + // Ensure /run/sztp directory exists + if err := EnsureDirExists(a.GetSymLinkDir()); err != nil { + fmt.Printf("Failed to create directory %s: %v\n", a.GetSymLinkDir(), err) + return err + } + + if err := EnsureFileExists(a.GetStatusFilePath()); err != nil { + return err + } + if err := EnsureFileExists(a.GetResultFilePath()); err != nil { + return err + } + + statusSymlinkPath := filepath.Join(a.GetSymLinkDir(), "status.json") + resultSymlinkPath := filepath.Join(a.GetSymLinkDir(), "result.json") + + // Create symlinks for status.json and result.json + if err := CreateSymlink(a.GetStatusFilePath(), statusSymlinkPath); err != nil { + fmt.Printf("Failed to create symlink for status.json: %v\n", err) + return err + } + if err := CreateSymlink(a.GetResultFilePath(), resultSymlinkPath); err != nil { + fmt.Printf("Failed to create symlink for result.json: %v\n", err) + return err + } + + fmt.Println("Symlinks created successfully.") + + if err := a.UpdateAndSaveStatus("init", true, ""); err != nil { + return err + } + return nil } + +/* func (a *Agent) configureStatus() error { log.Println("configureStatus") return nil diff --git a/sztp-agent/pkg/secureagent/status_test.go b/sztp-agent/pkg/secureagent/status_test.go index 2aa11331..ece4925e 100644 --- a/sztp-agent/pkg/secureagent/status_test.go +++ b/sztp-agent/pkg/secureagent/status_test.go @@ -20,6 +20,9 @@ func TestAgent_RunCommandStatus(t *testing.T) { ProgressJSON ProgressJSON BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo BootstrapServerRedirectInfo BootstrapServerRedirectInfo + StatusFilePath string + ResultFilePath string + SymLinkDir string } tests := []struct { name string @@ -41,6 +44,9 @@ func TestAgent_RunCommandStatus(t *testing.T) { ProgressJSON: ProgressJSON{}, BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{}, BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{}, + StatusFilePath: "/var/lib/sztp/status.json", + ResultFilePath: "/var/lib/sztp/result.json", + SymLinkDir: "/run/sztp", }, }, } @@ -59,7 +65,11 @@ func TestAgent_RunCommandStatus(t *testing.T) { ProgressJSON: tt.fields.ProgressJSON, BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo, BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo, + StatusFilePath: tt.fields.StatusFilePath, + ResultFilePath: tt.fields.ResultFilePath, + SymLinkDir: tt.fields.SymLinkDir, } + a.PrepareStatus() if err := a.RunCommandStatus(); (err != nil) != tt.wantErr { t.Errorf("RunCommandStatus() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/sztp-agent/pkg/secureagent/utils.go b/sztp-agent/pkg/secureagent/utils.go index e9cc7a6b..8ed65069 100644 --- a/sztp-agent/pkg/secureagent/utils.go +++ b/sztp-agent/pkg/secureagent/utils.go @@ -88,3 +88,21 @@ func calculateSHA256File(filePath string) (string, error) { checkSum := fmt.Sprintf("%x", h.Sum(nil)) return checkSum, nil } + +// saveToFile writes the given data to a specified file path. +func saveToFile(data interface{}, filePath string) error { + tempPath := filePath + ".tmp" + file, err := os.Create(tempPath) + if err != nil { + return err + } + defer file.Close() + + encoder := json.NewEncoder(file) + if err := encoder.Encode(data); err != nil { + return err + } + + // Atomic move of temp file to replace the original. + return os.Rename(tempPath, filePath) +} From 5b861cee3fd04b3e884fd501d019cdea74d1f843 Mon Sep 17 00:00:00 2001 From: Bhoopesh Date: Sun, 20 Oct 2024 16:41:13 +0530 Subject: [PATCH 2/4] feat: tests for status util funtions Signed-off-by: Bhoopesh --- sztp-agent/pkg/secureagent/configuration.go | 8 +- sztp-agent/pkg/secureagent/daemon.go | 8 +- sztp-agent/pkg/secureagent/image.go | 4 +- sztp-agent/pkg/secureagent/run.go | 2 +- sztp-agent/pkg/secureagent/status.go | 67 ++------- sztp-agent/pkg/secureagent/status_test.go | 2 +- sztp-agent/pkg/secureagent/utils.go | 51 +++++++ sztp-agent/pkg/secureagent/utils_test.go | 150 ++++++++++++++++++++ 8 files changed, 221 insertions(+), 71 deletions(-) diff --git a/sztp-agent/pkg/secureagent/configuration.go b/sztp-agent/pkg/secureagent/configuration.go index 00a569a6..95b35b6a 100644 --- a/sztp-agent/pkg/secureagent/configuration.go +++ b/sztp-agent/pkg/secureagent/configuration.go @@ -10,7 +10,7 @@ import ( func (a *Agent) copyConfigurationFile() error { log.Println("[INFO] Starting the Copy Configuration.") _ = a.doReportProgress(ProgressTypeConfigInitiated, "Configuration Initiated") - _ = a.UpdateAndSaveStatus("config", true, "") + _ = a.updateAndSaveStatus("config", true, "") // Copy the configuration file to the device file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + "-config") if err != nil { @@ -37,7 +37,7 @@ func (a *Agent) copyConfigurationFile() error { } log.Println("[INFO] Configuration file copied successfully") _ = a.doReportProgress(ProgressTypeConfigComplete, "Configuration Complete") - _ = a.UpdateAndSaveStatus("config", false, "") + _ = a.updateAndSaveStatus("config", false, "") return nil } @@ -58,7 +58,7 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error { } log.Println("[INFO] Starting the " + scriptName + "-configuration.") _ = a.doReportProgress(reportStart, "Report starting") - _ = a.UpdateAndSaveStatus(scriptName+"-script", true, "") + _ = a.updateAndSaveStatus(scriptName+"-script", true, "") // nolint:gosec file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + scriptName + "configuration.sh") if err != nil { @@ -92,7 +92,7 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error { } log.Println(string(out)) // remove it _ = a.doReportProgress(reportEnd, "Report end") - _ = a.UpdateAndSaveStatus(scriptName+"-script", false, "") + _ = a.updateAndSaveStatus(scriptName+"-script", false, "") log.Println("[INFO] " + scriptName + "-Configuration script executed successfully") return nil } diff --git a/sztp-agent/pkg/secureagent/daemon.go b/sztp-agent/pkg/secureagent/daemon.go index a829a324..066e5714 100644 --- a/sztp-agent/pkg/secureagent/daemon.go +++ b/sztp-agent/pkg/secureagent/daemon.go @@ -33,7 +33,7 @@ const ( // RunCommandDaemon runs the command in the background func (a *Agent) RunCommandDaemon() error { - if err := a.PrepareStatus(); err != nil { + if err := a.prepareStatus(); err != nil { log.Println("failed to prepare status: ", err) return err } @@ -50,7 +50,7 @@ func (a *Agent) RunCommandDaemon() error { } func (a *Agent) performBootstrapSequence() error { - _ = a.UpdateAndSaveStatus("bootstrap", true, "") + _ = a.updateAndSaveStatus("bootstrap", true, "") var err error err = a.discoverBootstrapURLs() if err != nil { @@ -81,7 +81,7 @@ func (a *Agent) performBootstrapSequence() error { return err } _ = a.doReportProgress(ProgressTypeBootstrapComplete, "Bootstrap Complete") - _ = a.UpdateAndSaveStatus("bootstrap", false, "") + _ = a.updateAndSaveStatus("bootstrap", false, "") return nil } @@ -148,7 +148,7 @@ func (a *Agent) doRequestBootstrapServerOnboardingInfo() error { } log.Println("[INFO] Response retrieved successfully") _ = a.doReportProgress(ProgressTypeBootstrapInitiated, "Bootstrap Initiated") - _ = a.UpdateAndSaveStatus("bootstrap", true, "") + _ = a.updateAndSaveStatus("bootstrap", true, "") crypto := res.IetfSztpBootstrapServerOutput.ConveyedInformation newVal, err := base64.StdEncoding.DecodeString(crypto) if err != nil { diff --git a/sztp-agent/pkg/secureagent/image.go b/sztp-agent/pkg/secureagent/image.go index a918bd51..44666985 100644 --- a/sztp-agent/pkg/secureagent/image.go +++ b/sztp-agent/pkg/secureagent/image.go @@ -23,7 +23,7 @@ import ( func (a *Agent) downloadAndValidateImage() error { log.Printf("[INFO] Starting the Download Image: %v", a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI) _ = a.doReportProgress(ProgressTypeBootImageInitiated, "BootImage Initiated") - _ = a.UpdateAndSaveStatus("boot-image", true, "") + _ = a.updateAndSaveStatus("boot-image", true, "") // Download the image from DownloadURI and save it to a file a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference = fmt.Sprintf("%8d", time.Now().Unix()) for i, item := range a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI { @@ -79,7 +79,7 @@ func (a *Agent) downloadAndValidateImage() error { } log.Println("[INFO] Checksum verified successfully") _ = a.doReportProgress(ProgressTypeBootImageComplete, "BootImage Complete") - _ = a.UpdateAndSaveStatus("boot-image", false, "") + _ = a.updateAndSaveStatus("boot-image", false, "") return nil default: return errors.New("unsupported hash algorithm") diff --git a/sztp-agent/pkg/secureagent/run.go b/sztp-agent/pkg/secureagent/run.go index 978fc036..09dbf8c4 100644 --- a/sztp-agent/pkg/secureagent/run.go +++ b/sztp-agent/pkg/secureagent/run.go @@ -13,7 +13,7 @@ import "log" // RunCommand runs the command in the background func (a *Agent) RunCommand() error { log.Println("runCommand started") - if err := a.PrepareStatus(); err != nil { + if err := a.prepareStatus(); err != nil { log.Println("failed to prepare status: ", err) return err } diff --git a/sztp-agent/pkg/secureagent/status.go b/sztp-agent/pkg/secureagent/status.go index 9dcc15d5..656471cb 100644 --- a/sztp-agent/pkg/secureagent/status.go +++ b/sztp-agent/pkg/secureagent/status.go @@ -59,7 +59,7 @@ func (a *Agent) loadStatusFile() (*Status, error) { return &status, nil } -func (a *Agent) UpdateAndSaveStatus(stage string, isStart bool, errMsg string) error { +func (a *Agent) updateAndSaveStatus(stage string, isStart bool, errMsg string) error { status, err := a.loadStatusFile() if err != nil { fmt.Println("Creating a new status file.") @@ -143,57 +143,6 @@ func (a *Agent) saveResult(result *Result) error { return saveToFile(result, a.GetResultFilePath()) } -// EnsureDirExists checks if a directory exists, and creates it if it doesn't. -func EnsureDirExists(dir string) error { - if _, err := os.Stat(dir); os.IsNotExist(err) { - err := os.MkdirAll(dir, 0755) // Create the directory with appropriate permissions - if err != nil { - return fmt.Errorf("failed to create directory %s: %v", dir, err) - } - } - return nil -} - -// EnsureFile ensures that a file exists; creates it if it does not. -func EnsureFileExists(filePath string) error { - // Ensure the directory exists - dir := filepath.Dir(filePath) - if err := EnsureDirExists(dir); err != nil { - return err - } - - // Check if the file already exists - if _, err := os.Stat(filePath); os.IsNotExist(err) { - // File does not exist, create it - file, err := os.Create(filePath) - if err != nil { - return fmt.Errorf("failed to create file %s: %v", filePath, err) - } - defer file.Close() - fmt.Printf("File %s created successfully.\n", filePath) - } else { - fmt.Printf("File %s already exists.\n", filePath) - } - return nil -} - -// CreateSymlink creates a symlink for a file from target to link location. -func CreateSymlink(targetFile, linkFile string) error { - // Ensure the directory for the symlink exists - linkDir := filepath.Dir(linkFile) - if err := EnsureDirExists(linkDir); err != nil { - return err - } - - // Remove any existing symlink - if _, err := os.Lstat(linkFile); err == nil { - os.Remove(linkFile) - } - - // Create a new symlink - return os.Symlink(targetFile, linkFile) -} - // RunCommandStatus runs the command in the background func (a *Agent) RunCommandStatus() error { log.Println("RunCommandStatus") @@ -207,19 +156,19 @@ func (a *Agent) RunCommandStatus() error { return nil } -func (a *Agent) PrepareStatus() error { +func (a *Agent) prepareStatus() error { log.Println("prepareStatus") // Ensure /run/sztp directory exists - if err := EnsureDirExists(a.GetSymLinkDir()); err != nil { + if err := ensureDirExists(a.GetSymLinkDir()); err != nil { fmt.Printf("Failed to create directory %s: %v\n", a.GetSymLinkDir(), err) return err } - if err := EnsureFileExists(a.GetStatusFilePath()); err != nil { + if err := ensureFileExists(a.GetStatusFilePath()); err != nil { return err } - if err := EnsureFileExists(a.GetResultFilePath()); err != nil { + if err := ensureFileExists(a.GetResultFilePath()); err != nil { return err } @@ -227,18 +176,18 @@ func (a *Agent) PrepareStatus() error { resultSymlinkPath := filepath.Join(a.GetSymLinkDir(), "result.json") // Create symlinks for status.json and result.json - if err := CreateSymlink(a.GetStatusFilePath(), statusSymlinkPath); err != nil { + if err := createSymlink(a.GetStatusFilePath(), statusSymlinkPath); err != nil { fmt.Printf("Failed to create symlink for status.json: %v\n", err) return err } - if err := CreateSymlink(a.GetResultFilePath(), resultSymlinkPath); err != nil { + if err := createSymlink(a.GetResultFilePath(), resultSymlinkPath); err != nil { fmt.Printf("Failed to create symlink for result.json: %v\n", err) return err } fmt.Println("Symlinks created successfully.") - if err := a.UpdateAndSaveStatus("init", true, ""); err != nil { + if err := a.updateAndSaveStatus("init", true, ""); err != nil { return err } diff --git a/sztp-agent/pkg/secureagent/status_test.go b/sztp-agent/pkg/secureagent/status_test.go index ece4925e..d16d14c8 100644 --- a/sztp-agent/pkg/secureagent/status_test.go +++ b/sztp-agent/pkg/secureagent/status_test.go @@ -69,7 +69,7 @@ func TestAgent_RunCommandStatus(t *testing.T) { ResultFilePath: tt.fields.ResultFilePath, SymLinkDir: tt.fields.SymLinkDir, } - a.PrepareStatus() + a.prepareStatus() if err := a.RunCommandStatus(); (err != nil) != tt.wantErr { t.Errorf("RunCommandStatus() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/sztp-agent/pkg/secureagent/utils.go b/sztp-agent/pkg/secureagent/utils.go index 8ed65069..0eddb2f3 100644 --- a/sztp-agent/pkg/secureagent/utils.go +++ b/sztp-agent/pkg/secureagent/utils.go @@ -106,3 +106,54 @@ func saveToFile(data interface{}, filePath string) error { // Atomic move of temp file to replace the original. return os.Rename(tempPath, filePath) } + +// EnsureDirExists checks if a directory exists, and creates it if it doesn't. +func ensureDirExists(dir string) error { + if _, err := os.Stat(dir); os.IsNotExist(err) { + err := os.MkdirAll(dir, 0755) // Create the directory with appropriate permissions + if err != nil { + return fmt.Errorf("failed to create directory %s: %v", dir, err) + } + } + return nil +} + +// EnsureFile ensures that a file exists; creates it if it does not. +func ensureFileExists(filePath string) error { + // Ensure the directory exists + dir := filepath.Dir(filePath) + if err := ensureDirExists(dir); err != nil { + return err + } + + // Check if the file already exists + if _, err := os.Stat(filePath); os.IsNotExist(err) { + // File does not exist, create it + file, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("failed to create file %s: %v", filePath, err) + } + defer file.Close() + fmt.Printf("File %s created successfully.\n", filePath) + } else { + fmt.Printf("File %s already exists.\n", filePath) + } + return nil +} + +// CreateSymlink creates a symlink for a file from target to link location. +func createSymlink(targetFile, linkFile string) error { + // Ensure the directory for the symlink exists + linkDir := filepath.Dir(linkFile) + if err := ensureDirExists(linkDir); err != nil { + return err + } + + // Remove any existing symlink + if _, err := os.Lstat(linkFile); err == nil { + os.Remove(linkFile) + } + + // Create a new symlink + return os.Symlink(targetFile, linkFile) +} diff --git a/sztp-agent/pkg/secureagent/utils_test.go b/sztp-agent/pkg/secureagent/utils_test.go index bf32fec3..42e134f1 100644 --- a/sztp-agent/pkg/secureagent/utils_test.go +++ b/sztp-agent/pkg/secureagent/utils_test.go @@ -5,7 +5,9 @@ package secureagent import ( + "encoding/json" "os" + "path/filepath" "testing" ) @@ -76,3 +78,151 @@ func Test_calculateSHA256File(t *testing.T) { t.Errorf("Checksum did not match %s %s", checksum, expected) } } + +func Test_saveToFile(t *testing.T) { + tempDir, err := os.MkdirTemp("", "test_save_to_file") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + filePath := filepath.Join(tempDir, "test.json") + data := map[string]string{"key": "value"} + + err = saveToFile(data, filePath) + if err != nil { + t.Fatalf("saveToFile returned an error: %v", err) + } + + _, err = os.Stat(filePath) + if os.IsNotExist(err) { + t.Fatalf("file %s was not created", filePath) + } + + file, err := os.Open(filePath) + if err != nil { + t.Fatalf("failed to open the file: %v", err) + } + defer file.Close() + + var readData map[string]string + decoder := json.NewDecoder(file) + err = decoder.Decode(&readData) + if err != nil { + t.Fatalf("failed to decode JSON data: %v", err) + } + + if readData["key"] != "value" { + t.Errorf("expected 'key' to be 'value', got %s", readData["key"]) + } +} + +func TestEnsureDirExists(t *testing.T) { + tempDir, err := os.MkdirTemp("", "test_ensure_dir_exists") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + newDir := filepath.Join(tempDir, "newdir") + + if _, err := os.Stat(newDir); !os.IsNotExist(err) { + t.Fatalf("expected directory %s to not exist", newDir) + } + + err = ensureDirExists(newDir) + if err != nil { + t.Fatalf("ensureDirExists returned an error: %v", err) + } + + if _, err := os.Stat(newDir); os.IsNotExist(err) { + t.Fatalf("expected directory %s to be created", newDir) + } + + err = ensureDirExists(newDir) + if err != nil { + t.Fatalf("ensureDirExists returned an error when directory already exists: %v", err) + } +} + +func TestEnsureFileExists(t *testing.T) { + tempDir, err := os.MkdirTemp("", "test_ensure_file_exists") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + newFilePath := filepath.Join(tempDir, "newdir", "testfile.txt") + + err = ensureFileExists(newFilePath) + if err != nil { + t.Fatalf("ensureFileExists returned an error: %v", err) + } + + if _, err := os.Stat(newFilePath); os.IsNotExist(err) { + t.Fatalf("expected file %s to be created", newFilePath) + } + + err = ensureFileExists(newFilePath) + if err != nil { + t.Fatalf("ensureFileExists returned an error when file already exists: %v", err) + } +} + +func TestCreateSymlink(t *testing.T) { + tempDir, err := os.MkdirTemp("", "test_create_symlink") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + targetFile := filepath.Join(tempDir, "target.txt") + linkFile := filepath.Join(tempDir, "link.txt") + + err = os.WriteFile(targetFile, []byte("test data"), 0644) + if err != nil { + t.Fatalf("failed to create target file: %v", err) + } + + err = createSymlink(targetFile, linkFile) + if err != nil { + t.Fatalf("createSymlink returned an error: %v", err) + } + + linkInfo, err := os.Lstat(linkFile) + t.Logf("linkInfo: %v", linkInfo) /// + if err != nil { + t.Fatalf("failed to stat symlink: %v", err) + } + if linkInfo.Mode()&os.ModeSymlink == 0 { + t.Errorf("expected %s to be a symlink", linkFile) + } + + target, err := os.Readlink(linkFile) + if err != nil { + t.Fatalf("failed to read symlink: %v", err) + } + if target != targetFile { + t.Errorf("expected symlink to point to %s, got %s", targetFile, target) + } + + newTargetFile := filepath.Join(tempDir, "new_target.txt") + err = os.WriteFile(newTargetFile, []byte("new data"), 0644) + if err != nil { + t.Fatalf("failed to create new target file: %v", err) + } + + err = createSymlink(newTargetFile, linkFile) + if err != nil { + t.Fatalf("createSymlink returned an error when replacing symlink: %v", err) + } + + newTarget, err := os.Readlink(linkFile) + if err != nil { + t.Fatalf("failed to read new symlink: %v", err) + } + + if newTarget != newTargetFile { + t.Errorf("expected symlink to point to %s, got %s", newTargetFile, newTarget) + } +} From 501e683218677acaf50f86fd05dd75375038bd64 Mon Sep 17 00:00:00 2001 From: Bhoopesh Date: Sun, 20 Oct 2024 17:00:15 +0530 Subject: [PATCH 3/4] fix: result and status Signed-off-by: Bhoopesh --- docker-compose.dpu.yml | 25 +- sztp-agent/cmd/daemon.go | 21 +- sztp-agent/cmd/disable.go | 10 +- sztp-agent/cmd/enable.go | 10 +- sztp-agent/cmd/run.go | 15 +- sztp-agent/cmd/status.go | 12 +- sztp-agent/pkg/secureagent/agent.go | 8 +- sztp-agent/pkg/secureagent/agent_test.go | 10 +- sztp-agent/pkg/secureagent/configuration.go | 22 +- sztp-agent/pkg/secureagent/daemon.go | 15 +- sztp-agent/pkg/secureagent/image.go | 4 +- sztp-agent/pkg/secureagent/progress_test.go | 2 +- sztp-agent/pkg/secureagent/status.go | 284 +++++++++++++------- sztp-agent/pkg/secureagent/status_test.go | 22 +- sztp-agent/pkg/secureagent/utils.go | 111 +++++--- sztp-agent/pkg/secureagent/utils_test.go | 42 ++- 16 files changed, 360 insertions(+), 253 deletions(-) diff --git a/docker-compose.dpu.yml b/docker-compose.dpu.yml index 5e281dda..7e9bd162 100644 --- a/docker-compose.dpu.yml +++ b/docker-compose.dpu.yml @@ -53,10 +53,7 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/third_my_cert.pem', '--device-private-key', '/certs/third_private_key.pem', - '--serial-number', 'third-serial-number', - '--status-file-path', '/var/lib/sztp/status.json', - '--result-file-path', '/var/lib/sztp/result.json', - '--sym-link-dir', '/run/sztp'] + '--serial-number', 'third-serial-number'] agent2: <<: *agent @@ -65,10 +62,7 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/second_my_cert.pem', '--device-private-key', '/certs/second_private_key.pem', - '--serial-number', 'second-serial-number', - '--status-file-path', '/var/lib/sztp/status.json', - '--result-file-path', '/var/lib/sztp/result.json', - '--sym-link-dir', '/run/sztp'] + '--serial-number', 'second-serial-number'] agent1: <<: *agent @@ -77,10 +71,7 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/first_my_cert.pem', '--device-private-key', '/certs/first_private_key.pem', - '--serial-number', 'first-serial-number', - '--status-file-path', '/var/lib/sztp/status.json', - '--result-file-path', '/var/lib/sztp/result.json', - '--sym-link-dir', '/run/sztp'] + '--serial-number', 'first-serial-number'] agent4: <<: *agent @@ -89,10 +80,7 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/first_my_cert.pem', '--device-private-key', '/certs/first_private_key.pem', - '--serial-number', 'first-serial-number', - '--status-file-path', '/var/lib/sztp/status.json', - '--result-file-path', '/var/lib/sztp/result.json', - '--sym-link-dir', '/run/sztp'] + '--serial-number', 'first-serial-number'] agent5: <<: *agent @@ -101,10 +89,7 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/first_my_cert.pem', '--device-private-key', '/certs/first_private_key.pem', - '--serial-number', 'first-serial-number', - '--status-file-path', '/var/lib/sztp/status.json', - '--result-file-path', '/var/lib/sztp/result.json', - '--sym-link-dir', '/run/sztp'] + '--serial-number', 'first-serial-number'] volumes: client-certs: diff --git a/sztp-agent/cmd/daemon.go b/sztp-agent/cmd/daemon.go index 0dc0a0b1..7cc4c22e 100644 --- a/sztp-agent/cmd/daemon.go +++ b/sztp-agent/cmd/daemon.go @@ -33,15 +33,15 @@ func Daemon() *cobra.Command { deviceEndEntityCert string bootstrapTrustAnchorCert string statusFilePath string - resultFilePath string - symLinkDir string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ Use: "daemon", Short: "Run the daemon command", RunE: func(_ *cobra.Command, _ []string) error { - arrayChecker := []string{devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath} + arrayChecker := []string{devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath} if bootstrapURL != "" && dhcpLeaseFile != "" { return fmt.Errorf("'--bootstrap-url' and '--dhcp-lease-file' are mutualy exclusive") } @@ -55,15 +55,6 @@ func Daemon() *cobra.Command { _, err := url.ParseRequestURI(bootstrapURL) cobra.CheckErr(err) } - if statusFilePath == "" { - return fmt.Errorf("'--status-file-path' is required") - } - if resultFilePath == "" { - return fmt.Errorf("'--result-file-path' is required") - } - if symLinkDir == "" { - return fmt.Errorf("'--symlink-dir' is required") - } for _, filePath := range arrayChecker { info, err := os.Stat(filePath) cobra.CheckErr(err) @@ -87,9 +78,9 @@ func Daemon() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "/certs/private_key.pem", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "/certs/my_cert.pem", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "/certs/opi.pem", "Bootstrap server trust anchor Cert") - flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Path to the status file") - flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Path to the result file") - flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Path to the symlink directory") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") return cmd } diff --git a/sztp-agent/cmd/disable.go b/sztp-agent/cmd/disable.go index b086aeb3..732fb1c9 100644 --- a/sztp-agent/cmd/disable.go +++ b/sztp-agent/cmd/disable.go @@ -29,8 +29,8 @@ func Disable() *cobra.Command { deviceEndEntityCert string bootstrapTrustAnchorCert string statusFilePath string - resultFilePath string - symLinkDir string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ @@ -53,9 +53,9 @@ func Disable() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "", "Bootstrap server trust anchor Cert") - flags.StringVar(&statusFilePath, "status-file-path", "", "Status file path") - flags.StringVar(&resultFilePath, "result-file-path", "", "Result file path") - flags.StringVar(&symLinkDir, "sym-link-dir", "", "Sym Link Directory") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") return cmd } diff --git a/sztp-agent/cmd/enable.go b/sztp-agent/cmd/enable.go index 2c6513d9..2a750316 100644 --- a/sztp-agent/cmd/enable.go +++ b/sztp-agent/cmd/enable.go @@ -29,8 +29,8 @@ func Enable() *cobra.Command { deviceEndEntityCert string bootstrapTrustAnchorCert string statusFilePath string - resultFilePath string - symLinkDir string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ @@ -53,9 +53,9 @@ func Enable() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "", "Bootstrap server trust anchor Cert") - flags.StringVar(&statusFilePath, "status-file-path", "", "Status file path") - flags.StringVar(&resultFilePath, "result-file-path", "", "Result file path") - flags.StringVar(&symLinkDir, "sym-link-dir", "", "Sym Link Directory") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") return cmd } diff --git a/sztp-agent/cmd/run.go b/sztp-agent/cmd/run.go index a76cb791..0178e621 100644 --- a/sztp-agent/cmd/run.go +++ b/sztp-agent/cmd/run.go @@ -33,8 +33,8 @@ func Run() *cobra.Command { deviceEndEntityCert string bootstrapTrustAnchorCert string statusFilePath string - resultFilePath string - symLinkDir string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ @@ -55,15 +55,6 @@ func Run() *cobra.Command { _, err := url.ParseRequestURI(bootstrapURL) cobra.CheckErr(err) } - if statusFilePath == "" { - return fmt.Errorf("'--status-file-path' is required") - } - if resultFilePath == "" { - return fmt.Errorf("'--result-file-path' is required") - } - if symLinkDir == "" { - return fmt.Errorf("'--symlink-dir' is required") - } for _, filePath := range arrayChecker { info, err := os.Stat(filePath) cobra.CheckErr(err) @@ -89,7 +80,7 @@ func Run() *cobra.Command { flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "/certs/opi.pem", "Bootstrap server trust anchor Cert") flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") - flags.StringVar(&symLinkDir, "sym-link-dir", "", "Sym Link Directory") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") return cmd } diff --git a/sztp-agent/cmd/status.go b/sztp-agent/cmd/status.go index debbaa42..b8b1e1fc 100644 --- a/sztp-agent/cmd/status.go +++ b/sztp-agent/cmd/status.go @@ -29,8 +29,8 @@ func Status() *cobra.Command { deviceEndEntityCert string bootstrapTrustAnchorCert string statusFilePath string - resultFilePath string - symLinkDir string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ @@ -49,13 +49,13 @@ func Status() *cobra.Command { flags.StringVar(&bootstrapURL, "bootstrap-url", "", "Bootstrap server URL") flags.StringVar(&serialNumber, "serial-number", "", "Device's serial number") flags.StringVar(&dhcpLeaseFile, "dhcp-lease-file", "/var/lib/dhclient/dhclient.leases", "Device's dhclient leases file") - flags.StringVar(&devicePassword, "device-password", "", "Dehomevice's password") + flags.StringVar(&devicePassword, "device-password", "", "Device's password") flags.StringVar(&devicePrivateKey, "device-private-key", "", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "", "Bootstrap server trust anchor Cert") - flags.StringVar(&statusFilePath, "status-file-path", "", "Status file path") - flags.StringVar(&resultFilePath, "result-file-path", "", "Result file path") - flags.StringVar(&symLinkDir, "sym-link-dir", "", "Sym Link Directory") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") return cmd } diff --git a/sztp-agent/pkg/secureagent/agent.go b/sztp-agent/pkg/secureagent/agent.go index eaf33831..5c43e7d0 100644 --- a/sztp-agent/pkg/secureagent/agent.go +++ b/sztp-agent/pkg/secureagent/agent.go @@ -93,9 +93,9 @@ type Agent struct { BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo // BootstrapServerOnboardingInfo structure BootstrapServerRedirectInfo BootstrapServerRedirectInfo // BootstrapServerRedirectInfo structure HttpClient HttpClient - StatusFilePath string // Path to the status file - ResultFilePath string // Path to the result file - SymLinkDir string // Path to the symlink directory for the status file + StatusFilePath string // Path to the status file + ResultFilePath string // Path to the result file + SymLinkDir string // Path to the symlink directory for the status file } func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath, symLinkDir string, httpClient HttpClient) *Agent { @@ -116,7 +116,7 @@ func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, deviceP HttpClient: httpClient, StatusFilePath: statusFilePath, ResultFilePath: resultFilePath, - SymLinkDir: symLinkDir, + SymLinkDir: symLinkDir, } } diff --git a/sztp-agent/pkg/secureagent/agent_test.go b/sztp-agent/pkg/secureagent/agent_test.go index 25d75f96..fec69b9b 100644 --- a/sztp-agent/pkg/secureagent/agent_test.go +++ b/sztp-agent/pkg/secureagent/agent_test.go @@ -829,9 +829,9 @@ func TestNewAgent(t *testing.T) { devicePrivateKey string deviceEndEntityCert string bootstrapTrustAnchorCert string - statusFilePath string - resultFilePath string - symLinkDir string + statusFilePath string + resultFilePath string + symLinkDir string } client := http.Client{} tests := []struct { @@ -851,7 +851,7 @@ func TestNewAgent(t *testing.T) { bootstrapTrustAnchorCert: "TestBootstrapTrustCert", statusFilePath: "TestStatusFilePath", resultFilePath: "TestResultFilePath", - symLinkDir: "TestSymLinkDir", + symLinkDir: "TestSymLinkDir", }, want: &Agent{ InputBootstrapURL: "TestBootstrap", @@ -866,7 +866,7 @@ func TestNewAgent(t *testing.T) { DhcpLeaseFile: "TestDhcpLeaseFile", StatusFilePath: "TestStatusFilePath", ResultFilePath: "TestResultFilePath", - SymLinkDir: "TestSymLinkDir", + SymLinkDir: "TestSymLinkDir", HttpClient: &client, }, }, diff --git a/sztp-agent/pkg/secureagent/configuration.go b/sztp-agent/pkg/secureagent/configuration.go index 95b35b6a..f6594de2 100644 --- a/sztp-agent/pkg/secureagent/configuration.go +++ b/sztp-agent/pkg/secureagent/configuration.go @@ -10,7 +10,7 @@ import ( func (a *Agent) copyConfigurationFile() error { log.Println("[INFO] Starting the Copy Configuration.") _ = a.doReportProgress(ProgressTypeConfigInitiated, "Configuration Initiated") - _ = a.updateAndSaveStatus("config", true, "") + _ = a.updateAndSaveStatus(StageTypeConfig, true, "") // Copy the configuration file to the device file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + "-config") if err != nil { @@ -37,7 +37,7 @@ func (a *Agent) copyConfigurationFile() error { } log.Println("[INFO] Configuration file copied successfully") _ = a.doReportProgress(ProgressTypeConfigComplete, "Configuration Complete") - _ = a.updateAndSaveStatus("config", false, "") + _ = a.updateAndSaveStatus(StageTypeConfig, false, "") return nil } @@ -45,20 +45,24 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error { var script, scriptName string var reportStart, reportEnd ProgressType switch typeOf { - case "post": + case POST: script = a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.PostConfigurationScript - scriptName = "post" + scriptName = POST reportStart = ProgressTypePostScriptInitiated reportEnd = ProgressTypePostScriptComplete default: // pre or default script = a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.PreConfigurationScript - scriptName = "pre" + scriptName = PRE reportStart = ProgressTypePreScriptInitiated reportEnd = ProgressTypePreScriptComplete } log.Println("[INFO] Starting the " + scriptName + "-configuration.") _ = a.doReportProgress(reportStart, "Report starting") - _ = a.updateAndSaveStatus(scriptName+"-script", true, "") + if scriptName == PRE { + _ = a.updateAndSaveStatus(StageTypePreScript, true, "") + } else if scriptName == POST { + _ = a.updateAndSaveStatus(StageTypePostScript, true, "") + } // nolint:gosec file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + scriptName + "configuration.sh") if err != nil { @@ -92,7 +96,11 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error { } log.Println(string(out)) // remove it _ = a.doReportProgress(reportEnd, "Report end") - _ = a.updateAndSaveStatus(scriptName+"-script", false, "") + if scriptName == PRE { + _ = a.updateAndSaveStatus(StageTypePreScript, false, "") + } else if scriptName == POST { + _ = a.updateAndSaveStatus(StageTypePostScript, false, "") + } log.Println("[INFO] " + scriptName + "-Configuration script executed successfully") return nil } diff --git a/sztp-agent/pkg/secureagent/daemon.go b/sztp-agent/pkg/secureagent/daemon.go index 066e5714..a8abd7b9 100644 --- a/sztp-agent/pkg/secureagent/daemon.go +++ b/sztp-agent/pkg/secureagent/daemon.go @@ -37,51 +37,60 @@ func (a *Agent) RunCommandDaemon() error { log.Println("failed to prepare status: ", err) return err } + _ = a.updateAndSaveStatus(StageTypeIsCompleted, true, "") for { err := a.performBootstrapSequence() if err != nil { log.Println("[ERROR] Failed to perform the bootstrap sequence: ", err.Error()) log.Println("[INFO] Retrying in 5 seconds") time.Sleep(5 * time.Second) + _ = a.updateAndSaveStatus(StageTypeIsCompleted, false, err.Error()) continue } + _ = a.updateAndSaveStatus(StageTypeIsCompleted, false, "") return nil } } func (a *Agent) performBootstrapSequence() error { - _ = a.updateAndSaveStatus("bootstrap", true, "") var err error err = a.discoverBootstrapURLs() if err != nil { + _ = a.updateAndSaveStatus(StageTypeParsing, false, err.Error()) return err } err = a.doRequestBootstrapServerOnboardingInfo() if err != nil { + _ = a.updateAndSaveStatus(StageTypeOnboarding, false, err.Error()) return err } err = a.doHandleBootstrapRedirect() if err != nil { + _ = a.updateAndSaveStatus(StageTypeBootImage, false, err.Error()) return err } err = a.downloadAndValidateImage() if err != nil { + _ = a.updateAndSaveStatus(StageTypeBootImage, false, err.Error()) return err } err = a.copyConfigurationFile() if err != nil { + _ = a.updateAndSaveStatus(StageTypeConfig, false, err.Error()) return err } err = a.launchScriptsConfiguration(PRE) if err != nil { + _ = a.updateAndSaveStatus(StageTypePreScript, false, err.Error()) return err } err = a.launchScriptsConfiguration(POST) if err != nil { + _ = a.updateAndSaveStatus(StageTypePostScript, false, err.Error()) return err } _ = a.doReportProgress(ProgressTypeBootstrapComplete, "Bootstrap Complete") - _ = a.updateAndSaveStatus("bootstrap", false, "") + _ = a.updateAndSaveStatus(StageTypeBootstrap, false, "") return nil } @@ -148,7 +157,7 @@ func (a *Agent) doRequestBootstrapServerOnboardingInfo() error { } log.Println("[INFO] Response retrieved successfully") _ = a.doReportProgress(ProgressTypeBootstrapInitiated, "Bootstrap Initiated") - _ = a.updateAndSaveStatus("bootstrap", true, "") + _ = a.updateAndSaveStatus(StageTypeBootstrap, true, "") crypto := res.IetfSztpBootstrapServerOutput.ConveyedInformation newVal, err := base64.StdEncoding.DecodeString(crypto) if err != nil { diff --git a/sztp-agent/pkg/secureagent/image.go b/sztp-agent/pkg/secureagent/image.go index 44666985..7362d778 100644 --- a/sztp-agent/pkg/secureagent/image.go +++ b/sztp-agent/pkg/secureagent/image.go @@ -23,7 +23,7 @@ import ( func (a *Agent) downloadAndValidateImage() error { log.Printf("[INFO] Starting the Download Image: %v", a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI) _ = a.doReportProgress(ProgressTypeBootImageInitiated, "BootImage Initiated") - _ = a.updateAndSaveStatus("boot-image", true, "") + _ = a.updateAndSaveStatus(StageTypeBootImage, true, "") // Download the image from DownloadURI and save it to a file a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference = fmt.Sprintf("%8d", time.Now().Unix()) for i, item := range a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI { @@ -79,7 +79,7 @@ func (a *Agent) downloadAndValidateImage() error { } log.Println("[INFO] Checksum verified successfully") _ = a.doReportProgress(ProgressTypeBootImageComplete, "BootImage Complete") - _ = a.updateAndSaveStatus("boot-image", false, "") + _ = a.updateAndSaveStatus(StageTypeBootImage, false, "") return nil default: return errors.New("unsupported hash algorithm") diff --git a/sztp-agent/pkg/secureagent/progress_test.go b/sztp-agent/pkg/secureagent/progress_test.go index 78def06b..a821dcf9 100644 --- a/sztp-agent/pkg/secureagent/progress_test.go +++ b/sztp-agent/pkg/secureagent/progress_test.go @@ -142,7 +142,7 @@ func TestAgent_doReportProgress(t *testing.T) { DhcpLeaseFile: "DHCPLEASEFILE", ProgressJSON: ProgressJSON{}, }, - wantErr: true, + wantErr: true, }, } for _, tt := range tests { diff --git a/sztp-agent/pkg/secureagent/status.go b/sztp-agent/pkg/secureagent/status.go index 656471cb..6738c271 100644 --- a/sztp-agent/pkg/secureagent/status.go +++ b/sztp-agent/pkg/secureagent/status.go @@ -4,108 +4,166 @@ Copyright (C) 2022-2023 Intel Corporation Copyright (c) 2022 Dell Inc, or its subsidiaries. Copyright (C) 2022 Red Hat. */ - +// nolint // Package secureagent implements the secure agent package secureagent import ( - "encoding/json" "fmt" "log" - "os" "path/filepath" "time" ) +type StageType int64 + +const ( + StageTypeInit StageType = iota + StageTypeDownloadingFile + StageTypePendingReboot + StageTypeParsing + StageTypeOnboarding + StageTypeRedirect + StageTypeBootImage + StageTypePreScript + StageTypeConfig + StageTypePostScript + StageTypeBootstrap + StageTypeIsCompleted +) + +func (s StageType) String() string { + switch s { + case StageTypeInit: + return "init" + case StageTypeDownloadingFile: + return "downloading-file" + case StageTypePendingReboot: + return "pending-reboot" + case StageTypeParsing: + return "parsing" + case StageTypeOnboarding: + return "onboarding" + case StageTypeRedirect: + return "redirect" + case StageTypeBootImage: + return "boot-image" + case StageTypePreScript: + return "pre-script" + case StageTypeConfig: + return "config" + case StageTypePostScript: + return "post-script" + case StageTypeBootstrap: + return "bootstrap" + case StageTypeIsCompleted: + return "is-completed" + default: + return "unknown" + } +} + // Status represents the status of the provisioning process. type Status struct { - Init StageStatus `json:"init"` - DownloadingFile StageStatus `json:"downloading-file"` // not sure if this is needed - PendingReboot StageStatus `json:"pending-reboot"` - Parsing StageStatus `json:"parsing"` - BootImage StageStatus `json:"boot-image"` - PreScript StageStatus `json:"pre-script"` - Config StageStatus `json:"config"` - PostScript StageStatus `json:"post-script"` - Bootstrap StageStatus `json:"bootstrap"` - IsCompleted StageStatus `json:"is-completed"` - Informational string `json:"informational"` - DataSource string `json:"datasource"` - Stage string `json:"stage"` + Init StageStatus `json:"init"` + DownloadingFile StageStatus `json:"downloading-file"` + PendingReboot StageStatus `json:"pending-reboot"` + Parsing StageStatus `json:"parsing"` + Onboarding StageStatus `json:"onboarding"` + Redirect StageStatus `json:"redirect"` + BootImage StageStatus `json:"boot-image"` + PreScript StageStatus `json:"pre-script"` + Config StageStatus `json:"config"` + PostScript StageStatus `json:"post-script"` + Bootstrap StageStatus `json:"bootstrap"` + IsCompleted StageStatus `json:"is-completed"` + Informational string `json:"informational"` + Stage string `json:"stage"` } +// Result represents the result of the provisioning process. type Result struct { - DataSource string `json:"dat asource"` - Errors []string `json:"errors"` + Errors []string `json:"errors"` } +// StageStatus represents the status of a specific stage. type StageStatus struct { - Errors []string `json:"errors"` - Start float64 `json:"start"` - End float64 `json:"end"` -} - -// LoadStatusFile loads the current status.json from the filesystem. -func (a *Agent) loadStatusFile() (*Status, error) { - file, err := os.ReadFile(a.GetStatusFilePath()) - if err != nil { - return nil, err - } - var status Status - err = json.Unmarshal(file, &status) - if err != nil { - return nil, err - } - return &status, nil -} - -func (a *Agent) updateAndSaveStatus(stage string, isStart bool, errMsg string) error { - status, err := a.loadStatusFile() + Errors []string `json:"errors"` + Start float64 `json:"start"` + End float64 `json:"end"` +} + +func (a *Agent) getCurrStatus() (*Status, error) { + var status Status + err := loadFile(a.GetStatusFilePath(), &status) if err != nil { - fmt.Println("Creating a new status file.") - status = a.createNewStatus() + return nil, err } + return &status, nil +} - if err := a.updateStageStatus(status, stage, isStart, errMsg); err != nil { - return err +func (a *Agent) getCurrResult() (*Result, error) { + var result Result + err := loadFile(a.GetResultFilePath(), &result) + if err != nil { + return nil, err } - - return a.saveStatus(status) + return &result, nil } -// createNewStatus initializes a new Status object when status.json doesn't exist. func (a *Agent) createNewStatus() *Status { return &Status{ - DataSource: "ds", - Stage: "", + Stage: "", + IsCompleted: StageStatus{}, } } +// updateAndSaveStatus updates the status object for a specific stage and saves it to the status.json file. +func (a *Agent) updateAndSaveStatus(s StageType, isStart bool, errMsg string) error { + status, err := a.getCurrStatus() + if err != nil { + fmt.Println("Creating a new status file.") + status = a.createNewStatus() + } + + err = a.updateStageStatus(status, s, isStart, errMsg) + if err != nil { + return err + } + + return a.saveStatus(status) +} + // updateStageStatus updates the status object for a specific stage. -func (a *Agent) updateStageStatus(status *Status, stage string, isStart bool, errMsg string) error { +func (a *Agent) updateStageStatus(status *Status, stageType StageType, isStart bool, errMsg string) error { now := float64(time.Now().Unix()) + stage := stageType.String() - switch stage { - case "init": - updateStage(&status.Init, isStart, now, errMsg) - case "downloading-file": - updateStage(&status.DownloadingFile, isStart, now, errMsg) - case "pending-reboot": - updateStage(&status.PendingReboot, isStart, now, errMsg) - case "is-completed": - updateStage(&status.IsCompleted, isStart, now, errMsg) - case "parsing": - updateStage(&status.Parsing, isStart, now, errMsg) - case "boot-image": - updateStage(&status.BootImage, isStart, now, errMsg) - case "pre-script": - updateStage(&status.PreScript, isStart, now, errMsg) - case "config": - updateStage(&status.Config, isStart, now, errMsg) - case "post-script": - updateStage(&status.PostScript, isStart, now, errMsg) - case "bootstrap": - updateStage(&status.Bootstrap, isStart, now, errMsg) + switch stageType { + case StageTypeInit: + a.updateStage(&status.Init, isStart, now, errMsg) + case StageTypeDownloadingFile: + a.updateStage(&status.DownloadingFile, isStart, now, errMsg) + case StageTypePendingReboot: + a.updateStage(&status.PendingReboot, isStart, now, errMsg) + case StageTypeIsCompleted: + a.updateStage(&status.IsCompleted, isStart, now, errMsg) + case StageTypeParsing: + a.updateStage(&status.Parsing, isStart, now, errMsg) + case StageTypeOnboarding: + a.updateStage(&status.Onboarding, isStart, now, errMsg) + case StageTypeRedirect: + a.updateStage(&status.Redirect, isStart, now, errMsg) + case StageTypeBootImage: + a.updateStage(&status.BootImage, isStart, now, errMsg) + case StageTypePreScript: + a.updateStage(&status.PreScript, isStart, now, errMsg) + case StageTypeConfig: + a.updateStage(&status.Config, isStart, now, errMsg) + case StageTypePostScript: + a.updateStage(&status.PostScript, isStart, now, errMsg) + case StageTypeBootstrap: + a.updateStage(&status.Bootstrap, isStart, now, errMsg) default: return fmt.Errorf("unknown stage: %s", stage) @@ -113,15 +171,15 @@ func (a *Agent) updateStageStatus(status *Status, stage string, isStart bool, er // Update the current stage if isStart { - status.Stage = stage + status.Stage = stage + "-in-progress" } else { - status.Stage = "" + status.Stage = stage + "-completed" } return nil } -func updateStage(stageStatus *StageStatus, isStart bool, now float64, errMsg string) { +func (a *Agent) updateStage(stageStatus *StageStatus, isStart bool, now float64, errMsg string) { if isStart { stageStatus.Start = now stageStatus.End = 0 @@ -129,30 +187,47 @@ func updateStage(stageStatus *StageStatus, isStart bool, now float64, errMsg str stageStatus.End = now if errMsg != "" { stageStatus.Errors = append(stageStatus.Errors, errMsg) + err := a.updateAndSaveResult(errMsg) + if err != nil { + fmt.Printf("Failed to update and save result: %v\n", err) + } } } } -// SaveStatusToFile saves the Status object to the status.json file. func (a *Agent) saveStatus(status *Status) error { return saveToFile(status, a.GetStatusFilePath()) } -// SaveResultFile saves the Result object to the result.json file. func (a *Agent) saveResult(result *Result) error { return saveToFile(result, a.GetResultFilePath()) } +func (a *Agent) updateAndSaveResult(errMsg string) error { + result, err := a.getCurrResult() + if err != nil { + fmt.Println("Creating a new result file.") + result = &Result{ + Errors: []string{}, + } + } + + if errMsg != "" { + result.Errors = append(result.Errors, errMsg) + } + + return a.saveResult(result) +} + // RunCommandStatus runs the command in the background func (a *Agent) RunCommandStatus() error { log.Println("RunCommandStatus") - // read the status file and print the status in command line - status, err := a.loadStatusFile() - if err != nil { - log.Println("failed to load status file: ", err) - return err - } - fmt.Printf("Current status: %+v\n", status) + status, err := a.getCurrStatus() + if err != nil { + log.Println("failed to load status file: ", err) + return err + } + fmt.Printf("Current status: %+v\n", status) return nil } @@ -165,31 +240,34 @@ func (a *Agent) prepareStatus() error { return err } - if err := ensureFileExists(a.GetStatusFilePath()); err != nil { - return err - } - if err := ensureFileExists(a.GetResultFilePath()); err != nil { - return err - } + fmt.Println("Status File Path", a.GetStatusFilePath()) + fmt.Println("Result File Path", a.GetResultFilePath()) + + if err := ensureFileExists(a.GetStatusFilePath()); err != nil { + return err + } + if err := ensureFileExists(a.GetResultFilePath()); err != nil { + return err + } - statusSymlinkPath := filepath.Join(a.GetSymLinkDir(), "status.json") - resultSymlinkPath := filepath.Join(a.GetSymLinkDir(), "result.json") + statusSymlinkPath := filepath.Join(a.GetSymLinkDir(), "status.json") + resultSymlinkPath := filepath.Join(a.GetSymLinkDir(), "result.json") - // Create symlinks for status.json and result.json - if err := createSymlink(a.GetStatusFilePath(), statusSymlinkPath); err != nil { - fmt.Printf("Failed to create symlink for status.json: %v\n", err) - return err - } - if err := createSymlink(a.GetResultFilePath(), resultSymlinkPath); err != nil { - fmt.Printf("Failed to create symlink for result.json: %v\n", err) - return err - } + // Create symlinks for status.json and result.json + if err := createSymlink(a.GetStatusFilePath(), statusSymlinkPath); err != nil { + fmt.Printf("Failed to create symlink for status.json: %v\n", err) + return err + } + if err := createSymlink(a.GetResultFilePath(), resultSymlinkPath); err != nil { + fmt.Printf("Failed to create symlink for result.json: %v\n", err) + return err + } - fmt.Println("Symlinks created successfully.") + fmt.Println("Symlinks created successfully.") - if err := a.updateAndSaveStatus("init", true, ""); err != nil { - return err - } + if err := a.updateAndSaveStatus(StageTypeInit, true, ""); err != nil { + return err + } return nil } diff --git a/sztp-agent/pkg/secureagent/status_test.go b/sztp-agent/pkg/secureagent/status_test.go index d16d14c8..afdf405d 100644 --- a/sztp-agent/pkg/secureagent/status_test.go +++ b/sztp-agent/pkg/secureagent/status_test.go @@ -20,9 +20,9 @@ func TestAgent_RunCommandStatus(t *testing.T) { ProgressJSON ProgressJSON BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo BootstrapServerRedirectInfo BootstrapServerRedirectInfo - StatusFilePath string - ResultFilePath string - SymLinkDir string + StatusFilePath string + ResultFilePath string + SymLinkDir string } tests := []struct { name string @@ -44,9 +44,9 @@ func TestAgent_RunCommandStatus(t *testing.T) { ProgressJSON: ProgressJSON{}, BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{}, BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{}, - StatusFilePath: "/var/lib/sztp/status.json", - ResultFilePath: "/var/lib/sztp/result.json", - SymLinkDir: "/run/sztp", + StatusFilePath: "/var/lib/sztp/status.json", + ResultFilePath: "/var/lib/sztp/result.json", + SymLinkDir: "/run/sztp", }, }, } @@ -65,11 +65,13 @@ func TestAgent_RunCommandStatus(t *testing.T) { ProgressJSON: tt.fields.ProgressJSON, BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo, BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo, - StatusFilePath: tt.fields.StatusFilePath, - ResultFilePath: tt.fields.ResultFilePath, - SymLinkDir: tt.fields.SymLinkDir, + StatusFilePath: tt.fields.StatusFilePath, + ResultFilePath: tt.fields.ResultFilePath, + SymLinkDir: tt.fields.SymLinkDir, + } + if err := a.prepareStatus(); err != nil { + t.Errorf("prepareStatus() error = %v", err) } - a.prepareStatus() if err := a.RunCommandStatus(); (err != nil) != tt.wantErr { t.Errorf("RunCommandStatus() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/sztp-agent/pkg/secureagent/utils.go b/sztp-agent/pkg/secureagent/utils.go index 0eddb2f3..23ad4f65 100644 --- a/sztp-agent/pkg/secureagent/utils.go +++ b/sztp-agent/pkg/secureagent/utils.go @@ -89,14 +89,18 @@ func calculateSHA256File(filePath string) (string, error) { return checkSum, nil } -// saveToFile writes the given data to a specified file path. func saveToFile(data interface{}, filePath string) error { tempPath := filePath + ".tmp" + tempPath = filepath.Clean(tempPath) file, err := os.Create(tempPath) if err != nil { return err } - defer file.Close() + defer func() { + if err := file.Close(); err != nil { + log.Println("[ERROR] Error when closing:", err) + } + }() encoder := json.NewEncoder(file) if err := encoder.Encode(data); err != nil { @@ -107,53 +111,72 @@ func saveToFile(data interface{}, filePath string) error { return os.Rename(tempPath, filePath) } -// EnsureDirExists checks if a directory exists, and creates it if it doesn't. func ensureDirExists(dir string) error { - if _, err := os.Stat(dir); os.IsNotExist(err) { - err := os.MkdirAll(dir, 0755) // Create the directory with appropriate permissions - if err != nil { - return fmt.Errorf("failed to create directory %s: %v", dir, err) - } - } - return nil + if _, err := os.Stat(dir); os.IsNotExist(err) { + err := os.MkdirAll(dir, 0750) // Create the directory with appropriate permissions + if err != nil { + return fmt.Errorf("failed to create directory %s: %v", dir, err) + } + } + return nil } -// EnsureFile ensures that a file exists; creates it if it does not. func ensureFileExists(filePath string) error { - // Ensure the directory exists - dir := filepath.Dir(filePath) - if err := ensureDirExists(dir); err != nil { - return err - } - - // Check if the file already exists - if _, err := os.Stat(filePath); os.IsNotExist(err) { - // File does not exist, create it - file, err := os.Create(filePath) - if err != nil { - return fmt.Errorf("failed to create file %s: %v", filePath, err) - } - defer file.Close() - fmt.Printf("File %s created successfully.\n", filePath) - } else { - fmt.Printf("File %s already exists.\n", filePath) - } - return nil + dir := filepath.Dir(filePath) + if err := ensureDirExists(dir); err != nil { + return err + } + + if _, err := os.Stat(filePath); os.IsNotExist(err) { + filePath = filepath.Clean(filePath) + file, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("failed to create file %s: %v", filePath, err) + } + defer func() { + if err := file.Close(); err != nil { + log.Println("[ERROR] Error when closing:", err) + } + }() + fmt.Printf("File %s created successfully.\n", filePath) + } else { + fmt.Printf("File %s already exists.\n", filePath) + } + return nil } -// CreateSymlink creates a symlink for a file from target to link location. func createSymlink(targetFile, linkFile string) error { - // Ensure the directory for the symlink exists - linkDir := filepath.Dir(linkFile) - if err := ensureDirExists(linkDir); err != nil { - return err - } - - // Remove any existing symlink - if _, err := os.Lstat(linkFile); err == nil { - os.Remove(linkFile) - } - - // Create a new symlink - return os.Symlink(targetFile, linkFile) + targetFile = filepath.Clean(targetFile) + linkFile = filepath.Clean(linkFile) + + linkDir := filepath.Dir(linkFile) + if err := ensureDirExists(linkDir); err != nil { + return err + } + + // Check if linkFile exists and is a symlink to targetFile + if existingTarget, err := os.Readlink(linkFile); err == nil { + if existingTarget == targetFile { + return nil // Symlink already points to the target; skip creation + } + // Remove the existing file (even if it's a wrong symlink or regular file) + if err := os.Remove(linkFile); err != nil { + return err + } + } + + return os.Symlink(targetFile, linkFile) +} + +func loadFile(filePath string, v interface{}) error { + filePath = filepath.Clean(filePath) + file, err := os.ReadFile(filePath) + if err != nil { + return err + } + err = json.Unmarshal(file, v) + if err != nil { + return err + } + return nil } diff --git a/sztp-agent/pkg/secureagent/utils_test.go b/sztp-agent/pkg/secureagent/utils_test.go index 42e134f1..a3217bf0 100644 --- a/sztp-agent/pkg/secureagent/utils_test.go +++ b/sztp-agent/pkg/secureagent/utils_test.go @@ -84,7 +84,11 @@ func Test_saveToFile(t *testing.T) { if err != nil { t.Fatalf("failed to create temp directory: %v", err) } - defer os.RemoveAll(tempDir) + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("failed to remove temp directory: %v", err) + } + }() filePath := filepath.Join(tempDir, "test.json") data := map[string]string{"key": "value"} @@ -99,11 +103,16 @@ func Test_saveToFile(t *testing.T) { t.Fatalf("file %s was not created", filePath) } + filePath = filepath.Clean(filePath) file, err := os.Open(filePath) if err != nil { t.Fatalf("failed to open the file: %v", err) } - defer file.Close() + defer func() { + if err := file.Close(); err != nil { + t.Fatalf("failed to close the file: %v", err) + } + }() var readData map[string]string decoder := json.NewDecoder(file) @@ -117,12 +126,16 @@ func Test_saveToFile(t *testing.T) { } } -func TestEnsureDirExists(t *testing.T) { +func Test_ensureDirExists(t *testing.T) { tempDir, err := os.MkdirTemp("", "test_ensure_dir_exists") if err != nil { t.Fatalf("failed to create temp directory: %v", err) } - defer os.RemoveAll(tempDir) + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("failed to remove temp directory: %v", err) + } + }() newDir := filepath.Join(tempDir, "newdir") @@ -145,12 +158,16 @@ func TestEnsureDirExists(t *testing.T) { } } -func TestEnsureFileExists(t *testing.T) { +func Test_ensureFileExists(t *testing.T) { tempDir, err := os.MkdirTemp("", "test_ensure_file_exists") if err != nil { t.Fatalf("failed to create temp directory: %v", err) } - defer os.RemoveAll(tempDir) + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("failed to remove temp directory: %v", err) + } + }() newFilePath := filepath.Join(tempDir, "newdir", "testfile.txt") @@ -169,17 +186,21 @@ func TestEnsureFileExists(t *testing.T) { } } -func TestCreateSymlink(t *testing.T) { +func Test_createSymlink(t *testing.T) { tempDir, err := os.MkdirTemp("", "test_create_symlink") if err != nil { t.Fatalf("failed to create temp directory: %v", err) } - defer os.RemoveAll(tempDir) + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("failed to remove temp directory: %v", err) + } + }() targetFile := filepath.Join(tempDir, "target.txt") linkFile := filepath.Join(tempDir, "link.txt") - err = os.WriteFile(targetFile, []byte("test data"), 0644) + err = os.WriteFile(targetFile, []byte("test data"), 0600) if err != nil { t.Fatalf("failed to create target file: %v", err) } @@ -190,7 +211,6 @@ func TestCreateSymlink(t *testing.T) { } linkInfo, err := os.Lstat(linkFile) - t.Logf("linkInfo: %v", linkInfo) /// if err != nil { t.Fatalf("failed to stat symlink: %v", err) } @@ -207,7 +227,7 @@ func TestCreateSymlink(t *testing.T) { } newTargetFile := filepath.Join(tempDir, "new_target.txt") - err = os.WriteFile(newTargetFile, []byte("new data"), 0644) + err = os.WriteFile(newTargetFile, []byte("new data"), 0600) if err != nil { t.Fatalf("failed to create new target file: %v", err) } From b1113e56205dbbec6d6392cd1339f6f5253ccc34 Mon Sep 17 00:00:00 2001 From: Bhoopesh Date: Sat, 26 Oct 2024 03:13:08 +0530 Subject: [PATCH 4/4] fix: use rand no for temp file Signed-off-by: Bhoopesh --- sztp-agent/pkg/secureagent/status.go | 1 - sztp-agent/pkg/secureagent/status_test.go | 35 ++++++++++++++++++++--- sztp-agent/pkg/secureagent/utils.go | 17 ++++++++--- 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/sztp-agent/pkg/secureagent/status.go b/sztp-agent/pkg/secureagent/status.go index 6738c271..4922a45f 100644 --- a/sztp-agent/pkg/secureagent/status.go +++ b/sztp-agent/pkg/secureagent/status.go @@ -169,7 +169,6 @@ func (a *Agent) updateStageStatus(status *Status, stageType StageType, isStart b return fmt.Errorf("unknown stage: %s", stage) } - // Update the current stage if isStart { status.Stage = stage + "-in-progress" } else { diff --git a/sztp-agent/pkg/secureagent/status_test.go b/sztp-agent/pkg/secureagent/status_test.go index afdf405d..75c94705 100644 --- a/sztp-agent/pkg/secureagent/status_test.go +++ b/sztp-agent/pkg/secureagent/status_test.go @@ -4,9 +4,36 @@ // Package secureagent implements the secure agent package secureagent -import "testing" +import ( + "testing" +) + +const StatusTestContent = `{ + "init": {"errors": [], "start": 1729891263, "end": 0}, + "downloading-file": {"errors": [], "start": 0, "end": 0}, + "pending-reboot": {"errors": [], "start": 0, "end": 0}, + "parsing": {"errors": [], "start": 0, "end": 0}, + "onboarding": {"errors": [], "start": 0, "end": 0}, + "redirect": {"errors": [], "start": 0, "end": 0}, + "boot-image": {"errors": [], "start": 1729891263, "end": 1729891263}, + "pre-script": {"errors": [], "start": 1729891264, "end": 1729891264}, + "config": {"errors": [], "start": 1729891264, "end": 1729891264}, + "post-script": {"errors": [], "start": 1729891264, "end": 1729891264}, + "bootstrap": {"errors": [], "start": 1729891263, "end": 1729891264}, + "is-completed": {"errors": [], "start": 1729891263, "end": 1729891264}, + "informational": "", + "stage": "is-completed-completed" +}` + +const ResultTestContent = `{ + "errors": ["error1", "error2"], +}` func TestAgent_RunCommandStatus(t *testing.T) { + testStatusFile := "/tmp/sztp/status.json" + testResultFile := "/tmp/sztp/result.json" + testSymLinkDir := "/tmp/symlink" + type fields struct { BootstrapURL string SerialNumber string @@ -44,9 +71,9 @@ func TestAgent_RunCommandStatus(t *testing.T) { ProgressJSON: ProgressJSON{}, BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{}, BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{}, - StatusFilePath: "/var/lib/sztp/status.json", - ResultFilePath: "/var/lib/sztp/result.json", - SymLinkDir: "/run/sztp", + StatusFilePath: testStatusFile, + ResultFilePath: testResultFile, + SymLinkDir: testSymLinkDir, }, }, } diff --git a/sztp-agent/pkg/secureagent/utils.go b/sztp-agent/pkg/secureagent/utils.go index 23ad4f65..21aa3c73 100644 --- a/sztp-agent/pkg/secureagent/utils.go +++ b/sztp-agent/pkg/secureagent/utils.go @@ -9,6 +9,7 @@ Copyright (C) 2022 Red Hat. package secureagent import ( + "crypto/rand" "crypto/sha256" "encoding/json" "fmt" @@ -90,7 +91,9 @@ func calculateSHA256File(filePath string) (string, error) { } func saveToFile(data interface{}, filePath string) error { - tempPath := filePath + ".tmp" + filePath = filepath.Clean(filePath) + random, _ := rand.Prime(rand.Reader, 64) + tempPath := fmt.Sprintf("%s.%d.tmp", filePath, random) // rand number to avoid conflicts when multiple agents are running tempPath = filepath.Clean(tempPath) file, err := os.Create(tempPath) if err != nil { @@ -108,7 +111,11 @@ func saveToFile(data interface{}, filePath string) error { } // Atomic move of temp file to replace the original. - return os.Rename(tempPath, filePath) + if err := os.Rename(tempPath, filePath); err != nil { + return fmt.Errorf("failed to rename %s to %s: %v", tempPath, filePath, err) + } + + return nil } func ensureDirExists(dir string) error { @@ -127,11 +134,13 @@ func ensureFileExists(filePath string) error { return err } + fmt.Printf("Checking if file %s exists...\n", filePath) + if _, err := os.Stat(filePath); os.IsNotExist(err) { filePath = filepath.Clean(filePath) file, err := os.Create(filePath) if err != nil { - return fmt.Errorf("failed to create file %s: %v", filePath, err) + return fmt.Errorf("[ERROR] failed to create file %s: %v", filePath, err) } defer func() { if err := file.Close(); err != nil { @@ -157,7 +166,7 @@ func createSymlink(targetFile, linkFile string) error { // Check if linkFile exists and is a symlink to targetFile if existingTarget, err := os.Readlink(linkFile); err == nil { if existingTarget == targetFile { - return nil // Symlink already points to the target; skip creation + return nil // Symlink already points to the target -> skip creation } // Remove the existing file (even if it's a wrong symlink or regular file) if err := os.Remove(linkFile); err != nil {