Skip to content

Commit

Permalink
CreateMachine(ctx...)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc committed Sep 18, 2024
1 parent 37c0109 commit 1d54af3
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 11 deletions.
9 changes: 5 additions & 4 deletions cmd/crowdsec-cli/climachine/machines.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package climachine

import (
"context"
"encoding/csv"
"encoding/json"
"errors"
Expand Down Expand Up @@ -278,8 +279,8 @@ func (cli *cliMachines) newAddCmd() *cobra.Command {
cscli machines add MyTestMachine --auto
cscli machines add MyTestMachine --password MyPassword
cscli machines add -f- --auto > /tmp/mycreds.yaml`,
RunE: func(_ *cobra.Command, args []string) error {
return cli.add(args, string(password), dumpFile, apiURL, interactive, autoAdd, force)
RunE: func(cmd *cobra.Command, args []string) error {
return cli.add(cmd.Context(), args, string(password), dumpFile, apiURL, interactive, autoAdd, force)
},
}

Expand All @@ -294,7 +295,7 @@ cscli machines add -f- --auto > /tmp/mycreds.yaml`,
return cmd
}

func (cli *cliMachines) add(args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error {
func (cli *cliMachines) add(ctx context.Context, args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error {
var (
err error
machineID string
Expand Down Expand Up @@ -353,7 +354,7 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri

password := strfmt.Password(machinePassword)

_, err = cli.db.CreateMachine(&machineID, &password, "", true, force, types.PasswordAuthType)
_, err = cli.db.CreateMachine(ctx, &machineID, &password, "", true, force, types.PasswordAuthType)
if err != nil {
return fmt.Errorf("unable to create machine: %w", err)
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/apiserver/controllers/v1/machines.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ func (c *Controller) shouldAutoRegister(token string, gctx *gin.Context) (bool,
}

func (c *Controller) CreateMachine(gctx *gin.Context) {
ctx := gctx.Request.Context()

var input models.WatcherRegistrationRequest

if err := gctx.ShouldBindJSON(&input); err != nil {
Expand All @@ -66,7 +68,7 @@ func (c *Controller) CreateMachine(gctx *gin.Context) {
return
}

if _, err := c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil {
if _, err := c.DBClient.CreateMachine(ctx, input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil {
c.HandleDBErrors(gctx, err)
return
}
Expand Down
6 changes: 4 additions & 2 deletions pkg/apiserver/middlewares/v1/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ type authInput struct {
}

func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
// XXX: should we pass ctx instead
ctx := c.Request.Context()
ret := authInput{}

if j.TlsAuth == nil {
Expand All @@ -76,7 +78,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {

ret.clientMachine, err = j.DbClient.Ent.Machine.Query().
Where(machine.MachineId(ret.machineID)).
First(j.DbClient.CTX)
First(ctx)
if ent.IsNotFound(err) {
// Machine was not found, let's create it
logger.Infof("machine %s not found, create it", ret.machineID)
Expand All @@ -91,7 +93,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {

password := strfmt.Password(pwd)

ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType)
ret.clientMachine, err = j.DbClient.CreateMachine(ctx, &ret.machineID, &password, "", true, true, types.TlsAuthType)
if err != nil {
return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err)
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/database/machines.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (c *Client) MachineUpdateBaseMetrics(ctx context.Context, machineID string,
return nil
}

func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) {
func (c *Client) CreateMachine(ctx context.Context, machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) {
hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
if err != nil {
c.Log.Warningf("CreateMachine: %s", err)
Expand All @@ -82,14 +82,14 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA
machineExist, err := c.Ent.Machine.
Query().
Where(machine.MachineIdEQ(*machineID)).
Select(machine.FieldMachineId).Strings(c.CTX)
Select(machine.FieldMachineId).Strings(ctx)
if err != nil {
return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err)
}

if len(machineExist) > 0 {
if force {
_, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(c.CTX)
_, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(ctx)
if err != nil {
c.Log.Warningf("CreateMachine : %s", err)
return nil, errors.Wrapf(UpdateFail, "machine '%s'", *machineID)
Expand All @@ -113,7 +113,7 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA
SetIpAddress(ipAddress).
SetIsValidated(isValidated).
SetAuthType(authType).
Save(c.CTX)
Save(ctx)
if err != nil {
c.Log.Warningf("CreateMachine : %s", err)
return nil, errors.Wrapf(InsertFail, "creating machine '%s'", *machineID)
Expand Down

0 comments on commit 1d54af3

Please sign in to comment.