diff --git a/server/logger/logger.go b/server/logger/logger.go index 590e167..25c50f0 100644 --- a/server/logger/logger.go +++ b/server/logger/logger.go @@ -221,6 +221,57 @@ func (log *Logger) Log(entry LogEntry) error { return err } +// Implement a streaming iterator for log entries +// This will allow us to iterate through the log entries without loading all of them into memory +// This is useful for large logs +type LogEntryIterator struct { + rows *sql.Rows +} + +func (it *LogEntryIterator) Next() bool { + res := it.rows.Next() + if !res { + it.rows.Close() + } + return res +} + +func (it *LogEntryIterator) Value() (LogEntry, error) { + var entry LogEntry + if err := it.rows.Scan(&entry.ParticipantId, &entry.ExecutedCommand, &entry.ErrorCode, &entry.ErrorMessage, &entry.GeneratedOutput, &entry.FilePath, &entry.CreatedAt); err != nil { + return LogEntry{}, err + } + return entry, nil +} + +func (it *LogEntryIterator) List() ([]LogEntry, error) { + var entries []LogEntry + for it.Next() { + entry, err := it.Value() + if err != nil { + return nil, err + } + entries = append(entries, entry) + } + return entries, nil +} + +func (log *Logger) Entries() (*LogEntryIterator, error) { + rows, err := log.db.Query("SELECT participant_id, executed_command, error_code, error_message, generated_output, file_path, created_at FROM logs") + if err != nil { + return nil, err + } + return &LogEntryIterator{rows: rows}, nil +} + +func (log *Logger) EntriesByParticipantId(participantId string) (*LogEntryIterator, error) { + rows, err := log.db.Query("SELECT participant_id, executed_command, error_code, error_message, generated_output, file_path, created_at FROM logs WHERE participant_id = ?", participantId) + if err != nil { + return nil, err + } + return &LogEntryIterator{rows: rows}, nil +} + func (log *Logger) Reset() error { // delete logs if _, err := log.db.Exec("DELETE FROM logs WHERE participant_id = ?", log.ParticipantId()); err != nil { diff --git a/server/logger/logger_test.go b/server/logger/logger_test.go index 4dd0acc..e0fc0ed 100644 --- a/server/logger/logger_test.go +++ b/server/logger/logger_test.go @@ -35,6 +35,104 @@ func TestLogger_Log(t *testing.T) { } } +func TestLogger_Entries(t *testing.T) { + // Create a temporary database file for testing + dbPath := "test.db" + defer os.Remove(dbPath) + + // Create a new logger + log, err := logger.NewMemoryLogger() + if err != nil { + t.Fatal(err) + } + defer log.Close() + + // Add multiple logs + for i := 0; i < 5; i++ { + params := logger.LogEntry{ + ExecutedCommand: "go test", + ErrorCode: 1, + ErrorMessage: "Test failed", + GeneratedOutput: "Some output", + FilePath: fmt.Sprintf("/path/to/file%d.go", i), + } + + err = log.Log(params) + if err != nil { + t.Fatal(err) + } + } + + // Test Entries + entriesIter, err := log.Entries() + if err != nil { + t.Fatal(err) + } + + entries, err := entriesIter.List() + if err != nil { + t.Fatal(err) + } + + // Compare the retrieved entries with the original entries + if len(entries) != 5 { + t.Errorf("expected 5 entries, got %d", len(entries)) + } +} + +func TestLogger_EntriesByParticipantId(t *testing.T) { + // Create a temporary database file for testing + dbPath := "test.db" + defer os.Remove(dbPath) + + // Create a new logger + log, err := logger.NewMemoryLogger() + if err != nil { + t.Fatal(err) + } + defer log.Close() + + // Add multiple logs (first loop is the participant id and the second loop is the log entry) + participantIds := []string{ + log.ParticipantId(), + "participant2", + } + + for _, participantId := range participantIds { + for i := 0; i < 5; i++ { + params := logger.LogEntry{ + ParticipantId: participantId, + ExecutedCommand: "go test", + ErrorCode: 1, + ErrorMessage: "Test failed", + GeneratedOutput: "Some output", + FilePath: fmt.Sprintf("/path/to/file%d.go", i), + } + + err = log.Log(params) + if err != nil { + t.Fatal(err) + } + } + } + + // Test EntriesByParticipantId + entriesIter, err := log.EntriesByParticipantId(log.ParticipantId()) + if err != nil { + t.Fatal(err) + } + + entries, err := entriesIter.List() + if err != nil { + t.Fatal(err) + } + + // Compare the retrieved entries with the original entries + if len(entries) != 5 { + t.Errorf("expected 5 entries, got %d", len(entries)) + } +} + func TestLogger_AddSetting(t *testing.T) { // Create a temporary database file for testing dbPath := "test.db"