Skip to content

Commit

Permalink
parse config on start, remove os.Getenv
Browse files Browse the repository at this point in the history
  • Loading branch information
agalitsyn committed Jan 19, 2023
1 parent cedd808 commit aa892cb
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 219 deletions.
78 changes: 54 additions & 24 deletions cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"flag"
"fmt"
"io"
"net/url"
"os"
"os/signal"
"path"
"path/filepath"
"syscall"
"time"

Expand Down Expand Up @@ -35,13 +37,14 @@ import (
const serviceName = "vxapi"

type Config struct {
Debug bool `config:"debug"`
Develop bool `config:"is_develop"`
Log LogConfig
DB DBConfig
Tracing TracingConfig
PublicAPI PublicAPIConfig
EventWorker EventWorkerConfig
Debug bool `config:"debug"`
Develop bool `config:"is_develop"`
Log LogConfig
DB DBConfig
Tracing TracingConfig
PublicAPI PublicAPIConfig
EventWorker EventWorkerConfig
ServerEventWorker ServerEventWorkerConfig
}

type LogConfig struct {
Expand All @@ -51,11 +54,12 @@ type LogConfig struct {
}

type DBConfig struct {
User string `config:"db_user,required"`
Pass string `config:"db_pass,required"`
Name string `config:"db_name,required"`
Host string `config:"db_host,required"`
Port int `config:"db_port,required"`
User string `config:"db_user,required"`
Pass string `config:"db_pass,required"`
Name string `config:"db_name,required"`
Host string `config:"db_host,required"`
Port int `config:"db_port,required"`
MigrationDir string `config:"migration_dir"`
}

type PublicAPIConfig struct {
Expand All @@ -65,6 +69,10 @@ type PublicAPIConfig struct {
CertFile string `config:"api_ssl_crt"`
KeyFile string `config:"api_ssl_key"`
GracefulTimeout time.Duration `config:"public_api_graceful_timeout"`
StaticPath string `config:"api_static_path"`
StaticURL string `config:"api_static_url"`
TemplatesDir string `config:"templates_dir"`
CertsPath string `config:"certs_path"`
}

type TracingConfig struct {
Expand All @@ -75,6 +83,10 @@ type EventWorkerConfig struct {
PollInterval time.Duration `config:"event_worker_poll_interval"`
}

type ServerEventWorkerConfig struct {
KeepDays int `config:"retention_events"`
}

func defaultConfig() Config {
return Config{
Log: LogConfig{
Expand All @@ -85,14 +97,23 @@ func defaultConfig() Config {
Tracing: TracingConfig{
Addr: "otel.local:8148",
},
DB: DBConfig{
MigrationDir: "db/api/migrations",
},
PublicAPI: PublicAPIConfig{
Addr: ":8080",
AddrHTTPS: ":8443",
GracefulTimeout: time.Minute,
TemplatesDir: "templates",
StaticPath: "static",
CertsPath: filepath.Join("security", "certs", "api"),
},
EventWorker: EventWorkerConfig{
PollInterval: 30 * time.Second,
},
ServerEventWorker: ServerEventWorkerConfig{
KeepDays: 7,
},
}
}

Expand Down Expand Up @@ -160,12 +181,7 @@ func main() {
logrus.WithError(err).Error("could not connect to database")
return
}

migrationDir := "db/api/migrations"
if dir, ok := os.LookupEnv("MIGRATION_DIR"); ok {
migrationDir = dir
}
if err = db.Migrate(migrationDir); err != nil {
if err = db.Migrate(cfg.DB.MigrationDir); err != nil {
logrus.WithError(err).Error("could not apply migrations")
return
}
Expand Down Expand Up @@ -263,27 +279,41 @@ func main() {
go worker.SyncModulesToPolicies(ctx, dbWithORM)

// run worker to synchronize events retention policy to all instance DB
go worker.SyncRetentionEvents(ctx, dbWithORM)
go worker.SyncRetentionEvents(ctx, dbWithORM, cfg.ServerEventWorker.KeepDays)

uiStaticURL, err := url.Parse(cfg.PublicAPI.StaticURL)
if err != nil {
logrus.WithError(err).Error("error on parsing URL to redirect requests to the UI static")
return
}
userActionWriter := useraction.NewLogWriter()

router := server.NewRouter(
server.RouterConfig{
BaseURL: "/api/v1",
Debug: cfg.Debug,
UseSSL: cfg.PublicAPI.UseSSL,
StaticPath: cfg.PublicAPI.StaticPath,
StaticURL: uiStaticURL,
TemplatesDir: cfg.PublicAPI.TemplatesDir,
CertsPath: cfg.PublicAPI.CertsPath,
},
dbWithORM,
exchanger,
userActionWriter,
dbConnectionStorage,
s3ConnectionStorage,
)

srvg, ctx := errgroup.WithContext(ctx)
srvg.Go(func() error {
group, ctx := errgroup.WithContext(ctx)
group.Go(func() error {
return server.Server{
Addr: cfg.PublicAPI.Addr,
GracefulTimeout: cfg.PublicAPI.GracefulTimeout,
}.ListenAndServe(ctx, router)
})
if cfg.PublicAPI.UseSSL {
srvg.Go(func() error {
group.Go(func() error {
return server.Server{
Addr: cfg.PublicAPI.AddrHTTPS,
CertFile: cfg.PublicAPI.CertFile,
Expand All @@ -292,7 +322,7 @@ func main() {
}.ListenAndServeTLS(ctx, router)
})
}
if err = srvg.Wait(); err != nil {
logrus.WithError(err).Error("failed to start server")
if err = group.Wait(); err != nil {
logrus.WithError(err).Error("could not start services")
}
}
5 changes: 2 additions & 3 deletions pkg/app/api/server/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@ import (
"soldr/pkg/app/api/server/context"
"soldr/pkg/app/api/server/private"
"soldr/pkg/app/api/server/response"
"soldr/pkg/app/api/utils"
"soldr/pkg/app/api/utils/dbencryptor"
)

func authTokenProtoRequired() gin.HandlerFunc {
func authTokenProtoRequired(apiBaseURL string) gin.HandlerFunc {
privInteractive := "vxapi.modules.interactive"
connTypeRegexp := regexp.MustCompile(
fmt.Sprintf("%s/vxpws/(aggregate|browser|external)/.*", utils.PrefixPathAPI),
fmt.Sprintf("%s/vxpws/(aggregate|browser|external)/.*", apiBaseURL),
)
return func(c *gin.Context) {
if c.IsAborted() {
Expand Down
11 changes: 5 additions & 6 deletions pkg/app/api/server/private/modules.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ func BuildModuleSConfig(module *models.ModuleS) (map[string][]byte, error) {
return files, nil
}

func LoadModuleSTemplate(mi *models.ModuleInfo) (Template, *models.ModuleS, error) {
func LoadModuleSTemplate(mi *models.ModuleInfo, templatesDir string) (Template, *models.ModuleS, error) {
fs, err := storage.NewFS()
if err != nil {
return nil, nil, errors.New("failed initialize FS driver: " + err.Error())
Expand All @@ -589,10 +589,6 @@ func LoadModuleSTemplate(mi *models.ModuleInfo) (Template, *models.ModuleS, erro
var module *models.ModuleS
template := make(Template)
loadModuleDir := func(dir string) (map[string][]byte, error) {
templatesDir := "templates"
if dir, ok := os.LookupEnv("TEMPLATES_DIR"); ok {
templatesDir = dir
}
tpath := joinPath(templatesDir, mi.Template, dir)
if fs.IsNotExist(tpath) {
return nil, errors.New("template directory not found")
Expand Down Expand Up @@ -1298,17 +1294,20 @@ type ModuleService struct {
db *gorm.DB
serverConnector *client.AgentServerClient
userActionWriter useraction.Writer
templatesDir string
}

func NewModuleService(
db *gorm.DB,
serverConnector *client.AgentServerClient,
userActionWriter useraction.Writer,
templatesDir string,
) *ModuleService {
return &ModuleService{
db: db,
serverConnector: serverConnector,
userActionWriter: userActionWriter,
templatesDir: templatesDir,
}
}

Expand Down Expand Up @@ -3050,7 +3049,7 @@ func (s *ModuleService) CreateModule(c *gin.Context) {
info.System = false

var err error
if template, module, err = LoadModuleSTemplate(&info); err != nil {
if template, module, err = LoadModuleSTemplate(&info, s.templatesDir); err != nil {
logrus.WithError(err).Errorf("error loading module")
response.Error(c, response.ErrCreateModuleLoadFail, err)
return
Expand Down
39 changes: 22 additions & 17 deletions pkg/app/api/server/proto/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"fmt"
"net/http"
"net/url"
"os"
"path/filepath"

"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -126,12 +125,9 @@ func doVXServerConnection(
ctxConn *ctxVXConnection,
agentInfo *protoagent.Information,
ltacGetter vxcommonVM.LTACGetter,
certsPath string,
) (*socket, error) {
certsDir := filepath.Join("security", "certs", "api")
if dir, ok := os.LookupEnv("CERTS_PATH"); ok {
certsDir = dir
}
hardeningVM, packEncryptor, err := prepareVM(ctxConn, NewCertProvider(certsDir), ltacGetter)
hardeningVM, packEncryptor, err := prepareVM(ctxConn, NewCertProvider(certsPath), ltacGetter)
if err != nil {
return nil, fmt.Errorf("failed to prepare VM: %w", err)
}
Expand Down Expand Up @@ -208,7 +204,14 @@ func sendAuthResp(ctx context.Context, conn vxproto.IConnection, authRespMessage
return nil
}

func wsConnectToVXServer(c *gin.Context, connType vxproto.AgentType, sockID, sockType string, uaf useraction.Fields) {
func wsConnectToVXServer(
c *gin.Context,
connType vxproto.AgentType,
sockID string,
sockType string,
certsPath string,
uaf useraction.Fields,
) {
var (
serverConn *socket
sv *models.Service
Expand Down Expand Up @@ -275,16 +278,15 @@ func wsConnectToVXServer(c *gin.Context, connType vxproto.AgentType, sockID, soc
sv: sv,
logger: logger,
}
certsDir := filepath.Join("security", "certs", "api")
if dir, ok := os.LookupEnv("CERTS_PATH"); ok {
certsDir = dir
}

logger.WithField("auth_req", authReq).Debug("try doVXServerConnection")
if serverConn, err = doVXServerConnection(
serverConn, err = doVXServerConnection(
ctxConn,
agentInfo,
NewStore(filepath.Join(certsDir, sockType)),
); err != nil {
NewStore(filepath.Join(certsPath, sockType)),
certsPath,
)
if err != nil {
clientConn.Close(c.Request.Context())
logger.WithError(err).Error("failed to initialize connection to server")
uaf.FailReason = "failed to initialize connection to server"
Expand Down Expand Up @@ -348,17 +350,20 @@ type ProtoService struct {
db *gorm.DB
serverConnector *client.AgentServerClient
userActionWriter useraction.Writer
certsPath string
}

func NewProtoService(
db *gorm.DB,
serverConnector *client.AgentServerClient,
userActionWriter useraction.Writer,
certsPath string,
) *ProtoService {
return &ProtoService{
db: db,
serverConnector: serverConnector,
userActionWriter: userActionWriter,
certsPath: certsPath,
}
}

Expand Down Expand Up @@ -407,7 +412,7 @@ func (s *ProtoService) AggregateWSConnect(c *gin.Context) {
return
}

wsConnectToVXServer(c, vxproto.Aggregate, sockID, sockType, uaf)
wsConnectToVXServer(c, vxproto.Aggregate, sockID, sockType, s.certsPath, uaf)
}

func (s *ProtoService) BrowserWSConnect(c *gin.Context) {
Expand Down Expand Up @@ -456,7 +461,7 @@ func (s *ProtoService) BrowserWSConnect(c *gin.Context) {
return
}

wsConnectToVXServer(c, vxproto.Browser, sockID, sockType, uaf)
wsConnectToVXServer(c, vxproto.Browser, sockID, sockType, s.certsPath, uaf)
}

func (s *ProtoService) ExternalWSConnect(c *gin.Context) {
Expand Down Expand Up @@ -503,5 +508,5 @@ func (s *ProtoService) ExternalWSConnect(c *gin.Context) {
return
}

wsConnectToVXServer(c, vxproto.External, sockID, sockType, uaf)
wsConnectToVXServer(c, vxproto.External, sockID, sockType, s.certsPath, uaf)
}
Loading

0 comments on commit aa892cb

Please sign in to comment.