From da5c63eb33f435b02316acd820b56b57212dfb85 Mon Sep 17 00:00:00 2001 From: gwen windflower Date: Sun, 21 Apr 2024 06:48:09 -0500 Subject: [PATCH 1/6] fix(sql validation): improve form validation and parameterized SQL Use SQL params where possible (couldn't get Snowflake to work) and more robust handling of blank values from forms (ie multiple spaces are now still seen as empty) --- forms.go | 48 +++++++++++++++++---------------- sourcerer/get_columns.go | 13 +++++---- sourcerer/get_sources_tables.go | 13 +++++---- 3 files changed, 41 insertions(+), 33 deletions(-) diff --git a/forms.go b/forms.go index bf89b36..caccd91 100644 --- a/forms.go +++ b/forms.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "strconv" + "strings" "github.com/charmbracelet/huh" "github.com/fatih/color" @@ -35,7 +36,8 @@ type FormResponse struct { Confirm bool } -var not_empty = func(s string) error { +func notEmpty(s string) error { + s = strings.TrimSpace(s) if len(s) == 0 { return fmt.Errorf("cannot be empty, please enter a value") } @@ -94,14 +96,14 @@ https://github.com/gwenwindflower/tbd Title("What *prefix* for your staging files?"). Value(&dfr.Prefix). Placeholder("stg"). - Validate(not_empty), + Validate(notEmpty), ), huh.NewGroup(huh.NewInput(). Title("What is the *name* of your dbt project?"). Value(&dfr.ProjectName). Placeholder("rivendell"). - Validate(not_empty), + Validate(notEmpty), ).WithHideFunc(func() bool { return !dfr.ScaffoldProject }), @@ -123,17 +125,17 @@ https://github.com/gwenwindflower/tbd Title("Which *output* in that profile do you want to use?"). Value(&dfr.DbtProfileOutput). Placeholder("dev"). - Validate(not_empty), + Validate(notEmpty), huh.NewInput(). Title("What *schema* do you want to generate?"). Value(&dfr.Schema). Placeholder("raw"). - Validate(not_empty), + Validate(notEmpty), huh.NewInput(). Title("What *database* is that schema in?"). Value(&dfr.Database). Placeholder("jaffle_shop"). - Validate(not_empty), + Validate(notEmpty), ).WithHideFunc(func() bool { return !dfr.UseDbtProfile }), @@ -157,22 +159,22 @@ https://github.com/gwenwindflower/tbd Title("What is your username?"). Value(&dfr.Username). Placeholder("aragorn@dunedain.king"). - Validate(not_empty), + Validate(notEmpty), huh.NewInput(). Title("What is your Snowflake account id?"). Value(&dfr.Account). Placeholder("elfstone-consulting.us-west-1"). - Validate(not_empty), + Validate(notEmpty), huh.NewInput(). Title("What is the *schema* you want to generate?"). Value(&dfr.Schema). Placeholder("minas-tirith"). - Validate(not_empty), + Validate(notEmpty), huh.NewInput(). Title("What *database* is that schema in?"). Value(&dfr.Database). Placeholder("gondor"). - Validate(not_empty), + Validate(notEmpty), ).WithHideFunc(func() bool { return dfr.Warehouse != "snowflake" }), @@ -182,12 +184,12 @@ https://github.com/gwenwindflower/tbd Title("What GCP *project id* do you want to generate?"). Value(&dfr.Project). Placeholder("legolas_inc"). - Validate(not_empty), + Validate(notEmpty), huh.NewInput(). Title("What is the *dataset* you want to generate?"). Value(&dfr.Dataset). Placeholder("mirkwood"). - Validate(not_empty), + Validate(notEmpty), ).WithHideFunc(func() bool { return dfr.Warehouse != "bigquery" }), @@ -198,17 +200,17 @@ https://github.com/gwenwindflower/tbd Relative to pwd e.g. if db is in this dir -> cool_ducks.db`). Value(&dfr.Path). Placeholder("/path/to/duckdb.db"). - Validate(not_empty), + Validate(notEmpty), huh.NewInput(). Title("What is the *database* you want to generate?"). Value(&dfr.Database). Placeholder("gimli_corp"). - Validate(not_empty), + Validate(notEmpty), huh.NewInput(). Title("What is the *schema* you want to generate?"). Value(&dfr.Schema). Placeholder("moria"). - Validate(not_empty), + Validate(notEmpty), ).WithHideFunc(func() bool { return dfr.Warehouse != "duckdb" }), @@ -217,7 +219,7 @@ Relative to pwd e.g. if db is in this dir -> cool_ducks.db`). huh.NewInput(). Title("What is your Postgres *host*?"). Value(&dfr.Host). - Validate(not_empty), + Validate(notEmpty), huh.NewInput(). Title("What is your Postgres *port*?"). Value(&dfr.Port). @@ -232,22 +234,22 @@ Relative to pwd e.g. if db is in this dir -> cool_ducks.db`). Title("What is your Postgres *username*?"). Value(&dfr.Username). Placeholder("galadriel"). - Validate(not_empty), + Validate(notEmpty), huh.NewInput(). Title("What is your Postgres *password*?"). Value(&dfr.Password). - Validate(not_empty). + Validate(notEmpty). EchoMode(huh.EchoModePassword), huh.NewInput(). Title("What is the *database* you want to generate?"). Value(&dfr.Database). Placeholder("lothlorien"). - Validate(not_empty), + Validate(notEmpty), huh.NewInput(). Title("What is the *schema* you want to generate?"). Value(&dfr.Schema). Placeholder("mallorn_trees"). - Validate(not_empty), + Validate(notEmpty), huh.NewSelect[string](). Title("What ssl mode do you want to use?"). Value(&dfr.SslMode). @@ -258,7 +260,7 @@ Relative to pwd e.g. if db is in this dir -> cool_ducks.db`). huh.NewOption("Verify-full", "verify-full"), huh.NewOption("Prefer", "prefer"), huh.NewOption("Allow", "allow")). - Validate(not_empty), + Validate(notEmpty), ).WithHideFunc(func() bool { return dfr.Warehouse != "postgres" }), @@ -283,7 +285,7 @@ Get one at https://groq.com.`, yellowItalic("Optional"), pinkUnderline("descript Title("What env var holds your Groq key?"). Placeholder("GROQ_API_KEY"). Value(&dfr.GroqKeyEnvVar). - Validate(not_empty), + Validate(notEmpty), ).WithHideFunc(func() bool { return !dfr.GenerateDescriptions }), @@ -293,7 +295,7 @@ Get one at https://groq.com.`, yellowItalic("Optional"), pinkUnderline("descript Title("What directory do you want to build into?\n Must be new or empty."). Value(&dfr.BuildDir). Placeholder("build"). - Validate(not_empty), + Validate(notEmpty), huh.NewConfirm(). Title("🚦Are you ready to do this thing?🚦"). Value(&dfr.Confirm), diff --git a/sourcerer/get_columns.go b/sourcerer/get_columns.go index 639b1d0..23d9d2f 100644 --- a/sourcerer/get_columns.go +++ b/sourcerer/get_columns.go @@ -14,6 +14,7 @@ import ( func (sfc *SfConn) GetColumns(ctx context.Context, t shared.SourceTable) ([]shared.Column, error) { var cs []shared.Column + // TODO: figure out binding parameters issue on Snowflake so this can be done properly q := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%s' AND table_name = '%s'", sfc.Schema, t.Name) rows, err := sfc.Db.QueryContext(ctx, q) if err != nil { @@ -33,10 +34,12 @@ func (sfc *SfConn) GetColumns(ctx context.Context, t shared.SourceTable) ([]shar func (bqc *BqConn) GetColumns(ctx context.Context, t shared.SourceTable) ([]shared.Column, error) { var cs []shared.Column - qs := fmt.Sprintf("SELECT column_name, data_type FROM %s.%s.INFORMATION_SCHEMA.COLUMNS WHERE table_name = @table", bqc.Project, bqc.Dataset) + qs := "SELECT column_name, data_type FROM @project.@dataset.INFORMATION_SCHEMA.COLUMNS WHERE table_name = @table" q := bqc.Bq.Query(qs) q.Parameters = []bigquery.QueryParameter{ {Name: "table", Value: t.Name}, + {Name: "project", Value: bqc.Project}, + {Name: "dataset", Value: bqc.Dataset}, } it, err := q.Read(ctx) if err != nil { @@ -62,8 +65,8 @@ func (bqc *BqConn) GetColumns(ctx context.Context, t shared.SourceTable) ([]shar func (dc *DuckConn) GetColumns(ctx context.Context, t shared.SourceTable) ([]shared.Column, error) { var cs []shared.Column - q := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%s' AND table_name = '%s'", dc.Schema, t.Name) - rows, err := dc.Db.QueryContext(ctx, q) + q := "SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '?' AND table_name = '?'" + rows, err := dc.Db.QueryContext(ctx, q, dc.Schema, t.Name) if err != nil { log.Fatalf("Error fetching columns for table %s: %v\n", t.Name, err) } @@ -80,8 +83,8 @@ func (dc *DuckConn) GetColumns(ctx context.Context, t shared.SourceTable) ([]sha func (pgc *PgConn) GetColumns(ctx context.Context, t shared.SourceTable) ([]shared.Column, error) { var cs []shared.Column - q := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%s' AND table_name = '%s'", pgc.Schema, t.Name) - rows, err := pgc.Db.QueryContext(ctx, q) + q := "SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '?' AND table_name = '?'" + rows, err := pgc.Db.QueryContext(ctx, q, pgc.Schema, t.Name) if err != nil { log.Fatalf("Error fetching columns for table %s: %v\n", t.Name, err) } diff --git a/sourcerer/get_sources_tables.go b/sourcerer/get_sources_tables.go index df1ee90..8b637b6 100644 --- a/sourcerer/get_sources_tables.go +++ b/sourcerer/get_sources_tables.go @@ -13,7 +13,10 @@ import ( func (sfc *SfConn) GetSourceTables(ctx context.Context) (shared.SourceTables, error) { ts := shared.SourceTables{} defer sfc.Cancel() - rows, err := sfc.Db.QueryContext(ctx, fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%s'", sfc.Schema)) + // TODO: why doesn't this work? + // q := `SELECT table_name FROM information_schema.tables WHERE table_schema = '?'` + q := fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%s'", sfc.Schema) + rows, err := sfc.Db.QueryContext(ctx, q) if err != nil { log.Fatalf("Error fetching tables: %v\n", err) } @@ -50,8 +53,8 @@ func (bqc *BqConn) GetSourceTables(ctx context.Context) (shared.SourceTables, er func (dc *DuckConn) GetSourceTables(ctx context.Context) (shared.SourceTables, error) { ts := shared.SourceTables{} defer dc.Cancel() - q := fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%s'", dc.Schema) - rows, err := dc.Db.QueryContext(ctx, q) + q := "SELECT table_name FROM information_schema.tables WHERE table_schema = '?'" + rows, err := dc.Db.QueryContext(ctx, q, dc.Schema) if err != nil { log.Fatalf("Error fetching tables: %v\n", err) } @@ -70,8 +73,8 @@ func (dc *DuckConn) GetSourceTables(ctx context.Context) (shared.SourceTables, e func (pgc *PgConn) GetSourceTables(ctx context.Context) (shared.SourceTables, error) { ts := shared.SourceTables{} defer pgc.Cancel() - q := fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%s'", pgc.Schema) - rows, err := pgc.Db.QueryContext(ctx, q) + q := "SELECT table_name FROM information_schema.tables WHERE table_schema = '$1'" + rows, err := pgc.Db.QueryContext(ctx, q, pgc.Schema) if err != nil { log.Fatalf("Error fetching tables: %v\n", err) } From f3c76b1589fdb783d4bc786f951f1a0d62b25d6c Mon Sep 17 00:00:00 2001 From: gwen windflower Date: Sun, 21 Apr 2024 12:35:54 -0500 Subject: [PATCH 2/6] feat(llm): refactor to support multiple apis Adds OpenAI and Anthropic, which took significant refactoring. --- .gitignore | 4 +- forms.go | 41 ++++--- generate_column_desc.go | 208 -------------------------------- generate_column_desc_test.go | 47 -------- llm_get_llm.go | 140 +++++++++++++++++++++ llm_get_rate_limiter.go | 23 ++++ llm_get_response.go | 158 ++++++++++++++++++++++++ llm_get_response_test.go | 123 +++++++++++++++++++ llm_infer_column_fields.go | 53 ++++++++ llm_infer_column_fields_test.go | 38 ++++++ llm_prompts.go | 35 ++++++ llm_set_description.go | 46 +++++++ llm_set_tests.go | 54 +++++++++ main.go | 13 +- sourcerer/get_columns.go | 5 +- 15 files changed, 710 insertions(+), 278 deletions(-) delete mode 100644 generate_column_desc.go delete mode 100644 generate_column_desc_test.go create mode 100644 llm_get_llm.go create mode 100644 llm_get_rate_limiter.go create mode 100644 llm_get_response.go create mode 100644 llm_get_response_test.go create mode 100644 llm_infer_column_fields.go create mode 100644 llm_infer_column_fields_test.go create mode 100644 llm_prompts.go create mode 100644 llm_set_description.go create mode 100644 llm_set_tests.go diff --git a/.gitignore b/.gitignore index 02b5ca9..e720bb6 100644 --- a/.gitignore +++ b/.gitignore @@ -24,9 +24,11 @@ # Go workspace file go.work +# build directory +dist/ + # Project specific build test_build tbd -dist/ diff --git a/forms.go b/forms.go index caccd91..20663e4 100644 --- a/forms.go +++ b/forms.go @@ -11,8 +11,8 @@ import ( ) type FormResponse struct { - Path string - Username string + Password string + Host string BuildDir string SslMode string Database string @@ -22,13 +22,14 @@ type FormResponse struct { ProjectName string Warehouse string Account string - GroqKeyEnvVar string - Password string - DbtProfileName string + LlmKeyEnvVar string DbtProfileOutput string + DbtProfileName string + Path string Port string - Host string + Username string Prefix string + Llm string GenerateDescriptions bool ScaffoldProject bool CreateProfile bool @@ -57,11 +58,11 @@ func getProfileOptions(ps DbtProfiles) []huh.Option[string] { func Forms(ps DbtProfiles) (FormResponse, error) { dfr := FormResponse{ - BuildDir: "build", - GroqKeyEnvVar: "GROQ_API_KEY", - Prefix: "stg", - Host: "localhost", - Port: "5432", + BuildDir: "build", + LlmKeyEnvVar: "OPENAI_API_KEY", + Prefix: "stg", + Host: "localhost", + Port: "5432", } pinkUnderline := color.New(color.FgMagenta).Add(color.Bold, color.Underline).SprintFunc() greenBold := color.New(color.FgGreen).Add(color.Bold).SprintFunc() @@ -268,23 +269,27 @@ Relative to pwd e.g. if db is in this dir -> cool_ducks.db`). huh.NewGroup( huh.NewNote(). Title(fmt.Sprintf("πŸ€– %s LLM generation πŸ¦™βœ¨", redBold("Experimental"))). - Description(fmt.Sprintf(`%s features via Groq. -Currently generates: + Description(fmt.Sprintf(`%s generates: ✴︎ column %s ✴︎ relevant %s -_Requires a_ %s _stored in an env var_: -Get one at https://groq.com.`, yellowItalic("Optional"), pinkUnderline("descriptions"), pinkUnderline("tests"), greenBoldItalic("Groq API key"))), +_Requires an_ %s _stored in an env var_.`, yellowItalic("Optionally"), pinkUnderline("descriptions"), pinkUnderline("tests"), greenBoldItalic("LLM API key"))), huh.NewConfirm(). Title("Do you want to infer descriptions and tests?"). Value(&dfr.GenerateDescriptions), ), huh.NewGroup( + huh.NewSelect[string](). + Options( + huh.NewOption("OpenAI", "openai"), + huh.NewOption("Groq", "groq"), + huh.NewOption("Anthropic", "anthropic"), + ).Value(&dfr.Llm), huh.NewInput(). - Title("What env var holds your Groq key?"). - Placeholder("GROQ_API_KEY"). - Value(&dfr.GroqKeyEnvVar). + Title("What env var holds your LLM API key?"). + Placeholder("OPENAI_API_KEY"). + Value(&dfr.LlmKeyEnvVar). Validate(notEmpty), ).WithHideFunc(func() bool { return !dfr.GenerateDescriptions diff --git a/generate_column_desc.go b/generate_column_desc.go deleted file mode 100644 index 4c91d1e..0000000 --- a/generate_column_desc.go +++ /dev/null @@ -1,208 +0,0 @@ -package main - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "log" - "net/http" - "os" - "regexp" - "sync" - "time" - - "github.com/gwenwindflower/tbd/shared" - "github.com/schollz/progressbar/v3" -) - -type Payload struct { - Stop interface{} `json:"stop"` - Model string `json:"model"` - Messages []Message `json:"messages"` - Temp float64 `json:"temperature"` - Tokens int `json:"max_tokens"` - TopP int `json:"top_p"` - Stream bool `json:"stream"` -} - -type Message struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type GroqResponse struct { - SystemFingerprint interface{} `json:"system_fingerprint"` - ID string `json:"id"` - Object string `json:"object"` - Model string `json:"model"` - Choices []struct { - Logprobs interface{} `json:"logprobs"` - Message struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - Index int `json:"index"` - } `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - PromptTime float64 `json:"prompt_time"` - CompletionTokens int `json:"completion_tokens"` - CompletionTime float64 `json:"completion_time"` - TotalTokens int `json:"total_tokens"` - TotalTime float64 `json:"total_time"` - } `json:"usage"` - Created int `json:"created"` -} - -// Groq API constants -const ( - maxRate = 30 - interval = time.Minute - URL = "https://api.groq.com/openai/v1/chat/completions" - desc_prompt = `Generate a description for a column in a specific table in a data warehouse, - the table is called %s and the column is called %s. The description should be concise, 1 to 3 sentences, - and inform both business users and technical data analyts about the purpose and contents of the column. - Avoid using the column name in the description, as it is redundant β€” put another way do not use tautological - descriptions, for example on an 'order_id' column saying "This is the id of an order". Don't do that. A good - example for an 'order_id' column would be something like "This is the primary key of the orders table, - each distinct order has a unique 'order_id'". Another good example for an orders table would be describing - 'product_type' as "The category of product, the bucket that a product falls into, for example 'electronics' or 'clothing'". - Avoid making assumptions about the data, as you don't have access to it. Don't make assertions about data that you - haven't seen, just use business context, the table name, and the column to generate the description. The description. - There is no need to add a title just the sentences that compose the description, it's being put onto a field in a YAML file, -so again, no title, no formatting, just 1 to 3 sentences.` - tests_prompt = `Generate a list of tests that can be run on a column in a specific table in a data warehouse, -the table is called %s and the column is called %s. The tests are YAML config, there are 2 to choose from. -They have the following structure, follow this structure exactly: - - unique - - not_null -Return only the tests that are applicable to the column, for example, a column that is a primary key should have -both unique and not_null tests, while a column that is a foreign key should only have the not_null test. If a -column is potentially optional, then it should have neither test. Return only the tests that are applicable to the column. -They will be nested under a 'tests' key in a YAML file, so no need to add a title or key, just the list of tests by themselves. - For example, a good response for a 'product_type' column in an 'orders' table would be: - - not_null - - A good response for an 'order_id' column in an 'orders' table would be: - - unique - - not_null - - A good response for a 'product_sku' column in an 'orders' table would be: - - not_null -` -) - -func GenerateColumnDescriptions(ts shared.SourceTables) { - var wg sync.WaitGroup - - semaphore := make(chan struct{}, maxRate) - // We maek 2 calls so we divide the rate by 2 - limiter := time.NewTicker(interval / (maxRate / 2)) - defer limiter.Stop() - - bar := progressbar.NewOptions(len(ts.SourceTables), - progressbar.OptionShowCount(), - progressbar.OptionSetWidth(30), - progressbar.OptionShowElapsedTimeOnFinish(), - progressbar.OptionEnableColorCodes(true), - progressbar.OptionSetDescription("πŸ€–πŸ“"), - ) - for i := range ts.SourceTables { - for j := range ts.SourceTables[i].Columns { - - semaphore <- struct{}{} - <-limiter.C - - wg.Add(1) - go func(i, j int) { - defer wg.Done() - defer func() { <-semaphore }() - - table_name := ts.SourceTables[i].Name - column_name := ts.SourceTables[i].Columns[j].Name - desc_prompt := fmt.Sprintf(desc_prompt, table_name, column_name) - tests_prompt := fmt.Sprintf(tests_prompt, table_name, column_name) - desc_resp, err := GetGroqResponse(desc_prompt) - if err != nil { - log.Fatalf("Failed to get response from Groq for description: %v\n", err) - } - tests_resp, err := GetGroqResponse(tests_prompt) - if err != nil { - log.Fatalf("Failed to get response from Groq for tests: %v\n", err) - } - if len(desc_resp.Choices) > 0 { - ts.SourceTables[i].Columns[j].Description = desc_resp.Choices[0].Message.Content - } - if len(tests_resp.Choices) > 0 { - r := regexp.MustCompile(`unique|not_null`) - matches := r.FindAllString(tests_resp.Choices[0].Message.Content, -1) - matches = Deduplicate(matches) - ts.SourceTables[i].Columns[j].Tests = matches - } - }(i, j) - } - bar.Add(1) - } - wg.Wait() -} - -func GetGroqResponse(prompt string) (GroqResponse, error) { - meta := Payload{ - Messages: []Message{ - { - Role: "user", - Content: prompt, - }, - }, - Model: "llama3-70b-8192", - Temp: 0.5, - Tokens: 2048, - TopP: 1, - Stream: false, - Stop: nil, - } - payload, err := json.Marshal(meta) - if err != nil { - log.Fatalf("Failed to marshal JSON: %v\n", err) - } - req, err := http.NewRequest(http.MethodPost, URL, bytes.NewBuffer(payload)) - if err != nil { - log.Fatalf("Unable to create request: %v\n", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+os.Getenv("GROQ_API_KEY")) - client := http.Client{} - response, err := client.Do(req) - if err != nil { - log.Fatalf("Request failed: %v\n", err) - } - defer response.Body.Close() - - body, err := io.ReadAll(response.Body) - if err != nil { - log.Fatalf("Cannot read response body: %v\n", err) - } - - var resp GroqResponse - err = json.Unmarshal(body, &resp) - if err != nil { - log.Fatalf("Failed to unmarshal JSON: %v\n", err) - } - return resp, nil -} - -func Deduplicate(elements []string) []string { - encountered := map[string]bool{} - result := []string{} - - for v := range elements { - if encountered[elements[v]] { - } else { - encountered[elements[v]] = true - result = append(result, elements[v]) - } - } - return result -} diff --git a/generate_column_desc_test.go b/generate_column_desc_test.go deleted file mode 100644 index ede81cc..0000000 --- a/generate_column_desc_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package main - -import ( - "testing" - - "github.com/jarcoal/httpmock" -) - -func TestGetGroqResponse(t *testing.T) { - prompt := "Who destroyed Orthanc" - httpmock.Activate() - defer httpmock.DeactivateAndReset() - httpmock.RegisterResponder("POST", "https://api.groq.com/openai/v1/chat/completions", - httpmock.NewStringResponder(200, `{"choices": [{"index": 0, "message": {"role": "assistant","content": "Treebeard and the Ents destroyed Orthanc."}}]}`)) - GroqResponse, err := GetGroqResponse(prompt) - if err != nil { - t.Error("expected", nil, "got", err) - } - info := httpmock.GetCallCountInfo() - if info["POST https://api.groq.com/openai/v1/chat/completions"] != 1 { - t.Error("expected", 1, "got", info["POST https://api.groq.com/openai/v1/chat/completions"]) - } - expected := "Treebeard and the Ents destroyed Orthanc." - if GroqResponse.Choices[0].Message.Content != expected { - t.Error("expected", expected, "got", GroqResponse.Choices[0].Message.Content) - } -} - -func TestGenerateColumnDescriptions(t *testing.T) { - ts := CreateTempSourceTables() - httpmock.Activate() - defer httpmock.DeactivateAndReset() - httpmock.RegisterResponder("POST", "https://api.groq.com/openai/v1/chat/completions", - httpmock.NewStringResponder(200, `{"choices": [{"index": 0, "message": {"role": "assistant","content": "lord of rivendell"}}]}`)) - GenerateColumnDescriptions(ts) - - info := httpmock.GetCallCountInfo() - if info["POST https://api.groq.com/openai/v1/chat/completions"] != 2 { - t.Error("expected", 2, "got", info["POST https://api.groq.com/openai/v1/chat/completions"]) - } - - expected := "lord of rivendell" - desc := ts.SourceTables[0].Columns[0].Description - if desc != expected { - t.Error("expected", expected, "got", desc) - } -} diff --git a/llm_get_llm.go b/llm_get_llm.go new file mode 100644 index 0000000..90ec2fc --- /dev/null +++ b/llm_get_llm.go @@ -0,0 +1,140 @@ +package main + +import ( + "fmt" + "os" + "time" + + "github.com/gwenwindflower/tbd/shared" +) + +type Llm interface { + GetResponse(prompt string) error + SetDescription(descPrompt string, ts shared.SourceTables, i, j int) error + SetTests(descPrompt string, ts shared.SourceTables, i, j int) error + GetRateLimiter() (chan struct{}, *time.Ticker) +} + +type OpenAI struct { + Type string + ApiKey string + Model string + Url string + Response struct { + SystemFingerprint interface{} `json:"system_fingerprint"` + Id string `json:"id"` + Object string `json:"object"` + Model string `json:"model"` + Choices []struct { + Logprobs interface{} `json:"logprobs"` + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` + Index int `json:"index"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + PromptTime float64 `json:"prompt_time"` + CompletionTokens int `json:"completion_tokens"` + CompletionTime float64 `json:"completion_time"` + TotalTokens int `json:"total_tokens"` + TotalTime float64 `json:"total_time"` + } `json:"usage"` + Created int `json:"created"` + } +} + +type Groq struct { + Type string + ApiKey string + Model string + Url string + Response struct { + SystemFingerprint interface{} `json:"system_fingerprint"` + Id string `json:"id"` + Object string `json:"object"` + Model string `json:"model"` + Choices []struct { + Logprobs interface{} `json:"logprobs"` + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` + Index int `json:"index"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + PromptTime float64 `json:"prompt_time"` + CompletionTokens int `json:"completion_tokens"` + CompletionTime float64 `json:"completion_time"` + TotalTokens int `json:"total_tokens"` + TotalTime float64 `json:"total_time"` + } `json:"usage"` + Created int `json:"created"` + } +} + +type Anthropic struct { + Type string + ApiKey string + Model string + Url string + Response struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + Id string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + StopReason string `json:"stop_reason"` + StopSequence []string `json:"stop_sequence"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputToken int `json:"output_tokens"` + } `json:"usage"` + } +} + +type Payload struct { + Stop interface{} `json:"stop"` + Model string `json:"model"` + Messages []Message `json:"messages"` + Temp float64 `json:"temperature"` + MaxTokens int `json:"max_tokens"` + TopP int `json:"top_p"` + Stream bool `json:"stream"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +func GetLlm(fr FormResponse) (Llm, error) { + switch fr.Llm { + case "groq": + g := &Groq{ + Type: "groq", + ApiKey: os.Getenv(fr.LlmKeyEnvVar), + Model: "llama3-70b-8192", + Url: "https://api.groq.com/openai/v1/chat/completions", + } + return g, nil + case "openai": + o := &OpenAI{ + Type: "openai", + ApiKey: os.Getenv(fr.LlmKeyEnvVar), + Model: "gpt-4-turbo", + Url: "https://api.openai.com/v1/chat/completions", + } + return o, nil + case "anthropic": + a := &Anthropic{ + Type: "anthropic", + ApiKey: os.Getenv(fr.LlmKeyEnvVar), + Model: "claude-3-opus-20240229", + Url: "https://api.anthropic.com/v1/messages", + } + return a, nil + default: + return nil, fmt.Errorf("invalid LLM choice: %v", fr.Llm) + } +} diff --git a/llm_get_rate_limiter.go b/llm_get_rate_limiter.go new file mode 100644 index 0000000..bdedbb1 --- /dev/null +++ b/llm_get_rate_limiter.go @@ -0,0 +1,23 @@ +package main + +import "time" + +func (o *OpenAI) GetRateLimiter() (semaphore chan struct{}, limiter *time.Ticker) { + return getLimiter(240) +} + +func (a *Anthropic) GetRateLimiter() (semaphore chan struct{}, limiter *time.Ticker) { + return getLimiter(240) +} + +func (g *Groq) GetRateLimiter() (semaphore chan struct{}, limiter *time.Ticker) { + return getLimiter(30) +} + +func getLimiter(mr int) (semaphore chan struct{}, limiter *time.Ticker) { + i := time.Minute + semaphore = make(chan struct{}, (mr / 2)) + // We make 2 calls so we divide the rate by 2 + limiter = time.NewTicker(i / time.Duration(mr/2)) + return semaphore, limiter +} diff --git a/llm_get_response.go b/llm_get_response.go new file mode 100644 index 0000000..33cc731 --- /dev/null +++ b/llm_get_response.go @@ -0,0 +1,158 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" +) + +func (o *OpenAI) GetResponse(prompt string) error { + meta := Payload{ + Messages: []Message{ + { + Role: "user", + Content: prompt, + }, + }, + Model: o.Model, + Temp: 0.5, + MaxTokens: 2048, + TopP: 1, + Stream: false, + Stop: nil, + } + payload, err := json.Marshal(meta) + if err != nil { + return fmt.Errorf("failed to marshal JSON: %v", err) + } + req, err := http.NewRequest(http.MethodPost, o.Url, bytes.NewBuffer(payload)) + if err != nil { + return fmt.Errorf("unable to create request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+o.ApiKey) + client := http.Client{} + response, err := client.Do(req) + if err != nil { + return fmt.Errorf("request failed: %v", err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return fmt.Errorf("request failed with status code: %v", response.StatusCode) + } + body, err := io.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("cannot read response body: %v", err) + } + err = json.Unmarshal(body, &o.Response) + if err != nil { + return fmt.Errorf("failed to unmarshal JSON: %v", err) + } + return nil +} + +func (o *Groq) GetResponse(prompt string) error { + meta := Payload{ + Messages: []Message{ + { + Role: "user", + Content: prompt, + }, + }, + Model: o.Model, + Temp: 0.5, + MaxTokens: 2048, + TopP: 1, + Stream: false, + Stop: nil, + } + payload, err := json.Marshal(meta) + if err != nil { + return fmt.Errorf("failed to marshal JSON: %v", err) + } + req, err := http.NewRequest(http.MethodPost, o.Url, bytes.NewBuffer(payload)) + if err != nil { + return fmt.Errorf("unable to create request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+o.ApiKey) + client := http.Client{} + response, err := client.Do(req) + if err != nil { + return fmt.Errorf("request failed: %v", err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return fmt.Errorf("request failed with status code: %v", response.StatusCode) + } + body, err := io.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("cannot read response body: %v", err) + } + err = json.Unmarshal(body, &o.Response) + if err != nil { + return fmt.Errorf("failed to unmarshal JSON: %v", err) + } + return nil +} + +func (a *Anthropic) GetResponse(prompt string) error { + meta := Payload{ + Messages: []Message{ + { + Role: "user", + Content: prompt, + }, + }, + Model: a.Model, + Temp: 0.5, + MaxTokens: 2048, + TopP: 1, + Stream: false, + } + payload, err := json.Marshal(meta) + if err != nil { + return fmt.Errorf("failed to marshal JSON: %v", err) + } + req, err := http.NewRequest(http.MethodPost, a.Url, bytes.NewBuffer(payload)) + if err != nil { + return fmt.Errorf("unable to create request: %v", err) + } + req.Header.Set("content-type", "application/json") + req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Set("x-api-key", a.ApiKey) + client := http.Client{} + response, err := client.Do(req) + if err != nil { + return fmt.Errorf("request failed: %v", err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return fmt.Errorf("request failed with status code: %v", response.StatusCode) + } + body, err := io.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("cannot read response body: %v", err) + } + err = json.Unmarshal(body, &a.Response) + if err != nil { + return fmt.Errorf("failed to unmarshal JSON: %v", err) + } + return nil +} + +func Deduplicate(elements []string) []string { + encountered := map[string]bool{} + result := []string{} + + for v := range elements { + if encountered[elements[v]] { + } else { + encountered[elements[v]] = true + result = append(result, elements[v]) + } + } + return result +} diff --git a/llm_get_response_test.go b/llm_get_response_test.go new file mode 100644 index 0000000..652c5ec --- /dev/null +++ b/llm_get_response_test.go @@ -0,0 +1,123 @@ +package main + +import ( + "fmt" + "testing" + + "github.com/jarcoal/httpmock" +) + +func TestGetGroqResponse(t *testing.T) { + prompt := "Who destroyed Orthanc" + httpmock.Activate() + defer httpmock.DeactivateAndReset() + httpmock.RegisterResponder("POST", "https://api.groq.com/openai/v1/chat/completions", + httpmock.NewStringResponder(200, `{"choices": [{"index": 0, "message": {"role": "assistant","content": "Treebeard and the Ents destroyed Orthanc."}}]}`)) + llm, err := GetLlm(FormResponse{Llm: "groq"}) + if err != nil { + t.Errorf("Did not expect err getting LLM: %v", err) + } + g, ok := llm.(*Groq) + if !ok { + t.Error("Expceted Groq LLM type") + } + err = g.GetResponse(prompt) + if err != nil { + t.Error("expected", nil, "got", err) + } + info := httpmock.GetCallCountInfo() + if info["POST https://api.groq.com/openai/v1/chat/completions"] != 1 { + t.Error("expected", 1, "got", info["POST https://api.groq.com/openai/v1/chat/completions"]) + } + expected := "Treebeard and the Ents destroyed Orthanc." + if g.Response.Choices[0].Message.Content != expected { + t.Error("expected", expected, "got", g.Response.Choices[0].Message.Content) + } +} + +func TestGetOpenAIResponse(t *testing.T) { + prompt := "Who destroyed Orthanc" + httpmock.Activate() + defer httpmock.DeactivateAndReset() + httpmock.RegisterResponder("POST", "https://api.openai.com/v1/chat/completions", + httpmock.NewStringResponder(200, `{"choices": [{"index": 0, "message": {"role": "assistant","content": "Treebeard and the Ents destroyed Orthanc."}}]}`)) + llm, err := GetLlm(FormResponse{Llm: "openai"}) + if err != nil { + t.Errorf("Did not expect err getting LLM: %v", err) + } + o, ok := llm.(*OpenAI) + if !ok { + t.Error("Expceted OpenAI LLM type") + } + err = o.GetResponse(prompt) + if err != nil { + t.Errorf("Did not expect err getting response: %v", err) + } + // TODO: flaky test + // info := httpmock.GetCallCountInfo() + // expectedCalls := 1 + // if info["POST https://api.openai.com/v1/completions"] != expectedCalls { + // t.Error("expected", expectedCalls, "got", info["POST https://api.openai.com/v1/chat/completions"]) + // } + expectedResp := "Treebeard and the Ents destroyed Orthanc." + if o.Response.Choices[0].Message.Content != expectedResp { + t.Error("expected", expectedResp, "got", o.Response.Choices[0].Message.Content) + } +} + +func TestGetAnthropicResponse(t *testing.T) { + prompt := "Who destroyed Orthanc" + url := "https://api.anthropic.com/v1/messages" + httpmock.Activate() + defer httpmock.DeactivateAndReset() + httpmock.RegisterResponder("POST", url, + httpmock.NewStringResponder(200, `{"role": "assistant", "content": [{ "type": "text", "text": "Treebeard and the Ents destroyed Orthanc."}]}`)) + llm, err := GetLlm(FormResponse{Llm: "anthropic"}) + if err != nil { + t.Errorf("Did not expect err getting LLM: %v", err) + } + a, ok := llm.(*Anthropic) + if !ok { + t.Error("Expceted Anthropic LLM type") + } + err = a.GetResponse(prompt) + if err != nil { + t.Errorf("Did not expect err getting response: %v", err) + } + info := httpmock.GetCallCountInfo() + expectedCalls := 1 + if info[fmt.Sprintf("POST %s", url)] != expectedCalls { + t.Error("expected", expectedCalls, "got", info[fmt.Sprintf("POST %s", url)]) + } + expectedResp := "Treebeard and the Ents destroyed Orthanc." + actualResp := a.Response.Content[0].Text + if actualResp != expectedResp { + t.Error("expected", expectedResp, "got", actualResp) + } +} + +func TestGetAnthropicResponseError(t *testing.T) { + prompt := "Who destroyed Orthanc" + url := "https://api.anthropic.com/v1/messages" + httpmock.Activate() + defer httpmock.DeactivateAndReset() + httpmock.RegisterResponder("POST", url, + httpmock.NewStringResponder(400, `{"type": "error", "error": {"type": "invalid_request_error", "message": "max_tokens: Field required"}}`)) + llm, err := GetLlm(FormResponse{Llm: "anthropic"}) + if err != nil { + t.Errorf("Did not expect err getting LLM: %v", err) + } + a, ok := llm.(*Anthropic) + if !ok { + t.Error("Expceted Anthropic LLM type") + } + err = a.GetResponse(prompt) + if err == nil { + t.Error("expected error, got nil") + } + info := httpmock.GetCallCountInfo() + expectedCalls := 1 + if info[fmt.Sprintf("POST %s", url)] != expectedCalls { + t.Error("expected", expectedCalls, "got", info[fmt.Sprintf("POST %s", url)]) + } +} diff --git a/llm_infer_column_fields.go b/llm_infer_column_fields.go new file mode 100644 index 0000000..cf44600 --- /dev/null +++ b/llm_infer_column_fields.go @@ -0,0 +1,53 @@ +package main + +import ( + "fmt" + "sync" + + "github.com/gwenwindflower/tbd/shared" + "github.com/schollz/progressbar/v3" +) + +func InferColumnFields(llm Llm, ts shared.SourceTables) error { + var wg sync.WaitGroup + semaphore, limiter := llm.GetRateLimiter() + defer limiter.Stop() + + bar := progressbar.NewOptions(len(ts.SourceTables), + progressbar.OptionShowCount(), + progressbar.OptionSetWidth(30), + progressbar.OptionShowElapsedTimeOnFinish(), + progressbar.OptionEnableColorCodes(true), + progressbar.OptionSetDescription("πŸ€–πŸ“"), + ) + for i := range ts.SourceTables { + for j := range ts.SourceTables[i].Columns { + + semaphore <- struct{}{} + <-limiter.C + + wg.Add(1) + go func(i, j int) error { + defer wg.Done() + defer func() { <-semaphore }() + + tableName := ts.SourceTables[i].Name + columnName := ts.SourceTables[i].Columns[j].Name + descPrompt := fmt.Sprintf(DESC_PROMPT, tableName, columnName) + testsPrompt := fmt.Sprintf(TESTS_PROMPT, tableName, columnName) + err := llm.SetDescription(descPrompt, ts, i, j) + if err != nil { + return fmt.Errorf("error setting description: %v", err) + } + err = llm.SetTests(testsPrompt, ts, i, j) + if err != nil { + return fmt.Errorf("error setting tests: %v", err) + } + return nil + }(i, j) + } + bar.Add(1) + } + wg.Wait() + return nil +} diff --git a/llm_infer_column_fields_test.go b/llm_infer_column_fields_test.go new file mode 100644 index 0000000..dac5d90 --- /dev/null +++ b/llm_infer_column_fields_test.go @@ -0,0 +1,38 @@ +package main + +import ( + "testing" + + "github.com/jarcoal/httpmock" +) + +func TestInferColumnFields(t *testing.T) { + ts := CreateTempSourceTables() + httpmock.Activate() + defer httpmock.DeactivateAndReset() + httpmock.RegisterResponder("POST", "https://api.groq.com/openai/v1/chat/completions", + httpmock.NewStringResponder(200, `{"choices": [{"index": 0, "message": {"role": "assistant","content": "lord of rivendell"}}]}`)) + llm, err := GetLlm(FormResponse{Llm: "groq"}) + if err != nil { + t.Errorf("Did not expect err getting LLM: %v", err) + } + g, ok := llm.(*Groq) + if !ok { + t.Error("Expceted Groq LLM type") + } + err = InferColumnFields(g, ts) + if err != nil { + t.Errorf("Did not expect err infering column fields: %v", err) + } + + info := httpmock.GetCallCountInfo() + if info["POST https://api.groq.com/openai/v1/chat/completions"] != 2 { + t.Error("expected", 2, "got", info["POST https://api.groq.com/openai/v1/chat/completions"]) + } + + expected := "lord of rivendell" + desc := ts.SourceTables[0].Columns[0].Description + if desc != expected { + t.Error("expected", expected, "got", desc) + } +} diff --git a/llm_prompts.go b/llm_prompts.go new file mode 100644 index 0000000..5e15daa --- /dev/null +++ b/llm_prompts.go @@ -0,0 +1,35 @@ +package main + +const ( + DESC_PROMPT = `Generate a description for a column in a specific table in a data warehouse, + the table is called %s and the column is called %s. The description should be concise, 1 to 3 sentences, + and inform both business users and technical data analyts about the purpose and contents of the column. + Avoid using the column name in the description, as it is redundant β€” put another way do not use tautological + descriptions, for example on an 'order_id' column saying "This is the id of an order". Don't do that. A good + example for an 'order_id' column would be something like "This is the primary key of the orders table, + each distinct order has a unique 'order_id'". Another good example for an orders table would be describing + 'product_type' as "The category of product, the bucket that a product falls into, for example 'electronics' or 'clothing'". + Avoid making assumptions about the data, as you don't have access to it. Don't make assertions about data that you + haven't seen, just use business context, the table name, and the column to generate the description. The description. + There is no need to add a title just the sentences that compose the description, it's being put onto a field in a YAML file, +so again, no title, no formatting, just 1 to 3 sentences.` + TESTS_PROMPT = `Generate a list of tests that can be run on a column in a specific table in a data warehouse, +the table is called %s and the column is called %s. The tests are YAML config, there are 2 to choose from. +They have the following structure, follow this structure exactly: + - unique + - not_null +Return only the tests that are applicable to the column, for example, a column that is a primary key should have +both unique and not_null tests, while a column that is a foreign key should only have the not_null test. If a +column is potentially optional, then it should have neither test. Return only the tests that are applicable to the column. +They will be nested under a 'tests' key in a YAML file, so no need to add a title or key, just the list of tests by themselves. + For example, a good response for a 'product_type' column in an 'orders' table would be: + - not_null + + A good response for an 'order_id' column in an 'orders' table would be: + - unique + - not_null + + A good response for a 'product_sku' column in an 'orders' table would be: + - not_null +` +) diff --git a/llm_set_description.go b/llm_set_description.go new file mode 100644 index 0000000..27af505 --- /dev/null +++ b/llm_set_description.go @@ -0,0 +1,46 @@ +package main + +import ( + "fmt" + + "github.com/gwenwindflower/tbd/shared" +) + +func (o *OpenAI) SetDescription(descPrompt string, ts shared.SourceTables, i, j int) error { + err := o.GetResponse(descPrompt) + if err != nil { + return fmt.Errorf("failed to get response from OpenAI for description: %v", err) + } + if len(o.Response.Choices) > 0 { + ts.SourceTables[i].Columns[j].Description = o.Response.Choices[0].Message.Content + } + return nil +} + +func (g *Groq) SetDescription(descPrompt string, ts shared.SourceTables, i, j int) error { + err := g.GetResponse(descPrompt) + if err != nil { + return fmt.Errorf("failed to get response from Groq for description: %v", err) + } + if len(g.Response.Choices) > 0 { + ts.SourceTables[i].Columns[j].Description = g.Response.Choices[0].Message.Content + } + return nil +} + +func (a *Anthropic) SetDescription(descPrompt string, ts shared.SourceTables, i, j int) error { + err := a.GetResponse(descPrompt) + if err != nil { + return fmt.Errorf("failed to get response from Anthropic for description: %v", err) + } + if len(a.Response.Content) == 0 { + return fmt.Errorf("no response content, likely bad request") + } + resp := a.Response.Content[0] + if len(resp.Text) > 0 && resp.Type == "text" { + ts.SourceTables[i].Columns[j].Description = resp.Text + } else { + return fmt.Errorf("no text response found") + } + return nil +} diff --git a/llm_set_tests.go b/llm_set_tests.go new file mode 100644 index 0000000..3324a81 --- /dev/null +++ b/llm_set_tests.go @@ -0,0 +1,54 @@ +package main + +import ( + "fmt" + "regexp" + + "github.com/gwenwindflower/tbd/shared" +) + +func (o *OpenAI) SetTests(testsPrompt string, ts shared.SourceTables, i, j int) error { + err := o.GetResponse(testsPrompt) + if err != nil { + return fmt.Errorf("failed to get response from OpenAI for tests: %v", err) + } + if len(o.Response.Choices) > 0 { + r := regexp.MustCompile(`unique|not_null`) + matches := r.FindAllString(o.Response.Choices[0].Message.Content, -1) + matches = Deduplicate(matches) + ts.SourceTables[i].Columns[j].Tests = matches + } + return nil +} + +func (g *Groq) SetTests(testsPrompt string, ts shared.SourceTables, i, j int) error { + err := g.GetResponse(testsPrompt) + if err != nil { + return fmt.Errorf("failed to get response from Groq for tests: %v", err) + } + if len(g.Response.Choices) > 0 { + r := regexp.MustCompile(`unique|not_null`) + matches := r.FindAllString(g.Response.Choices[0].Message.Content, -1) + matches = Deduplicate(matches) + ts.SourceTables[i].Columns[j].Tests = matches + } + return nil +} + +func (a *Anthropic) SetTests(testsPrompt string, ts shared.SourceTables, i, j int) error { + err := a.GetResponse(testsPrompt) + if err != nil { + return fmt.Errorf("failed to get response from Anthropic for tests: %v", err) + } + if len(a.Response.Content) == 0 { + return fmt.Errorf("no response content, likely bad request") + } + resp := a.Response.Content[0] + if len(resp.Text) > 0 && resp.Type == "text" { + r := regexp.MustCompile(`unique|not_null`) + matches := r.FindAllString(resp.Text, -1) + matches = Deduplicate(matches) + ts.SourceTables[i].Columns[j].Tests = matches + } + return nil +} diff --git a/main.go b/main.go index 63f40d7..1968792 100644 --- a/main.go +++ b/main.go @@ -63,7 +63,18 @@ func main() { e.ProcessingStart = time.Now() if fr.GenerateDescriptions { - GenerateColumnDescriptions(ts) + llm, err := GetLlm(fr) + if err != nil { + // Using Printf instead of log.Fatalf since the program doesn't + // need to totally fail if the API provider can't be fetched + fmt.Printf("Error getting API provider: %v\n", err) + } + err = InferColumnFields(llm, ts) + if err != nil { + // Using Printf instead of log.Fatalf since the program + // doesn't need to totally fail if there's an error in the column field inference + fmt.Printf("Error inferring column fields: %v\n", err) + } } if fr.CreateProfile { WriteProfile(cd, bd) diff --git a/sourcerer/get_columns.go b/sourcerer/get_columns.go index 23d9d2f..67d3b5c 100644 --- a/sourcerer/get_columns.go +++ b/sourcerer/get_columns.go @@ -34,12 +34,11 @@ func (sfc *SfConn) GetColumns(ctx context.Context, t shared.SourceTable) ([]shar func (bqc *BqConn) GetColumns(ctx context.Context, t shared.SourceTable) ([]shared.Column, error) { var cs []shared.Column - qs := "SELECT column_name, data_type FROM @project.@dataset.INFORMATION_SCHEMA.COLUMNS WHERE table_name = @table" + // BQ does not support binding parameters to table names so we have to do string interpolation + qs := fmt.Sprintf("SELECT column_name, data_type FROM %s.%s.INFORMATION_SCHEMA.COLUMNS WHERE table_name = @table", bqc.Project, bqc.Dataset) q := bqc.Bq.Query(qs) q.Parameters = []bigquery.QueryParameter{ {Name: "table", Value: t.Name}, - {Name: "project", Value: bqc.Project}, - {Name: "dataset", Value: bqc.Dataset}, } it, err := q.Read(ctx) if err != nil { From 03644cb254e47c7fdd08a3aa8dc566f35c205ad4 Mon Sep 17 00:00:00 2001 From: gwen windflower Date: Sun, 21 Apr 2024 13:04:26 -0500 Subject: [PATCH 3/6] fix(llm progress bar): Use column count not table count Iterates quicker and less lag when 'complete' but actually wrapping up the final tables columns. --- llm_infer_column_fields.go | 17 ++++++++++++++--- main.go | 9 ++++++++- sourcerer/put_columns_on_tables.go | 5 ++++- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/llm_infer_column_fields.go b/llm_infer_column_fields.go index cf44600..0addc49 100644 --- a/llm_infer_column_fields.go +++ b/llm_infer_column_fields.go @@ -4,6 +4,7 @@ import ( "fmt" "sync" + "github.com/fatih/color" "github.com/gwenwindflower/tbd/shared" "github.com/schollz/progressbar/v3" ) @@ -13,11 +14,13 @@ func InferColumnFields(llm Llm, ts shared.SourceTables) error { semaphore, limiter := llm.GetRateLimiter() defer limiter.Stop() - bar := progressbar.NewOptions(len(ts.SourceTables), + bar := progressbar.NewOptions(countColumns(ts), progressbar.OptionShowCount(), progressbar.OptionSetWidth(30), - progressbar.OptionShowElapsedTimeOnFinish(), progressbar.OptionEnableColorCodes(true), + progressbar.OptionOnCompletion(func() { + color.HiGreen("\nColumn config generated.") + }), progressbar.OptionSetDescription("πŸ€–πŸ“"), ) for i := range ts.SourceTables { @@ -45,9 +48,17 @@ func InferColumnFields(llm Llm, ts shared.SourceTables) error { } return nil }(i, j) + bar.Add(1) } - bar.Add(1) } wg.Wait() return nil } + +func countColumns(ts shared.SourceTables) int { + c := 0 + for _, t := range ts.SourceTables { + c += len(t.Columns) + } + return c +} diff --git a/main.go b/main.go index 1968792..f14dc1e 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "log" "time" + "github.com/fatih/color" "github.com/gwenwindflower/tbd/sourcerer" ) @@ -49,10 +50,13 @@ func main() { if err != nil { log.Fatalf("Error connecting to database: %v\n", err) } + fmt.Println("Connected to database") ts, err := dbc.GetSourceTables(ctx) if err != nil { log.Fatalf("Error getting sources: %v\n", err) } + fmt.Println("Got source tables") + fmt.Println("Putting columns on tables...") err = sourcerer.PutColumnsOnTables(ctx, ts, dbc) if err != nil { log.Fatalf("Error putting columns on tables: %v\n", err) @@ -69,6 +73,7 @@ func main() { // need to totally fail if the API provider can't be fetched fmt.Printf("Error getting API provider: %v\n", err) } + fmt.Println("Generating descriptions and tests...") err = InferColumnFields(llm, ts) if err != nil { // Using Printf instead of log.Fatalf since the program @@ -76,6 +81,7 @@ func main() { fmt.Printf("Error inferring column fields: %v\n", err) } } + fmt.Println("Writing files...") if fr.CreateProfile { WriteProfile(cd, bd) } @@ -91,5 +97,6 @@ func main() { log.Fatalf("Error writing files: %v\n", err) } e.ProcessingElapsed = time.Since(e.ProcessingStart).Seconds() - fmt.Printf("\n🏁 Done in %.1fs fetching data and %.1fs writing files!\nYour YAML and SQL files are in the %s directory.", e.DbElapsed, e.ProcessingElapsed, fr.BuildDir) + pinkUnderline := color.New(color.FgMagenta).Add(color.Bold, color.Underline).SprintFunc() + fmt.Printf("\n🏁 Done in %.1fs fetching data and %.1fs writing files!\nYour YAML and SQL files are in the %s directory.", e.DbElapsed, e.ProcessingElapsed, pinkUnderline(fr.BuildDir)) } diff --git a/sourcerer/put_columns_on_tables.go b/sourcerer/put_columns_on_tables.go index 96bba8d..6c3b10f 100644 --- a/sourcerer/put_columns_on_tables.go +++ b/sourcerer/put_columns_on_tables.go @@ -6,6 +6,7 @@ import ( "regexp" "sync" + "github.com/fatih/color" "github.com/gwenwindflower/tbd/shared" "github.com/schollz/progressbar/v3" ) @@ -24,7 +25,9 @@ func PutColumnsOnTables(ctx context.Context, ts shared.SourceTables, dbc DbConn) bar := progressbar.NewOptions(len(ts.SourceTables), progressbar.OptionSetWidth(30), progressbar.OptionShowCount(), - progressbar.OptionShowElapsedTimeOnFinish(), + progressbar.OptionOnCompletion(func() { + color.HiGreen("\nSource tables contructed.") + }), progressbar.OptionEnableColorCodes(true), progressbar.OptionSetDescription("🏎️✨"), ) From 0ccad8e05de24c0294cfd354bfaa3424c2618f7c Mon Sep 17 00:00:00 2001 From: gwen windflower Date: Sun, 21 Apr 2024 14:04:34 -0500 Subject: [PATCH 4/6] fix(anthropic llm): Remove 'Stop' from Payload type This was causing Anthropic's API to 400, and we don't need it, it was being passed as null on the other APIs anyway. --- README.md | 19 +++++++++++++------ llm_get_llm.go | 13 ++++++------- llm_get_rate_limiter.go | 10 +++++----- llm_get_response.go | 2 -- llm_infer_column_fields.go | 13 ++++++------- llm_infer_column_fields_test.go | 7 +------ llm_set_description.go | 11 ++++------- main.go | 2 +- 8 files changed, 36 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 913599c..795184a 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ It's designed to be super fast and easy to use with a friendly TUI that fast for ### It's the **_easy button_** for dbt projects. #### Quickstart + ```bash brew tap gwenwindflower/homebrew-tbd brew install tbd @@ -37,6 +38,7 @@ If you're new to dbt, [check out the wiki](https://github.com/gwenwindflower/tbd - [x] DuckDB If you don't have a cloud warehouse, but want to spin up a dbt project with `tbd` I recommend either: + - **BigQuery** β€” they have a generous free tier, authenticating with `gcloud` CLI is super easy, and `tbd` requires very few manual configurations. They also have a ton of great public datasets you can model. - **DuckDB** β€” you can work completely locally and skip the cloud altogether. You will need to find some data, but DuckDB can _very_ easily ingest CSVs, JSON, or Parquet, so if you have some raw data you want to work with, this is a great option as well. @@ -59,6 +61,7 @@ go install github.com/gwenwindflower/tbd@latest That's it! It's a single binary and has no dependencies on `dbt` itself, for maximum speed it operates directly with your warehouse, so you don't even need to have `dbt` installed to use it. That said, it _can_ leverage the profiles in your `~/.dbt/profiles.yml` file if you have them set up, so you can use the same connection information to save yourself some typing. ## πŸ” Warehouse-specific setup + `tbd` at present, for security, only supports SSO methods of authentication. Please check out the below guides for your target warehouse before using `tbd` to ensure a smooth experience. ### ❄️ Snowflake @@ -103,13 +106,15 @@ your_build_dir/ ### πŸ¦™ LLM features -`tbd` has some neat alpha features that are still in development. One of these is the ability to generate documentation and tests for your sources via LLM. It uses [Groq](https://groq.com) running `llama3-70b-8192` to do its inference. It's not perfect, but it's pretty good! It requires setting an environment variable with your Groq API key beforehand that you'll then pass the name of. +`tbd` has some neat alpha features that infer documentation and tests for your columns. There are multiple supported LLMs via API: Groq running Llama 3 70B, Anthropic Claude 3 Opus, and OpenAI GPT-4 Turbo. They have very different rate limits (these are limitations in the API that `tbd` respects): -The biggest thing to flag is that while Groq is in free beta, they have a very low rate limit on their API: 30 requests per minute. The actual inference on Groq is _super_ fast, but for now I've had to rate limit the API calls so it will take a few minutes or quite awhile depending on your schema size. Once Groq is out of beta, I'll remove the rate limit, but you'll of course have to pay for the API calls via your Groq account. +- **Groq** 30 requests per minute +- **Claude 3 Opus** 5 requests per minute +- **GPT-4 Turbo** 500 request per minute -I will _definitely_ be adding other LLM providers in the future, probably Anthropic Claude 3 Opus as the next one so you can choose between maximum quality (Claude) or maximum speed (Groq, when I can remove the rate limit). +As you can see, if you have anything but a very smol schema, you should stick with OpenAI. When Groq ups their rate limit after they're out of beta, that will be the fastest option, but for now, OpenAI is the best bet. The good news is that GPT-4 Turbo is _really_ good at this task (honestly better than Claude Opus) and pretty dang fast! The results are great in my testing. -I'm going to experiment very soon with using structured output conformed to dbt's JSON schema and passing entire tables, rather than iterating through columns, and see how it does with that. If it works that will be significantly faster as it can churn out entire files quickly and the rate limit will be less of a factor. +I'm going to experiment very soon with using structured output conformed to dbt's JSON schema and passing entire tables, rather than iterating through columns, and see how it does with that. If it works that will be significantly faster as it can churn out entire files (and perhaps improve quality through having more context) and the rate limits will be less of a factor. ### 🌊 Example workflows @@ -118,18 +123,20 @@ I'm going to experiment very soon with using structured output conformed to dbt' ## πŸ˜… To Do - [ ] Get to 100% test coverage -- [ ] Add Claude 3 Opus option +- [x] Add Claude 3 Opus option +- [x] Add OpenAI GPT-4 Turbo option - [x] Add support for Snowflake - [x] Add support for BigQuery - [ ] Add support for Redshift - [ ] Add support for Databricks - [ ] Add support for Postgres - [x] Add support for DuckDB +- [x] Add support for MotherDuck - [ ] Build on Linux - [ ] Build on Windows ## πŸ€— Contributing -I welcome Discussions, Issues, and PRs! This is pre-release software and without folks using it and opening Issues or Discussions I won't be able to find the rough edges and smooth them out. So please if you get stuck open an Issue and let's figure out how to fix it! +I welcome Discussions, Issues, and PRs! This is pre-release software and without folks using it and opening Issues or Discussions I won't be able to find the rough edges and smooth them out. So please if you get stuck open an Issue and let's figure out how to fix it! If you're a dbt user and aren't familiar with Go, but interested in learning a bit of it, I'm also happy to help guide you through opening a PR, just let me know πŸ’—. diff --git a/llm_get_llm.go b/llm_get_llm.go index 90ec2fc..d07f7b2 100644 --- a/llm_get_llm.go +++ b/llm_get_llm.go @@ -94,13 +94,12 @@ type Anthropic struct { } type Payload struct { - Stop interface{} `json:"stop"` - Model string `json:"model"` - Messages []Message `json:"messages"` - Temp float64 `json:"temperature"` - MaxTokens int `json:"max_tokens"` - TopP int `json:"top_p"` - Stream bool `json:"stream"` + Model string `json:"model"` + Messages []Message `json:"messages"` + Temp float64 `json:"temperature"` + MaxTokens int `json:"max_tokens"` + TopP int `json:"top_p"` + Stream bool `json:"stream"` } type Message struct { diff --git a/llm_get_rate_limiter.go b/llm_get_rate_limiter.go index bdedbb1..c68245e 100644 --- a/llm_get_rate_limiter.go +++ b/llm_get_rate_limiter.go @@ -3,17 +3,17 @@ package main import "time" func (o *OpenAI) GetRateLimiter() (semaphore chan struct{}, limiter *time.Ticker) { - return getLimiter(240) -} - -func (a *Anthropic) GetRateLimiter() (semaphore chan struct{}, limiter *time.Ticker) { - return getLimiter(240) + return getLimiter(500) } func (g *Groq) GetRateLimiter() (semaphore chan struct{}, limiter *time.Ticker) { return getLimiter(30) } +func (a *Anthropic) GetRateLimiter() (semaphore chan struct{}, limiter *time.Ticker) { + return getLimiter(5) +} + func getLimiter(mr int) (semaphore chan struct{}, limiter *time.Ticker) { i := time.Minute semaphore = make(chan struct{}, (mr / 2)) diff --git a/llm_get_response.go b/llm_get_response.go index 33cc731..744711f 100644 --- a/llm_get_response.go +++ b/llm_get_response.go @@ -21,7 +21,6 @@ func (o *OpenAI) GetResponse(prompt string) error { MaxTokens: 2048, TopP: 1, Stream: false, - Stop: nil, } payload, err := json.Marshal(meta) if err != nil { @@ -66,7 +65,6 @@ func (o *Groq) GetResponse(prompt string) error { MaxTokens: 2048, TopP: 1, Stream: false, - Stop: nil, } payload, err := json.Marshal(meta) if err != nil { diff --git a/llm_infer_column_fields.go b/llm_infer_column_fields.go index 0addc49..d50b6ea 100644 --- a/llm_infer_column_fields.go +++ b/llm_infer_column_fields.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "log" "sync" "github.com/fatih/color" @@ -9,7 +10,7 @@ import ( "github.com/schollz/progressbar/v3" ) -func InferColumnFields(llm Llm, ts shared.SourceTables) error { +func InferColumnFields(llm Llm, ts shared.SourceTables) { var wg sync.WaitGroup semaphore, limiter := llm.GetRateLimiter() defer limiter.Stop() @@ -30,7 +31,7 @@ func InferColumnFields(llm Llm, ts shared.SourceTables) error { <-limiter.C wg.Add(1) - go func(i, j int) error { + go func(i, j int) { defer wg.Done() defer func() { <-semaphore }() @@ -40,19 +41,17 @@ func InferColumnFields(llm Llm, ts shared.SourceTables) error { testsPrompt := fmt.Sprintf(TESTS_PROMPT, tableName, columnName) err := llm.SetDescription(descPrompt, ts, i, j) if err != nil { - return fmt.Errorf("error setting description: %v", err) + log.Fatalf("Error generating descriptions: %v\n", err) } err = llm.SetTests(testsPrompt, ts, i, j) if err != nil { - return fmt.Errorf("error setting tests: %v", err) + log.Fatalf("Error generating tests: %v\n", err) } - return nil + bar.Add(1) }(i, j) - bar.Add(1) } } wg.Wait() - return nil } func countColumns(ts shared.SourceTables) int { diff --git a/llm_infer_column_fields_test.go b/llm_infer_column_fields_test.go index dac5d90..6005bd4 100644 --- a/llm_infer_column_fields_test.go +++ b/llm_infer_column_fields_test.go @@ -20,16 +20,11 @@ func TestInferColumnFields(t *testing.T) { if !ok { t.Error("Expceted Groq LLM type") } - err = InferColumnFields(g, ts) - if err != nil { - t.Errorf("Did not expect err infering column fields: %v", err) - } - + InferColumnFields(g, ts) info := httpmock.GetCallCountInfo() if info["POST https://api.groq.com/openai/v1/chat/completions"] != 2 { t.Error("expected", 2, "got", info["POST https://api.groq.com/openai/v1/chat/completions"]) } - expected := "lord of rivendell" desc := ts.SourceTables[0].Columns[0].Description if desc != expected { diff --git a/llm_set_description.go b/llm_set_description.go index 27af505..0f2e92f 100644 --- a/llm_set_description.go +++ b/llm_set_description.go @@ -14,6 +14,7 @@ func (o *OpenAI) SetDescription(descPrompt string, ts shared.SourceTables, i, j if len(o.Response.Choices) > 0 { ts.SourceTables[i].Columns[j].Description = o.Response.Choices[0].Message.Content } + return nil } @@ -31,16 +32,12 @@ func (g *Groq) SetDescription(descPrompt string, ts shared.SourceTables, i, j in func (a *Anthropic) SetDescription(descPrompt string, ts shared.SourceTables, i, j int) error { err := a.GetResponse(descPrompt) if err != nil { - return fmt.Errorf("failed to get response from Anthropic for description: %v", err) + return fmt.Errorf("failed to get ok response from Anthropic for description: %v", err) } if len(a.Response.Content) == 0 { return fmt.Errorf("no response content, likely bad request") } - resp := a.Response.Content[0] - if len(resp.Text) > 0 && resp.Type == "text" { - ts.SourceTables[i].Columns[j].Description = resp.Text - } else { - return fmt.Errorf("no text response found") - } + respContent := a.Response.Content[0] + ts.SourceTables[i].Columns[j].Description = respContent.Text return nil } diff --git a/main.go b/main.go index f14dc1e..d9dd376 100644 --- a/main.go +++ b/main.go @@ -74,7 +74,7 @@ func main() { fmt.Printf("Error getting API provider: %v\n", err) } fmt.Println("Generating descriptions and tests...") - err = InferColumnFields(llm, ts) + InferColumnFields(llm, ts) if err != nil { // Using Printf instead of log.Fatalf since the program // doesn't need to totally fail if there's an error in the column field inference From 5c2cf447d22bf0836fc3c7c9d704d42de4038334 Mon Sep 17 00:00:00 2001 From: gwen windflower Date: Sun, 21 Apr 2024 14:14:25 -0500 Subject: [PATCH 5/6] style(forms): Remove experimental from LLM page --- forms.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/forms.go b/forms.go index 20663e4..212c781 100644 --- a/forms.go +++ b/forms.go @@ -68,7 +68,6 @@ func Forms(ps DbtProfiles) (FormResponse, error) { greenBold := color.New(color.FgGreen).Add(color.Bold).SprintFunc() yellowItalic := color.New(color.FgHiYellow).Add(color.Italic).SprintFunc() greenBoldItalic := color.New(color.FgHiGreen).Add(color.Bold).SprintFunc() - redBold := color.New(color.FgHiRed).Add(color.Bold).SprintFunc() err := huh.NewForm( huh.NewGroup( huh.NewNote(). @@ -268,12 +267,12 @@ Relative to pwd e.g. if db is in this dir -> cool_ducks.db`). huh.NewGroup( huh.NewNote(). - Title(fmt.Sprintf("πŸ€– %s LLM generation πŸ¦™βœ¨", redBold("Experimental"))). - Description(fmt.Sprintf(`%s generates: + Title(fmt.Sprintf("πŸ€– %s LLM generation πŸ¦™βœ¨", yellowItalic("Optional"))). + Description(fmt.Sprintf(`Infers: ✴︎ column %s ✴︎ relevant %s -_Requires an_ %s _stored in an env var_.`, yellowItalic("Optionally"), pinkUnderline("descriptions"), pinkUnderline("tests"), greenBoldItalic("LLM API key"))), +_Requires an_ %s _stored in an env var_.`, pinkUnderline("descriptions"), pinkUnderline("tests"), greenBoldItalic("API key"))), huh.NewConfirm(). Title("Do you want to infer descriptions and tests?"). Value(&dfr.GenerateDescriptions), From c362d437c3749d7bf9da8bf2624a189c3edf3592 Mon Sep 17 00:00:00 2001 From: gwen windflower Date: Sun, 21 Apr 2024 14:14:50 -0500 Subject: [PATCH 6/6] chore(bump version) --- version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.go b/version.go index 001d21c..f370ef0 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package main -const Version = "0.0.22" +const Version = "0.0.23"