Skip to content

Commit

Permalink
feat(yaml inference): Use Llama3 as model
Browse files Browse the repository at this point in the history
This commit sets the Groq model to Llama3-70B and adds a progress bar for the slow generation.
  • Loading branch information
gwenwindflower committed Apr 20, 2024
1 parent 44dba80 commit 70cf11b
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions generate_column_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/gwenwindflower/tbd/shared"
"github.com/schollz/progressbar/v3"
)

type Payload struct {
Expand Down Expand Up @@ -93,16 +94,22 @@ They will be nested under a 'tests' key in a YAML file, so no need to add a titl
`
)

func GenerateColumnDescriptions(tables shared.SourceTables) {
func GenerateColumnDescriptions(ts shared.SourceTables) {
var wg sync.WaitGroup

semaphore := make(chan struct{}, maxRate)
// We maek 2 calls so we divide the rate by 2
limiter := time.NewTicker(interval / (maxRate / 2))
defer limiter.Stop()

for i := range tables.SourceTables {
for j := range tables.SourceTables[i].Columns {
bar := progressbar.NewOptions(len(ts.SourceTables),
progressbar.OptionShowCount(),
progressbar.OptionShowElapsedTimeOnFinish(),
progressbar.OptionEnableColorCodes(true),
progressbar.OptionSetDescription("🤖📝"),
)
for i := range ts.SourceTables {
for j := range ts.SourceTables[i].Columns {

semaphore <- struct{}{}
<-limiter.C
Expand All @@ -112,8 +119,8 @@ func GenerateColumnDescriptions(tables shared.SourceTables) {
defer wg.Done()
defer func() { <-semaphore }()

table_name := tables.SourceTables[i].Name
column_name := tables.SourceTables[i].Columns[j].Name
table_name := ts.SourceTables[i].Name
column_name := ts.SourceTables[i].Columns[j].Name
desc_prompt := fmt.Sprintf(desc_prompt, table_name, column_name)
tests_prompt := fmt.Sprintf(tests_prompt, table_name, column_name)
desc_resp, err := GetGroqResponse(desc_prompt)
Expand All @@ -125,16 +132,17 @@ func GenerateColumnDescriptions(tables shared.SourceTables) {
log.Fatalf("Failed to get response from Groq for tests: %v\n", err)
}
if len(desc_resp.Choices) > 0 {
tables.SourceTables[i].Columns[j].Description = desc_resp.Choices[0].Message.Content
ts.SourceTables[i].Columns[j].Description = desc_resp.Choices[0].Message.Content
}
if len(tests_resp.Choices) > 0 {
r := regexp.MustCompile(`unique|not_null`)
matches := r.FindAllString(tests_resp.Choices[0].Message.Content, -1)
matches = Deduplicate(matches)
tables.SourceTables[i].Columns[j].Tests = matches
ts.SourceTables[i].Columns[j].Tests = matches
}
}(i, j)
}
bar.Add(1)
}
wg.Wait()
}
Expand All @@ -147,7 +155,7 @@ func GetGroqResponse(prompt string) (GroqResponse, error) {
Content: prompt,
},
},
Model: "mixtral-8x7b-32768",
Model: "Llama3-70B-8192",
Temp: 0.5,
Tokens: 2048,
TopP: 1,
Expand Down

0 comments on commit 70cf11b

Please sign in to comment.