diff --git a/cmd/ydbcp/config.yaml b/cmd/ydbcp/config.yaml index 0ae64f38..2217a9ec 100644 --- a/cmd/ydbcp/config.yaml +++ b/cmd/ydbcp/config.yaml @@ -8,6 +8,10 @@ db_connection: client_connection: insecure: true discovery: false + allowed_endpoint_domains: + - .allowed-domain.com + - allowed-hostname.domain.com + allow_insecure_endpoint: false s3: endpoint: s3.endpoint.com diff --git a/cmd/ydbcp/main.go b/cmd/ydbcp/main.go index d4b2aac8..33757ce5 100644 --- a/cmd/ydbcp/main.go +++ b/cmd/ydbcp/main.go @@ -95,7 +95,14 @@ func main() { } }() - backup.NewBackupService(dbConnector, clientConnector, configInstance.S3, authProvider).Register(server) + backup.NewBackupService( + dbConnector, + clientConnector, + configInstance.S3, + authProvider, + configInstance.ClientConnection.AllowedEndpointDomains, + configInstance.ClientConnection.AllowInsecureEndpoint, + ).Register(server) operation.NewOperationService(dbConnector, authProvider).Register(server) if err := server.Start(ctx, &wg); err != nil { diff --git a/internal/auth/mock.go b/internal/auth/mock.go index c2bd649c..8e80ee02 100644 --- a/internal/auth/mock.go +++ b/internal/auth/mock.go @@ -184,13 +184,10 @@ func (p *MockAuthProvider) Authorize( return results, subject, nil } -func NewMockAuthProvider(ctx context.Context, options ...Option) (auth.AuthProvider, error) { +func NewMockAuthProvider(options ...Option) *MockAuthProvider { p := &MockAuthProvider{} for _, opt := range options { opt(p) } - if err := p.Init(ctx, ""); err != nil { - return nil, err - } - return p, nil + return p } diff --git a/internal/config/config.go b/internal/config/config.go index 578fc5dc..137dd9d7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "os" + "regexp" "strings" "ydbcp/internal/util/xlog" @@ -30,10 +31,12 @@ type YDBConnectionConfig struct { } type ClientConnectionConfig struct { - Insecure bool `yaml:"insecure"` - Discovery bool `yaml:"discovery" default:"true"` - DialTimeoutSeconds uint32 `yaml:"dial_timeout_seconds" default:"5"` - OAuth2KeyFile string `yaml:"oauth2_key_file"` + Insecure bool `yaml:"insecure"` + Discovery bool `yaml:"discovery" default:"true"` + DialTimeoutSeconds uint32 `yaml:"dial_timeout_seconds" default:"5"` + OAuth2KeyFile string `yaml:"oauth2_key_file"` + AllowedEndpointDomains []string `yaml:"allowed_endpoint_domains"` + AllowInsecureEndpoint bool `yaml:"allow_insecure_endpoint"` } type AuthConfig struct { @@ -57,6 +60,10 @@ type Config struct { GRPCServer GRPCServerConfig `yaml:"grpc_server"` } +var ( + validDomainFilter = regexp.MustCompile(`^[A-Za-z\.][A-Za-z0-9\-\.]+[A-Za-z]$`) +) + func (config Config) ToString() (string, error) { data, err := yaml.Marshal(&config) if err != nil { @@ -82,11 +89,20 @@ func InitConfig(ctx context.Context, confPath string) (Config, error) { zap.Error(err)) return Config{}, err } - return config, nil + return config, config.Validate() } return Config{}, errors.New("configuration file path is empty") } +func (c *Config) Validate() error { + for _, domain := range c.ClientConnection.AllowedEndpointDomains { + if !validDomainFilter.MatchString(domain) { + return fmt.Errorf("incorrect domain filter in allowed_endpoint_domains: %s", domain) + } + } + return nil +} + func readSecret(filename string) (string, error) { rawSecret, err := os.ReadFile(filename) if err != nil { diff --git a/internal/server/services/backup/backupservice.go b/internal/server/services/backup/backupservice.go index d8187f7d..fd97a3bc 100644 --- a/internal/server/services/backup/backupservice.go +++ b/internal/server/services/backup/backupservice.go @@ -3,6 +3,7 @@ package backup import ( "context" "path" + "regexp" "strings" "time" @@ -26,10 +27,39 @@ import ( type BackupService struct { pb.UnimplementedBackupServiceServer - driver db.DBConnector - clientConn client.ClientConnector - s3 config.S3Config - auth ap.AuthProvider + driver db.DBConnector + clientConn client.ClientConnector + s3 config.S3Config + auth ap.AuthProvider + allowedEndpointDomains []string + allowInsecureEndpoint bool +} + +var ( + validEndpoint = regexp.MustCompile(`^(grpcs://|grpc://)?([A-Za-z0-9\-\.]+)(:[0-9]+)?$`) +) + +func (s *BackupService) isAllowedEndpoint(e string) bool { + groups := validEndpoint.FindStringSubmatch(e) + if len(groups) < 3 { + return false + } + tls := groups[1] == "grpcs://" + if !tls && !s.allowInsecureEndpoint { + return false + } + fqdn := groups[2] + + for _, domain := range s.allowedEndpointDomains { + if strings.HasPrefix(domain, ".") { + if strings.HasSuffix(fqdn, domain) { + return true + } + } else if fqdn == domain { + return true + } + } + return false } func (s *BackupService) GetBackup(ctx context.Context, request *pb.GetBackupRequest) (*pb.Backup, error) { @@ -75,9 +105,13 @@ func (s *BackupService) MakeBackup(ctx context.Context, req *pb.MakeBackupReques } xlog.Debug(ctx, "MakeBackup", zap.String("subject", subject)) + if !s.isAllowedEndpoint(req.DatabaseEndpoint) { + return nil, status.Errorf(codes.InvalidArgument, "endpoint of database is invalid or not allowed, endpoint %s", req.DatabaseEndpoint) + } + clientConnectionParams := types.YdbConnectionParams{ - Endpoint: req.GetDatabaseEndpoint(), - DatabaseName: req.GetDatabaseName(), + Endpoint: req.DatabaseEndpoint, + DatabaseName: req.DatabaseName, } dsn := types.MakeYdbConnectionString(clientConnectionParams) client, err := s.clientConn.Open(ctx, dsn) @@ -183,9 +217,13 @@ func (s *BackupService) MakeRestore(ctx context.Context, req *pb.MakeRestoreRequ } xlog.Debug(ctx, "MakeRestore", zap.String("subject", subject)) + if !s.isAllowedEndpoint(req.DatabaseEndpoint) { + return nil, status.Errorf(codes.InvalidArgument, "endpoint of database is invalid or not allowed, endpoint %s", req.DatabaseEndpoint) + } + clientConnectionParams := types.YdbConnectionParams{ - Endpoint: req.GetDatabaseEndpoint(), - DatabaseName: req.GetDatabaseName(), + Endpoint: req.DatabaseEndpoint, + DatabaseName: req.DatabaseName, } dsn := types.MakeYdbConnectionString(clientConnectionParams) client, err := s.clientConn.Open(ctx, dsn) @@ -309,11 +347,15 @@ func NewBackupService( clientConn client.ClientConnector, s3 config.S3Config, auth ap.AuthProvider, + allowedEndpointDomains []string, + allowInsecureEndpoint bool, ) *BackupService { return &BackupService{ - driver: driver, - clientConn: clientConn, - s3: s3, - auth: auth, + driver: driver, + clientConn: clientConn, + s3: s3, + auth: auth, + allowedEndpointDomains: allowedEndpointDomains, + allowInsecureEndpoint: allowInsecureEndpoint, } } diff --git a/internal/server/services/backup/backupservice_test.go b/internal/server/services/backup/backupservice_test.go new file mode 100644 index 00000000..56bc5116 --- /dev/null +++ b/internal/server/services/backup/backupservice_test.go @@ -0,0 +1,60 @@ +package backup + +import ( + "testing" + "ydbcp/internal/auth" + "ydbcp/internal/config" + "ydbcp/internal/connectors/client" + "ydbcp/internal/connectors/db" + + "github.com/stretchr/testify/assert" +) + +func TestEndpointValidation(t *testing.T) { + dbConnector := db.NewMockDBConnector() + clientConnector := client.NewMockClientConnector() + auth := auth.NewMockAuthProvider() + + s := NewBackupService( + dbConnector, + clientConnector, + config.S3Config{}, + auth, + []string{".valid.com", "hostname.good.com"}, + true, + ) + + assert.True(t, s.isAllowedEndpoint("grpc://some-host.zone.valid.com")) + assert.False(t, s.isAllowedEndpoint("grpcs://host.zone.invalid.com")) + assert.True(t, s.isAllowedEndpoint("grpcs://hostname.good.com:1234")) + assert.True(t, s.isAllowedEndpoint("example.valid.com:1234")) + assert.False(t, s.isAllowedEndpoint("grpcs://something.hostname.good.com:1234")) + assert.False(t, s.isAllowedEndpoint("")) + assert.False(t, s.isAllowedEndpoint("grpcs://evilvalid.com:1234")) + assert.False(t, s.isAllowedEndpoint("badhostname.good.com")) + assert.False(t, s.isAllowedEndpoint("some^bad$symbols.valid.com")) +} + +func TestEndpointSecureValidation(t *testing.T) { + dbConnector := db.NewMockDBConnector() + clientConnector := client.NewMockClientConnector() + auth := auth.NewMockAuthProvider() + + s := NewBackupService( + dbConnector, + clientConnector, + config.S3Config{}, + auth, + []string{".valid.com", "hostname.good.com"}, + false, + ) + + assert.False(t, s.isAllowedEndpoint("grpc://some-host.zone.valid.com")) + assert.False(t, s.isAllowedEndpoint("grpcs://host.zone.invalid.com")) + assert.False(t, s.isAllowedEndpoint("host.zone.valid.com")) + assert.True(t, s.isAllowedEndpoint("grpcs://hostname.good.com:1234")) + assert.False(t, s.isAllowedEndpoint("grpcs://something.hostname.good.com:1234")) + assert.False(t, s.isAllowedEndpoint("")) + assert.False(t, s.isAllowedEndpoint("grpcs://evilvalid.com:1234")) + assert.False(t, s.isAllowedEndpoint("badhostname.good.com")) +}