Skip to content

Commit

Permalink
context propagation: pkg/database/config
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc committed Sep 18, 2024
1 parent 8e369f5 commit 6991c67
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 35 deletions.
2 changes: 1 addition & 1 deletion cmd/crowdsec-cli/clipapi/papi.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Clie
return fmt.Errorf("unable to get PAPI permissions: %w", err)
}

lastTimestampStr, err := db.GetConfigItem(apiserver.PapiPullKey)
lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey)
if err != nil {
lastTimestampStr = ptr.Of("never")
}
Expand Down
28 changes: 14 additions & 14 deletions pkg/apiserver/apic.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio
// we receive a list of decisions and links for blocklist and we need to create a list of alerts :
// one alert for "community blocklist"
// one alert per list we're subscribed to
func (a *apic) PullTop(forcePull bool) error {
func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
var err error

// A mutex with TryLock would be a bit simpler
Expand Down Expand Up @@ -655,7 +655,7 @@ func (a *apic) PullTop(forcePull bool) error {

log.Infof("Starting community-blocklist update")

data, _, err := a.apiClient.Decisions.GetStreamV3(context.Background(), apiclient.DecisionsStreamOpts{Startup: a.startup})
data, _, err := a.apiClient.Decisions.GetStreamV3(ctx, apiclient.DecisionsStreamOpts{Startup: a.startup})
if err != nil {
return fmt.Errorf("get stream: %w", err)
}
Expand Down Expand Up @@ -700,17 +700,17 @@ func (a *apic) PullTop(forcePull bool) error {
}

// update blocklists
if err := a.UpdateBlocklists(data.Links, addCounters, forcePull); err != nil {
if err := a.UpdateBlocklists(ctx, data.Links, addCounters, forcePull); err != nil {
return fmt.Errorf("while updating blocklists: %w", err)
}

return nil
}

// we receive a link to a blocklist, we pull the content of the blocklist and we create one alert
func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error {
func (a *apic) PullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, forcePull bool) error {
addCounters, _ := makeAddAndDeleteCounters()
if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{
if err := a.UpdateBlocklists(ctx, &modelscapi.GetDecisionsStreamResponseLinks{
Blocklists: []*modelscapi.BlocklistLink{blocklist},
}, addCounters, forcePull); err != nil {
return fmt.Errorf("while pulling blocklist: %w", err)
Expand Down Expand Up @@ -820,7 +820,7 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo
return false, nil
}

func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error {
func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error {
if blocklist.Scope == nil {
log.Warningf("blocklist has no scope")
return nil
Expand Down Expand Up @@ -848,13 +848,13 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
)

if !forcePull {
lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
lastPullTimestamp, err = a.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
if err != nil {
return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
}
}

decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp)
decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(ctx, blocklist, lastPullTimestamp)
if err != nil {
return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err)
}
Expand All @@ -869,7 +869,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
return nil
}

err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
err = a.dbClient.SetConfigItem(ctx, blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
if err != nil {
return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
}
Expand All @@ -892,7 +892,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
return nil
}

func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error {
func (a *apic) UpdateBlocklists(ctx context.Context, links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error {
if links == nil {
return nil
}
Expand All @@ -908,7 +908,7 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink
}

for _, blocklist := range links.Blocklists {
if err := a.updateBlocklist(defaultClient, blocklist, addCounters, forcePull); err != nil {
if err := a.updateBlocklist(ctx, defaultClient, blocklist, addCounters, forcePull); err != nil {
return err
}
}
Expand All @@ -931,7 +931,7 @@ func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int
}
}

func (a *apic) Pull() error {
func (a *apic) Pull(ctx context.Context) error {
defer trace.CatchPanic("lapi/pullFromAPIC")

toldOnce := false
Expand All @@ -955,7 +955,7 @@ func (a *apic) Pull() error {
time.Sleep(1 * time.Second)
}

if err := a.PullTop(false); err != nil {
if err := a.PullTop(ctx, false); err != nil {
log.Errorf("capi pull top: %s", err)
}

Expand All @@ -967,7 +967,7 @@ func (a *apic) Pull() error {
case <-ticker.C:
ticker.Reset(a.pullInterval)

if err := a.PullTop(false); err != nil {
if err := a.PullTop(ctx, false); err != nil {
log.Errorf("capi pull top: %s", err)
continue
}
Expand Down
18 changes: 10 additions & 8 deletions pkg/apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,17 +310,17 @@ func (s *APIServer) apicPush() error {
return nil
}

func (s *APIServer) apicPull() error {
if err := s.apic.Pull(); err != nil {
func (s *APIServer) apicPull(ctx context.Context) error {
if err := s.apic.Pull(ctx); err != nil {
log.Errorf("capi pull: %s", err)
return err
}

return nil
}

func (s *APIServer) papiPull() error {
if err := s.papi.Pull(); err != nil {
func (s *APIServer) papiPull(ctx context.Context) error {
if err := s.papi.Pull(ctx); err != nil {
log.Errorf("papi pull: %s", err)
return err
}
Expand All @@ -337,16 +337,16 @@ func (s *APIServer) papiSync() error {
return nil
}

func (s *APIServer) initAPIC() {
func (s *APIServer) initAPIC(ctx context.Context) {
s.apic.pushTomb.Go(s.apicPush)
s.apic.pullTomb.Go(s.apicPull)
s.apic.pullTomb.Go(func() error { return s.apicPull(ctx) })

// csConfig.API.Server.ConsoleConfig.ShareCustomScenarios
if s.apic.apiClient.IsEnrolled() {
if s.consoleConfig.IsPAPIEnabled() {
if s.papi.URL != "" {
log.Info("Starting PAPI decision receiver")
s.papi.pullTomb.Go(s.papiPull)
s.papi.pullTomb.Go(func() error { return s.papiPull(ctx) })
s.papi.syncTomb.Go(s.papiSync)
} else {
log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.")
Expand Down Expand Up @@ -381,8 +381,10 @@ func (s *APIServer) Run(apiReady chan bool) error {
TLSConfig: tlsCfg,
}

ctx := context.TODO()

if s.apic != nil {
s.initAPIC()
s.initAPIC(ctx)
}

s.httpServerTomb.Go(func() error {
Expand Down
8 changes: 4 additions & 4 deletions pkg/apiserver/papi.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,13 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error {
}

// PullPAPI is the long polling client for real-time decisions from PAPI
func (p *Papi) Pull() error {
func (p *Papi) Pull(ctx context.Context) error {
defer trace.CatchPanic("lapi/PullPAPI")
p.Logger.Infof("Starting Polling API Pull")

lastTimestamp := time.Time{}

lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey)
lastTimestampStr, err := p.DBClient.GetConfigItem(ctx, PapiPullKey)
if err != nil {
p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err)
}
Expand All @@ -248,7 +248,7 @@ func (p *Papi) Pull() error {
return fmt.Errorf("failed to serialize last timestamp: %w", err)
}

if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil {
if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil {
p.Logger.Errorf("error setting papi pull last key: %s", err)
} else {
p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime))
Expand Down Expand Up @@ -277,7 +277,7 @@ func (p *Papi) Pull() error {
continue
}

if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil {
if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil {
return fmt.Errorf("failed to update last timestamp: %w", err)
}

Expand Down
7 changes: 5 additions & 2 deletions pkg/apiserver/papi_cmd.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package apiserver

import (
"context"
"encoding/json"
"fmt"
"time"
Expand Down Expand Up @@ -215,17 +216,19 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err)
}

ctx := context.TODO()

if forcePullMsg.Blocklist == nil {
p.Logger.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists")

err = p.apic.PullTop(true)
err = p.apic.PullTop(ctx, true)
if err != nil {
return fmt.Errorf("failed to force pull operation: %w", err)
}
} else {
p.Logger.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name)

err = p.apic.PullBlocklist(&modelscapi.BlocklistLink{
err = p.apic.PullBlocklist(ctx, &modelscapi.BlocklistLink{
Name: &forcePullMsg.Blocklist.Name,
URL: &forcePullMsg.Blocklist.Url,
Remediation: &forcePullMsg.Blocklist.Remediation,
Expand Down
12 changes: 6 additions & 6 deletions pkg/database/config.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package database

import (
"context"
"github.com/pkg/errors"

"github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem"
)

func (c *Client) GetConfigItem(key string) (*string, error) {
result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(c.CTX)
func (c *Client) GetConfigItem(ctx context.Context, key string) (*string, error) {
result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx)
if err != nil && ent.IsNotFound(err) {
return nil, nil
}
Expand All @@ -19,11 +20,10 @@ func (c *Client) GetConfigItem(key string) (*string, error) {
return &result.Value, nil
}

func (c *Client) SetConfigItem(key string, value string) error {

nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(c.CTX)
func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error {
nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx)
if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { //not found, create
err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(c.CTX)
err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx)
if err != nil {
return errors.Wrapf(QueryFail, "insert config item: %s", err)
}
Expand Down

0 comments on commit 6991c67

Please sign in to comment.