From a79a1e1b2caa08df2c25c754885d129330f13820 Mon Sep 17 00:00:00 2001 From: U Cirello Date: Sun, 27 Oct 2024 10:34:40 -0700 Subject: [PATCH] Use row-level locking to handle database operations (#118) * pglock: remove unnecessary transaction to execute heartbeats * pglock: remove unnecessary transaction to execute releases * pglock: remove unnecessary transaction to execute lock acquisitions --- client.go | 112 +++++++++++++++++++--------------------- client_internal_test.go | 109 +++----------------------------------- client_test.go | 4 +- 3 files changed, 62 insertions(+), 163 deletions(-) diff --git a/client.go b/client.go index 2ad418c..84c7e8b 100644 --- a/client.go +++ b/client.go @@ -223,39 +223,42 @@ func (c *Client) storeAcquire(ctx context.Context, l *Lock) error { return typedError(err, "cannot run query to read record version number") } - tx, err := c.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable}) - if err != nil { - return typedError(err, "cannot create transaction for lock acquisition") - } c.log.Debug("storeAcquire in: %v %v %v %v", l.name, rvn, l.data, l.recordVersionNumber) defer func() { c.log.Debug("storeAcquire out: %v %v %v %v", l.name, rvn, l.data, l.recordVersionNumber) }() - _, err = tx.ExecContext(ctx, ` + rowLockInfo := c.db.QueryRowContext(ctx, ` INSERT INTO `+c.tableName+` ("name", "record_version_number", "data", "owner") VALUES ($1, $2, $3, $6) ON CONFLICT ("name") DO UPDATE SET - "record_version_number" = $2, + "record_version_number" = CASE + WHEN COALESCE(`+c.tableName+`."record_version_number" = $4, TRUE) THEN $2 + ELSE `+c.tableName+`."record_version_number" + END, "data" = CASE - WHEN $5 THEN $3 + WHEN COALESCE(`+c.tableName+`."record_version_number" = $4, TRUE) THEN + CASE + WHEN $5 THEN $3 + ELSE `+c.tableName+`."data" + END ELSE `+c.tableName+`."data" END, - "owner" = $6 - WHERE - `+c.tableName+`."record_version_number" IS NULL - OR `+c.tableName+`."record_version_number" = $4 + "owner" = CASE + WHEN COALESCE(`+c.tableName+`."record_version_number" = $4, TRUE) THEN $6 + ELSE `+c.tableName+`."owner" + END + RETURNING + "record_version_number", "data", "owner" `, l.name, rvn, l.data, l.recordVersionNumber, l.replaceData, c.owner) - if err != nil { - return typedError(err, "cannot run query to acquire lock") - } - rowLockInfo := tx.QueryRowContext(ctx, `SELECT "record_version_number", "data", "owner" FROM `+c.tableName+` WHERE name = $1 FOR UPDATE`, l.name) - var actualRVN int64 - var data []byte - var actualOwner string - if err := rowLockInfo.Scan(&actualRVN, &data, &actualOwner); err != nil { + var ( + actualRVN int64 + actualData []byte + actualOwner string + ) + if err := rowLockInfo.Scan(&actualRVN, &actualData, &actualOwner); err != nil && !errors.Is(err, sql.ErrNoRows) { return typedError(err, "cannot load information for lock acquisition") } l.owner = actualOwner @@ -263,11 +266,8 @@ func (c *Client) storeAcquire(ctx context.Context, l *Lock) error { l.recordVersionNumber = actualRVN return ErrNotAcquired } - if err := tx.Commit(); err != nil { - return typedError(err, "cannot commit lock acquisition") - } l.recordVersionNumber = rvn - l.data = data + l.data = actualData return nil } @@ -312,21 +312,34 @@ func (c *Client) storeRelease(ctx context.Context, l *Lock) error { defer l.mu.Unlock() ctx, cancel := context.WithTimeout(ctx, l.leaseDuration) defer cancel() - tx, err := c.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable}) - if err != nil { - return typedError(err, "cannot create transaction for lock acquisition") - } - result, err := tx.ExecContext(ctx, ` - UPDATE - `+c.tableName+` - SET - "record_version_number" = NULL - WHERE - "name" = $1 - AND "record_version_number" = $2 - `, l.name, l.recordVersionNumber) - if err != nil { - return typedError(err, "cannot run query to release lock") + var result sql.Result + switch l.keepOnRelease { + case true: + res, err := c.db.ExecContext(ctx, ` + UPDATE + `+c.tableName+` + SET + "record_version_number" = NULL + WHERE + "name" = $1 + AND "record_version_number" = $2 + `, l.name, l.recordVersionNumber) + if err != nil { + return typedError(err, "cannot run query to release lock (keep)") + } + result = res + case false: + res, err := c.db.ExecContext(ctx, ` + DELETE FROM + `+c.tableName+` + WHERE + "name" = $1 + AND "record_version_number" = $2 + `, l.name, l.recordVersionNumber) + if err != nil { + return typedError(err, "cannot run query to delete lock (delete)") + } + result = res } affected, err := result.RowsAffected() if err != nil { @@ -335,20 +348,6 @@ func (c *Client) storeRelease(ctx context.Context, l *Lock) error { l.isReleased = true return ErrLockAlreadyReleased } - if !l.keepOnRelease { - _, err := tx.ExecContext(ctx, ` - DELETE FROM - `+c.tableName+` - WHERE - "name" = $1 - AND "record_version_number" IS NULL`, l.name) - if err != nil { - return typedError(err, "cannot run query to delete lock") - } - } - if err := tx.Commit(); err != nil { - return typedError(err, "cannot commit lock release") - } l.isReleased = true l.heartbeatCancel() return nil @@ -392,11 +391,7 @@ func (c *Client) storeHeartbeat(ctx context.Context, l *Lock) error { if err != nil { return typedError(err, "cannot run query to read record version number") } - tx, err := c.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable}) - if err != nil { - return typedError(err, "cannot create transaction for lock acquisition") - } - result, err := tx.ExecContext(ctx, ` + result, err := c.db.ExecContext(ctx, ` UPDATE `+c.tableName+` SET @@ -414,9 +409,6 @@ func (c *Client) storeHeartbeat(ctx context.Context, l *Lock) error { } else if affected == 0 { return ErrLockAlreadyReleased } - if err := tx.Commit(); err != nil { - return typedError(err, "cannot commit lock heartbeat") - } l.recordVersionNumber = rvn return nil } diff --git a/client_internal_test.go b/client_internal_test.go index 33b86c5..024e975 100644 --- a/client_internal_test.go +++ b/client_internal_test.go @@ -120,15 +120,6 @@ func TestDBErrorHandling(t *testing.T) { } } t.Run("acquire", func(t *testing.T) { - t.Run("bad tx", func(t *testing.T) { - client, mock, _ := setup() - badTx := errors.New("transaction begin error") - mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1)) - mock.ExpectBegin().WillReturnError(badTx) - if _, err := client.Acquire("bad-tx"); !errors.Is(err, badTx) { - t.Errorf("expected tx error missing: %v", err) - } - }) t.Run("bad rvn", func(t *testing.T) { client, mock, _ := setup() badRVN := errors.New("cannot load next RVN") @@ -141,103 +132,32 @@ func TestDBErrorHandling(t *testing.T) { client, mock, _ := setup() badInsert := errors.New("cannot insert") mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec(`INSERT INTO locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnError(badInsert) + mock.ExpectQuery(`INSERT INTO locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnError(badInsert) if _, err := client.Acquire("bad-insert"); !errors.Is(err, badInsert) { t.Errorf("expected RVN error missing: %v", err) } }) - t.Run("bad RVN confirmation", func(t *testing.T) { - client, mock, _ := setup() - badRVN := errors.New("cannot confirm RVN") - mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec(`INSERT INTO locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectQuery(`SELECT "record_version_number", "data", "owner" FROM locks WHERE name = (.+)`).WithArgs(sqlmock.AnyArg()).WillReturnError(badRVN) - if _, err := client.Acquire("bad-insert"); !errors.Is(err, badRVN) { - t.Errorf("expected RVN confirmation error missing: %v", err) - } - }) - t.Run("bad commit", func(t *testing.T) { - client, mock, _ := setup() - badCommit := errors.New("cannot confirm RVN") - mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec(`INSERT INTO locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectQuery(`SELECT "record_version_number", "data", "owner" FROM locks WHERE name = (.+)`). - WithArgs(sqlmock.AnyArg()). - WillReturnRows( - sqlmock.NewRows([]string{ - "record_version_number", - "data", - "owner", - }).AddRow(1, []byte{}, "owner"), - ) - mock.ExpectCommit().WillReturnError(badCommit) - if _, err := client.Acquire("bad-insert"); !errors.Is(err, badCommit) { - t.Errorf("expected commit error missing: %v", err) - } - }) }) t.Run("release", func(t *testing.T) { - t.Run("bad tx", func(t *testing.T) { - client, mock, fakeLock := setup() - badTx := errors.New("transaction begin error") - mock.ExpectBegin().WillReturnError(badTx) - if err := client.Release(fakeLock); !errors.Is(err, badTx) { - t.Errorf("expected tx error missing: %v", err) - } - }) t.Run("bad update", func(t *testing.T) { client, mock, fakeLock := setup() - badUpdate := errors.New("cannot update") - mock.ExpectBegin() - mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnError(badUpdate) - if err := client.Release(fakeLock); !errors.Is(err, badUpdate) { - t.Errorf("expected update error missing: %v", err) - } - }) - t.Run("bad update result", func(t *testing.T) { - client, mock, fakeLock := setup() - badUpdateResult := errors.New("cannot grab update result") - mock.ExpectBegin() - mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewErrorResult(badUpdateResult)) - if err := client.Release(fakeLock); !errors.Is(err, badUpdateResult) { - t.Errorf("expected update result error missing: %v", err) + badDelete := errors.New("cannot delete lock entry") + fakeLock.keepOnRelease = true + mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnError(badDelete) + if err := client.Release(fakeLock); !errors.Is(err, badDelete) { + t.Errorf("expected delete error missing: %v", err) } }) t.Run("bad delete", func(t *testing.T) { client, mock, fakeLock := setup() badDelete := errors.New("cannot delete lock entry") - mock.ExpectBegin() - mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(`DELETE FROM locks (.+)`).WithArgs(sqlmock.AnyArg()).WillReturnError(badDelete) + mock.ExpectExec(`DELETE FROM locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnError(badDelete) if err := client.Release(fakeLock); !errors.Is(err, badDelete) { t.Errorf("expected delete error missing: %v", err) } }) - t.Run("bad commit", func(t *testing.T) { - client, mock, fakeLock := setup() - badCommit := errors.New("cannot commit release") - mock.ExpectBegin() - mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec(`DELETE FROM locks (.+)`).WithArgs(sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit().WillReturnError(badCommit) - if err := client.Release(fakeLock); !errors.Is(err, badCommit) { - t.Errorf("expected commit error missing: %v", err) - } - }) }) t.Run("heartbeat", func(t *testing.T) { - t.Run("bad tx", func(t *testing.T) { - client, mock, fakeLock := setup() - badTx := errors.New("transaction begin error") - mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1)) - mock.ExpectBegin().WillReturnError(badTx) - if err := client.SendHeartbeat(context.Background(), fakeLock); !errors.Is(err, badTx) { - t.Errorf("expected tx error missing: %v", err) - } - }) t.Run("bad rvn", func(t *testing.T) { client, mock, fakeLock := setup() badRVN := errors.New("cannot load next RVN") @@ -246,11 +166,10 @@ func TestDBErrorHandling(t *testing.T) { t.Errorf("expected RVN error missing: %v", err) } }) - t.Run("bad insert", func(t *testing.T) { + t.Run("bad update", func(t *testing.T) { client, mock, fakeLock := setup() badUpdate := errors.New("cannot insert") mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1)) - mock.ExpectBegin() mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnError(badUpdate) if err := client.SendHeartbeat(context.Background(), fakeLock); !errors.Is(err, badUpdate) { t.Errorf("expected RVN error missing: %v", err) @@ -260,23 +179,11 @@ func TestDBErrorHandling(t *testing.T) { client, mock, fakeLock := setup() badRVN := errors.New("cannot confirm RVN") mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1)) - mock.ExpectBegin() mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewErrorResult(badRVN)) if err := client.SendHeartbeat(context.Background(), fakeLock); !errors.Is(err, badRVN) { t.Errorf("expected RVN confirmation error missing: %v", err) } }) - t.Run("bad commit", func(t *testing.T) { - client, mock, fakeLock := setup() - badCommit := errors.New("cannot confirm RVN") - mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit().WillReturnError(badCommit) - if err := client.SendHeartbeat(context.Background(), fakeLock); !errors.Is(err, badCommit) { - t.Errorf("expected commit error missing: %v", err) - } - }) }) t.Run("GetAllLocks", func(t *testing.T) { diff --git a/client_test.go b/client_test.go index b349c51..049a350 100644 --- a/client_test.go +++ b/client_test.go @@ -350,14 +350,14 @@ func TestKeepOnRelease(t *testing.T) { expected := []byte("42") l, err := c.Acquire(name, pglock.KeepOnRelease(), pglock.WithData(expected)) if err != nil { - t.Fatal("unexpected error while acquiring lock:", err) + t.Fatal("unexpected error while acquiring lock (take 1):", err) } t.Log("lock acquired") l.Close() l2, err := c.Acquire(name) if err != nil { - t.Fatal("unexpected error while acquiring lock:", err) + t.Fatal("unexpected error while acquiring lock (take 2):", err) } defer l2.Close() t.Log("lock reacquired")