diff --git a/internal/connectors/db/yql/queries/write.go b/internal/connectors/db/yql/queries/write.go index 4a211a2a..62bae968 100644 --- a/internal/connectors/db/yql/queries/write.go +++ b/internal/connectors/db/yql/queries/write.go @@ -29,6 +29,7 @@ type WriteSingleTableQueryImpl struct { tableName string upsertFields []string tableQueryParams []table.ParameterOption + updateParam *table.ParameterOption } func (d *WriteSingleTableQueryImpl) AddValueParam(name string, value table_types.Value) { @@ -36,6 +37,12 @@ func (d *WriteSingleTableQueryImpl) AddValueParam(name string, value table_types d.tableQueryParams = append(d.tableQueryParams, table.ValueParam(fmt.Sprintf("%s_%d", name, d.index), value)) } +func (d *WriteSingleTableQueryImpl) AddUpdateId(value table_types.Value) { + updateParamName := fmt.Sprintf("$id_%d", d.index) + vp := table.ValueParam(updateParamName, value) + d.updateParam = &vp +} + func (d *WriteSingleTableQueryImpl) GetParamNames() []string { res := make([]string, len(d.tableQueryParams)) for i, p := range d.tableQueryParams { @@ -73,12 +80,14 @@ func BuildCreateOperationQuery(operation types.Operation, index int) WriteSingle ) d.AddValueParam( "$created_at", - table_types.StringValueFromString(""), //TODO + table_types.TimestampValueFromTime(tb.CreatedAt), ) d.AddValueParam( "$operation_id", table_types.StringValueFromString(tb.YdbOperationId), ) + } else { + panic("Implement me") } return d @@ -89,7 +98,7 @@ func BuildUpdateOperationQuery(operation types.Operation, index int) WriteSingle index: index, tableName: "Operations", } - d.AddValueParam("$id", table_types.UUIDValue(operation.GetId())) + d.AddUpdateId(table_types.UUIDValue(operation.GetId())) d.AddValueParam( "$status", table_types.StringValueFromString(operation.GetState().String()), ) @@ -105,7 +114,7 @@ func BuildUpdateBackupQuery(backup types.Backup, index int) WriteSingleTableQuer index: index, tableName: "Backups", } - d.AddValueParam("$id", table_types.UUIDValue(backup.ID)) + d.AddUpdateId(table_types.UUIDValue(backup.ID)) d.AddValueParam("$status", table_types.StringValueFromString(backup.Status)) return d } @@ -160,35 +169,83 @@ func (d *WriteTableQueryImpl) WithCreateOperation(operation types.Operation) Wri } func (d *WriteSingleTableQueryImpl) DeclareParameters() string { - declares := make([]string, len(d.tableQueryParams)) - for i, param := range d.tableQueryParams { - declares[i] = fmt.Sprintf("DECLARE %s AS %s", param.Name(), param.Value().Type().String()) + declares := make([]string, 0) + if d.updateParam != nil { + declares = append( + declares, + fmt.Sprintf("DECLARE %s AS %s", (*d.updateParam).Name(), (*d.updateParam).Value().Type().String()), + ) + } + for _, param := range d.tableQueryParams { + declares = append( + declares, fmt.Sprintf("DECLARE %s AS %s", param.Name(), param.Value().Type().String()), + ) } return strings.Join(declares, ";\n") } +func ProcessUpsertQuery( + queryStrings *[]string, allParams *[]table.ParameterOption, t *WriteSingleTableQueryImpl, +) error { + if len(t.upsertFields) == 0 { + return errors.New("No fields to upsert") + } + if len(t.tableName) == 0 { + return errors.New("No table") + } + declares := t.DeclareParameters() + *queryStrings = append( + *queryStrings, fmt.Sprintf( + "%s;\nUPSERT INTO %s (%s) VALUES (%s)", declares, t.tableName, strings.Join(t.upsertFields, ", "), + strings.Join(t.GetParamNames(), ", "), + ), + ) + for _, p := range t.tableQueryParams { + *allParams = append(*allParams, p) + } + return nil +} + +func ProcessUpdateQuery( + queryStrings *[]string, allParams *[]table.ParameterOption, t *WriteSingleTableQueryImpl, +) error { + if len(t.upsertFields) == 0 { + return errors.New("No fields to upsert") + } + if len(t.tableName) == 0 { + return errors.New("No table") + } + declares := t.DeclareParameters() + paramNames := t.GetParamNames() + keyParam := fmt.Sprintf("id = %s", (*t.updateParam).Name()) + updates := make([]string, 0) + for i := range t.upsertFields { + updates = append(updates, fmt.Sprintf("%s = %s", t.upsertFields[i], paramNames[i])) + } + *queryStrings = append( + *queryStrings, fmt.Sprintf( + "%s;\nUPDATE %s SET %s WHERE %s", declares, t.tableName, strings.Join(updates, ", "), keyParam, + ), + ) + *allParams = append(*allParams, *t.updateParam) + for _, p := range t.tableQueryParams { + *allParams = append(*allParams, p) + } + return nil +} + func (d *WriteTableQueryImpl) FormatQuery(ctx context.Context) (*FormatQueryResult, error) { - queryStrings := make([]string, len(d.tableQueries)) + queryStrings := make([]string, 0) allParams := make([]table.ParameterOption, 0) - for i, t := range d.tableQueries { - if len(t.upsertFields) == 0 { - return nil, errors.New("No fields to upsert") - } - if len(t.tableName) == 0 { - return nil, errors.New("No table") + for _, t := range d.tableQueries { + var err error + if t.updateParam == nil { + err = ProcessUpsertQuery(&queryStrings, &allParams, &t) + } else { + err = ProcessUpdateQuery(&queryStrings, &allParams, &t) } - declares := t.DeclareParameters() - paramNames := t.GetParamNames() - keyParam := fmt.Sprintf("%s = %s", t.upsertFields[0], paramNames[0]) - updates := make([]string, 0) - for j := 1; j < len(t.upsertFields); j++ { - updates = append(updates, fmt.Sprintf("%s = %s", t.upsertFields[j], paramNames[j])) - } - queryStrings[i] = fmt.Sprintf( - "%s;\nUPDATE %s SET %s WHERE %s", declares, t.tableName, strings.Join(updates, ", "), keyParam, - ) - for _, p := range t.tableQueryParams { - allParams = append(allParams, p) + if err != nil { + return nil, err } } res := strings.Join(queryStrings, ";\n") diff --git a/internal/connectors/db/yql/queries/write_test.go b/internal/connectors/db/yql/queries/write_test.go index 95b4b7c0..2c04b200 100644 --- a/internal/connectors/db/yql/queries/write_test.go +++ b/internal/connectors/db/yql/queries/write_test.go @@ -6,10 +6,11 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/table" table_types "github.com/ydb-platform/ydb-go-sdk/v3/table/types" "testing" + "time" "ydbcp/internal/types" ) -func TestQueryBuilder_Write(t *testing.T) { +func TestQueryBuilder_UpdateUpdate(t *testing.T) { const ( queryString = `DECLARE $id_0 AS Uuid; DECLARE $status_0 AS String; @@ -50,3 +51,189 @@ UPDATE Operations SET status = $status_1, message = $message_1 WHERE id = $id_1` ) assert.Equal(t, queryParams, query.QueryParams, "bad query params") } + +func TestQueryBuilder_CreateCreate(t *testing.T) { + const ( + queryString = `DECLARE $id_0 AS Uuid; +DECLARE $container_id_0 AS String; +DECLARE $database_0 AS String; +DECLARE $initiated_0 AS String; +DECLARE $s3_endpoint_0 AS String; +DECLARE $s3_region_0 AS String; +DECLARE $s3_bucket_0 AS String; +DECLARE $s3_path_prefix_0 AS String; +DECLARE $status_0 AS String; +UPSERT INTO Backups (id, container_id, database, initiated, s3_endpoint, s3_region, s3_bucket, s3_path_prefix, status) VALUES ($id_0, $container_id_0, $database_0, $initiated_0, $s3_endpoint_0, $s3_region_0, $s3_bucket_0, $s3_path_prefix_0, $status_0); +DECLARE $id_1 AS Uuid; +DECLARE $type_1 AS String; +DECLARE $status_1 AS String; +DECLARE $container_id_1 AS String; +DECLARE $database_1 AS String; +DECLARE $backup_id_1 AS Uuid; +DECLARE $initiated_1 AS String; +DECLARE $created_at_1 AS Timestamp; +DECLARE $operation_id_1 AS String; +UPSERT INTO Operations (id, type, status, container_id, database, backup_id, initiated, created_at, operation_id) VALUES ($id_1, $type_1, $status_1, $container_id_1, $database_1, $backup_id_1, $initiated_1, $created_at_1, $operation_id_1)` + ) + opId := types.GenerateObjectID() + backupId := types.GenerateObjectID() + tbOp := types.TakeBackupOperation{ + Id: opId, + ContainerID: "a", + BackupId: backupId, + State: "PENDING", + Message: "Message", + YdbConnectionParams: types.YdbConnectionParams{ + Endpoint: "", + DatabaseName: "dbname", + }, + YdbOperationId: "1234", + SourcePaths: nil, + SourcePathToExclude: nil, + CreatedAt: time.Unix(0, 0), + } + backup := types.Backup{ + ID: backupId, + ContainerID: "a", + DatabaseName: "b", + S3Endpoint: "c", + S3Region: "d", + S3Bucket: "e", + S3PathPrefix: "f", + Status: "Available", + } + builder := NewWriteTableQuery(). + WithCreateBackup(backup). + WithCreateOperation(&tbOp) + var ( + queryParams = table.NewQueryParameters( + table.ValueParam("$id_0", table_types.UUIDValue(backupId)), + table.ValueParam("$container_id_0", table_types.StringValueFromString("a")), + table.ValueParam("$database_0", table_types.StringValueFromString("b")), + table.ValueParam("$initiated_0", table_types.StringValueFromString("")), + table.ValueParam("$s3_endpoint_0", table_types.StringValueFromString("c")), + table.ValueParam("$s3_region_0", table_types.StringValueFromString("d")), + table.ValueParam("$s3_bucket_0", table_types.StringValueFromString("e")), + table.ValueParam("$s3_path_prefix_0", table_types.StringValueFromString("f")), + table.ValueParam("$status_0", table_types.StringValueFromString("Available")), + table.ValueParam("$id_1", table_types.UUIDValue(opId)), + table.ValueParam("$type_1", table_types.StringValueFromString("TB")), + table.ValueParam( + "$status_1", table_types.StringValueFromString(string(tbOp.State)), + ), + table.ValueParam( + "$container_id_1", table_types.StringValueFromString(tbOp.ContainerID), + ), + table.ValueParam( + "$database_1", + table_types.StringValueFromString(tbOp.YdbConnectionParams.DatabaseName), + ), + table.ValueParam( + "$backup_id_1", + table_types.UUIDValue(tbOp.BackupId), + ), + table.ValueParam( + "$initiated_1", + table_types.StringValueFromString(""), + ), + table.ValueParam( + "$created_at_1", + table_types.TimestampValueFromTime(tbOp.CreatedAt), + ), + table.ValueParam( + "$operation_id_1", + table_types.StringValueFromString(tbOp.YdbOperationId), + ), + ) + ) + query, err := builder.FormatQuery(context.Background()) + assert.Empty(t, err) + assert.Equal( + t, queryString, query.QueryText, + "bad query format", + ) + assert.Equal(t, queryParams, query.QueryParams, "bad query params") +} + +func TestQueryBuilder_UpdateCreate(t *testing.T) { + const ( + queryString = `DECLARE $id_0 AS Uuid; +DECLARE $status_0 AS String; +UPDATE Backups SET status = $status_0 WHERE id = $id_0; +DECLARE $id_1 AS Uuid; +DECLARE $type_1 AS String; +DECLARE $status_1 AS String; +DECLARE $container_id_1 AS String; +DECLARE $database_1 AS String; +DECLARE $backup_id_1 AS Uuid; +DECLARE $initiated_1 AS String; +DECLARE $created_at_1 AS Timestamp; +DECLARE $operation_id_1 AS String; +UPSERT INTO Operations (id, type, status, container_id, database, backup_id, initiated, created_at, operation_id) VALUES ($id_1, $type_1, $status_1, $container_id_1, $database_1, $backup_id_1, $initiated_1, $created_at_1, $operation_id_1)` + ) + opId := types.GenerateObjectID() + backupId := types.GenerateObjectID() + tbOp := types.TakeBackupOperation{ + Id: opId, + ContainerID: "a", + BackupId: backupId, + State: "PENDING", + Message: "Message", + YdbConnectionParams: types.YdbConnectionParams{ + Endpoint: "", + DatabaseName: "dbname", + }, + YdbOperationId: "1234", + SourcePaths: nil, + SourcePathToExclude: nil, + CreatedAt: time.Unix(0, 0), + } + backup := types.Backup{ + ID: backupId, + Status: "Available", + } + builder := NewWriteTableQuery(). + WithUpdateBackup(backup). + WithCreateOperation(&tbOp) + var ( + queryParams = table.NewQueryParameters( + table.ValueParam("$id_0", table_types.UUIDValue(backupId)), + table.ValueParam("$status_0", table_types.StringValueFromString("Available")), + table.ValueParam("$id_1", table_types.UUIDValue(opId)), + table.ValueParam("$type_1", table_types.StringValueFromString("TB")), + table.ValueParam( + "$status_1", table_types.StringValueFromString(string(tbOp.State)), + ), + table.ValueParam( + "$container_id_1", table_types.StringValueFromString(tbOp.ContainerID), + ), + table.ValueParam( + "$database_1", + table_types.StringValueFromString(tbOp.YdbConnectionParams.DatabaseName), + ), + table.ValueParam( + "$backup_id_1", + table_types.UUIDValue(tbOp.BackupId), + ), + table.ValueParam( + "$initiated_1", + table_types.StringValueFromString(""), + ), + table.ValueParam( + "$created_at_1", + table_types.TimestampValueFromTime(tbOp.CreatedAt), + ), + table.ValueParam( + "$operation_id_1", + table_types.StringValueFromString(tbOp.YdbOperationId), + ), + ) + ) + query, err := builder.FormatQuery(context.Background()) + assert.Empty(t, err) + assert.Equal( + t, queryString, query.QueryText, + "bad query format", + ) + assert.Equal(t, queryParams, query.QueryParams, "bad query params") +}