Skip to content

Commit

Permalink
Added support for CopyFrom with string values - closes #16
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Aug 5, 2024
1 parent a6ab4f8 commit b36ce70
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.2.2 (unreleased)

- Added support for `CopyFrom` with `string` values

## 0.2.1 (2024-07-23)

- Added `pgx` package
Expand Down
12 changes: 11 additions & 1 deletion pgx/sparsevec.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,19 @@ func (encodePlanSparseVectorCodecBinary) Encode(value any, buf []byte) (newBuf [

type scanPlanSparseVectorCodecBinary struct{}

type scanPlanSparseVectorCodecText struct{}

func (SparseVectorCodec) PlanScan(m *pgtype.Map, oid uint32, format int16, target any) pgtype.ScanPlan {
_, ok := target.(*pgvector.SparseVector)
if !ok {
return nil
}

if format == pgx.BinaryFormatCode {
switch format {
case pgx.BinaryFormatCode:
return scanPlanSparseVectorCodecBinary{}
case pgx.TextFormatCode:
return scanPlanSparseVectorCodecText{}
}

return nil
Expand All @@ -59,6 +64,11 @@ func (scanPlanSparseVectorCodecBinary) Scan(src []byte, dst any) error {
return v.DecodeBinary(src)
}

func (scanPlanSparseVectorCodecText) Scan(src []byte, dst any) error {
v := (dst).(*pgvector.SparseVector)
return v.Scan(src)
}

func (c SparseVectorCodec) DecodeDatabaseSQLValue(m *pgtype.Map, oid uint32, format int16, src []byte) (driver.Value, error) {
return c.DecodeValue(m, oid, format, src)
}
Expand Down
12 changes: 11 additions & 1 deletion pgx/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,19 @@ func (encodePlanVectorCodecBinary) Encode(value any, buf []byte) (newBuf []byte,

type scanPlanVectorCodecBinary struct{}

type scanPlanVectorCodecText struct{}

func (VectorCodec) PlanScan(m *pgtype.Map, oid uint32, format int16, target any) pgtype.ScanPlan {
_, ok := target.(*pgvector.Vector)
if !ok {
return nil
}

if format == pgx.BinaryFormatCode {
switch format {
case pgx.BinaryFormatCode:
return scanPlanVectorCodecBinary{}
case pgx.TextFormatCode:
return scanPlanVectorCodecText{}
}

return nil
Expand All @@ -59,6 +64,11 @@ func (scanPlanVectorCodecBinary) Scan(src []byte, dst any) error {
return v.DecodeBinary(src)
}

func (scanPlanVectorCodecText) Scan(src []byte, dst any) error {
v := (dst).(*pgvector.Vector)
return v.Scan(src)
}

func (c VectorCodec) DecodeDatabaseSQLValue(m *pgtype.Map, oid uint32, format int16, src []byte) (driver.Value, error) {
return c.DecodeValue(m, oid, format, src)
}
Expand Down
21 changes: 20 additions & 1 deletion pgx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestPgx(t *testing.T) {
panic(err)
}

_, err = conn.Exec(ctx, "CREATE TABLE pgx_items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))")
_, err = conn.Exec(ctx, "CREATE TABLE pgx_items (id bigserial, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3), score float8)")
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -115,4 +115,23 @@ func TestPgx(t *testing.T) {
if distances[0] != 0 || distances[1] != 1 || distances[2] != math.Sqrt(3) {
t.Error()
}

var item PgxItem
row := conn.QueryRow(ctx, "SELECT embedding, sparse_embedding FROM pgx_items ORDER BY id LIMIT 1", pgx.QueryResultFormats{pgx.TextFormatCode, pgx.TextFormatCode})
err = row.Scan(&item.Embedding, &item.SparseEmbedding)
if err != nil {
panic(err)
}

_, err = conn.CopyFrom(
ctx,
pgx.Identifier{"pgx_items"},
[]string{"embedding", "binary_embedding", "sparse_embedding"},
pgx.CopyFromSlice(1, func(i int) ([]any, error) {
return []interface{}{"[1,2,3]", "101", "{1:1,2:2,3:3}/3"}, nil
}),
)
if err != nil {
panic(err)
}
}

0 comments on commit b36ce70

Please sign in to comment.