Skip to content

Commit

Permalink
Checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 committed May 24, 2024
1 parent 11186be commit d9b9770
Showing 1 changed file with 200 additions and 113 deletions.
313 changes: 200 additions & 113 deletions lib/mssql/scanner.go
Original file line number Diff line number Diff line change
@@ -1,136 +1,223 @@
package mssql

import (
"database/sql"
"fmt"
"slices"
"strings"

"github.com/artie-labs/transfer/clients/mssql/dialect"

"github.com/artie-labs/reader/lib/mssql/parse"
"github.com/artie-labs/reader/lib/mssql/schema"
"github.com/artie-labs/reader/lib/rdbms"
"github.com/artie-labs/reader/lib/rdbms/column"
"github.com/artie-labs/reader/lib/rdbms/primary_key"
"github.com/artie-labs/reader/lib/rdbms/scan"
"database/sql"
"fmt"
"slices"
"strings"
"time"

"github.com/artie-labs/transfer/clients/mssql/dialect"

"github.com/artie-labs/reader/lib/mssql/parse"
"github.com/artie-labs/reader/lib/mssql/schema"
"github.com/artie-labs/reader/lib/rdbms"
"github.com/artie-labs/reader/lib/rdbms/column"
"github.com/artie-labs/reader/lib/rdbms/primary_key"
"github.com/artie-labs/reader/lib/rdbms/scan"
)

const (
TimeMicro = "15:04:05.000000"
TimeNano = "15:04:05.000000000"
DateTimeMicro = "2006-01-02 15:04:05.000000"
DateTimeNano = "2006-01-02 15:04:05.000000000"
DateTimeOffset = "2006-01-02 15:04:05.0000000 -07:00"
)

var supportedPrimaryKeyDataType = []schema.DataType{
schema.Bit,
schema.Bytes,
schema.Int16,
schema.Int32,
schema.Int64,
schema.Numeric,
schema.Float,
schema.Money,
schema.Date,
schema.String,
schema.Time,
schema.TimeMicro,
schema.TimeNano,
schema.Datetime2,
schema.Datetime2Micro,
schema.Datetime2Nano,
schema.DatetimeOffset,
schema.Bit,
schema.Bytes,
schema.Int16,
schema.Int32,
schema.Int64,
schema.Numeric,
schema.Float,
schema.Money,
schema.Date,
schema.String,
schema.Time,
schema.TimeMicro,
schema.TimeNano,
schema.Datetime2,
schema.Datetime2Micro,
schema.Datetime2Nano,
schema.DatetimeOffset,
}

func NewScanner(db *sql.DB, table Table, columns []schema.Column, cfg scan.ScannerConfig) (*scan.Scanner, error) {
for _, key := range table.PrimaryKeys() {
_column, err := column.ByName(columns, key)
if err != nil {
return nil, fmt.Errorf("missing column with name: %q", key)
}

if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) {
return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name)
}
}

primaryKeyBounds, err := table.FetchPrimaryKeysBounds(db)
if err != nil {
return nil, err
}

adapter := scanAdapter{schema: table.Schema, tableName: table.Name, columns: columns}
return scan.NewScanner(db, primaryKeyBounds, cfg, adapter)
for _, key := range table.PrimaryKeys() {
_column, err := column.ByName(columns, key)
if err != nil {
return nil, fmt.Errorf("missing column with name: %q", key)
}

if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) {
return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name)
}
}

primaryKeyBounds, err := table.FetchPrimaryKeysBounds(db)
if err != nil {
return nil, err
}

adapter := scanAdapter{schema: table.Schema, tableName: table.Name, columns: columns}
return scan.NewScanner(db, primaryKeyBounds, cfg, adapter)
}

type scanAdapter struct {
schema string
tableName string
columns []schema.Column
schema string
tableName string
columns []schema.Column
}

func (s scanAdapter) ParsePrimaryKeyValueForOverrides(_ string, value string) (any, error) {
// We don't need to cast it at all.
return value, nil
}

func (s scanAdapter) ParsePrimaryKeyValueForOverrides(columnName string, value string) (any, error) {
// TODO: Implement Date, Time, Datetime for primary key types.
columnIdx := slices.IndexFunc(s.columns, func(x schema.Column) bool { return x.Name == columnName })
if columnIdx < 0 {
return nil, fmt.Errorf("primary key column does not exist: %q", columnName)
}

_column := s.columns[columnIdx]
if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) {
return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name)
}

switch _column.Type {
case schema.Bit:
return value == "1", nil
default:
return value, nil
}
func (s scanAdapter) parsePrimaryKeyValues(columnName string, value any) (any, error) {
columnIdx := slices.IndexFunc(s.columns, func(x schema.Column) bool { return x.Name == columnName })
if columnIdx < 0 {
return nil, fmt.Errorf("primary key column does not exist: %q", columnName)
}

_column := s.columns[columnIdx]
if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) {
return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name)
}

