diff --git a/cmd/pggen/pggen.go b/cmd/pggen/pggen.go index b228e814..2a006aed 100644 --- a/cmd/pggen/pggen.go +++ b/cmd/pggen/pggen.go @@ -94,6 +94,8 @@ func newGenCmd() *ffcli.Command { goTypes := flags.Strings(fset, "go-type", nil, "custom type mapping from Postgres to fully qualified Go type, "+ "like 'device_type=github.com/jschaf/pggen.DeviceType'") + inlineParamCount := fset.Int("inline-param-count", 2, + "number of params (inclusive) to inline when calling querier methods; 0 always generates a struct") logLvl := zap.InfoLevel fset.Var(&logLvl, "log", "log level: debug, info, or error") goSubCmd := &ffcli.Command{ @@ -159,14 +161,15 @@ func newGenCmd() *ffcli.Command { // Codegen. err = pggen.Generate(pggen.GenerateOptions{ - Language: pggen.LangGo, - ConnString: *postgresConn, - SchemaFiles: schemas, - QueryFiles: queries, - OutputDir: outDir, - Acronyms: acros, - TypeOverrides: typeOverrides, - LogLevel: logLvl, + Language: pggen.LangGo, + ConnString: *postgresConn, + SchemaFiles: schemas, + QueryFiles: queries, + OutputDir: outDir, + Acronyms: acros, + TypeOverrides: typeOverrides, + LogLevel: logLvl, + InlineParamCount: *inlineParamCount, }) if err != nil { return err diff --git a/example/acceptance_test.go b/example/acceptance_test.go index 7e9f1069..f45842da 100644 --- a/example/acceptance_test.go +++ b/example/acceptance_test.go @@ -1,5 +1,4 @@ //go:build acceptance_test -// +build acceptance_test package example @@ -140,6 +139,42 @@ func TestExamples(t *testing.T) { "--go-type", "_int4=[]int", }, }, + { + name: "example/inline_param_count/inline0", + args: []string{ + "--schema-glob", "example/inline_param_count/schema.sql", + "--query-glob", "example/inline_param_count/query.sql", + "--output-dir", "example/inline_param_count/inline0", + "--inline-param-count", "0", + }, + }, + { + name: "example/inline_param_count/inline1", + args: []string{ + "--schema-glob", "example/inline_param_count/schema.sql", + "--query-glob", "example/inline_param_count/query.sql", + "--output-dir", "example/inline_param_count/inline1", + "--inline-param-count", "1", + }, + }, + { + name: "example/inline_param_count/inline2", + args: []string{ + "--schema-glob", "example/inline_param_count/schema.sql", + "--query-glob", "example/inline_param_count/query.sql", + "--output-dir", "example/inline_param_count/inline2", + "--inline-param-count", "2", + }, + }, + { + name: "example/inline_param_count/inline3", + args: []string{ + "--schema-glob", "example/inline_param_count/schema.sql", + "--query-glob", "example/inline_param_count/query.sql", + "--output-dir", "example/inline_param_count/inline3", + "--inline-param-count", "3", + }, + }, { name: "example/ltree", args: []string{ @@ -251,7 +286,7 @@ func TestExamples(t *testing.T) { args := append(tt.args, "--postgres-connection", connStr) runPggen(t, pggen, args...) if !*update { - assertNoDiff(t) + assertNoGitDiff(t) } }) } @@ -263,7 +298,7 @@ func runPggen(t *testing.T, pggen string, args ...string) string { Args: append([]string{pggen, "gen", "go"}, args...), Dir: projDir, } - t.Log("running pggen") + t.Logf("running pggen: %s", cmd.String()) output, err := cmd.CombinedOutput() if err != nil { t.Log("pggen output:\n" + string(bytes.TrimSpace(output))) @@ -300,7 +335,7 @@ var ( gitBinOnce = &sync.Once{} ) -func assertNoDiff(t *testing.T) { +func assertNoGitDiff(t *testing.T) { gitBinOnce.Do(func() { gitBin, gitBinErr = exec.LookPath("git") if gitBinErr != nil { diff --git a/example/author/codegen_test.go b/example/author/codegen_test.go index 43a66a8b..55d47855 100644 --- a/example/author/codegen_test.go +++ b/example/author/codegen_test.go @@ -16,11 +16,12 @@ func TestGenerate_Go_Example_Author(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "author", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "author", + Language: pggen.LangGo, + InlineParamCount: 2, }) if err != nil { t.Fatalf("Generate() example/author: %s", err) diff --git a/example/complex_params/codegen_test.go b/example/complex_params/codegen_test.go index ce779907..b4e6e7f2 100644 --- a/example/complex_params/codegen_test.go +++ b/example/complex_params/codegen_test.go @@ -16,11 +16,12 @@ func TestGenerate_Go_Example_ComplexParams(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "complex_params", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "complex_params", + Language: pggen.LangGo, + InlineParamCount: 2, TypeOverrides: map[string]string{ "int4": "int", "text": "string", diff --git a/example/composite/codegen_test.go b/example/composite/codegen_test.go index 879035e8..68334b1f 100644 --- a/example/composite/codegen_test.go +++ b/example/composite/codegen_test.go @@ -16,11 +16,12 @@ func TestGenerate_Go_Example_Composite(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "composite", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "composite", + Language: pggen.LangGo, + InlineParamCount: 2, TypeOverrides: map[string]string{ "_bool": "[]bool", "bool": "bool", diff --git a/example/custom_types/codegen_test.go b/example/custom_types/codegen_test.go index 4b28b56f..e80b54af 100644 --- a/example/custom_types/codegen_test.go +++ b/example/custom_types/codegen_test.go @@ -16,11 +16,12 @@ func TestGenerate_Go_Example_CustomTypes(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "custom_types", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "custom_types", + Language: pggen.LangGo, + InlineParamCount: 2, TypeOverrides: map[string]string{ "text": "github.com/jschaf/pggen/example/custom_types/mytype.String", "int8": "github.com/jschaf/pggen/example/custom_types.CustomInt", diff --git a/example/device/codegen_test.go b/example/device/codegen_test.go index e6e015fa..a5582b4b 100644 --- a/example/device/codegen_test.go +++ b/example/device/codegen_test.go @@ -16,11 +16,12 @@ func TestGenerate_Go_Example_Device(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "device", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "device", + Language: pggen.LangGo, + InlineParamCount: 2, }) if err != nil { t.Fatalf("Generate() example/device: %s", err) diff --git a/example/domain/codegen_test.go b/example/domain/codegen_test.go index bcc90aef..b5565673 100644 --- a/example/domain/codegen_test.go +++ b/example/domain/codegen_test.go @@ -16,11 +16,12 @@ func TestGenerate_Go_Example_Domain(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "domain", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "domain", + Language: pggen.LangGo, + InlineParamCount: 2, }) if err != nil { t.Fatalf("Generate() example/domain: %s", err) diff --git a/example/enums/codegen_test.go b/example/enums/codegen_test.go index 868690c4..43994fb4 100644 --- a/example/enums/codegen_test.go +++ b/example/enums/codegen_test.go @@ -16,11 +16,12 @@ func TestGenerate_Go_Example_Enums(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "enums", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "enums", + Language: pggen.LangGo, + InlineParamCount: 2, }) if err != nil { t.Fatalf("Generate() example/enums: %s", err) diff --git a/example/erp/order/codegen_test.go b/example/erp/order/codegen_test.go index deb18b18..e7d09153 100644 --- a/example/erp/order/codegen_test.go +++ b/example/erp/order/codegen_test.go @@ -9,7 +9,7 @@ import ( "testing" ) -func TestGenerate_Go_Example_Order(t *testing.T) { +func TestGenerate_Go_Example_ERP_Order(t *testing.T) { conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{ "../01_schema.sql", "../02_schema.sql", @@ -24,11 +24,12 @@ func TestGenerate_Go_Example_Order(t *testing.T) { "customer.sql", "price.sql", }, - OutputDir: tmpDir, - GoPackage: "order", - Language: pggen.LangGo, - Acronyms: map[string]string{"mrr": "MRR"}, - TypeOverrides: map[string]string{"tenant_id": "int"}, + OutputDir: tmpDir, + GoPackage: "order", + Language: pggen.LangGo, + InlineParamCount: 2, + Acronyms: map[string]string{"mrr": "MRR"}, + TypeOverrides: map[string]string{"tenant_id": "int"}, }) if err != nil { t.Fatalf("Generate() example/erp/order: %s", err) diff --git a/example/function/codegen_test.go b/example/function/codegen_test.go index 1ca77016..313bb3f1 100644 --- a/example/function/codegen_test.go +++ b/example/function/codegen_test.go @@ -16,11 +16,12 @@ func TestGenerate_Go_Example_Function(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "function", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "function", + Language: pggen.LangGo, + InlineParamCount: 2, }) if err != nil { t.Fatalf("Generate() example/function: %s", err) diff --git a/example/go_pointer_types/codegen_test.go b/example/go_pointer_types/codegen_test.go index 5a07ae20..775697de 100644 --- a/example/go_pointer_types/codegen_test.go +++ b/example/go_pointer_types/codegen_test.go @@ -16,11 +16,12 @@ func TestGenerate_Go_Example_GoPointerTypes(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "go_pointer_types", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "go_pointer_types", + Language: pggen.LangGo, + InlineParamCount: 2, TypeOverrides: map[string]string{ "int4": "*int", "_int4": "[]int", diff --git a/example/inline_param_count/codegen_test.go b/example/inline_param_count/codegen_test.go new file mode 100644 index 00000000..1599ebd6 --- /dev/null +++ b/example/inline_param_count/codegen_test.go @@ -0,0 +1,90 @@ +package author + +import ( + "github.com/jschaf/pggen" + "github.com/jschaf/pggen/internal/pgtest" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "testing" +) + +func TestGenerate_Go_Example_InlineParamCount(t *testing.T) { + conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) + defer cleanupFunc() + + tests := []struct { + name string + opts pggen.GenerateOptions + wantQueryPath string + }{ + { + name: "inline0", + opts: pggen.GenerateOptions{ + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + GoPackage: "inline0", + Language: pggen.LangGo, + InlineParamCount: 0, + }, + wantQueryPath: "inline0/query.sql.go", + }, + { + name: "inline1", + opts: pggen.GenerateOptions{ + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + GoPackage: "inline1", + Language: pggen.LangGo, + InlineParamCount: 1, + }, + wantQueryPath: "inline1/query.sql.go", + }, + { + name: "inline2", + opts: pggen.GenerateOptions{ + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + GoPackage: "inline2", + Language: pggen.LangGo, + InlineParamCount: 2, + }, + wantQueryPath: "inline2/query.sql.go", + }, + { + name: "inline3", + opts: pggen.GenerateOptions{ + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + GoPackage: "inline3", + Language: pggen.LangGo, + InlineParamCount: 3, + }, + wantQueryPath: "inline3/query.sql.go", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + tt.opts.OutputDir = tmpDir + err := pggen.Generate(tt.opts) + if err != nil { + t.Fatalf("Generate() example/author %s: %s", tt.name, err.Error()) + } + + gotQueryFile := filepath.Join(tmpDir, "query.sql.go") + assert.FileExists(t, gotQueryFile, "Generate() should emit query.sql.go") + wantQueries, err := os.ReadFile(tt.wantQueryPath) + if err != nil { + t.Fatalf("read wanted query.go.sql: %s", err) + } + gotQueries, err := os.ReadFile(gotQueryFile) + if err != nil { + t.Fatalf("read generated query.go.sql: %s", err) + } + assert.Equalf(t, string(wantQueries), string(gotQueries), + "Got file %s; does not match contents of %s", + gotQueryFile, tt.wantQueryPath) + }) + } +} diff --git a/example/inline_param_count/inline0/query.sql.go b/example/inline_param_count/inline0/query.sql.go new file mode 100644 index 00000000..88319cb8 --- /dev/null +++ b/example/inline_param_count/inline0/query.sql.go @@ -0,0 +1,339 @@ +// Code generated by pggen. DO NOT EDIT. + +package inline0 + +import ( + "context" + "fmt" + "github.com/jackc/pgconn" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" +) + +// Querier is a typesafe Go interface backed by SQL queries. +// +// Methods ending with Batch enqueue a query to run later in a pgx.Batch. After +// calling SendBatch on pgx.Conn, pgxpool.Pool, or pgx.Tx, use the Scan methods +// to parse the results. +type Querier interface { + // CountAuthors returns the number of authors (zero params). + CountAuthors(ctx context.Context) (*int, error) + // CountAuthorsBatch enqueues a CountAuthors query into batch to be executed + // later by the batch. + CountAuthorsBatch(batch genericBatch) + // CountAuthorsScan scans the result of an executed CountAuthorsBatch query. + CountAuthorsScan(results pgx.BatchResults) (*int, error) + + // FindAuthorById finds one (or zero) authors by ID (one param). + FindAuthorByID(ctx context.Context, params FindAuthorByIDParams) (FindAuthorByIDRow, error) + // FindAuthorByIDBatch enqueues a FindAuthorByID query into batch to be executed + // later by the batch. + FindAuthorByIDBatch(batch genericBatch, params FindAuthorByIDParams) + // FindAuthorByIDScan scans the result of an executed FindAuthorByIDBatch query. + FindAuthorByIDScan(results pgx.BatchResults) (FindAuthorByIDRow, error) + + // InsertAuthor inserts an author by name and returns the ID (two params). + InsertAuthor(ctx context.Context, params InsertAuthorParams) (int32, error) + // InsertAuthorBatch enqueues a InsertAuthor query into batch to be executed + // later by the batch. + InsertAuthorBatch(batch genericBatch, params InsertAuthorParams) + // InsertAuthorScan scans the result of an executed InsertAuthorBatch query. + InsertAuthorScan(results pgx.BatchResults) (int32, error) + + // DeleteAuthorsByFullName deletes authors by the full name (three params). + DeleteAuthorsByFullName(ctx context.Context, params DeleteAuthorsByFullNameParams) (pgconn.CommandTag, error) + // DeleteAuthorsByFullNameBatch enqueues a DeleteAuthorsByFullName query into batch to be executed + // later by the batch. + DeleteAuthorsByFullNameBatch(batch genericBatch, params DeleteAuthorsByFullNameParams) + // DeleteAuthorsByFullNameScan scans the result of an executed DeleteAuthorsByFullNameBatch query. + DeleteAuthorsByFullNameScan(results pgx.BatchResults) (pgconn.CommandTag, error) +} + +type DBQuerier struct { + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name +} + +var _ Querier = &DBQuerier{} + +// genericConn is a connection to a Postgres database. This is usually backed by +// *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +type genericConn interface { + // Query executes sql with args. If there is an error the returned Rows will + // be returned in an error state. So it is allowed to ignore the error + // returned from Query and handle it in Rows. + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) + + // QueryRow is a convenience wrapper over Query. Any error that occurs while + // querying is deferred until calling Scan on the returned Row. That Row will + // error with pgx.ErrNoRows if no rows are returned. + QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row + + // Exec executes sql. sql can be either a prepared statement name or an SQL + // string. arguments should be referenced positionally from the sql string + // as $1, $2, etc. + Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) +} + +// genericBatch batches queries to send in a single network request to a +// Postgres server. This is usually backed by *pgx.Batch. +type genericBatch interface { + // Queue queues a query to batch b. query can be an SQL query or the name of a + // prepared statement. See Queue on *pgx.Batch. + Queue(query string, arguments ...interface{}) +} + +// NewQuerier creates a DBQuerier that implements Querier. conn is typically +// *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerier(conn genericConn) *DBQuerier { + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead + // of pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its + // transitive dependencies, pggen will use the binary encoding format for + // the input parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} +} + +// WithTx creates a new DBQuerier that uses the transaction to run all queries. +func (q *DBQuerier) WithTx(tx pgx.Tx) (*DBQuerier, error) { + return &DBQuerier{conn: tx}, nil +} + +// preparer is any Postgres connection transport that provides a way to prepare +// a statement, most commonly *pgx.Conn. +type preparer interface { + Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) +} + +// PrepareAllQueries executes a PREPARE statement for all pggen generated SQL +// queries in querier files. Typical usage is as the AfterConnect callback +// for pgxpool.Config +// +// pgx will use the prepared statement if available. Calling PrepareAllQueries +// is an optional optimization to avoid a network round-trip the first time pgx +// runs a query if pgx statement caching is enabled. +func PrepareAllQueries(ctx context.Context, p preparer) error { + if _, err := p.Prepare(ctx, countAuthorsSQL, countAuthorsSQL); err != nil { + return fmt.Errorf("prepare query 'CountAuthors': %w", err) + } + if _, err := p.Prepare(ctx, findAuthorByIDSQL, findAuthorByIDSQL); err != nil { + return fmt.Errorf("prepare query 'FindAuthorByID': %w", err) + } + if _, err := p.Prepare(ctx, insertAuthorSQL, insertAuthorSQL); err != nil { + return fmt.Errorf("prepare query 'InsertAuthor': %w", err) + } + if _, err := p.Prepare(ctx, deleteAuthorsByFullNameSQL, deleteAuthorsByFullNameSQL); err != nil { + return fmt.Errorf("prepare query 'DeleteAuthorsByFullName': %w", err) + } + return nil +} + +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + +const countAuthorsSQL = `SELECT count(*) FROM author;` + +// CountAuthors implements Querier.CountAuthors. +func (q *DBQuerier) CountAuthors(ctx context.Context) (*int, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "CountAuthors") + row := q.conn.QueryRow(ctx, countAuthorsSQL) + var item *int + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("query CountAuthors: %w", err) + } + return item, nil +} + +// CountAuthorsBatch implements Querier.CountAuthorsBatch. +func (q *DBQuerier) CountAuthorsBatch(batch genericBatch) { + batch.Queue(countAuthorsSQL) +} + +// CountAuthorsScan implements Querier.CountAuthorsScan. +func (q *DBQuerier) CountAuthorsScan(results pgx.BatchResults) (*int, error) { + row := results.QueryRow() + var item *int + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("scan CountAuthorsBatch row: %w", err) + } + return item, nil +} + +const findAuthorByIDSQL = `SELECT * FROM author WHERE author_id = $1;` + +type FindAuthorByIDParams struct { + AuthorID int32 +} + +type FindAuthorByIDRow struct { + AuthorID int32 `json:"author_id"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Suffix *string `json:"suffix"` +} + +// FindAuthorByID implements Querier.FindAuthorByID. +func (q *DBQuerier) FindAuthorByID(ctx context.Context, params FindAuthorByIDParams) (FindAuthorByIDRow, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "FindAuthorByID") + row := q.conn.QueryRow(ctx, findAuthorByIDSQL, params.AuthorID) + var item FindAuthorByIDRow + if err := row.Scan(&item.AuthorID, &item.FirstName, &item.LastName, &item.Suffix); err != nil { + return item, fmt.Errorf("query FindAuthorByID: %w", err) + } + return item, nil +} + +// FindAuthorByIDBatch implements Querier.FindAuthorByIDBatch. +func (q *DBQuerier) FindAuthorByIDBatch(batch genericBatch, params FindAuthorByIDParams) { + batch.Queue(findAuthorByIDSQL, params.AuthorID) +} + +// FindAuthorByIDScan implements Querier.FindAuthorByIDScan. +func (q *DBQuerier) FindAuthorByIDScan(results pgx.BatchResults) (FindAuthorByIDRow, error) { + row := results.QueryRow() + var item FindAuthorByIDRow + if err := row.Scan(&item.AuthorID, &item.FirstName, &item.LastName, &item.Suffix); err != nil { + return item, fmt.Errorf("scan FindAuthorByIDBatch row: %w", err) + } + return item, nil +} + +const insertAuthorSQL = `INSERT INTO author (first_name, last_name) +VALUES ($1, $2) +RETURNING author_id;` + +type InsertAuthorParams struct { + FirstName string + LastName string +} + +// InsertAuthor implements Querier.InsertAuthor. +func (q *DBQuerier) InsertAuthor(ctx context.Context, params InsertAuthorParams) (int32, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "InsertAuthor") + row := q.conn.QueryRow(ctx, insertAuthorSQL, params.FirstName, params.LastName) + var item int32 + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("query InsertAuthor: %w", err) + } + return item, nil +} + +// InsertAuthorBatch implements Querier.InsertAuthorBatch. +func (q *DBQuerier) InsertAuthorBatch(batch genericBatch, params InsertAuthorParams) { + batch.Queue(insertAuthorSQL, params.FirstName, params.LastName) +} + +// InsertAuthorScan implements Querier.InsertAuthorScan. +func (q *DBQuerier) InsertAuthorScan(results pgx.BatchResults) (int32, error) { + row := results.QueryRow() + var item int32 + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("scan InsertAuthorBatch row: %w", err) + } + return item, nil +} + +const deleteAuthorsByFullNameSQL = `DELETE +FROM author +WHERE first_name = $1 + AND last_name = $2 + AND CASE WHEN $3 = '' THEN suffix IS NULL ELSE suffix = $3 END;` + +type DeleteAuthorsByFullNameParams struct { + FirstName string + LastName string + Suffix string +} + +// DeleteAuthorsByFullName implements Querier.DeleteAuthorsByFullName. +func (q *DBQuerier) DeleteAuthorsByFullName(ctx context.Context, params DeleteAuthorsByFullNameParams) (pgconn.CommandTag, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "DeleteAuthorsByFullName") + cmdTag, err := q.conn.Exec(ctx, deleteAuthorsByFullNameSQL, params.FirstName, params.LastName, params.Suffix) + if err != nil { + return cmdTag, fmt.Errorf("exec query DeleteAuthorsByFullName: %w", err) + } + return cmdTag, err +} + +// DeleteAuthorsByFullNameBatch implements Querier.DeleteAuthorsByFullNameBatch. +func (q *DBQuerier) DeleteAuthorsByFullNameBatch(batch genericBatch, params DeleteAuthorsByFullNameParams) { + batch.Queue(deleteAuthorsByFullNameSQL, params.FirstName, params.LastName, params.Suffix) +} + +// DeleteAuthorsByFullNameScan implements Querier.DeleteAuthorsByFullNameScan. +func (q *DBQuerier) DeleteAuthorsByFullNameScan(results pgx.BatchResults) (pgconn.CommandTag, error) { + cmdTag, err := results.Exec() + if err != nil { + return cmdTag, fmt.Errorf("exec DeleteAuthorsByFullNameBatch: %w", err) + } + return cmdTag, err +} + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{ValueTranscoder: pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), typeName: t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/inline_param_count/inline0/query.sql_test.go b/example/inline_param_count/inline0/query.sql_test.go new file mode 100644 index 00000000..9c8e5181 --- /dev/null +++ b/example/inline_param_count/inline0/query.sql_test.go @@ -0,0 +1,91 @@ +package inline0 + +import ( + "context" + "errors" + "github.com/jschaf/pggen/internal/errs" + "github.com/stretchr/testify/require" + "testing" + + "github.com/jackc/pgx/v4" + "github.com/jschaf/pggen/internal/pgtest" + "github.com/stretchr/testify/assert" +) + +func TestNewQuerier_FindAuthorByID(t *testing.T) { + conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) + defer cleanup() + + q := NewQuerier(conn) + adamsID := insertAuthor(t, q, "john", "adams") + insertAuthor(t, q, "george", "washington") + + t.Run("CountAuthors two", func(t *testing.T) { + got, err := q.CountAuthors(context.Background()) + require.NoError(t, err) + assert.Equal(t, 2, *got) + }) + + t.Run("FindAuthorByID", func(t *testing.T) { + authorByID, err := q.FindAuthorByID(context.Background(), FindAuthorByIDParams{AuthorID: adamsID}) + require.NoError(t, err) + assert.Equal(t, FindAuthorByIDRow{ + AuthorID: adamsID, + FirstName: "john", + LastName: "adams", + Suffix: nil, + }, authorByID) + }) + + t.Run("FindAuthorByID - none-exists", func(t *testing.T) { + missingAuthorByID, err := q.FindAuthorByID(context.Background(), FindAuthorByIDParams{AuthorID: 888}) + require.Error(t, err, "expected error when finding author ID that doesn't match") + assert.Zero(t, missingAuthorByID, "expected zero value when error") + if !errors.Is(err, pgx.ErrNoRows) { + t.Fatalf("expected no rows error to wrap pgx.ErrNoRows; got %s", err) + } + }) + + t.Run("FindAuthorByIDBatch", func(t *testing.T) { + batch := &pgx.Batch{} + q.FindAuthorByIDBatch(batch, FindAuthorByIDParams{AuthorID: adamsID}) + results := conn.SendBatch(context.Background(), batch) + defer errs.CaptureT(t, results.Close, "close batch results") + authors, err := q.FindAuthorByIDScan(results) + require.NoError(t, err) + assert.Equal(t, FindAuthorByIDRow{ + AuthorID: adamsID, + FirstName: "john", + LastName: "adams", + Suffix: nil, + }, authors) + }) +} + +func TestNewQuerier_DeleteAuthorsByFullName(t *testing.T) { + conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) + defer cleanup() + q := NewQuerier(conn) + insertAuthor(t, q, "george", "washington") + + t.Run("DeleteAuthorsByFullName", func(t *testing.T) { + tag, err := q.DeleteAuthorsByFullName(context.Background(), DeleteAuthorsByFullNameParams{ + FirstName: "george", + LastName: "washington", + Suffix: "", + }) + require.NoError(t, err) + assert.Truef(t, tag.Delete(), "expected delete tag; got %s", tag.String()) + assert.Equal(t, int64(1), tag.RowsAffected()) + }) +} + +func insertAuthor(t *testing.T, q *DBQuerier, first, last string) int32 { + t.Helper() + authorID, err := q.InsertAuthor(context.Background(), InsertAuthorParams{ + FirstName: first, + LastName: last, + }) + require.NoError(t, err, "insert author") + return authorID +} diff --git a/example/inline_param_count/inline1/query.sql.go b/example/inline_param_count/inline1/query.sql.go new file mode 100644 index 00000000..f2631191 --- /dev/null +++ b/example/inline_param_count/inline1/query.sql.go @@ -0,0 +1,335 @@ +// Code generated by pggen. DO NOT EDIT. + +package inline1 + +import ( + "context" + "fmt" + "github.com/jackc/pgconn" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" +) + +// Querier is a typesafe Go interface backed by SQL queries. +// +// Methods ending with Batch enqueue a query to run later in a pgx.Batch. After +// calling SendBatch on pgx.Conn, pgxpool.Pool, or pgx.Tx, use the Scan methods +// to parse the results. +type Querier interface { + // CountAuthors returns the number of authors (zero params). + CountAuthors(ctx context.Context) (*int, error) + // CountAuthorsBatch enqueues a CountAuthors query into batch to be executed + // later by the batch. + CountAuthorsBatch(batch genericBatch) + // CountAuthorsScan scans the result of an executed CountAuthorsBatch query. + CountAuthorsScan(results pgx.BatchResults) (*int, error) + + // FindAuthorById finds one (or zero) authors by ID (one param). + FindAuthorByID(ctx context.Context, authorID int32) (FindAuthorByIDRow, error) + // FindAuthorByIDBatch enqueues a FindAuthorByID query into batch to be executed + // later by the batch. + FindAuthorByIDBatch(batch genericBatch, authorID int32) + // FindAuthorByIDScan scans the result of an executed FindAuthorByIDBatch query. + FindAuthorByIDScan(results pgx.BatchResults) (FindAuthorByIDRow, error) + + // InsertAuthor inserts an author by name and returns the ID (two params). + InsertAuthor(ctx context.Context, params InsertAuthorParams) (int32, error) + // InsertAuthorBatch enqueues a InsertAuthor query into batch to be executed + // later by the batch. + InsertAuthorBatch(batch genericBatch, params InsertAuthorParams) + // InsertAuthorScan scans the result of an executed InsertAuthorBatch query. + InsertAuthorScan(results pgx.BatchResults) (int32, error) + + // DeleteAuthorsByFullName deletes authors by the full name (three params). + DeleteAuthorsByFullName(ctx context.Context, params DeleteAuthorsByFullNameParams) (pgconn.CommandTag, error) + // DeleteAuthorsByFullNameBatch enqueues a DeleteAuthorsByFullName query into batch to be executed + // later by the batch. + DeleteAuthorsByFullNameBatch(batch genericBatch, params DeleteAuthorsByFullNameParams) + // DeleteAuthorsByFullNameScan scans the result of an executed DeleteAuthorsByFullNameBatch query. + DeleteAuthorsByFullNameScan(results pgx.BatchResults) (pgconn.CommandTag, error) +} + +type DBQuerier struct { + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name +} + +var _ Querier = &DBQuerier{} + +// genericConn is a connection to a Postgres database. This is usually backed by +// *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +type genericConn interface { + // Query executes sql with args. If there is an error the returned Rows will + // be returned in an error state. So it is allowed to ignore the error + // returned from Query and handle it in Rows. + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) + + // QueryRow is a convenience wrapper over Query. Any error that occurs while + // querying is deferred until calling Scan on the returned Row. That Row will + // error with pgx.ErrNoRows if no rows are returned. + QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row + + // Exec executes sql. sql can be either a prepared statement name or an SQL + // string. arguments should be referenced positionally from the sql string + // as $1, $2, etc. + Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) +} + +// genericBatch batches queries to send in a single network request to a +// Postgres server. This is usually backed by *pgx.Batch. +type genericBatch interface { + // Queue queues a query to batch b. query can be an SQL query or the name of a + // prepared statement. See Queue on *pgx.Batch. + Queue(query string, arguments ...interface{}) +} + +// NewQuerier creates a DBQuerier that implements Querier. conn is typically +// *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerier(conn genericConn) *DBQuerier { + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead + // of pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its + // transitive dependencies, pggen will use the binary encoding format for + // the input parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} +} + +// WithTx creates a new DBQuerier that uses the transaction to run all queries. +func (q *DBQuerier) WithTx(tx pgx.Tx) (*DBQuerier, error) { + return &DBQuerier{conn: tx}, nil +} + +// preparer is any Postgres connection transport that provides a way to prepare +// a statement, most commonly *pgx.Conn. +type preparer interface { + Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) +} + +// PrepareAllQueries executes a PREPARE statement for all pggen generated SQL +// queries in querier files. Typical usage is as the AfterConnect callback +// for pgxpool.Config +// +// pgx will use the prepared statement if available. Calling PrepareAllQueries +// is an optional optimization to avoid a network round-trip the first time pgx +// runs a query if pgx statement caching is enabled. +func PrepareAllQueries(ctx context.Context, p preparer) error { + if _, err := p.Prepare(ctx, countAuthorsSQL, countAuthorsSQL); err != nil { + return fmt.Errorf("prepare query 'CountAuthors': %w", err) + } + if _, err := p.Prepare(ctx, findAuthorByIDSQL, findAuthorByIDSQL); err != nil { + return fmt.Errorf("prepare query 'FindAuthorByID': %w", err) + } + if _, err := p.Prepare(ctx, insertAuthorSQL, insertAuthorSQL); err != nil { + return fmt.Errorf("prepare query 'InsertAuthor': %w", err) + } + if _, err := p.Prepare(ctx, deleteAuthorsByFullNameSQL, deleteAuthorsByFullNameSQL); err != nil { + return fmt.Errorf("prepare query 'DeleteAuthorsByFullName': %w", err) + } + return nil +} + +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + +const countAuthorsSQL = `SELECT count(*) FROM author;` + +// CountAuthors implements Querier.CountAuthors. +func (q *DBQuerier) CountAuthors(ctx context.Context) (*int, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "CountAuthors") + row := q.conn.QueryRow(ctx, countAuthorsSQL) + var item *int + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("query CountAuthors: %w", err) + } + return item, nil +} + +// CountAuthorsBatch implements Querier.CountAuthorsBatch. +func (q *DBQuerier) CountAuthorsBatch(batch genericBatch) { + batch.Queue(countAuthorsSQL) +} + +// CountAuthorsScan implements Querier.CountAuthorsScan. +func (q *DBQuerier) CountAuthorsScan(results pgx.BatchResults) (*int, error) { + row := results.QueryRow() + var item *int + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("scan CountAuthorsBatch row: %w", err) + } + return item, nil +} + +const findAuthorByIDSQL = `SELECT * FROM author WHERE author_id = $1;` + +type FindAuthorByIDRow struct { + AuthorID int32 `json:"author_id"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Suffix *string `json:"suffix"` +} + +// FindAuthorByID implements Querier.FindAuthorByID. +func (q *DBQuerier) FindAuthorByID(ctx context.Context, authorID int32) (FindAuthorByIDRow, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "FindAuthorByID") + row := q.conn.QueryRow(ctx, findAuthorByIDSQL, authorID) + var item FindAuthorByIDRow + if err := row.Scan(&item.AuthorID, &item.FirstName, &item.LastName, &item.Suffix); err != nil { + return item, fmt.Errorf("query FindAuthorByID: %w", err) + } + return item, nil +} + +// FindAuthorByIDBatch implements Querier.FindAuthorByIDBatch. +func (q *DBQuerier) FindAuthorByIDBatch(batch genericBatch, authorID int32) { + batch.Queue(findAuthorByIDSQL, authorID) +} + +// FindAuthorByIDScan implements Querier.FindAuthorByIDScan. +func (q *DBQuerier) FindAuthorByIDScan(results pgx.BatchResults) (FindAuthorByIDRow, error) { + row := results.QueryRow() + var item FindAuthorByIDRow + if err := row.Scan(&item.AuthorID, &item.FirstName, &item.LastName, &item.Suffix); err != nil { + return item, fmt.Errorf("scan FindAuthorByIDBatch row: %w", err) + } + return item, nil +} + +const insertAuthorSQL = `INSERT INTO author (first_name, last_name) +VALUES ($1, $2) +RETURNING author_id;` + +type InsertAuthorParams struct { + FirstName string + LastName string +} + +// InsertAuthor implements Querier.InsertAuthor. +func (q *DBQuerier) InsertAuthor(ctx context.Context, params InsertAuthorParams) (int32, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "InsertAuthor") + row := q.conn.QueryRow(ctx, insertAuthorSQL, params.FirstName, params.LastName) + var item int32 + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("query InsertAuthor: %w", err) + } + return item, nil +} + +// InsertAuthorBatch implements Querier.InsertAuthorBatch. +func (q *DBQuerier) InsertAuthorBatch(batch genericBatch, params InsertAuthorParams) { + batch.Queue(insertAuthorSQL, params.FirstName, params.LastName) +} + +// InsertAuthorScan implements Querier.InsertAuthorScan. +func (q *DBQuerier) InsertAuthorScan(results pgx.BatchResults) (int32, error) { + row := results.QueryRow() + var item int32 + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("scan InsertAuthorBatch row: %w", err) + } + return item, nil +} + +const deleteAuthorsByFullNameSQL = `DELETE +FROM author +WHERE first_name = $1 + AND last_name = $2 + AND CASE WHEN $3 = '' THEN suffix IS NULL ELSE suffix = $3 END;` + +type DeleteAuthorsByFullNameParams struct { + FirstName string + LastName string + Suffix string +} + +// DeleteAuthorsByFullName implements Querier.DeleteAuthorsByFullName. +func (q *DBQuerier) DeleteAuthorsByFullName(ctx context.Context, params DeleteAuthorsByFullNameParams) (pgconn.CommandTag, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "DeleteAuthorsByFullName") + cmdTag, err := q.conn.Exec(ctx, deleteAuthorsByFullNameSQL, params.FirstName, params.LastName, params.Suffix) + if err != nil { + return cmdTag, fmt.Errorf("exec query DeleteAuthorsByFullName: %w", err) + } + return cmdTag, err +} + +// DeleteAuthorsByFullNameBatch implements Querier.DeleteAuthorsByFullNameBatch. +func (q *DBQuerier) DeleteAuthorsByFullNameBatch(batch genericBatch, params DeleteAuthorsByFullNameParams) { + batch.Queue(deleteAuthorsByFullNameSQL, params.FirstName, params.LastName, params.Suffix) +} + +// DeleteAuthorsByFullNameScan implements Querier.DeleteAuthorsByFullNameScan. +func (q *DBQuerier) DeleteAuthorsByFullNameScan(results pgx.BatchResults) (pgconn.CommandTag, error) { + cmdTag, err := results.Exec() + if err != nil { + return cmdTag, fmt.Errorf("exec DeleteAuthorsByFullNameBatch: %w", err) + } + return cmdTag, err +} + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{ValueTranscoder: pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), typeName: t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/inline_param_count/inline1/query.sql_test.go b/example/inline_param_count/inline1/query.sql_test.go new file mode 100644 index 00000000..92792e56 --- /dev/null +++ b/example/inline_param_count/inline1/query.sql_test.go @@ -0,0 +1,91 @@ +package inline1 + +import ( + "context" + "errors" + "github.com/jschaf/pggen/internal/errs" + "github.com/stretchr/testify/require" + "testing" + + "github.com/jackc/pgx/v4" + "github.com/jschaf/pggen/internal/pgtest" + "github.com/stretchr/testify/assert" +) + +func TestNewQuerier_FindAuthorByID(t *testing.T) { + conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) + defer cleanup() + + q := NewQuerier(conn) + adamsID := insertAuthor(t, q, "john", "adams") + insertAuthor(t, q, "george", "washington") + + t.Run("CountAuthors two", func(t *testing.T) { + got, err := q.CountAuthors(context.Background()) + require.NoError(t, err) + assert.Equal(t, 2, *got) + }) + + t.Run("FindAuthorByID", func(t *testing.T) { + authorByID, err := q.FindAuthorByID(context.Background(), adamsID) + require.NoError(t, err) + assert.Equal(t, FindAuthorByIDRow{ + AuthorID: adamsID, + FirstName: "john", + LastName: "adams", + Suffix: nil, + }, authorByID) + }) + + t.Run("FindAuthorByID - none-exists", func(t *testing.T) { + missingAuthorByID, err := q.FindAuthorByID(context.Background(), 888) + require.Error(t, err, "expected error when finding author ID that doesn't match") + assert.Zero(t, missingAuthorByID, "expected zero value when error") + if !errors.Is(err, pgx.ErrNoRows) { + t.Fatalf("expected no rows error to wrap pgx.ErrNoRows; got %s", err) + } + }) + + t.Run("FindAuthorByIDBatch", func(t *testing.T) { + batch := &pgx.Batch{} + q.FindAuthorByIDBatch(batch, adamsID) + results := conn.SendBatch(context.Background(), batch) + defer errs.CaptureT(t, results.Close, "close batch results") + authors, err := q.FindAuthorByIDScan(results) + require.NoError(t, err) + assert.Equal(t, FindAuthorByIDRow{ + AuthorID: adamsID, + FirstName: "john", + LastName: "adams", + Suffix: nil, + }, authors) + }) +} + +func TestNewQuerier_DeleteAuthorsByFullName(t *testing.T) { + conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) + defer cleanup() + q := NewQuerier(conn) + insertAuthor(t, q, "george", "washington") + + t.Run("DeleteAuthorsByFullName", func(t *testing.T) { + tag, err := q.DeleteAuthorsByFullName(context.Background(), DeleteAuthorsByFullNameParams{ + FirstName: "george", + LastName: "washington", + Suffix: "", + }) + require.NoError(t, err) + assert.Truef(t, tag.Delete(), "expected delete tag; got %s", tag.String()) + assert.Equal(t, int64(1), tag.RowsAffected()) + }) +} + +func insertAuthor(t *testing.T, q *DBQuerier, first, last string) int32 { + t.Helper() + authorID, err := q.InsertAuthor(context.Background(), InsertAuthorParams{ + FirstName: first, + LastName: last, + }) + require.NoError(t, err, "insert author") + return authorID +} diff --git a/example/inline_param_count/inline2/query.sql.go b/example/inline_param_count/inline2/query.sql.go new file mode 100644 index 00000000..8005ab46 --- /dev/null +++ b/example/inline_param_count/inline2/query.sql.go @@ -0,0 +1,330 @@ +// Code generated by pggen. DO NOT EDIT. + +package inline2 + +import ( + "context" + "fmt" + "github.com/jackc/pgconn" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" +) + +// Querier is a typesafe Go interface backed by SQL queries. +// +// Methods ending with Batch enqueue a query to run later in a pgx.Batch. After +// calling SendBatch on pgx.Conn, pgxpool.Pool, or pgx.Tx, use the Scan methods +// to parse the results. +type Querier interface { + // CountAuthors returns the number of authors (zero params). + CountAuthors(ctx context.Context) (*int, error) + // CountAuthorsBatch enqueues a CountAuthors query into batch to be executed + // later by the batch. + CountAuthorsBatch(batch genericBatch) + // CountAuthorsScan scans the result of an executed CountAuthorsBatch query. + CountAuthorsScan(results pgx.BatchResults) (*int, error) + + // FindAuthorById finds one (or zero) authors by ID (one param). + FindAuthorByID(ctx context.Context, authorID int32) (FindAuthorByIDRow, error) + // FindAuthorByIDBatch enqueues a FindAuthorByID query into batch to be executed + // later by the batch. + FindAuthorByIDBatch(batch genericBatch, authorID int32) + // FindAuthorByIDScan scans the result of an executed FindAuthorByIDBatch query. + FindAuthorByIDScan(results pgx.BatchResults) (FindAuthorByIDRow, error) + + // InsertAuthor inserts an author by name and returns the ID (two params). + InsertAuthor(ctx context.Context, firstName string, lastName string) (int32, error) + // InsertAuthorBatch enqueues a InsertAuthor query into batch to be executed + // later by the batch. + InsertAuthorBatch(batch genericBatch, firstName string, lastName string) + // InsertAuthorScan scans the result of an executed InsertAuthorBatch query. + InsertAuthorScan(results pgx.BatchResults) (int32, error) + + // DeleteAuthorsByFullName deletes authors by the full name (three params). + DeleteAuthorsByFullName(ctx context.Context, params DeleteAuthorsByFullNameParams) (pgconn.CommandTag, error) + // DeleteAuthorsByFullNameBatch enqueues a DeleteAuthorsByFullName query into batch to be executed + // later by the batch. + DeleteAuthorsByFullNameBatch(batch genericBatch, params DeleteAuthorsByFullNameParams) + // DeleteAuthorsByFullNameScan scans the result of an executed DeleteAuthorsByFullNameBatch query. + DeleteAuthorsByFullNameScan(results pgx.BatchResults) (pgconn.CommandTag, error) +} + +type DBQuerier struct { + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name +} + +var _ Querier = &DBQuerier{} + +// genericConn is a connection to a Postgres database. This is usually backed by +// *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +type genericConn interface { + // Query executes sql with args. If there is an error the returned Rows will + // be returned in an error state. So it is allowed to ignore the error + // returned from Query and handle it in Rows. + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) + + // QueryRow is a convenience wrapper over Query. Any error that occurs while + // querying is deferred until calling Scan on the returned Row. That Row will + // error with pgx.ErrNoRows if no rows are returned. + QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row + + // Exec executes sql. sql can be either a prepared statement name or an SQL + // string. arguments should be referenced positionally from the sql string + // as $1, $2, etc. + Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) +} + +// genericBatch batches queries to send in a single network request to a +// Postgres server. This is usually backed by *pgx.Batch. +type genericBatch interface { + // Queue queues a query to batch b. query can be an SQL query or the name of a + // prepared statement. See Queue on *pgx.Batch. + Queue(query string, arguments ...interface{}) +} + +// NewQuerier creates a DBQuerier that implements Querier. conn is typically +// *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerier(conn genericConn) *DBQuerier { + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead + // of pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its + // transitive dependencies, pggen will use the binary encoding format for + // the input parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} +} + +// WithTx creates a new DBQuerier that uses the transaction to run all queries. +func (q *DBQuerier) WithTx(tx pgx.Tx) (*DBQuerier, error) { + return &DBQuerier{conn: tx}, nil +} + +// preparer is any Postgres connection transport that provides a way to prepare +// a statement, most commonly *pgx.Conn. +type preparer interface { + Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) +} + +// PrepareAllQueries executes a PREPARE statement for all pggen generated SQL +// queries in querier files. Typical usage is as the AfterConnect callback +// for pgxpool.Config +// +// pgx will use the prepared statement if available. Calling PrepareAllQueries +// is an optional optimization to avoid a network round-trip the first time pgx +// runs a query if pgx statement caching is enabled. +func PrepareAllQueries(ctx context.Context, p preparer) error { + if _, err := p.Prepare(ctx, countAuthorsSQL, countAuthorsSQL); err != nil { + return fmt.Errorf("prepare query 'CountAuthors': %w", err) + } + if _, err := p.Prepare(ctx, findAuthorByIDSQL, findAuthorByIDSQL); err != nil { + return fmt.Errorf("prepare query 'FindAuthorByID': %w", err) + } + if _, err := p.Prepare(ctx, insertAuthorSQL, insertAuthorSQL); err != nil { + return fmt.Errorf("prepare query 'InsertAuthor': %w", err) + } + if _, err := p.Prepare(ctx, deleteAuthorsByFullNameSQL, deleteAuthorsByFullNameSQL); err != nil { + return fmt.Errorf("prepare query 'DeleteAuthorsByFullName': %w", err) + } + return nil +} + +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + +const countAuthorsSQL = `SELECT count(*) FROM author;` + +// CountAuthors implements Querier.CountAuthors. +func (q *DBQuerier) CountAuthors(ctx context.Context) (*int, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "CountAuthors") + row := q.conn.QueryRow(ctx, countAuthorsSQL) + var item *int + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("query CountAuthors: %w", err) + } + return item, nil +} + +// CountAuthorsBatch implements Querier.CountAuthorsBatch. +func (q *DBQuerier) CountAuthorsBatch(batch genericBatch) { + batch.Queue(countAuthorsSQL) +} + +// CountAuthorsScan implements Querier.CountAuthorsScan. +func (q *DBQuerier) CountAuthorsScan(results pgx.BatchResults) (*int, error) { + row := results.QueryRow() + var item *int + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("scan CountAuthorsBatch row: %w", err) + } + return item, nil +} + +const findAuthorByIDSQL = `SELECT * FROM author WHERE author_id = $1;` + +type FindAuthorByIDRow struct { + AuthorID int32 `json:"author_id"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Suffix *string `json:"suffix"` +} + +// FindAuthorByID implements Querier.FindAuthorByID. +func (q *DBQuerier) FindAuthorByID(ctx context.Context, authorID int32) (FindAuthorByIDRow, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "FindAuthorByID") + row := q.conn.QueryRow(ctx, findAuthorByIDSQL, authorID) + var item FindAuthorByIDRow + if err := row.Scan(&item.AuthorID, &item.FirstName, &item.LastName, &item.Suffix); err != nil { + return item, fmt.Errorf("query FindAuthorByID: %w", err) + } + return item, nil +} + +// FindAuthorByIDBatch implements Querier.FindAuthorByIDBatch. +func (q *DBQuerier) FindAuthorByIDBatch(batch genericBatch, authorID int32) { + batch.Queue(findAuthorByIDSQL, authorID) +} + +// FindAuthorByIDScan implements Querier.FindAuthorByIDScan. +func (q *DBQuerier) FindAuthorByIDScan(results pgx.BatchResults) (FindAuthorByIDRow, error) { + row := results.QueryRow() + var item FindAuthorByIDRow + if err := row.Scan(&item.AuthorID, &item.FirstName, &item.LastName, &item.Suffix); err != nil { + return item, fmt.Errorf("scan FindAuthorByIDBatch row: %w", err) + } + return item, nil +} + +const insertAuthorSQL = `INSERT INTO author (first_name, last_name) +VALUES ($1, $2) +RETURNING author_id;` + +// InsertAuthor implements Querier.InsertAuthor. +func (q *DBQuerier) InsertAuthor(ctx context.Context, firstName string, lastName string) (int32, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "InsertAuthor") + row := q.conn.QueryRow(ctx, insertAuthorSQL, firstName, lastName) + var item int32 + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("query InsertAuthor: %w", err) + } + return item, nil +} + +// InsertAuthorBatch implements Querier.InsertAuthorBatch. +func (q *DBQuerier) InsertAuthorBatch(batch genericBatch, firstName string, lastName string) { + batch.Queue(insertAuthorSQL, firstName, lastName) +} + +// InsertAuthorScan implements Querier.InsertAuthorScan. +func (q *DBQuerier) InsertAuthorScan(results pgx.BatchResults) (int32, error) { + row := results.QueryRow() + var item int32 + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("scan InsertAuthorBatch row: %w", err) + } + return item, nil +} + +const deleteAuthorsByFullNameSQL = `DELETE +FROM author +WHERE first_name = $1 + AND last_name = $2 + AND CASE WHEN $3 = '' THEN suffix IS NULL ELSE suffix = $3 END;` + +type DeleteAuthorsByFullNameParams struct { + FirstName string + LastName string + Suffix string +} + +// DeleteAuthorsByFullName implements Querier.DeleteAuthorsByFullName. +func (q *DBQuerier) DeleteAuthorsByFullName(ctx context.Context, params DeleteAuthorsByFullNameParams) (pgconn.CommandTag, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "DeleteAuthorsByFullName") + cmdTag, err := q.conn.Exec(ctx, deleteAuthorsByFullNameSQL, params.FirstName, params.LastName, params.Suffix) + if err != nil { + return cmdTag, fmt.Errorf("exec query DeleteAuthorsByFullName: %w", err) + } + return cmdTag, err +} + +// DeleteAuthorsByFullNameBatch implements Querier.DeleteAuthorsByFullNameBatch. +func (q *DBQuerier) DeleteAuthorsByFullNameBatch(batch genericBatch, params DeleteAuthorsByFullNameParams) { + batch.Queue(deleteAuthorsByFullNameSQL, params.FirstName, params.LastName, params.Suffix) +} + +// DeleteAuthorsByFullNameScan implements Querier.DeleteAuthorsByFullNameScan. +func (q *DBQuerier) DeleteAuthorsByFullNameScan(results pgx.BatchResults) (pgconn.CommandTag, error) { + cmdTag, err := results.Exec() + if err != nil { + return cmdTag, fmt.Errorf("exec DeleteAuthorsByFullNameBatch: %w", err) + } + return cmdTag, err +} + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{ValueTranscoder: pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), typeName: t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/inline_param_count/inline2/query.sql_test.go b/example/inline_param_count/inline2/query.sql_test.go new file mode 100644 index 00000000..2920c2b3 --- /dev/null +++ b/example/inline_param_count/inline2/query.sql_test.go @@ -0,0 +1,88 @@ +package inline2 + +import ( + "context" + "errors" + "github.com/jschaf/pggen/internal/errs" + "github.com/stretchr/testify/require" + "testing" + + "github.com/jackc/pgx/v4" + "github.com/jschaf/pggen/internal/pgtest" + "github.com/stretchr/testify/assert" +) + +func TestNewQuerier_FindAuthorByID(t *testing.T) { + conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) + defer cleanup() + + q := NewQuerier(conn) + adamsID := insertAuthor(t, q, "john", "adams") + insertAuthor(t, q, "george", "washington") + + t.Run("CountAuthors two", func(t *testing.T) { + got, err := q.CountAuthors(context.Background()) + require.NoError(t, err) + assert.Equal(t, 2, *got) + }) + + t.Run("FindAuthorByID", func(t *testing.T) { + authorByID, err := q.FindAuthorByID(context.Background(), adamsID) + require.NoError(t, err) + assert.Equal(t, FindAuthorByIDRow{ + AuthorID: adamsID, + FirstName: "john", + LastName: "adams", + Suffix: nil, + }, authorByID) + }) + + t.Run("FindAuthorByID - none-exists", func(t *testing.T) { + missingAuthorByID, err := q.FindAuthorByID(context.Background(), 888) + require.Error(t, err, "expected error when finding author ID that doesn't match") + assert.Zero(t, missingAuthorByID, "expected zero value when error") + if !errors.Is(err, pgx.ErrNoRows) { + t.Fatalf("expected no rows error to wrap pgx.ErrNoRows; got %s", err) + } + }) + + t.Run("FindAuthorByIDBatch", func(t *testing.T) { + batch := &pgx.Batch{} + q.FindAuthorByIDBatch(batch, adamsID) + results := conn.SendBatch(context.Background(), batch) + defer errs.CaptureT(t, results.Close, "close batch results") + authors, err := q.FindAuthorByIDScan(results) + require.NoError(t, err) + assert.Equal(t, FindAuthorByIDRow{ + AuthorID: adamsID, + FirstName: "john", + LastName: "adams", + Suffix: nil, + }, authors) + }) +} + +func TestNewQuerier_DeleteAuthorsByFullName(t *testing.T) { + conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) + defer cleanup() + q := NewQuerier(conn) + insertAuthor(t, q, "george", "washington") + + t.Run("DeleteAuthorsByFullName", func(t *testing.T) { + tag, err := q.DeleteAuthorsByFullName(context.Background(), DeleteAuthorsByFullNameParams{ + FirstName: "george", + LastName: "washington", + Suffix: "", + }) + require.NoError(t, err) + assert.Truef(t, tag.Delete(), "expected delete tag; got %s", tag.String()) + assert.Equal(t, int64(1), tag.RowsAffected()) + }) +} + +func insertAuthor(t *testing.T, q *DBQuerier, first, last string) int32 { + t.Helper() + authorID, err := q.InsertAuthor(context.Background(), first, last) + require.NoError(t, err, "insert author") + return authorID +} diff --git a/example/inline_param_count/inline3/query.sql.go b/example/inline_param_count/inline3/query.sql.go new file mode 100644 index 00000000..606da209 --- /dev/null +++ b/example/inline_param_count/inline3/query.sql.go @@ -0,0 +1,324 @@ +// Code generated by pggen. DO NOT EDIT. + +package inline3 + +import ( + "context" + "fmt" + "github.com/jackc/pgconn" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" +) + +// Querier is a typesafe Go interface backed by SQL queries. +// +// Methods ending with Batch enqueue a query to run later in a pgx.Batch. After +// calling SendBatch on pgx.Conn, pgxpool.Pool, or pgx.Tx, use the Scan methods +// to parse the results. +type Querier interface { + // CountAuthors returns the number of authors (zero params). + CountAuthors(ctx context.Context) (*int, error) + // CountAuthorsBatch enqueues a CountAuthors query into batch to be executed + // later by the batch. + CountAuthorsBatch(batch genericBatch) + // CountAuthorsScan scans the result of an executed CountAuthorsBatch query. + CountAuthorsScan(results pgx.BatchResults) (*int, error) + + // FindAuthorById finds one (or zero) authors by ID (one param). + FindAuthorByID(ctx context.Context, authorID int32) (FindAuthorByIDRow, error) + // FindAuthorByIDBatch enqueues a FindAuthorByID query into batch to be executed + // later by the batch. + FindAuthorByIDBatch(batch genericBatch, authorID int32) + // FindAuthorByIDScan scans the result of an executed FindAuthorByIDBatch query. + FindAuthorByIDScan(results pgx.BatchResults) (FindAuthorByIDRow, error) + + // InsertAuthor inserts an author by name and returns the ID (two params). + InsertAuthor(ctx context.Context, firstName string, lastName string) (int32, error) + // InsertAuthorBatch enqueues a InsertAuthor query into batch to be executed + // later by the batch. + InsertAuthorBatch(batch genericBatch, firstName string, lastName string) + // InsertAuthorScan scans the result of an executed InsertAuthorBatch query. + InsertAuthorScan(results pgx.BatchResults) (int32, error) + + // DeleteAuthorsByFullName deletes authors by the full name (three params). + DeleteAuthorsByFullName(ctx context.Context, firstName string, lastName string, suffix string) (pgconn.CommandTag, error) + // DeleteAuthorsByFullNameBatch enqueues a DeleteAuthorsByFullName query into batch to be executed + // later by the batch. + DeleteAuthorsByFullNameBatch(batch genericBatch, firstName string, lastName string, suffix string) + // DeleteAuthorsByFullNameScan scans the result of an executed DeleteAuthorsByFullNameBatch query. + DeleteAuthorsByFullNameScan(results pgx.BatchResults) (pgconn.CommandTag, error) +} + +type DBQuerier struct { + conn genericConn // underlying Postgres transport to use + types *typeResolver // resolve types by name +} + +var _ Querier = &DBQuerier{} + +// genericConn is a connection to a Postgres database. This is usually backed by +// *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +type genericConn interface { + // Query executes sql with args. If there is an error the returned Rows will + // be returned in an error state. So it is allowed to ignore the error + // returned from Query and handle it in Rows. + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) + + // QueryRow is a convenience wrapper over Query. Any error that occurs while + // querying is deferred until calling Scan on the returned Row. That Row will + // error with pgx.ErrNoRows if no rows are returned. + QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row + + // Exec executes sql. sql can be either a prepared statement name or an SQL + // string. arguments should be referenced positionally from the sql string + // as $1, $2, etc. + Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) +} + +// genericBatch batches queries to send in a single network request to a +// Postgres server. This is usually backed by *pgx.Batch. +type genericBatch interface { + // Queue queues a query to batch b. query can be an SQL query or the name of a + // prepared statement. See Queue on *pgx.Batch. + Queue(query string, arguments ...interface{}) +} + +// NewQuerier creates a DBQuerier that implements Querier. conn is typically +// *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerier(conn genericConn) *DBQuerier { + return NewQuerierConfig(conn, QuerierConfig{}) +} + +type QuerierConfig struct { + // DataTypes contains pgtype.Value to use for encoding and decoding instead + // of pggen-generated pgtype.ValueTranscoder. + // + // If OIDs are available for an input parameter type and all of its + // transitive dependencies, pggen will use the binary encoding format for + // the input parameter. + DataTypes []pgtype.DataType +} + +// NewQuerierConfig creates a DBQuerier that implements Querier with the given +// config. conn is typically *pgx.Conn, pgx.Tx, or *pgxpool.Pool. +func NewQuerierConfig(conn genericConn, cfg QuerierConfig) *DBQuerier { + return &DBQuerier{conn: conn, types: newTypeResolver(cfg.DataTypes)} +} + +// WithTx creates a new DBQuerier that uses the transaction to run all queries. +func (q *DBQuerier) WithTx(tx pgx.Tx) (*DBQuerier, error) { + return &DBQuerier{conn: tx}, nil +} + +// preparer is any Postgres connection transport that provides a way to prepare +// a statement, most commonly *pgx.Conn. +type preparer interface { + Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) +} + +// PrepareAllQueries executes a PREPARE statement for all pggen generated SQL +// queries in querier files. Typical usage is as the AfterConnect callback +// for pgxpool.Config +// +// pgx will use the prepared statement if available. Calling PrepareAllQueries +// is an optional optimization to avoid a network round-trip the first time pgx +// runs a query if pgx statement caching is enabled. +func PrepareAllQueries(ctx context.Context, p preparer) error { + if _, err := p.Prepare(ctx, countAuthorsSQL, countAuthorsSQL); err != nil { + return fmt.Errorf("prepare query 'CountAuthors': %w", err) + } + if _, err := p.Prepare(ctx, findAuthorByIDSQL, findAuthorByIDSQL); err != nil { + return fmt.Errorf("prepare query 'FindAuthorByID': %w", err) + } + if _, err := p.Prepare(ctx, insertAuthorSQL, insertAuthorSQL); err != nil { + return fmt.Errorf("prepare query 'InsertAuthor': %w", err) + } + if _, err := p.Prepare(ctx, deleteAuthorsByFullNameSQL, deleteAuthorsByFullNameSQL); err != nil { + return fmt.Errorf("prepare query 'DeleteAuthorsByFullName': %w", err) + } + return nil +} + +// typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. +type typeResolver struct { + connInfo *pgtype.ConnInfo // types by Postgres type name +} + +func newTypeResolver(types []pgtype.DataType) *typeResolver { + ci := pgtype.NewConnInfo() + for _, typ := range types { + if txt, ok := typ.Value.(textPreferrer); ok && typ.OID != unknownOID { + typ.Value = txt.ValueTranscoder + } + ci.RegisterDataType(typ) + } + return &typeResolver{connInfo: ci} +} + +// findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. +func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { + typ, ok := tr.connInfo.DataTypeForName(name) + if !ok { + return 0, nil, false + } + v := pgtype.NewValue(typ.Value) + return typ.OID, v.(pgtype.ValueTranscoder), true +} + +// setValue sets the value of a ValueTranscoder to a value that should always +// work and panics if it fails. +func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { + if err := vt.Set(val); err != nil { + panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) + } + return vt +} + +const countAuthorsSQL = `SELECT count(*) FROM author;` + +// CountAuthors implements Querier.CountAuthors. +func (q *DBQuerier) CountAuthors(ctx context.Context) (*int, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "CountAuthors") + row := q.conn.QueryRow(ctx, countAuthorsSQL) + var item *int + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("query CountAuthors: %w", err) + } + return item, nil +} + +// CountAuthorsBatch implements Querier.CountAuthorsBatch. +func (q *DBQuerier) CountAuthorsBatch(batch genericBatch) { + batch.Queue(countAuthorsSQL) +} + +// CountAuthorsScan implements Querier.CountAuthorsScan. +func (q *DBQuerier) CountAuthorsScan(results pgx.BatchResults) (*int, error) { + row := results.QueryRow() + var item *int + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("scan CountAuthorsBatch row: %w", err) + } + return item, nil +} + +const findAuthorByIDSQL = `SELECT * FROM author WHERE author_id = $1;` + +type FindAuthorByIDRow struct { + AuthorID int32 `json:"author_id"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Suffix *string `json:"suffix"` +} + +// FindAuthorByID implements Querier.FindAuthorByID. +func (q *DBQuerier) FindAuthorByID(ctx context.Context, authorID int32) (FindAuthorByIDRow, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "FindAuthorByID") + row := q.conn.QueryRow(ctx, findAuthorByIDSQL, authorID) + var item FindAuthorByIDRow + if err := row.Scan(&item.AuthorID, &item.FirstName, &item.LastName, &item.Suffix); err != nil { + return item, fmt.Errorf("query FindAuthorByID: %w", err) + } + return item, nil +} + +// FindAuthorByIDBatch implements Querier.FindAuthorByIDBatch. +func (q *DBQuerier) FindAuthorByIDBatch(batch genericBatch, authorID int32) { + batch.Queue(findAuthorByIDSQL, authorID) +} + +// FindAuthorByIDScan implements Querier.FindAuthorByIDScan. +func (q *DBQuerier) FindAuthorByIDScan(results pgx.BatchResults) (FindAuthorByIDRow, error) { + row := results.QueryRow() + var item FindAuthorByIDRow + if err := row.Scan(&item.AuthorID, &item.FirstName, &item.LastName, &item.Suffix); err != nil { + return item, fmt.Errorf("scan FindAuthorByIDBatch row: %w", err) + } + return item, nil +} + +const insertAuthorSQL = `INSERT INTO author (first_name, last_name) +VALUES ($1, $2) +RETURNING author_id;` + +// InsertAuthor implements Querier.InsertAuthor. +func (q *DBQuerier) InsertAuthor(ctx context.Context, firstName string, lastName string) (int32, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "InsertAuthor") + row := q.conn.QueryRow(ctx, insertAuthorSQL, firstName, lastName) + var item int32 + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("query InsertAuthor: %w", err) + } + return item, nil +} + +// InsertAuthorBatch implements Querier.InsertAuthorBatch. +func (q *DBQuerier) InsertAuthorBatch(batch genericBatch, firstName string, lastName string) { + batch.Queue(insertAuthorSQL, firstName, lastName) +} + +// InsertAuthorScan implements Querier.InsertAuthorScan. +func (q *DBQuerier) InsertAuthorScan(results pgx.BatchResults) (int32, error) { + row := results.QueryRow() + var item int32 + if err := row.Scan(&item); err != nil { + return item, fmt.Errorf("scan InsertAuthorBatch row: %w", err) + } + return item, nil +} + +const deleteAuthorsByFullNameSQL = `DELETE +FROM author +WHERE first_name = $1 + AND last_name = $2 + AND CASE WHEN $3 = '' THEN suffix IS NULL ELSE suffix = $3 END;` + +// DeleteAuthorsByFullName implements Querier.DeleteAuthorsByFullName. +func (q *DBQuerier) DeleteAuthorsByFullName(ctx context.Context, firstName string, lastName string, suffix string) (pgconn.CommandTag, error) { + ctx = context.WithValue(ctx, "pggen_query_name", "DeleteAuthorsByFullName") + cmdTag, err := q.conn.Exec(ctx, deleteAuthorsByFullNameSQL, firstName, lastName, suffix) + if err != nil { + return cmdTag, fmt.Errorf("exec query DeleteAuthorsByFullName: %w", err) + } + return cmdTag, err +} + +// DeleteAuthorsByFullNameBatch implements Querier.DeleteAuthorsByFullNameBatch. +func (q *DBQuerier) DeleteAuthorsByFullNameBatch(batch genericBatch, firstName string, lastName string, suffix string) { + batch.Queue(deleteAuthorsByFullNameSQL, firstName, lastName, suffix) +} + +// DeleteAuthorsByFullNameScan implements Querier.DeleteAuthorsByFullNameScan. +func (q *DBQuerier) DeleteAuthorsByFullNameScan(results pgx.BatchResults) (pgconn.CommandTag, error) { + cmdTag, err := results.Exec() + if err != nil { + return cmdTag, fmt.Errorf("exec DeleteAuthorsByFullNameBatch: %w", err) + } + return cmdTag, err +} + +// textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding +// format to text instead binary (the default). pggen uses the text format +// when the OID is unknownOID because the binary format requires the OID. +// Typically occurs if the results from QueryAllDataTypes aren't passed to +// NewQuerierConfig. +type textPreferrer struct { + pgtype.ValueTranscoder + typeName string +} + +// PreferredParamFormat implements pgtype.ParamFormatPreferrer. +func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } + +func (t textPreferrer) NewTypeValue() pgtype.Value { + return textPreferrer{ValueTranscoder: pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), typeName: t.typeName} +} + +func (t textPreferrer) TypeName() string { + return t.typeName +} + +// unknownOID means we don't know the OID for a type. This is okay for decoding +// because pgx call DecodeText or DecodeBinary without requiring the OID. For +// encoding parameters, pggen uses textPreferrer if the OID is unknown. +const unknownOID = 0 diff --git a/example/inline_param_count/inline3/query.sql_test.go b/example/inline_param_count/inline3/query.sql_test.go new file mode 100644 index 00000000..ed33fe67 --- /dev/null +++ b/example/inline_param_count/inline3/query.sql_test.go @@ -0,0 +1,84 @@ +package inline3 + +import ( + "context" + "errors" + "github.com/jschaf/pggen/internal/errs" + "github.com/stretchr/testify/require" + "testing" + + "github.com/jackc/pgx/v4" + "github.com/jschaf/pggen/internal/pgtest" + "github.com/stretchr/testify/assert" +) + +func TestNewQuerier_FindAuthorByID(t *testing.T) { + conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) + defer cleanup() + + q := NewQuerier(conn) + adamsID := insertAuthor(t, q, "john", "adams") + insertAuthor(t, q, "george", "washington") + + t.Run("CountAuthors two", func(t *testing.T) { + got, err := q.CountAuthors(context.Background()) + require.NoError(t, err) + assert.Equal(t, 2, *got) + }) + + t.Run("FindAuthorByID", func(t *testing.T) { + authorByID, err := q.FindAuthorByID(context.Background(), adamsID) + require.NoError(t, err) + assert.Equal(t, FindAuthorByIDRow{ + AuthorID: adamsID, + FirstName: "john", + LastName: "adams", + Suffix: nil, + }, authorByID) + }) + + t.Run("FindAuthorByID - none-exists", func(t *testing.T) { + missingAuthorByID, err := q.FindAuthorByID(context.Background(), 888) + require.Error(t, err, "expected error when finding author ID that doesn't match") + assert.Zero(t, missingAuthorByID, "expected zero value when error") + if !errors.Is(err, pgx.ErrNoRows) { + t.Fatalf("expected no rows error to wrap pgx.ErrNoRows; got %s", err) + } + }) + + t.Run("FindAuthorByIDBatch", func(t *testing.T) { + batch := &pgx.Batch{} + q.FindAuthorByIDBatch(batch, adamsID) + results := conn.SendBatch(context.Background(), batch) + defer errs.CaptureT(t, results.Close, "close batch results") + authors, err := q.FindAuthorByIDScan(results) + require.NoError(t, err) + assert.Equal(t, FindAuthorByIDRow{ + AuthorID: adamsID, + FirstName: "john", + LastName: "adams", + Suffix: nil, + }, authors) + }) +} + +func TestNewQuerier_DeleteAuthorsByFullName(t *testing.T) { + conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) + defer cleanup() + q := NewQuerier(conn) + insertAuthor(t, q, "george", "washington") + + t.Run("DeleteAuthorsByFullName", func(t *testing.T) { + tag, err := q.DeleteAuthorsByFullName(context.Background(), "george", "washington", "") + require.NoError(t, err) + assert.Truef(t, tag.Delete(), "expected delete tag; got %s", tag.String()) + assert.Equal(t, int64(1), tag.RowsAffected()) + }) +} + +func insertAuthor(t *testing.T, q *DBQuerier, first, last string) int32 { + t.Helper() + authorID, err := q.InsertAuthor(context.Background(), first, last) + require.NoError(t, err, "insert author") + return authorID +} diff --git a/example/inline_param_count/query.sql b/example/inline_param_count/query.sql new file mode 100644 index 00000000..0b9c5ad0 --- /dev/null +++ b/example/inline_param_count/query.sql @@ -0,0 +1,21 @@ +-- CountAuthors returns the number of authors (zero params). +-- name: CountAuthors :one +SELECT count(*) FROM author; + +-- FindAuthorById finds one (or zero) authors by ID (one param). +-- name: FindAuthorByID :one +SELECT * FROM author WHERE author_id = pggen.arg('AuthorID'); + +-- InsertAuthor inserts an author by name and returns the ID (two params). +-- name: InsertAuthor :one +INSERT INTO author (first_name, last_name) +VALUES (pggen.arg('FirstName'), pggen.arg('LastName')) +RETURNING author_id; + +-- DeleteAuthorsByFullName deletes authors by the full name (three params). +-- name: DeleteAuthorsByFullName :exec +DELETE +FROM author +WHERE first_name = pggen.arg('FirstName') + AND last_name = pggen.arg('LastName') + AND CASE WHEN pggen.arg('Suffix') = '' THEN suffix IS NULL ELSE suffix = pggen.arg('Suffix') END; \ No newline at end of file diff --git a/example/inline_param_count/schema.sql b/example/inline_param_count/schema.sql new file mode 100644 index 00000000..98cd0f82 --- /dev/null +++ b/example/inline_param_count/schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE author ( + author_id serial PRIMARY KEY, + first_name text NOT NULL, + last_name text NOT NULL, + suffix text NULL +); diff --git a/example/ltree/codegen_test.go b/example/ltree/codegen_test.go index 4b157502..f5f1fec1 100644 --- a/example/ltree/codegen_test.go +++ b/example/ltree/codegen_test.go @@ -16,11 +16,12 @@ func TestGenerate_Go_Example_ltree(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "ltree", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "ltree", + Language: pggen.LangGo, + InlineParamCount: 2, TypeOverrides: map[string]string{ "ltree": "github.com/jackc/pgtype.Text", "_ltree": "github.com/jackc/pgtype.TextArray", diff --git a/example/nested/codegen_test.go b/example/nested/codegen_test.go index 780d6c1b..8422ec0d 100644 --- a/example/nested/codegen_test.go +++ b/example/nested/codegen_test.go @@ -16,11 +16,12 @@ func TestGenerate_Go_Example_nested(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "nested", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "nested", + Language: pggen.LangGo, + InlineParamCount: 2, TypeOverrides: map[string]string{ "int4": "int", "text": "string", diff --git a/example/numeric_external/codegen_test.go b/example/numeric_external/codegen_test.go index a6db08bb..5423b96b 100644 --- a/example/numeric_external/codegen_test.go +++ b/example/numeric_external/codegen_test.go @@ -16,11 +16,12 @@ func TestGenerate_Go_Example_Numeric_External(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "numeric_external", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "numeric_external", + Language: pggen.LangGo, + InlineParamCount: 2, TypeOverrides: map[string]string{ "int4": "int", "int8": "int", diff --git a/example/pgcrypto/codegen_test.go b/example/pgcrypto/codegen_test.go index e2c9d1bd..cd20cd22 100644 --- a/example/pgcrypto/codegen_test.go +++ b/example/pgcrypto/codegen_test.go @@ -16,11 +16,12 @@ func TestGenerate_Go_Example_Pgcrypto(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "pgcrypto", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "pgcrypto", + Language: pggen.LangGo, + InlineParamCount: 2, }) if err != nil { t.Fatalf("Generate() example/pgcrypto: %s", err) diff --git a/example/slices/codegen_test.go b/example/slices/codegen_test.go index fc51fb0f..a8eb468c 100644 --- a/example/slices/codegen_test.go +++ b/example/slices/codegen_test.go @@ -17,11 +17,12 @@ func TestGenerate_Go_Example_Slices(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "slices", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "slices", + Language: pggen.LangGo, + InlineParamCount: 2, TypeOverrides: map[string]string{ "_bool": "[]bool", "bool": "bool", diff --git a/example/syntax/codegen_test.go b/example/syntax/codegen_test.go index 01298537..e690e70e 100644 --- a/example/syntax/codegen_test.go +++ b/example/syntax/codegen_test.go @@ -20,9 +20,10 @@ func TestGenerate_Go_Example_Syntax(t *testing.T) { QueryFiles: []string{ "query.sql", }, - OutputDir: tmpDir, - GoPackage: "syntax", - Language: pggen.LangGo, + OutputDir: tmpDir, + GoPackage: "syntax", + Language: pggen.LangGo, + InlineParamCount: 2, }) if err != nil { t.Fatalf("Generate() example/syntax: %s", err) diff --git a/example/void/codegen_test.go b/example/void/codegen_test.go index df958c8d..e531f690 100644 --- a/example/void/codegen_test.go +++ b/example/void/codegen_test.go @@ -16,11 +16,12 @@ func TestGenerate_Go_Example_void(t *testing.T) { tmpDir := t.TempDir() err := pggen.Generate( pggen.GenerateOptions{ - ConnString: conn.Config().ConnString(), - QueryFiles: []string{"query.sql"}, - OutputDir: tmpDir, - GoPackage: "void", - Language: pggen.LangGo, + ConnString: conn.Config().ConnString(), + QueryFiles: []string{"query.sql"}, + OutputDir: tmpDir, + GoPackage: "void", + Language: pggen.LangGo, + InlineParamCount: 2, }) if err != nil { t.Fatalf("Generate() example/void: %s", err) diff --git a/generate.go b/generate.go index 3081cb7b..7bdfc963 100644 --- a/generate.go +++ b/generate.go @@ -61,6 +61,9 @@ type GenerateOptions struct { TypeOverrides map[string]string // What log level to log at. LogLevel zapcore.Level + // How many params to inline when calling querier methods. + // Set to 0 to always create a struct for params. + InlineParamCount int } // Generate generates language specific code to safely wrap each SQL @@ -115,10 +118,11 @@ func Generate(opts GenerateOptions) (mErr error) { switch opts.Language { case LangGo: goOpts := golang.GenerateOptions{ - GoPkg: opts.GoPackage, - OutputDir: opts.OutputDir, - Acronyms: opts.Acronyms, - TypeOverrides: opts.TypeOverrides, + GoPkg: opts.GoPackage, + OutputDir: opts.OutputDir, + Acronyms: opts.Acronyms, + TypeOverrides: opts.TypeOverrides, + InlineParamCount: opts.InlineParamCount, } if err := golang.Generate(goOpts, queryFiles); err != nil { return fmt.Errorf("generate go code: %w", err) diff --git a/internal/codegen/golang/generate.go b/internal/codegen/golang/generate.go index f9358712..1df69865 100644 --- a/internal/codegen/golang/generate.go +++ b/internal/codegen/golang/generate.go @@ -19,6 +19,9 @@ type GenerateOptions struct { Acronyms map[string]string // A map from a Postgres type name to a fully qualified Go type. TypeOverrides map[string]string + // How many params to inline when calling querier methods. + // Set to 0 to always create a struct for params. + InlineParamCount int } // Generate emits generated Go files for each of the queryFiles. @@ -30,9 +33,10 @@ func Generate(opts GenerateOptions, queryFiles []codegen.QueryFile) error { caser := casing.NewCaser() caser.AddAcronyms(opts.Acronyms) templater := NewTemplater(TemplaterOpts{ - Caser: caser, - Resolver: NewTypeResolver(caser, opts.TypeOverrides), - Pkg: pkgName, + Caser: caser, + Resolver: NewTypeResolver(caser, opts.TypeOverrides), + Pkg: pkgName, + InlineParamCount: opts.InlineParamCount, }) templatedFiles, err := templater.TemplateAll(queryFiles) if err != nil { diff --git a/internal/codegen/golang/templated_file.go b/internal/codegen/golang/templated_file.go index c70d3e7b..f3bf317c 100644 --- a/internal/codegen/golang/templated_file.go +++ b/internal/codegen/golang/templated_file.go @@ -33,13 +33,14 @@ type TemplatedFile struct { // TemplatedQuery is a query with all information required to execute the // codegen template. type TemplatedQuery struct { - Name string // name of the query, from the comment preceding the query - SQLVarName string // name of the string variable containing the SQL - ResultKind ast.ResultKind // kind of result: :one, :many, or :exec - Doc string // doc from the source query file, formatted for Go - PreparedSQL string // SQL query, ready to run with PREPARE statement - Inputs []TemplatedParam // input parameters to the query - Outputs []TemplatedColumn // output columns of the query + Name string // name of the query, from the comment preceding the query + SQLVarName string // name of the string variable containing the SQL + ResultKind ast.ResultKind // kind of result: :one, :many, or :exec + Doc string // doc from the source query file, formatted for Go + PreparedSQL string // SQL query, ready to run with PREPARE statement + Inputs []TemplatedParam // input parameters to the query + Outputs []TemplatedColumn // output columns of the query + InlineParamCount int // inclusive count of params that will be inlined } type TemplatedParam struct { @@ -82,21 +83,17 @@ func (tq TemplatedQuery) EmitPreparedSQL() string { // a name and type based on the number of params. For use in a method // definition. func (tq TemplatedQuery) EmitParams() string { - switch len(tq.Inputs) { - case 0: - return "" - case 1, 2: - sb := strings.Builder{} - for _, input := range tq.Inputs { - sb.WriteString(", ") - sb.WriteString(input.LowerName) - sb.WriteRune(' ') - sb.WriteString(input.QualType) - } - return sb.String() - default: + if !tq.isInlineParams() { return ", params " + tq.Name + "Params" } + sb := strings.Builder{} + for _, input := range tq.Inputs { + sb.WriteString(", ") + sb.WriteString(input.LowerName) + sb.WriteRune(' ') + sb.WriteString(input.QualType) + } + return sb.String() } func getLongestInput(inputs []TemplatedParam) int { @@ -111,7 +108,7 @@ func getLongestInput(inputs []TemplatedParam) int { // EmitParamStruct emits the struct definition for query params if needed. func (tq TemplatedQuery) EmitParamStruct() string { - if len(tq.Inputs) < 3 { + if tq.isInlineParams() { return "" } sb := &strings.Builder{} @@ -160,10 +157,8 @@ func (tq TemplatedQuery) EmitParamNames() string { sb.WriteString(name) } } - switch len(tq.Inputs) { - case 0: - return "" - case 1, 2: + switch { + case tq.isInlineParams(): sb := &strings.Builder{} for _, input := range tq.Inputs { sb.WriteString(", ") @@ -180,6 +175,10 @@ func (tq TemplatedQuery) EmitParamNames() string { } } +func (tq TemplatedQuery) isInlineParams() bool { + return len(tq.Inputs) <= tq.InlineParamCount +} + // EmitRowScanArgs emits the args to scan a single row from a pgx.Row or // pgx.Rows. func (tq TemplatedQuery) EmitRowScanArgs() (string, error) { diff --git a/internal/codegen/golang/templater.go b/internal/codegen/golang/templater.go index 8a1e603c..9dac3b31 100644 --- a/internal/codegen/golang/templater.go +++ b/internal/codegen/golang/templater.go @@ -13,9 +13,10 @@ import ( // Templater creates query file templates. type Templater struct { - caser casing.Caser - resolver TypeResolver - pkg string // Go package name + caser casing.Caser + resolver TypeResolver + pkg string // Go package name + inlineParamCount int } // TemplaterOpts is options to control the template logic. @@ -23,13 +24,16 @@ type TemplaterOpts struct { Caser casing.Caser Resolver TypeResolver Pkg string // Go package name + // How many params to inline when calling querier methods. + InlineParamCount int } func NewTemplater(opts TemplaterOpts) Templater { return Templater{ - pkg: opts.Pkg, - caser: opts.Caser, - resolver: opts.Resolver, + pkg: opts.Pkg, + caser: opts.Caser, + resolver: opts.Resolver, + inlineParamCount: opts.InlineParamCount, } } @@ -179,13 +183,14 @@ func (tm Templater) templateFile(file codegen.QueryFile, isLeader bool) (Templat } queries = append(queries, TemplatedQuery{ - Name: tm.caser.ToUpperGoIdent(query.Name), - SQLVarName: tm.caser.ToLowerGoIdent(query.Name) + "SQL", - ResultKind: query.ResultKind, - Doc: docs.String(), - PreparedSQL: query.PreparedSQL, - Inputs: inputs, - Outputs: outputs, + Name: tm.caser.ToUpperGoIdent(query.Name), + SQLVarName: tm.caser.ToLowerGoIdent(query.Name) + "SQL", + ResultKind: query.ResultKind, + Doc: docs.String(), + PreparedSQL: query.PreparedSQL, + Inputs: inputs, + Outputs: outputs, + InlineParamCount: tm.inlineParamCount, }) }