diff --git a/internal/collector/scheduled_task/scheduled_task.go b/internal/collector/scheduled_task/scheduled_task.go index cb9bb82ec..189e3fbbd 100644 --- a/internal/collector/scheduled_task/scheduled_task.go +++ b/internal/collector/scheduled_task/scheduled_task.go @@ -20,30 +20,24 @@ import ( "fmt" "log/slog" "regexp" + "runtime" "strings" - "sync" "github.com/alecthomas/kingpin/v2" "github.com/go-ole/go-ole" "github.com/go-ole/go-ole/oleutil" - "github.com/prometheus-community/windows_exporter/internal/headers/schedule_service" "github.com/prometheus-community/windows_exporter/internal/mi" "github.com/prometheus-community/windows_exporter/internal/types" "github.com/prometheus/client_golang/prometheus" ) -const ( - Name = "scheduled_task" - - workerCount = 4 -) +const Name = "scheduled_task" type Config struct { TaskExclude *regexp.Regexp `yaml:"task_exclude"` TaskInclude *regexp.Regexp `yaml:"task_include"` } -//nolint:gochecknoglobals var ConfigDefaults = Config{ TaskExclude: types.RegExpEmpty, TaskInclude: types.RegExpAny, @@ -52,12 +46,6 @@ var ConfigDefaults = Config{ type Collector struct { config Config - logger *slog.Logger - - scheduledTasksReqCh chan struct{} - scheduledTasksWorker chan scheduledTaskWorkerRequest - scheduledTasksCh chan scheduledTaskResults - lastResult *prometheus.Desc missedRuns *prometheus.Desc state *prometheus.Desc @@ -82,10 +70,7 @@ const ( SCHED_S_TASK_HAS_NOT_RUN TaskResult = 0x00041303 ) -//nolint:gochecknoglobals -var taskStates = []string{"disabled", "queued", "ready", "running", "unknown"} - -type scheduledTask struct { +type ScheduledTask struct { Name string Path string Enabled bool @@ -94,15 +79,7 @@ type scheduledTask struct { LastTaskResult TaskResult } -type scheduledTaskResults struct { - tasks []scheduledTask - err error -} - -type scheduledTaskWorkerRequest struct { - folderPath string - results chan<- scheduledTaskResults -} +type ScheduledTasks []ScheduledTask func New(config *Config) *Collector { if config == nil { @@ -165,27 +142,10 @@ func (c *Collector) GetName() string { } func (c *Collector) Close() error { - close(c.scheduledTasksReqCh) - - c.scheduledTasksReqCh = nil - return nil } -func (c *Collector) Build(logger *slog.Logger, _ *mi.Session) error { - c.logger = logger.With(slog.String("collector", Name)) - - initErrCh := make(chan error) - c.scheduledTasksReqCh = make(chan struct{}) - c.scheduledTasksCh = make(chan scheduledTaskResults) - c.scheduledTasksWorker = make(chan scheduledTaskWorkerRequest, 100) - - go c.initializeScheduleService(initErrCh) - - if err := <-initErrCh; err != nil { - return fmt.Errorf("initialize schedule service: %w", err) - } - +func (c *Collector) Build(_ *slog.Logger, _ *mi.Session) error { c.lastResult = prometheus.NewDesc( prometheus.BuildFQName(types.Namespace, Name, "last_result"), "The result that was returned the last time the registered task was run", @@ -211,7 +171,13 @@ func (c *Collector) Build(logger *slog.Logger, _ *mi.Session) error { } func (c *Collector) Collect(ch chan<- prometheus.Metric) error { - scheduledTasks, err := c.getScheduledTasks() + return c.collect(ch) +} + +var TASK_STATES = []string{"disabled", "queued", "ready", "running", "unknown"} + +func (c *Collector) collect(ch chan<- prometheus.Metric) error { + scheduledTasks, err := getScheduledTasks() if err != nil { return fmt.Errorf("get scheduled tasks: %w", err) } @@ -222,7 +188,7 @@ func (c *Collector) Collect(ch chan<- prometheus.Metric) error { continue } - for _, state := range taskStates { + for _, state := range TASK_STATES { var stateValue float64 if strings.ToLower(task.State.String()) == state { @@ -265,202 +231,71 @@ func (c *Collector) Collect(ch chan<- prometheus.Metric) error { return nil } -func (c *Collector) getScheduledTasks() ([]scheduledTask, error) { - c.scheduledTasksReqCh <- struct{}{} - - scheduledTasks, ok := <-c.scheduledTasksCh +const SCHEDULED_TASK_PROGRAM_ID = "Schedule.Service.1" - if !ok { - return []scheduledTask{}, nil - } - - return scheduledTasks.tasks, scheduledTasks.err -} - -func (c *Collector) initializeScheduleService(initErrCh chan<- error) { - service := schedule_service.New() - if err := service.Connect(); err != nil { - initErrCh <- fmt.Errorf("failed to connect to schedule service: %w", err) - - return - } +// S_FALSE is returned by CoInitialize if it was already called on this thread. +const S_FALSE = 0x00000001 - defer service.Close() +func getScheduledTasks() (ScheduledTasks, error) { + var scheduledTasks ScheduledTasks - errs := make([]error, 0, workerCount) + // The only way to run WMI queries in parallel while being thread-safe is to + // ensure the CoInitialize[Ex]() call is bound to its current OS thread. + // Otherwise, attempting to initialize and run parallel queries across + // goroutines will result in protected memory errors. + runtime.LockOSThread() + defer runtime.UnlockOSThread() - for range workerCount { - errCh := make(chan error, workerCount) - - go c.collectWorker(errCh) - - if err := <-errCh; err != nil { - errs = append(errs, err) + if err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil { + var oleCode *ole.OleError + if errors.As(err, &oleCode) && oleCode.Code() != ole.S_OK && oleCode.Code() != S_FALSE { + return nil, err } } + defer ole.CoUninitialize() - if err := errors.Join(errs...); err != nil { - initErrCh <- err - - return + schedClassID, err := ole.ClassIDFrom(SCHEDULED_TASK_PROGRAM_ID) + if err != nil { + return scheduledTasks, err } - close(initErrCh) - - taskServiceObj := service.GetOLETaskServiceObj() - scheduledTasks := make([]scheduledTask, 0, 500) - - for range c.scheduledTasksReqCh { - func() { - // Clear the slice to avoid memory leaks - clear(scheduledTasks) - scheduledTasks = scheduledTasks[:0] - - res, err := oleutil.CallMethod(taskServiceObj, "GetFolder", `\`) - if err != nil { - c.scheduledTasksCh <- scheduledTaskResults{err: err} - } - - rootFolderObj := res.ToIDispatch() - defer rootFolderObj.Release() - - errs := make([]error, 0) - scheduledTasksWorkerResults := make(chan scheduledTaskResults) - - wg := &sync.WaitGroup{} - - go func() { - for workerResults := range scheduledTasksWorkerResults { - wg.Done() - - if workerResults.err != nil { - errs = append(errs, workerResults.err) - } - - if workerResults.tasks != nil { - errs = append(errs, workerResults.err) - - scheduledTasks = append(scheduledTasks, workerResults.tasks...) - } - } - }() - - if err := c.fetchRecursively(rootFolderObj, wg, scheduledTasksWorkerResults); err != nil { - errs = append(errs, err) - } - - wg.Wait() - - close(scheduledTasksWorkerResults) - - c.scheduledTasksCh <- scheduledTaskResults{tasks: scheduledTasks, err: errors.Join(errs...)} - }() + taskSchedulerObj, err := ole.CreateInstance(schedClassID, nil) + if err != nil || taskSchedulerObj == nil { + return scheduledTasks, err } + defer taskSchedulerObj.Release() - close(c.scheduledTasksCh) - close(c.scheduledTasksWorker) - - c.scheduledTasksCh = nil - c.scheduledTasksWorker = nil -} - -func (c *Collector) collectWorker(errCh chan<- error) { - defer func() { - if r := recover(); r != nil { - c.logger.Error("worker panic", - slog.Any("panic", r), - ) + taskServiceObj := taskSchedulerObj.MustQueryInterface(ole.IID_IDispatch) - errCh := make(chan error, 1) - // Restart the collectWorker - go c.collectWorker(errCh) - - if err := <-errCh; err != nil { - c.logger.Error("failed to restart worker", - slog.Any("err", err), - ) - } - } - }() - - service := schedule_service.New() - if err := service.Connect(); err != nil { - errCh <- fmt.Errorf("failed to connect to schedule service: %w", err) - - return - } - - close(errCh) - - defer service.Close() - - taskServiceObj := service.GetOLETaskServiceObj() - - for task := range c.scheduledTasksWorker { - scheduledTasks, err := fetchTasksInFolder(taskServiceObj, task.folderPath) - - task.results <- scheduledTaskResults{tasks: scheduledTasks, err: err} - } -} - -func (c *Collector) fetchRecursively(folder *ole.IDispatch, wg *sync.WaitGroup, results chan<- scheduledTaskResults) error { - folderPathVariant, err := oleutil.GetProperty(folder, "Path") + _, err = oleutil.CallMethod(taskServiceObj, "Connect") if err != nil { - return fmt.Errorf("failed to get folder path: %w", err) + return scheduledTasks, err } - folderPath := folderPathVariant.ToString() - - wg.Add(1) - c.scheduledTasksWorker <- scheduledTaskWorkerRequest{folderPath: folderPath, results: results} + defer taskServiceObj.Release() - res, err := oleutil.CallMethod(folder, "GetFolders", 1) + res, err := oleutil.CallMethod(taskServiceObj, "GetFolder", `\`) if err != nil { - return err + return scheduledTasks, err } - subFolders := res.ToIDispatch() - defer subFolders.Release() + rootFolderObj := res.ToIDispatch() + defer rootFolderObj.Release() - return oleutil.ForEach(subFolders, func(v *ole.VARIANT) error { - subFolder := v.ToIDispatch() - defer subFolder.Release() + err = fetchTasksRecursively(rootFolderObj, &scheduledTasks) - return c.fetchRecursively(subFolder, wg, results) - }) + return scheduledTasks, err } -func fetchTasksInFolder(taskServiceObj *ole.IDispatch, folderPath string) ([]scheduledTask, error) { - folderObjRes, err := oleutil.CallMethod(taskServiceObj, "GetFolder", folderPath) - if err != nil { - return nil, fmt.Errorf("failed to get folder %s: %w", folderPath, err) - } - - folderObj := folderObjRes.ToIDispatch() - defer folderObj.Release() - - tasksRes, err := oleutil.CallMethod(folderObj, "GetTasks", 1) +func fetchTasksInFolder(folder *ole.IDispatch, scheduledTasks *ScheduledTasks) error { + res, err := oleutil.CallMethod(folder, "GetTasks", 1) if err != nil { - return nil, fmt.Errorf("failed to get tasks in folder %s: %w", folderPath, err) + return err } - tasks := tasksRes.ToIDispatch() + tasks := res.ToIDispatch() defer tasks.Release() - // Get task count - countVariant, err := oleutil.GetProperty(tasks, "Count") - if err != nil { - return nil, fmt.Errorf("failed to get task count: %w", err) - } - - taskCount := int(countVariant.Val) - - defer func(countVariant *ole.VARIANT) { - _ = countVariant.Clear() - }(countVariant) - - scheduledTasks := make([]scheduledTask, 0, taskCount) - err = oleutil.ForEach(tasks, func(v *ole.VARIANT) error { task := v.ToIDispatch() defer task.Release() @@ -470,19 +305,39 @@ func fetchTasksInFolder(taskServiceObj *ole.IDispatch, folderPath string) ([]sch return err } - scheduledTasks = append(scheduledTasks, parsedTask) + *scheduledTasks = append(*scheduledTasks, parsedTask) return nil }) + + return err +} + +func fetchTasksRecursively(folder *ole.IDispatch, scheduledTasks *ScheduledTasks) error { + if err := fetchTasksInFolder(folder, scheduledTasks); err != nil { + return err + } + + res, err := oleutil.CallMethod(folder, "GetFolders", 1) if err != nil { - return nil, fmt.Errorf("failed to iterate over tasks: %w", err) + return err } - return scheduledTasks, nil + subFolders := res.ToIDispatch() + defer subFolders.Release() + + err = oleutil.ForEach(subFolders, func(v *ole.VARIANT) error { + subFolder := v.ToIDispatch() + defer subFolder.Release() + + return fetchTasksRecursively(subFolder, scheduledTasks) + }) + + return err } -func parseTask(task *ole.IDispatch) (scheduledTask, error) { - var scheduledTask scheduledTask +func parseTask(task *ole.IDispatch) (ScheduledTask, error) { + var scheduledTask ScheduledTask taskNameVar, err := oleutil.GetProperty(task, "Name") if err != nil {