Skip to content

Commit

Permalink
Use row-level locking to handle database operations (#118)
Browse files Browse the repository at this point in the history
* pglock: remove unnecessary transaction to execute heartbeats

* pglock: remove unnecessary transaction to execute releases

* pglock: remove unnecessary transaction to execute lock acquisitions
  • Loading branch information
ucirello authored Oct 27, 2024
1 parent 1b6d0d1 commit a79a1e1
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 163 deletions.
112 changes: 52 additions & 60 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,51 +223,51 @@ func (c *Client) storeAcquire(ctx context.Context, l *Lock) error {
return typedError(err, "cannot run query to read record version number")
}

tx, err := c.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return typedError(err, "cannot create transaction for lock acquisition")
}
c.log.Debug("storeAcquire in: %v %v %v %v", l.name, rvn, l.data, l.recordVersionNumber)
defer func() {
c.log.Debug("storeAcquire out: %v %v %v %v", l.name, rvn, l.data, l.recordVersionNumber)
}()
_, err = tx.ExecContext(ctx, `
rowLockInfo := c.db.QueryRowContext(ctx, `
INSERT INTO `+c.tableName+`
("name", "record_version_number", "data", "owner")
VALUES
($1, $2, $3, $6)
ON CONFLICT ("name") DO UPDATE
SET
"record_version_number" = $2,
"record_version_number" = CASE
WHEN COALESCE(`+c.tableName+`."record_version_number" = $4, TRUE) THEN $2
ELSE `+c.tableName+`."record_version_number"
END,
"data" = CASE
WHEN $5 THEN $3
WHEN COALESCE(`+c.tableName+`."record_version_number" = $4, TRUE) THEN
CASE
WHEN $5 THEN $3
ELSE `+c.tableName+`."data"
END
ELSE `+c.tableName+`."data"
END,
"owner" = $6
WHERE
`+c.tableName+`."record_version_number" IS NULL
OR `+c.tableName+`."record_version_number" = $4
"owner" = CASE
WHEN COALESCE(`+c.tableName+`."record_version_number" = $4, TRUE) THEN $6
ELSE `+c.tableName+`."owner"
END
RETURNING
"record_version_number", "data", "owner"
`, l.name, rvn, l.data, l.recordVersionNumber, l.replaceData, c.owner)
if err != nil {
return typedError(err, "cannot run query to acquire lock")
}
rowLockInfo := tx.QueryRowContext(ctx, `SELECT "record_version_number", "data", "owner" FROM `+c.tableName+` WHERE name = $1 FOR UPDATE`, l.name)
var actualRVN int64
var data []byte
var actualOwner string
if err := rowLockInfo.Scan(&actualRVN, &data, &actualOwner); err != nil {
var (
actualRVN int64
actualData []byte
actualOwner string
)
if err := rowLockInfo.Scan(&actualRVN, &actualData, &actualOwner); err != nil && !errors.Is(err, sql.ErrNoRows) {
return typedError(err, "cannot load information for lock acquisition")
}
l.owner = actualOwner
if actualRVN != rvn {
l.recordVersionNumber = actualRVN
return ErrNotAcquired
}
if err := tx.Commit(); err != nil {
return typedError(err, "cannot commit lock acquisition")
}
l.recordVersionNumber = rvn
l.data = data
l.data = actualData
return nil
}

Expand Down Expand Up @@ -312,21 +312,34 @@ func (c *Client) storeRelease(ctx context.Context, l *Lock) error {
defer l.mu.Unlock()
ctx, cancel := context.WithTimeout(ctx, l.leaseDuration)
defer cancel()
tx, err := c.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return typedError(err, "cannot create transaction for lock acquisition")
}
result, err := tx.ExecContext(ctx, `
UPDATE
`+c.tableName+`
SET
"record_version_number" = NULL
WHERE
"name" = $1
AND "record_version_number" = $2
`, l.name, l.recordVersionNumber)
if err != nil {
return typedError(err, "cannot run query to release lock")
var result sql.Result
switch l.keepOnRelease {
case true:
res, err := c.db.ExecContext(ctx, `
UPDATE
`+c.tableName+`
SET
"record_version_number" = NULL
WHERE
"name" = $1
AND "record_version_number" = $2
`, l.name, l.recordVersionNumber)
if err != nil {
return typedError(err, "cannot run query to release lock (keep)")
}
result = res
case false:
res, err := c.db.ExecContext(ctx, `
DELETE FROM
`+c.tableName+`
WHERE
"name" = $1
AND "record_version_number" = $2
`, l.name, l.recordVersionNumber)
if err != nil {
return typedError(err, "cannot run query to delete lock (delete)")
}
result = res
}
affected, err := result.RowsAffected()
if err != nil {
Expand All @@ -335,20 +348,6 @@ func (c *Client) storeRelease(ctx context.Context, l *Lock) error {
l.isReleased = true
return ErrLockAlreadyReleased
}
if !l.keepOnRelease {
_, err := tx.ExecContext(ctx, `
DELETE FROM
`+c.tableName+`
WHERE
"name" = $1
AND "record_version_number" IS NULL`, l.name)
if err != nil {
return typedError(err, "cannot run query to delete lock")
}
}
if err := tx.Commit(); err != nil {
return typedError(err, "cannot commit lock release")
}
l.isReleased = true
l.heartbeatCancel()
return nil
Expand Down Expand Up @@ -392,11 +391,7 @@ func (c *Client) storeHeartbeat(ctx context.Context, l *Lock) error {
if err != nil {
return typedError(err, "cannot run query to read record version number")
}
tx, err := c.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return typedError(err, "cannot create transaction for lock acquisition")
}
result, err := tx.ExecContext(ctx, `
result, err := c.db.ExecContext(ctx, `
UPDATE
`+c.tableName+`
SET
Expand All @@ -414,9 +409,6 @@ func (c *Client) storeHeartbeat(ctx context.Context, l *Lock) error {
} else if affected == 0 {
return ErrLockAlreadyReleased
}
if err := tx.Commit(); err != nil {
return typedError(err, "cannot commit lock heartbeat")
}
l.recordVersionNumber = rvn
return nil
}
Expand Down
109 changes: 8 additions & 101 deletions client_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,6 @@ func TestDBErrorHandling(t *testing.T) {
}
}
t.Run("acquire", func(t *testing.T) {
t.Run("bad tx", func(t *testing.T) {
client, mock, _ := setup()
badTx := errors.New("transaction begin error")
mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1))
mock.ExpectBegin().WillReturnError(badTx)
if _, err := client.Acquire("bad-tx"); !errors.Is(err, badTx) {
t.Errorf("expected tx error missing: %v", err)
}
})
t.Run("bad rvn", func(t *testing.T) {
client, mock, _ := setup()
badRVN := errors.New("cannot load next RVN")
Expand All @@ -141,103 +132,32 @@ func TestDBErrorHandling(t *testing.T) {
client, mock, _ := setup()
badInsert := errors.New("cannot insert")
mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec(`INSERT INTO locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnError(badInsert)
mock.ExpectQuery(`INSERT INTO locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnError(badInsert)
if _, err := client.Acquire("bad-insert"); !errors.Is(err, badInsert) {
t.Errorf("expected RVN error missing: %v", err)
}
})
t.Run("bad RVN confirmation", func(t *testing.T) {
client, mock, _ := setup()
badRVN := errors.New("cannot confirm RVN")
mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec(`INSERT INTO locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectQuery(`SELECT "record_version_number", "data", "owner" FROM locks WHERE name = (.+)`).WithArgs(sqlmock.AnyArg()).WillReturnError(badRVN)
if _, err := client.Acquire("bad-insert"); !errors.Is(err, badRVN) {
t.Errorf("expected RVN confirmation error missing: %v", err)
}
})
t.Run("bad commit", func(t *testing.T) {
client, mock, _ := setup()
badCommit := errors.New("cannot confirm RVN")
mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec(`INSERT INTO locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectQuery(`SELECT "record_version_number", "data", "owner" FROM locks WHERE name = (.+)`).
WithArgs(sqlmock.AnyArg()).
WillReturnRows(
sqlmock.NewRows([]string{
"record_version_number",
"data",
"owner",
}).AddRow(1, []byte{}, "owner"),
)
mock.ExpectCommit().WillReturnError(badCommit)
if _, err := client.Acquire("bad-insert"); !errors.Is(err, badCommit) {
t.Errorf("expected commit error missing: %v", err)
}
})
})
t.Run("release", func(t *testing.T) {
t.Run("bad tx", func(t *testing.T) {
client, mock, fakeLock := setup()
badTx := errors.New("transaction begin error")
mock.ExpectBegin().WillReturnError(badTx)
if err := client.Release(fakeLock); !errors.Is(err, badTx) {
t.Errorf("expected tx error missing: %v", err)
}
})
t.Run("bad update", func(t *testing.T) {
client, mock, fakeLock := setup()
badUpdate := errors.New("cannot update")
mock.ExpectBegin()
mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnError(badUpdate)
if err := client.Release(fakeLock); !errors.Is(err, badUpdate) {
t.Errorf("expected update error missing: %v", err)
}
})
t.Run("bad update result", func(t *testing.T) {
client, mock, fakeLock := setup()
badUpdateResult := errors.New("cannot grab update result")
mock.ExpectBegin()
mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewErrorResult(badUpdateResult))
if err := client.Release(fakeLock); !errors.Is(err, badUpdateResult) {
t.Errorf("expected update result error missing: %v", err)
badDelete := errors.New("cannot delete lock entry")
fakeLock.keepOnRelease = true
mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnError(badDelete)
if err := client.Release(fakeLock); !errors.Is(err, badDelete) {
t.Errorf("expected delete error missing: %v", err)
}
})
t.Run("bad delete", func(t *testing.T) {
client, mock, fakeLock := setup()
badDelete := errors.New("cannot delete lock entry")
mock.ExpectBegin()
mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(`DELETE FROM locks (.+)`).WithArgs(sqlmock.AnyArg()).WillReturnError(badDelete)
mock.ExpectExec(`DELETE FROM locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnError(badDelete)
if err := client.Release(fakeLock); !errors.Is(err, badDelete) {
t.Errorf("expected delete error missing: %v", err)
}
})
t.Run("bad commit", func(t *testing.T) {
client, mock, fakeLock := setup()
badCommit := errors.New("cannot commit release")
mock.ExpectBegin()
mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(`DELETE FROM locks (.+)`).WithArgs(sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit().WillReturnError(badCommit)
if err := client.Release(fakeLock); !errors.Is(err, badCommit) {
t.Errorf("expected commit error missing: %v", err)
}
})
})
t.Run("heartbeat", func(t *testing.T) {
t.Run("bad tx", func(t *testing.T) {
client, mock, fakeLock := setup()
badTx := errors.New("transaction begin error")
mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1))
mock.ExpectBegin().WillReturnError(badTx)
if err := client.SendHeartbeat(context.Background(), fakeLock); !errors.Is(err, badTx) {
t.Errorf("expected tx error missing: %v", err)
}
})
t.Run("bad rvn", func(t *testing.T) {
client, mock, fakeLock := setup()
badRVN := errors.New("cannot load next RVN")
Expand All @@ -246,11 +166,10 @@ func TestDBErrorHandling(t *testing.T) {
t.Errorf("expected RVN error missing: %v", err)
}
})
t.Run("bad insert", func(t *testing.T) {
t.Run("bad update", func(t *testing.T) {
client, mock, fakeLock := setup()
badUpdate := errors.New("cannot insert")
mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnError(badUpdate)
if err := client.SendHeartbeat(context.Background(), fakeLock); !errors.Is(err, badUpdate) {
t.Errorf("expected RVN error missing: %v", err)
Expand All @@ -260,23 +179,11 @@ func TestDBErrorHandling(t *testing.T) {
client, mock, fakeLock := setup()
badRVN := errors.New("cannot confirm RVN")
mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewErrorResult(badRVN))
if err := client.SendHeartbeat(context.Background(), fakeLock); !errors.Is(err, badRVN) {
t.Errorf("expected RVN confirmation error missing: %v", err)
}
})
t.Run("bad commit", func(t *testing.T) {
client, mock, fakeLock := setup()
badCommit := errors.New("cannot confirm RVN")
mock.ExpectQuery(`SELECT nextval\('locks_rvn'\)`).WillReturnRows(sqlmock.NewRows([]string{"nextval"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec(`UPDATE locks (.+)`).WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit().WillReturnError(badCommit)
if err := client.SendHeartbeat(context.Background(), fakeLock); !errors.Is(err, badCommit) {
t.Errorf("expected commit error missing: %v", err)
}
})
})

t.Run("GetAllLocks", func(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,14 @@ func TestKeepOnRelease(t *testing.T) {
expected := []byte("42")
l, err := c.Acquire(name, pglock.KeepOnRelease(), pglock.WithData(expected))
if err != nil {
t.Fatal("unexpected error while acquiring lock:", err)
t.Fatal("unexpected error while acquiring lock (take 1):", err)
}
t.Log("lock acquired")
l.Close()

l2, err := c.Acquire(name)
if err != nil {
t.Fatal("unexpected error while acquiring lock:", err)
t.Fatal("unexpected error while acquiring lock (take 2):", err)
}
defer l2.Close()
t.Log("lock reacquired")
Expand Down

0 comments on commit a79a1e1

Please sign in to comment.