diff --git a/cli/cmd/runtime/start.go b/cli/cmd/runtime/start.go index b0d25c83687..e5bee35898f 100644 --- a/cli/cmd/runtime/start.go +++ b/cli/cmd/runtime/start.go @@ -23,6 +23,7 @@ import ( "golang.org/x/sync/errgroup" // Load connectors and reconcilers for runtime + _ "github.com/rilldata/rill/runtime/drivers/athena" _ "github.com/rilldata/rill/runtime/drivers/bigquery" _ "github.com/rilldata/rill/runtime/drivers/druid" _ "github.com/rilldata/rill/runtime/drivers/duckdb" diff --git a/go.mod b/go.mod index 9b3236e1d9c..1d8be9ab430 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/apache/arrow/go/v13 v13.0.0 github.com/apache/calcite-avatica-go/v5 v5.2.0 github.com/aws/aws-sdk-go v1.44.268 + github.com/aws/aws-sdk-go-v2/service/athena v1.31.6 github.com/benbjohnson/clock v1.3.5 github.com/bmatcuk/doublestar/v4 v4.6.0 github.com/bradleyfalzon/ghinstallation/v2 v2.4.0 @@ -108,25 +109,25 @@ require ( github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect github.com/andybalholm/brotli v1.0.5 // indirect github.com/apache/thrift v0.18.1 // indirect - github.com/aws/aws-sdk-go-v2 v1.18.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.21.0 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 // indirect - github.com/aws/aws-sdk-go-v2/config v1.18.25 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.13.24 // indirect + github.com/aws/aws-sdk-go-v2/config v1.18.25 + github.com/aws/aws-sdk-go-v2/credentials v1.13.24 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.3 // indirect github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.67 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.27 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.3.34 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.25 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.28 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.27 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.2 // indirect - github.com/aws/aws-sdk-go-v2/service/s3 v1.33.1 // indirect + github.com/aws/aws-sdk-go-v2/service/s3 v1.33.1 github.com/aws/aws-sdk-go-v2/service/sso v1.12.10 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.10 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.19.0 // indirect - github.com/aws/smithy-go v1.13.5 // indirect + github.com/aws/smithy-go v1.14.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect diff --git a/go.sum b/go.sum index 15929a79067..ee0eb330dd3 100644 --- a/go.sum +++ b/go.sum @@ -592,8 +592,9 @@ github.com/aws/aws-sdk-go v1.44.268 h1:WoK20tlAvsvQzTcE6TajoprbXmTbcud6MjhErL4P/ github.com/aws/aws-sdk-go v1.44.268/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/aws/aws-sdk-go-v2 v1.9.1/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4= github.com/aws/aws-sdk-go-v2 v1.17.4/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= -github.com/aws/aws-sdk-go-v2 v1.18.0 h1:882kkTpSFhdgYRKVZ/VCgf7sd0ru57p2JCxz4/oN5RY= github.com/aws/aws-sdk-go-v2 v1.18.0/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= +github.com/aws/aws-sdk-go-v2 v1.21.0 h1:gMT0IW+03wtYJhRqTVYn0wLzwdnK9sRMcxmtfGzRdJc= +github.com/aws/aws-sdk-go-v2 v1.21.0/go.mod h1:/RfNgGmRxI+iFOB1OeJUyxiU+9s88k3pfHvDagGEp0M= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 h1:dK82zF6kkPeCo8J1e+tGx4JdvDIQzj7ygIoLg8WMuGs= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10/go.mod h1:VeTZetY5KRJLuD/7fkQXMU6Mw7H5m/KP2J5Iy9osMno= github.com/aws/aws-sdk-go-v2/config v1.18.12/go.mod h1:J36fOhj1LQBr+O4hJCiT8FwVvieeoSGOtPuvhKlsNu8= @@ -609,17 +610,21 @@ github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.51/go.mod h1:7Grl2gV+dx9SW github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.67 h1:fI9/5BDEaAv/pv1VO1X1n3jfP9it+IGqWsCuuBQI8wM= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.67/go.mod h1:zQClPRIwQZfJlZq6WZve+s4Tb4JW+3V6eS+4+KrYeP8= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.28/go.mod h1:3lwChorpIM/BhImY/hy+Z6jekmN92cXGPI1QJasVPYY= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33 h1:kG5eQilShqmJbv11XL1VpyDbaEJzWxd4zRiCG30GSn4= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33/go.mod h1:7i0PF1ME/2eUPFcjkVIwq+DOygHEoK92t5cDqNgYbIw= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41 h1:22dGT7PneFMx4+b3pz7lMTRyN8ZKH7M2cW4GP9yUS2g= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41/go.mod h1:CrObHAuPneJBlfEJ5T3szXOUkLEThaGfvnhTf33buas= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.22/go.mod h1:EqK7gVrIGAHyZItrD1D8B0ilgwMD1GiWAmbU4u/JHNk= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.27 h1:vFQlirhuM8lLlpI7imKOMsjdQLuN9CPi+k44F/OFVsk= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.27/go.mod h1:UrHnn3QV/d0pBZ6QBAEQcqFLf8FAzLmoUfPVIueOvoM= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35 h1:SijA0mgjV8E+8G45ltVHs0fvKpTj8xmZJ3VwhGKtUSI= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35/go.mod h1:SJC1nEVVva1g3pHAIdCp7QsRIkMmLAgoDquQ9Rr8kYw= github.com/aws/aws-sdk-go-v2/internal/ini v1.3.29/go.mod h1:TwuqRBGzxjQJIwH16/fOZodwXt2Zxa9/cwJC5ke4j7s= github.com/aws/aws-sdk-go-v2/internal/ini v1.3.34 h1:gGLG7yKaXG02/jBlg210R7VgQIotiQntNhsCFejawx8= github.com/aws/aws-sdk-go-v2/internal/ini v1.3.34/go.mod h1:Etz2dj6UHYuw+Xw830KfzCfWGMzqvUTCjUj5b76GVDc= github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.19/go.mod h1:8W88sW3PjamQpKFUQvHWWKay6ARsNvZnzU7+a4apubw= github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.25 h1:AzwRi5OKKwo4QNqPf7TjeO+tK8AyOK3GVSwmRPo7/Cs= github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.25/go.mod h1:SUbB4wcbSEyCvqBxv/O/IBf93RbEze7U7OnoTlpPB+g= +github.com/aws/aws-sdk-go-v2/service/athena v1.31.6 h1:EFaTu1rBt+KQglDeYRpP1PHot/6xlYzvouxm2aRmrG8= +github.com/aws/aws-sdk-go-v2/service/athena v1.31.6/go.mod h1:DHafyhR8x70ANJZ2RkJx8oeJsfEBqaGwZ591vlihVFQ= github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.8.1/go.mod h1:CM+19rL1+4dFWnOQKwDc7H1KwXTz+h61oUSHyhV0b3o= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11 h1:y2+VQzC6Zh2ojtV2LoC0MNwHWc6qXv/j2vrQtlftkdA= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11/go.mod h1:iV4q2hsqtNECrfmlXyord9u4zyuFEJX9eLgLpSPzWA8= @@ -650,8 +655,9 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.18.3/go.mod h1:b+psTJn33Q4qGoDaM7ZiO github.com/aws/aws-sdk-go-v2/service/sts v1.19.0 h1:2DQLAKDteoEDI8zpCzqBMaZlJuoE9iTYD0gFmXVax9E= github.com/aws/aws-sdk-go-v2/service/sts v1.19.0/go.mod h1:BgQOMsg8av8jset59jelyPW7NoZcZXLVpDsXunGDrk8= github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= -github.com/aws/smithy-go v1.13.5 h1:hgz0X/DX0dGqTYpGALqXJoRKRj5oQ7150i5FdTePzO8= github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= +github.com/aws/smithy-go v1.14.2 h1:MJU9hqBGbvWZdApzpvoF2WAIJDbtjK2NDJSiJP7HblQ= +github.com/aws/smithy-go v1.14.2/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= github.com/benbjohnson/clock v1.0.3/go.mod h1:bGMdMPoPVvcYyt1gHDf4J2KE153Yf9BuiUKYMaxlTDM= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= diff --git a/runtime/connections.go b/runtime/connections.go index 4ca6ff5344a..9710a9fc48b 100644 --- a/runtime/connections.go +++ b/runtime/connections.go @@ -227,7 +227,7 @@ func (r *Runtime) connectorConfig(ctx context.Context, instanceID, name string) // For backwards compatibility, certain root-level variables apply to certain implicit connectors. // NOTE: This switches on connector.Name, not connector.Type, because this only applies to implicit connectors. switch connector.Name { - case "s3": + case "s3", "athena": setIfNil(cfg, "aws_access_key_id", vars["aws_access_key_id"]) setIfNil(cfg, "aws_secret_access_key", vars["aws_secret_access_key"]) setIfNil(cfg, "aws_session_token", vars["aws_session_token"]) diff --git a/runtime/drivers/athena/athena.go b/runtime/drivers/athena/athena.go new file mode 100644 index 00000000000..76a7b470c59 --- /dev/null +++ b/runtime/drivers/athena/athena.go @@ -0,0 +1,177 @@ +package athena + +import ( + "context" + "fmt" + + "github.com/mitchellh/mapstructure" + "github.com/rilldata/rill/runtime/drivers" + "github.com/rilldata/rill/runtime/pkg/activity" + "go.uber.org/zap" +) + +func init() { + drivers.Register("athena", driver{}) + drivers.RegisterAsConnector("athena", driver{}) +} + +var spec = drivers.Spec{ + DisplayName: "Amazon Athena", + Description: "Connect to Amazon Athena database.", + ServiceAccountDocs: "", + SourceProperties: []drivers.PropertySchema{ + { + Key: "sql", + Type: drivers.StringPropertyType, + Required: true, + DisplayName: "SQL", + Description: "Query to extract data from Athena.", + Placeholder: "select * from catalog.table;", + }, + { + Key: "output_location", + DisplayName: "S3 output location", + Description: "Output location for query results in S3.", + Placeholder: "s3://bucket-name/path/", + Type: drivers.StringPropertyType, + Required: false, + }, + { + Key: "workgroup", + DisplayName: "AWS Athena workgroup", + Description: "AWS Athena workgroup to use for queries.", + Placeholder: "primary", + Type: drivers.StringPropertyType, + Required: false, + }, + { + Key: "region", + DisplayName: "AWS region", + Description: "AWS region to connect to Athena and the output location.", + Placeholder: "us-east-1", + Type: drivers.StringPropertyType, + Required: false, + }, + }, + ConfigProperties: []drivers.PropertySchema{ + { + Key: "aws_access_key_id", + Secret: true, + }, + { + Key: "aws_secret_access_key", + Secret: true, + }, + }, +} + +type driver struct{} + +func (d driver) Open(config map[string]any, shared bool, _ activity.Client, logger *zap.Logger) (drivers.Handle, error) { + if shared { + return nil, fmt.Errorf("athena driver can't be shared") + } + conf := &configProperties{} + err := mapstructure.Decode(config, conf) + if err != nil { + return nil, err + } + + conn := &Connection{ + config: conf, + logger: logger, + } + return conn, nil +} + +func (d driver) Drop(config map[string]any, logger *zap.Logger) error { + return drivers.ErrDropNotSupported +} + +func (d driver) Spec() drivers.Spec { + return spec +} + +func (d driver) HasAnonymousSourceAccess(ctx context.Context, src map[string]any, logger *zap.Logger) (bool, error) { + return false, nil +} + +type Connection struct { + config *configProperties + logger *zap.Logger +} + +var _ drivers.Handle = &Connection{} + +// Driver implements drivers.Connection. +func (c *Connection) Driver() string { + return "athena" +} + +// Config implements drivers.Connection. +func (c *Connection) Config() map[string]any { + m := make(map[string]any, 0) + _ = mapstructure.Decode(c.config, m) + return m +} + +// Close implements drivers.Connection. +func (c *Connection) Close() error { + return nil +} + +// Registry implements drivers.Connection. +func (c *Connection) AsRegistry() (drivers.RegistryStore, bool) { + return nil, false +} + +// Catalog implements drivers.Connection. +func (c *Connection) AsCatalogStore(instanceID string) (drivers.CatalogStore, bool) { + return nil, false +} + +// Repo implements drivers.Connection. +func (c *Connection) AsRepoStore(instanceID string) (drivers.RepoStore, bool) { + return nil, false +} + +// OLAP implements drivers.Connection. +func (c *Connection) AsOLAP(instanceID string) (drivers.OLAPStore, bool) { + return nil, false +} + +// Migrate implements drivers.Connection. +func (c *Connection) Migrate(ctx context.Context) (err error) { + return nil +} + +// MigrationStatus implements drivers.Connection. +func (c *Connection) MigrationStatus(ctx context.Context) (current, desired int, err error) { + return 0, 0, nil +} + +// AsObjectStore implements drivers.Connection. +func (c *Connection) AsObjectStore() (drivers.ObjectStore, bool) { + return nil, false +} + +// AsTransporter implements drivers.Connection. +func (c *Connection) AsTransporter(from, to drivers.Handle) (drivers.Transporter, bool) { + return nil, false +} + +func (c *Connection) AsFileStore() (drivers.FileStore, bool) { + return nil, false +} + +// AsSQLStore implements drivers.Connection. +func (c *Connection) AsSQLStore() (drivers.SQLStore, bool) { + return c, true +} + +type configProperties struct { + AccessKeyID string `mapstructure:"aws_access_key_id"` + SecretAccessKey string `mapstructure:"aws_secret_access_key"` + SessionToken string `mapstructure:"aws_access_token"` + AllowHostAccess bool `mapstructure:"allow_host_access"` +} diff --git a/runtime/drivers/athena/sql_store.go b/runtime/drivers/athena/sql_store.go new file mode 100644 index 00000000000..32b533961dc --- /dev/null +++ b/runtime/drivers/athena/sql_store.go @@ -0,0 +1,282 @@ +package athena + +import ( + "context" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/athena" + types2 "github.com/aws/aws-sdk-go-v2/service/athena/types" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/google/uuid" + "github.com/mitchellh/mapstructure" + "github.com/rilldata/rill/runtime/drivers" + rillblob "github.com/rilldata/rill/runtime/drivers/blob" + "gocloud.dev/blob" + "gocloud.dev/blob/s3blob" +) + +func (c *Connection) Query(_ context.Context, _ map[string]any) (drivers.RowIterator, error) { + return nil, drivers.ErrNotImplemented +} + +func (c *Connection) QueryAsFiles(ctx context.Context, props map[string]any, _ *drivers.QueryOption, _ drivers.Progress) (outIt drivers.FileIterator, outErr error) { + conf, err := parseSourceProperties(props) + if err != nil { + return nil, fmt.Errorf("failed to parse config: %w", err) + } + + awsConfig, err := c.awsConfig(ctx, conf.AWSRegion) + if err != nil { + return nil, err + } + + client := athena.NewFromConfig(awsConfig) + outputLocation, err := resolveOutputLocation(ctx, client, conf) + if err != nil { + return nil, err + } + + outputURL, err := url.Parse(outputLocation) + if err != nil { + return nil, err + } + + // outputLocation s3://bucket/path + // unloadLocation s3://bucket/path/rill-tmp- + // unloadPath path/rill-tmp- + unloadFolderName := "rill-tmp-" + uuid.New().String() + bucketName := outputURL.Hostname() + unloadURL := outputURL.JoinPath(unloadFolderName) + unloadLocation := unloadURL.String() + unloadPath := strings.TrimPrefix(unloadURL.Path, "/") + + cleanupFn := func() error { + return deleteObjectsInPrefix(ctx, awsConfig, bucketName, unloadPath) + } + + err = c.unload(ctx, client, conf, unloadLocation) + if err != nil { + unloadErr := fmt.Errorf("failed to unload: %w", err) + cleanupErr := cleanupFn() + if cleanupErr != nil { + cleanupErr = fmt.Errorf("cleanup error: %w", cleanupErr) + } + return nil, errors.Join(unloadErr, cleanupErr) + } + + defer func() { + if outErr != nil { + cleanupErr := cleanupFn() + if cleanupErr != nil { + outErr = errors.Join(outErr, fmt.Errorf("cleanup error: %w", cleanupErr)) + } + } + }() + + bucketObj, err := openBucket(ctx, awsConfig, bucketName) + if err != nil { + return nil, fmt.Errorf("cannot open bucket %q: %w", bucketName, err) + } + + opts := rillblob.Options{ + GlobPattern: unloadPath + "/**", + Format: "parquet", + } + + it, err := rillblob.NewIterator(ctx, bucketObj, opts, c.logger) + if err != nil { + return nil, fmt.Errorf("cannot download parquet output %q %w", opts.GlobPattern, err) + } + + return autoDeleteFileIterator{ + FileIterator: it, + cleanupFn: cleanupFn, + }, nil +} + +func (c *Connection) awsConfig(ctx context.Context, awsRegion string) (aws.Config, error) { + loadOptions := []func(*config.LoadOptions) error{ + // Setting the default region to an empty string, will result in the default region value being ignored + config.WithDefaultRegion("us-east-1"), + // Setting the region to an empty string, will result in the region value being ignored + config.WithRegion(awsRegion), + } + + // If one of the static properties is specified: access key, secret key, or session token, use static credentials, + // Else fallback to the SDK's default credential chain (environment, instance, etc) unless AllowHostAccess is false + if c.config.AccessKeyID != "" || c.config.SecretAccessKey != "" || c.config.SessionToken != "" { + p := credentials.NewStaticCredentialsProvider(c.config.AccessKeyID, c.config.SecretAccessKey, c.config.SessionToken) + loadOptions = append(loadOptions, config.WithCredentialsProvider(p)) + } else if !c.config.AllowHostAccess { + return aws.Config{}, fmt.Errorf("static creds are not provided, and host access is not allowed") + } + + return config.LoadDefaultConfig(ctx, loadOptions...) +} + +func (c *Connection) unload(ctx context.Context, client *athena.Client, conf *sourceProperties, unloadLocation string) error { + finalSQL := fmt.Sprintf("UNLOAD (%s\n) TO '%s' WITH (format = 'PARQUET')", conf.SQL, unloadLocation) + + executeParams := &athena.StartQueryExecutionInput{ + QueryString: aws.String(finalSQL), + } + + if conf.OutputLocation != "" { + executeParams.ResultConfiguration = &types2.ResultConfiguration{ + OutputLocation: aws.String(conf.OutputLocation), + } + } + + if conf.Workgroup != "" { // primary is used if nothing is set + executeParams.WorkGroup = aws.String(conf.Workgroup) + } + + queryExecutionOutput, err := client.StartQueryExecution(ctx, executeParams) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + _, err = client.StopQueryExecution(ctx, &athena.StopQueryExecutionInput{ + QueryExecutionId: queryExecutionOutput.QueryExecutionId, + }) + return errors.Join(ctx.Err(), err) + default: + status, err := client.GetQueryExecution(ctx, &athena.GetQueryExecutionInput{ + QueryExecutionId: queryExecutionOutput.QueryExecutionId, + }) + if err != nil { + return err + } + + switch status.QueryExecution.Status.State { + case types2.QueryExecutionStateSucceeded: + return nil + case types2.QueryExecutionStateCancelled: + return fmt.Errorf("Athena query execution cancelled") + case types2.QueryExecutionStateFailed: + return fmt.Errorf("Athena query execution failed %s", *status.QueryExecution.Status.AthenaError.ErrorMessage) + } + } + time.Sleep(time.Second) + } +} + +func parseSourceProperties(props map[string]any) (*sourceProperties, error) { + conf := &sourceProperties{} + err := mapstructure.Decode(props, conf) + if err != nil { + return nil, err + } + + return conf, nil +} + +func resolveOutputLocation(ctx context.Context, client *athena.Client, conf *sourceProperties) (string, error) { + if conf.OutputLocation != "" { + return conf.OutputLocation, nil + } + + workgroup := conf.Workgroup + // fallback to "primary" (default) workgroup if no workgroup is specified + if workgroup == "" { + workgroup = "primary" + } + + wo, err := client.GetWorkGroup(ctx, &athena.GetWorkGroupInput{ + WorkGroup: aws.String(workgroup), + }) + if err != nil { + return "", err + } + + resultConfiguration := wo.WorkGroup.Configuration.ResultConfiguration + if resultConfiguration != nil && resultConfiguration.OutputLocation != nil && *resultConfiguration.OutputLocation != "" { + return *resultConfiguration.OutputLocation, nil + } + + return "", fmt.Errorf("either output_location or workgroup with an output location must be set") +} + +func openBucket(ctx context.Context, cfg aws.Config, bucket string) (*blob.Bucket, error) { + s3client := s3.NewFromConfig(cfg) + return s3blob.OpenBucketV2(ctx, s3client, bucket, nil) +} + +func deleteObjectsInPrefix(ctx context.Context, cfg aws.Config, bucketName, prefix string) error { + s3client := s3.NewFromConfig(cfg) + + deleteBatch := func(objects []types.ObjectIdentifier) error { + _, err := s3client.DeleteObjects(ctx, &s3.DeleteObjectsInput{ + Bucket: &bucketName, + Delete: &types.Delete{ + Objects: objects, + }, + }) + return err + } + + var continuationToken *string + for { + out, err := s3client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ + Bucket: &bucketName, + Prefix: &prefix, + ContinuationToken: continuationToken, + }) + if err != nil { + return err + } + + ids := make([]types.ObjectIdentifier, 0, len(out.Contents)) + for _, o := range out.Contents { + ids = append(ids, types.ObjectIdentifier{ + Key: o.Key, + }) + } + + if len(ids) > 0 { + if err := deleteBatch(ids); err != nil { + return err + } + } + + if out.IsTruncated && out.NextContinuationToken != nil { + continuationToken = out.NextContinuationToken + } else { + break + } + } + + return nil +} + +type sourceProperties struct { + SQL string `mapstructure:"sql"` + OutputLocation string `mapstructure:"output_location"` + Workgroup string `mapstructure:"workgroup"` + AWSRegion string `mapstructure:"region"` +} + +type autoDeleteFileIterator struct { + drivers.FileIterator + cleanupFn func() error +} + +func (i autoDeleteFileIterator) Close() error { + err := i.FileIterator.Close() + if err != nil { + return err + } + + return i.cleanupFn() +} diff --git a/runtime/drivers/bigquery/sql_store.go b/runtime/drivers/bigquery/sql_store.go index ffa1860e68d..789179a2538 100644 --- a/runtime/drivers/bigquery/sql_store.go +++ b/runtime/drivers/bigquery/sql_store.go @@ -219,6 +219,10 @@ func (f *fileIterator) Size(unit drivers.ProgressUnit) (int64, bool) { } } +func (f *fileIterator) Format() string { + return "" +} + func (f *fileIterator) downloadAsJSONFile() error { tf := time.Now() defer func() { diff --git a/runtime/drivers/blob/blobdownloader.go b/runtime/drivers/blob/blobdownloader.go index cceab7fbf1a..5674ba7b913 100644 --- a/runtime/drivers/blob/blobdownloader.go +++ b/runtime/drivers/blob/blobdownloader.go @@ -50,6 +50,8 @@ type Options struct { KeepFilesUntilClose bool // BatchSizeBytes is the combined size of all files returned in one call to next() BatchSizeBytes int64 + // General blob format (json, csv, parquet, etc) + Format string } // sets defaults if not set by user @@ -240,6 +242,10 @@ func (it *blobIterator) Next() ([]string, error) { return result, nil } +func (it *blobIterator) Format() string { + return it.opts.Format +} + // TODO: Ideally planner should take ownership of the bucket and return an iterator with next returning objectWithPlan func (it *blobIterator) plan() ([]*objectWithPlan, error) { var ( @@ -463,6 +469,10 @@ func (it *prefetchedIterator) Next() ([]string, error) { return it.batch, nil } +func (it *prefetchedIterator) Format() string { + return it.underlying.Format() +} + // downloadResult represents a successfully downloaded file type downloadResult struct { path string diff --git a/runtime/drivers/duckdb/transporter/sqlstore_to_duckDB.go b/runtime/drivers/duckdb/transporter/sqlstore_to_duckDB.go index b4818037f6e..fb78e5cd480 100644 --- a/runtime/drivers/duckdb/transporter/sqlstore_to_duckDB.go +++ b/runtime/drivers/duckdb/transporter/sqlstore_to_duckDB.go @@ -87,6 +87,10 @@ func (s *sqlStoreToDuckDB) Transfer(ctx context.Context, srcProps, sinkProps map } format := fileutil.FullExt(files[0]) + if iter.Format() != "" { + format += "." + iter.Format() + } + from, err := sourceReader(files, format, make(map[string]any)) if err != nil { return err diff --git a/runtime/drivers/duckdb/transporter/transporter_test.go b/runtime/drivers/duckdb/transporter/transporter_test.go index 49ead478f3d..240208c2572 100644 --- a/runtime/drivers/duckdb/transporter/transporter_test.go +++ b/runtime/drivers/duckdb/transporter/transporter_test.go @@ -48,6 +48,10 @@ func (m *mockIterator) Size(unit drivers.ProgressUnit) (int64, bool) { func (m *mockIterator) KeepFilesUntilClose(keepFilesUntilClose bool) { } +func (m *mockIterator) Format() string { + return "" +} + var _ drivers.FileIterator = &mockIterator{} func TestIterativeCSVIngestionWithVariableSchema(t *testing.T) { diff --git a/runtime/drivers/object_store.go b/runtime/drivers/object_store.go index 52f39aca597..2621155c0b3 100644 --- a/runtime/drivers/object_store.go +++ b/runtime/drivers/object_store.go @@ -21,4 +21,7 @@ type FileIterator interface { // KeepFilesUntilClose marks the iterator to keep the files until close is called. // This is used when the entire list of files is used at once in certain cases. KeepFilesUntilClose(keepFilesUntilClose bool) + // Format returns general file format (json, csv, parquet, etc) + // Returns an empty string if there is no general format + Format() string } diff --git a/runtime/services/catalog/artifacts/yaml/objects.go b/runtime/services/catalog/artifacts/yaml/objects.go index 05b58ecec6e..543d1e200e8 100644 --- a/runtime/services/catalog/artifacts/yaml/objects.go +++ b/runtime/services/catalog/artifacts/yaml/objects.go @@ -46,6 +46,8 @@ type Source struct { DB string `yaml:"db,omitempty" mapstructure:"db,omitempty"` ProjectID string `yaml:"project_id,omitempty" mapstructure:"project_id,omitempty"` PostgresDatabaseURL string `yaml:"database_url,omitempty" mapstructure:"database_url,omitempty"` + AthenaOutputLocation string `yaml:"output_location,omitempty" mapstructure:"output_location,omitempty"` + AthenaWorkgroup string `yaml:"workgroup,omitempty" mapstructure:"workgroup,omitempty"` } type MetricsView struct { @@ -212,6 +214,14 @@ func fromSourceArtifact(source *Source, path string) (*drivers.CatalogEntry, err props["database_url"] = source.PostgresDatabaseURL } + if source.AthenaOutputLocation != "" { + props["output_location"] = source.AthenaOutputLocation + } + + if source.AthenaWorkgroup != "" { + props["workgroup"] = source.AthenaWorkgroup + } + propsPB, err := structpb.NewStruct(props) if err != nil { return nil, err diff --git a/runtime/services/catalog/migrator/sources/sources.go b/runtime/services/catalog/migrator/sources/sources.go index 57d04b497a1..81c17ab8ea4 100644 --- a/runtime/services/catalog/migrator/sources/sources.go +++ b/runtime/services/catalog/migrator/sources/sources.go @@ -376,7 +376,7 @@ func connectorVariables(src *runtimev1.Source, env map[string]string, repoRoot s "allow_host_access": strings.EqualFold(env["allow_host_access"], "true"), } switch connector { - case "s3": + case "s3", "athena": vars["aws_access_key_id"] = env["aws_access_key_id"] vars["aws_secret_access_key"] = env["aws_secret_access_key"] vars["aws_session_token"] = env["aws_session_token"] diff --git a/web-common/src/features/sources/modal/AddSourceModal.svelte b/web-common/src/features/sources/modal/AddSourceModal.svelte index f12bfcacd54..69dd6984d0b 100644 --- a/web-common/src/features/sources/modal/AddSourceModal.svelte +++ b/web-common/src/features/sources/modal/AddSourceModal.svelte @@ -23,6 +23,7 @@ import LocalSourceUpload from "./LocalSourceUpload.svelte"; import RemoteSourceForm from "./RemoteSourceForm.svelte"; import RequestConnectorForm from "./RequestConnectorForm.svelte"; + import AmazonAthena from "@rilldata/web-common/components/icons/connectors/AmazonAthena.svelte"; export let open: boolean; @@ -36,7 +37,7 @@ // azure_blob_storage // duckdb "bigquery", - // athena + "athena", "motherduck", "postgres", "local_file", @@ -49,7 +50,7 @@ // azure_blob_storage: MicrosoftAzureBlobStorage, // duckdb: DuckDB, bigquery: GoogleBigQuery, - // athena: AmazonAthena, + athena: AmazonAthena, motherduck: MotherDuck, postgres: Postgres, local_file: LocalFile, diff --git a/web-common/src/features/sources/modal/submitRemoteSourceForm.ts b/web-common/src/features/sources/modal/submitRemoteSourceForm.ts index 7cff3de34ed..7800deea80c 100644 --- a/web-common/src/features/sources/modal/submitRemoteSourceForm.ts +++ b/web-common/src/features/sources/modal/submitRemoteSourceForm.ts @@ -64,6 +64,8 @@ export async function submitRemoteSourceForm( Object.entries(values).map(([key, value]) => { switch (key) { case "project_id": + case "output_location": + case "workgroup": return [key, value]; case "database_url": return [key, value]; diff --git a/web-common/src/features/sources/modal/yupSchemas.ts b/web-common/src/features/sources/modal/yupSchemas.ts index 0fba22f26d7..0527719533f 100644 --- a/web-common/src/features/sources/modal/yupSchemas.ts +++ b/web-common/src/features/sources/modal/yupSchemas.ts @@ -81,6 +81,19 @@ export function getYupSchema(connector: V1ConnectorSpec) { .required("Source name is required"), database_url: yup.string(), }); + case "athena": + return yup.object().shape({ + sql: yup.string().required("sql is required"), + sourceName: yup + .string() + .matches( + /^[a-zA-Z_][a-zA-Z0-9_]*$/, + "Source name must start with a letter or underscore and contain only letters, numbers, and underscores" + ) + .required("Source name is required"), + output_location: yup.string(), + workgroup: yup.string(), + }); default: throw new Error(`Unknown connector: ${connector.name}`); }