Skip to content

Commit

Permalink
feat: adds retries to azure metadata service.
Browse files Browse the repository at this point in the history
This adds retries to azure metadata service using the github.com/hashicorp/go-retryablehttp package.
Also adds some unit tests.
  • Loading branch information
VAveryanov8 committed Dec 17, 2024
1 parent 1a854fe commit 926c561
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 17 deletions.
68 changes: 53 additions & 15 deletions pkg/cloudmeta/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,53 @@ import (
"net/http"
"time"

"github.com/hashicorp/go-retryablehttp"
"github.com/pkg/errors"
"github.com/scylladb/go-log"
)

// AzureBaseURL is a base url of azure metadata service.
const AzureBaseURL = "http://169.254.169.254/metadata"
// azureBaseURL is a base url of azure metadata service.
const azureBaseURL = "http://169.254.169.254/metadata"

// AzureMetadata is a wrapper around azure metadata service.
type AzureMetadata struct {
// azureMetadata is a wrapper around azure metadata service.
type azureMetadata struct {
client *http.Client

baseURL string
}

// NewAzureMetadata returns AzureMetadata service.
func NewAzureMetadata() *AzureMetadata {
return &AzureMetadata{
client: defaultClient(),
baseURL: AzureBaseURL,
// newAzureMetadata returns AzureMetadata service.
func newAzureMetadata(logger log.Logger) *azureMetadata {
return &azureMetadata{
client: defaultClient(logger),
baseURL: azureBaseURL,
}
}

func defaultClient() *http.Client {
func defaultClient(logger log.Logger) *http.Client {
client := retryablehttp.NewClient()

client.RetryMax = 3
client.RetryWaitMin = 500 * time.Millisecond
client.RetryWaitMax = 5 * time.Second
client.Logger = &logWrapper{
logger: logger,
}

transport := http.DefaultTransport.(*http.Transport).Clone()
// we must not use proxy for the metadata requests - see https://learn.microsoft.com/en-us/azure/virtual-machines/instance-metadata-service?tabs=linux#proxies.
transport.Proxy = nil
return &http.Client{
// setting some generous timeout, it can be overwritten by using context.WithTimeout.
Timeout: 10 * time.Second,

client.HTTPClient = &http.Client{
// Quite small timeout per request, because we have retries and also it's a local network call.
Timeout: 1 * time.Second,
Transport: transport,
}
return client.StandardClient()
}

// Metadata return InstanceMetadata from azure if available.
func (azure *AzureMetadata) Metadata(ctx context.Context) (InstanceMetadata, error) {
func (azure *azureMetadata) Metadata(ctx context.Context) (InstanceMetadata, error) {
vmSize, err := azure.getVMSize(ctx)
if err != nil {
return InstanceMetadata{}, errors.Wrap(err, "azure.getVMSize")
Expand All @@ -58,7 +71,7 @@ func (azure *AzureMetadata) Metadata(ctx context.Context) (InstanceMetadata, err
// azureAPIVersion should be present in every request to metadata service in query parameter.
const azureAPIVersion = "2023-07-01"

func (azure *AzureMetadata) getVMSize(ctx context.Context) (string, error) {
func (azure *azureMetadata) getVMSize(ctx context.Context) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, azure.baseURL+"/instance", http.NoBody)
if err != nil {
return "", errors.Wrap(err, "http new request")
Expand Down Expand Up @@ -97,3 +110,28 @@ type azureMetadataResponse struct {
type azureCompute struct {
VMSize string `json:"vmSize"`
}

// logWrapper implements go-retryablehttp.LeveledLogger interface.
type logWrapper struct {
logger log.Logger
}

// Info wraps logger.Info method.
func (log *logWrapper) Info(msg string, keyVals ...interface{}) {
log.logger.Info(context.Background(), msg, keyVals...)
}

// Error wraps logger.Error method.
func (log *logWrapper) Error(msg string, keyVals ...interface{}) {
log.logger.Error(context.Background(), msg, keyVals...)
}

// Warn wraps logger.Error method.
func (log *logWrapper) Warn(msg string, keyVals ...interface{}) {
log.logger.Error(context.Background(), msg, keyVals...)
}

// Debug wraps logger.Debug method.
func (log *logWrapper) Debug(msg string, keyVals ...interface{}) {
log.logger.Debug(context.Background(), msg, keyVals...)
}
117 changes: 117 additions & 0 deletions pkg/cloudmeta/azure_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (C) 2024 ScyllaDB

package cloudmeta

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/scylladb/go-log"
)

func TestAzureMetadata(t *testing.T) {
testCases := []struct {
name string
handler http.Handler

expectedCalls int
expectedErr bool
expectedMeta InstanceMetadata
}{
{
name: "when response is 200",
handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
testCheckRequireParams(t, r)

w.Write([]byte(`{"compute":{"vmSize":"Standard-A3"}}`))
}),
expectedCalls: 1,
expectedErr: false,
expectedMeta: InstanceMetadata{
CloudProvider: CloudProviderAzure,
InstanceType: "Standard-A3",
},
},
{
name: "when response is 404: not retryable",
handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
testCheckRequireParams(t, r)

w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`internal server error`))
}),
expectedCalls: 1,
expectedErr: true,
expectedMeta: InstanceMetadata{},
},
{
name: "when response is 500: retryable",
handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
testCheckRequireParams(t, r)

w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`internal server error`))
}),
expectedCalls: 4,
expectedErr: true,
expectedMeta: InstanceMetadata{},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
handler := &testHandler{Handler: tc.handler}
testSrv := httptest.NewServer(handler)
defer testSrv.Close()

azureMeta := newAzureMetadata(log.NewDevelopment())
azureMeta.baseURL = testSrv.URL

meta, err := azureMeta.Metadata(context.Background())
if tc.expectedErr && err == nil {
t.Fatalf("expected err: %v\n", err)
}
if !tc.expectedErr && err != nil {
t.Fatalf("unexpected err: %v\n", err)
}

if tc.expectedCalls != handler.calls {
t.Fatalf("unexected number of calls: %d != %d", handler.calls, tc.expectedCalls)
}

if meta.CloudProvider != tc.expectedMeta.CloudProvider {
t.Fatalf("unexpected cloud provider: %s", meta.CloudProvider)
}

if meta.InstanceType != tc.expectedMeta.InstanceType {
t.Fatalf("unexpected instance type: %s", meta.InstanceType)
}
})
}
}

