Skip to content
This repository has been archived by the owner on May 29, 2024. It is now read-only.

Commit

Permalink
Address most pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
adrain-cb committed Feb 9, 2024
1 parent 2851d2c commit 74f2414
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 63 deletions.
4 changes: 2 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ version: "3.8"
services:
localstack:
container_name: "${LOCALSTACK_DOCKER_NAME:-localstack-main}"
image: localstack/localstack
image: localstack/localstack:3.1.0
ports:
- "127.0.0.1:4566:4566" # LocalStack Gateway
- "127.0.0.1:4510-4559:4510-4559" # external services port range
Expand All @@ -12,4 +12,4 @@ services:
- DEBUG=${DEBUG:-0}
volumes:
- "/var/run/docker.sock:/var/run/docker.sock"
- "./scripts/localstack-setup.sh:/etc/localstack/init/ready.d/script.sh"
- "./scripts/localstack-e2e-test-setup.sh:/etc/localstack/init/ready.d/script.sh"
14 changes: 6 additions & 8 deletions e2e/alerting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func TestMultiDirectiveRouting(t *testing.T) {
return height != nil && height.Uint64() > receipt.BlockNumber.Uint64(), nil
}))

sqsMessages, err := e2e.GetMessages(ts.AppCfg.AlertConfig.SNSConfig.Endpoint, "multi-directive-test-queue")
sqsMessages, err := e2e.GetSNSMessages(ts.AppCfg.AlertConfig.SNSConfig.Endpoint, "multi-directive-test-queue")
require.NoError(t, err)

assert.Len(t, sqsMessages.Messages, 1, "Incorrect number of SNS messages sent")
Expand Down Expand Up @@ -172,20 +172,18 @@ func TestCoolDown(t *testing.T) {
return height != nil && height.Uint64() > receipt.BlockNumber.Uint64(), nil
}))

time.Sleep(1 * time.Second)

// Check that the balance enforcement was triggered using the mocked server cache.
posts := ts.TestSlackSvr.SlackAlerts()

require.Equal(t, 1, len(posts), "No balance enforcement alert was sent")
assert.Contains(t, posts[0].Text, "balance_enforcement", "Balance enforcement alert was not sent")
assert.Contains(t, posts[0].Text, alertMsg)

sqsMessages, err := e2e.GetMessages(ts.AppCfg.AlertConfig.SNSConfig.Endpoint, "alert-cooldown-test-queue")
sqsMessages, err := e2e.GetSNSMessages(ts.AppCfg.AlertConfig.SNSConfig.Endpoint, "alert-cooldown-test-queue")
assert.NoError(t, err)
assert.Len(t, sqsMessages.Messages, 1, "Incorrect number of SNS messages sent")
assert.Contains(t, *sqsMessages.Messages[0].Body, "balance_enforcement", "Balance enforcement alert was not sent")

require.Equal(t, 1, len(posts), "No balance enforcement alert was sent")
assert.Contains(t, posts[0].Text, "balance_enforcement", "Balance enforcement alert was not sent")
assert.Contains(t, posts[0].Text, alertMsg)

// Ensure that no new alerts are sent for provided cooldown period.
time.Sleep(1 * time.Second)
posts = ts.TestSlackSvr.SlackAlerts()
Expand Down
25 changes: 15 additions & 10 deletions e2e/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,7 @@ func CreateSysTestSuite(t *testing.T, topicArn string) *SysTestSuite {

pagerdutyServer := NewTestPagerDutyServer("127.0.0.1", 0)

if err := os.Setenv("AWS_REGION", "us-east-1"); err != nil { //nolint:tenv // Cannot use t.SetEnv here
t.Fatal(err)
}
if err := os.Setenv("AWS_SECRET_ACCESS_KEY", "test"); err != nil { //nolint:tenv // Cannot t.Setenv here
t.Fatal(err)
}
if err := os.Setenv("AWS_ACCESS_KEY_ID", "test"); err != nil { //nolint:tenv // Cannot use t.SetEnv here
t.Fatal(err)
}
setAwsVars(t)

slackURL := fmt.Sprintf("http://127.0.0.1:%d", slackServer.Port)
pagerdutyURL := fmt.Sprintf("http://127.0.0.1:%d", pagerdutyServer.Port)
Expand Down Expand Up @@ -203,7 +195,20 @@ func DefaultTestConfig() *config.Config {
}
}

