Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support UPSERTs and UPDATEs in WriteTableQuery #31

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 82 additions & 25 deletions internal/connectors/db/yql/queries/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,20 @@ type WriteSingleTableQueryImpl struct {
tableName string
upsertFields []string
tableQueryParams []table.ParameterOption
updateParam *table.ParameterOption
}

func (d *WriteSingleTableQueryImpl) AddValueParam(name string, value table_types.Value) {
d.upsertFields = append(d.upsertFields, name[1:])
d.tableQueryParams = append(d.tableQueryParams, table.ValueParam(fmt.Sprintf("%s_%d", name, d.index), value))
}

func (d *WriteSingleTableQueryImpl) AddUpdateId(value table_types.Value) {
updateParamName := fmt.Sprintf("$id_%d", d.index)
vp := table.ValueParam(updateParamName, value)
d.updateParam = &vp
}

func (d *WriteSingleTableQueryImpl) GetParamNames() []string {
res := make([]string, len(d.tableQueryParams))
for i, p := range d.tableQueryParams {
Expand Down Expand Up @@ -73,12 +80,14 @@ func BuildCreateOperationQuery(operation types.Operation, index int) WriteSingle
)
d.AddValueParam(
"$created_at",
table_types.StringValueFromString(""), //TODO
table_types.TimestampValueFromTime(tb.CreatedAt),
)
d.AddValueParam(
"$operation_id",
table_types.StringValueFromString(tb.YdbOperationId),
)
} else {
panic("Implement me")
}

return d
Expand All @@ -89,7 +98,7 @@ func BuildUpdateOperationQuery(operation types.Operation, index int) WriteSingle
index: index,
tableName: "Operations",
}
d.AddValueParam("$id", table_types.UUIDValue(operation.GetId()))
d.AddUpdateId(table_types.UUIDValue(operation.GetId()))
d.AddValueParam(
"$status", table_types.StringValueFromString(operation.GetState().String()),
)
Expand All @@ -105,7 +114,7 @@ func BuildUpdateBackupQuery(backup types.Backup, index int) WriteSingleTableQuer
index: index,
tableName: "Backups",
}
d.AddValueParam("$id", table_types.UUIDValue(backup.ID))
d.AddUpdateId(table_types.UUIDValue(backup.ID))
d.AddValueParam("$status", table_types.StringValueFromString(backup.Status))
return d
}
Expand Down Expand Up @@ -160,35 +169,83 @@ func (d *WriteTableQueryImpl) WithCreateOperation(operation types.Operation) Wri
}

