From 9312c6d2850b0d480c8434ac7dc7391eee3edae3 Mon Sep 17 00:00:00 2001 From: ChrisFriesen <26197502+ChrisFriesen@users.noreply.github.com> Date: Wed, 12 Apr 2023 10:35:45 -0700 Subject: [PATCH 1/3] Adding tests for cape key --- cmd/cape/cmd/key.go | 18 ++++--- cmd/cape/cmd/key_test.go | 77 ++++++++++++++++++++++++++++ mocks/mocks.go | 66 ++++++++++++++++++++++-- pcrs/pcrs_test.go | 2 + sdk/key.go | 12 ++--- sdk/key_test.go | 108 +++++++++++++++++++++++++++++++++++++++ sdk/test_test.go | 43 +++------------- 7 files changed, 271 insertions(+), 55 deletions(-) create mode 100644 cmd/cape/cmd/key_test.go create mode 100644 sdk/key_test.go diff --git a/cmd/cape/cmd/key.go b/cmd/cape/cmd/key.go index b39b29ad..1d8bf8a8 100644 --- a/cmd/cape/cmd/key.go +++ b/cmd/cape/cmd/key.go @@ -3,7 +3,6 @@ package cmd import ( "crypto/x509" "encoding/pem" - "fmt" "github.com/spf13/cobra" @@ -38,7 +37,7 @@ func key(cmd *cobra.Command, args []string) error { return UserError{Msg: "error retrieving pcr flags", Err: err} } - token, err := getAuthToken() + token, err := authTokenFunc() if err != nil { return err } @@ -48,7 +47,7 @@ func key(cmd *cobra.Command, args []string) error { return err } - capeKey, err := sdk.Key(keyReq) + capeKey, err := keyFunc(keyReq) if err != nil { return err } @@ -57,18 +56,20 @@ func key(cmd *cobra.Command, args []string) error { // ...but NOTE that Cape will only support decryption if envelope encryption is used. p, err := x509.ParsePKIXPublicKey(capeKey) if err != nil { - return err + return UserError{Msg: "error: key in unexpected format", Err: err} } m, err := x509.MarshalPKIXPublicKey(p) if err != nil { - return err + return UserError{Msg: "error: key in unexpected format", Err: err} } - fmt.Println(string(pem.EncodeToMemory(&pem.Block{ + if _, err := cmd.OutOrStdout().Write(pem.EncodeToMemory(&pem.Block{ Type: "PUBLIC KEY", Bytes: m, - }))) + })); err != nil { + return err + } return nil } @@ -83,3 +84,6 @@ func GetKeyRequest(pcrSlice []string, token string) (sdk.KeyRequest, error) { PcrSlice: pcrSlice, }, nil } + +var authTokenFunc = getAuthToken +var keyFunc = sdk.Key diff --git a/cmd/cape/cmd/key_test.go b/cmd/cape/cmd/key_test.go new file mode 100644 index 00000000..ad425f67 --- /dev/null +++ b/cmd/cape/cmd/key_test.go @@ -0,0 +1,77 @@ +package cmd + +import ( + "errors" + "testing" + + "github.com/capeprivacy/cli/sdk" +) + +func TestKeyNoArgs(t *testing.T) { + cmd, stdout, stderr := getCmd() + cmd.SetArgs([]string{"key"}) + + want := `-----BEGIN PUBLIC KEY----- +MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAoM2iIWF9ocxYGsSUnyt9 +P7NLq3gv39uNJCdIGee/y8EQHhFEg6cJONPJP60E/3Zt4hnrYh4a4lx7rV0aWks5 +KxpQi6LPP98sUKLkZO/ZTcY5Ugtn7FAQNj19ohtI39c2WCgxUB/1IR485jE1SLFn +x351mcog4V3pdU6THK1ZQTNhkonsLwyaP5TzpKySpz+OlgOBNDxqm6iRb7BQrc/w +hYj8Fpfj92m83cWk+jhlqUQwjMZ3b0B9jmSfzUNmEZEng/+Bw9hFpMH48LOsAHwg +z5tC1RhuGI5Is6VaKUeKbnptZQREIcXcs2857h+1i6EVW11shn4IRpOl3nvFoU+t +SDwpOQXs7oFcsEWz+qhpknMcQfd/fv/z4FSUuvStzlNO6bsGm8KBNtXLjTbhK4V7 +Q44KcYulow/Dp4Rq3Pf+ZHgoqpfqujspWV1Sh++u6rPCte8lMozEIVd1scaCWw9S +w1id8sguJTfgccx1HBbp76q0U2zojfQf+EAyMHwN0/4JnqZ1mJZPhi9nGpnINZuy +wBsNPjtORYiDYLdLY7VL/O/tXsX03uVKfu6mQZNxhOSR2sD6AoEi/LaECMM+L96Q +EhOGvy7wILr1Zjc6KlUksXKOlXeKhJ0xxwcBWMznJG82WzeNczQ14I+I9RdkLUtP +vbU0SB7H7aX/bxqvQ+MOwS8CAwEAAQ== +-----END PUBLIC KEY----- +` + + keyFunc = func(keyReq sdk.KeyRequest) ([]byte, error) { + return []byte("0\x82\x02\"0\r\x06\t*\x86H\x86\xf7\r\x01\x01\x01\x05\x00\x03\x82\x02\x0f\x000\x82\x02\n\x02\x82\x02\x01\x00\xa0͢!a}\xa1\xccX\x1aĔ\x9f+}?\xb3K\xabx/\xdfۍ$'H\x19\xe7\xbf\xcb\xc1\x10\x1e\x11D\x83\xa7\t8\xd3\xc9?\xad\x04\xffvm\xe2\x19\xebb\x1e\x1a\xe2\\{\xad]\x1aZK9+\x1aP\x8b\xa2\xcf?\xdf,P\xa2\xe4d\xef\xd9M\xc69R\vg\xecP\x106=}\xa2\x1bH\xdf\xd76X(1P\x1f\xf5!\x1e<\xe615H\xb1g\xc7~u\x99\xca \xe1]\xe9uN\x93\x1c\xadYA3a\x92\x89\xec/\f\x9a?\x94\U000e4b12\xa7?\x8e\x96\x03\x814;NE\x88\x83`\xb7Kc\xb5K\xfc\xef\xed^\xc5\xf4\xde\xe5J~\xee\xa6A\x93q\x84\xe4\x91\xda\xc0\xfa\x02\x81\"\xfc\xb6\x84\b\xc3>/ސ\x12\x13\x86\xbf.\xf0 \xba\xf5f7:*U$\xb1r\x8e\x95w\x8a\x84\x9d1\xc7\a\x01X\xcc\xe7$o6[7\x8ds45\xe0\x8f\x88\xf5\x17d-KO\xbd\xb54H\x1e\xc7\xed\xa5\xffo\x1a\xafC\xe3\x0e\xc1/\x02\x03\x01\x00\x01"), nil + } + authTokenFunc = func() (string, error) { + return "you're you", nil + } + defer func() { + keyFunc = sdk.Key + authTokenFunc = getAuthToken + }() + + if err := cmd.Execute(); err != nil { + t.Fatalf("received unexpected error: %s", err) + } + + if got, want := stderr.String(), ""; got != want { + t.Fatalf("didn't get expected stderr, got %s, wanted %s", got, want) + } + + if got, want := stdout.String(), want; got != want { + t.Fatalf("didn't get expected stdout, got %s, wanted %s", got, want) + } +} + +func TestKeyInvalidFormat(t *testing.T) { + cmd, _, _ := getCmd() + cmd.SetArgs([]string{"key"}) + + keyFunc = func(keyReq sdk.KeyRequest) ([]byte, error) { + return []byte("-----BEGIN PUBLIC KEY-----\nTestKey\n-----END PUBLIC KEY-----\n"), nil + } + authTokenFunc = func() (string, error) { + return "you're you", nil + } + defer func() { + keyFunc = sdk.Key + authTokenFunc = getAuthToken + }() + + err := cmd.Execute() + if err == nil { + t.Fatalf("expected an error: %s", err) + } + + if !errors.As(err, &UserError{}) { + t.Fatalf("expected different error: %s", err) + } +} diff --git a/mocks/mocks.go b/mocks/mocks.go index 024c0290..056cfcfa 100644 --- a/mocks/mocks.go +++ b/mocks/mocks.go @@ -1,11 +1,69 @@ package mocks -import "github.com/capeprivacy/attest/attest" +import ( + "github.com/capeprivacy/attest/attest" + "github.com/capeprivacy/cli" + "github.com/capeprivacy/cli/entities" +) -type MockVerifier struct { +type Verifier struct { VerifyFn func(attestation []byte, nonce []byte) (*attest.AttestationDoc, error) } -func (m MockVerifier) Verify(attestation []byte, nonce []byte) (*attest.AttestationDoc, error) { - return m.VerifyFn(attestation, nonce) +func (v Verifier) Verify(attestation []byte, nonce []byte) (*attest.AttestationDoc, error) { + if v.VerifyFn != nil { + return v.VerifyFn(attestation, nonce) + } + return &attest.AttestationDoc{}, nil +} + +type Protocol struct { + WriteStartFn func(req entities.StartRequest) error + ReadAttestationDocFn func() ([]byte, error) + ReadRunResultsFn func() (*cli.RunResult, error) + WriteBinaryFn func(b []byte) error + WriteFunctionInfoFn func(name string, public bool) error + ReadDeploymentResultsFn func() (*entities.SetDeploymentIDRequest, error) +} + +func (p Protocol) WriteStart(req entities.StartRequest) error { + if p.WriteStartFn != nil { + return p.WriteStartFn(req) + } + return nil +} + +func (p Protocol) ReadAttestationDoc() ([]byte, error) { + if p.ReadAttestationDocFn != nil { + return p.ReadAttestationDocFn() + } + return []byte{}, nil +} + +func (p Protocol) ReadRunResults() (*cli.RunResult, error) { + if p.ReadRunResultsFn != nil { + return p.ReadRunResultsFn() + } + return &cli.RunResult{}, nil +} + +func (p Protocol) WriteBinary(b []byte) error { + if p.WriteBinaryFn != nil { + return p.WriteBinaryFn(b) + } + return nil +} + +func (p Protocol) WriteFunctionInfo(name string, public bool) error { + if p.WriteFunctionInfoFn != nil { + return p.WriteFunctionInfoFn(name, public) + } + return nil +} + +func (p Protocol) ReadDeploymentResults() (*entities.SetDeploymentIDRequest, error) { + if p.ReadDeploymentResultsFn != nil { + return p.ReadDeploymentResultsFn() + } + return &entities.SetDeploymentIDRequest{}, nil } diff --git a/pcrs/pcrs_test.go b/pcrs/pcrs_test.go index 90840dc7..e03ef5ea 100644 --- a/pcrs/pcrs_test.go +++ b/pcrs/pcrs_test.go @@ -1,3 +1,5 @@ +//go:build integration + package pcrs import ( diff --git a/sdk/key.go b/sdk/key.go index a5fee631..ecb83b70 100644 --- a/sdk/key.go +++ b/sdk/key.go @@ -39,7 +39,7 @@ func Key(keyReq KeyRequest) ([]byte, error) { // If the key file isn't present we download it, but log this error anyway in case something else happened. log.Debugf("Unable to open cape key file: %s", err) - capeKey, err = downloadAndSaveKey(keyReq) + capeKey, err = downloadAndSaveKey(keyReq, attest.NewVerifier()) if err != nil { return nil, err } @@ -48,10 +48,10 @@ func Key(keyReq KeyRequest) ([]byte, error) { return capeKey, nil } -func downloadAndSaveKey(keyReq KeyRequest) ([]byte, error) { +func downloadAndSaveKey(keyReq KeyRequest, verifier Verifier) ([]byte, error) { log.Debug("Downloading cape key...") - _, userData, err := ConnectAndAttest(keyReq) + _, userData, err := ConnectAndAttest(keyReq, verifier) if err != nil { log.Println("failed to attest") return nil, err @@ -71,7 +71,7 @@ func downloadAndSaveKey(keyReq KeyRequest) ([]byte, error) { } // TODO: Run, deploy and test could use this function. -func ConnectAndAttest(keyReq KeyRequest) (*attest.AttestationDoc, *AttestationUserData, error) { +func ConnectAndAttest(keyReq KeyRequest, verifier Verifier) (*attest.AttestationDoc, *AttestationUserData, error) { endpoint := fmt.Sprintf("%s/v1/key", keyReq.URL) authProtocolType := "cape.runtime" @@ -91,7 +91,7 @@ func ConnectAndAttest(keyReq KeyRequest) (*attest.AttestationDoc, *AttestationUs return nil, nil, err } - p := getProtocol(conn) + p := getProtocolFn(conn) req := entities.StartRequest{Nonce: nonce} log.Debug("\n> Sending Nonce and Auth Token") @@ -107,8 +107,6 @@ func ConnectAndAttest(keyReq KeyRequest) (*attest.AttestationDoc, *AttestationUs return nil, nil, err } - verifier := attest.NewVerifier() - log.Debug("< Auth Completed. Received Attestation Document") doc, err := verifier.Verify(attestDoc, nonce) if err != nil { diff --git a/sdk/key_test.go b/sdk/key_test.go new file mode 100644 index 00000000..7975e90a --- /dev/null +++ b/sdk/key_test.go @@ -0,0 +1,108 @@ +package sdk + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/capeprivacy/attest/attest" + "github.com/capeprivacy/cli/mocks" + "github.com/gorilla/websocket" +) + +func TestKeyNotPresent(t *testing.T) { + dir := t.TempDir() + file := "testKey.pub.der" + want := []byte("key") + + // Set up attestation + userData := AttestationUserData{ + CapeKey: want, + } + data, err := json.Marshal(userData) + if err != nil { + t.Errorf("unable to set up attestation user data: %s", err) + } + doc := attest.AttestationDoc{ + UserData: data, + } + + verifier := mocks.Verifier{ + VerifyFn: func(attestation []byte, nonce []byte) (*attest.AttestationDoc, error) { + return &doc, nil + }, + } + + getProtocolFn = func(ws *websocket.Conn) protocol { + return mocks.Protocol{} + } + defer func() { + getProtocolFn = getProtocol + }() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + _, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Fatal(err) + } + })) + defer s.Close() + + req := KeyRequest{ + ConfigDir: dir, + CapeKeyFile: file, + URL: wsURL(s.URL), + } + + got, err := downloadAndSaveKey(req, verifier) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + if !bytes.Equal(want, got) { + t.Fatalf("Key does not match, want: %s, got: %s", want, got) + } + + savedKey, err := readFile(dir, file) + if err != nil { + t.Fatalf("Unexpected error reading file %s", err) + } + + if !bytes.Equal(want, savedKey) { + t.Fatalf("Saved key does not match, want: %s, got: %s", want, got) + } +} + +func TestKeyPresent(t *testing.T) { + dir := t.TempDir() + file := "testKey.pub.der" + + want := []byte("-----BEGIN PUBLIC KEY-----\nTestKey\n-----END PUBLIC KEY-----") + err := os.WriteFile(filepath.Join(dir, file), want, os.ModePerm) + if err != nil { + t.Fatalf("Unable to setup key file err: %s", err) + } + + req := KeyRequest{ + ConfigDir: dir, + CapeKeyFile: file, + } + + got, err := Key(req) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + if !bytes.Equal(want, got) { + t.Fatalf("Key retrieved from file does not match, want: %s, got: %s", want, got) + } +} diff --git a/sdk/test_test.go b/sdk/test_test.go index 12b90c2f..9ae915f0 100644 --- a/sdk/test_test.go +++ b/sdk/test_test.go @@ -11,34 +11,9 @@ import ( "github.com/capeprivacy/cli" "github.com/capeprivacy/attest/attest" - "github.com/capeprivacy/cli/entities" "github.com/capeprivacy/cli/mocks" ) -type testProtocol struct { - start func(req entities.StartRequest) error - attest func() ([]byte, error) - results func() (*cli.RunResult, error) - binary func(b []byte) error -} - -func (t testProtocol) WriteFunctionInfo(name string, public bool) error { - return nil -} - -func (t testProtocol) ReadDeploymentResults() (*entities.SetDeploymentIDRequest, error) { - return nil, nil -} - -func (t testProtocol) WriteStart(request entities.StartRequest) error { - return t.start(request) -} -func (t testProtocol) ReadAttestationDoc() ([]byte, error) { return t.attest() } -func (t testProtocol) ReadRunResults() (*cli.RunResult, error) { - return t.results() -} -func (t testProtocol) WriteBinary(bytes []byte) error { return t.binary(bytes) } - func wsURL(origURL string) string { u, _ := url.Parse(origURL) u.Scheme = "ws" @@ -47,24 +22,18 @@ func wsURL(origURL string) string { } func TestCapeTest(t *testing.T) { - verifier := mocks.MockVerifier{ - VerifyFn: func(attestation []byte, nonce []byte) (*attest.AttestationDoc, error) { - return &attest.AttestationDoc{}, nil - }, - } - localEncrypt = func(doc attest.AttestationDoc, plaintext []byte) ([]byte, error) { return plaintext, nil } getProtocolFn = func(ws *websocket.Conn) protocol { - return testProtocol{ - start: func(req entities.StartRequest) error { return nil }, - attest: func() ([]byte, error) { return []byte{}, nil }, - results: func() (*cli.RunResult, error) { + return mocks.Protocol{ + ReadRunResultsFn: func() (*cli.RunResult, error) { return &cli.RunResult{Message: []byte("good job")}, nil }, - binary: func(b []byte) error { return nil }, } } + defer func() { + getProtocolFn = getProtocol + }() s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upgrader := websocket.Upgrader{ @@ -85,7 +54,7 @@ func TestCapeTest(t *testing.T) { Insecure: true, } - res, err := Test(test, verifier, wsURL(s.URL), []string{}) + res, err := Test(test, mocks.Verifier{}, wsURL(s.URL), []string{}) if err != nil { t.Fatal(err) } From 9eefe8b9f903439ed9915d703af6869e6fc6970d Mon Sep 17 00:00:00 2001 From: ChrisFriesen <26197502+ChrisFriesen@users.noreply.github.com> Date: Wed, 12 Apr 2023 10:45:02 -0700 Subject: [PATCH 2/3] Add token argument to key --- cmd/cape/cmd/key.go | 12 +++++++++--- sdk/key_test.go | 3 ++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/cmd/cape/cmd/key.go b/cmd/cape/cmd/key.go index 1d8bf8a8..629145ec 100644 --- a/cmd/cape/cmd/key.go +++ b/cmd/cape/cmd/key.go @@ -29,6 +29,7 @@ func init() { rootCmd.AddCommand(keyCmd) keyCmd.PersistentFlags().StringSliceP("pcr", "p", []string{""}, "pass multiple PCRs to validate against, used while getting key for the first time") + keyCmd.PersistentFlags().StringP("token", "t", "", "authorization token to use") } func key(cmd *cobra.Command, args []string) error { @@ -37,9 +38,14 @@ func key(cmd *cobra.Command, args []string) error { return UserError{Msg: "error retrieving pcr flags", Err: err} } - token, err := authTokenFunc() - if err != nil { - return err + token, _ := cmd.Flags().GetString("token") + if token == "" { + t, err := authTokenFunc() + if err != nil { + return err + } + + token = t } keyReq, err := GetKeyRequest(pcrSlice, token) diff --git a/sdk/key_test.go b/sdk/key_test.go index 7975e90a..f95bc2dc 100644 --- a/sdk/key_test.go +++ b/sdk/key_test.go @@ -9,9 +9,10 @@ import ( "path/filepath" "testing" + "github.com/gorilla/websocket" + "github.com/capeprivacy/attest/attest" "github.com/capeprivacy/cli/mocks" - "github.com/gorilla/websocket" ) func TestKeyNotPresent(t *testing.T) { From b35515142ecad50e8351ca63a09ccd32bdbd838d Mon Sep 17 00:00:00 2001 From: ChrisFriesen <26197502+ChrisFriesen@users.noreply.github.com> Date: Thu, 13 Apr 2023 07:36:23 -0700 Subject: [PATCH 3/3] Use existing auth function var --- cmd/cape/cmd/key.go | 3 +-- cmd/cape/cmd/key_test.go | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/cmd/cape/cmd/key.go b/cmd/cape/cmd/key.go index 629145ec..6538d864 100644 --- a/cmd/cape/cmd/key.go +++ b/cmd/cape/cmd/key.go @@ -40,7 +40,7 @@ func key(cmd *cobra.Command, args []string) error { token, _ := cmd.Flags().GetString("token") if token == "" { - t, err := authTokenFunc() + t, err := authToken() if err != nil { return err } @@ -91,5 +91,4 @@ func GetKeyRequest(pcrSlice []string, token string) (sdk.KeyRequest, error) { }, nil } -var authTokenFunc = getAuthToken var keyFunc = sdk.Key diff --git a/cmd/cape/cmd/key_test.go b/cmd/cape/cmd/key_test.go index ad425f67..edf53138 100644 --- a/cmd/cape/cmd/key_test.go +++ b/cmd/cape/cmd/key_test.go @@ -30,12 +30,12 @@ vbU0SB7H7aX/bxqvQ+MOwS8CAwEAAQ== keyFunc = func(keyReq sdk.KeyRequest) ([]byte, error) { return []byte("0\x82\x02\"0\r\x06\t*\x86H\x86\xf7\r\x01\x01\x01\x05\x00\x03\x82\x02\x0f\x000\x82\x02\n\x02\x82\x02\x01\x00\xa0͢!a}\xa1\xccX\x1aĔ\x9f+}?\xb3K\xabx/\xdfۍ$'H\x19\xe7\xbf\xcb\xc1\x10\x1e\x11D\x83\xa7\t8\xd3\xc9?\xad\x04\xffvm\xe2\x19\xebb\x1e\x1a\xe2\\{\xad]\x1aZK9+\x1aP\x8b\xa2\xcf?\xdf,P\xa2\xe4d\xef\xd9M\xc69R\vg\xecP\x106=}\xa2\x1bH\xdf\xd76X(1P\x1f\xf5!\x1e<\xe615H\xb1g\xc7~u\x99\xca \xe1]\xe9uN\x93\x1c\xadYA3a\x92\x89\xec/\f\x9a?\x94\U000e4b12\xa7?\x8e\x96\x03\x814;NE\x88\x83`\xb7Kc\xb5K\xfc\xef\xed^\xc5\xf4\xde\xe5J~\xee\xa6A\x93q\x84\xe4\x91\xda\xc0\xfa\x02\x81\"\xfc\xb6\x84\b\xc3>/ސ\x12\x13\x86\xbf.\xf0 \xba\xf5f7:*U$\xb1r\x8e\x95w\x8a\x84\x9d1\xc7\a\x01X\xcc\xe7$o6[7\x8ds45\xe0\x8f\x88\xf5\x17d-KO\xbd\xb54H\x1e\xc7\xed\xa5\xffo\x1a\xafC\xe3\x0e\xc1/\x02\x03\x01\x00\x01"), nil } - authTokenFunc = func() (string, error) { + authToken = func() (string, error) { return "you're you", nil } defer func() { keyFunc = sdk.Key - authTokenFunc = getAuthToken + authToken = getAuthToken }() if err := cmd.Execute(); err != nil { @@ -58,12 +58,12 @@ func TestKeyInvalidFormat(t *testing.T) { keyFunc = func(keyReq sdk.KeyRequest) ([]byte, error) { return []byte("-----BEGIN PUBLIC KEY-----\nTestKey\n-----END PUBLIC KEY-----\n"), nil } - authTokenFunc = func() (string, error) { + authToken = func() (string, error) { return "you're you", nil } defer func() { keyFunc = sdk.Key - authTokenFunc = getAuthToken + authToken = getAuthToken }() err := cmd.Execute()