Skip to content

Commit

Permalink
rest of machines.go
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc committed Sep 18, 2024
1 parent e261ae6 commit 94f3abe
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 18 deletions.
2 changes: 1 addition & 1 deletion cmd/crowdsec-cli/climachine/machines.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ func (cli *cliMachines) prune(ctx context.Context, duration time.Duration, notVa
}

if !notValidOnly {
if pending, err := cli.db.QueryMachinesInactiveSince(time.Now().UTC().Add(-duration)); err == nil {
if pending, err := cli.db.QueryMachinesInactiveSince(ctx, time.Now().UTC().Add(-duration)); err == nil {
machines = append(machines, pending...)
}
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/apiserver/controllers/v1/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (
func (c *Controller) HeartBeat(gctx *gin.Context) {
machineID, _ := getMachineIDFromContext(gctx)

if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil {
ctx := gctx.Request.Context()

if err := c.DBClient.UpdateMachineLastHeartBeat(ctx, machineID); err != nil {
c.HandleDBErrors(gctx, err)
return
}
Expand Down
10 changes: 6 additions & 4 deletions pkg/apiserver/middlewares/v1/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
auth *authInput
)

ctx := c.Request.Context()

if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
auth, err = j.authTLS(c)
if err != nil {
Expand All @@ -200,7 +202,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
}
}

err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID)
err = j.DbClient.UpdateMachineScenarios(ctx, scenarios, auth.clientMachine.ID)
if err != nil {
log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err)
return nil, jwt.ErrFailedAuthentication
Expand All @@ -210,7 +212,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
clientIP := c.ClientIP()

if auth.clientMachine.IpAddress == "" {
err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID)
err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID)
if err != nil {
log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err)
return nil, jwt.ErrFailedAuthentication
Expand All @@ -220,7 +222,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
if auth.clientMachine.IpAddress != clientIP && auth.clientMachine.IpAddress != "" {
log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, clientIP, auth.clientMachine.IpAddress)

err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID)
err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID)
if err != nil {
log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err)
return nil, jwt.ErrFailedAuthentication
Expand All @@ -233,7 +235,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
return nil, jwt.ErrFailedAuthentication
}

if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil {
if err := j.DbClient.UpdateMachineVersion(ctx, useragent[1], auth.clientMachine.ID); err != nil {
log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err)
log.Errorf("bad user agent from : %s", clientIP)

Expand Down
24 changes: 12 additions & 12 deletions pkg/database/machines.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,51 +197,51 @@ func (c *Client) BulkDeleteWatchers(ctx context.Context, machines []*ent.Machine
return nbDeleted, nil
}

func (c *Client) UpdateMachineLastHeartBeat(machineID string) error {
_, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(c.CTX)
func (c *Client) UpdateMachineLastHeartBeat(ctx context.Context, machineID string) error {
_, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(ctx)
if err != nil {
return errors.Wrapf(UpdateFail, "updating machine last_heartbeat: %s", err)
}

return nil
}

func (c *Client) UpdateMachineScenarios(scenarios string, id int) error {
func (c *Client) UpdateMachineScenarios(ctx context.Context, scenarios string, id int) error {
_, err := c.Ent.Machine.UpdateOneID(id).
SetUpdatedAt(time.Now().UTC()).
SetScenarios(scenarios).
Save(c.CTX)
Save(ctx)
if err != nil {
return fmt.Errorf("unable to update machine in database: %w", err)
}

return nil
}

func (c *Client) UpdateMachineIP(ipAddr string, id int) error {
func (c *Client) UpdateMachineIP(ctx context.Context, ipAddr string, id int) error {
_, err := c.Ent.Machine.UpdateOneID(id).
SetIpAddress(ipAddr).
Save(c.CTX)
Save(ctx)
if err != nil {
return fmt.Errorf("unable to update machine IP in database: %w", err)
}

return nil
}

func (c *Client) UpdateMachineVersion(ipAddr string, id int) error {
func (c *Client) UpdateMachineVersion(ctx context.Context, ipAddr string, id int) error {
_, err := c.Ent.Machine.UpdateOneID(id).
SetVersion(ipAddr).
Save(c.CTX)
Save(ctx)
if err != nil {
return fmt.Errorf("unable to update machine version in database: %w", err)
}

return nil
}

func (c *Client) IsMachineRegistered(machineID string) (bool, error) {
exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(c.CTX)
func (c *Client) IsMachineRegistered(ctx context.Context, machineID string) (bool, error) {
exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(ctx)
if err != nil {
return false, err
}
Expand All @@ -257,11 +257,11 @@ func (c *Client) IsMachineRegistered(machineID string) (bool, error) {
return false, nil
}

func (c *Client) QueryMachinesInactiveSince(t time.Time) ([]*ent.Machine, error) {
func (c *Client) QueryMachinesInactiveSince(ctx context.Context, t time.Time) ([]*ent.Machine, error) {
return c.Ent.Machine.Query().Where(
machine.Or(
machine.And(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)),
machine.And(machine.LastHeartbeatIsNil(), machine.CreatedAtLT(t)),
),
).All(c.CTX)
).All(ctx)
}

0 comments on commit 94f3abe

Please sign in to comment.