Skip to content

Commit

Permalink
support UPSERTs and UPDATEs in WriteTableQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
qrort committed Jul 29, 2024
1 parent 3ca6dc1 commit 90a23a4
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 26 deletions.
107 changes: 82 additions & 25 deletions internal/connectors/db/yql/queries/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,20 @@ type WriteSingleTableQueryImpl struct {
tableName string
upsertFields []string
tableQueryParams []table.ParameterOption
updateParam *table.ParameterOption
}

func (d *WriteSingleTableQueryImpl) AddValueParam(name string, value table_types.Value) {
d.upsertFields = append(d.upsertFields, name[1:])
d.tableQueryParams = append(d.tableQueryParams, table.ValueParam(fmt.Sprintf("%s_%d", name, d.index), value))
}

func (d *WriteSingleTableQueryImpl) AddUpdateId(value table_types.Value) {
updateParamName := fmt.Sprintf("$id_%d", d.index)
vp := table.ValueParam(updateParamName, value)
d.updateParam = &vp
}

func (d *WriteSingleTableQueryImpl) GetParamNames() []string {
res := make([]string, len(d.tableQueryParams))
for i, p := range d.tableQueryParams {
Expand Down Expand Up @@ -73,12 +80,14 @@ func BuildCreateOperationQuery(operation types.Operation, index int) WriteSingle
)
d.AddValueParam(
"$created_at",
table_types.StringValueFromString(""), //TODO
table_types.TimestampValueFromTime(tb.CreatedAt),
)
d.AddValueParam(
"$operation_id",
table_types.StringValueFromString(tb.YdbOperationId),
)
} else {
panic("Implement me")
}

return d
Expand All @@ -89,7 +98,7 @@ func BuildUpdateOperationQuery(operation types.Operation, index int) WriteSingle
index: index,
tableName: "Operations",
}
d.AddValueParam("$id", table_types.UUIDValue(operation.GetId()))
d.AddUpdateId(table_types.UUIDValue(operation.GetId()))
d.AddValueParam(
"$status", table_types.StringValueFromString(operation.GetState().String()),
)
Expand All @@ -105,7 +114,7 @@ func BuildUpdateBackupQuery(backup types.Backup, index int) WriteSingleTableQuer
index: index,
tableName: "Backups",
}
d.AddValueParam("$id", table_types.UUIDValue(backup.ID))
d.AddUpdateId(table_types.UUIDValue(backup.ID))
d.AddValueParam("$status", table_types.StringValueFromString(backup.Status))
return d
}
Expand Down Expand Up @@ -160,35 +169,83 @@ func (d *WriteTableQueryImpl) WithCreateOperation(operation types.Operation) Wri
}

