Skip to content

Commit

Permalink
Merge pull request juju#18577 from jack-w-shaw/db-repl
Browse files Browse the repository at this point in the history
juju#18577

Add a number of new helper functions to the db-repl

- `.tables` lists all tables in a db
- `.triggers` lists all triggers in a db
- `.views` lists all views in a db
- `.ddl` shows the ddl for a table

## QA steps
  • Loading branch information
jujubot authored Jan 8, 2025
2 parents 67bcf52 + 0b94483 commit 72912a7
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 62 deletions.
2 changes: 2 additions & 0 deletions cmd/jujud-controller/agent/dbrepl.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ This is a DB REPL (Read-Eval-Print Loop) environment.
You can run arbitrary code here, including code that can modify the
state of the system. Be careful!
Type '.help' for help.
`
)

Expand Down
201 changes: 139 additions & 62 deletions internal/worker/dbrepl/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import (
"io"
"os"
"strings"
"text/tabwriter"
"time"

"github.com/chzyer/readline"
"github.com/fatih/color"
"github.com/juju/ansiterm"
"github.com/juju/errors"
"gopkg.in/tomb.v2"

Expand Down Expand Up @@ -79,7 +80,10 @@ type dbReplWorker struct {
cfg WorkerConfig
tomb tomb.Tomb

dbGetter coredatabase.DBGetter
dbGetter coredatabase.DBGetter
controllerDB database.TxnRunner
currentDB database.TxnRunner
currentNamespace string
}

// NewWorker creates a new dbaccessor worker.
Expand All @@ -89,9 +93,17 @@ func NewWorker(cfg WorkerConfig) (*dbReplWorker, error) {
return nil, errors.Trace(err)
}

controllerDB, err := cfg.DBGetter.GetDB(database.ControllerNS)
if err != nil {
return nil, errors.Annotate(err, "getting controller db")
}

w := &dbReplWorker{
cfg: cfg,
dbGetter: cfg.DBGetter,
cfg: cfg,
dbGetter: cfg.DBGetter,
controllerDB: controllerDB,
currentDB: controllerDB,
currentNamespace: "*",
}

w.tomb.Go(w.loop)
Expand Down Expand Up @@ -156,11 +168,12 @@ func (w *dbReplWorker) loop() (err error) {
if err != nil {
return errors.Annotate(err, "failed to get db")
}
controllerDB := currentDB
currentNamespace := "controller"
w.controllerDB = currentDB
w.currentNamespace = "controller"

close(done)

// Allow the line to be closed when the worker is dying.
go func() {
defer line.Close()

Expand All @@ -180,7 +193,7 @@ func (w *dbReplWorker) loop() (err error) {
default:
}

line.SetPrompt("repl (" + currentNamespace + ")> ")
line.SetPrompt("repl (" + w.currentNamespace + ")> ")
if err != nil {
return errors.Annotate(err, "failed to read input")
}
Expand All @@ -206,65 +219,120 @@ func (w *dbReplWorker) loop() (err error) {
}

switch args[0] {
case ".exit":
case ".exit", ".quit":
return worker.ErrTerminateAgent
case ".help":
fmt.Fprintf(w.cfg.Stdout, helpText)
continue
case ".help", ".h":
fmt.Fprint(w.cfg.Stdout, helpText)
case ".switch":
if len(args) != 2 {
fmt.Fprintln(w.cfg.Stderr, "usage: .switch <name>")
continue
}
w.execSwitch(ctx, args[1:])
case ".models":
w.execModels(ctx)
case ".tables":
w.execTables(ctx)
case ".triggers":
w.execTriggers(ctx)
case ".views":
w.execViews(ctx)
case ".ddl":
w.execShowDDL(ctx, args[1:])

argName := args[1]
if argName == "controller" {
currentDB = controllerDB
currentNamespace = argName
continue
}
parts := strings.Split(argName, "-")
if len(parts) != 2 {
fmt.Fprintln(w.cfg.Stderr, "invalid namespace name")
continue
} else if parts[0] != "model" {
fmt.Fprintln(w.cfg.Stderr, "invalid model namespace name")
continue
default:
if err := w.executeQuery(ctx, w.currentDB, input); err != nil {
w.cfg.Logger.Errorf("failed to execute query: %v", err)
}
name := parts[1]
}
}
}

var uuid string
if err := controllerDB.StdTxn(ctx, func(ctx context.Context, tx *sql.Tx) error {
row := tx.QueryRowContext(ctx, "SELECT uuid FROM model WHERE name=?", name)
if err := row.Scan(&uuid); err != nil {
return err
}
return nil
}); errors.Is(err, sql.ErrNoRows) {
fmt.Fprintf(w.cfg.Stderr, "model %q not found\n", name)
continue
} else if err != nil {
fmt.Fprintf(w.cfg.Stderr, "failed to select %q database: %v\n", name, err)
continue
}
func (w *dbReplWorker) execSwitch(ctx context.Context, args []string) {
if len(args) != 1 {
fmt.Fprintln(w.cfg.Stderr, "usage: .switch <name>")
return
}

currentDB, err = w.dbGetter.GetDB(uuid)
if err != nil {
fmt.Fprintf(w.cfg.Stderr, "failed to switch to namespace %q: %v\n", name, err)
continue
}
currentNamespace = argName
argName := args[0]
if argName == "controller" {
w.currentDB = w.controllerDB
w.currentNamespace = "*"
return
}

case ".models":
if err := w.executeQuery(ctx, controllerDB, "SELECT uuid, name FROM model;"); err != nil {
w.cfg.Logger.Errorf("failed to execute query: %v", err)
}
parts := strings.Split(argName, "-")
if len(parts) != 2 {
fmt.Fprintln(w.cfg.Stderr, "invalid namespace name")
return
} else if parts[0] != "model" {
fmt.Fprintln(w.cfg.Stderr, "invalid model namespace name")
return
}
name := parts[1]

default:
if err := w.executeQuery(ctx, currentDB, input); err != nil {
w.cfg.Logger.Errorf("failed to execute query: %v", err)
}
var uuid string
if err := w.controllerDB.StdTxn(ctx, func(ctx context.Context, tx *sql.Tx) error {
row := tx.QueryRowContext(ctx, "SELECT uuid FROM model WHERE name=?", name)
if err := row.Scan(&uuid); err != nil {
return err
}
return nil
}); errors.Is(err, sql.ErrNoRows) {
fmt.Fprintf(w.cfg.Stderr, "model %q not found\n", name)
return
} else if err != nil {
fmt.Fprintf(w.cfg.Stderr, "failed to select %q database: %v\n", name, err)
return
}

var err error
w.currentDB, err = w.dbGetter.GetDB(uuid)
if err != nil {
fmt.Fprintf(w.cfg.Stderr, "failed to switch to namespace %q: %v\n", name, err)
return
}
w.currentNamespace = argName
}

func (w *dbReplWorker) execModels(ctx context.Context) {
if err := w.executeQuery(ctx, w.controllerDB, "SELECT uuid, name FROM model"); err != nil {
w.cfg.Logger.Errorf("failed to execute query: %v", err)
}
}

func (w *dbReplWorker) execTables(ctx context.Context) {
if err := w.executeQuery(ctx, w.currentDB, "SELECT name AS table_name FROM sqlite_master WHERE type='table'"); err != nil {
w.cfg.Logger.Errorf("failed to execute query: %v", err)
}
}

func (w *dbReplWorker) execShowDDL(ctx context.Context, args []string) {
if len(args) != 1 {
fmt.Fprintln(w.cfg.Stderr, "usage: .ddl <name>")
return
}

name := args[0]
var ddl string
if err := w.currentDB.StdTxn(ctx, func(ctx context.Context, tx *sql.Tx) error {
row := tx.QueryRowContext(ctx, "SELECT sql FROM sqlite_master WHERE name=?", name)
if err := row.Scan(&ddl); err != nil {
return err
}
return nil
}); err != nil {
w.cfg.Logger.Errorf("failed to execute query: %v\n", err)
}

fmt.Fprintln(w.cfg.Stdout, ddl)
}

func (w *dbReplWorker) execTriggers(ctx context.Context) {
if err := w.executeQuery(ctx, w.currentDB, "SELECT name AS trigger_name FROM sqlite_master WHERE type='trigger'"); err != nil {
w.cfg.Logger.Errorf("failed to execute query: %v", err)
}
}

func (w *dbReplWorker) execViews(ctx context.Context) {
if err := w.executeQuery(ctx, w.currentDB, "SELECT name AS view_name FROM sqlite_master WHERE type='view'"); err != nil {
w.cfg.Logger.Errorf("failed to execute query: %v", err)
}
}

Expand Down Expand Up @@ -301,10 +369,15 @@ func (w *dbReplWorker) executeQuery(ctx context.Context, db database.TxnRunner,
}
n := len(columns)

headerStyle := color.New(color.Bold)
var sb strings.Builder
writer := tabwriter.NewWriter(&sb, 0, 8, 1, '\t', 0)

// Use the ansiterm tabwriter because the stdlib tabwriter contains a bug
// which breaks if there are color codes. Our own tabwriter implementation
// doesn't have this issue.
writer := ansiterm.NewTabWriter(&sb, 0, 8, 1, '\t', 0)
for _, col := range columns {
fmt.Fprintf(writer, "%s\t", col)
headerStyle.Fprintf(writer, "%s\t", col)
}
fmt.Fprintln(writer)

Expand Down Expand Up @@ -353,10 +426,14 @@ func filterInput(r rune) (rune, bool) {
const helpText = `
The following commands are available:
.exit Exit the REPL.
.help Show this help message.
.exit, .quit Exit the REPL.
.help, .h Show this help message.
.models Show all models.
.switch Switch to a different model (or global database).
.switch <model> Switch to a different model (or global database).
.tables Show all standard tables in the current database.
.triggers Show all trigger tables in the current database.
.views Show all views in the current database.
.ddl <name> Show the DDL for the specified table, trigger, or view.
The global database can be accessed by using the '*' or 'global' keyword
when switching databases.
Expand Down

0 comments on commit 72912a7

Please sign in to comment.