diff --git a/lib/mssql/scanner.go b/lib/mssql/scanner.go index 03c40fef..4e57dbf1 100644 --- a/lib/mssql/scanner.go +++ b/lib/mssql/scanner.go @@ -1,136 +1,223 @@ package mssql import ( - "database/sql" - "fmt" - "slices" - "strings" - - "github.com/artie-labs/transfer/clients/mssql/dialect" - - "github.com/artie-labs/reader/lib/mssql/parse" - "github.com/artie-labs/reader/lib/mssql/schema" - "github.com/artie-labs/reader/lib/rdbms" - "github.com/artie-labs/reader/lib/rdbms/column" - "github.com/artie-labs/reader/lib/rdbms/primary_key" - "github.com/artie-labs/reader/lib/rdbms/scan" + "database/sql" + "fmt" + "slices" + "strings" + "time" + + "github.com/artie-labs/transfer/clients/mssql/dialect" + + "github.com/artie-labs/reader/lib/mssql/parse" + "github.com/artie-labs/reader/lib/mssql/schema" + "github.com/artie-labs/reader/lib/rdbms" + "github.com/artie-labs/reader/lib/rdbms/column" + "github.com/artie-labs/reader/lib/rdbms/primary_key" + "github.com/artie-labs/reader/lib/rdbms/scan" +) + +const ( + TimeMicro = "15:04:05.000000" + TimeNano = "15:04:05.000000000" + DateTimeMicro = "2006-01-02 15:04:05.000000" + DateTimeNano = "2006-01-02 15:04:05.000000000" + DateTimeOffset = "2006-01-02 15:04:05.0000000 -07:00" ) var supportedPrimaryKeyDataType = []schema.DataType{ - schema.Bit, - schema.Bytes, - schema.Int16, - schema.Int32, - schema.Int64, - schema.Numeric, - schema.Float, - schema.Money, - schema.Date, - schema.String, - schema.Time, - schema.TimeMicro, - schema.TimeNano, - schema.Datetime2, - schema.Datetime2Micro, - schema.Datetime2Nano, - schema.DatetimeOffset, + schema.Bit, + schema.Bytes, + schema.Int16, + schema.Int32, + schema.Int64, + schema.Numeric, + schema.Float, + schema.Money, + schema.Date, + schema.String, + schema.Time, + schema.TimeMicro, + schema.TimeNano, + schema.Datetime2, + schema.Datetime2Micro, + schema.Datetime2Nano, + schema.DatetimeOffset, } func NewScanner(db *sql.DB, table Table, columns []schema.Column, cfg scan.ScannerConfig) (*scan.Scanner, error) { - for _, key := range table.PrimaryKeys() { - _column, err := column.ByName(columns, key) - if err != nil { - return nil, fmt.Errorf("missing column with name: %q", key) - } - - if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) { - return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name) - } - } - - primaryKeyBounds, err := table.FetchPrimaryKeysBounds(db) - if err != nil { - return nil, err - } - - adapter := scanAdapter{schema: table.Schema, tableName: table.Name, columns: columns} - return scan.NewScanner(db, primaryKeyBounds, cfg, adapter) + for _, key := range table.PrimaryKeys() { + _column, err := column.ByName(columns, key) + if err != nil { + return nil, fmt.Errorf("missing column with name: %q", key) + } + + if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) { + return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name) + } + } + + primaryKeyBounds, err := table.FetchPrimaryKeysBounds(db) + if err != nil { + return nil, err + } + + adapter := scanAdapter{schema: table.Schema, tableName: table.Name, columns: columns} + return scan.NewScanner(db, primaryKeyBounds, cfg, adapter) } type scanAdapter struct { - schema string - tableName string - columns []schema.Column + schema string + tableName string + columns []schema.Column +} + +func (s scanAdapter) ParsePrimaryKeyValueForOverrides(_ string, value string) (any, error) { + // We don't need to cast it at all. + return value, nil } -func (s scanAdapter) ParsePrimaryKeyValueForOverrides(columnName string, value string) (any, error) { - // TODO: Implement Date, Time, Datetime for primary key types. - columnIdx := slices.IndexFunc(s.columns, func(x schema.Column) bool { return x.Name == columnName }) - if columnIdx < 0 { - return nil, fmt.Errorf("primary key column does not exist: %q", columnName) - } - - _column := s.columns[columnIdx] - if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) { - return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name) - } - - switch _column.Type { - case schema.Bit: - return value == "1", nil - default: - return value, nil - } +func (s scanAdapter) parsePrimaryKeyValues(columnName string, value any) (any, error) { + columnIdx := slices.IndexFunc(s.columns, func(x schema.Column) bool { return x.Name == columnName }) + if columnIdx < 0 { + return nil, fmt.Errorf("primary key column does not exist: %q", columnName) + } + + _column := s.columns[columnIdx] + if !slices.Contains(supportedPrimaryKeyDataType, _column.Type) { + return nil, fmt.Errorf("DataType(%d) for column %q is not supported for use as a primary key", _column.Type, _column.Name) + } + + switch _column.Type { + case schema.Time: + switch castedVal := value.(type) { + case time.Time: + return castedVal.Format(time.TimeOnly), nil + case string: + return castedVal, nil + default: + return nil, fmt.Errorf("expected time.Time type, received: %T", value) + } + case schema.TimeMicro: + switch castedVal := value.(type) { + case time.Time: + return castedVal.Format(TimeMicro), nil + case string: + return castedVal, nil + default: + return nil, fmt.Errorf("expected time.Time type, received: %T", value) + } + case schema.TimeNano: + switch castedVal := value.(type) { + case time.Time: + return castedVal.Format(TimeNano), nil + case string: + return castedVal, nil + default: + return nil, fmt.Errorf("expected time.Time type, received: %T", value) + } + case schema.Datetime2: + switch castedVal := value.(type) { + case time.Time: + return castedVal.Format(time.DateTime), nil + case string: + return castedVal, nil + default: + return nil, fmt.Errorf("expected time.Time type, received: %T", value) + } + case schema.Datetime2Micro: + switch castedVal := value.(type) { + case time.Time: + return castedVal.Format(DateTimeMicro), nil + case string: + return castedVal, nil + default: + return nil, fmt.Errorf("expected time.Time type, received: %T", value) + } + case schema.Datetime2Nano: + switch castedVal := value.(type) { + case time.Time: + return castedVal.Format(DateTimeNano), nil + case string: + return castedVal, nil + default: + return nil, fmt.Errorf("expected time.Time type, received: %T", value) + } + case schema.DatetimeOffset: + switch castedVal := value.(type) { + case time.Time: + return castedVal.Format(DateTimeOffset), nil + case string: + return castedVal, nil + default: + return nil, fmt.Errorf("expected time.Time type, received: %T", value) + } + default: + return value, nil + } } func (s scanAdapter) BuildQuery(primaryKeys []primary_key.Key, isFirstBatch bool, batchSize uint) (string, []any, error) { - mssqlDialect := dialect.MSSQLDialect{} - colNames := make([]string, len(s.columns)) - for idx, col := range s.columns { - colNames[idx] = mssqlDialect.QuoteIdentifier(col.Name) - } - - startingValues := make([]any, len(primaryKeys)) - endingValues := make([]any, len(primaryKeys)) - for i, pk := range primaryKeys { - startingValues[i] = pk.StartingValue - endingValues[i] = pk.EndingValue - } - - quotedKeyNames := make([]string, len(primaryKeys)) - for i, key := range primaryKeys { - quotedKeyNames[i] = mssqlDialect.QuoteIdentifier(key.Name) - } - - lowerBoundComparison := ">" - if isFirstBatch { - lowerBoundComparison = ">=" - } - - return fmt.Sprintf(`SELECT TOP %d %s FROM %s.%s WHERE (%s) %s (%s) AND (%s) <= (%s) ORDER BY %s`, - // TOP - batchSize, - // SELECT - strings.Join(colNames, ","), - // FROM - mssqlDialect.QuoteIdentifier(s.schema), mssqlDialect.QuoteIdentifier(s.tableName), - // WHERE (pk) > (123) - strings.Join(quotedKeyNames, ","), lowerBoundComparison, strings.Join(rdbms.QueryPlaceholders("?", len(startingValues)), ","), - strings.Join(quotedKeyNames, ","), strings.Join(rdbms.QueryPlaceholders("?", len(endingValues)), ","), - // ORDER BY - strings.Join(quotedKeyNames, ","), - ), slices.Concat(startingValues, endingValues), nil + mssqlDialect := dialect.MSSQLDialect{} + colNames := make([]string, len(s.columns)) + for idx, col := range s.columns { + colNames[idx] = mssqlDialect.QuoteIdentifier(col.Name) + } + + startingValues := make([]any, len(primaryKeys)) + endingValues := make([]any, len(primaryKeys)) + for i, pk := range primaryKeys { + pkStartVal, err := s.parsePrimaryKeyValues(pk.Name, pk.StartingValue) + if err != nil { + return "", nil, err + } + + pkEndVal, err := s.parsePrimaryKeyValues(pk.Name, pk.EndingValue) + if err != nil { + return "", nil, err + } + + fmt.Println("pkStartVal", pkStartVal) + fmt.Println("pkEndVal", pkEndVal) + + startingValues[i] = pkStartVal + endingValues[i] = pkEndVal + } + + quotedKeyNames := make([]string, len(primaryKeys)) + for i, key := range primaryKeys { + quotedKeyNames[i] = mssqlDialect.QuoteIdentifier(key.Name) + } + + lowerBoundComparison := ">" + if isFirstBatch { + lowerBoundComparison = ">=" + } + + return fmt.Sprintf(`SELECT TOP %d %s FROM %s.%s WHERE (%s) %s (%s) AND (%s) <= (%s) ORDER BY %s`, + // TOP + batchSize, + // SELECT + strings.Join(colNames, ","), + // FROM + mssqlDialect.QuoteIdentifier(s.schema), mssqlDialect.QuoteIdentifier(s.tableName), + // WHERE (pk) > (123) + strings.Join(quotedKeyNames, ","), lowerBoundComparison, strings.Join(rdbms.QueryPlaceholders("?", len(startingValues)), ","), + strings.Join(quotedKeyNames, ","), strings.Join(rdbms.QueryPlaceholders("?", len(endingValues)), ","), + // ORDER BY + strings.Join(quotedKeyNames, ","), + ), slices.Concat(startingValues, endingValues), nil } func (s scanAdapter) ParseRow(values []any) error { - for i, value := range values { - parsedValue, err := parse.ParseValue(s.columns[i].Type, value) - if err != nil { - return fmt.Errorf("failed to parse column: %q: %w", s.columns[i].Name, err) - } + for i, value := range values { + parsedValue, err := parse.ParseValue(s.columns[i].Type, value) + if err != nil { + return fmt.Errorf("failed to parse column: %q: %w", s.columns[i].Name, err) + } - values[i] = parsedValue - } + values[i] = parsedValue + } - return nil + return nil }