func (d *WriteSingleTableQueryImpl) DeclareParameters() string {
declares := make([]string, len(d.tableQueryParams))
for i, param := range d.tableQueryParams {
declares[i] = fmt.Sprintf("DECLARE %s AS %s", param.Name(), param.Value().Type().String())
declares := make([]string, 0)
if d.updateParam != nil {
declares = append(
declares,
fmt.Sprintf("DECLARE %s AS %s", (*d.updateParam).Name(), (*d.updateParam).Value().Type().String()),
)
}
for _, param := range d.tableQueryParams {
declares = append(
declares, fmt.Sprintf("DECLARE %s AS %s", param.Name(), param.Value().Type().String()),
)
}
return strings.Join(declares, ";\n")
}

func ProcessUpsertQuery(
queryStrings *[]string, allParams *[]table.ParameterOption, t *WriteSingleTableQueryImpl,
) error {
if len(t.upsertFields) == 0 {
return errors.New("No fields to upsert")
}
if len(t.tableName) == 0 {
return errors.New("No table")
}
declares := t.DeclareParameters()
*queryStrings = append(
*queryStrings, fmt.Sprintf(
"%s;\nUPSERT INTO %s (%s) VALUES (%s)", declares, t.tableName, strings.Join(t.upsertFields, ", "),
strings.Join(t.GetParamNames(), ", "),
),
)
for _, p := range t.tableQueryParams {
*allParams = append(*allParams, p)
}
return nil
}

func ProcessUpdateQuery(
queryStrings *[]string, allParams *[]table.ParameterOption, t *WriteSingleTableQueryImpl,
) error {
if len(t.upsertFields) == 0 {
return errors.New("No fields to upsert")
}
if len(t.tableName) == 0 {
return errors.New("No table")
}
declares := t.DeclareParameters()
paramNames := t.GetParamNames()
keyParam := fmt.Sprintf("id = %s", (*t.updateParam).Name())
updates := make([]string, 0)
for i := range t.upsertFields {
updates = append(updates, fmt.Sprintf("%s = %s", t.upsertFields[i], paramNames[i]))
}
*queryStrings = append(
*queryStrings, fmt.Sprintf(
"%s;\nUPDATE %s SET %s WHERE %s", declares, t.tableName, strings.Join(updates, ", "), keyParam,
),
)
*allParams = append(*allParams, *t.updateParam)
for _, p := range t.tableQueryParams {
*allParams = append(*allParams, p)
}
return nil
}

func (d *WriteTableQueryImpl) FormatQuery(ctx context.Context) (*FormatQueryResult, error) {
queryStrings := make([]string, len(d.tableQueries))
queryStrings := make([]string, 0)
allParams := make([]table.ParameterOption, 0)
for i, t := range d.tableQueries {
if len(t.upsertFields) == 0 {
return nil, errors.New("No fields to upsert")
}
if len(t.tableName) == 0 {
return nil, errors.New("No table")
for _, t := range d.tableQueries {
var err error
if t.updateParam == nil {
err = ProcessUpsertQuery(&queryStrings, &allParams, &t)
} else {
err = ProcessUpdateQuery(&queryStrings, &allParams, &t)
}
declares := t.DeclareParameters()
paramNames := t.GetParamNames()
keyParam := fmt.Sprintf("%s = %s", t.upsertFields[0], paramNames[0])
updates := make([]string, 0)
for j := 1; j < len(t.upsertFields); j++ {
updates = append(updates, fmt.Sprintf("%s = %s", t.upsertFields[j], paramNames[j]))
}
queryStrings[i] = fmt.Sprintf(
"%s;\nUPDATE %s SET %s WHERE %s", declares, t.tableName, strings.Join(updates, ", "), keyParam,
)
for _, p := range t.tableQueryParams {
allParams = append(allParams, p)
if err != nil {
return nil, err
}
}
res := strings.Join(queryStrings, ";\n")
Expand Down
189 changes: 188 additions & 1 deletion internal/connectors/db/yql/queries/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ import (
"github.com/ydb-platform/ydb-go-sdk/v3/table"
table_types "github.com/ydb-platform/ydb-go-sdk/v3/table/types"
"testing"
"time"
"ydbcp/internal/types"
)

func TestQueryBuilder_Write(t *testing.T) {
func TestQueryBuilder_UpdateUpdate(t *testing.T) {
const (
queryString = `DECLARE $id_0 AS Uuid;
DECLARE $status_0 AS String;
Expand Down Expand Up @@ -50,3 +51,189 @@ UPDATE Operations SET status = $status_1, message = $message_1 WHERE id = $id_1`
)
assert.Equal(t, queryParams, query.QueryParams, "bad query params")
}

func TestQueryBuilder_CreateCreate(t *testing.T) {
const (
queryString = `DECLARE $id_0 AS Uuid;
DECLARE $container_id_0 AS String;
DECLARE $database_0 AS String;
DECLARE $initiated_0 AS String;
DECLARE $s3_endpoint_0 AS String;
DECLARE $s3_region_0 AS String;
DECLARE $s3_bucket_0 AS String;
DECLARE $s3_path_prefix_0 AS String;
DECLARE $status_0 AS String;
UPSERT INTO Backups (id, container_id, database, initiated, s3_endpoint, s3_region, s3_bucket, s3_path_prefix, status) VALUES ($id_0, $container_id_0, $database_0, $initiated_0, $s3_endpoint_0, $s3_region_0, $s3_bucket_0, $s3_path_prefix_0, $status_0);
DECLARE $id_1 AS Uuid;
DECLARE $type_1 AS String;
DECLARE $status_1 AS String;
DECLARE $container_id_1 AS String;
DECLARE $database_1 AS String;
DECLARE $backup_id_1 AS Uuid;
DECLARE $initiated_1 AS String;
DECLARE $created_at_1 AS Timestamp;
DECLARE $operation_id_1 AS String;
UPSERT INTO Operations (id, type, status, container_id, database, backup_id, initiated, created_at, operation_id) VALUES ($id_1, $type_1, $status_1, $container_id_1, $database_1, $backup_id_1, $initiated_1, $created_at_1, $operation_id_1)`
)
opId := types.GenerateObjectID()
backupId := types.GenerateObjectID()
tbOp := types.TakeBackupOperation{
Id: opId,
ContainerID: "a",
BackupId: backupId,
State: "PENDING",
Message: "Message",
YdbConnectionParams: types.YdbConnectionParams{
Endpoint: "",
DatabaseName: "dbname",
},
YdbOperationId: "1234",
SourcePaths: nil,
SourcePathToExclude: nil,
CreatedAt: time.Unix(0, 0),
}
backup := types.Backup{
ID: backupId,
ContainerID: "a",
DatabaseName: "b",
S3Endpoint: "c",
S3Region: "d",
S3Bucket: "e",
S3PathPrefix: "f",
Status: "Available",
}
builder := NewWriteTableQuery().
WithCreateBackup(backup).
WithCreateOperation(&tbOp)
var (
queryParams = table.NewQueryParameters(
table.ValueParam("$id_0", table_types.UUIDValue(backupId)),
table.ValueParam("$container_id_0", table_types.StringValueFromString("a")),
table.ValueParam("$database_0", table_types.StringValueFromString("b")),
table.ValueParam("$initiated_0", table_types.StringValueFromString("")),
table.ValueParam("$s3_endpoint_0", table_types.StringValueFromString("c")),
table.ValueParam("$s3_region_0", table_types.StringValueFromString("d")),
table.ValueParam("$s3_bucket_0", table_types.StringValueFromString("e")),
table.ValueParam("$s3_path_prefix_0", table_types.StringValueFromString("f")),
table.ValueParam("$status_0", table_types.StringValueFromString("Available")),
table.ValueParam("$id_1", table_types.UUIDValue(opId)),
table.ValueParam("$type_1", table_types.StringValueFromString("TB")),
table.ValueParam(
"$status_1", table_types.StringValueFromString(string(tbOp.State)),
),
table.ValueParam(
"$container_id_1", table_types.StringValueFromString(tbOp.ContainerID),
),
table.ValueParam(
"$database_1",
table_types.StringValueFromString(tbOp.YdbConnectionParams.DatabaseName),
),
table.ValueParam(
"$backup_id_1",
table_types.UUIDValue(tbOp.BackupId),
),
table.ValueParam(
"$initiated_1",
table_types.StringValueFromString(""),
),
table.ValueParam(
"$created_at_1",
table_types.TimestampValueFromTime(tbOp.CreatedAt),
),
table.ValueParam(
"$operation_id_1",
table_types.StringValueFromString(tbOp.YdbOperationId),
),
)
)
query, err := builder.FormatQuery(context.Background())
assert.Empty(t, err)
assert.Equal(
t, queryString, query.QueryText,
"bad query format",
)
assert.Equal(t, queryParams, query.QueryParams, "bad query params")
}

func TestQueryBuilder_UpdateCreate(t *testing.T) {
const (
queryString = `DECLARE $id_0 AS Uuid;
DECLARE $status_0 AS String;
UPDATE Backups SET status = $status_0 WHERE id = $id_0;
DECLARE $id_1 AS Uuid;
DECLARE $type_1 AS String;
DECLARE $status_1 AS String;
DECLARE $container_id_1 AS String;
DECLARE $database_1 AS String;
DECLARE $backup_id_1 AS Uuid;
DECLARE $initiated_1 AS String;
DECLARE $created_at_1 AS Timestamp;
DECLARE $operation_id_1 AS String;
UPSERT INTO Operations (id, type, status, container_id, database, backup_id, initiated, created_at, operation_id) VALUES ($id_1, $type_1, $status_1, $container_id_1, $database_1, $backup_id_1, $initiated_1, $created_at_1, $operation_id_1)`
)
opId := types.GenerateObjectID()
backupId := types.GenerateObjectID()
tbOp := types.TakeBackupOperation{
Id: opId,
ContainerID: "a",
BackupId: backupId,
State: "PENDING",
Message: "Message",
YdbConnectionParams: types.YdbConnectionParams{
Endpoint: "",
DatabaseName: "dbname",
},
YdbOperationId: "1234",
SourcePaths: nil,
SourcePathToExclude: nil,
CreatedAt: time.Unix(0, 0),
}
backup := types.Backup{
ID: backupId,
Status: "Available",
}
builder := NewWriteTableQuery().
WithUpdateBackup(backup).
WithCreateOperation(&tbOp)
var (
queryParams = table.NewQueryParameters(
table.ValueParam("$id_0", table_types.UUIDValue(backupId)),
table.ValueParam("$status_0", table_types.StringValueFromString("Available")),
table.ValueParam("$id_1", table_types.UUIDValue(opId)),
table.ValueParam("$type_1", table_types.StringValueFromString("TB")),
table.ValueParam(
"$status_1", table_types.StringValueFromString(string(tbOp.State)),
),
table.ValueParam(
"$container_id_1", table_types.StringValueFromString(tbOp.ContainerID),
),
table.ValueParam(
"$database_1",
table_types.StringValueFromString(tbOp.YdbConnectionParams.DatabaseName),
),
table.ValueParam(
"$backup_id_1",
table_types.UUIDValue(tbOp.BackupId),
),
table.ValueParam(
"$initiated_1",
table_types.StringValueFromString(""),
),
table.ValueParam(
"$created_at_1",
table_types.TimestampValueFromTime(tbOp.CreatedAt),
),
table.ValueParam(
"$operation_id_1",
table_types.StringValueFromString(tbOp.YdbOperationId),
),
)
)
query, err := builder.FormatQuery(context.Background())
assert.Empty(t, err)
assert.Equal(
t, queryString, query.QueryText,
"bad query format",
)
assert.Equal(t, queryParams, query.QueryParams, "bad query params")
}

0 comments on commit 90a23a4

Please sign in to comment.