Skip to content

Commit

Permalink
Merge branch 'master' into bounty-handler-UTs-V3
Browse files Browse the repository at this point in the history
  • Loading branch information
AbdulWahab3181 authored Mar 7, 2024
2 parents 60b9a33 + 2a77241 commit bbf20c3
Show file tree
Hide file tree
Showing 18 changed files with 433 additions and 66 deletions.
21 changes: 21 additions & 0 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,27 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler {
})
}

// ConnectionContext parses token for connection code
func ConnectionCodeContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("token")

if token == "" {
fmt.Println("[auth] no token")
http.Error(w, http.StatusText(401), 401)
return
}

if token != config.Connection_Auth {
fmt.Println("Not a super admin : auth")
http.Error(w, http.StatusText(401), 401)
return
}
ctx := context.WithValue(r.Context(), ContextKey, token)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

func AdminCheck(pubkey string) bool {
for _, val := range config.SuperAdmins {
if val == pubkey {
Expand Down
2 changes: 2 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ var S3FolderName string
var S3Url string
var AdminCheck string
var AdminDevFreePass = "FREE_PASS"
var Connection_Auth string

var S3Client *s3.Client
var PresignClient *s3.PresignClient
Expand All @@ -51,6 +52,7 @@ func InitConfig() {
S3FolderName = os.Getenv("S3_FOLDER_NAME")
S3Url = os.Getenv("S3_URL")
AdminCheck = os.Getenv("ADMIN_CHECK")
Connection_Auth = os.Getenv("CONNECTION_AUTH")

// Add to super admins
SuperAdmins = StripSuperAdmins(AdminStrings)
Expand Down
10 changes: 6 additions & 4 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -1495,10 +1495,12 @@ func (db database) GetPeopleListShort(count uint32) *[]PersonInShort {
return &p
}

func (db database) CreateConnectionCode(c ConnectionCodes) (ConnectionCodes, error) {
if c.DateCreated == nil {
now := time.Now()
c.DateCreated = &now
func (db database) CreateConnectionCode(c []ConnectionCodes) ([]ConnectionCodes, error) {
now := time.Now()
for _, code := range c {
if code.DateCreated.IsZero() {
code.DateCreated = &now
}
}
db.db.Create(&c)
return c, nil
Expand Down
2 changes: 1 addition & 1 deletion db/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ type Database interface {
CountBounties() uint64
GetPeopleListShort(count uint32) *[]PersonInShort
GetConnectionCode() ConnectionCodesShort
CreateConnectionCode(c ConnectionCodes) (ConnectionCodes, error)
CreateConnectionCode(c []ConnectionCodes) ([]ConnectionCodes, error)
GetLnUser(lnKey string) int64
CreateLnUser(lnKey string) (Person, error)
GetBountiesLeaderboard() []LeaderData
Expand Down
20 changes: 13 additions & 7 deletions handlers/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"io"
"net/http"
"time"

"github.com/form3tech-oss/jwt-go"
"github.com/stakwork/sphinx-tribes/auth"
Expand Down Expand Up @@ -54,30 +53,37 @@ func (ah *authHandler) GetIsAdmin(w http.ResponseWriter, r *http.Request) {
}

func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Request) {
code := db.ConnectionCodes{}
now := time.Now()
codeArr := []db.ConnectionCodes{}
codeStrArr := []string{}

body, err := io.ReadAll(r.Body)
r.Body.Close()

err = json.Unmarshal(body, &code)
err = json.Unmarshal(body, &codeStrArr)

code.IsUsed = false
code.DateCreated = &now
for _, code := range codeStrArr {
code := db.ConnectionCodes{
ConnectionString: code,
IsUsed: false,
}
codeArr = append(codeArr, code)
}

if err != nil {
fmt.Println(err)
w.WriteHeader(http.StatusNotAcceptable)
return
}

_, err = ah.db.CreateConnectionCode(code)
_, err = ah.db.CreateConnectionCode(codeArr)

if err != nil {
fmt.Println("=> ERR create connection code", err)
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode("Codes created successfully")
}

func (ah *authHandler) GetConnectionCode(w http.ResponseWriter, _ *http.Request) {
Expand Down
34 changes: 22 additions & 12 deletions handlers/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"github.com/stakwork/sphinx-tribes/db"
mocks "github.com/stakwork/sphinx-tribes/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

func TestGetAdminPubkeys(t *testing.T) {
Expand Down Expand Up @@ -52,16 +51,21 @@ func TestGetAdminPubkeys(t *testing.T) {
}

func TestCreateConnectionCode(t *testing.T) {

mockDb := mocks.NewDatabase(t)
aHandler := NewAuthHandler(mockDb)
t.Run("should create connection code successful", func(t *testing.T) {
codeToBeInserted := db.ConnectionCodes{
ConnectionString: "custom connection string",
codeToBeInserted := []string{"custom connection string", "custom connection string 2"}

codeArr := []db.ConnectionCodes{}
for _, code := range codeToBeInserted {
code := db.ConnectionCodes{
ConnectionString: code,
IsUsed: false,
}
codeArr = append(codeArr, code)
}
mockDb.On("CreateConnectionCode", mock.MatchedBy(func(code db.ConnectionCodes) bool {
return code.IsUsed == false && code.ConnectionString == codeToBeInserted.ConnectionString
})).Return(codeToBeInserted, nil).Once()

mockDb.On("CreateConnectionCode", codeArr).Return(codeArr, nil).Once()

body, _ := json.Marshal(codeToBeInserted)
req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body))
Expand All @@ -77,12 +81,18 @@ func TestCreateConnectionCode(t *testing.T) {
})

t.Run("should return error if failed to add connection code", func(t *testing.T) {
codeToBeInserted := db.ConnectionCodes{
ConnectionString: "custom connection string",
codeToBeInserted := []string{"custom connection string", "custom connection string 2"}

codeArr := []db.ConnectionCodes{}
for _, code := range codeToBeInserted {
code := db.ConnectionCodes{
ConnectionString: code,
IsUsed: false,
}
codeArr = append(codeArr, code)
}
mockDb.On("CreateConnectionCode", mock.MatchedBy(func(code db.ConnectionCodes) bool {
return code.IsUsed == false && code.ConnectionString == codeToBeInserted.ConnectionString
})).Return(codeToBeInserted, errors.New("failed to create connection")).Once()

mockDb.On("CreateConnectionCode", codeArr).Return(codeArr, errors.New("failed to create connection")).Once()

body, _ := json.Marshal(codeToBeInserted)
req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body))
Expand Down
30 changes: 14 additions & 16 deletions handlers/bounty.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,15 +608,14 @@ func formatPayError(errorMsg string) db.InvoicePayError {
}
}

func GetLightningInvoice(payment_request string) (db.InvoiceResult, db.InvoiceError) {
func (h *bountyHandler) GetLightningInvoice(payment_request string) (db.InvoiceResult, db.InvoiceError) {
url := fmt.Sprintf("%s/invoice?payment_request=%s", config.RelayUrl, payment_request)

client := &http.Client{}
req, err := http.NewRequest(http.MethodGet, url, nil)

req.Header.Set("x-user-token", config.RelayAuthKey)
req.Header.Set("Content-Type", "application/json")
res, _ := client.Do(req)
res, _ := h.httpClient.Do(req)

if err != nil {
log.Printf("Request Failed: %s", err)
Expand Down Expand Up @@ -695,9 +694,9 @@ func (h *bountyHandler) PayLightningInvoice(payment_request string) (db.InvoiceP
}
}

func GetInvoiceData(w http.ResponseWriter, r *http.Request) {
func (h *bountyHandler) GetInvoiceData(w http.ResponseWriter, r *http.Request) {
paymentRequest := chi.URLParam(r, "paymentRequest")
invoiceData, invoiceErr := GetLightningInvoice(paymentRequest)
invoiceData, invoiceErr := h.GetLightningInvoice(paymentRequest)

if invoiceErr.Error != "" {
w.WriteHeader(http.StatusForbidden)
Expand All @@ -709,7 +708,7 @@ func GetInvoiceData(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(invoiceData)
}

func PollInvoice(w http.ResponseWriter, r *http.Request) {
func (h *bountyHandler) PollInvoice(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string)
paymentRequest := chi.URLParam(r, "paymentRequest")
Expand All @@ -721,7 +720,7 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) {
return
}

invoiceRes, invoiceErr := GetLightningInvoice(paymentRequest)
invoiceRes, invoiceErr := h.GetLightningInvoice(paymentRequest)

if invoiceErr.Error != "" {
w.WriteHeader(http.StatusForbidden)
Expand All @@ -731,14 +730,14 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) {

if invoiceRes.Response.Settled {
// Todo if an invoice is settled
invoice := db.DB.GetInvoice(paymentRequest)
invData := db.DB.GetUserInvoiceData(paymentRequest)
dbInvoice := db.DB.GetInvoice(paymentRequest)
invoice := h.db.GetInvoice(paymentRequest)
invData := h.db.GetUserInvoiceData(paymentRequest)
dbInvoice := h.db.GetInvoice(paymentRequest)

// Make any change only if the invoice has not been settled
if !dbInvoice.Status {
if invoice.Type == "BUDGET" {
db.DB.AddAndUpdateBudget(invoice)
h.db.AddAndUpdateBudget(invoice)
} else if invoice.Type == "KEYSEND" {
url := fmt.Sprintf("%s/payment", config.RelayUrl)

Expand All @@ -748,12 +747,11 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) {

jsonBody := []byte(bodyData)

client := &http.Client{}
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))

req.Header.Set("x-user-token", config.RelayAuthKey)
req.Header.Set("Content-Type", "application/json")
res, _ := client.Do(req)
res, _ := h.httpClient.Do(req)

if err != nil {
log.Printf("Request Failed: %s", err)
Expand All @@ -769,13 +767,13 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) {
keysendRes := db.KeysendSuccess{}
err = json.Unmarshal(body, &keysendRes)

bounty, err := db.DB.GetBountyByCreated(uint(invData.Created))
bounty, err := h.db.GetBountyByCreated(uint(invData.Created))

if err == nil {
bounty.Paid = true
}

db.DB.UpdateBounty(bounty)
h.db.UpdateBounty(bounty)
} else {
// Unmarshal result
keysendError := db.KeysendError{}
Expand All @@ -784,7 +782,7 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) {
}
}
// Update the invoice status
db.DB.UpdateInvoice(paymentRequest)
h.db.UpdateInvoice(paymentRequest)
}

}
Expand Down
Loading

0 comments on commit bbf20c3

Please sign in to comment.