diff --git a/agent/agentbootstrap/bootstrap.go b/agent/agentbootstrap/bootstrap.go index 390cf1379d8..73200dfd4cd 100644 --- a/agent/agentbootstrap/bootstrap.go +++ b/agent/agentbootstrap/bootstrap.go @@ -243,6 +243,7 @@ func (b *AgentBootstrap) Initialize(ctx stdcontext.Context) (_ *state.Controller ccbootstrap.InsertInitialControllerConfig(stateParams.ControllerConfig), cloudbootstrap.InsertCloud(stateParams.ControllerCloud), credbootstrap.InsertCredential(credential.IdFromTag(cloudCredTag), cloudCred), + cloudbootstrap.SetCloudDefaults(stateParams.ControllerCloud.Name, stateParams.ControllerInheritedConfig), modelbootstrap.CreateModel(controllerUUID, controllerModelArgs), ), database.BootstrapModelConcern(controllerUUID, diff --git a/domain/cloud/bootstrap/bootstrap.go b/domain/cloud/bootstrap/bootstrap.go index 1948dad2333..00150d48af3 100644 --- a/domain/cloud/bootstrap/bootstrap.go +++ b/domain/cloud/bootstrap/bootstrap.go @@ -6,6 +6,7 @@ package bootstrap import ( "context" "database/sql" + "fmt" "github.com/juju/errors" "github.com/juju/utils/v3" @@ -13,6 +14,7 @@ import ( "github.com/juju/juju/cloud" "github.com/juju/juju/core/database" "github.com/juju/juju/domain/cloud/state" + modelconfigservice "github.com/juju/juju/domain/modelconfig/service" ) // InsertCloud inserts the initial cloud during bootstrap. @@ -30,3 +32,29 @@ func InsertCloud(cloud cloud.Cloud) func(context.Context, database.TxnRunner) er })) } } + +// SetCloudDefaults is responsible for setting a previously inserted cloud's +// default config values that will be used as part of the default values +// supplied to a models config. If no cloud exists for the specified name an +// error satisfying [github.com/juju/juju/domain/cloud/errors.NotFound] will be +// returned. +func SetCloudDefaults( + cloudName string, + defaults map[string]any, +) func(context.Context, database.TxnRunner) error { + return func(ctx context.Context, db database.TxnRunner) error { + strDefaults, err := modelconfigservice.CoerceConfigForStorage(defaults) + if err != nil { + return fmt.Errorf("coercing cloud %q default values for storage: %w", cloudName, err) + } + + err = db.StdTxn(ctx, func(ctx context.Context, tx *sql.Tx) error { + return state.SetCloudDefaults(ctx, tx, cloudName, strDefaults) + }) + + if err != nil { + return fmt.Errorf("setting cloud %q bootstrap defaults: %w", cloudName, err) + } + return nil + } +} diff --git a/domain/cloud/bootstrap/bootstrap_test.go b/domain/cloud/bootstrap/bootstrap_test.go index 8018c0f210b..2b480aeb77b 100644 --- a/domain/cloud/bootstrap/bootstrap_test.go +++ b/domain/cloud/bootstrap/bootstrap_test.go @@ -10,6 +10,8 @@ import ( gc "gopkg.in/check.v1" "github.com/juju/juju/cloud" + clouderrors "github.com/juju/juju/domain/cloud/errors" + "github.com/juju/juju/domain/cloud/state" schematesting "github.com/juju/juju/domain/schema/testing" ) @@ -29,3 +31,87 @@ func (s *bootstrapSuite) TestInsertCloud(c *gc.C) { c.Assert(row.Scan(&name), jc.ErrorIsNil) c.Assert(name, gc.Equals, "cirrus") } + +// TestSetCloudDefaultsNoExist is check that if we try and set cloud defaults +// for a cloud that doesn't exist we get a [clouderrors.NotFound] error back +func (s *bootstrapSuite) TestSetCloudDefaultsNoExist(c *gc.C) { + set := SetCloudDefaults("noexist", map[string]any{ + "HTTP_PROXY": "[2001:0DB8::1]:80", + }) + + err := set(context.Background(), s.TxnRunner()) + c.Check(err, jc.ErrorIs, clouderrors.NotFound) + + var count int + row := s.DB().QueryRow("SELECT count(*) FROM cloud_defaults") + err = row.Scan(&count) + c.Check(err, jc.ErrorIsNil) + c.Check(count, gc.Equals, 0) +} + +// TestSetCloudDefaults is testing the happy path for setting cloud defaults. +func (s *bootstrapSuite) TestSetCloudDefaults(c *gc.C) { + cld := cloud.Cloud{ + Name: "cirrus", + Type: "ec2", + AuthTypes: cloud.AuthTypes{cloud.UserPassAuthType}, + } + err := InsertCloud(cld)(context.Background(), s.TxnRunner()) + c.Check(err, jc.ErrorIsNil) + + set := SetCloudDefaults("cirrus", map[string]any{ + "HTTP_PROXY": "[2001:0DB8::1]:80", + }) + + err = set(context.Background(), s.TxnRunner()) + c.Check(err, jc.ErrorIsNil) + + st := state.NewState(s.TxnRunnerFactory()) + defaults, err := st.CloudDefaults(context.Background(), "cirrus") + c.Check(err, jc.ErrorIsNil) + c.Check(defaults, jc.DeepEquals, map[string]string{ + "HTTP_PROXY": "[2001:0DB8::1]:80", + }) +} + +// TestSetCloudDefaultsOverrides is testing that repeated calls to +// [SetCloudDefaults] overrides existing cloud defaults that have been set. +func (s *bootstrapSuite) TestSetCloudDefaultsOverides(c *gc.C) { + cld := cloud.Cloud{ + Name: "cirrus", + Type: "ec2", + AuthTypes: cloud.AuthTypes{cloud.UserPassAuthType}, + } + err := InsertCloud(cld)(context.Background(), s.TxnRunner()) + c.Check(err, jc.ErrorIsNil) + + set := SetCloudDefaults("cirrus", map[string]any{ + "HTTP_PROXY": "[2001:0DB8::1]:80", + }) + + err = set(context.Background(), s.TxnRunner()) + c.Check(err, jc.ErrorIsNil) + + st := state.NewState(s.TxnRunnerFactory()) + defaults, err := st.CloudDefaults(context.Background(), "cirrus") + c.Check(err, jc.ErrorIsNil) + c.Check(defaults, jc.DeepEquals, map[string]string{ + "HTTP_PROXY": "[2001:0DB8::1]:80", + }) + + // Second time around + + set = SetCloudDefaults("cirrus", map[string]any{ + "foo": "bar", + }) + + err = set(context.Background(), s.TxnRunner()) + c.Check(err, jc.ErrorIsNil) + + st = state.NewState(s.TxnRunnerFactory()) + defaults, err = st.CloudDefaults(context.Background(), "cirrus") + c.Check(err, jc.ErrorIsNil) + c.Check(defaults, jc.DeepEquals, map[string]string{ + "foo": "bar", + }) +} diff --git a/domain/cloud/errors/errors.go b/domain/cloud/errors/errors.go new file mode 100644 index 00000000000..b36f970fc55 --- /dev/null +++ b/domain/cloud/errors/errors.go @@ -0,0 +1,14 @@ +// Copyright 2023 Canonical Ltd. +// Licensed under the AGPLv3, see LICENCE file for details. + +package errors + +import ( + "github.com/juju/errors" +) + +const ( + // NotFound describes an error that occurs when the cloud being operated on + // does not exist. + NotFound = errors.ConstError("cloud not found") +) diff --git a/domain/cloud/state/state.go b/domain/cloud/state/state.go index 1d24c69de37..faace8bebcc 100644 --- a/domain/cloud/state/state.go +++ b/domain/cloud/state/state.go @@ -17,6 +17,7 @@ import ( coredatabase "github.com/juju/juju/core/database" "github.com/juju/juju/core/watcher" "github.com/juju/juju/domain" + clouderrors "github.com/juju/juju/domain/cloud/errors" "github.com/juju/juju/domain/model" "github.com/juju/juju/internal/database" ) @@ -53,36 +54,71 @@ func (st *State) ListClouds(ctx context.Context, name string) ([]cloud.Cloud, er // cloud has no defaults or the cloud does not exist a nil error is returned // with an empty defaults map. func (st *State) CloudDefaults(ctx context.Context, cloudName string) (map[string]string, error) { - defaults := map[string]string{} - db, err := st.DB() if err != nil { - return defaults, errors.Trace(err) - } - + return nil, fmt.Errorf("getting database for setting cloud %q defaults: %w", cloudName, err) + } + + // This might look like an odd way to query for cloud defaults but by doing + // a left join onto the cloud table we are always guaranteed at least one + // row to be returned. This lets us confirm that a cloud actually exists + // for the name. + // The reason for going to so much effort for seeing if the cloud exists is + // so we can return an error if a cloud has been asked for that doesn't + // exist. This is important as it will let us potentially identify bad logic + // problems in Juju early where we have logic that might go off the rails + // with bad values that make their way down to state. stmt := ` -SELECT key, value -FROM cloud_defaults - INNER JOIN cloud - ON cloud_defaults.cloud_uuid = cloud.uuid -WHERE cloud.name = ? +SELECT cloud_defaults.key, + cloud_defaults.value, + cloud.uuid +FROM cloud +LEFT JOIN cloud_defaults ON cloud.uuid = cloud_defaults.cloud_uuid +WHERE cloud.name = ? ` - return defaults, db.StdTxn(ctx, func(ctx context.Context, tx *sql.Tx) error { + rval := make(map[string]string) + err = db.StdTxn(ctx, func(ctx context.Context, tx *sql.Tx) error { rows, err := tx.QueryContext(ctx, stmt, cloudName) - if err != nil { - return fmt.Errorf("fetching cloud %q defaults: %w", cloudName, err) + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("%w %q", clouderrors.NotFound, cloudName) + } else if err != nil { + return fmt.Errorf("getting cloud %q defaults: %w", cloudName, err) } + defer func() { _ = rows.Close() }() - var key, value string + var ( + cloudUUID string + key, value sql.NullString + ) for rows.Next() { - if err := rows.Scan(&key, &value); err != nil { - return fmt.Errorf("compiling cloud %q defaults: %w", cloudName, stderrors.Join(err, rows.Close())) + if err := rows.Scan(&key, &value, &cloudUUID); err != nil { + return fmt.Errorf("reading cloud %q default: %w", cloudName, err) } - defaults[key] = value + if !key.Valid { + // If the key is null it means there is no defaults set for the + // cloud. We can safely just continue because the next iteration + // of rows will return done. + continue + } + rval[key.String] = value.String + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("reading cloud %q defaults: %w", cloudName, err) } + // If cloudUUID is the zero value it means no cloud exists for cloudName. + if cloudUUID == "" { + return fmt.Errorf("%w %q", clouderrors.NotFound, cloudName) + } + return nil }) + + if err != nil { + return nil, err + } + return rval, nil } // UpdateCloudDefaults is responsible for updating default config values for a @@ -816,3 +852,53 @@ func (st *State) WatchCloud( result, err := getWatcher("cloud", uuid, changestream.All) return result, errors.Annotatef(err, "watching cloud") } + +// SetCloudDefaults is responsible for removing any previously set cloud +// default values and setting the new cloud defaults to use. If no defaults are +// supplied to this function then the currently set cloud default values will be +// removed and no further operations will be be +// performed. If no cloud exists for the cloud name then an error satisfying +// [clouderrors.NotFound] is returned. +func SetCloudDefaults( + ctx context.Context, + tx *sql.Tx, + cloudName string, + defaults map[string]string, +) error { + cloudUUIDStmt := "SELECT uuid FROM cloud WHERE name = ?" + + var cloudUUID string + row := tx.QueryRowContext(ctx, cloudUUIDStmt, cloudName) + err := row.Scan(&cloudUUID) + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("%w %q", clouderrors.NotFound, cloudName) + } else if err != nil { + return fmt.Errorf("getting cloud %q uuid to set cloud model defaults: %w", cloudName, err) + } + + deleteStmt := "DELETE FROM cloud_defaults WHERE cloud_defaults.cloud_uuid = ?" + _, err = tx.ExecContext(ctx, deleteStmt, cloudUUID) + if err != nil { + return fmt.Errorf("removing previously set cloud %q model defaults: %w", cloudName, err) + } + + if len(defaults) == 0 { + return nil + } + + bindStr, args := database.MapToMultiPlaceholderTransform(defaults, func(k, v string) []any { + return []any{cloudUUID, k, v} + }) + + insertStmt := fmt.Sprintf( + "INSERT INTO cloud_defaults (cloud_uuid, key, value) VALUES %s", + bindStr, + ) + + _, err = tx.ExecContext(ctx, insertStmt, args...) + if err != nil { + return fmt.Errorf("setting cloud %q model defaults: %w", cloudName, err) + } + + return nil +} diff --git a/domain/cloud/state/state_test.go b/domain/cloud/state/state_test.go index 7aa89e80dba..60a229e0f02 100644 --- a/domain/cloud/state/state_test.go +++ b/domain/cloud/state/state_test.go @@ -22,6 +22,7 @@ import ( "github.com/juju/juju/core/watcher" "github.com/juju/juju/core/watcher/eventsource" "github.com/juju/juju/core/watcher/watchertest" + clouderrors "github.com/juju/juju/domain/cloud/errors" "github.com/juju/juju/domain/model" modelstate "github.com/juju/juju/domain/model/state" modeltesting "github.com/juju/juju/domain/model/testing" @@ -527,10 +528,13 @@ func (s *stateSuite) TestEmptyCloudDefaults(c *gc.C) { c.Assert(len(defaults), gc.Equals, 0) } -func (s *stateSuite) TestNonFoundCloudDefaults(c *gc.C) { +// TestNotFoundCloudDefaults is testing what happens if we request a cloud +// defaults for a cloud that doesn't exist. It should result in a +// [clouderrors.NotFound] error. +func (s *stateSuite) TestNotFoundCloudDefaults(c *gc.C) { st := NewState(s.TxnRunnerFactory()) defaults, err := st.CloudDefaults(context.Background(), "notfound") - c.Assert(err, jc.ErrorIsNil) + c.Assert(err, jc.ErrorIs, clouderrors.NotFound) c.Assert(len(defaults), gc.Equals, 0) } @@ -737,3 +741,109 @@ func (s *stateSuite) TestNullCloudType(c *gc.C) { }) c.Assert(jujudb.IsErrConstraintNotNull(err), jc.IsTrue) } + +// TestSetCloudDefaults is testing the happy path for [SetCloudDefaults] +func (s *stateSuite) TestSetCloudDefaults(c *gc.C) { + cld := testCloud + st := NewState(s.TxnRunnerFactory()) + err := st.UpsertCloud(ctx.Background(), cld) + c.Assert(err, jc.ErrorIsNil) + + err = s.TxnRunner().StdTxn(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return SetCloudDefaults(ctx, tx, cld.Name, map[string]string{ + "clouddefault": "one", + }) + }) + c.Check(err, jc.ErrorIsNil) + + defaults, err := st.CloudDefaults(context.Background(), cld.Name) + c.Check(err, jc.ErrorIsNil) + c.Check(defaults, jc.DeepEquals, map[string]string{ + "clouddefault": "one", + }) +} + +// TestSetCloudDefaultsNotFound is asserting that if we try and set cloud +// defaults for a cloud that doesn't exist we get back an error that satisfies +// [clouderrors.NotFound]. +func (s *stateSuite) TestSetCloudDefaultsNotFound(c *gc.C) { + st := NewState(s.TxnRunnerFactory()) + + err := s.TxnRunner().StdTxn(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return SetCloudDefaults(ctx, tx, "noexist", map[string]string{ + "clouddefault": "one", + }) + }) + c.Check(err, jc.ErrorIs, clouderrors.NotFound) + + defaults, err := st.CloudDefaults(context.Background(), "noexist") + c.Check(err, jc.ErrorIs, clouderrors.NotFound) + c.Check(len(defaults), gc.Equals, 0) +} + +// TestSetCloudDefaultsOverrides checks that successive calls to +// SetCloudDefaults overrides the previously set values for cloud defaults. +func (s *stateSuite) TestSetCloudDefaultsOverrides(c *gc.C) { + cld := testCloud + st := NewState(s.TxnRunnerFactory()) + err := st.UpsertCloud(ctx.Background(), cld) + c.Assert(err, jc.ErrorIsNil) + + err = s.TxnRunner().StdTxn(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return SetCloudDefaults(ctx, tx, cld.Name, map[string]string{ + "clouddefault": "one", + }) + }) + c.Check(err, jc.ErrorIsNil) + + defaults, err := st.CloudDefaults(context.Background(), cld.Name) + c.Check(err, jc.ErrorIsNil) + c.Check(defaults, jc.DeepEquals, map[string]string{ + "clouddefault": "one", + }) + + err = s.TxnRunner().StdTxn(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return SetCloudDefaults(ctx, tx, cld.Name, map[string]string{ + "clouddefaultnew": "two", + }) + }) + c.Check(err, jc.ErrorIsNil) + + defaults, err = st.CloudDefaults(context.Background(), cld.Name) + c.Check(err, jc.ErrorIsNil) + c.Check(defaults, jc.DeepEquals, map[string]string{ + "clouddefaultnew": "two", + }) +} + +// TestSetCloudDefaultsDelete is testing that if we call [SetCloudDefaults] with +// a empty map of defaults the existing cloud defaults are removed and no +// further actions are taken. +func (s *stateSuite) TestSetCloudDefaultsDelete(c *gc.C) { + cld := testCloud + st := NewState(s.TxnRunnerFactory()) + err := st.UpsertCloud(ctx.Background(), cld) + c.Assert(err, jc.ErrorIsNil) + + err = s.TxnRunner().StdTxn(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return SetCloudDefaults(ctx, tx, cld.Name, map[string]string{ + "clouddefault": "one", + }) + }) + c.Check(err, jc.ErrorIsNil) + + defaults, err := st.CloudDefaults(context.Background(), cld.Name) + c.Check(err, jc.ErrorIsNil) + c.Check(defaults, jc.DeepEquals, map[string]string{ + "clouddefault": "one", + }) + + err = s.TxnRunner().StdTxn(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return SetCloudDefaults(ctx, tx, cld.Name, nil) + }) + c.Check(err, jc.ErrorIsNil) + + defaults, err = st.CloudDefaults(context.Background(), cld.Name) + c.Check(err, jc.ErrorIsNil) + c.Check(len(defaults), gc.Equals, 0) +} diff --git a/internal/database/statement.go b/internal/database/statement.go index f899ac2ff62..f8fe783926d 100644 --- a/internal/database/statement.go +++ b/internal/database/statement.go @@ -63,6 +63,38 @@ func MapToMultiPlaceholder[K comparable, V any](m map[K]V) (string, []any) { return strings.Join(binds, ","), vals } +// MapToMultiPlaceholderTransform returns a string of bind args for map key +// value inserts. A transform function is supplied so that for each key value +// pair in the map a slice of values to be inserted can be returned to build bind +// arguments from. The number bind arguments for each bind set will based on the +// number of values in the returned slice. +// +// Example usage: +// +// myMap := map[string]string{"one": "two", "three": "four"} +// MapToMultiPlaceholderTransform(myMap, func(k, v string) []any { +// return []any{"staticval", k, v} +// }) +func MapToMultiPlaceholderTransform[K comparable, V any](m map[K]V, trans func(K, V) []any) (string, []any) { + binds := make([]string, 0, len(m)) + vals := make([]any, 0, len(m)*2) + for k, v := range m { + tupleVals := trans(k, v) + if len(tupleVals) == 0 { + continue + } + + tupleBinds := make([]string, len(tupleVals)) + for i := range tupleBinds { + tupleBinds[i] = "?" + } + binds = append(binds, fmt.Sprintf("(%s)", strings.Join(tupleBinds, ","))) + vals = append(vals, tupleVals...) + } + + return strings.Join(binds, ","), vals +} + // SqlairClauseAnd creates a sqlair query condition where each // of the non-empty map values becomes an AND operator. func SqlairClauseAnd(columnValues map[string]any) (string, sqlair.M) { diff --git a/internal/database/statement_test.go b/internal/database/statement_test.go index 81160081dde..8a21b646ef4 100644 --- a/internal/database/statement_test.go +++ b/internal/database/statement_test.go @@ -4,6 +4,8 @@ package database import ( + "fmt" + "github.com/canonical/sqlair" "github.com/juju/testing" jc "github.com/juju/testing/checkers" @@ -122,3 +124,52 @@ func (s *statementSuite) TestMapToMultiPlaceholder(c *gc.C) { } c.Assert(count, gc.Equals, 9) } + +func ExampleMapToMultiPlaceholderTransform() { + myMap := map[string]string{"one": "two"} + bindStmt, args := MapToMultiPlaceholderTransform(myMap, func(k, v string) []any { + return []any{"staticvalue", k, v} + }) + + fmt.Println(bindStmt) + fmt.Println(args) + // Output: + // (?,?,?) + // [staticvalue one two] +} + +// TestMapToMultiPlaceholderTransformNoVals is testing that if the transform +// func returns no values that an empty string and no arg values are being +// returned. +func (s *statementSuite) TestMapToMultiPlaceholderTransformNoVals(c *gc.C) { + bind, args := MapToMultiPlaceholderTransform( + map[string]string{"test": "test"}, + func(k, v string) []any { return nil }, + ) + c.Check(bind, gc.Equals, "") + c.Check(len(args), gc.Equals, 0) +} + +// TestMapToMultiPlaceholderTransform is testing the happy path. +func (s *statementSuite) TestMapToMultiPlaceholderTransform(c *gc.C) { + bind, args := MapToMultiPlaceholderTransform( + map[string]string{"test": "foobar"}, + func(k, v string) []any { return []any{k, v} }, + ) + c.Check(bind, gc.Equals, "(?,?)") + c.Check(args, jc.DeepEquals, []any{"test", "foobar"}) + + bind, args = MapToMultiPlaceholderTransform( + map[string]string{"test": "foobar", "ipv6": "isgreat"}, + func(k, v string) []any { return []any{k, v} }, + ) + c.Check(bind, gc.Equals, "(?,?),(?,?)") + c.Check(len(args), gc.Equals, 4) + + bind, args = MapToMultiPlaceholderTransform( + map[string]string{"test": "foobar", "ipv6": "isgreat"}, + func(k, v string) []any { return []any{k, v, "staticval"} }, + ) + c.Check(bind, gc.Equals, "(?,?,?),(?,?,?)") + c.Check(len(args), gc.Equals, 6) +}