Skip to content

Commit

Permalink
Merge pull request juju#17009 from Aflynn50/reuse-prepare
Browse files Browse the repository at this point in the history
juju#17009

When using SQLair statements in domain, there is often little to no reuse of the prepared statements. The statements are created just before they are used then dropped when the method finishes. SQLair statements are designed to be prepared once then reused.

This PR adds a `Prepare` function to state base that can be used instead of `sqlair.Prepare`. This function saves the prepared statements in a map keyed by the query itself.

`StateBase.Prepare` can be used right next to where the query is run since it is relatively cheap. This allows readers of the code to quickly grok what is going on.

This code also updates to the latest version of SQLair which returns `ErrNoRows` when `GetAll` returns no rows (previously it just left the list it was passed empty) [SQLair PR here](canonical/sqlair#138).

Because of this checks are added to calls of `GetAll`.

<!-- Why this change is needed and what it does. -->

## Checklist

<!-- If an item is not applicable, use `~strikethrough~`. -->

- [x] Code style: imports ordered, good names, simple structure, etc
- [x] Comments saying why design decisions were made
- [x] Go unit tests, with comments saying what you're testing
- [x] [Integration tests](https://github.com/juju/juju/tree/main/tests), with comments saying what you're testing
- [x] [doc.go](https://discourse.charmhub.io/t/readme-in-packages/451) added or updated in changed packages

## QA steps
  • Loading branch information
jujubot authored Mar 22, 2024
2 parents e6fd53b + e7c5620 commit a01d3b4
Show file tree
Hide file tree
Showing 23 changed files with 292 additions and 170 deletions.
12 changes: 5 additions & 7 deletions domain/annotation/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (st *State) GetAnnotations(ctx context.Context, id annotations.ID) (map[str
return nil, errors.Trace(err)
}

getAnnotationsStmt, err := sqlair.Prepare(getAnnotationsQuery, Annotation{}, sqlair.M{})
getAnnotationsStmt, err := st.Prepare(getAnnotationsQuery, Annotation{}, sqlair.M{})
if err != nil {
return nil, errors.Annotatef(err, "preparing get annotations query for ID: %q", id.Name)
}
Expand All @@ -63,7 +63,6 @@ func (st *State) getAnnotationsForModel(ctx context.Context, id annotations.ID,
err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error {
return tx.Query(ctx, getAnnotationsStmt).GetAll(&annotationsResults)
})

if err != nil {
if errors.Is(err, sqlair.ErrNoRows) {
// No errors, we return empty map if no annotation is found
Expand Down Expand Up @@ -95,7 +94,7 @@ func (st *State) getAnnotationsForID(ctx context.Context, id annotations.ID, get
if err != nil {
return nil, errors.Annotatef(err, "preparing get annotations query for ID: %q", id.Name)
}
kindQueryStmt, err := sqlair.Prepare(kindQuery, sqlair.M{})
kindQueryStmt, err := st.Prepare(kindQuery, sqlair.M{})
if err != nil {
return nil, errors.Annotatef(err, "preparing get annotations query for ID: %q", id.Name)
}
Expand Down Expand Up @@ -123,7 +122,6 @@ func (st *State) getAnnotationsForID(ctx context.Context, id annotations.ID, get
"uuid": uuid,
}).GetAll(&annotationsResults)
})

if err != nil {
if errors.Is(err, sqlair.ErrNoRows) {
// No errors, we return empty map if no annotation is found
Expand Down Expand Up @@ -170,11 +168,11 @@ func (st *State) SetAnnotations(ctx context.Context, id annotations.ID,
}

// Prepare sqlair statements
setAnnotationsStmt, err := sqlair.Prepare(setAnnotationsQuery, Annotation{}, sqlair.M{})
setAnnotationsStmt, err := st.Prepare(setAnnotationsQuery, Annotation{}, sqlair.M{})
if err != nil {
return errors.Annotatef(err, "preparing set annotations query for ID: %q", id.Name)
}
deleteAnnotationsStmt, err := sqlair.Prepare(deleteAnnotationsQuery, Annotation{}, sqlair.M{})
deleteAnnotationsStmt, err := st.Prepare(deleteAnnotationsQuery, Annotation{}, sqlair.M{})
if err != nil {
return errors.Annotatef(err, "preparing set annotations query for ID: %q", id.Name)
}
Expand Down Expand Up @@ -204,7 +202,7 @@ func (st *State) setAnnotationsForID(ctx context.Context, id annotations.ID,
if err != nil {
return errors.Annotatef(err, "preparing uuid retrieval query for ID: %q", id.Name)
}
kindQueryStmt, err := sqlair.Prepare(kindQuery, sqlair.M{})
kindQueryStmt, err := st.Prepare(kindQuery, sqlair.M{})
if err != nil {
return errors.Annotatef(err, "preparing uuid retrieval query for ID: %q", id.Name)
}
Expand Down
29 changes: 16 additions & 13 deletions domain/application/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (st *State) UpsertApplication(ctx context.Context, name string, units ...ap

appNameParam := sqlair.M{"name": name}
query := `SELECT &M.uuid FROM application WHERE name = $M.name`
queryStmt, err := sqlair.Prepare(query, sqlair.M{})
queryStmt, err := st.Prepare(query, sqlair.M{})
if err != nil {
return errors.Trace(err)
}
Expand All @@ -60,12 +60,12 @@ func (st *State) UpsertApplication(ctx context.Context, name string, units ...ap
INSERT INTO application (uuid, name, life_id)
VALUES ($M.application_uuid, $M.name, $M.life_id)
`
createApplicationStmt, err := sqlair.Prepare(createApplication, sqlair.M{})
createApplicationStmt, err := st.Prepare(createApplication, sqlair.M{})
if err != nil {
return errors.Trace(err)
}

upsertUnitFunc, err := upsertUnitFuncGetter()
upsertUnitFunc, err := st.upsertUnitFuncGetter()
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -117,19 +117,19 @@ func (st *State) DeleteApplication(ctx context.Context, name string) error {
appNameParam := sqlair.M{"name": name}

queryApplication := `SELECT &M.uuid FROM application WHERE name = $M.name`
queryApplicationStmt, err := sqlair.Prepare(queryApplication, sqlair.M{})
queryApplicationStmt, err := st.Prepare(queryApplication, sqlair.M{})
if err != nil {
return errors.Trace(err)
}

queryUnits := `SELECT count(*) AS &M.count FROM unit WHERE application_uuid = $M.application_uuid`
queryUnitsStmt, err := sqlair.Prepare(queryUnits, sqlair.M{})
queryUnitsStmt, err := st.Prepare(queryUnits, sqlair.M{})
if err != nil {
return errors.Trace(err)
}

deleteApplication := `DELETE FROM application WHERE name = $M.name`
deleteApplicationStmt, err := sqlair.Prepare(deleteApplication, sqlair.M{})
deleteApplicationStmt, err := st.Prepare(deleteApplication, sqlair.M{})
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -175,7 +175,7 @@ func (st *State) AddUnits(ctx context.Context, applicationName string, args ...a
return errors.Trace(err)
}

upsertUnitFunc, err := upsertUnitFuncGetter()
upsertUnitFunc, err := st.upsertUnitFuncGetter()
if err != nil {
return errors.Trace(err)
}
Expand All @@ -196,15 +196,15 @@ type upsertUnitFunc func(ctx context.Context, tx *sqlair.TX, appName string, par
// upsertUnitFuncGetter returns a function which can be called as many times
// as needed to add units, ensuring that statement preparation is only done once.
// TODO - this just creates a minimal row for now.
func upsertUnitFuncGetter() (upsertUnitFunc, error) {
func (st *State) upsertUnitFuncGetter() (upsertUnitFunc, error) {
query := `SELECT &M.uuid FROM unit WHERE unit_id = $M.name`
queryStmt, err := sqlair.Prepare(query, sqlair.M{})
queryStmt, err := st.Prepare(query, sqlair.M{})
if err != nil {
return nil, errors.Trace(err)
}

queryApplication := `SELECT &M.uuid FROM application WHERE name = $M.name`
queryApplicationStmt, err := sqlair.Prepare(queryApplication, sqlair.M{})
queryApplicationStmt, err := st.Prepare(queryApplication, sqlair.M{})
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -213,13 +213,13 @@ func upsertUnitFuncGetter() (upsertUnitFunc, error) {
INSERT INTO unit (uuid, net_node_uuid, unit_id, life_id, application_uuid)
VALUES ($M.unit_uuid, $M.net_node_uuid, $M.unit_id, $M.life_id, $M.application_uuid)
`
createUnitStmt, err := sqlair.Prepare(createUnit, sqlair.M{})
createUnitStmt, err := st.Prepare(createUnit, sqlair.M{})
if err != nil {
return nil, errors.Trace(err)
}

createNode := `INSERT INTO net_node (uuid) VALUES ($M.net_node_uuid)`
createNodeStmt, err := sqlair.Prepare(createNode, sqlair.M{})
createNodeStmt, err := st.Prepare(createNode, sqlair.M{})
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -292,7 +292,7 @@ func (st *State) StorageDefaults(ctx context.Context) (domainstorage.StorageDefa

attrs := []string{application.StorageDefaultBlockSourceKey, application.StorageDefaultFilesystemSourceKey}
attrsSlice := sqlair.S(transform.Slice(attrs, func(s string) any { return any(s) }))
stmt, err := sqlair.Prepare(`
stmt, err := st.Prepare(`
SELECT &KeyValue.* FROM model_config WHERE key IN ($S[:])
`, sqlair.S{}, KeyValue{})
if err != nil {
Expand All @@ -303,6 +303,9 @@ SELECT &KeyValue.* FROM model_config WHERE key IN ($S[:])
var values []KeyValue
err := tx.Query(ctx, stmt, attrsSlice).GetAll(&values)
if err != nil {
if errors.Is(err, sqlair.ErrNoRows) {
return nil
}
return fmt.Errorf("getting model config attrs for storage defaults: %w", err)
}

Expand Down
2 changes: 1 addition & 1 deletion domain/autocert/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (st *State) Get(ctx context.Context, name string) ([]byte, error) {
SELECT (name, data) AS (&Autocert.*)
FROM autocert_cache
WHERE name = $M.name`
s, err := sqlair.Prepare(q, Autocert{}, sqlair.M{})
s, err := st.Prepare(q, Autocert{}, sqlair.M{})
if err != nil {
return nil, errors.Annotatef(err, "preparing %q", q)
}
Expand Down
35 changes: 19 additions & 16 deletions domain/blockdevice/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ func (st *State) BlockDevices(ctx context.Context, machineId string) ([]blockdev
var result []blockdevice.BlockDevice
err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error {
var err error
result, err = loadBlockDevices(ctx, tx, machineId)
result, err = st.loadBlockDevices(ctx, tx, machineId)
return errors.Trace(err)
})
return result, errors.Trace(err)
}

func loadBlockDevices(ctx context.Context, tx *sqlair.TX, machineId string) ([]blockdevice.BlockDevice, error) {
func (st *State) loadBlockDevices(ctx context.Context, tx *sqlair.TX, machineId string) ([]blockdevice.BlockDevice, error) {
query := `
SELECT bd.* AS &BlockDevice.*,
bdl.* AS &DeviceLink.*,
Expand All @@ -71,7 +71,7 @@ WHERE machine.machine_id = $M.machine_id
sqlair.M{},
}

stmt, err := sqlair.Prepare(query, types...)
stmt, err := st.Prepare(query, types...)
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -84,19 +84,22 @@ WHERE machine.machine_id = $M.machine_id
machineParam := sqlair.M{"machine_id": machineId}
err = tx.Query(ctx, stmt, machineParam).GetAll(&dbRows, &dbDeviceLinks, &dbFilesystemTypes)
if err != nil {
if errors.Is(err, sqlair.ErrNoRows) {
return nil, nil
}
return nil, errors.Annotatef(err, "loading block devices for machine %q", machineId)
}
result, _, err := dbRows.toBlockDevicesAndMachines(dbDeviceLinks, dbFilesystemTypes, nil)
return result, errors.Trace(err)
}

func getMachineInfo(ctx context.Context, tx *sqlair.TX, machineId string) (string, life.Life, error) {
func (st *State) getMachineInfo(ctx context.Context, tx *sqlair.TX, machineId string) (string, life.Life, error) {
q := `
SELECT machine.life_id AS &M.life_id, machine.uuid AS &M.machine_uuid
FROM machine
WHERE machine.machine_id = $M.machine_id
`
stmt, err := sqlair.Prepare(q, sqlair.M{})
stmt, err := st.Prepare(q, sqlair.M{})
if err != nil {
return "", 0, errors.Trace(err)
}
Expand Down Expand Up @@ -127,22 +130,22 @@ func (st *State) SetMachineBlockDevices(ctx context.Context, machineId string, d
}

err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error {
machineUUID, machineLife, err := getMachineInfo(ctx, tx, machineId)
machineUUID, machineLife, err := st.getMachineInfo(ctx, tx, machineId)
if err != nil {
return errors.Trace(err)
}
if machineLife == life.Dead {
return errors.Errorf("cannot update block devices on dead machine %q", machineId)
}
existing, err := loadBlockDevices(ctx, tx, machineId)
existing, err := st.loadBlockDevices(ctx, tx, machineId)
if err != nil {
return errors.Annotatef(err, "loading block devices for machine %q", machineId)
}
if !blockDevicesChanged(existing, devices) {
return nil
}

if err := updateBlockDevices(ctx, tx, machineUUID, devices...); err != nil {
if err := st.updateBlockDevices(ctx, tx, machineUUID, devices...); err != nil {
return errors.Annotatef(err, "updating block devices on machine %q (%s)", machineId, machineUUID)
}
return nil
Expand All @@ -151,7 +154,7 @@ func (st *State) SetMachineBlockDevices(ctx context.Context, machineId string, d
return errors.Trace(err)
}

func updateBlockDevices(ctx context.Context, tx *sqlair.TX, machineUUID string, devices ...blockdevice.BlockDevice) error {
func (st *State) updateBlockDevices(ctx context.Context, tx *sqlair.TX, machineUUID string, devices ...blockdevice.BlockDevice) error {
if err := RemoveMachineBlockDevices(ctx, tx, machineUUID); err != nil {
return errors.Annotatef(err, "removing existing block devices for machine %q", machineUUID)
}
Expand All @@ -161,12 +164,12 @@ func updateBlockDevices(ctx context.Context, tx *sqlair.TX, machineUUID string,
}

fsTypeQuery := `SELECT * AS &FilesystemType.* FROM filesystem_type`
fsTypeStmt, err := sqlair.Prepare(fsTypeQuery, FilesystemType{})
fsTypeStmt, err := st.Prepare(fsTypeQuery, FilesystemType{})
if err != nil {
return errors.Trace(err)
}
var fsTypes []FilesystemType
if err := tx.Query(ctx, fsTypeStmt).GetAll(&fsTypes); err != nil {
if err := tx.Query(ctx, fsTypeStmt).GetAll(&fsTypes); err != nil && !errors.Is(err, sqlair.ErrNoRows) {
return errors.Trace(err)
}
fsTypeByName := make(map[string]int)
Expand All @@ -192,7 +195,7 @@ VALUES (
$BlockDevice.in_use
)
`
insertStmt, err := sqlair.Prepare(insertQuery, BlockDevice{})
insertStmt, err := st.Prepare(insertQuery, BlockDevice{})
if err != nil {
return errors.Trace(err)
}
Expand All @@ -204,7 +207,7 @@ VALUES (
$DeviceLink.name
)
`
insertLinkStmt, err := sqlair.Prepare(insertLinkQuery, DeviceLink{})
insertLinkStmt, err := st.Prepare(insertLinkQuery, DeviceLink{})
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -297,7 +300,7 @@ FROM block_device bd
BlockDeviceMachine{},
}

stmt, err := sqlair.Prepare(query, types...)
stmt, err := st.Prepare(query, types...)
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -313,7 +316,7 @@ FROM block_device bd
dbFilesystemTypes []FilesystemType
dbMachines []BlockDeviceMachine
)
if err := tx.Query(ctx, stmt).GetAll(&dbRows, &dbDeviceLinks, &dbFilesystemTypes, &dbMachines); err != nil {
if err := tx.Query(ctx, stmt).GetAll(&dbRows, &dbDeviceLinks, &dbFilesystemTypes, &dbMachines); err != nil && !errors.Is(err, sqlair.ErrNoRows) {
return errors.Annotate(err, "loading block devices")
}
blockDevices, machines, err = dbRows.toBlockDevicesAndMachines(dbDeviceLinks, dbFilesystemTypes, dbMachines)
Expand Down Expand Up @@ -399,7 +402,7 @@ func (st *State) WatchBlockDevices(
machineLife life.Life
)
err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error {
machineUUID, machineLife, err = getMachineInfo(ctx, tx, machineId)
machineUUID, machineLife, err = st.getMachineInfo(ctx, tx, machineId)
return errors.Trace(err)
})

Expand Down
17 changes: 10 additions & 7 deletions domain/cloud/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,12 @@ func (st *State) UpdateCloudDefaults(
return errors.Trace(err)
}

selectStmt, err := sqlair.Prepare("SELECT &Cloud.uuid FROM cloud WHERE name = $Cloud.name", Cloud{})
selectStmt, err := st.Prepare("SELECT &Cloud.uuid FROM cloud WHERE name = $Cloud.name", Cloud{})
if err != nil {
return errors.Trace(err)
}

deleteStmt, err := sqlair.Prepare(`
deleteStmt, err := st.Prepare(`
DELETE FROM cloud_defaults
WHERE key IN ($Attrs[:])
AND cloud_uuid = $Cloud.uuid;
Expand Down Expand Up @@ -290,7 +290,7 @@ func (st *State) CloudAllRegionDefaults(
return defaults, fmt.Errorf("getting database instance for cloud region defaults: %w", err)
}

stmt, err := sqlair.Prepare(`
stmt, err := st.Prepare(`
SELECT (cloud_region.name,
cloud_region_defaults.key,
cloud_region_defaults.value)
Expand All @@ -311,6 +311,9 @@ WHERE cloud.name = $Cloud.name
var regionDefaultValues []CloudRegionDefaultValue

if err := tx.Query(ctx, stmt, Cloud{Name: cloudName}).GetAll(&regionDefaultValues); err != nil {
if errors.Is(err, sqlair.ErrNoRows) {
return nil
}
return fmt.Errorf("fetching cloud %q region defaults: %w", cloudName, domain.CoerceError(err))
}

Expand Down Expand Up @@ -343,7 +346,7 @@ func (st *State) UpdateCloudRegionDefaults(
return errors.Trace(err)
}

selectStmt, err := sqlair.Prepare(`
selectStmt, err := st.Prepare(`
SELECT cloud_region.uuid AS &CloudRegion.uuid
FROM cloud_region
INNER JOIN cloud
Expand All @@ -355,7 +358,7 @@ AND cloud_region.name = $CloudRegion.name;
return errors.Trace(err)
}

deleteStmt, err := sqlair.Prepare(`
deleteStmt, err := st.Prepare(`
DELETE FROM cloud_region_defaults
WHERE key IN ($Attrs[:])
AND region_uuid = $CloudRegion.uuid;
Expand All @@ -364,7 +367,7 @@ AND region_uuid = $CloudRegion.uuid;
return errors.Trace(err)
}

upsertStmt, err := sqlair.Prepare(`
upsertStmt, err := st.Prepare(`
INSERT INTO cloud_region_defaults (region_uuid, key, value)
VALUES ($CloudRegionDefaults.region_uuid, $CloudRegionDefaults.key, $CloudRegionDefaults.value)
ON CONFLICT(region_uuid, key) DO UPDATE
Expand Down Expand Up @@ -606,7 +609,7 @@ func (st *State) UpsertCloud(ctx context.Context, cloud cloud.Cloud) error {
return errors.Trace(err)
}

selectUUIDStmt, err := sqlair.Prepare("SELECT &Cloud.uuid FROM cloud WHERE name = $Cloud.name", Cloud{})
selectUUIDStmt, err := st.Prepare("SELECT &Cloud.uuid FROM cloud WHERE name = $Cloud.name", Cloud{})
if err != nil {
return errors.Trace(domain.CoerceError(err))
}
Expand Down
Loading

0 comments on commit a01d3b4

Please sign in to comment.