Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass context.Context through for MERGE and APPEND #956

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ type Store struct {
db.Store
}

func (s *Store) Append(tableData *optimization.TableData, useTempTable bool) error {
func (s *Store) Append(ctx context.Context, tableData *optimization.TableData, useTempTable bool) error {
if !useTempTable {
return shared.Append(s, tableData, types.AdditionalSettings{})
return shared.Append(ctx, s, tableData, types.AdditionalSettings{})
}

// We can simplify this once Google has fully rolled out the ability to execute DML on recently streamed data
Expand All @@ -55,7 +55,7 @@ func (s *Store) Append(tableData *optimization.TableData, useTempTable bool) err

defer func() { _ = ddl.DropTemporaryTable(s, temporaryTableID, false) }()

err := shared.Append(s, tableData, types.AdditionalSettings{
err := shared.Append(ctx, s, tableData, types.AdditionalSettings{
UseTempTable: true,
TempTableID: temporaryTableID,
})
Expand All @@ -78,7 +78,7 @@ func (s *Store) Append(tableData *optimization.TableData, useTempTable bool) err
return nil
}

func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Dialect: s.Dialect(),
Expand All @@ -100,7 +100,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo
return err
}

return s.putTable(context.Background(), bqTempTableID, tableData)
return s.putTable(ctx, bqTempTableID, tableData)
}

func (s *Store) IdentifierFor(topicConfig kafkalib.TopicConfig, table string) sql.TableIdentifier {
Expand Down
5 changes: 3 additions & 2 deletions clients/bigquery/merge.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bigquery

import (
"context"
"fmt"
"strings"

Expand All @@ -14,7 +15,7 @@ import (
"github.com/artie-labs/transfer/lib/typing/columns"
)

func (s *Store) Merge(tableData *optimization.TableData) error {
func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) error {
var additionalEqualityStrings []string
if tableData.TopicConfig().BigQueryPartitionSettings != nil {
distinctDates, err := tableData.DistinctDates(tableData.TopicConfig().BigQueryPartitionSettings.PartitionField)
Expand All @@ -30,7 +31,7 @@ func (s *Store) Merge(tableData *optimization.TableData) error {
additionalEqualityStrings = []string{mergeString}
}

return shared.Merge(s, tableData, types.MergeOpts{
return shared.Merge(ctx, s, tableData, types.MergeOpts{
AdditionalEqualityStrings: additionalEqualityStrings,
// BigQuery has DDL quotas.
RetryColBackfill: true,
Expand Down
13 changes: 6 additions & 7 deletions clients/databricks/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ func describeTableQuery(tableID TableIdentifier) (string, []any) {
return fmt.Sprintf("DESCRIBE TABLE %s", tableID.FullyQualifiedName()), nil
}

func (s Store) Merge(tableData *optimization.TableData) error {
return shared.Merge(s, tableData, types.MergeOpts{})
func (s Store) Merge(ctx context.Context, tableData *optimization.TableData) error {
return shared.Merge(ctx, s, tableData, types.MergeOpts{})
}

func (s Store) Append(tableData *optimization.TableData, useTempTable bool) error {
return shared.Append(s, tableData, types.AdditionalSettings{UseTempTable: useTempTable})
func (s Store) Append(ctx context.Context, tableData *optimization.TableData, useTempTable bool) error {
return shared.Append(ctx, s, tableData, types.AdditionalSettings{UseTempTable: useTempTable})
}

func (s Store) IdentifierFor(topicConfig kafkalib.TopicConfig, table string) sql.TableIdentifier {
Expand Down Expand Up @@ -89,8 +89,7 @@ func (s Store) GetTableConfig(tableData *optimization.TableData) (*types.DwhTabl
}.GetTableConfig()
}

func (s Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
// TODO: Update PrepareTemporaryTable interface to include context
func (s Store) PrepareTemporaryTable(ctx context.Context, tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Dialect: s.Dialect(),
Expand Down Expand Up @@ -120,7 +119,7 @@ func (s Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCon
}()

// Upload the local file to DBFS
ctx := driverctx.NewContextWithStagingInfo(context.Background(), []string{"/var"})
ctx = driverctx.NewContextWithStagingInfo(ctx, []string{"/var"})

castedTempTableID, isOk := tempTableID.(TableIdentifier)
if !isOk {
Expand Down
3 changes: 2 additions & 1 deletion clients/mssql/staging.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mssql

import (
"context"
"fmt"

mssql "github.com/microsoft/go-mssqldb"
Expand All @@ -13,7 +14,7 @@ import (
"github.com/artie-labs/transfer/lib/typing/columns"
)

func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
func (s *Store) PrepareTemporaryTable(_ context.Context, tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Dialect: s.Dialect(),
Expand Down
9 changes: 5 additions & 4 deletions clients/mssql/store.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mssql

import (
"context"
"strings"

_ "github.com/microsoft/go-mssqldb"
Expand Down Expand Up @@ -38,12 +39,12 @@ func (s *Store) dialect() dialect.MSSQLDialect {
return dialect.MSSQLDialect{}
}

func (s *Store) Merge(tableData *optimization.TableData) error {
return shared.Merge(s, tableData, types.MergeOpts{})
func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) error {
return shared.Merge(ctx, s, tableData, types.MergeOpts{})
}

func (s *Store) Append(tableData *optimization.TableData, _ bool) error {
return shared.Append(s, tableData, types.AdditionalSettings{})
func (s *Store) Append(ctx context.Context, tableData *optimization.TableData, _ bool) error {
return shared.Append(ctx, s, tableData, types.AdditionalSettings{})
}

// specificIdentifierFor returns a MS SQL [TableIdentifier] for a [TopicConfig] + table name.
Expand Down
9 changes: 5 additions & 4 deletions clients/redshift/redshift.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package redshift

import (
"context"
"fmt"

_ "github.com/jackc/pgx/v5/stdlib"
Expand All @@ -26,12 +27,12 @@ type Store struct {
db.Store
}

func (s *Store) Append(tableData *optimization.TableData, _ bool) error {
return shared.Append(s, tableData, types.AdditionalSettings{})
func (s *Store) Append(ctx context.Context, tableData *optimization.TableData, _ bool) error {
return shared.Append(ctx, s, tableData, types.AdditionalSettings{})
}

func (s *Store) Merge(tableData *optimization.TableData) error {
return shared.Merge(s, tableData, types.MergeOpts{
func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) error {
return shared.Merge(ctx, s, tableData, types.MergeOpts{
// We are adding SELECT DISTINCT here for the temporary table as an extra guardrail.
// Redshift does not enforce any row uniqueness and there could be potential LOAD errors which will cause duplicate rows to arise.
SubQueryDedupe: true,
Expand Down
4 changes: 2 additions & 2 deletions clients/redshift/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"github.com/artie-labs/transfer/lib/sql"
)

func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Dialect: s.Dialect(),
Expand Down Expand Up @@ -47,7 +47,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo
}()

// Load fp into s3, get S3 URI and pass it down.
s3Uri, err := s3lib.UploadLocalFileToS3(context.Background(), s3lib.UploadArgs{
s3Uri, err := s3lib.UploadLocalFileToS3(ctx, s3lib.UploadArgs{
OptionalS3Prefix: s.optionalS3Prefix,
Bucket: s.bucket,
FilePath: fp,
Expand Down
8 changes: 4 additions & 4 deletions clients/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ func (s *Store) ObjectPrefix(tableData *optimization.TableData) string {
return strings.Join([]string{fqTableName, yyyyMMDDFormat}, "/")
}

func (s *Store) Append(tableData *optimization.TableData, _ bool) error {
func (s *Store) Append(ctx context.Context, tableData *optimization.TableData, _ bool) error {
// There's no difference in appending or merging for S3.
return s.Merge(tableData)
return s.Merge(ctx, tableData)
}

// Merge - will take tableData, write it into a particular file in the specified format, in these steps:
// 1. Load a ParquetWriter from a JSON schema (auto-generated)
// 2. Load the temporary file, under this format: s3://bucket/folderName/fullyQualifiedTableName/YYYY-MM-DD/{{unix_timestamp}}.parquet.gz
// 3. It will then upload this to S3
// 4. Delete the temporary file
func (s *Store) Merge(tableData *optimization.TableData) error {
func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) error {
if tableData.ShouldSkipUpdate() {
return nil
}
Expand Down Expand Up @@ -127,7 +127,7 @@ func (s *Store) Merge(tableData *optimization.TableData) error {
}
}()

if _, err = s3lib.UploadLocalFileToS3(context.Background(), s3lib.UploadArgs{
if _, err = s3lib.UploadLocalFileToS3(ctx, s3lib.UploadArgs{
Bucket: s.config.S3.Bucket,
OptionalS3Prefix: s.ObjectPrefix(tableData),
FilePath: fp,
Expand Down
4 changes: 3 additions & 1 deletion clients/shared/append.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package shared

import (
"context"
"fmt"

"github.com/artie-labs/transfer/lib/config/constants"
Expand All @@ -11,7 +12,7 @@ import (
"github.com/artie-labs/transfer/lib/typing/columns"
)

func Append(dwh destination.DataWarehouse, tableData *optimization.TableData, opts types.AdditionalSettings) error {
func Append(ctx context.Context, dwh destination.DataWarehouse, tableData *optimization.TableData, opts types.AdditionalSettings) error {
if tableData.ShouldSkipUpdate() {
return nil
}
Expand Down Expand Up @@ -58,6 +59,7 @@ func Append(dwh destination.DataWarehouse, tableData *optimization.TableData, op
}

return dwh.PrepareTemporaryTable(
ctx,
tableData,
tableConfig,
tempTableID,
Expand Down
5 changes: 3 additions & 2 deletions clients/shared/merge.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package shared

import (
"context"
"fmt"
"log/slog"
"time"
Expand All @@ -17,7 +18,7 @@ import (

const backfillMaxRetries = 1000

func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, opts types.MergeOpts) error {
func Merge(ctx context.Context, dwh destination.DataWarehouse, tableData *optimization.TableData, opts types.MergeOpts) error {
if tableData.ShouldSkipUpdate() {
return nil
}
Expand Down Expand Up @@ -76,7 +77,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, opt
}
}()

if err = dwh.PrepareTemporaryTable(tableData, tableConfig, temporaryTableID, tableID, types.AdditionalSettings{}, true); err != nil {
if err = dwh.PrepareTemporaryTable(ctx, tableData, tableConfig, temporaryTableID, tableID, types.AdditionalSettings{}, true); err != nil {
return fmt.Errorf("failed to prepare temporary table: %w", err)
}

Expand Down
27 changes: 11 additions & 16 deletions clients/snowflake/snowflake_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package snowflake

import (
"context"
"fmt"
"strconv"
"strings"
Expand All @@ -9,19 +10,16 @@ import (

"github.com/artie-labs/transfer/clients/shared"
"github.com/artie-labs/transfer/lib/config"
"github.com/artie-labs/transfer/lib/kafkalib/partition"
"github.com/artie-labs/transfer/lib/sql"

"github.com/artie-labs/transfer/lib/typing/columns"

"github.com/stretchr/testify/assert"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/kafkalib"
"github.com/artie-labs/transfer/lib/kafkalib/partition"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
"github.com/artie-labs/transfer/lib/typing/ext"
"github.com/stretchr/testify/assert"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this one up?

)

func (s *SnowflakeTestSuite) identifierFor(tableData *optimization.TableData) sql.TableIdentifier {
Expand Down Expand Up @@ -80,7 +78,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeNilEdgeCase() {

s.stageStore.configMap.AddTableToConfig(s.identifierFor(tableData), types.NewDwhTableConfig(&anotherCols, nil, false, true))

err := s.stageStore.Merge(tableData)
err := s.stageStore.Merge(context.Background(), tableData)
_col, isOk := tableData.ReadOnlyInMemoryCols().GetColumn("first_name")
assert.True(s.T(), isOk)
assert.Equal(s.T(), _col.KindDetails, typing.String)
Expand Down Expand Up @@ -126,7 +124,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeReestablishAuth() {

s.stageStore.configMap.AddTableToConfig(s.identifierFor(tableData), types.NewDwhTableConfig(&cols, nil, false, true))

assert.NoError(s.T(), s.stageStore.Merge(tableData))
assert.NoError(s.T(), s.stageStore.Merge(context.Background(), tableData))
assert.Equal(s.T(), 5, s.fakeStageStore.ExecCallCount())
}

Expand Down Expand Up @@ -174,7 +172,7 @@ func (s *SnowflakeTestSuite) TestExecuteMerge() {
tableID := s.identifierFor(tableData)
fqName := tableID.FullyQualifiedName()
s.stageStore.configMap.AddTableToConfig(tableID, types.NewDwhTableConfig(&cols, nil, false, true))
err := s.stageStore.Merge(tableData)
err := s.stageStore.Merge(context.Background(), tableData)
assert.Nil(s.T(), err)
s.fakeStageStore.ExecReturns(nil, nil)
// CREATE TABLE IF NOT EXISTS customer.public.orders___artie_Mwv9YADmRy (id int,name string,__artie_delete boolean,created_at timestamp_tz) STAGE_COPY_OPTIONS = ( PURGE = TRUE ) STAGE_FILE_FORMAT = ( TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE) COMMENT='expires:2023-06-27 11:54:03 UTC'
Expand Down Expand Up @@ -258,8 +256,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() {
_config := types.NewDwhTableConfig(&sflkCols, nil, false, true)
s.stageStore.configMap.AddTableToConfig(s.identifierFor(tableData), _config)

err := s.stageStore.Merge(tableData)
assert.Nil(s.T(), err)
assert.NoError(s.T(), s.stageStore.Merge(context.Background(), tableData))
s.fakeStageStore.ExecReturns(nil, nil)
assert.Equal(s.T(), s.fakeStageStore.ExecCallCount(), 5, "called merge")

Expand All @@ -282,8 +279,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() {
break
}

err = s.stageStore.Merge(tableData)
assert.NoError(s.T(), err)
assert.NoError(s.T(), s.stageStore.Merge(context.Background(), tableData))
s.fakeStageStore.ExecReturns(nil, nil)
assert.Equal(s.T(), s.fakeStageStore.ExecCallCount(), 10, "called merge again")

Expand All @@ -294,8 +290,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() {

func (s *SnowflakeTestSuite) TestExecuteMergeExitEarly() {
tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{}, "foo")
err := s.stageStore.Merge(tableData)
assert.Nil(s.T(), err)
assert.NoError(s.T(), s.stageStore.Merge(context.Background(), tableData))
}

func (s *SnowflakeTestSuite) TestStore_AdditionalEqualityStrings() {
Expand Down
3 changes: 2 additions & 1 deletion clients/snowflake/staging.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package snowflake

import (
"context"
"encoding/csv"
"fmt"
"log/slog"
Expand Down Expand Up @@ -54,7 +55,7 @@ func castColValStaging(colVal any, colKind typing.KindDetails) (string, error) {
return replaceExceededValues(value, colKind), nil
}

func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, additionalSettings types.AdditionalSettings, createTempTable bool) error {
func (s *Store) PrepareTemporaryTable(_ context.Context, tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, additionalSettings types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Dialect: s.Dialect(),
Expand Down
Loading
Loading