diff --git a/generate_column_desc.go b/generate_column_desc.go index 504ded7..ae4b731 100644 --- a/generate_column_desc.go +++ b/generate_column_desc.go @@ -13,6 +13,7 @@ import ( "time" "github.com/gwenwindflower/tbd/shared" + "github.com/schollz/progressbar/v3" ) type Payload struct { @@ -93,7 +94,7 @@ They will be nested under a 'tests' key in a YAML file, so no need to add a titl ` ) -func GenerateColumnDescriptions(tables shared.SourceTables) { +func GenerateColumnDescriptions(ts shared.SourceTables) { var wg sync.WaitGroup semaphore := make(chan struct{}, maxRate) @@ -101,8 +102,14 @@ func GenerateColumnDescriptions(tables shared.SourceTables) { limiter := time.NewTicker(interval / (maxRate / 2)) defer limiter.Stop() - for i := range tables.SourceTables { - for j := range tables.SourceTables[i].Columns { + bar := progressbar.NewOptions(len(ts.SourceTables), + progressbar.OptionShowCount(), + progressbar.OptionShowElapsedTimeOnFinish(), + progressbar.OptionEnableColorCodes(true), + progressbar.OptionSetDescription("🤖📝"), + ) + for i := range ts.SourceTables { + for j := range ts.SourceTables[i].Columns { semaphore <- struct{}{} <-limiter.C @@ -112,8 +119,8 @@ func GenerateColumnDescriptions(tables shared.SourceTables) { defer wg.Done() defer func() { <-semaphore }() - table_name := tables.SourceTables[i].Name - column_name := tables.SourceTables[i].Columns[j].Name + table_name := ts.SourceTables[i].Name + column_name := ts.SourceTables[i].Columns[j].Name desc_prompt := fmt.Sprintf(desc_prompt, table_name, column_name) tests_prompt := fmt.Sprintf(tests_prompt, table_name, column_name) desc_resp, err := GetGroqResponse(desc_prompt) @@ -125,16 +132,17 @@ func GenerateColumnDescriptions(tables shared.SourceTables) { log.Fatalf("Failed to get response from Groq for tests: %v\n", err) } if len(desc_resp.Choices) > 0 { - tables.SourceTables[i].Columns[j].Description = desc_resp.Choices[0].Message.Content + ts.SourceTables[i].Columns[j].Description = desc_resp.Choices[0].Message.Content } if len(tests_resp.Choices) > 0 { r := regexp.MustCompile(`unique|not_null`) matches := r.FindAllString(tests_resp.Choices[0].Message.Content, -1) matches = Deduplicate(matches) - tables.SourceTables[i].Columns[j].Tests = matches + ts.SourceTables[i].Columns[j].Tests = matches } }(i, j) } + bar.Add(1) } wg.Wait() } @@ -147,7 +155,7 @@ func GetGroqResponse(prompt string) (GroqResponse, error) { Content: prompt, }, }, - Model: "mixtral-8x7b-32768", + Model: "Llama3-70B-8192", Temp: 0.5, Tokens: 2048, TopP: 1,