func (d *WriteSingleTableQueryImpl) DeclareParameters() string {
declares := make([]string, len(d.tableQueryParams))
for i, param := range d.tableQueryParams {
declares[i] = fmt.Sprintf("DECLARE %s AS %s", param.Name(), param.Value().Type().String())
declares := make([]string, 0)
if d.updateParam != nil {
declares = append(
declares,
fmt.Sprintf("DECLARE %s AS %s", (*d.updateParam).Name(), (*d.updateParam).Value().Type().String()),
)
}
for _, param := range d.tableQueryParams {
declares = append(
declares, fmt.Sprintf("DECLARE %s AS %s", param.Name(), param.Value().Type().String()),
)
}
return strings.Join(declares, ";\n")
}

func ProcessUpsertQuery(
queryStrings *[]string, allParams *[]table.ParameterOption, t *WriteSingleTableQueryImpl,
) error {
if len(t.upsertFields) == 0 {
return errors.New("No fields to upsert")
}
if len(t.tableName) == 0 {
return errors.New("No table")
}
declares := t.DeclareParameters()
*queryStrings = append(
*queryStrings, fmt.Sprintf(
"%s;\nUPSERT INTO %s (%s) VALUES (%s)", declares, t.tableName, strings.Join(t.upsertFields, ", "),
strings.Join(t.GetParamNames(), ", "),
),
)
for _, p := range t.tableQueryParams {
*allParams = append(*allParams, p)
}
return nil
}

func ProcessUpdateQuery(
queryStrings *[]string, allParams *[]table.ParameterOption, t *WriteSingleTableQueryImpl,
) error {
if len(t.upsertFields) == 0 {
return errors.New("No fields to upsert")
}
if len(t.tableName) == 0 {
return errors.New("No table")
}
declares := t.DeclareParameters()
paramNames := t.GetParamNames()
keyParam := fmt.Sprintf("id = %s", (*t.updateParam).Name())
updates := make([]string, 0)
for i := range t.upsertFields {
updates = append(updates, fmt.Sprintf("%s = %s", t.upsertFields[i], paramNames[i]))
}
*queryStrings = append(
*queryStrings, fmt.Sprintf(
"%s;\nUPDATE %s SET %s WHERE %s", declares, t.tableName, strings.Join(updates, ", "), keyParam,
),
)
*allParams = append(*allParams, *t.updateParam)
for _, p := range t.tableQueryParams {
*allParams = append(*allParams, p)
}
return nil
}

func (d *WriteTableQueryImpl) FormatQuery(ctx context.Context) (*FormatQueryResult, error) {
queryStrings := make([]string, len(d.tableQueries))
queryStrings := make([]string, 0)
allParams := make([]table.ParameterOption, 0)
for i, t := range d.tableQueries {
if len(t.upsertFields) == 0 {
return nil, errors.New("No fields to upsert")
}
if len(t.tableName) == 0 {
return nil, errors.New("No table")
for _, t := range d.tableQueries {
var err error
if t.updateParam == nil {
err = ProcessUpsertQuery(&queryStrings, &allParams, &t)
} else {
err = ProcessUpdateQuery(&queryStrings, &allParams, &t)
}
declares := t.DeclareParameters()
paramNames := t.GetParamNames()
keyParam := fmt.Sprintf("%s = %s", t.upsertFields[0], paramNames[0])
updates := make([]string, 0)
for j := 1; j < len(t.upsertFields); j++ {
updates = append(updates, fmt.Sprintf("%s = %s", t.upsertFields[j], paramNames[j]))
}
queryStrings[i] = fmt.Sprintf(
"%s;\nUPDATE %s SET %s WHERE %s", declares, t.tableName, strings.Join(updates, ", "), keyParam,
)
for _, p := range t.tableQueryParams {
allParams = append(allParams, p)
if err != nil {
return nil, err
}
}
res := strings.Join(queryStrings, ";\n")
Expand Down
189 changes: 188 additions & 1 deletion internal/connectors/db/yql/queries/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ import (
"github.com/ydb-platform/ydb-go-sdk/v3/table"
table_types "github.com/ydb-platform/ydb-go-sdk/v3/table/types"
"testing"
"time"
"ydbcp/internal/types"
)

func TestQueryBuilder_Write(t *testing.T) {
func TestQueryBuilder_UpdateUpdate(t *testing.T) {
const (
queryString = `DECLARE $id_0 AS Uuid;
DECLARE $status_0 AS String;
Expand Down Expand Up @@ -50,3 +51,189 @@ UPDATE Operations SET status = $status_1, message = $message_1 WHERE id = $id_1`
)
assert.Equal(t, queryParams, query.QueryParams, "bad query params")
}

func TestQueryBuilder_CreateCreate(t *testing.T) {
const (
queryString = `DECLARE $id_0 AS Uuid;
DECLARE $container_id_0 AS String;
DECLARE $database_0 AS String;
DECLARE $initiated_0 AS String;
DECLARE $s3_endpoint_0 AS String;
DECLARE $s3_region_0 AS String;
DECLARE $s3_bucket_0 AS String;
DECLARE $s3_path_prefix_0 AS String;
DECLARE $status_0 AS String;
UPSERT INTO Backups (id, container_id, database, initiated, s3_endpoint, s3_region, s3_bucket, s3_path_prefix, status) VALUES ($id_0, $container_id_0, $database_0, $initiated_0, $s3_endpoint_0, $s3_region_0, $s3_bucket_0, $s3_path_prefix_0, $status_0);
DECLARE $id_1 AS Uuid;
DECLARE $type_1 AS String;
DECLARE $status_1 AS String;
DECLARE $container_id_1 AS String;
DECLARE $database_1 AS String;
DECLARE $backup_id_1 AS Uuid;
DECLARE $initiated_1 AS String;
DECLARE $created_at_1 AS Timestamp;
DECLARE $operation_id_1 AS String;
UPSERT INTO Operations (id, type, status, container_id, database, backup_id, initiated, created_at, operation_id) VALUES ($id_1, $type_1, $status_1, $container_id_1, $database_1, $backup_id_1, $initiated_1, $created_at_1, $operation_id_1)`
)
opId := types.GenerateObjectID()
backupId := types.GenerateObjectID()
tbOp := types.TakeBackupOperation{
Id: opId,
ContainerID: "a",
BackupId: backupId,
State: "PENDING",
Message: "Message",
YdbConnectionParams: types.YdbConnectionParams{
Endpoint: "",
DatabaseName: "dbname",
},
YdbOperationId: "1234",
SourcePaths: nil,
SourcePathToExclude: nil,
CreatedAt: time.Unix(0, 0),
}
backup := types.Backup{
ID: backupId,
ContainerID: "a",
DatabaseName: "b",
S3Endpoint: "c",
S3Region: "d",
S3Bucket: "e",
S3PathPrefix: "f",
Status: "Available",
}
builder := NewWriteTableQuery().
WithCreateBackup(backup).
WithCreateOperation(&tbOp)
var (
queryParams = table.NewQueryParameters(
table.ValueParam("$id_0", table_types.UUIDValue(backupId)),
table.ValueParam("$container_id_0", table_types.StringValueFromString("a")),
table.ValueParam("$database_0", table_types.StringValueFromString("b")),
table.ValueParam("$initiated_0", table_types.StringValueFromString("")),
table.ValueParam("$s3_endpoint_0", table_types.StringValueFromString("c")),
table.ValueParam("$s3_region_0", table_types.StringValueFromString("d")),
table.ValueParam("$s3_bucket_0", table_types.StringValueFromString("e")),
table.ValueParam("$s3_path_prefix_0", table_types.StringValueFromString("f")),
table.ValueParam("$status_0", table_types.StringValueFromString("Available")),
table.ValueParam("$id_1", table_types.UUIDValue(opId)),
table.ValueParam("$type_1", table_types.StringValueFromString("TB")),
table.ValueParam(
"$status_1", table_types.StringValueFromString(string(tbOp.State)),
),
table.ValueParam(
"$container_id_1", table_types.StringValueFromString(tbOp.ContainerID),
),
table.ValueParam(
"$database_1",
table_types.StringValueFromString(tbOp.YdbConnectionParams.DatabaseName),
),
table.ValueParam(
"$backup_id_1",
table_types.UUIDValue(tbOp.BackupId),
),
table.ValueParam(
"$initiated_1",
table_types.StringValueFromString(""),
),
table.ValueParam(
"$created_at_1",
table_types.TimestampValueFromTime(tbOp.CreatedAt),
),
table.ValueParam(
"$operation_id_1",
table_types.StringValueFromString(tbOp.YdbOperationId),
),
)
)
query, err := builder.FormatQuery(context.Background())
assert.Empty(t, err)
assert.Equal(
t, queryString, query.QueryText,
"bad query format",
)
assert.Equal(t, queryParams, query.QueryParams, "bad query params")
}

func TestQueryBuilder_UpdateCreate(t *testing.T) {
const (
queryString = `DECLARE $id_0 AS Uuid;
DECLARE $status_0 AS String;
UPDATE Backups SET status = $status_0 WHERE id = $id_0;
DECLARE $id_1 AS Uuid;
DECLARE $type_1 AS String;
DECLARE $status_1 AS String;
DECLARE $container_id_1 AS String;
DECLARE $database_1 AS String;
DECLARE $backup_id_1 AS Uuid;
DECLARE $initiated_1 AS String;
DECLARE $created_at_1 AS Timestamp;
DECLARE $operation_id_1 AS String;
UPSERT INTO Operations (id, type, status, container_id, database, backup_id, initiated, created_at, operation_id) VALUES ($id_1, $type_1, $status_1, $container_id_1, $database_1, $backup_id_1, $initiated_1, $created_at_1, $operation_id_1)`
)
opId := types.GenerateObjectID()
backupId := types.GenerateObjectID()
tbOp := types.TakeBackupOperation{
Id: opId,
ContainerID: "a",
BackupId: backupId,
State: "PENDING",
Message: "Message",
YdbConnectionParams: types.YdbConnectionParams{
Endpoint: "",
DatabaseName: "dbname",
},
YdbOperationId: "1234",
SourcePaths: nil,
SourcePathToExclude: nil,
CreatedAt: time.Unix(0, 0),
}
backup := types.Backup{
ID: backupId,
Status: "Available",
}
builder := NewWriteTableQuery().
WithUpdateBackup(backup).
WithCreateOperation(&tbOp)
var (
queryParams = table.NewQueryParameters(
table.ValueParam("$id_0", table_types.UUIDValue(backupId)),
table.ValueParam("$status_0", table_types.StringValueFromString("Available")),
table.ValueParam("$id_1", table_types.UUIDValue(opId)),
table.ValueParam("$type_1", table_types.StringValueFromString("TB")),
table.ValueParam(
"$status_1", table_types.StringValueFromString(string(tbOp.State)),
),
table.ValueParam(
"$container_id_1", table_types.StringValueFromString(tbOp.ContainerID),
),
table.ValueParam(
"$database_1",
table_types.StringValueFromString(tbOp.YdbConnectionParams.DatabaseName),
),
table.ValueParam(
"$backup_id_1",
table_types.UUIDValue(tbOp.BackupId),
),
table.ValueParam(
"$initiated_1",
table_types.StringValueFromString(""),
),
table.ValueParam(
"$created_at_1",
table_types.TimestampValueFromTime(tbOp.CreatedAt),
),
table.ValueParam(
"$operation_id_1",
table_types.StringValueFromString(tbOp.YdbOperationId),
),
)
)
query, err := builder.FormatQuery(context.Background())
assert.Empty(t, err)
assert.Equal(
t, queryString, query.QueryText,
"bad query format",
)
assert.Equal(t, queryParams, query.QueryParams, "bad query params")
}
Loading