Skip to content

Commit

Permalink
Merge pull request #36 from DelineaXPM/tokenCaching.ds
Browse files Browse the repository at this point in the history
Added simple access token caching
  • Loading branch information
delinea-sagar authored Nov 25, 2024
2 parents 680813d + a4fb757 commit 5730b57
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 35 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@
.vscode
.idea
test_config.json
.DS_Store
142 changes: 107 additions & 35 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ import (
"io"
"io/ioutil"
"log"
"math"
"mime/multipart"
"net/http"
"net/url"
"os"
"regexp"
"strings"
"time"
)

const (
Expand Down Expand Up @@ -40,6 +43,11 @@ type Server struct {
Configuration
}

type TokenCache struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
}

// New returns an initialized Secrets object
func New(config Configuration) (*Server, error) {
if config.ServerURL == "" && config.Tenant == "" || config.ServerURL != "" && config.Tenant != "" {
Expand Down Expand Up @@ -158,7 +166,13 @@ func (s Server) accessResource(method, resource, path string, input interface{})

log.Printf("[DEBUG] calling %s %s", method, req.URL.String())

data, _, err := handleResponse((&http.Client{}).Do(req))
data, statusCode, err := handleResponse((&http.Client{}).Do(req))

// Check for unauthorized or access denied
if statusCode.StatusCode == http.StatusUnauthorized || statusCode.StatusCode == http.StatusForbidden {
s.clearTokenCache()
log.Printf("[ERROR] Token cache cleared due to unauthorized or access denied response.")
}

return data, err
}
Expand Down Expand Up @@ -252,17 +266,69 @@ func (s Server) uploadFile(secretId int, fileField SecretField) error {
return err
}

func (s *Server) setCacheAccessToken(value string, expiresIn int, baseURL string) error {
cache := TokenCache{}
cache.AccessToken = value
cache.ExpiresIn = (int(time.Now().Unix()) + expiresIn) - int(math.Floor(float64(expiresIn)*0.9))

data, _ := json.Marshal(cache)
os.Setenv("SS_AT_"+url.QueryEscape(baseURL), string(data))
return nil
}

func (s *Server) getCacheAccessToken(baseURL string) (string, bool) {
data, ok := os.LookupEnv("SS_AT_" + url.QueryEscape(baseURL))
if !ok {
s.clearTokenCache()
return "", ok
}
cache := TokenCache{}
if err := json.Unmarshal([]byte(data), &cache); err != nil {
return "", false
}
if time.Now().Unix() < int64(cache.ExpiresIn) {
return cache.AccessToken, true
}
return "", false
}

func (s *Server) clearTokenCache() {
var baseURL string

if s.ServerURL == "" {
baseURL = fmt.Sprintf(cloudBaseURLTemplate, s.Tenant, s.TLD)
} else {
baseURL = s.ServerURL
}

os.Setenv("SS_AT_"+url.QueryEscape(baseURL), "")
}

// getAccessToken gets an OAuth2 Access Grant and returns the token
// endpoint and get an accessGrant.
func (s *Server) getAccessToken() (string, error) {
if s.Credentials.Token != "" {
return s.Credentials.Token, nil
}
response, err := s.checkPlatformDetails()
var baseURL string

if s.ServerURL == "" {
baseURL = fmt.Sprintf(cloudBaseURLTemplate, s.Tenant, s.TLD)
} else {
baseURL = s.ServerURL
}

response, err := s.checkPlatformDetails(baseURL)
if err != nil {
log.Print("Error while checking server details:", err)
return "", err
} else if err == nil && response == "" {

accessToken, found := s.getCacheAccessToken(baseURL)
if found {
return accessToken, nil
}

values := url.Values{
"username": {s.Credentials.Username},
"password": {s.Credentials.Password},
Expand Down Expand Up @@ -292,21 +358,17 @@ func (s *Server) getAccessToken() (string, error) {
log.Print("[ERROR] parsing grant response:", err)
return "", err
}
if err = s.setCacheAccessToken(grant.AccessToken, grant.ExpiresIn, baseURL); err != nil {
log.Print("[ERROR] caching access token:", err)
return "", err
}
return grant.AccessToken, nil
} else {
return response, nil
}
}

func (s *Server) checkPlatformDetails() (string, error) {
var baseURL string

if s.ServerURL == "" {
baseURL = fmt.Sprintf(cloudBaseURLTemplate, s.Tenant, s.TLD)
} else {
baseURL = s.ServerURL
}

func (s *Server) checkPlatformDetails(baseURL string) (string, error) {
platformHelthCheckUrl := fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "health")
ssHealthCheckUrl := fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "healthcheck.aspx")

Expand All @@ -316,40 +378,50 @@ func (s *Server) checkPlatformDetails() (string, error) {
} else {
isHealthy := checkJSONResponse(platformHelthCheckUrl)
if isHealthy {
requestData := url.Values{}
requestData.Set("grant_type", "client_credentials")
requestData.Set("client_id", s.Credentials.Username)
requestData.Set("client_secret", s.Credentials.Password)
requestData.Set("scope", "xpmheadless")

req, err := http.NewRequest("POST", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "identity/api/oauth2/token/xpmplatform"), bytes.NewBufferString(requestData.Encode()))
if err != nil {
log.Print("Error creating HTTP request:", err)
return "", err
}
accessToken, found := s.getCacheAccessToken(baseURL)
if !found {
requestData := url.Values{}
requestData.Set("grant_type", "client_credentials")
requestData.Set("client_id", s.Credentials.Username)
requestData.Set("client_secret", s.Credentials.Password)
requestData.Set("scope", "xpmheadless")

req, err := http.NewRequest("POST", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "identity/api/oauth2/token/xpmplatform"), bytes.NewBufferString(requestData.Encode()))
if err != nil {
log.Print("Error creating HTTP request:", err)
return "", err
}

req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

data, _, err := handleResponse((&http.Client{}).Do(req))
if err != nil {
log.Print("[ERROR] get token response error:", err)
return "", err
}
data, _, err := handleResponse((&http.Client{}).Do(req))
if err != nil {
log.Print("[ERROR] get token response error:", err)
return "", err
}

var tokenjsonResponse OAuthTokens
if err = json.Unmarshal(data, &tokenjsonResponse); err != nil {
log.Print("[ERROR] parsing get token response:", err)
return "", err
var tokenjsonResponse OAuthTokens
if err = json.Unmarshal(data, &tokenjsonResponse); err != nil {
log.Print("[ERROR] parsing get token response:", err)
return "", err
}
accessToken = tokenjsonResponse.AccessToken

if err = s.setCacheAccessToken(tokenjsonResponse.AccessToken, tokenjsonResponse.ExpiresIn, baseURL); err != nil {
log.Print("[ERROR] caching access token:", err)
return "", err
}
}

req, err = http.NewRequest("GET", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "vaultbroker/api/vaults"), bytes.NewBuffer([]byte{}))
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "vaultbroker/api/vaults"), bytes.NewBuffer([]byte{}))
if err != nil {
log.Print("Error creating HTTP request:", err)
return "", err
}
req.Header.Add("Authorization", "Bearer "+tokenjsonResponse.AccessToken)
req.Header.Add("Authorization", "Bearer "+accessToken)

data, _, err = handleResponse((&http.Client{}).Do(req))
data, _, err := handleResponse((&http.Client{}).Do(req))
if err != nil {
log.Print("[ERROR] get vaults response error:", err)
return "", err
Expand All @@ -374,7 +446,7 @@ func (s *Server) checkPlatformDetails() (string, error) {
return "", fmt.Errorf("no configured vault found")
}

return tokenjsonResponse.AccessToken, nil
return accessToken, nil
}
}
return "", fmt.Errorf("invalid URL")
Expand Down

0 comments on commit 5730b57

Please sign in to comment.