diff --git a/cmd/tea.go b/cmd/tea.go index 1888b224..866f592b 100644 --- a/cmd/tea.go +++ b/cmd/tea.go @@ -17,7 +17,7 @@ import ( "github.com/overmindtech/sdp-go/auth" stdlibsource "github.com/overmindtech/stdlib-source/sources" log "github.com/sirupsen/logrus" - "github.com/sourcegraph/conc" + "github.com/sourcegraph/conc/pool" "github.com/spf13/cobra" "github.com/spf13/viper" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" @@ -213,24 +213,27 @@ func InitializeSources(ctx context.Context, oi OvermindInstance, aws_config, aws "me-south-1", "me-central-1", "sa-east-1"} - working_regions := make(chan string, len(all_regions)) // enough space for all regions - region_checkers := conc.NewWaitGroup() + configCtx, configCancel := context.WithTimeout(ctx, 10*time.Second) + defer configCancel() + + region_checkers := pool. + NewWithResults[string](). + WithContext(configCtx). + WithMaxGoroutines(len(all_regions)). + WithFirstError() for _, r := range all_regions { r := r // loopvar saver; TODO: update golangci-lint or vscode validator to understand this is not required anymore - configCtx, configCancel := context.WithTimeout(ctx, 10*time.Second) - defer configCancel() - lf := log.Fields{ "region": r, "strategy": awsAuthConfig.Strategy, } - region_checkers.Go(func() { + region_checkers.Go(func(ctx context.Context) (string, error) { cfg, err := awsAuthConfig.GetAWSConfig(r) if err != nil { log.WithError(err).WithFields(lf).Debug("skipping region") - return + return "", err } // Add OTel instrumentation @@ -249,18 +252,19 @@ func InitializeSources(ctx context.Context, oi OvermindInstance, aws_config, aws lf["externalID"] = awsAuthConfig.ExternalID } log.WithError(err).WithFields(lf).Debug("skipping region") - return + return "", err } - working_regions <- r + return r, nil }) } - region_checkers.Wait() - close(working_regions) - for r := range working_regions { - awsAuthConfig.Regions = append(awsAuthConfig.Regions, r) + working_regions, err := region_checkers.Wait() + // errors are only relevant if no region remained + if len(working_regions) == 0 { + return func() {}, fmt.Errorf("no regions available: %w", err) } + awsAuthConfig.Regions = append(awsAuthConfig.Regions, working_regions...) log.WithField("regions", awsAuthConfig.Regions).Debug("Using regions") awsEngine, err := proc.InitializeAwsSourceEngine(ctx, natsOptions, awsAuthConfig, 2_000) diff --git a/cmd/tea_initialisesources.go b/cmd/tea_initialisesources.go index 6561832e..e39d3abe 100644 --- a/cmd/tea_initialisesources.go +++ b/cmd/tea_initialisesources.go @@ -14,6 +14,7 @@ import ( "github.com/charmbracelet/huh" "github.com/charmbracelet/lipgloss" "github.com/overmindtech/sdp-go" + log "github.com/sirupsen/logrus" "github.com/spf13/viper" "golang.org/x/oauth2" ) @@ -25,7 +26,12 @@ type loadSourcesConfigMsg struct { token *oauth2.Token } -type askForAwsConfigMsg struct{} +type askForAwsConfigMsg struct { + // an optional message when requesting a new config to explain why a new + // config is required. This is used for example when a source does not start + // up correctly. + retryMsg string +} type configStoredMsg struct{} type sourceInitialisationFailedMsg struct{ err error } type sourcesInitialisedMsg struct{} @@ -132,6 +138,10 @@ func (m initialiseSourcesModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { Title("Choose how to access your AWS account (read-only):"). Options(options...) m.awsConfigForm = huh.NewForm(huh.NewGroup(selector)) + m.awsConfigFormDone = false + if msg.retryMsg != "" { + m.errorHints = append(m.errorHints, msg.retryMsg) + } cmds = append(cmds, selector.Focus()) } else { m.awsConfigFormDone = true @@ -150,7 +160,7 @@ func (m initialiseSourcesModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } case configStoredMsg: - m.title += " (config stored)" + m.title = "Configuring AWS Access (config stored)" case sourcesInitialisedMsg: m.awsSourceRunning = true m.stdlibSourceRunning = true @@ -261,7 +271,7 @@ func (m initialiseSourcesModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m initialiseSourcesModel) View() string { bits := []string{m.taskModel.View()} for _, hint := range m.errorHints { - bits = append(bits, wrap(fmt.Sprintf(" %v", hint), m.width, 2)) + bits = append(bits, wrap(fmt.Sprintf(" %v %v", lipgloss.NewStyle().Foreground(ColorPalette.BgDanger).Render("✗"), hint), m.width, 2)) } if m.awsConfigForm != nil && !m.awsConfigFormDone { bits = append(bits, m.awsConfigForm.View()) @@ -335,7 +345,9 @@ func (m initialiseSourcesModel) startSourcesCmd(aws_config, aws_profile string) // should sources require more teardown, we'll have to figure something out. _, err := InitializeSources(m.ctx, m.oi, aws_config, aws_profile, m.token) if err != nil { - return sourceInitialisationFailedMsg{err} + log.WithError(err).Error("failed to initialise sources") + viper.Set("reset-stored-config", true) + return askForAwsConfigMsg{retryMsg: fmt.Sprintf("Error initialising sources: %v", err)} } return sourcesInitialisedMsg{} }