From c135737a553c54c002b9b7074ae705b15b4c7f10 Mon Sep 17 00:00:00 2001 From: elraphty Date: Fri, 8 Nov 2024 08:51:47 +0100 Subject: [PATCH] refactored connection codes test --- handlers/auth.go | 3 +++ handlers/auth_test.go | 58 ++++++++----------------------------------- 2 files changed, 14 insertions(+), 47 deletions(-) diff --git a/handlers/auth.go b/handlers/auth.go index 6e734e945..3ef2d1346 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -69,6 +69,8 @@ func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Reque if err != nil { fmt.Println("Could not umarshal connection code body") + w.WriteHeader(http.StatusNotAcceptable) + return } for i := 0; i < int(codeBody.Number); i++ { @@ -82,6 +84,7 @@ func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Reque codeArr = append(codeArr, newCode) } } + _, err = ah.db.CreateConnectionCode(codeArr) if err != nil { diff --git a/handlers/auth_test.go b/handlers/auth_test.go index 2c65a58ad..fe6e91267 100644 --- a/handlers/auth_test.go +++ b/handlers/auth_test.go @@ -4,9 +4,6 @@ import ( "bytes" "context" "encoding/json" - "github.com/google/uuid" - "github.com/lib/pq" - mocks "github.com/stakwork/sphinx-tribes/mocks" "net/http" "net/http/httptest" "os" @@ -14,6 +11,10 @@ import ( "testing" "time" + "github.com/google/uuid" + "github.com/lib/pq" + mocks "github.com/stakwork/sphinx-tribes/mocks" + "github.com/form3tech-oss/jwt-go" "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/config" @@ -58,61 +59,23 @@ func TestCreateConnectionCode(t *testing.T) { t.Run("should create connection code successful", func(t *testing.T) { rr := httptest.NewRecorder() handler := http.HandlerFunc(aHandler.CreateConnectionCode) - codeStrArr := []string{"sampleCode1"} - - codeArr := []db.ConnectionCodes{} - now := time.Now() + body := []byte(`{"number": 2`) - for i, code := range codeStrArr { - code := db.ConnectionCodes{ - ID: uint(i), - ConnectionString: code, - IsUsed: false, - DateCreated: &now, - } - - codeArr = append(codeArr, code) - } - - codeShort := db.ConnectionCodesShort{ - ConnectionString: codeArr[0].ConnectionString, - DateCreated: codeArr[0].DateCreated, - } - - db.TestDB.CreateConnectionCode(codeArr) - - body, _ := json.Marshal(codeStrArr) req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body)) if err != nil { t.Fatal(err) } - codes := db.TestDB.GetConnectionCode() handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) - assert.EqualValues(t, codeShort.ConnectionString, codes.ConnectionString) - tolerance := time.Millisecond - timeDifference := codeShort.DateCreated.Sub(*codes.DateCreated) - if timeDifference < 0 { - timeDifference = -timeDifference - } - assert.True(t, timeDifference <= tolerance, "Expected DateCreated to be within tolerance") + codes := db.TestDB.GetConnectionCode() + assert.NotEmpty(t, codes) }) t.Run("should return error if failed to add connection code", func(t *testing.T) { - codeToBeInserted := []string{} + body := []byte(`{"number":0`) - codeArr := []db.ConnectionCodes{} - for _, code := range codeToBeInserted { - code := db.ConnectionCodes{ - ConnectionString: code, - IsUsed: false, - } - codeArr = append(codeArr, code) - } - - body, _ := json.Marshal(codeToBeInserted) req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body)) if err != nil { t.Fatal(err) @@ -125,7 +88,7 @@ func TestCreateConnectionCode(t *testing.T) { }) t.Run("should return error for malformed request body", func(t *testing.T) { - body := []byte(`{"id":0,"connection_string":"string","is_used":false,"date_created":"5T11:50:00Z"}`) + body := []byte(`{"id":0,"connection_string":"string", "number": 0}`) req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body)) if err != nil { t.Fatal(err) @@ -138,7 +101,8 @@ func TestCreateConnectionCode(t *testing.T) { }) t.Run("should return error for invalid json", func(t *testing.T) { - body := []byte(`{"id":0,"connection_string":"string"`) + body := []byte(`{"nonumber":0`) + req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body)) if err != nil { t.Fatal(err)