Skip to content

Commit

Permalink
feat: add transaction runner (#127)
Browse files Browse the repository at this point in the history
* feat: add transaction runner

* chore: disable internal retries when using runner
  • Loading branch information
olavloite authored Nov 25, 2024
1 parent 49ad3b0 commit b446423
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 0 deletions.
64 changes: 64 additions & 0 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@ package gorm

import (
"context"
"database/sql"
"fmt"
"log"
"os"
"reflect"
"regexp"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -350,3 +353,64 @@ func TestIntegration_CommitTimestamp(t *testing.T) {
t.Fatalf("missing commit timestamp for singer")
}
}

func TestIntegration_RunTransaction(t *testing.T) {
skipIfShort(t)
t.Parallel()

ctx := context.Background()
dsn, cleanup, err := testutil.CreateTestDB(ctx)
if err != nil {
log.Fatalf("could not init integration tests while creating database: %v", err)
}
defer cleanup()
// Open db.
db, err := gorm.Open(New(Config{
DriverName: "spanner",
DSN: dsn,
}), &gorm.Config{PrepareStmt: true})
if err != nil {
log.Fatal(err)
}

type Number struct {
Id int64
Name string
}

if err := db.AutoMigrate(&Number{}); err != nil {
t.Fatalf("failed to migrate: %v", err)
}

numTransactions := 20
wg := &sync.WaitGroup{}
wg.Add(numTransactions)
var countBefore int
if err := db.Raw("select count(1) from numbers").Scan(&countBefore).Error; err != nil {
t.Fatal(err)
}
for i := 0; i < numTransactions; i++ {
go func() {
defer wg.Done()
_ = RunTransaction(ctx, db, func(tx *gorm.DB) error {
var max int64
if err := tx.Raw("select coalesce(max(id), 0) from numbers").Scan(&max).Error; err != nil {
return err
}
number := Number{Id: max + 1, Name: fmt.Sprintf("Number: %d", max+1)}
if err := tx.Create(&number).Error; err != nil {
return err
}
return nil
}, &sql.TxOptions{})
}()
}
wg.Wait()
var countAfter int
if err := db.Raw("select count(1) from numbers").Scan(&countAfter).Error; err != nil {
t.Fatal(err)
}
if g, w := countAfter-countBefore, numTransactions; g != w {
t.Fatalf("number count mismatch:\n Got: %v\nWant: %v", g, w)
}
}
60 changes: 60 additions & 0 deletions retry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright 2024 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package gorm

import (
"context"
"database/sql"
"math/rand"
"time"

"cloud.google.com/go/spanner"
"github.com/googleapis/gax-go/v2"
spannerdriver "github.com/googleapis/go-sql-spanner"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"gorm.io/gorm"
)

// RunTransaction executes a transaction on Spanner using the given
// gorm database, and retries the transaction if it is aborted by Spanner.
func RunTransaction(ctx context.Context, db *gorm.DB, fc func(tx *gorm.DB) error, opts ...*sql.TxOptions) error {
// Disable internal (checksum-based) retries on the Spanner database/SQL connection.
var opt *sql.TxOptions
// Note: gorm also only uses the first option, so it is safe to pick just the first element in the slice.
if len(opts) > 0 {
opt = opts[0]
}
opt.Isolation = spannerdriver.WithDisableRetryAborts(opt.Isolation)
for {
err := db.Transaction(fc, opt)
if err == nil {
return nil
}
s, ok := status.FromError(err)
if !ok || s.Code() != codes.Aborted {
return err
}
delay, ok := spanner.ExtractRetryDelay(err)
if !ok {
// Use a random backoff time if no backoff time was included in the error.
r := rand.New(rand.NewSource(time.Now().UnixNano()))
delay = time.Duration(r.Intn(20)) * time.Millisecond
}
if err := gax.Sleep(ctx, delay); err != nil {
return err
}
}
}
70 changes: 70 additions & 0 deletions spanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
package gorm

import (
"context"
"database/sql"
"reflect"
"strconv"
"testing"

"cloud.google.com/go/spanner/apiv1/spannerpb"
structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/googleapis/go-sql-spanner/testutil"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
Expand Down Expand Up @@ -223,6 +227,72 @@ func TestAutoSaveAssociations(t *testing.T) {
}
}

func TestRunTransaction(t *testing.T) {
t.Parallel()

ctx := context.Background()
db, server, teardown := setupTestGormConnection(t)
defer teardown()

s := singerWithCommitTimestamp{
FirstName: "First",
LastName: "Last",
}
insertSql := "INSERT INTO `singers` (`first_name`,`last_name`,`last_updated`,`rating`) VALUES (@p1,@p2,PENDING_COMMIT_TIMESTAMP(),@p3) THEN RETURN `id`"
_ = putSingerResult(server, insertSql, s)
if err := RunTransaction(ctx, db, func(tx *gorm.DB) error {
if err := tx.Create(&s).Error; err != nil {
return err
}
return nil
}, &sql.TxOptions{}); err != nil {
t.Fatal(err)
}
// Verify that the insert was only executed once.
reqs := drainRequestsFromServer(server.TestSpanner)
execReqs := requestsOfType(reqs, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
insertReqs := filter(execReqs, insertSql)
if g, w := len(insertReqs), 1; g != w {
t.Fatalf("num requests mismatch\n Got: %v\nWant: %v", g, w)
}

// Run the same transaction again, but now we simulate that Spanner aborted the transaction.
server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{
Errors: []error{status.Error(codes.Aborted, "Aborted")},
})
attempts := 0
if err := RunTransaction(ctx, db, func(tx *gorm.DB) error {
attempts++
if err := tx.Create(&s).Error; err != nil {
return err
}
return nil
}, &sql.TxOptions{}); err != nil {
t.Fatal(err)
}
// Now verify that the insert was executed twice and that the function was called twice.
if g, w := attempts, 2; g != w {
t.Fatalf("attempts mismatch\n Got: %v\nWant: %v", g, w)
}
reqs = drainRequestsFromServer(server.TestSpanner)
execReqs = requestsOfType(reqs, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
insertReqs = filter(execReqs, insertSql)
if g, w := len(insertReqs), 2; g != w {
t.Fatalf("num requests mismatch\n Got: %v\nWant: %v", g, w)
}
}

func filter(requests []interface{}, sql string) (ret []*spannerpb.ExecuteSqlRequest) {
for _, i := range requests {
if req, ok := i.(*spannerpb.ExecuteSqlRequest); ok {
if req.Sql == sql {
ret = append(ret, req)
}
}
}
return ret
}

func getLastSql(server *testutil.MockedSpannerInMemTestServer) string {
return getLastSqlRequest(server).Sql
}
Expand Down

0 comments on commit b446423

Please sign in to comment.