diff --git a/cmd/jujud-controller/agent/dbrepl.go b/cmd/jujud-controller/agent/dbrepl.go index 79491a35da2..2ec3f36b299 100644 --- a/cmd/jujud-controller/agent/dbrepl.go +++ b/cmd/jujud-controller/agent/dbrepl.go @@ -146,6 +146,8 @@ This is a DB REPL (Read-Eval-Print Loop) environment. You can run arbitrary code here, including code that can modify the state of the system. Be careful! + +Type '.help' for help. ` ) diff --git a/internal/worker/dbrepl/worker.go b/internal/worker/dbrepl/worker.go index ed9a1679417..2a558f4af22 100644 --- a/internal/worker/dbrepl/worker.go +++ b/internal/worker/dbrepl/worker.go @@ -10,10 +10,11 @@ import ( "io" "os" "strings" - "text/tabwriter" "time" "github.com/chzyer/readline" + "github.com/fatih/color" + "github.com/juju/ansiterm" "github.com/juju/errors" "gopkg.in/tomb.v2" @@ -79,7 +80,10 @@ type dbReplWorker struct { cfg WorkerConfig tomb tomb.Tomb - dbGetter coredatabase.DBGetter + dbGetter coredatabase.DBGetter + controllerDB database.TxnRunner + currentDB database.TxnRunner + currentNamespace string } // NewWorker creates a new dbaccessor worker. @@ -89,9 +93,17 @@ func NewWorker(cfg WorkerConfig) (*dbReplWorker, error) { return nil, errors.Trace(err) } + controllerDB, err := cfg.DBGetter.GetDB(database.ControllerNS) + if err != nil { + return nil, errors.Annotate(err, "getting controller db") + } + w := &dbReplWorker{ - cfg: cfg, - dbGetter: cfg.DBGetter, + cfg: cfg, + dbGetter: cfg.DBGetter, + controllerDB: controllerDB, + currentDB: controllerDB, + currentNamespace: "*", } w.tomb.Go(w.loop) @@ -156,11 +168,12 @@ func (w *dbReplWorker) loop() (err error) { if err != nil { return errors.Annotate(err, "failed to get db") } - controllerDB := currentDB - currentNamespace := "controller" + w.controllerDB = currentDB + w.currentNamespace = "controller" close(done) + // Allow the line to be closed when the worker is dying. go func() { defer line.Close() @@ -180,7 +193,7 @@ func (w *dbReplWorker) loop() (err error) { default: } - line.SetPrompt("repl (" + currentNamespace + ")> ") + line.SetPrompt("repl (" + w.currentNamespace + ")> ") if err != nil { return errors.Annotate(err, "failed to read input") } @@ -206,65 +219,120 @@ func (w *dbReplWorker) loop() (err error) { } switch args[0] { - case ".exit": + case ".exit", ".quit": return worker.ErrTerminateAgent - case ".help": - fmt.Fprintf(w.cfg.Stdout, helpText) - continue + case ".help", ".h": + fmt.Fprint(w.cfg.Stdout, helpText) case ".switch": - if len(args) != 2 { - fmt.Fprintln(w.cfg.Stderr, "usage: .switch ") - continue - } + w.execSwitch(ctx, args[1:]) + case ".models": + w.execModels(ctx) + case ".tables": + w.execTables(ctx) + case ".triggers": + w.execTriggers(ctx) + case ".views": + w.execViews(ctx) + case ".ddl": + w.execShowDDL(ctx, args[1:]) - argName := args[1] - if argName == "controller" { - currentDB = controllerDB - currentNamespace = argName - continue - } - parts := strings.Split(argName, "-") - if len(parts) != 2 { - fmt.Fprintln(w.cfg.Stderr, "invalid namespace name") - continue - } else if parts[0] != "model" { - fmt.Fprintln(w.cfg.Stderr, "invalid model namespace name") - continue + default: + if err := w.executeQuery(ctx, w.currentDB, input); err != nil { + w.cfg.Logger.Errorf("failed to execute query: %v", err) } - name := parts[1] + } + } +} - var uuid string - if err := controllerDB.StdTxn(ctx, func(ctx context.Context, tx *sql.Tx) error { - row := tx.QueryRowContext(ctx, "SELECT uuid FROM model WHERE name=?", name) - if err := row.Scan(&uuid); err != nil { - return err - } - return nil - }); errors.Is(err, sql.ErrNoRows) { - fmt.Fprintf(w.cfg.Stderr, "model %q not found\n", name) - continue - } else if err != nil { - fmt.Fprintf(w.cfg.Stderr, "failed to select %q database: %v\n", name, err) - continue - } +func (w *dbReplWorker) execSwitch(ctx context.Context, args []string) { + if len(args) != 1 { + fmt.Fprintln(w.cfg.Stderr, "usage: .switch ") + return + } - currentDB, err = w.dbGetter.GetDB(uuid) - if err != nil { - fmt.Fprintf(w.cfg.Stderr, "failed to switch to namespace %q: %v\n", name, err) - continue - } - currentNamespace = argName + argName := args[0] + if argName == "controller" { + w.currentDB = w.controllerDB + w.currentNamespace = "*" + return + } - case ".models": - if err := w.executeQuery(ctx, controllerDB, "SELECT uuid, name FROM model;"); err != nil { - w.cfg.Logger.Errorf("failed to execute query: %v", err) - } + parts := strings.Split(argName, "-") + if len(parts) != 2 { + fmt.Fprintln(w.cfg.Stderr, "invalid namespace name") + return + } else if parts[0] != "model" { + fmt.Fprintln(w.cfg.Stderr, "invalid model namespace name") + return + } + name := parts[1] - default: - if err := w.executeQuery(ctx, currentDB, input); err != nil { - w.cfg.Logger.Errorf("failed to execute query: %v", err) - } + var uuid string + if err := w.controllerDB.StdTxn(ctx, func(ctx context.Context, tx *sql.Tx) error { + row := tx.QueryRowContext(ctx, "SELECT uuid FROM model WHERE name=?", name) + if err := row.Scan(&uuid); err != nil { + return err } + return nil + }); errors.Is(err, sql.ErrNoRows) { + fmt.Fprintf(w.cfg.Stderr, "model %q not found\n", name) + return + } else if err != nil { + fmt.Fprintf(w.cfg.Stderr, "failed to select %q database: %v\n", name, err) + return + } + + var err error + w.currentDB, err = w.dbGetter.GetDB(uuid) + if err != nil { + fmt.Fprintf(w.cfg.Stderr, "failed to switch to namespace %q: %v\n", name, err) + return + } + w.currentNamespace = argName +} + +func (w *dbReplWorker) execModels(ctx context.Context) { + if err := w.executeQuery(ctx, w.controllerDB, "SELECT uuid, name FROM model"); err != nil { + w.cfg.Logger.Errorf("failed to execute query: %v", err) + } +} + +func (w *dbReplWorker) execTables(ctx context.Context) { + if err := w.executeQuery(ctx, w.currentDB, "SELECT name AS table_name FROM sqlite_master WHERE type='table'"); err != nil { + w.cfg.Logger.Errorf("failed to execute query: %v", err) + } +} + +func (w *dbReplWorker) execShowDDL(ctx context.Context, args []string) { + if len(args) != 1 { + fmt.Fprintln(w.cfg.Stderr, "usage: .ddl ") + return + } + + name := args[0] + var ddl string + if err := w.currentDB.StdTxn(ctx, func(ctx context.Context, tx *sql.Tx) error { + row := tx.QueryRowContext(ctx, "SELECT sql FROM sqlite_master WHERE name=?", name) + if err := row.Scan(&ddl); err != nil { + return err + } + return nil + }); err != nil { + w.cfg.Logger.Errorf("failed to execute query: %v\n", err) + } + + fmt.Fprintln(w.cfg.Stdout, ddl) +} + +func (w *dbReplWorker) execTriggers(ctx context.Context) { + if err := w.executeQuery(ctx, w.currentDB, "SELECT name AS trigger_name FROM sqlite_master WHERE type='trigger'"); err != nil { + w.cfg.Logger.Errorf("failed to execute query: %v", err) + } +} + +func (w *dbReplWorker) execViews(ctx context.Context) { + if err := w.executeQuery(ctx, w.currentDB, "SELECT name AS view_name FROM sqlite_master WHERE type='view'"); err != nil { + w.cfg.Logger.Errorf("failed to execute query: %v", err) } } @@ -301,10 +369,15 @@ func (w *dbReplWorker) executeQuery(ctx context.Context, db database.TxnRunner, } n := len(columns) + headerStyle := color.New(color.Bold) var sb strings.Builder - writer := tabwriter.NewWriter(&sb, 0, 8, 1, '\t', 0) + + // Use the ansiterm tabwriter because the stdlib tabwriter contains a bug + // which breaks if there are color codes. Our own tabwriter implementation + // doesn't have this issue. + writer := ansiterm.NewTabWriter(&sb, 0, 8, 1, '\t', 0) for _, col := range columns { - fmt.Fprintf(writer, "%s\t", col) + headerStyle.Fprintf(writer, "%s\t", col) } fmt.Fprintln(writer) @@ -353,10 +426,14 @@ func filterInput(r rune) (rune, bool) { const helpText = ` The following commands are available: - .exit Exit the REPL. - .help Show this help message. + .exit, .quit Exit the REPL. + .help, .h Show this help message. .models Show all models. - .switch Switch to a different model (or global database). + .switch Switch to a different model (or global database). + .tables Show all standard tables in the current database. + .triggers Show all trigger tables in the current database. + .views Show all views in the current database. + .ddl Show the DDL for the specified table, trigger, or view. The global database can be accessed by using the '*' or 'global' keyword when switching databases.