func GetMessages(endpoint string, queueName string) (*sqs.ReceiveMessageOutput, error) {
func setAwsVars(t *testing.T) {
awsEnvVariables := map[string]string{
"AWS_REGION": "us-east-1",
"AWS_SECRET_ACCESS_KEY": "test",
"AWS_ACCESS_KEY_ID": "test",
}
for key, value := range awsEnvVariables {
if err := os.Setenv(key, value); err != nil {
t.Fatalf("Error setting %s environment variable: %s", key, err)
}
}
}

func GetSNSMessages(endpoint string, queueName string) (*sqs.ReceiveMessageOutput, error) {
sess, err := session.NewSession(&aws.Config{
Endpoint: aws.String(endpoint),
})
Expand Down
23 changes: 11 additions & 12 deletions internal/alert/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type Manager interface {
}

// Config ... Alert manager configuration
// SNSConfig is not part of the RoutingParams as we only support publishing to one SNS client
type Config struct {
RoutingCfgPath string
PagerdutyAlertEventsURL string
Expand Down Expand Up @@ -91,8 +92,8 @@ func (am *alertManager) handleSlackPost(alert core.Alert, policy *core.AlertPoli

// Create event trigger
event := &client.AlertEventTrigger{
Message: am.interpolator.SlackMessage(alert, policy.Msg),
Severity: alert.Sev,
Message: am.interpolator.SlackMessage(alert, policy.Msg),
Alert: alert,
}

for _, sc := range slackClients {
Expand Down Expand Up @@ -121,9 +122,8 @@ func (am *alertManager) handlePagerDutyPost(alert core.Alert) error {
}

event := &client.AlertEventTrigger{
Message: am.interpolator.PagerDutyMessage(alert),
DedupKey: alert.PathID,
Severity: alert.Sev,
Message: am.interpolator.PagerDutyMessage(alert),
Alert: alert,
}

for _, pdc := range pdClients {
Expand All @@ -145,24 +145,23 @@ func (am *alertManager) handlePagerDutyPost(alert core.Alert) error {

func (am *alertManager) handleSNSPublish(alert core.Alert, policy *core.AlertPolicy) error {
event := &client.AlertEventTrigger{
Message: am.interpolator.SlackMessage(alert, policy.Msg),
DedupKey: alert.PathID,
Severity: alert.Sev,
Message: am.interpolator.SlackMessage(alert, policy.Msg),
Alert: alert,
}

cli := am.cm.GetSNSClient()
c := am.cm.GetSNSClient()

resp, err := cli.PostEvent(am.ctx, event)
resp, err := c.PostEvent(am.ctx, event)
if err != nil {
return err
}

if resp.Status != core.SuccessStatus {
return fmt.Errorf("client %s could not post to sns: %s", cli.GetName(), resp.Message)
return fmt.Errorf("client %s could not post to sns: %s", c.GetName(), resp.Message)
}

am.logger.Debug("Successfully posted to SNS", zap.Any("resp", resp))
am.metrics.RecordAlertGenerated(alert, core.SNS, cli.GetName())
am.metrics.RecordAlertGenerated(alert, core.SNS, c.GetName())
return nil
}

Expand Down
1 change: 1 addition & 0 deletions internal/alert/routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type RoutingDirectory interface {
// routingDirectory ... Routing directory implementation
// NOTE: This implementation works for now, but if we add more routing clients in the future,
// we should consider refactoring this to be more generic
// Only one SNS client is needed in most cases. If we need to support multiple SNS clients, we can refactor this
type routingDirectory struct {
pagerDutyClients map[core.Severity][]client.PagerDutyClient
slackClients map[core.Severity][]client.SlackClient
Expand Down
21 changes: 16 additions & 5 deletions internal/client/alert.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ type AlertClient interface {

// AlertEventTrigger ... A standardized event trigger for alert clients
type AlertEventTrigger struct {
Message string
Severity core.Severity
DedupKey core.PathID
Message string
Alert core.Alert
}

// AlertAPIResponse ... A standardized response for alert clients
Expand All @@ -30,8 +29,20 @@ type AlertAPIResponse struct {
// ToPagerdutyEvent ... Converts an AlertEventTrigger to a PagerDutyEventTrigger
func (a *AlertEventTrigger) ToPagerdutyEvent() *PagerDutyEventTrigger {
return &PagerDutyEventTrigger{
DedupKey: a.DedupKey.String(),
Severity: a.Severity.ToPagerDutySev(),
DedupKey: a.Alert.PathID.String(),
Severity: a.Alert.Sev.ToPagerDutySev(),
Message: a.Message,
}
}

func (a *AlertEventTrigger) ToSNSMessagePayload() *SNSMessagePayload {
return &SNSMessagePayload{
Network: a.Alert.Net.String(),
HeuristicType: a.Alert.HT.String(),
Severity: a.Alert.Sev.String(),
PathID: a.Alert.PathID.String(),
HeuristicID: a.Alert.HeuristicID.String(),
Timestamp: a.Alert.Timestamp,
Content: a.Message,
}
}
14 changes: 8 additions & 6 deletions internal/client/alert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,24 @@ import (

func TestToPagerDutyEvent(t *testing.T) {
alert := &client.AlertEventTrigger{
Message: "test",
Severity: core.HIGH,
DedupKey: core.PathID{},
Message: "test",
Alert: core.Alert{
Sev: core.HIGH,
PathID: core.PathID{},
},
}

sPathID := alert.DedupKey.String()
sPathID := alert.Alert.PathID.String()
res := alert.ToPagerdutyEvent()
assert.Equal(t, core.Critical, res.Severity)
assert.Equal(t, "test", res.Message)
assert.Equal(t, sPathID, res.DedupKey)

alert.Severity = core.MEDIUM
alert.Alert.Sev = core.MEDIUM
res = alert.ToPagerdutyEvent()
assert.Equal(t, core.Error, res.Severity)

alert.Severity = core.LOW
alert.Alert.Sev = core.LOW
res = alert.ToPagerdutyEvent()
assert.Equal(t, core.Warning, res.Severity)
}
64 changes: 46 additions & 18 deletions internal/client/sns.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package client

import (
"context"
"encoding/json"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
Expand All @@ -25,6 +27,22 @@ type SNSConfig struct {
Endpoint string
}

// SNSMessagePayload ... The json message payload published to SNS
type SNSMessagePayload struct {
Network string `json:"network"`
HeuristicType string `json:"heuristic_type"`
Severity string `json:"severity"`
PathID string `json:"path_id"`
HeuristicID string `json:"heuristic_id"`
Timestamp time.Time `json:"timestamp"`
Content string `json:"content"`
}

// SNSMessage ... The SNS message structure. Required for SNS Publish API
type SNSMessage struct {
Default string `json:"default"`
}

type snsClient struct {
svc *sns.SNS
name string
Expand Down Expand Up @@ -55,13 +73,37 @@ func NewSNSClient(cfg *SNSConfig, name string) SNSClient {
}
}

// PostEvent ... Posts an event to an SNS topic ARN
// Marshal ... Marshals the SNS message payload
func (p *SNSMessagePayload) Marshal() ([]byte, error) {
payloadBytes, err := json.Marshal(p)
if err != nil {
return nil, err
}

msg := &SNSMessage{
Default: string(payloadBytes),
}

msgBytes, err := json.Marshal(msg)
if err != nil {
return nil, err
}

return msgBytes, nil
}

// PostEvent ... Publishes an event to an SNS topic ARN
func (sc snsClient) PostEvent(_ context.Context, event *AlertEventTrigger) (*AlertAPIResponse, error) {
msgPayload, err := event.ToSNSMessagePayload().Marshal()
if err != nil {
return nil, err
}

// Publish a message to the topic
result, err := sc.svc.Publish(&sns.PublishInput{
MessageAttributes: getAttributesFromEvent(event),
Message: &event.Message,
TopicArn: &sc.topicArn,
Message: aws.String(string(msgPayload)),
MessageStructure: aws.String("json"),
TopicArn: &sc.topicArn,
})
if err != nil {
return &AlertAPIResponse{
Expand All @@ -76,20 +118,6 @@ func (sc snsClient) PostEvent(_ context.Context, event *AlertEventTrigger) (*Ale
}, nil
}

// getAttributesFromEvent ... Helper method to get attributes from an AlertEventTrigger
func getAttributesFromEvent(event *AlertEventTrigger) map[string]*sns.MessageAttributeValue {
return map[string]*sns.MessageAttributeValue{
"severity": {
DataType: aws.String("String"),
StringValue: aws.String(event.Severity.String()),
},
"dedup_key": {
DataType: aws.String("String"),
StringValue: aws.String(event.DedupKey.String()),
},
}
}

func (sc snsClient) GetName() string {
return sc.name
}
51 changes: 51 additions & 0 deletions internal/client/sns_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package client

import (
"encoding/json"
"github.com/base-org/pessimism/internal/core"

Check failure on line 5 in internal/client/sns_test.go

View workflow job for this annotation

GitHub Actions / lint

File is not `goimports`-ed (goimports)
"github.com/stretchr/testify/assert"
"testing"
"time"
)

func TestSNSMessagePayload_Marshal(t *testing.T) {
alert := core.Alert{
Net: core.Layer1,
HT: core.BalanceEnforcement,
Sev: core.HIGH,
PathID: core.MakePathID(0, core.MakeProcessID(core.Live, 0, 0, 0), core.MakeProcessID(core.Live, 0, 0, 0)),
HeuristicID: core.UUID{},
Timestamp: time.Time{},
Content: "test",
}

event := &AlertEventTrigger{
Message: "test",
Alert: alert,
}

payload, err := event.ToSNSMessagePayload().Marshal()
if err != nil {
t.Fatal(err)
}

var snsPayload SNSMessage
err = json.Unmarshal(payload, &snsPayload)
if err != nil {
t.Fatal(err)
}

var snsMsgPayload SNSMessagePayload
err = json.Unmarshal([]byte(snsPayload.Default), &snsMsgPayload)
if err != nil {
t.Fatal(err)
}

assert.Equal(t, core.Layer1.String(), snsMsgPayload.Network)
assert.Equal(t, core.BalanceEnforcement.String(), snsMsgPayload.HeuristicType)
assert.Equal(t, core.HIGH.String(), snsMsgPayload.Severity)
assert.Equal(t, "test", snsMsgPayload.Content)
assert.Equal(t, alert.PathID.String(), snsMsgPayload.PathID)
assert.Equal(t, alert.HeuristicID.String(), snsMsgPayload.HeuristicID)
assert.Equal(t, alert.Timestamp, snsMsgPayload.Timestamp)
}
4 changes: 2 additions & 2 deletions internal/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ func (m *Metrics) RecordAlertGenerated(alert core.Alert, dest core.AlertDestinat
net := alert.PathID.Network().String()
h := alert.HT.String()
sev := alert.Sev.String()
path := alert.PathID.String()
id := alert.PathID.String()

m.AlertsGenerated.WithLabelValues(net, h, path, sev, dest.String(), clientName).Inc()
m.AlertsGenerated.WithLabelValues(net, h, id, sev, dest.String(), clientName).Inc()
}

func (m *Metrics) RecordNodeError(n core.Network) {
Expand Down
File renamed without changes.

0 comments on commit 74f2414

Please sign in to comment.