type testHandler struct {
http.Handler
// Keep track of how many times handler func has been called
// so we can test retries policy.
calls int
}

func (th *testHandler) ServeHTTP(w http.ResponseWriter, t *http.Request) {
th.calls++
th.Handler.ServeHTTP(w, t)
}

func testCheckRequireParams(t *testing.T, r *http.Request) {
t.Helper()
metadataHeader := r.Header.Get("Metadata")
if metadataHeader != "true" {
t.Fatalf("Metadata: true header is required")
}
apiVersion := r.URL.Query().Get("api-version")
if apiVersion != azureAPIVersion {
t.Fatalf("unexpected ?api-version: %s", apiVersion)
}
}
6 changes: 4 additions & 2 deletions pkg/cloudmeta/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"context"
"time"

"github.com/scylladb/go-log"

"github.com/pkg/errors"
"go.uber.org/multierr"
)
Expand Down Expand Up @@ -41,7 +43,7 @@ type CloudMeta struct {
}

// NewCloudMeta creates new CloudMeta provider.
func NewCloudMeta() (*CloudMeta, error) {
func NewCloudMeta(logger log.Logger) (*CloudMeta, error) {
const defaultTimeout = 5 * time.Second

awsMeta, err := newAWSMetadata()
Expand All @@ -51,7 +53,7 @@ func NewCloudMeta() (*CloudMeta, error) {

gcpMeta := newGCPMetadata()

azureMeta := NewAzureMetadata()
azureMeta := newAzureMetadata(logger)

return &CloudMeta{
providers: []CloudMetadataProvider{
Expand Down

0 comments on commit 926c561

Please sign in to comment.