diff --git a/domain/cloud/state/state.go b/domain/cloud/state/state.go index 1d24c69de37..66c35b7431f 100644 --- a/domain/cloud/state/state.go +++ b/domain/cloud/state/state.go @@ -17,6 +17,7 @@ import ( coredatabase "github.com/juju/juju/core/database" "github.com/juju/juju/core/watcher" "github.com/juju/juju/domain" + clouderrors "github.com/juju/juju/domain/cloud/errors" "github.com/juju/juju/domain/model" "github.com/juju/juju/internal/database" ) @@ -53,36 +54,71 @@ func (st *State) ListClouds(ctx context.Context, name string) ([]cloud.Cloud, er // cloud has no defaults or the cloud does not exist a nil error is returned // with an empty defaults map. func (st *State) CloudDefaults(ctx context.Context, cloudName string) (map[string]string, error) { - defaults := map[string]string{} - db, err := st.DB() if err != nil { - return defaults, errors.Trace(err) - } - + return nil, fmt.Errorf("getting database for setting cloud %q defaults: %w", cloudName, err) + } + + // This might look like an odd way to query for cloud defaults but by doing + // a left join onto the cloud table we are always guaranteed at least one + // row to be returned. This lets us confirm that a cloud actually exists + // for the name. + // The reason for going to so much effort for seeing if the cloud exists is + // so we can return an error if a cloud has been asked for that doesn't + // exist. This is important as it will let us potentially identify bad logic + // problems in Juju early where we have logic that might go off the rails + // with bad values that make their way down to state. stmt := ` -SELECT key, value -FROM cloud_defaults - INNER JOIN cloud - ON cloud_defaults.cloud_uuid = cloud.uuid -WHERE cloud.name = ? +SELECT cloud_defaults.key, + cloud_defaults.value, + cloud.uuid +FROM cloud +LEFT JOIN cloud_defaults ON cloud.uuid = cloud_defaults.cloud_uuid +WHERE cloud.name = ? ` - return defaults, db.StdTxn(ctx, func(ctx context.Context, tx *sql.Tx) error { + rval := make(map[string]string) + err = db.StdTxn(ctx, func(ctx context.Context, tx *sql.Tx) error { rows, err := tx.QueryContext(ctx, stmt, cloudName) - if err != nil { - return fmt.Errorf("fetching cloud %q defaults: %w", cloudName, err) + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("%w %q", clouderrors.NotFound, cloudName) + } else if err != nil { + return fmt.Errorf("getting cloud %q defaults: %w", cloudName, err) } + defer func() { _ = rows.Close() }() - var key, value string + var ( + cloudUUID string + key, value sql.NullString + ) for rows.Next() { - if err := rows.Scan(&key, &value); err != nil { - return fmt.Errorf("compiling cloud %q defaults: %w", cloudName, stderrors.Join(err, rows.Close())) + if err := rows.Scan(&key, &value, &cloudUUID); err != nil { + return fmt.Errorf("reading cloud %q default: %w", cloudName, err) } - defaults[key] = value + if !key.Valid { + // If the key is null it means there is no defaults set for the + // cloud. We can safely just continue because the next iteration + // of rows will return done. + continue + } + rval[key.String] = value.String + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("reading cloud %q defaults: %w", cloudName, err) } + // If cloudUUID is the zero value it means no cloud exists for cloudName. + if cloudUUID == "" { + return fmt.Errorf("%w %q", clouderrors.NotFound, cloudName) + } + return nil }) + + if err != nil { + return nil, err + } + return rval, nil } // UpdateCloudDefaults is responsible for updating default config values for a @@ -816,3 +852,53 @@ func (st *State) WatchCloud( result, err := getWatcher("cloud", uuid, changestream.All) return result, errors.Annotatef(err, "watching cloud") } + +// SetCloudDefaults is responsible for removing any previously set cloud +// default values and setting the new cloud defaults to use. If no defaults are +// supplied to this function then the currently set cloud default values will be +// removed and no further operations will be be +// performed. If no cloud exists for the cloud name then an error satisfying +// [clouderrors.NotFound] is returned. +func SetCloudDefaults( + ctx context.Context, + tx *sql.Tx, + cloudName string, + defaults map[string]string, +) error { + cloudUUIDStmt := "SELECT uuid FROM cloud WHERE name = ?" + + var cloudUUID string + row := tx.QueryRowContext(ctx, cloudUUIDStmt, cloudName) + err := row.Scan(&cloudUUID) + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("%w %q", clouderrors.NotFound, cloudName) + } else if err != nil { + return fmt.Errorf("getting cloud %q uuid to set cloud model defaults: %w", cloudName, err) + } + + deleteStmt := "DELETE FROM cloud_defaults WHERE cloud_defaults.cloud_uuid = ?" + _, err = tx.ExecContext(ctx, deleteStmt, cloudUUID) + if err != nil { + return fmt.Errorf("removing previously set cloud %q model defaults: %w", cloudName, err) + } + + if len(defaults) == 0 { + return nil + } + + bindStr, args := database.MapToMultiPlaceholderTransform(defaults, func(k, v string) []any { + return []any{cloudUUID, k, v} + }) + + insertStmt := fmt.Sprintf( + "INSERT INTO cloud_defaults (cloud_uuid, key, value) VALUES %s", + bindStr, + ) + + _, err = tx.ExecContext(ctx, insertStmt, args...) + if err != nil { + return fmt.Errorf("setting cloud %q model defaults: %w", cloudName, err) + } + + return nil +} diff --git a/domain/cloud/state/state_test.go b/domain/cloud/state/state_test.go index 7aa89e80dba..60a229e0f02 100644 --- a/domain/cloud/state/state_test.go +++ b/domain/cloud/state/state_test.go @@ -22,6 +22,7 @@ import ( "github.com/juju/juju/core/watcher" "github.com/juju/juju/core/watcher/eventsource" "github.com/juju/juju/core/watcher/watchertest" + clouderrors "github.com/juju/juju/domain/cloud/errors" "github.com/juju/juju/domain/model" modelstate "github.com/juju/juju/domain/model/state" modeltesting "github.com/juju/juju/domain/model/testing" @@ -527,10 +528,13 @@ func (s *stateSuite) TestEmptyCloudDefaults(c *gc.C) { c.Assert(len(defaults), gc.Equals, 0) } -func (s *stateSuite) TestNonFoundCloudDefaults(c *gc.C) { +// TestNotFoundCloudDefaults is testing what happens if we request a cloud +// defaults for a cloud that doesn't exist. It should result in a +// [clouderrors.NotFound] error. +func (s *stateSuite) TestNotFoundCloudDefaults(c *gc.C) { st := NewState(s.TxnRunnerFactory()) defaults, err := st.CloudDefaults(context.Background(), "notfound") - c.Assert(err, jc.ErrorIsNil) + c.Assert(err, jc.ErrorIs, clouderrors.NotFound) c.Assert(len(defaults), gc.Equals, 0) } @@ -737,3 +741,109 @@ func (s *stateSuite) TestNullCloudType(c *gc.C) { }) c.Assert(jujudb.IsErrConstraintNotNull(err), jc.IsTrue) } + +// TestSetCloudDefaults is testing the happy path for [SetCloudDefaults] +func (s *stateSuite) TestSetCloudDefaults(c *gc.C) { + cld := testCloud + st := NewState(s.TxnRunnerFactory()) + err := st.UpsertCloud(ctx.Background(), cld) + c.Assert(err, jc.ErrorIsNil) + + err = s.TxnRunner().StdTxn(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return SetCloudDefaults(ctx, tx, cld.Name, map[string]string{ + "clouddefault": "one", + }) + }) + c.Check(err, jc.ErrorIsNil) + + defaults, err := st.CloudDefaults(context.Background(), cld.Name) + c.Check(err, jc.ErrorIsNil) + c.Check(defaults, jc.DeepEquals, map[string]string{ + "clouddefault": "one", + }) +} + +// TestSetCloudDefaultsNotFound is asserting that if we try and set cloud +// defaults for a cloud that doesn't exist we get back an error that satisfies +// [clouderrors.NotFound]. +func (s *stateSuite) TestSetCloudDefaultsNotFound(c *gc.C) { + st := NewState(s.TxnRunnerFactory()) + + err := s.TxnRunner().StdTxn(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return SetCloudDefaults(ctx, tx, "noexist", map[string]string{ + "clouddefault": "one", + }) + }) + c.Check(err, jc.ErrorIs, clouderrors.NotFound) + + defaults, err := st.CloudDefaults(context.Background(), "noexist") + c.Check(err, jc.ErrorIs, clouderrors.NotFound) + c.Check(len(defaults), gc.Equals, 0) +} + +// TestSetCloudDefaultsOverrides checks that successive calls to +// SetCloudDefaults overrides the previously set values for cloud defaults. +func (s *stateSuite) TestSetCloudDefaultsOverrides(c *gc.C) { + cld := testCloud + st := NewState(s.TxnRunnerFactory()) + err := st.UpsertCloud(ctx.Background(), cld) + c.Assert(err, jc.ErrorIsNil) + + err = s.TxnRunner().StdTxn(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return SetCloudDefaults(ctx, tx, cld.Name, map[string]string{ + "clouddefault": "one", + }) + }) + c.Check(err, jc.ErrorIsNil) + + defaults, err := st.CloudDefaults(context.Background(), cld.Name) + c.Check(err, jc.ErrorIsNil) + c.Check(defaults, jc.DeepEquals, map[string]string{ + "clouddefault": "one", + }) + + err = s.TxnRunner().StdTxn(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return SetCloudDefaults(ctx, tx, cld.Name, map[string]string{ + "clouddefaultnew": "two", + }) + }) + c.Check(err, jc.ErrorIsNil) + + defaults, err = st.CloudDefaults(context.Background(), cld.Name) + c.Check(err, jc.ErrorIsNil) + c.Check(defaults, jc.DeepEquals, map[string]string{ + "clouddefaultnew": "two", + }) +} + +// TestSetCloudDefaultsDelete is testing that if we call [SetCloudDefaults] with +// a empty map of defaults the existing cloud defaults are removed and no +// further actions are taken. +func (s *stateSuite) TestSetCloudDefaultsDelete(c *gc.C) { + cld := testCloud + st := NewState(s.TxnRunnerFactory()) + err := st.UpsertCloud(ctx.Background(), cld) + c.Assert(err, jc.ErrorIsNil) + + err = s.TxnRunner().StdTxn(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return SetCloudDefaults(ctx, tx, cld.Name, map[string]string{ + "clouddefault": "one", + }) + }) + c.Check(err, jc.ErrorIsNil) + + defaults, err := st.CloudDefaults(context.Background(), cld.Name) + c.Check(err, jc.ErrorIsNil) + c.Check(defaults, jc.DeepEquals, map[string]string{ + "clouddefault": "one", + }) + + err = s.TxnRunner().StdTxn(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return SetCloudDefaults(ctx, tx, cld.Name, nil) + }) + c.Check(err, jc.ErrorIsNil) + + defaults, err = st.CloudDefaults(context.Background(), cld.Name) + c.Check(err, jc.ErrorIsNil) + c.Check(len(defaults), gc.Equals, 0) +}