Skip to content

Commit

Permalink
Use generic cache store
Browse files Browse the repository at this point in the history
  • Loading branch information
Kalpit Pant committed Dec 8, 2024
1 parent c301dea commit 2f92055
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 39 deletions.
52 changes: 40 additions & 12 deletions v2/distributed_gobreaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gobreaker

import (
"context"
"encoding/json"
"fmt"
"time"
)
Expand All @@ -14,32 +15,59 @@ type SharedState struct {
Expiry time.Time `json:"expiry"`
}

type SharedStateStore interface {
GetState(ctx context.Context) (SharedState, error)
SetState(ctx context.Context, state SharedState) error
type SharedDataStore interface {
GetData(ctx context.Context, name string) ([]byte, error)
SetData(ctx context.Context, name string, data []byte) error
}

// DistributedCircuitBreaker extends CircuitBreaker with distributed state storage
type DistributedCircuitBreaker[T any] struct {
*CircuitBreaker[T]
store SharedStateStore
store SharedDataStore
}

// NewDistributedCircuitBreaker returns a new DistributedCircuitBreaker configured with the given StorageSettings
func NewDistributedCircuitBreaker[T any](store SharedStateStore, settings Settings) *DistributedCircuitBreaker[T] {
func NewDistributedCircuitBreaker[T any](store SharedDataStore, settings Settings) *DistributedCircuitBreaker[T] {
cb := NewCircuitBreaker[T](settings)
return &DistributedCircuitBreaker[T]{
CircuitBreaker: cb,
store: store,
}
}

func (rcb *DistributedCircuitBreaker[T]) getStorageKey() string {
return "cb:" + rcb.name
}

func (rcb *DistributedCircuitBreaker[T]) getStoredState(ctx context.Context) (SharedState, error) {
var state SharedState
data, err := rcb.store.GetData(ctx, rcb.getStorageKey())
if len(data) == 0 {
// Key doesn't exist, return default state
return SharedState{State: StateClosed}, nil
} else if err != nil {
return state, err
}

err = json.Unmarshal(data, &state)
return state, err
}

func (rcb *DistributedCircuitBreaker[T]) setStoredState(ctx context.Context, state SharedState) error {
data, err := json.Marshal(state)
if err != nil {
return err
}

return rcb.store.SetData(ctx, rcb.getStorageKey(), data)
}

func (dcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State {
if dcb.store == nil {
return dcb.CircuitBreaker.State()
}

state, err := dcb.store.GetState(ctx)
state, err := dcb.getStoredState(ctx)
if err != nil {
// Fallback to in-memory state if Storage fails
return dcb.CircuitBreaker.State()
Expand All @@ -51,7 +79,7 @@ func (dcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State {
// Update the state in Storage if it has changed
if currentState != state.State {
state.State = currentState
if err := dcb.store.SetState(ctx, state); err != nil {
if err := dcb.setStoredState(ctx, state); err != nil {
// Log the error, but continue with the current state
fmt.Printf("Failed to update state in storage: %v\n", err)
}
Expand Down Expand Up @@ -86,7 +114,7 @@ func (dcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func()
}

func (dcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uint64, error) {
state, err := dcb.store.GetState(ctx)
state, err := dcb.getStoredState(ctx)
if err != nil {
return 0, err
}
Expand All @@ -95,7 +123,7 @@ func (dcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin

if currentState != state.State {
dcb.setState(&state, currentState, now)
err = dcb.store.SetState(ctx, state)
err = dcb.setStoredState(ctx, state)
if err != nil {
return 0, err
}
Expand All @@ -108,7 +136,7 @@ func (dcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin
}

state.Counts.onRequest()
err = dcb.store.SetState(ctx, state)
err = dcb.setStoredState(ctx, state)
if err != nil {
return 0, err
}
Expand All @@ -117,7 +145,7 @@ func (dcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin
}

func (dcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, before uint64, success bool) {
state, err := dcb.store.GetState(ctx)
state, err := dcb.getStoredState(ctx)
if err != nil {
return
}
Expand All @@ -133,7 +161,7 @@ func (dcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, befor
dcb.onFailure(&state, currentState, now)
}

dcb.store.SetState(ctx, state)
dcb.setStoredState(ctx, state)
}

func (dcb *DistributedCircuitBreaker[T]) onSuccess(state *SharedState, currentState State, now time.Time) {
Expand Down
38 changes: 11 additions & 27 deletions v2/distributed_gobreaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gobreaker

import (
"context"
"encoding/json"
"errors"
"testing"
"time"
Expand All @@ -19,27 +18,12 @@ type storageAdapter struct {
client *redis.Client
}

func (r *storageAdapter) GetState(ctx context.Context) (SharedState, error) {
var state SharedState
data, err := r.client.Get(ctx, "gobreaker").Bytes()
if len(data) == 0 {
// Key doesn't exist, return default state
return SharedState{State: StateClosed}, nil
} else if err != nil {
return state, err
}

err = json.Unmarshal(data, &state)
return state, err
func (r *storageAdapter) GetData(ctx context.Context, key string) ([]byte, error) {
return r.client.Get(ctx, key).Bytes()
}

func (r *storageAdapter) SetState(ctx context.Context, state SharedState) error {
data, err := json.Marshal(state)
if err != nil {
return err
}

return r.client.Set(ctx, "gobreaker", data, 0).Err()
func (r *storageAdapter) SetData(ctx context.Context, key string, value []byte) error {
return r.client.Set(ctx, key, value, 0).Err()
}

func setupTestWithMiniredis() (*DistributedCircuitBreaker[any], *miniredis.Miniredis, *redis.Client) {
Expand All @@ -66,14 +50,14 @@ func setupTestWithMiniredis() (*DistributedCircuitBreaker[any], *miniredis.Minir
}

func pseudoSleepStorage(ctx context.Context, dcb *DistributedCircuitBreaker[any], period time.Duration) {
state, _ := dcb.store.GetState(ctx)
state, _ := dcb.getStoredState(ctx)

state.Expiry = state.Expiry.Add(-period)
// Reset counts if the interval has passed
if time.Now().After(state.Expiry) {
state.Counts = Counts{}
}
dcb.store.SetState(ctx, state)
dcb.setStoredState(ctx, state)
}

func successRequest(ctx context.Context, dcb *DistributedCircuitBreaker[any]) error {
Expand Down Expand Up @@ -174,11 +158,11 @@ func TestDistributedCircuitBreakerCounts(t *testing.T) {
assert.Nil(t, successRequest(ctx, dcb))
}

state, _ := dcb.store.GetState(ctx)
state, _ := dcb.getStoredState(ctx)
assert.Equal(t, Counts{5, 5, 0, 5, 0}, state.Counts)

assert.Nil(t, failRequest(ctx, dcb))
state, _ = dcb.store.GetState(ctx)
state, _ = dcb.getStoredState(ctx)
assert.Equal(t, Counts{6, 5, 1, 0, 1}, state.Counts)
}

Expand Down Expand Up @@ -240,14 +224,14 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) {
assert.NoError(t, failRequest(ctx, customDCB))
}

state, err := customDCB.store.GetState(ctx)
state, err := customDCB.getStoredState(ctx)
assert.NoError(t, err)
assert.Equal(t, StateClosed, state.State)
assert.Equal(t, Counts{10, 5, 5, 0, 1}, state.Counts)

// Perform one more successful request
assert.NoError(t, successRequest(ctx, customDCB))
state, err = customDCB.store.GetState(ctx)
state, err = customDCB.getStoredState(ctx)
assert.NoError(t, err)
assert.Equal(t, Counts{11, 6, 5, 1, 0}, state.Counts)

Expand All @@ -262,7 +246,7 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) {
// Check if the circuit breaker is now open
assert.Equal(t, StateOpen, customDCB.State(ctx))

state, err = customDCB.store.GetState(ctx)
state, err = customDCB.getStoredState(ctx)
assert.NoError(t, err)
assert.Equal(t, Counts{0, 0, 0, 0, 0}, state.Counts)
})
Expand Down

0 comments on commit 2f92055

Please sign in to comment.