switch _column.Type {
case schema.Time:
switch castedVal := value.(type) {
case time.Time:
return castedVal.Format(time.TimeOnly), nil
case string:
return castedVal, nil
default:
return nil, fmt.Errorf("expected time.Time type, received: %T", value)
}
case schema.TimeMicro:
switch castedVal := value.(type) {
case time.Time:
return castedVal.Format(TimeMicro), nil
case string:
return castedVal, nil
default:
return nil, fmt.Errorf("expected time.Time type, received: %T", value)
}
case schema.TimeNano:
switch castedVal := value.(type) {
case time.Time:
return castedVal.Format(TimeNano), nil
case string:
return castedVal, nil
default:
return nil, fmt.Errorf("expected time.Time type, received: %T", value)
}
case schema.Datetime2:
switch castedVal := value.(type) {
case time.Time:
return castedVal.Format(time.DateTime), nil
case string:
return castedVal, nil
default:
return nil, fmt.Errorf("expected time.Time type, received: %T", value)
}
case schema.Datetime2Micro:
switch castedVal := value.(type) {
case time.Time:
return castedVal.Format(DateTimeMicro), nil
case string:
return castedVal, nil
default:
return nil, fmt.Errorf("expected time.Time type, received: %T", value)
}
case schema.Datetime2Nano:
switch castedVal := value.(type) {
case time.Time:
return castedVal.Format(DateTimeNano), nil
case string:
return castedVal, nil
default:
return nil, fmt.Errorf("expected time.Time type, received: %T", value)
}
case schema.DatetimeOffset:
switch castedVal := value.(type) {
case time.Time:
return castedVal.Format(DateTimeOffset), nil
case string:
return castedVal, nil
default:
return nil, fmt.Errorf("expected time.Time type, received: %T", value)
}
default:
return value, nil
}
}

func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any, error) {
mssqlDialect := dialect.MSSQLDialect{}
colNames := make([]string, len(s.columns))
for idx, col := range s.columns {
colNames[idx] = mssqlDialect.QuoteIdentifier(col.Name)
}

startingValues := make([]any, len(primaryKeys))
endingValues := make([]any, len(primaryKeys))
for i, pk := range primaryKeys {
startingValues[i] = pk.StartingValue
endingValues[i] = pk.EndingValue
}

quotedKeyNames := make([]string, len(primaryKeys))
for i, key := range primaryKeys {
quotedKeyNames[i] = mssqlDialect.QuoteIdentifier(key.Name)
}

lowerBoundComparison := ">"
if isFirstBatch {
lowerBoundComparison = ">="
}

return fmt.Sprintf(`SELECT TOP %d %s FROM %s.%s WHERE (%s) %s (%s) AND (%s) <= (%s) ORDER BY %s`,
// TOP
batchSize,
// SELECT
strings.Join(colNames, ","),
// FROM
mssqlDialect.QuoteIdentifier(s.schema), mssqlDialect.QuoteIdentifier(s.tableName),
// WHERE (pk) > (123)
strings.Join(quotedKeyNames, ","), lowerBoundComparison, strings.Join(rdbms.QueryPlaceholders("?", len(startingValues)), ","),
strings.Join(quotedKeyNames, ","), strings.Join(rdbms.QueryPlaceholders("?", len(endingValues)), ","),
// ORDER BY
strings.Join(quotedKeyNames, ","),
), slices.Concat(startingValues, endingValues), nil
mssqlDialect := dialect.MSSQLDialect{}
colNames := make([]string, len(s.columns))
for idx, col := range s.columns {
colNames[idx] = mssqlDialect.QuoteIdentifier(col.Name)
}

startingValues := make([]any, len(primaryKeys))
endingValues := make([]any, len(primaryKeys))
for i, pk := range primaryKeys {
pkStartVal, err := s.parsePrimaryKeyValues(pk.Name, pk.StartingValue)
if err != nil {
return "", nil, err
}

pkEndVal, err := s.parsePrimaryKeyValues(pk.Name, pk.EndingValue)
if err != nil {
return "", nil, err
}

fmt.Println("pkStartVal", pkStartVal)
fmt.Println("pkEndVal", pkEndVal)

startingValues[i] = pkStartVal
endingValues[i] = pkEndVal
}

quotedKeyNames := make([]string, len(primaryKeys))
for i, key := range primaryKeys {
quotedKeyNames[i] = mssqlDialect.QuoteIdentifier(key.Name)
}

lowerBoundComparison := ">"
if isFirstBatch {
lowerBoundComparison = ">="
}

return fmt.Sprintf(`SELECT TOP %d %s FROM %s.%s WHERE (%s) %s (%s) AND (%s) <= (%s) ORDER BY %s`,
// TOP
batchSize,
// SELECT
strings.Join(colNames, ","),
// FROM
mssqlDialect.QuoteIdentifier(s.schema), mssqlDialect.QuoteIdentifier(s.tableName),
// WHERE (pk) > (123)
strings.Join(quotedKeyNames, ","), lowerBoundComparison, strings.Join(rdbms.QueryPlaceholders("?", len(startingValues)), ","),
strings.Join(quotedKeyNames, ","), strings.Join(rdbms.QueryPlaceholders("?", len(endingValues)), ","),
// ORDER BY
strings.Join(quotedKeyNames, ","),
), slices.Concat(startingValues, endingValues), nil
}

func (s scanAdapter) ParseRow(values []any) error {
for i, value := range values {
parsedValue, err := parse.ParseValue(s.columns[i].Type, value)
if err != nil {
return fmt.Errorf("failed to parse column: %q: %w", s.columns[i].Name, err)
}
for i, value := range values {
parsedValue, err := parse.ParseValue(s.columns[i].Type, value)
if err != nil {
return fmt.Errorf("failed to parse column: %q: %w", s.columns[i].Name, err)
}

values[i] = parsedValue
}
values[i] = parsedValue
}

return nil
return nil
}

0 comments on commit d9b9770

Please sign in to comment.