-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
200 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,136 +1,223 @@ | ||
package mssql | ||
|
||
import ( | ||
"database/sql" | ||
"fmt" | ||
"slices" | ||
"strings" | ||
|
||
"github.com/artie-labs/transfer/clients/mssql/dialect" | ||
|
||
"github.com/artie-labs/reader/lib/mssql/parse" | ||
"github.com/artie-labs/reader/lib/mssql/schema" | ||
"github.com/artie-labs/reader/lib/rdbms" | ||
"github.com/artie-labs/reader/lib/rdbms/column" | ||
"github.com/artie-labs/reader/lib/rdbms/primary_key" | ||
"github.com/artie-labs/reader/lib/rdbms/scan" | ||
"database/sql" | ||
"fmt" | ||
"slices" | ||
"strings" | ||
"time" | ||
|
||
"github.com/artie-labs/transfer/clients/mssql/dialect" | ||
|
||
"github.com/artie-labs/reader/lib/mssql/parse" | ||
"github.com/artie-labs/reader/lib/mssql/schema" | ||
"github.com/artie-labs/reader/lib/rdbms" | ||
"github.com/artie-labs/reader/lib/rdbms/column" | ||
"github.com/artie-labs/reader/lib/rdbms/primary_key" | ||
"github.com/artie-labs/reader/lib/rdbms/scan" | ||
) | ||
|
||
const ( | ||
TimeMicro = "15:04:05.000000" | ||
TimeNano = "15:04:05.000000000" | ||
DateTimeMicro = "2006-01-02 15:04:05.000000" | ||
DateTimeNano = "2006-01-02 15:04:05.000000000" | ||
DateTimeOffset = "2006-01-02 15:04:05.0000000 -07:00" | ||
) | ||
|
||
var supportedPrimaryKeyDataType = []schema.DataType{ | ||
schema.Bit, | ||
schema.Bytes, | ||
schema.Int16, | ||
schema.Int32, | ||
schema.Int64, | ||
schema.Numeric, | ||
schema.Float, | ||
schema.Money, | ||
schema.Date, | ||
schema.String, | ||
schema.Time, | ||
schema.TimeMicro, | ||
schema.TimeNano, | ||
schema.Datetime2, | ||
schema.Datetime2Micro, | ||
schema.Datetime2Nano, | ||
schema.DatetimeOffset, | ||
schema.Bit, | ||
schema.Bytes, | ||
schema.Int16, | ||
schema.Int32, | ||
schema.Int64, | ||
schema.Numeric, | ||
schema.Float, | ||
schema.Money, | ||
schema.Date, | ||
schema.String, | ||
schema.Time, | ||
schema.TimeMicro, | ||
schema.TimeNano, | ||
schema.Datetime2, | ||
schema.Datetime2Micro, | ||
schema.Datetime2Nano, | ||
schema.DatetimeOffset, | ||
} | ||
|
||
func NewScanner(db *sql.DB, table Table, columns []schema.Column, cfg scan.ScannerConfig) (*scan.Scanner, error) { | ||
for _, key := range table.PrimaryKeys() { | ||
_column, err := column.ByName(columns, key) | ||
if err != nil { | ||
return nil, fmt.Errorf("missing column with name: %q", key) | ||
} | ||
|
||
if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) { | ||
return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name) | ||
} | ||
} | ||
|
||
primaryKeyBounds, err := table.FetchPrimaryKeysBounds(db) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
adapter := scanAdapter{schema: table.Schema, tableName: table.Name, columns: columns} | ||
return scan.NewScanner(db, primaryKeyBounds, cfg, adapter) | ||
for _, key := range table.PrimaryKeys() { | ||
_column, err := column.ByName(columns, key) | ||
if err != nil { | ||
return nil, fmt.Errorf("missing column with name: %q", key) | ||
} | ||
|
||
if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) { | ||
return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name) | ||
} | ||
} | ||
|
||
primaryKeyBounds, err := table.FetchPrimaryKeysBounds(db) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
adapter := scanAdapter{schema: table.Schema, tableName: table.Name, columns: columns} | ||
return scan.NewScanner(db, primaryKeyBounds, cfg, adapter) | ||
} | ||
|
||
type scanAdapter struct { | ||
schema string | ||
tableName string | ||
columns []schema.Column | ||
schema string | ||
tableName string | ||
columns []schema.Column | ||
} | ||
|
||
func (s scanAdapter) ParsePrimaryKeyValueForOverrides(_ string, value string) (any, error) { | ||
// We don't need to cast it at all. | ||
return value, nil | ||
} | ||
|
||
func (s scanAdapter) ParsePrimaryKeyValueForOverrides(columnName string, value string) (any, error) { | ||
// TODO: Implement Date, Time, Datetime for primary key types. | ||
columnIdx := slices.IndexFunc(s.columns, func(x schema.Column) bool { return x.Name == columnName }) | ||
if columnIdx < 0 { | ||
return nil, fmt.Errorf("primary key column does not exist: %q", columnName) | ||
} | ||
|
||
_column := s.columns[columnIdx] | ||
if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) { | ||
return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name) | ||
} | ||
|
||
switch _column.Type { | ||
case schema.Bit: | ||
return value == "1", nil | ||
default: | ||
return value, nil | ||
} | ||
func (s scanAdapter) parsePrimaryKeyValues(columnName string, value any) (any, error) { | ||
columnIdx := slices.IndexFunc(s.columns, func(x schema.Column) bool { return x.Name == columnName }) | ||
if columnIdx < 0 { | ||
return nil, fmt.Errorf("primary key column does not exist: %q", columnName) | ||
} | ||
|
||
_column := s.columns[columnIdx] | ||
if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) { | ||
return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name) | ||
} | ||
|
||
switch _column.Type { | ||
case schema.Time: | ||
switch castedVal := value.(type) { | ||
case time.Time: | ||
return castedVal.Format(time.TimeOnly), nil | ||
case string: | ||
return castedVal, nil | ||
default: | ||
return nil, fmt.Errorf("expected time.Time type, received: %T", value) | ||
} | ||
case schema.TimeMicro: | ||
switch castedVal := value.(type) { | ||
case time.Time: | ||
return castedVal.Format(TimeMicro), nil | ||
case string: | ||
return castedVal, nil | ||
default: | ||
return nil, fmt.Errorf("expected time.Time type, received: %T", value) | ||
} | ||
case schema.TimeNano: | ||
switch castedVal := value.(type) { | ||
case time.Time: | ||
return castedVal.Format(TimeNano), nil | ||
case string: | ||
return castedVal, nil | ||
default: | ||
return nil, fmt.Errorf("expected time.Time type, received: %T", value) | ||
} | ||
case schema.Datetime2: | ||
switch castedVal := value.(type) { | ||
case time.Time: | ||
return castedVal.Format(time.DateTime), nil | ||
case string: | ||
return castedVal, nil | ||
default: | ||
return nil, fmt.Errorf("expected time.Time type, received: %T", value) | ||
} | ||
case schema.Datetime2Micro: | ||
switch castedVal := value.(type) { | ||
case time.Time: | ||
return castedVal.Format(DateTimeMicro), nil | ||
case string: | ||
return castedVal, nil | ||
default: | ||
return nil, fmt.Errorf("expected time.Time type, received: %T", value) | ||
} | ||
case schema.Datetime2Nano: | ||
switch castedVal := value.(type) { | ||
case time.Time: | ||
return castedVal.Format(DateTimeNano), nil | ||
case string: | ||
return castedVal, nil | ||
default: | ||
return nil, fmt.Errorf("expected time.Time type, received: %T", value) | ||
} | ||
case schema.DatetimeOffset: | ||
switch castedVal := value.(type) { | ||
case time.Time: | ||
return castedVal.Format(DateTimeOffset), nil | ||
case string: | ||
return castedVal, nil | ||
default: | ||
return nil, fmt.Errorf("expected time.Time type, received: %T", value) | ||
} | ||
default: | ||
return value, nil | ||
} | ||
} | ||
|
||
func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any, error) { | ||
mssqlDialect := dialect.MSSQLDialect{} | ||
colNames := make([]string, len(s.columns)) | ||
for idx, col := range s.columns { | ||
colNames[idx] = mssqlDialect.QuoteIdentifier(col.Name) | ||
} | ||
|
||
startingValues := make([]any, len(primaryKeys)) | ||
endingValues := make([]any, len(primaryKeys)) | ||
for i, pk := range primaryKeys { | ||
startingValues[i] = pk.StartingValue | ||
endingValues[i] = pk.EndingValue | ||
} | ||
|
||
quotedKeyNames := make([]string, len(primaryKeys)) | ||
for i, key := range primaryKeys { | ||
quotedKeyNames[i] = mssqlDialect.QuoteIdentifier(key.Name) | ||
} | ||
|
||
lowerBoundComparison := ">" | ||
if isFirstBatch { | ||
lowerBoundComparison = ">=" | ||
} | ||
|
||
return fmt.Sprintf(`SELECT TOP %d %s FROM %s.%s WHERE (%s) %s (%s) AND (%s) <= (%s) ORDER BY %s`, | ||
// TOP | ||
batchSize, | ||
// SELECT | ||
strings.Join(colNames, ","), | ||
// FROM | ||
mssqlDialect.QuoteIdentifier(s.schema), mssqlDialect.QuoteIdentifier(s.tableName), | ||
// WHERE (pk) > (123) | ||
strings.Join(quotedKeyNames, ","), lowerBoundComparison, strings.Join(rdbms.QueryPlaceholders("?", len(startingValues)), ","), | ||
strings.Join(quotedKeyNames, ","), strings.Join(rdbms.QueryPlaceholders("?", len(endingValues)), ","), | ||
// ORDER BY | ||
strings.Join(quotedKeyNames, ","), | ||
), slices.Concat(startingValues, endingValues), nil | ||
mssqlDialect := dialect.MSSQLDialect{} | ||
colNames := make([]string, len(s.columns)) | ||
for idx, col := range s.columns { | ||
colNames[idx] = mssqlDialect.QuoteIdentifier(col.Name) | ||
} | ||
|
||
startingValues := make([]any, len(primaryKeys)) | ||
endingValues := make([]any, len(primaryKeys)) | ||
for i, pk := range primaryKeys { | ||
pkStartVal, err := s.parsePrimaryKeyValues(pk.Name, pk.StartingValue) | ||
if err != nil { | ||
return "", nil, err | ||
} | ||
|
||
pkEndVal, err := s.parsePrimaryKeyValues(pk.Name, pk.EndingValue) | ||
if err != nil { | ||
return "", nil, err | ||
} | ||
|
||
fmt.Println("pkStartVal", pkStartVal) | ||
fmt.Println("pkEndVal", pkEndVal) | ||
|
||
startingValues[i] = pkStartVal | ||
endingValues[i] = pkEndVal | ||
} | ||
|
||
quotedKeyNames := make([]string, len(primaryKeys)) | ||
for i, key := range primaryKeys { | ||
quotedKeyNames[i] = mssqlDialect.QuoteIdentifier(key.Name) | ||
} | ||
|
||
lowerBoundComparison := ">" | ||
if isFirstBatch { | ||
lowerBoundComparison = ">=" | ||
} | ||
|
||
return fmt.Sprintf(`SELECT TOP %d %s FROM %s.%s WHERE (%s) %s (%s) AND (%s) <= (%s) ORDER BY %s`, | ||
// TOP | ||
batchSize, | ||
// SELECT | ||
strings.Join(colNames, ","), | ||
// FROM | ||
mssqlDialect.QuoteIdentifier(s.schema), mssqlDialect.QuoteIdentifier(s.tableName), | ||
// WHERE (pk) > (123) | ||
strings.Join(quotedKeyNames, ","), lowerBoundComparison, strings.Join(rdbms.QueryPlaceholders("?", len(startingValues)), ","), | ||
strings.Join(quotedKeyNames, ","), strings.Join(rdbms.QueryPlaceholders("?", len(endingValues)), ","), | ||
// ORDER BY | ||
strings.Join(quotedKeyNames, ","), | ||
), slices.Concat(startingValues, endingValues), nil | ||
} | ||
|
||
func (s scanAdapter) ParseRow(values []any) error { | ||
for i, value := range values { | ||
parsedValue, err := parse.ParseValue(s.columns[i].Type, value) | ||
if err != nil { | ||
return fmt.Errorf("failed to parse column: %q: %w", s.columns[i].Name, err) | ||
} | ||
for i, value := range values { | ||
parsedValue, err := parse.ParseValue(s.columns[i].Type, value) | ||
if err != nil { | ||
return fmt.Errorf("failed to parse column: %q: %w", s.columns[i].Name, err) | ||
} | ||
|
||
values[i] = parsedValue | ||
} | ||
values[i] = parsedValue | ||
} | ||
|
||
return nil | ||
return nil | ||
} |