Skip to content

Commit

Permalink
idempotent install
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanndickson committed Aug 13, 2024
1 parent 802129c commit daea696
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 108 deletions.
11 changes: 7 additions & 4 deletions completion.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package serpent

import "strings"
import (
"github.com/spf13/pflag"
)

// CompletionModeEnv is a special environment variable that is
// set when the command is being run in completion mode.
Expand All @@ -12,17 +14,18 @@ func (inv *Invocation) IsCompletionMode() bool {
return ok
}

// DefaultCompletionHandler returns a handler that prints all
// known flags and subcommands that haven't already been set to valid values.
// DefaultCompletionHandler is a handler that prints all known flags and
// subcommands that haven't been exhaustively set.
func DefaultCompletionHandler(inv *Invocation) []string {
var allResps []string
for _, cmd := range inv.Command.Children {
allResps = append(allResps, cmd.Name())
}
for _, opt := range inv.Command.Options {
_, isSlice := opt.Value.(pflag.SliceValue)
if opt.ValueSource == ValueSourceNone ||
opt.ValueSource == ValueSourceDefault ||
strings.Contains(opt.Value.Type(), "array") {
isSlice {
allResps = append(allResps, "--"+opt.Flag)
}
}
Expand Down
162 changes: 134 additions & 28 deletions completion/all.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package completion

import (
"bytes"
"errors"
"fmt"
"io"
"io/fs"
"os"
"os/user"
"path/filepath"
Expand All @@ -13,11 +16,16 @@ import (
"github.com/coder/serpent"
)

const (
completionStartTemplate = `# ============ BEGIN {{.Name}} COMPLETION ============`
completionEndTemplate = `# ============ END {{.Name}} COMPLETION ==============`
)

type Shell interface {
Name() string
InstallPath() (string, error)
UsesOwnFile() bool
WriteCompletion(io.Writer) error
ProgramName() string
}

const (
Expand Down Expand Up @@ -77,56 +85,154 @@ func DetectUserShell(programName string) (Shell, error) {
return nil, fmt.Errorf("default shell not found")
}

func generateCompletion(
scriptTemplate string,
) func(io.Writer, string) error {
return func(w io.Writer, programName string) error {
tmpl, err := template.New("script").Parse(scriptTemplate)
if err != nil {
return fmt.Errorf("parse template: %w", err)
}

err = tmpl.Execute(
w,
map[string]string{
"Name": programName,
},
)
if err != nil {
return fmt.Errorf("execute template: %w", err)
}
func configTemplateWriter(
w io.Writer,
cfgTemplate string,
programName string,
) error {
tmpl, err := template.New("script").Parse(cfgTemplate)
if err != nil {
return fmt.Errorf("parse template: %w", err)
}

return nil
err = tmpl.Execute(
w,
map[string]string{
"Name": programName,
},
)
if err != nil {
return fmt.Errorf("execute template: %w", err)
}

return nil
}

func InstallShellCompletion(shell Shell) error {
path, err := shell.InstallPath()
if err != nil {
return fmt.Errorf("get install path: %w", err)
}
var headerBuf bytes.Buffer
err = configTemplateWriter(&headerBuf, completionStartTemplate, shell.ProgramName())
if err != nil {
return fmt.Errorf("generate header: %w", err)
}

var footerBytes bytes.Buffer
err = configTemplateWriter(&footerBytes, completionEndTemplate, shell.ProgramName())
if err != nil {
return fmt.Errorf("generate footer: %w", err)
}

err = os.MkdirAll(filepath.Dir(path), 0o755)
if err != nil {
return fmt.Errorf("create directories: %w", err)
}

if shell.UsesOwnFile() {
err := os.WriteFile(path, nil, 0o644)
f, err := os.ReadFile(path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("read ssh config failed: %w", err)
}

before, after, err := templateConfigSplit(headerBuf.Bytes(), footerBytes.Bytes(), f)
if err != nil {
return err
}

outBuf := bytes.Buffer{}
_, _ = outBuf.Write(before)
if len(before) > 0 {
_, _ = outBuf.Write([]byte("\n"))
}
_, _ = outBuf.Write(headerBuf.Bytes())
err = shell.WriteCompletion(&outBuf)
if err != nil {
return fmt.Errorf("generate completion: %w", err)
}
_, _ = outBuf.Write(footerBytes.Bytes())
_, _ = outBuf.Write([]byte("\n"))
_, _ = outBuf.Write(after)

err = writeWithTempFileAndMove(path, &outBuf)
if err != nil {
return fmt.Errorf("write completion: %w", err)
}

return nil
}

func templateConfigSplit(header, footer, data []byte) (before, after []byte, err error) {
startCount := bytes.Count(data, header)
endCount := bytes.Count(data, footer)
if startCount > 1 || endCount > 1 {
return nil, nil, fmt.Errorf("Malformed config file: multiple config sections")
}

startIndex := bytes.Index(data, header)
endIndex := bytes.Index(data, footer)
if startIndex == -1 && endIndex != -1 {
return data, nil, fmt.Errorf("Malformed config file: missing completion header")
}
if startIndex != -1 && endIndex == -1 {
return data, nil, fmt.Errorf("Malformed config file: missing completion footer")
}
if startIndex != -1 && endIndex != -1 {
if startIndex > endIndex {
return data, nil, fmt.Errorf("Malformed config file: completion header after footer")
}
// Include leading and trailing newline, if present
start := startIndex
if start > 0 {
start--
}
end := endIndex + len(footer)
if end < len(data) {
end++
}
return data[:start], data[end:], nil
}
return data, nil, nil
}

// writeWithTempFileAndMove writes to a temporary file in the same
// directory as path and renames the temp file to the file provided in
// path. This ensure we avoid trashing the file we are writing due to
// unforeseen circumstance like filesystem full, command killed, etc.
func writeWithTempFileAndMove(path string, r io.Reader) (err error) {
dir := filepath.Dir(path)
name := filepath.Base(path)

if err = os.MkdirAll(dir, 0o700); err != nil {
return fmt.Errorf("create directory: %w", err)
}

// Create a tempfile in the same directory for ensuring write
// operation does not fail.
f, err := os.CreateTemp(dir, fmt.Sprintf(".%s.", name))
if err != nil {
return fmt.Errorf("create temp file failed: %w", err)
}
defer func() {
if err != nil {
return fmt.Errorf("create file: %w", err)
_ = os.Remove(f.Name()) // Cleanup in case a step failed.
}
}()

_, err = io.Copy(f, r)
if err != nil {
_ = f.Close()
return fmt.Errorf("write temp file failed: %w", err)
}

f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
err = f.Close()
if err != nil {
return fmt.Errorf("open file for appending: %w", err)
return fmt.Errorf("close temp file failed: %w", err)
}
defer f.Close()

err = shell.WriteCompletion(f)
err = os.Rename(f.Name(), path)
if err != nil {
return fmt.Errorf("write completion script: %w", err)
return fmt.Errorf("rename temp file failed: %w", err)
}

return nil
Expand Down
16 changes: 6 additions & 10 deletions completion/bash.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@ func (b *bash) Name() string {
return "bash"
}

// UsesOwnFile implements Shell.
func (b *bash) UsesOwnFile() bool {
return false
}

// InstallPath implements Shell.
func (b *bash) InstallPath() (string, error) {
homeDir, err := home.Dir()
Expand All @@ -42,12 +37,15 @@ func (b *bash) InstallPath() (string, error) {

// WriteCompletion implements Shell.
func (b *bash) WriteCompletion(w io.Writer) error {
return generateCompletion(bashCompletionTemplate)(w, b.programName)
return configTemplateWriter(w, bashCompletionTemplate, b.programName)
}

const bashCompletionTemplate = `
// ProgramName implements Shell.
func (b *bash) ProgramName() string {
return b.programName
}

# === BEGIN {{.Name}} COMPLETION ===
const bashCompletionTemplate = `
_generate_{{.Name}}_completions() {
# Capture the line excluding the command, and everything after the current word
local args=("${COMP_WORDS[@]:1:COMP_CWORD}")
Expand All @@ -65,6 +63,4 @@ _generate_{{.Name}}_completions() {
}
# Setup Bash to use the function for completions for '{{.Name}}'
complete -F _generate_{{.Name}}_completions {{.Name}}
# === END {{.Name}} COMPLETION ===
`
12 changes: 6 additions & 6 deletions completion/fish.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ func Fish(goos string, programName string) Shell {
return &fish{goos: goos, programName: programName}
}

// UsesOwnFile implements Shell.
func (f *fish) UsesOwnFile() bool {
return true
}

// Name implements Shell.
func (f *fish) Name() string {
return "fish"
Expand All @@ -39,7 +34,12 @@ func (f *fish) InstallPath() (string, error) {

// WriteCompletion implements Shell.
func (f *fish) WriteCompletion(w io.Writer) error {
return generateCompletion(fishCompletionTemplate)(w, f.programName)
return configTemplateWriter(w, fishCompletionTemplate, f.programName)
}

// ProgramName implements Shell.
func (f *fish) ProgramName() string {
return f.programName
}

const fishCompletionTemplate = `
Expand Down
18 changes: 7 additions & 11 deletions completion/powershell.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ type powershell struct {
programName string
}

var _ Shell = &powershell{}

// Name implements Shell.
func (p *powershell) Name() string {
return "powershell"
Expand All @@ -20,11 +22,6 @@ func Powershell(goos string, programName string) Shell {
return &powershell{goos: goos, programName: programName}
}

// UsesOwnFile implements Shell.
func (p *powershell) UsesOwnFile() bool {
return false
}

// InstallPath implements Shell.
func (p *powershell) InstallPath() (string, error) {
var (
Expand All @@ -45,14 +42,15 @@ func (p *powershell) InstallPath() (string, error) {

// WriteCompletion implements Shell.
func (p *powershell) WriteCompletion(w io.Writer) error {
return generateCompletion(pshCompletionTemplate)(w, p.programName)
return configTemplateWriter(w, pshCompletionTemplate, p.programName)
}

var _ Shell = &powershell{}
// ProgramName implements Shell.
func (p *powershell) ProgramName() string {
return p.programName
}

const pshCompletionTemplate = `
# === BEGIN {{.Name}} COMPLETION ===
# Escaping output sourced from:
# https://github.com/spf13/cobra/blob/e94f6d0dd9a5e5738dca6bce03c4b1207ffbc0ec/powershell_completions.go#L47
filter _{{.Name}}_escapeStringWithSpecialChars {
Expand Down Expand Up @@ -89,6 +87,4 @@ $_{{.Name}}_completions = {
rm env:COMPLETION_MODE
}
Register-ArgumentCompleter -CommandName {{.Name}} -ScriptBlock $_{{.Name}}_completions
# === END {{.Name}} COMPLETION ===
`
16 changes: 6 additions & 10 deletions completion/zsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@ func (z *zsh) Name() string {
return "zsh"
}

// UsesOwnFile implements Shell.
func (z *zsh) UsesOwnFile() bool {
return false
}

// InstallPath implements Shell.
func (z *zsh) InstallPath() (string, error) {
homeDir, err := home.Dir()
Expand All @@ -39,19 +34,20 @@ func (z *zsh) InstallPath() (string, error) {

// WriteCompletion implements Shell.
func (z *zsh) WriteCompletion(w io.Writer) error {
return generateCompletion(zshCompletionTemplate)(w, z.programName)
return configTemplateWriter(w, zshCompletionTemplate, z.programName)
}

const zshCompletionTemplate = `
// ProgramName implements Shell.
func (z *zsh) ProgramName() string {
return z.programName
}

# === BEGIN {{.Name}} COMPLETION ===
const zshCompletionTemplate = `
_{{.Name}}_completions() {
local -a args completions
args=("${words[@]:1:$#words}")
completions=($(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}"))
compadd -a completions
}
compdef _{{.Name}}_completions {{.Name}}
# === END {{.Name}} COMPLETION ===
`
Loading

0 comments on commit daea696

Please sign in to comment.