Skip to content

Commit

Permalink
[Postgres] Improved handling of bit (#503)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored Oct 5, 2024
1 parent 901b6e6 commit 1e06d1a
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 65 deletions.
36 changes: 29 additions & 7 deletions integration_tests/postgres/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"cmp"
"database/sql"
"encoding/json"
"errors"
Expand All @@ -26,13 +27,8 @@ func main() {
}
slog.SetDefault(slog.New(tint.NewHandler(os.Stderr, &tint.Options{})))

var pgHost string = os.Getenv("PG_HOST")
if pgHost == "" {
pgHost = "localhost"
}

var pgConfig = config.PostgreSQL{
Host: pgHost,
pgConfig := config.PostgreSQL{
Host: cmp.Or(os.Getenv("PG_HOST"), "localhost"),
Port: 5432,
Username: "postgres",
Password: "postgres",
Expand Down Expand Up @@ -80,6 +76,8 @@ CREATE TABLE %s (
c_bigint bigint,
c_bigserial bigserial,
c_bit bit,
c_bit1 bit(1),
c_bit5 bit(5),
c_boolean boolean,
-- c_box box,
c_bytea bytea,
Expand Down Expand Up @@ -150,6 +148,10 @@ INSERT INTO %s VALUES (
100000123100000123,
-- c_bit
B'1',
-- c_bit1
B'1',
-- c_bit5
B'10101',
-- c_boolean
true,
-- c_box
Expand Down Expand Up @@ -300,6 +302,24 @@ const expectedPayloadTemplate = `{
"name": "",
"parameters": null
},
{
"type": "boolean",
"optional": false,
"default": null,
"field": "c_bit1",
"name": "",
"parameters": null
},
{
"type": "bytes",
"optional": false,
"default": null,
"field": "c_bit5",
"name": "io.debezium.data.Bits",
"parameters": {
"length": "5"
}
},
{
"type": "boolean",
"optional": false,
Expand Down Expand Up @@ -664,6 +684,8 @@ const expectedPayloadTemplate = `{
"c_bigint": 9009900990099009000,
"c_bigserial": 100000123100000123,
"c_bit": true,
"c_bit1": true,
"c_bit5": "FQ==",
"c_boolean": true,
"c_bytea": "YWJjIGtsbSAqqVQ=",
"c_character": "X",
Expand Down
77 changes: 65 additions & 12 deletions lib/debezium/converters/bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,81 @@ import (
"fmt"
"github.com/artie-labs/transfer/lib/debezium"
"github.com/artie-labs/transfer/lib/typing"
"math/big"
)

type BitConverter struct{}
func NewBitConverter(charMaxLength int) BitConverter {
return BitConverter{charMaxLength: charMaxLength}
}

type BitConverter struct {
charMaxLength int
}

func (BitConverter) ToField(name string) debezium.Field {
return debezium.Field{
FieldName: name,
Type: debezium.Boolean,
func (b BitConverter) ToField(name string) debezium.Field {
switch b.charMaxLength {
case 1:
return debezium.Field{FieldName: name, Type: debezium.Boolean}
default:
return debezium.Field{
FieldName: name,
DebeziumType: debezium.Bits,
Type: debezium.Bytes,
Parameters: map[string]any{"length": fmt.Sprint(b.charMaxLength)},
}
}
}

func (BitConverter) Convert(value any) (any, error) {
// This will be 0 (false) or 1 (true)
func (b BitConverter) Convert(value any) (any, error) {
stringValue, err := typing.AssertType[string](value)
if err != nil {
return nil, err
}

if stringValue == "0" {
return false, nil
} else if stringValue == "1" {
return true, nil
if b.charMaxLength == 0 {
return nil, fmt.Errorf("bit converter failed: invalid char max length")
}

if len(stringValue) != b.charMaxLength {
return nil, fmt.Errorf("bit converter failed: mismatched char max length, value: %q, length: %d", stringValue, len(stringValue))
}

switch b.charMaxLength {
case 1:
// For bit, bit(1) - We will convert these to booleans
if stringValue == "0" {
return false, nil
} else if stringValue == "1" {
return true, nil
}
return nil, fmt.Errorf(`string value %q is not in ["0", "1"]`, value)
default:
for _, char := range stringValue {
if char != '0' && char != '1' {
return nil, fmt.Errorf("invalid binary string %q: contains non-binary characters", stringValue)
}
}

return stringToByteA(stringValue)
}
}

// stringToByteA - Converts an integer to a byte array of the specified length, using little endian, which mirrors the same logic as java.util.BitSet
// Ref: https://docs.oracle.com/javase/7/docs/api/java/util/BitSet.html
func stringToByteA(stringValue string) ([]byte, error) {
var intValue big.Int
_, isOk := intValue.SetString(stringValue, 2)
if !isOk {
return nil, fmt.Errorf("failed to parse binary string: %q", stringValue)
}

// Reverse the byte array to get little-endian order as Go's big.Int uses big-endian
return reverseBytes(intValue.Bytes()), nil
}

func reverseBytes(b []byte) []byte {
for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 {
b[i], b[j] = b[j], b[i]
}
return nil, fmt.Errorf(`string value %q is not in ["0", "1"]`, value)
return b
}
120 changes: 105 additions & 15 deletions lib/debezium/converters/bit_test.go
Original file line number Diff line number Diff line change
@@ -1,33 +1,123 @@
package converters

import (
"testing"

"github.com/artie-labs/transfer/lib/debezium"
"github.com/stretchr/testify/assert"
"testing"
)

func TestBitConverter_ToField(t *testing.T) {
{
// char size not specified
field := NewBitConverter(0).ToField("foo")
assert.Equal(t, "foo", field.FieldName)
assert.Equal(t, "bytes", string(field.Type))
assert.Equal(t, debezium.Bits, field.DebeziumType)
assert.Equal(t, map[string]interface{}{"length": "0"}, field.Parameters)
}
{
// char max size 1
field := NewBitConverter(1).ToField("foo")
assert.Equal(t, "foo", field.FieldName)
assert.Equal(t, "boolean", string(field.Type))
}
{
// char max size 5
field := NewBitConverter(5).ToField("foo")
assert.Equal(t, "foo", field.FieldName)
assert.Equal(t, "bytes", string(field.Type))
assert.Equal(t, debezium.Bits, field.DebeziumType)
assert.Equal(t, map[string]interface{}{"length": "5"}, field.Parameters)
}
}

func TestBitConverter_Convert(t *testing.T) {
converter := BitConverter{}
{
// Invalid value - wrong type
_, err := converter.Convert(1234)
assert.ErrorContains(t, err, "expected type string, got int")
// char size not specified
_, err := BitConverter{}.Convert("foo")
assert.ErrorContains(t, err, "bit converter failed: invalid char max length")
}
{
// char max size 1
converter := NewBitConverter(1)
{
// Invalid value - wrong type
_, err := converter.Convert(1234)
assert.ErrorContains(t, err, "expected type string, got int")
}
{
// Valid value - 0
value, err := converter.Convert("0")
assert.NoError(t, err)
assert.Equal(t, false, value)
}
{
// Valid value - 1
value, err := converter.Convert("1")
assert.NoError(t, err)
assert.Equal(t, true, value)
}
{
// Invalid value - 2
_, err := converter.Convert("2")
assert.ErrorContains(t, err, `string value "2" is not in ["0", "1"]`)
}
}
{
// char max size - 5
{
// Invalid, length not matching
converter := NewBitConverter(5)
_, err := converter.Convert("101111")
assert.ErrorContains(t, err, "bit converter failed: mismatched char max length")
}
{
// Invalid, value contains non 0s and 1s
converter := NewBitConverter(5)
_, err := converter.Convert("1011a")
assert.ErrorContains(t, err, "invalid binary string")
}
{
// Valid
converter := NewBitConverter(5)
value, err := converter.Convert("10101")
assert.NoError(t, err)
assert.Equal(t, []byte{21}, value)
}
{
// Valid #2
converter := NewBitConverter(5)
value, err := converter.Convert("10011")
assert.NoError(t, err)
assert.Equal(t, []byte{19}, value)
}
}
{
// Valid value - 0
value, err := converter.Convert("0")
// char max size - 10
converter := NewBitConverter(10)
value, err := converter.Convert("1000000011")
assert.NoError(t, err)
assert.Equal(t, false, value)
assert.Equal(t, []byte{3, 2}, value)
}
{
// Valid value - 1
value, err := converter.Convert("1")
// char max size - 17
converter := NewBitConverter(17)
value, err := converter.Convert("10000000111111111")
assert.NoError(t, err)
assert.Equal(t, true, value)
assert.Equal(t, []byte{255, 1, 1}, value)
}
{
// Invalid value - 2
_, err := converter.Convert("2")
assert.ErrorContains(t, err, `string value "2" is not in ["0", "1"]`)
// char max size - 24
converter := NewBitConverter(24)
value, err := converter.Convert("110110101111000111100101")
assert.NoError(t, err)
assert.Equal(t, []byte{229, 241, 218}, value)
}
{
// char max size - 240 (which exceeds int64)
converter := NewBitConverter(240)
value, err := converter.Convert("110110101111000111100101110110101111000111100101110110101111000111100101110110101111000111100101110110101111000111100101110110101111000111100101110110101111000111100101110110101111000111100101110110101111000111100101110110101111000111100101")
assert.NoError(t, err)
assert.Equal(t, []byte{229, 241, 218, 229, 241, 218, 229, 241, 218, 229, 241, 218, 229, 241, 218, 229, 241, 218, 229, 241, 218, 229, 241, 218, 229, 241, 218, 229, 241, 218}, value)
}
}
20 changes: 13 additions & 7 deletions lib/postgres/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@ const (
)

type Opts struct {
Scale uint16
Precision int
Scale uint16
Precision int
CharMaxLength int
}

type Column = column.Column[DataType, Opts]

const describeTableQuery = `
SELECT column_name, data_type, numeric_precision, numeric_scale, udt_name
SELECT column_name, data_type, numeric_precision, numeric_scale, udt_name, character_maximum_length
FROM information_schema.columns
WHERE table_schema = $1 AND table_name = $2`

Expand All @@ -73,12 +74,13 @@ func DescribeTable(db *sql.DB, _schema, table string) ([]Column, error) {
var numericPrecision *int
var numericScale *uint16
var udtName *string
err = rows.Scan(&colName, &colType, &numericPrecision, &numericScale, &udtName)
var charMaxLength *int
err = rows.Scan(&colName, &colType, &numericPrecision, &numericScale, &udtName, &charMaxLength)
if err != nil {
return nil, err
}

dataType, opts, err := ParseColumnDataType(colType, numericPrecision, numericScale, udtName)
dataType, opts, err := parseColumnDataType(colType, numericPrecision, numericScale, charMaxLength, udtName)
if err != nil {
return nil, fmt.Errorf("unable to identify type %q for column %q", colType, colName)
}
Expand All @@ -92,11 +94,15 @@ func DescribeTable(db *sql.DB, _schema, table string) ([]Column, error) {
return cols, nil
}

func ParseColumnDataType(colKind string, precision *int, scale *uint16, udtName *string) (DataType, *Opts, error) {
func parseColumnDataType(colKind string, precision *int, scale *uint16, charMaxLength *int, udtName *string) (DataType, *Opts, error) {
colKind = strings.ToLower(colKind)
switch colKind {
case "bit":
return Bit, nil, nil
if charMaxLength == nil {
return -1, nil, fmt.Errorf("invalid bit column: missing character maximum length")
}

return Bit, &Opts{CharMaxLength: *charMaxLength}, nil
case "boolean":
return Boolean, nil, nil
case "smallint":
Expand Down
Loading

0 comments on commit 1e06d1a

Please sign in to comment.