Skip to content

Commit

Permalink
Reset stored config if all regions fail to set up
Browse files Browse the repository at this point in the history
Fixes #378
  • Loading branch information
DavidS-ovm committed Jun 10, 2024
1 parent d732816 commit 25f8b1a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 18 deletions.
32 changes: 18 additions & 14 deletions cmd/tea.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"github.com/overmindtech/sdp-go/auth"
stdlibsource "github.com/overmindtech/stdlib-source/sources"
log "github.com/sirupsen/logrus"
"github.com/sourcegraph/conc"
"github.com/sourcegraph/conc/pool"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
Expand Down Expand Up @@ -213,24 +213,27 @@ func InitializeSources(ctx context.Context, oi OvermindInstance, aws_config, aws
"me-south-1",
"me-central-1",
"sa-east-1"}
working_regions := make(chan string, len(all_regions)) // enough space for all regions
region_checkers := conc.NewWaitGroup()
configCtx, configCancel := context.WithTimeout(ctx, 10*time.Second)
defer configCancel()

region_checkers := pool.
NewWithResults[string]().
WithContext(configCtx).
WithMaxGoroutines(len(all_regions)).
WithFirstError()

for _, r := range all_regions {
r := r // loopvar saver; TODO: update golangci-lint or vscode validator to understand this is not required anymore
configCtx, configCancel := context.WithTimeout(ctx, 10*time.Second)
defer configCancel()

lf := log.Fields{
"region": r,
"strategy": awsAuthConfig.Strategy,
}

region_checkers.Go(func() {
region_checkers.Go(func(ctx context.Context) (string, error) {
cfg, err := awsAuthConfig.GetAWSConfig(r)
if err != nil {
log.WithError(err).WithFields(lf).Debug("skipping region")
return
return "", err
}

// Add OTel instrumentation
Expand All @@ -249,18 +252,19 @@ func InitializeSources(ctx context.Context, oi OvermindInstance, aws_config, aws
lf["externalID"] = awsAuthConfig.ExternalID
}
log.WithError(err).WithFields(lf).Debug("skipping region")
return
return "", err
}
working_regions <- r
return r, nil
})
}

region_checkers.Wait()
close(working_regions)
for r := range working_regions {
awsAuthConfig.Regions = append(awsAuthConfig.Regions, r)
working_regions, err := region_checkers.Wait()
// errors are only relevant if no region remained
if len(working_regions) == 0 {
return func() {}, fmt.Errorf("no regions available: %w", err)
}

awsAuthConfig.Regions = append(awsAuthConfig.Regions, working_regions...)
log.WithField("regions", awsAuthConfig.Regions).Debug("Using regions")

awsEngine, err := proc.InitializeAwsSourceEngine(ctx, natsOptions, awsAuthConfig, 2_000)
Expand Down
20 changes: 16 additions & 4 deletions cmd/tea_initialisesources.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/charmbracelet/huh"
"github.com/charmbracelet/lipgloss"
"github.com/overmindtech/sdp-go"
log "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"golang.org/x/oauth2"
)
Expand All @@ -25,7 +26,12 @@ type loadSourcesConfigMsg struct {
token *oauth2.Token
}

type askForAwsConfigMsg struct{}
type askForAwsConfigMsg struct {
// an optional message when requesting a new config to explain why a new
// config is required. This is used for example when a source does not start
// up correctly.
retryMsg string
}
type configStoredMsg struct{}
type sourceInitialisationFailedMsg struct{ err error }
type sourcesInitialisedMsg struct{}
Expand Down Expand Up @@ -132,6 +138,10 @@ func (m initialiseSourcesModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
Title("Choose how to access your AWS account (read-only):").
Options(options...)
m.awsConfigForm = huh.NewForm(huh.NewGroup(selector))
m.awsConfigFormDone = false
if msg.retryMsg != "" {
m.errorHints = append(m.errorHints, msg.retryMsg)
}
cmds = append(cmds, selector.Focus())
} else {
m.awsConfigFormDone = true
Expand All @@ -150,7 +160,7 @@ func (m initialiseSourcesModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
}
case configStoredMsg:
m.title += " (config stored)"
m.title = "Configuring AWS Access (config stored)"
case sourcesInitialisedMsg:
m.awsSourceRunning = true
m.stdlibSourceRunning = true
Expand Down Expand Up @@ -261,7 +271,7 @@ func (m initialiseSourcesModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
func (m initialiseSourcesModel) View() string {
bits := []string{m.taskModel.View()}
for _, hint := range m.errorHints {
bits = append(bits, wrap(fmt.Sprintf(" %v", hint), m.width, 2))
bits = append(bits, wrap(fmt.Sprintf(" %v %v", lipgloss.NewStyle().Foreground(ColorPalette.BgDanger).Render("✗"), hint), m.width, 2))
}
if m.awsConfigForm != nil && !m.awsConfigFormDone {
bits = append(bits, m.awsConfigForm.View())
Expand Down Expand Up @@ -335,7 +345,9 @@ func (m initialiseSourcesModel) startSourcesCmd(aws_config, aws_profile string)
// should sources require more teardown, we'll have to figure something out.
_, err := InitializeSources(m.ctx, m.oi, aws_config, aws_profile, m.token)
if err != nil {
return sourceInitialisationFailedMsg{err}
log.WithError(err).Error("failed to initialise sources")
viper.Set("reset-stored-config", true)
return askForAwsConfigMsg{retryMsg: fmt.Sprintf("Error initialising sources: %v", err)}
}
return sourcesInitialisedMsg{}
}
Expand Down

0 comments on commit 25f8b1a

Please sign in to comment.