-
Notifications
You must be signed in to change notification settings - Fork 5
/
postgres.go
139 lines (123 loc) · 4.12 KB
/
postgres.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package schema
import (
"context"
"fmt"
"hash/crc32"
"strings"
"time"
"unicode"
)
const postgresAdvisoryLockSalt uint32 = 542384964
// Postgres is the dialect for Postgres-compatible
// databases
var Postgres = postgresDialect{}
type postgresDialect struct{}
// Lock implements the Locker interface to obtain a global lock before the
// migrations are run.
func (p postgresDialect) Lock(ctx context.Context, tx Queryer, tableName string) error {
lockID := p.advisoryLockID(tableName)
query := fmt.Sprintf("SELECT pg_advisory_lock(%s)", lockID)
_, err := tx.ExecContext(ctx, query)
return err
}
// Unlock implements the Locker interface to release the global lock after the
// migrations are run.
func (p postgresDialect) Unlock(ctx context.Context, tx Queryer, tableName string) error {
lockID := p.advisoryLockID(tableName)
query := fmt.Sprintf("SELECT pg_advisory_unlock(%s)", lockID)
_, err := tx.ExecContext(ctx, query)
return err
}
// CreateMigrationsTable implements the Dialect interface to create the
// table which tracks applied migrations. It only creates the table if it
// does not already exist
func (p postgresDialect) CreateMigrationsTable(ctx context.Context, tx Queryer, tableName string) error {
query := fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id VARCHAR(255) NOT NULL,
checksum VARCHAR(32) NOT NULL DEFAULT '',
execution_time_in_millis INTEGER NOT NULL DEFAULT 0,
applied_at TIMESTAMP WITH TIME ZONE NOT NULL
)
`, tableName)
_, err := tx.ExecContext(ctx, query)
return err
}
// InsertAppliedMigration implements the Dialect interface to insert a record
// into the migrations tracking table *after* a migration has successfully
// run.
func (p postgresDialect) InsertAppliedMigration(ctx context.Context, tx Queryer, tableName string, am *AppliedMigration) error {
query := fmt.Sprintf(`
INSERT INTO %s
( id, checksum, execution_time_in_millis, applied_at )
VALUES
( $1, $2, $3, $4 )`,
tableName,
)
_, err := tx.ExecContext(ctx, query, am.ID, am.MD5(), am.ExecutionTimeInMillis, am.AppliedAt)
return err
}
// GetAppliedMigrations retrieves all data from the migrations tracking table
func (p postgresDialect) GetAppliedMigrations(ctx context.Context, tx Queryer, tableName string) (migrations []*AppliedMigration, err error) {
migrations = make([]*AppliedMigration, 0)
query := fmt.Sprintf(`
SELECT id, checksum, execution_time_in_millis, applied_at
FROM %s ORDER BY id ASC
`, tableName)
rows, err := tx.QueryContext(ctx, query)
if err != nil {
return migrations, err
}
defer rows.Close()
for rows.Next() {
migration := AppliedMigration{}
err = rows.Scan(&migration.ID, &migration.Checksum, &migration.ExecutionTimeInMillis, &migration.AppliedAt)
if err != nil {
err = fmt.Errorf("failed to GetAppliedMigrations. Did somebody change the structure of the %s table?: %w", tableName, err)
return migrations, err
}
migration.AppliedAt = migration.AppliedAt.In(time.Local)
migrations = append(migrations, &migration)
}
return migrations, err
}
// QuotedTableName returns the string value of the name of the migration
// tracking table after it has been quoted for Postgres
func (p postgresDialect) QuotedTableName(schemaName, tableName string) string {
if schemaName == "" {
return p.QuotedIdent(tableName)
}
return p.QuotedIdent(schemaName) + "." + p.QuotedIdent(tableName)
}
// QuotedIdent wraps the supplied string in the Postgres identifier
// quote character
func (p postgresDialect) QuotedIdent(ident string) string {
if ident == "" {
return ""
}
var sb strings.Builder
sb.WriteRune('"')
for _, r := range ident {
switch {
case unicode.IsSpace(r):
// Skip spaces
continue
case r == '"':
// Escape double-quotes with repeated double-quotes
sb.WriteString(`""`)
case r == ';':
// Ignore the command termination character
continue
default:
sb.WriteRune(r)
}
}
sb.WriteRune('"')
return sb.String()
}
// advisoryLockID generates a table-specific lock name to use
func (p postgresDialect) advisoryLockID(tableName string) string {
sum := crc32.ChecksumIEEE([]byte(tableName))
sum = sum * postgresAdvisoryLockSalt
return fmt.Sprint(sum)
}