diff --git a/domain/annotation/state/state.go b/domain/annotation/state/state.go index 7367975cc07..6a9b0b80e5a 100644 --- a/domain/annotation/state/state.go +++ b/domain/annotation/state/state.go @@ -37,7 +37,7 @@ func (st *State) GetAnnotations(ctx context.Context, id annotations.ID) (map[str return nil, errors.Trace(err) } - getAnnotationsStmt, err := sqlair.Prepare(getAnnotationsQuery, Annotation{}, sqlair.M{}) + getAnnotationsStmt, err := st.Prepare(getAnnotationsQuery, Annotation{}, sqlair.M{}) if err != nil { return nil, errors.Annotatef(err, "preparing get annotations query for ID: %q", id.Name) } @@ -63,7 +63,6 @@ func (st *State) getAnnotationsForModel(ctx context.Context, id annotations.ID, err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { return tx.Query(ctx, getAnnotationsStmt).GetAll(&annotationsResults) }) - if err != nil { if errors.Is(err, sqlair.ErrNoRows) { // No errors, we return empty map if no annotation is found @@ -95,7 +94,7 @@ func (st *State) getAnnotationsForID(ctx context.Context, id annotations.ID, get if err != nil { return nil, errors.Annotatef(err, "preparing get annotations query for ID: %q", id.Name) } - kindQueryStmt, err := sqlair.Prepare(kindQuery, sqlair.M{}) + kindQueryStmt, err := st.Prepare(kindQuery, sqlair.M{}) if err != nil { return nil, errors.Annotatef(err, "preparing get annotations query for ID: %q", id.Name) } @@ -123,7 +122,6 @@ func (st *State) getAnnotationsForID(ctx context.Context, id annotations.ID, get "uuid": uuid, }).GetAll(&annotationsResults) }) - if err != nil { if errors.Is(err, sqlair.ErrNoRows) { // No errors, we return empty map if no annotation is found @@ -170,11 +168,11 @@ func (st *State) SetAnnotations(ctx context.Context, id annotations.ID, } // Prepare sqlair statements - setAnnotationsStmt, err := sqlair.Prepare(setAnnotationsQuery, Annotation{}, sqlair.M{}) + setAnnotationsStmt, err := st.Prepare(setAnnotationsQuery, Annotation{}, sqlair.M{}) if err != nil { return errors.Annotatef(err, "preparing set annotations query for ID: %q", id.Name) } - deleteAnnotationsStmt, err := sqlair.Prepare(deleteAnnotationsQuery, Annotation{}, sqlair.M{}) + deleteAnnotationsStmt, err := st.Prepare(deleteAnnotationsQuery, Annotation{}, sqlair.M{}) if err != nil { return errors.Annotatef(err, "preparing set annotations query for ID: %q", id.Name) } @@ -204,7 +202,7 @@ func (st *State) setAnnotationsForID(ctx context.Context, id annotations.ID, if err != nil { return errors.Annotatef(err, "preparing uuid retrieval query for ID: %q", id.Name) } - kindQueryStmt, err := sqlair.Prepare(kindQuery, sqlair.M{}) + kindQueryStmt, err := st.Prepare(kindQuery, sqlair.M{}) if err != nil { return errors.Annotatef(err, "preparing uuid retrieval query for ID: %q", id.Name) } diff --git a/domain/application/state/state.go b/domain/application/state/state.go index 3712d66d09a..851bb66434a 100644 --- a/domain/application/state/state.go +++ b/domain/application/state/state.go @@ -51,7 +51,7 @@ func (st *State) UpsertApplication(ctx context.Context, name string, units ...ap appNameParam := sqlair.M{"name": name} query := `SELECT &M.uuid FROM application WHERE name = $M.name` - queryStmt, err := sqlair.Prepare(query, sqlair.M{}) + queryStmt, err := st.Prepare(query, sqlair.M{}) if err != nil { return errors.Trace(err) } @@ -60,12 +60,12 @@ func (st *State) UpsertApplication(ctx context.Context, name string, units ...ap INSERT INTO application (uuid, name, life_id) VALUES ($M.application_uuid, $M.name, $M.life_id) ` - createApplicationStmt, err := sqlair.Prepare(createApplication, sqlair.M{}) + createApplicationStmt, err := st.Prepare(createApplication, sqlair.M{}) if err != nil { return errors.Trace(err) } - upsertUnitFunc, err := upsertUnitFuncGetter() + upsertUnitFunc, err := st.upsertUnitFuncGetter() if err != nil { return errors.Trace(err) } @@ -117,19 +117,19 @@ func (st *State) DeleteApplication(ctx context.Context, name string) error { appNameParam := sqlair.M{"name": name} queryApplication := `SELECT &M.uuid FROM application WHERE name = $M.name` - queryApplicationStmt, err := sqlair.Prepare(queryApplication, sqlair.M{}) + queryApplicationStmt, err := st.Prepare(queryApplication, sqlair.M{}) if err != nil { return errors.Trace(err) } queryUnits := `SELECT count(*) AS &M.count FROM unit WHERE application_uuid = $M.application_uuid` - queryUnitsStmt, err := sqlair.Prepare(queryUnits, sqlair.M{}) + queryUnitsStmt, err := st.Prepare(queryUnits, sqlair.M{}) if err != nil { return errors.Trace(err) } deleteApplication := `DELETE FROM application WHERE name = $M.name` - deleteApplicationStmt, err := sqlair.Prepare(deleteApplication, sqlair.M{}) + deleteApplicationStmt, err := st.Prepare(deleteApplication, sqlair.M{}) if err != nil { return errors.Trace(err) } @@ -175,7 +175,7 @@ func (st *State) AddUnits(ctx context.Context, applicationName string, args ...a return errors.Trace(err) } - upsertUnitFunc, err := upsertUnitFuncGetter() + upsertUnitFunc, err := st.upsertUnitFuncGetter() if err != nil { return errors.Trace(err) } @@ -196,15 +196,15 @@ type upsertUnitFunc func(ctx context.Context, tx *sqlair.TX, appName string, par // upsertUnitFuncGetter returns a function which can be called as many times // as needed to add units, ensuring that statement preparation is only done once. // TODO - this just creates a minimal row for now. -func upsertUnitFuncGetter() (upsertUnitFunc, error) { +func (st *State) upsertUnitFuncGetter() (upsertUnitFunc, error) { query := `SELECT &M.uuid FROM unit WHERE unit_id = $M.name` - queryStmt, err := sqlair.Prepare(query, sqlair.M{}) + queryStmt, err := st.Prepare(query, sqlair.M{}) if err != nil { return nil, errors.Trace(err) } queryApplication := `SELECT &M.uuid FROM application WHERE name = $M.name` - queryApplicationStmt, err := sqlair.Prepare(queryApplication, sqlair.M{}) + queryApplicationStmt, err := st.Prepare(queryApplication, sqlair.M{}) if err != nil { return nil, errors.Trace(err) } @@ -213,13 +213,13 @@ func upsertUnitFuncGetter() (upsertUnitFunc, error) { INSERT INTO unit (uuid, net_node_uuid, unit_id, life_id, application_uuid) VALUES ($M.unit_uuid, $M.net_node_uuid, $M.unit_id, $M.life_id, $M.application_uuid) ` - createUnitStmt, err := sqlair.Prepare(createUnit, sqlair.M{}) + createUnitStmt, err := st.Prepare(createUnit, sqlair.M{}) if err != nil { return nil, errors.Trace(err) } createNode := `INSERT INTO net_node (uuid) VALUES ($M.net_node_uuid)` - createNodeStmt, err := sqlair.Prepare(createNode, sqlair.M{}) + createNodeStmt, err := st.Prepare(createNode, sqlair.M{}) if err != nil { return nil, errors.Trace(err) } @@ -292,7 +292,7 @@ func (st *State) StorageDefaults(ctx context.Context) (domainstorage.StorageDefa attrs := []string{application.StorageDefaultBlockSourceKey, application.StorageDefaultFilesystemSourceKey} attrsSlice := sqlair.S(transform.Slice(attrs, func(s string) any { return any(s) })) - stmt, err := sqlair.Prepare(` + stmt, err := st.Prepare(` SELECT &KeyValue.* FROM model_config WHERE key IN ($S[:]) `, sqlair.S{}, KeyValue{}) if err != nil { @@ -303,6 +303,9 @@ SELECT &KeyValue.* FROM model_config WHERE key IN ($S[:]) var values []KeyValue err := tx.Query(ctx, stmt, attrsSlice).GetAll(&values) if err != nil { + if errors.Is(err, sqlair.ErrNoRows) { + return nil + } return fmt.Errorf("getting model config attrs for storage defaults: %w", err) } diff --git a/domain/autocert/state/state.go b/domain/autocert/state/state.go index eef4c3b7959..dd7dfe8e011 100644 --- a/domain/autocert/state/state.go +++ b/domain/autocert/state/state.go @@ -73,7 +73,7 @@ func (st *State) Get(ctx context.Context, name string) ([]byte, error) { SELECT (name, data) AS (&Autocert.*) FROM autocert_cache WHERE name = $M.name` - s, err := sqlair.Prepare(q, Autocert{}, sqlair.M{}) + s, err := st.Prepare(q, Autocert{}, sqlair.M{}) if err != nil { return nil, errors.Annotatef(err, "preparing %q", q) } diff --git a/domain/blockdevice/state/state.go b/domain/blockdevice/state/state.go index 31fdca29938..87fc7ac02ba 100644 --- a/domain/blockdevice/state/state.go +++ b/domain/blockdevice/state/state.go @@ -46,13 +46,13 @@ func (st *State) BlockDevices(ctx context.Context, machineId string) ([]blockdev var result []blockdevice.BlockDevice err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { var err error - result, err = loadBlockDevices(ctx, tx, machineId) + result, err = st.loadBlockDevices(ctx, tx, machineId) return errors.Trace(err) }) return result, errors.Trace(err) } -func loadBlockDevices(ctx context.Context, tx *sqlair.TX, machineId string) ([]blockdevice.BlockDevice, error) { +func (st *State) loadBlockDevices(ctx context.Context, tx *sqlair.TX, machineId string) ([]blockdevice.BlockDevice, error) { query := ` SELECT bd.* AS &BlockDevice.*, bdl.* AS &DeviceLink.*, @@ -71,7 +71,7 @@ WHERE machine.machine_id = $M.machine_id sqlair.M{}, } - stmt, err := sqlair.Prepare(query, types...) + stmt, err := st.Prepare(query, types...) if err != nil { return nil, errors.Trace(err) } @@ -84,19 +84,22 @@ WHERE machine.machine_id = $M.machine_id machineParam := sqlair.M{"machine_id": machineId} err = tx.Query(ctx, stmt, machineParam).GetAll(&dbRows, &dbDeviceLinks, &dbFilesystemTypes) if err != nil { + if errors.Is(err, sqlair.ErrNoRows) { + return nil, nil + } return nil, errors.Annotatef(err, "loading block devices for machine %q", machineId) } result, _, err := dbRows.toBlockDevicesAndMachines(dbDeviceLinks, dbFilesystemTypes, nil) return result, errors.Trace(err) } -func getMachineInfo(ctx context.Context, tx *sqlair.TX, machineId string) (string, life.Life, error) { +func (st *State) getMachineInfo(ctx context.Context, tx *sqlair.TX, machineId string) (string, life.Life, error) { q := ` SELECT machine.life_id AS &M.life_id, machine.uuid AS &M.machine_uuid FROM machine WHERE machine.machine_id = $M.machine_id ` - stmt, err := sqlair.Prepare(q, sqlair.M{}) + stmt, err := st.Prepare(q, sqlair.M{}) if err != nil { return "", 0, errors.Trace(err) } @@ -127,14 +130,14 @@ func (st *State) SetMachineBlockDevices(ctx context.Context, machineId string, d } err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { - machineUUID, machineLife, err := getMachineInfo(ctx, tx, machineId) + machineUUID, machineLife, err := st.getMachineInfo(ctx, tx, machineId) if err != nil { return errors.Trace(err) } if machineLife == life.Dead { return errors.Errorf("cannot update block devices on dead machine %q", machineId) } - existing, err := loadBlockDevices(ctx, tx, machineId) + existing, err := st.loadBlockDevices(ctx, tx, machineId) if err != nil { return errors.Annotatef(err, "loading block devices for machine %q", machineId) } @@ -142,7 +145,7 @@ func (st *State) SetMachineBlockDevices(ctx context.Context, machineId string, d return nil } - if err := updateBlockDevices(ctx, tx, machineUUID, devices...); err != nil { + if err := st.updateBlockDevices(ctx, tx, machineUUID, devices...); err != nil { return errors.Annotatef(err, "updating block devices on machine %q (%s)", machineId, machineUUID) } return nil @@ -151,7 +154,7 @@ func (st *State) SetMachineBlockDevices(ctx context.Context, machineId string, d return errors.Trace(err) } -func updateBlockDevices(ctx context.Context, tx *sqlair.TX, machineUUID string, devices ...blockdevice.BlockDevice) error { +func (st *State) updateBlockDevices(ctx context.Context, tx *sqlair.TX, machineUUID string, devices ...blockdevice.BlockDevice) error { if err := RemoveMachineBlockDevices(ctx, tx, machineUUID); err != nil { return errors.Annotatef(err, "removing existing block devices for machine %q", machineUUID) } @@ -161,12 +164,12 @@ func updateBlockDevices(ctx context.Context, tx *sqlair.TX, machineUUID string, } fsTypeQuery := `SELECT * AS &FilesystemType.* FROM filesystem_type` - fsTypeStmt, err := sqlair.Prepare(fsTypeQuery, FilesystemType{}) + fsTypeStmt, err := st.Prepare(fsTypeQuery, FilesystemType{}) if err != nil { return errors.Trace(err) } var fsTypes []FilesystemType - if err := tx.Query(ctx, fsTypeStmt).GetAll(&fsTypes); err != nil { + if err := tx.Query(ctx, fsTypeStmt).GetAll(&fsTypes); err != nil && !errors.Is(err, sqlair.ErrNoRows) { return errors.Trace(err) } fsTypeByName := make(map[string]int) @@ -192,7 +195,7 @@ VALUES ( $BlockDevice.in_use ) ` - insertStmt, err := sqlair.Prepare(insertQuery, BlockDevice{}) + insertStmt, err := st.Prepare(insertQuery, BlockDevice{}) if err != nil { return errors.Trace(err) } @@ -204,7 +207,7 @@ VALUES ( $DeviceLink.name ) ` - insertLinkStmt, err := sqlair.Prepare(insertLinkQuery, DeviceLink{}) + insertLinkStmt, err := st.Prepare(insertLinkQuery, DeviceLink{}) if err != nil { return errors.Trace(err) } @@ -297,7 +300,7 @@ FROM block_device bd BlockDeviceMachine{}, } - stmt, err := sqlair.Prepare(query, types...) + stmt, err := st.Prepare(query, types...) if err != nil { return nil, errors.Trace(err) } @@ -313,7 +316,7 @@ FROM block_device bd dbFilesystemTypes []FilesystemType dbMachines []BlockDeviceMachine ) - if err := tx.Query(ctx, stmt).GetAll(&dbRows, &dbDeviceLinks, &dbFilesystemTypes, &dbMachines); err != nil { + if err := tx.Query(ctx, stmt).GetAll(&dbRows, &dbDeviceLinks, &dbFilesystemTypes, &dbMachines); err != nil && !errors.Is(err, sqlair.ErrNoRows) { return errors.Annotate(err, "loading block devices") } blockDevices, machines, err = dbRows.toBlockDevicesAndMachines(dbDeviceLinks, dbFilesystemTypes, dbMachines) @@ -399,7 +402,7 @@ func (st *State) WatchBlockDevices( machineLife life.Life ) err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { - machineUUID, machineLife, err = getMachineInfo(ctx, tx, machineId) + machineUUID, machineLife, err = st.getMachineInfo(ctx, tx, machineId) return errors.Trace(err) }) diff --git a/domain/cloud/state/state.go b/domain/cloud/state/state.go index 5c392cdd4ec..c1abcb4ebe1 100644 --- a/domain/cloud/state/state.go +++ b/domain/cloud/state/state.go @@ -221,12 +221,12 @@ func (st *State) UpdateCloudDefaults( return errors.Trace(err) } - selectStmt, err := sqlair.Prepare("SELECT &Cloud.uuid FROM cloud WHERE name = $Cloud.name", Cloud{}) + selectStmt, err := st.Prepare("SELECT &Cloud.uuid FROM cloud WHERE name = $Cloud.name", Cloud{}) if err != nil { return errors.Trace(err) } - deleteStmt, err := sqlair.Prepare(` + deleteStmt, err := st.Prepare(` DELETE FROM cloud_defaults WHERE key IN ($Attrs[:]) AND cloud_uuid = $Cloud.uuid; @@ -290,7 +290,7 @@ func (st *State) CloudAllRegionDefaults( return defaults, fmt.Errorf("getting database instance for cloud region defaults: %w", err) } - stmt, err := sqlair.Prepare(` + stmt, err := st.Prepare(` SELECT (cloud_region.name, cloud_region_defaults.key, cloud_region_defaults.value) @@ -311,6 +311,9 @@ WHERE cloud.name = $Cloud.name var regionDefaultValues []CloudRegionDefaultValue if err := tx.Query(ctx, stmt, Cloud{Name: cloudName}).GetAll(®ionDefaultValues); err != nil { + if errors.Is(err, sqlair.ErrNoRows) { + return nil + } return fmt.Errorf("fetching cloud %q region defaults: %w", cloudName, domain.CoerceError(err)) } @@ -343,7 +346,7 @@ func (st *State) UpdateCloudRegionDefaults( return errors.Trace(err) } - selectStmt, err := sqlair.Prepare(` + selectStmt, err := st.Prepare(` SELECT cloud_region.uuid AS &CloudRegion.uuid FROM cloud_region INNER JOIN cloud @@ -355,7 +358,7 @@ AND cloud_region.name = $CloudRegion.name; return errors.Trace(err) } - deleteStmt, err := sqlair.Prepare(` + deleteStmt, err := st.Prepare(` DELETE FROM cloud_region_defaults WHERE key IN ($Attrs[:]) AND region_uuid = $CloudRegion.uuid; @@ -364,7 +367,7 @@ AND region_uuid = $CloudRegion.uuid; return errors.Trace(err) } - upsertStmt, err := sqlair.Prepare(` + upsertStmt, err := st.Prepare(` INSERT INTO cloud_region_defaults (region_uuid, key, value) VALUES ($CloudRegionDefaults.region_uuid, $CloudRegionDefaults.key, $CloudRegionDefaults.value) ON CONFLICT(region_uuid, key) DO UPDATE @@ -606,7 +609,7 @@ func (st *State) UpsertCloud(ctx context.Context, cloud cloud.Cloud) error { return errors.Trace(err) } - selectUUIDStmt, err := sqlair.Prepare("SELECT &Cloud.uuid FROM cloud WHERE name = $Cloud.name", Cloud{}) + selectUUIDStmt, err := st.Prepare("SELECT &Cloud.uuid FROM cloud WHERE name = $Cloud.name", Cloud{}) if err != nil { return errors.Trace(domain.CoerceError(err)) } diff --git a/domain/credential/state/state.go b/domain/credential/state/state.go index 494f1cbea2a..309b7356c97 100644 --- a/domain/credential/state/state.go +++ b/domain/credential/state/state.go @@ -53,7 +53,7 @@ AND owner_name = $M.owner AND cloud_name = $M.cloud_name ` - selectStmt, err := sqlair.Prepare(selectQ, sqlair.M{}) + selectStmt, err := st.Prepare(selectQ, sqlair.M{}) if err != nil { return "", errors.Trace(err) } @@ -87,7 +87,7 @@ WHERE cloud_name = $M.cloud_name AND name = $M.credential_name AND owner_name = $M.owner ` - stmt, err := sqlair.Prepare(q, sqlair.M{}) + stmt, err := st.Prepare(q, sqlair.M{}) if err != nil { return nil, errors.Trace(err) } @@ -266,6 +266,9 @@ func dbCredentialFromCredential(ctx context.Context, tx *sqlair.TX, credentialUU if err != nil { return nil, errors.Annotate(err, "loading cloud auth types") } + if errors.Is(err, sqlair.ErrNoRows) { + return nil, errors.Annotate(err, "no valid cloud auth types") + } var validAuthTypeNames []string for _, at := range validAuthTypes { if at.Type == credential.AuthType { @@ -321,7 +324,7 @@ AND cloud_credential.cloud_uuid = ( SELECT uuid FROM cloud WHERE name = $M.cloud_name )` - stmt, err := sqlair.Prepare(q, sqlair.M{}) + stmt, err := st.Prepare(q, sqlair.M{}) if err != nil { return errors.Trace(err) } @@ -383,12 +386,12 @@ func (st *State) CloudCredential(ctx context.Context, key corecredential.Key) (c creds, err = st.loadCloudCredentials(ctx, tx, key.Name, key.Cloud, key.Owner) return errors.Trace(err) }) - if err != nil { - return credential.CloudCredentialResult{}, errors.Trace(err) - } if len(creds) == 0 { return credential.CloudCredentialResult{}, fmt.Errorf("credential %q for cloud %q owned by %q %w", key.Name, key.Cloud, key.Owner, errors.NotFound) } + if err != nil { + return credential.CloudCredentialResult{}, errors.Trace(err) + } if len(creds) > 1 { return credential.CloudCredentialResult{}, errors.Errorf("expected 1 credential, got %d", len(creds)) } @@ -432,7 +435,7 @@ FROM cloud_credential cc queryArgs = []any{args} } - credStmt, err := sqlair.Prepare(credQuery, types...) + credStmt, err := st.Prepare(credQuery, types...) if err != nil { return nil, errors.Trace(err) } @@ -445,6 +448,9 @@ FROM cloud_credential cc ) err = tx.Query(ctx, credStmt, queryArgs...).GetAll(&dbRows, &dbAuthTypes, &dbclouds, &keyValues) if err != nil { + if errors.Is(err, sqlair.ErrNoRows) { + return nil, nil + } return nil, errors.Annotate(err, "loading cloud credentials") } return dbRows.toCloudCredentials(dbAuthTypes, dbclouds, keyValues) @@ -496,11 +502,11 @@ DELETE FROM cloud_credential WHERE cloud_credential.uuid = $M.uuid ` - credAttrDeleteStmt, err := sqlair.Prepare(credAttrDeleteQ, sqlair.M{}) + credAttrDeleteStmt, err := st.Prepare(credAttrDeleteQ, sqlair.M{}) if err != nil { return errors.Trace(err) } - credDeleteStmt, err := sqlair.Prepare(credDeleteQ, sqlair.M{}) + credDeleteStmt, err := st.Prepare(credDeleteQ, sqlair.M{}) if err != nil { return errors.Trace(err) } @@ -569,19 +575,22 @@ JOIN user ON cc.owner_uuid = user.uuid }) query = query + "WHERE " + condition - stmt, err := sqlair.Prepare(query, types...) + stmt, err := st.Prepare(query, types...) if err != nil { return nil, errors.Trace(err) } var info []sqlair.M + result := make(map[coremodel.UUID]string) err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { return tx.Query(ctx, stmt, args).GetAll(&info) }) if err != nil { + if errors.Is(err, sqlair.ErrNoRows) { + return result, nil + } return nil, errors.Trace(err) } - result := make(map[coremodel.UUID]string) for _, m := range info { name, _ := m["name"].(string) uuid, _ := m["model_uuid"].(string) diff --git a/domain/externalcontroller/state/state.go b/domain/externalcontroller/state/state.go index 66ee7ba7f37..9ca5b390b5a 100644 --- a/domain/externalcontroller/state/state.go +++ b/domain/externalcontroller/state/state.go @@ -50,7 +50,7 @@ FROM external_controller AS ctrl LEFT JOIN external_controller_address AS addrs ON ctrl.uuid = addrs.controller_uuid WHERE ctrl.uuid = $M.id` - s, err := sqlair.Prepare(q, Controller{}, sqlair.M{}) + s, err := st.Prepare(q, Controller{}, sqlair.M{}) if err != nil { return nil, errors.Annotatef(err, "preparing %q", q) } @@ -58,12 +58,10 @@ WHERE ctrl.uuid = $M.id` var rows Controllers if err := db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { return errors.Trace(tx.Query(ctx, s, sqlair.M{"id": controllerUUID}).GetAll(&rows)) - }); err != nil { - return nil, errors.Annotate(domain.CoerceError(err), "querying external controller") - } - - if len(rows) == 0 { + }); errors.Is(err, sqlair.ErrNoRows) || len(rows) == 0 { return nil, errors.NotFoundf("external controller %q", controllerUUID) + } else if err != nil { + return nil, errors.Annotate(domain.CoerceError(err), "querying external controller") } return &rows.ToControllerInfo()[0], nil @@ -98,7 +96,7 @@ WHERE ctrl.uuid = ( WHERE model.uuid = $M.model )` - s, err := sqlair.Prepare(q, Controller{}, sqlair.M{}) + s, err := st.Prepare(q, Controller{}, sqlair.M{}) if err != nil { return nil, errors.Annotatef(err, "preparing %q", q) } @@ -108,7 +106,7 @@ WHERE ctrl.uuid = ( for _, modelUUID := range modelUUIDs { var rows Controllers err := tx.Query(ctx, s, sqlair.M{"model": modelUUID}).GetAll(&rows) - if err != nil { + if err != nil && !errors.Is(err, sqlair.ErrNoRows) { return errors.Trace(domain.CoerceError(err)) } resultControllerInfos = append(resultControllerInfos, rows...) diff --git a/domain/machine/state/state.go b/domain/machine/state/state.go index 0a5e3d402c8..bfc6fef5c93 100644 --- a/domain/machine/state/state.go +++ b/domain/machine/state/state.go @@ -37,15 +37,15 @@ func NewState(factory coredb.TxnRunnerFactory, logger Logger) *State { // UpsertMachine creates or updates the specified machine. // TODO - this just creates a minimal row for now. -func (s *State) UpsertMachine(ctx context.Context, machineId string) error { - db, err := s.DB() +func (st *State) UpsertMachine(ctx context.Context, machineId string) error { + db, err := st.DB() if err != nil { return errors.Trace(err) } machineIDParam := sqlair.M{"machine_id": machineId} query := `SELECT &M.uuid FROM machine WHERE machine_id = $M.machine_id` - queryStmt, err := sqlair.Prepare(query, sqlair.M{}) + queryStmt, err := st.Prepare(query, sqlair.M{}) if err != nil { return errors.Trace(err) } @@ -54,13 +54,13 @@ func (s *State) UpsertMachine(ctx context.Context, machineId string) error { INSERT INTO machine (uuid, net_node_uuid, machine_id, life_id) VALUES ($M.machine_uuid, $M.net_node_uuid, $M.machine_id, $M.life_id) ` - createMachineStmt, err := sqlair.Prepare(createMachine, sqlair.M{}) + createMachineStmt, err := st.Prepare(createMachine, sqlair.M{}) if err != nil { return errors.Trace(err) } createNode := `INSERT INTO net_node (uuid) VALUES ($M.net_node_uuid)` - createNodeStmt, err := sqlair.Prepare(createNode, sqlair.M{}) + createNodeStmt, err := st.Prepare(createNode, sqlair.M{}) if err != nil { return errors.Trace(err) } @@ -104,8 +104,8 @@ VALUES ($M.machine_uuid, $M.net_node_uuid, $M.machine_id, $M.life_id) // DeleteMachine deletes the specified machine and any dependent child records. // TODO - this just deals with child block devices for now. -func (s *State) DeleteMachine(ctx context.Context, machineId string) error { - db, err := s.DB() +func (st *State) DeleteMachine(ctx context.Context, machineId string) error { + db, err := st.DB() if err != nil { return errors.Trace(err) } @@ -113,13 +113,13 @@ func (s *State) DeleteMachine(ctx context.Context, machineId string) error { machineIDParam := sqlair.M{"machine_id": machineId} queryMachine := `SELECT &M.uuid FROM machine WHERE machine_id = $M.machine_id` - queryMachineStmt, err := sqlair.Prepare(queryMachine, sqlair.M{}) + queryMachineStmt, err := st.Prepare(queryMachine, sqlair.M{}) if err != nil { return errors.Trace(err) } deleteMachine := `DELETE FROM machine WHERE machine_id = $M.machine_id` - deleteMachineStmt, err := sqlair.Prepare(deleteMachine, sqlair.M{}) + deleteMachineStmt, err := st.Prepare(deleteMachine, sqlair.M{}) if err != nil { return errors.Trace(err) } @@ -128,7 +128,7 @@ func (s *State) DeleteMachine(ctx context.Context, machineId string) error { DELETE FROM net_node WHERE uuid IN (SELECT net_node_uuid FROM machine WHERE machine_id = $M.machine_id) ` - deleteNodeStmt, err := sqlair.Prepare(deleteNode, sqlair.M{}) + deleteNodeStmt, err := st.Prepare(deleteNode, sqlair.M{}) if err != nil { return errors.Trace(err) } diff --git a/domain/modelconfig/state/state.go b/domain/modelconfig/state/state.go index 708c61b6d3f..37fa8d14bb1 100644 --- a/domain/modelconfig/state/state.go +++ b/domain/modelconfig/state/state.go @@ -59,7 +59,7 @@ func (st *State) ModelConfigHasAttributes( } attrsSlice := sqlair.S(transform.Slice(attrs, func(s string) any { return any(s) })) - stmt, err := sqlair.Prepare(` + stmt, err := st.Prepare(` SELECT &key.key FROM model_config WHERE key IN ($S[:]) `, sqlair.S{}, key{}) if err != nil { @@ -176,7 +176,7 @@ func (st *State) UpdateModelConfig( } removeAttrsSlice := sqlair.S(transform.Slice(removeAttrs, func(s string) any { return any(s) })) - deleteStmt, err := sqlair.Prepare(` + deleteStmt, err := st.Prepare(` DELETE FROM model_config WHERE key IN ($S[:]) `[1:], sqlair.S{}) @@ -184,7 +184,7 @@ WHERE key IN ($S[:]) return errors.Trace(err) } - upsertStmt, err := sqlair.Prepare(` + upsertStmt, err := st.Prepare(` INSERT INTO model_config (key, value) VALUES ($M.key, $M.value) ON CONFLICT(key) DO UPDATE SET value = excluded.value diff --git a/domain/network/state/space.go b/domain/network/state/space.go index 7a2d75edf78..6171eed458b 100644 --- a/domain/network/state/space.go +++ b/domain/network/state/space.go @@ -85,7 +85,7 @@ WHERE subnet_type.is_space_settable = FALSE AND subnet.uuid IN ($S[:])`, sqlair // that are of a type on which the space can be set. var nonSettableSubnets []Subnet - if err := tx.Query(ctx, checkInputSubnetsStmt, subnetIDsInS).GetAll(&nonSettableSubnets); err != nil { + if err := tx.Query(ctx, checkInputSubnetsStmt, subnetIDsInS).GetAll(&nonSettableSubnets); err != nil && !errors.Is(err, sqlair.ErrNoRows) { return errors.Annotatef(err, "checking if there are fan subnets for space %q", uuid) } @@ -109,7 +109,7 @@ WHERE subnet_type.is_space_settable = FALSE AND subnet.uuid IN ($S[:])`, sqlair // Retrieve the fan overlays (if any) of the passed subnet ids. var fanSubnets []Subnet err = tx.Query(ctx, findFanSubnetsStmt, subnetIDsInS).GetAll(&fanSubnets) - if err != nil { + if err != nil && !errors.Is(err, sqlair.ErrNoRows) { return errors.Annotatef(err, "retrieving the fan subnets for space %q", uuid) } // Append the fan subnet (unique) ids (if any) to the provided @@ -120,7 +120,7 @@ WHERE subnet_type.is_space_settable = FALSE AND subnet.uuid IN ($S[:])`, sqlair // Update all subnets (including their fan overlays) to include // the space uuid. for _, subnetID := range subnetIDs { - if err := updateSubnetSpaceID(ctx, tx, subnetID, uuid); err != nil { + if err := st.updateSubnetSpaceID(ctx, tx, subnetID, uuid); err != nil { return errors.Annotatef(err, "updating subnet %q using space uuid %q", subnetID, uuid) } } @@ -177,16 +177,15 @@ func (st *State) GetSpace( var spaceRows SpaceSubnetRows if err := db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { err := tx.Query(ctx, spacesStmt, sqlair.M{"id": uuid}).GetAll(&spaceRows) - if err != nil { + if err != nil && !errors.Is(err, sqlair.ErrNoRows) { return errors.Annotatef(err, "retrieving space %q", uuid) } return nil - }); err != nil { - return nil, errors.Annotate(err, "querying spaces") - } - if len(spaceRows) == 0 { + }); errors.Is(err, sqlair.ErrNoRows) || len(spaceRows) == 0 { return nil, errors.NotFoundf("space %q", uuid) + } else if err != nil { + return nil, errors.Annotate(err, "querying spaces") } return &spaceRows.ToSpaceInfos()[0], nil @@ -213,11 +212,10 @@ func (st *State) GetSpaceByName( var rows SpaceSubnetRows if err := db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { return errors.Trace(tx.Query(ctx, s, sqlair.M{"name": name}).GetAll(&rows)) - }); err != nil { - return nil, errors.Annotate(err, "querying spaces by name") - } - if len(rows) == 0 { + }); errors.Is(err, sqlair.ErrNoRows) || len(rows) == 0 { return nil, errors.NotFoundf("space with name %q", name) + } else if err != nil { + return nil, errors.Annotate(err, "querying spaces by name") } return &rows.ToSpaceInfos()[0], nil @@ -241,7 +239,9 @@ func (st *State) GetAllSpaces( var rows SpaceSubnetRows if err := db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { return errors.Trace(tx.Query(ctx, s).GetAll(&rows)) - }); err != nil { + }); errors.Is(err, sqlair.ErrNoRows) || len(rows) == 0 { + return nil, nil + } else if err != nil { return nil, errors.Annotate(err, "querying all spaces") } diff --git a/domain/network/state/subnet.go b/domain/network/state/subnet.go index 94d7df8f5fe..1889c76951c 100644 --- a/domain/network/state/subnet.go +++ b/domain/network/state/subnet.go @@ -30,7 +30,7 @@ func (st *State) UpsertSubnets(ctx context.Context, subnets []network.SubnetInfo return db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { for _, subnet := range subnets { - err := updateSubnetSpaceID( + err := st.updateSubnetSpaceID( ctx, tx, string(subnet.ID), @@ -41,7 +41,7 @@ func (st *State) UpsertSubnets(ctx context.Context, subnets []network.SubnetInfo } // If the subnet doesn't exist yet we need to create it. if errors.Is(err, errors.NotFound) { - if err := addSubnet( + if err := st.addSubnet( ctx, tx, subnet.ID.String(), @@ -55,7 +55,7 @@ func (st *State) UpsertSubnets(ctx context.Context, subnets []network.SubnetInfo }) } -func addSubnet(ctx context.Context, tx *sqlair.TX, subnetUUID string, subnetInfo network.SubnetInfo) error { +func (st *State) addSubnet(ctx context.Context, tx *sqlair.TX, subnetUUID string, subnetInfo network.SubnetInfo) error { var subnetType int if subnetInfo.FanInfo != nil { subnetType = subnetTypeFanOverlaySegment @@ -69,38 +69,38 @@ func addSubnet(ctx context.Context, tx *sqlair.TX, subnetUUID string, subnetInfo return errors.Trace(err) } - insertSubnetStmt, err := sqlair.Prepare(` + insertSubnetStmt, err := st.Prepare(` INSERT INTO subnet (uuid, cidr, vlan_tag, space_uuid, subnet_type_id) VALUES ($Subnet.uuid, $Subnet.cidr, $Subnet.vlan_tag, $Subnet.space_uuid, $Subnet.subnet_type_id)`, Subnet{}) if err != nil { return errors.Trace(err) } - insertSubnetAssociationStmt, err := sqlair.Prepare(` + insertSubnetAssociationStmt, err := st.Prepare(` INSERT INTO subnet_association (subject_subnet_uuid, associated_subnet_uuid, association_type_id) VALUES ($M.subject_subnet_uuid, $M.associated_subnet_uuid, 0)`, sqlair.M{}) // For the moment the only allowed association is 'overlay_of' and therefore its ID is hard-coded here. if err != nil { return errors.Trace(err) } - retrieveUnderlaySubnetUUIDStmt, err := sqlair.Prepare(` + retrieveUnderlaySubnetUUIDStmt, err := st.Prepare(` SELECT &Subnet.uuid FROM subnet WHERE cidr = $Subnet.cidr`, Subnet{}) if err != nil { return errors.Trace(err) } - insertSubnetProviderIDStmt, err := sqlair.Prepare(` + insertSubnetProviderIDStmt, err := st.Prepare(` INSERT INTO provider_subnet (provider_id, subnet_uuid) VALUES ($ProviderSubnet.provider_id, $ProviderSubnet.subnet_uuid)`, ProviderSubnet{}) if err != nil { return errors.Trace(err) } - insertSubnetProviderNetworkIDStmt, err := sqlair.Prepare(` + insertSubnetProviderNetworkIDStmt, err := st.Prepare(` INSERT INTO provider_network (uuid, provider_network_id) VALUES ($ProviderNetwork.uuid, $ProviderNetwork.provider_network_id)`, ProviderNetwork{}) if err != nil { return errors.Trace(err) } - insertSubnetProviderNetworkSubnetStmt, err := sqlair.Prepare(` + insertSubnetProviderNetworkSubnetStmt, err := st.Prepare(` INSERT INTO provider_network_subnet (provider_network_uuid, subnet_uuid) VALUES ($ProviderNetworkSubnet.provider_network_uuid, $ProviderNetworkSubnet.subnet_uuid)`, ProviderNetworkSubnet{}) if err != nil { @@ -165,26 +165,26 @@ VALUES ($ProviderNetworkSubnet.provider_network_uuid, $ProviderNetworkSubnet.sub return errors.Annotatef(err, "inserting association between provider network id %q and subnet %q", subnetInfo.ProviderNetworkId, subnetUUID) } - return addAvailabilityZones(ctx, tx, subnetUUID, subnetInfo) + return st.addAvailabilityZones(ctx, tx, subnetUUID, subnetInfo) } // addAvailabilityZones adds the availability zones of a subnet if they don't exist, and // update the availability_zone_subnet table with the subnet's id. -func addAvailabilityZones(ctx context.Context, tx *sqlair.TX, subnetUUID string, subnet network.SubnetInfo) error { - retrieveAvailabilityZoneStmt, err := sqlair.Prepare(` +func (st *State) addAvailabilityZones(ctx context.Context, tx *sqlair.TX, subnetUUID string, subnet network.SubnetInfo) error { + retrieveAvailabilityZoneStmt, err := st.Prepare(` SELECT &M.uuid FROM availability_zone WHERE name = $M.name`, sqlair.M{}) if err != nil { return errors.Trace(err) } - insertAvailabilityZoneStmt, err := sqlair.Prepare(` + insertAvailabilityZoneStmt, err := st.Prepare(` INSERT INTO availability_zone (uuid, name) VALUES ($M.uuid, $M.name)`, sqlair.M{}) if err != nil { return errors.Trace(err) } - insertAvailabilityZoneSubnetStmt, err := sqlair.Prepare(` + insertAvailabilityZoneSubnetStmt, err := st.Prepare(` INSERT INTO availability_zone_subnet (availability_zone_uuid, subnet_uuid) VALUES ($M.availability_zone_uuid, $M.subnet_uuid)`, sqlair.M{}) if err != nil { @@ -240,7 +240,7 @@ func (st *State) AddSubnet( return errors.Trace( db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { - return addSubnet(ctx, tx, subnet.ID.String(), subnet) + return st.addSubnet(ctx, tx, subnet.ID.String(), subnet) }), ) } @@ -295,7 +295,7 @@ func (st *State) GetAllSubnets( // Append the space uuid condition to the query only if it's passed to the function. q := retrieveSubnetsStmt + ";" - s, err := sqlair.Prepare(q, SpaceSubnetRow{}) + s, err := st.Prepare(q, SpaceSubnetRow{}) if err != nil { return nil, errors.Annotatef(err, "preparing %q", q) } @@ -303,7 +303,9 @@ func (st *State) GetAllSubnets( var rows SpaceSubnetRows if err := db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { return errors.Trace(tx.Query(ctx, s).GetAll(&rows)) - }); err != nil { + }); errors.Is(err, sqlair.ErrNoRows) { + return nil, nil + } else if err != nil { return nil, errors.Annotate(err, "querying subnets") } @@ -323,7 +325,7 @@ func (st *State) GetSubnet( // Append the space uuid condition to the query only if it's passed to the function. q := retrieveSubnetsStmt + " WHERE subnet.uuid = $M.id;" - stmt, err := sqlair.Prepare(q, SpaceSubnetRow{}, sqlair.M{}) + stmt, err := st.Prepare(q, SpaceSubnetRow{}, sqlair.M{}) if err != nil { return nil, errors.Annotatef(err, "preparing %q", q) } @@ -331,12 +333,10 @@ func (st *State) GetSubnet( var rows SpaceSubnetRows if err := db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { return errors.Trace(tx.Query(ctx, stmt, sqlair.M{"id": uuid}).GetAll(&rows)) - }); err != nil { - return nil, errors.Annotate(err, "querying subnets") - } - - if len(rows) == 0 { + }); errors.Is(err, sqlair.ErrNoRows) { return nil, errors.NotFoundf("subnet %q", uuid) + } else if err != nil { + return nil, errors.Annotate(err, "querying subnets") } return &rows.ToSubnetInfos()[0], nil @@ -357,7 +357,7 @@ func (st *State) GetSubnetsByCIDR( // Append the where clause to the query. q := retrieveSubnetsStmt + " WHERE subnet.cidr = $M.cidr;" - s, err := sqlair.Prepare(q, SpaceSubnetRow{}, sqlair.M{}) + s, err := st.Prepare(q, SpaceSubnetRow{}, sqlair.M{}) if err != nil { return nil, errors.Annotatef(err, "preparing %q", q) } @@ -372,20 +372,20 @@ func (st *State) GetSubnetsByCIDR( resultSubnets = append(resultSubnets, rows...) } return nil - }); err != nil { + }); err != nil && !errors.Is(err, sqlair.ErrNoRows) { return nil, errors.Annotate(err, "querying subnets") } return resultSubnets.ToSubnetInfos(), nil } -func updateSubnetSpaceID( +func (st *State) updateSubnetSpaceID( ctx context.Context, tx *sqlair.TX, uuid string, spaceID string, ) error { - updateSubnetSpaceIDStmt, err := sqlair.Prepare(` + updateSubnetSpaceIDStmt, err := st.Prepare(` UPDATE subnet SET space_uuid = $M.space_uuid WHERE uuid = $M.uuid;`, sqlair.M{}) @@ -421,7 +421,7 @@ func (st *State) UpdateSubnet( } return db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { - return updateSubnetSpaceID(ctx, tx, uuid, spaceID) + return st.updateSubnetSpaceID(ctx, tx, uuid, spaceID) }) } diff --git a/domain/permission/state/state.go b/domain/permission/state/state.go index ae49b8b57bd..54f8ef33574 100644 --- a/domain/permission/state/state.go +++ b/domain/permission/state/state.go @@ -50,16 +50,16 @@ func NewState(factory coredatabase.TxnRunnerFactory, logger Logger) *State { // usererrors.AuthenticationDisabled is returned. // If a permission for the user and target key already exists, // permissionerrors.AlreadyExists is returned. -func (s *State) CreatePermission(ctx context.Context, newPermissionUUID uuid.UUID, spec corepermission.UserAccessSpec) (corepermission.UserAccess, error) { +func (st *State) CreatePermission(ctx context.Context, newPermissionUUID uuid.UUID, spec corepermission.UserAccessSpec) (corepermission.UserAccess, error) { var userAccess corepermission.UserAccess - db, err := s.DB() + db, err := st.DB() if err != nil { return userAccess, errors.Trace(err) } err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { - user, err := findUser(ctx, tx, spec.User) + user, err := st.findUser(ctx, tx, spec.User) if err != nil { return errors.Trace(err) } @@ -225,7 +225,7 @@ func (s *State) ReadAllAccessTypeForUser(ctx context.Context, subject string, ac // findUser finds the user provided exists, hasn't been removed and is not // disabled. Return data needed to fill in corePermission.UserAccess. -func findUser( +func (st *State) findUser( ctx context.Context, tx *sqlair.TX, userName string, @@ -240,7 +240,7 @@ FROM v_user_auth u ON u.created_by_uuid = creator.uuid WHERE u.removed = false AND u.name = $M.name` - selectUserStmt, err := sqlair.Prepare(getUserQuery, User{}, sqlair.M{}) + selectUserStmt, err := st.Prepare(getUserQuery, User{}, sqlair.M{}) if err != nil { return result, errors.Annotate(err, "preparing select getUser query") } @@ -354,7 +354,7 @@ SELECT &M.found_it FROM ( // ErrNoRows. var foundIt = []sqlair.M{} err = tx.Query(ctx, targetExistsStmt, sqlair.M{"grant_on": targetKey}).GetAll(&foundIt) - if err != nil { + if err != nil && !errors.Is(err, sqlair.ErrNoRows) { return errors.Annotatef(err, "verifying %q target exists", targetKey) } diff --git a/domain/secretbackend/state/state.go b/domain/secretbackend/state/state.go index c74c85e012f..815aa84a439 100644 --- a/domain/secretbackend/state/state.go +++ b/domain/secretbackend/state/state.go @@ -153,7 +153,7 @@ ORDER BY b.name`, SecretBackendRow{}) var rows SecretBackendRows err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { err := tx.Query(ctx, stmt).GetAll(&rows) - if err != nil { + if err != nil && !errors.Is(err, sqlair.ErrNoRows) { return fmt.Errorf("querying secret backends: %w", err) } return nil @@ -193,12 +193,12 @@ WHERE b.%s = $M.identifier`, columName) } return nil }) + if errors.Is(err, sqlair.ErrNoRows) { + return nil, fmt.Errorf("%w: %q", backenderrors.NotFound, v) + } if err != nil { return nil, fmt.Errorf("cannot list secret backends: %w", err) } - if len(rows) == 0 { - return nil, fmt.Errorf("%w: %q", backenderrors.NotFound, v) - } return rows.toSecretBackends()[0], errors.Trace(err) } @@ -290,7 +290,7 @@ WHERE b.uuid IN ($S[:])`, args := sqlair.S(transform.Slice(backendIDs, func(s string) any { return any(s) })) err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { err := tx.Query(ctx, stmt, args).GetAll(&rows) - if err != nil { + if err != nil && !errors.Is(err, sqlair.ErrNoRows) { return fmt.Errorf("querying secret backend rotation changes: %w", err) } return nil diff --git a/domain/state.go b/domain/state.go index 7883ca50eb3..3f7f030618a 100644 --- a/domain/state.go +++ b/domain/state.go @@ -6,6 +6,7 @@ package domain import ( "sync" + "github.com/canonical/sqlair" "github.com/juju/errors" "github.com/juju/juju/core/database" @@ -17,12 +18,17 @@ type StateBase struct { mu sync.Mutex getDB database.TxnRunnerFactory db database.TxnRunner + + // stmts is a cache of sqlair statements keyed by the query string. + stmts map[string]*sqlair.Statement + stmtMutex sync.RWMutex } // NewStateBase returns a new StateBase. func NewStateBase(getDB database.TxnRunnerFactory) *StateBase { return &StateBase{ getDB: getDB, + stmts: make(map[string]*sqlair.Statement), } } @@ -44,3 +50,34 @@ func (st *StateBase) DB() (database.TxnRunner, error) { return st.db, nil } + +// Prepare prepares a SQLair query. If the query has been prepared previously it +// is retrieved from the statement cache. +// +// Note that because the type samples are not considered when retrieving a query +// from the cache, it is possible that two queries may have identical text, but +// use different types. Retrieving the wrong query would result in an error when +// the query was passed the wrong type at execution. +// +// The likelihood of this happening is low since the statement cache is scoped +// to individual domains meaning that the two identically worded statements +// would have to be in the same state package. This issue should be relatively +// rare and caught by QA if present. +func (st *StateBase) Prepare(query string, typeSamples ...any) (*sqlair.Statement, error) { + // Take a read lock to check if the statement is already prepared. + st.stmtMutex.RLock() + if stmt, ok := st.stmts[query]; ok && stmt != nil { + st.stmtMutex.RUnlock() + return stmt, nil + } + st.stmtMutex.RUnlock() + // Grab the write lock to prepare the statement. + st.stmtMutex.Lock() + defer st.stmtMutex.Unlock() + stmt, err := sqlair.Prepare(query, typeSamples...) + if err != nil { + return nil, errors.Annotate(err, "preparing:") + } + st.stmts[query] = stmt + return stmt, nil +} diff --git a/domain/state_test.go b/domain/state_test.go index eb97a865dbb..96bb60fb0b2 100644 --- a/domain/state_test.go +++ b/domain/state_test.go @@ -4,6 +4,9 @@ package domain import ( + "context" + + "github.com/canonical/sqlair" gc "gopkg.in/check.v1" schematesting "github.com/juju/juju/domain/schema/testing" @@ -28,3 +31,68 @@ func (s *stateSuite) TestStateBaseGetDBNilFactory(c *gc.C) { _, err := base.DB() c.Assert(err, gc.ErrorMatches, `nil getDB`) } + +func (s *stateSuite) TestStateBasePrepare(c *gc.C) { + f := s.TxnRunnerFactory() + base := NewStateBase(f) + db, err := base.DB() + c.Assert(err, gc.IsNil) + c.Assert(db, gc.NotNil) + + // Prepare new query. + stmt1, err := base.Prepare("SELECT name AS &M.* FROM sqlite_schema", sqlair.M{}) + c.Assert(err, gc.IsNil) + // Validate prepared statement works as expected. + err = db.Txn(context.Background(), func(ctx context.Context, tx *sqlair.TX) error { + results := sqlair.M{} + err := tx.Query(ctx, stmt1).Get(results) + if err != nil { + return err + } + c.Assert(results["name"], gc.Equals, "schema") + return nil + }) + c.Assert(err, gc.IsNil) + + // Retrieve previous statement. + stmt2, err := base.Prepare("SELECT name AS &M.* FROM sqlite_schema", sqlair.M{}) + c.Assert(err, gc.IsNil) + c.Assert(stmt1, gc.Equals, stmt2) +} + +func (s *stateSuite) TestStateBasePrepareKeyClash(c *gc.C) { + f := s.TxnRunnerFactory() + base := NewStateBase(f) + db, err := base.DB() + c.Assert(err, gc.IsNil) + c.Assert(db, gc.NotNil) + + // Prepare statement with TestType. + { + type TestType struct { + WrongName string `db:"type"` + } + _, err := base.Prepare("SELECT &TestType.* FROM sqlite_schema", TestType{}) + c.Assert(err, gc.IsNil) + } + + // Prepare statement with a different type of the same name, this will + // retrieve the previously prepared statement which used the shadowed type. + type TestType struct { + Name string `db:"name"` + } + stmt, err := base.Prepare("SELECT &TestType.* FROM sqlite_schema", TestType{}) + + // Try and run a query. + c.Assert(err, gc.IsNil) + err = db.Txn(context.Background(), func(ctx context.Context, tx *sqlair.TX) error { + results := TestType{} + err := tx.Query(ctx, stmt).Get(&results) + if err != nil { + return err + } + c.Assert(results.Name, gc.Equals, "schema") + return nil + }) + c.Assert(err, gc.ErrorMatches, `cannot get result: parameter with type "domain.TestType" missing, have type with same name: "domain.TestType"`) +} diff --git a/domain/storage/state/storagepool.go b/domain/storage/state/storagepool.go index de42b2fa48e..9d7ca88e9e2 100644 --- a/domain/storage/state/storagepool.go +++ b/domain/storage/state/storagepool.go @@ -194,11 +194,11 @@ DELETE FROM storage_pool WHERE storage_pool.uuid = (select uuid FROM storage_pool WHERE name = $M.name) ` - poolAttributeDeleteStmt, err := sqlair.Prepare(poolAttributeDeleteQ, sqlair.M{}) + poolAttributeDeleteStmt, err := st.Prepare(poolAttributeDeleteQ, sqlair.M{}) if err != nil { return errors.Trace(err) } - poolDeleteStmt, err := sqlair.Prepare(poolDeleteQ, sqlair.M{}) + poolDeleteStmt, err := st.Prepare(poolDeleteQ, sqlair.M{}) if err != nil { return errors.Trace(err) } @@ -233,7 +233,7 @@ func (st StoragePoolState) ReplaceStoragePool(ctx context.Context, pool domainst return errors.Trace(err) } - selectUUIDStmt, err := sqlair.Prepare("SELECT &StoragePool.uuid FROM storage_pool WHERE name = $StoragePool.name", StoragePool{}) + selectUUIDStmt, err := st.Prepare("SELECT &StoragePool.uuid FROM storage_pool WHERE name = $StoragePool.name", StoragePool{}) if err != nil { return errors.Trace(domain.CoerceError(err)) } @@ -304,7 +304,7 @@ FROM storage_pool sp keyValues []poolAttribute ) err = tx.Query(ctx, queryStmt, queryArgs...).GetAll(&dbRows, &keyValues) - if err != nil { + if err != nil && !errors.Is(err, sqlair.ErrNoRows) { return nil, errors.Annotate(err, "loading storage pool") } return dbRows.toStoragePools(keyValues) diff --git a/domain/unit/state/state.go b/domain/unit/state/state.go index 6dbc46d9f4e..fdec1fc1654 100644 --- a/domain/unit/state/state.go +++ b/domain/unit/state/state.go @@ -33,8 +33,8 @@ func NewState(factory coredb.TxnRunnerFactory, logger Logger) *State { } // DeleteUnit deletes the specified unit. -func (s *State) DeleteUnit(ctx context.Context, unitName string) error { - db, err := s.DB() +func (st *State) DeleteUnit(ctx context.Context, unitName string) error { + db, err := st.DB() if err != nil { return errors.Trace(err) } @@ -42,13 +42,13 @@ func (s *State) DeleteUnit(ctx context.Context, unitName string) error { unitIDParam := sqlair.M{"unit_id": unitName} queryUnit := `SELECT uuid as &M.uuid FROM unit WHERE unit_id = $M.unit_id` - queryUnitStmt, err := sqlair.Prepare(queryUnit, sqlair.M{}) + queryUnitStmt, err := st.Prepare(queryUnit, sqlair.M{}) if err != nil { return errors.Trace(err) } deleteUnit := `DELETE FROM unit WHERE unit_id = $M.unit_id` - deleteUnitStmt, err := sqlair.Prepare(deleteUnit, sqlair.M{}) + deleteUnitStmt, err := st.Prepare(deleteUnit, sqlair.M{}) if err != nil { return errors.Trace(err) } @@ -57,7 +57,7 @@ func (s *State) DeleteUnit(ctx context.Context, unitName string) error { DELETE FROM net_node WHERE uuid IN (SELECT net_node_uuid FROM unit WHERE unit_id = $M.unit_id) ` - deleteNodeStmt, err := sqlair.Prepare(deleteNode, sqlair.M{}) + deleteNodeStmt, err := st.Prepare(deleteNode, sqlair.M{}) if err != nil { return errors.Trace(err) } diff --git a/domain/upgrade/state/state.go b/domain/upgrade/state/state.go index c8af30630ec..ca89e928884 100644 --- a/domain/upgrade/state/state.go +++ b/domain/upgrade/state/state.go @@ -84,7 +84,7 @@ FROM upgrade_info_controller_node WHERE upgrade_info_uuid = $M.info_uuid AND controller_node_id = $M.controller_id; ` - lookForReadyNodeStmt, err := sqlair.Prepare(lookForReadyNodeQuery, infoControllerNode{}, sqlair.M{}) + lookForReadyNodeStmt, err := st.Prepare(lookForReadyNodeQuery, infoControllerNode{}, sqlair.M{}) if err != nil { return errors.Annotatef(err, "preparing %q", lookForReadyNodeQuery) } @@ -93,7 +93,7 @@ AND controller_node_id = $M.controller_id; INSERT INTO upgrade_info_controller_node (uuid, controller_node_id, upgrade_info_uuid) VALUES ($M.uuid, $M.controller_id, $M.info_uuid); ` - insertUpgradeNodeStmt, err := sqlair.Prepare(insertUpgradeNodeQuery, sqlair.M{}) + insertUpgradeNodeStmt, err := st.Prepare(insertUpgradeNodeQuery, sqlair.M{}) if err != nil { return errors.Annotatef(err, "preparing %q", insertUpgradeNodeQuery) } @@ -225,7 +225,7 @@ UPDATE upgrade_info SET state_type_id = $M.to_state WHERE uuid = $M.info_uuid AND state_type_id = $M.from_state;` - completedDBUpgradeStmt, err := sqlair.Prepare(q, sqlair.M{}) + completedDBUpgradeStmt, err := st.Prepare(q, sqlair.M{}) if err != nil { return errors.Annotatef(err, "preparing %q", q) } @@ -260,7 +260,7 @@ UPDATE upgrade_info SET state_type_id = $M.to_state WHERE uuid = $M.info_uuid AND state_type_id = $M.from_state;` - completedDBUpgradeStmt, err := sqlair.Prepare(q, sqlair.M{}) + completedDBUpgradeStmt, err := st.Prepare(q, sqlair.M{}) if err != nil { return errors.Annotatef(err, "preparing %q", q) } @@ -300,7 +300,7 @@ SELECT (controller_node_id, node_upgrade_completed_at) AS (&infoControllerNode.* FROM upgrade_info_controller_node WHERE upgrade_info_uuid = $M.info_uuid AND controller_node_id = $M.controller_id;` - lookForDoneNodesStmt, err := sqlair.Prepare(lookForDoneNodesQuery, infoControllerNode{}, sqlair.M{}) + lookForDoneNodesStmt, err := st.Prepare(lookForDoneNodesQuery, infoControllerNode{}, sqlair.M{}) if err != nil { return errors.Annotatef(err, "preparing select done query") } @@ -312,7 +312,7 @@ WHERE upgrade_info_uuid = $M.info_uuid AND controller_node_id = $M.controller_id AND node_upgrade_completed_at IS NULL; ` - setNodeToDoneStmt, err := sqlair.Prepare(setNodeToDoneQuery, sqlair.M{}) + setNodeToDoneStmt, err := st.Prepare(setNodeToDoneQuery, sqlair.M{}) if err != nil { return errors.Annotatef(err, "preparing update node query") } @@ -332,7 +332,7 @@ AND ( WHERE upgrade_info_uuid = $M.info_uuid ); ` - completeUpgradeStmt, err := sqlair.Prepare(completeUpgradeQuery, sqlair.M{}) + completeUpgradeStmt, err := st.Prepare(completeUpgradeQuery, sqlair.M{}) if err != nil { return errors.Annotatef(err, "preparing complete upgrade query") } diff --git a/domain/upgrade/state/state_test.go b/domain/upgrade/state/state_test.go index 981382f6a49..811117dcac9 100644 --- a/domain/upgrade/state/state_test.go +++ b/domain/upgrade/state/state_test.go @@ -6,6 +6,7 @@ package state import ( "context" "database/sql" + "errors" "github.com/canonical/sqlair" jc "github.com/juju/testing/checkers" @@ -452,6 +453,9 @@ WHERE upgrade_info_uuid = $M.info_uuid` ) err = db.Txn(context.Background(), func(ctx context.Context, tx *sqlair.TX) error { err = tx.Query(ctx, nodeInfosS, sqlair.M{"info_uuid": upgradeUUID}).GetAll(&nodeInfos) + if errors.Is(err, sqlair.ErrNoRows) { + return nil + } c.Assert(err, jc.ErrorIsNil) return nil }) diff --git a/domain/user/state/state.go b/domain/user/state/state.go index 08ed48d074a..04ba4346d00 100644 --- a/domain/user/state/state.go +++ b/domain/user/state/state.go @@ -134,14 +134,14 @@ FROM v_user_auth u WHERE u.removed = false ` - selectGetAllUsersStmt, err := sqlair.Prepare(getAllUsersQuery, User{}, sqlair.M{}) + selectGetAllUsersStmt, err := st.Prepare(getAllUsersQuery, User{}, sqlair.M{}) if err != nil { return errors.Annotate(err, "preparing select getAllUsers query") } var results []User err = tx.Query(ctx, selectGetAllUsersStmt).GetAll(&results) - if err != nil { + if err != nil && !errors.Is(err, sqlair.ErrNoRows) { return errors.Annotate(err, "getting query results") } @@ -174,7 +174,7 @@ SELECT (uuid, name, display_name, created_by_uuid, created_at, last_login, disab FROM v_user_auth WHERE uuid = $M.uuid` - selectGetUserStmt, err := sqlair.Prepare(getUserQuery, User{}, sqlair.M{}) + selectGetUserStmt, err := st.Prepare(getUserQuery, User{}, sqlair.M{}) if err != nil { return errors.Annotate(err, "preparing select getUser query") } @@ -201,6 +201,7 @@ WHERE uuid = $M.uuid` // GetUserUUIDByName will retrieve the user uuid for the user identifier by name. // If the user does not exist an error that satisfies [usererrors.NotFound] will // be returned. +// Exported for use in credential. func GetUserUUIDByName( ctx context.Context, tx *sqlair.TX, @@ -261,7 +262,7 @@ FROM v_user_auth WHERE name = $M.name AND removed = false` - selectGetUserByNameStmt, err := sqlair.Prepare(getUserByNameQuery, User{}, sqlair.M{}) + selectGetUserByNameStmt, err := st.Prepare(getUserByNameQuery, User{}, sqlair.M{}) if err != nil { return errors.Annotate(err, "preparing select getUserByName query") } @@ -308,7 +309,7 @@ WHERE user.name = $M.name AND removed = false ` - selectGetUserByAuthStmt, err := sqlair.Prepare(getUserWithAuthQuery, User{}, sqlair.M{}) + selectGetUserByAuthStmt, err := st.Prepare(getUserWithAuthQuery, User{}, sqlair.M{}) if err != nil { return user.User{}, errors.Annotate(err, "preparing select getUserWithAuth query") } @@ -360,17 +361,17 @@ func (st *State) RemoveUser(ctx context.Context, name string) error { m := make(sqlair.M, 1) - deletePassStmt, err := sqlair.Prepare("DELETE FROM user_password WHERE user_uuid = $M.uuid", m) + deletePassStmt, err := st.Prepare("DELETE FROM user_password WHERE user_uuid = $M.uuid", m) if err != nil { return errors.Annotate(err, "preparing password deletion query") } - deleteKeyStmt, err := sqlair.Prepare("DELETE FROM user_activation_key WHERE user_uuid = $M.uuid", m) + deleteKeyStmt, err := st.Prepare("DELETE FROM user_activation_key WHERE user_uuid = $M.uuid", m) if err != nil { return errors.Annotate(err, "preparing password deletion query") } - setRemovedStmt, err := sqlair.Prepare("UPDATE user SET removed = true WHERE uuid = $M.uuid", m) + setRemovedStmt, err := st.Prepare("UPDATE user SET removed = true WHERE uuid = $M.uuid", m) if err != nil { return errors.Annotate(err, "preparing password deletion query") } @@ -416,7 +417,7 @@ func (st *State) SetActivationKey(ctx context.Context, name string, activationKe m := make(sqlair.M, 1) - deletePassStmt, err := sqlair.Prepare("DELETE FROM user_password WHERE user_uuid = $M.uuid", m) + deletePassStmt, err := st.Prepare("DELETE FROM user_password WHERE user_uuid = $M.uuid", m) if err != nil { return errors.Annotate(err, "preparing password deletion query") } @@ -451,7 +452,7 @@ func (st *State) GetActivationKey(ctx context.Context, name string) ([]byte, err m := make(sqlair.M, 1) - selectKeyStmt, err := sqlair.Prepare(` + selectKeyStmt, err := st.Prepare(` SELECT (*) AS (&ActivationKey.*) FROM user_activation_key WHERE user_uuid = $M.uuid `, m, ActivationKey{}) if err != nil { @@ -498,7 +499,7 @@ func (st *State) SetPasswordHash(ctx context.Context, name string, passwordHash m := make(sqlair.M, 1) - deleteKeyStmt, err := sqlair.Prepare("DELETE FROM user_activation_key WHERE user_uuid = $M.uuid", m) + deleteKeyStmt, err := st.Prepare("DELETE FROM user_activation_key WHERE user_uuid = $M.uuid", m) if err != nil { return errors.Annotate(err, "preparing password deletion query") } @@ -540,7 +541,7 @@ VALUES ($M.uuid, false) ON CONFLICT(user_uuid) DO UPDATE SET disabled = false` - enableUserStmt, err := sqlair.Prepare(q, m) + enableUserStmt, err := st.Prepare(q, m) if err != nil { return errors.Annotate(err, "preparing enable user query") } @@ -581,7 +582,7 @@ VALUES ($M.uuid, true) ON CONFLICT(user_uuid) DO UPDATE SET disabled = true` - disableUserStmt, err := sqlair.Prepare(q, m) + disableUserStmt, err := st.Prepare(q, m) if err != nil { return errors.Annotate(err, "preparing disable user query") } @@ -645,7 +646,7 @@ UPDATE user_authentication SET last_login = datetime('now') WHERE user_uuid = $M.uuid` - updateLastLoginStmt, err := sqlair.Prepare(q, m) + updateLastLoginStmt, err := st.Prepare(q, m) if err != nil { return errors.Annotate(err, "preparing update updateLastLogin query") } @@ -861,7 +862,7 @@ func (st *State) uuidForName( func (st *State) getActiveUUIDStmt() (*sqlair.Statement, error) { var err error if st.activeUUIDStmt == nil { - st.activeUUIDStmt, err = sqlair.Prepare( + st.activeUUIDStmt, err = st.Prepare( "SELECT &M.uuid FROM user WHERE name = $M.name AND IFNULL(removed, false) = false", sqlair.M{}) } return st.activeUUIDStmt, errors.Annotate(err, "preparing user UUID statement") diff --git a/go.mod b/go.mod index 5d15c754d8d..75d5d227c06 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( github.com/canonical/go-dqlite v1.21.0 github.com/canonical/lxd v0.0.0-20231214113525-e676fc63c50a github.com/canonical/pebble v1.9.0 - github.com/canonical/sqlair v0.0.0-20240206143658-d8b84b389a40 + github.com/canonical/sqlair v0.0.0-20240319123511-a48b89bb30aa github.com/chzyer/readline v1.5.1 github.com/coreos/go-systemd/v22 v22.5.0 github.com/docker/distribution v2.8.3+incompatible diff --git a/go.sum b/go.sum index 65d241d0b06..223ed953593 100644 --- a/go.sum +++ b/go.sum @@ -154,8 +154,8 @@ github.com/canonical/lxd v0.0.0-20231214113525-e676fc63c50a h1:Tfo/MzXK5GeG7gzSH github.com/canonical/lxd v0.0.0-20231214113525-e676fc63c50a/go.mod h1:UxfHGKFoRjgu1NUA9EFiR++dKvyAiT0h9HT0ffMlzjc= github.com/canonical/pebble v1.9.0 h1:FWVEh1fg3aaW2HNue2Z2eYMwkJEQT8mgMFW3R5Iocn4= github.com/canonical/pebble v1.9.0/go.mod h1:9Qkjmq298g0+9SvM2E5eekkEN4pjHDWhgg9eB2I0tjk= -github.com/canonical/sqlair v0.0.0-20240206143658-d8b84b389a40 h1:ZQhteS4l8oF3adpLPCvJpLFBV638+uLpZtPQ8YLN3jk= -github.com/canonical/sqlair v0.0.0-20240206143658-d8b84b389a40/go.mod h1:T+40I2sXshY3KRxx0QQpqqn6hCibSKJ2KHzjBvJj8T4= +github.com/canonical/sqlair v0.0.0-20240319123511-a48b89bb30aa h1:eRRgzybbMhVtJXwIyrTCTuS66n46Mq0Li/tQF3ZtHtU= +github.com/canonical/sqlair v0.0.0-20240319123511-a48b89bb30aa/go.mod h1:T+40I2sXshY3KRxx0QQpqqn6hCibSKJ2KHzjBvJj8T4= github.com/canonical/x-go v0.0.0-20230522092633-7947a7587f5b h1:Da2fardddn+JDlVEYtrzBLTtyzoyU3nIS0Cf0GvjmwU= github.com/canonical/x-go v0.0.0-20230522092633-7947a7587f5b/go.mod h1:upTK9n6rlqITN9rCN69hdreI37dRDFUk2thlGGD5Cg8= github.com/cenkalti/backoff/v3 v3.0.0 h1:ske+9nBpD9qZsTBoF41nW5L+AIuFBKMeze18XQ3eG1c= diff --git a/internal/worker/changestreampruner/worker.go b/internal/worker/changestreampruner/worker.go index f7e2486f331..bb2a0d6e242 100644 --- a/internal/worker/changestreampruner/worker.go +++ b/internal/worker/changestreampruner/worker.go @@ -162,7 +162,7 @@ func (w *Pruner) prune() (map[string]int64, error) { err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { return errors.Trace(tx.Query(ctx, query).GetAll(&models)) }) - if err != nil { + if err != nil && !errors.Is(err, sqlair.ErrNoRows) { return nil, errors.Trace(err) } @@ -240,13 +240,11 @@ func (w *Pruner) locateLowestWatermark(ctx context.Context, tx *sqlair.TX, names // change_log_witness table, the change stream will put the witness // back in place after the next change log is written. var watermarks []Watermark - if err := tx.Query(ctx, selectWitnessQuery).GetAll(&watermarks); err != nil { - return Watermark{}, errors.Trace(err) - } - - // Nothing to do if there are no watermarks. - if len(watermarks) == 0 { + if err := tx.Query(ctx, selectWitnessQuery).GetAll(&watermarks); errors.Is(err, sqlair.ErrNoRows) { + // Nothing to do if there are no watermarks. return Watermark{}, nil + } else if err != nil { + return Watermark{}, errors.Trace(err) } // Gather all the watermarks that are within the windowed time period. @@ -261,7 +259,7 @@ func (w *Pruner) locateLowestWatermark(ctx context.Context, tx *sqlair.TX, names if err := tx.Query(ctx, selectChangeLogQuery, sqlair.M{ "created_at": w.cfg.Clock.Now().Add(-defaultWindowDuration), "limit": changestream.DefaultNumTermWatermarks + 1, - }).GetAll(&changes); err != nil { + }).GetAll(&changes); err != nil && !errors.Is(err, sqlair.ErrNoRows) { return Watermark{}, errors.Trace(err) } // If there are less than the default number of term watermarks, then