-
Notifications
You must be signed in to change notification settings - Fork 0
/
migration.go
138 lines (113 loc) · 2.7 KB
/
migration.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package gotelem
import (
"embed"
"errors"
"io"
"io/fs"
"path"
"regexp"
"sort"
"strconv"
)
// embed the migrations into applications so they can update databases.
//go:embed migrations/*
var migrationsFs embed.FS
var migrationRegex = regexp.MustCompile(`^([0-9]+)_(.*)_(down|up)\.sql$`)
type Migration struct {
Name string
Version uint
FileName string
}
type MigrationError struct {
}
// getMigrations returns a list of migrations, which are correctly index. zero is nil.
func getMigrations(files fs.FS) map[int]map[string]Migration {
res := make(map[int]map[string]Migration) // version number -> direction -> migration.
fs.WalkDir(files, ".", func(path string, d fs.DirEntry, err error) error {
if d.IsDir() {
return nil
}
m := migrationRegex.FindStringSubmatch(d.Name())
if len(m) != 4 {
panic("error parsing migration name")
}
migrationVer, _ := strconv.ParseInt(m[1], 10, 64)
mig := Migration{
Name: m[2],
Version: uint(migrationVer),
FileName: d.Name(),
}
var mMap map[string]Migration
mMap, ok := res[int(migrationVer)]
if !ok {
mMap = make(map[string]Migration)
}
mMap[m[3]] = mig
res[int(migrationVer)] = mMap
return nil
})
return res
}
func RunMigrations(tdb *TelemDb) (finalVer int, err error) {
currentVer, err := tdb.GetVersion()
if err != nil {
return
}
migrations := getMigrations(migrationsFs)
// get a sorted list of versions.
vers := make([]int, len(migrations))
i := 0
for k := range migrations {
vers[i] = k
i++
}
sort.Ints(vers)
expectedVer := 1
// check to make sure that there are no gaps (increasing by one each time)
for _, v := range vers {
if v != expectedVer {
err = errors.New("missing update between")
return 0, err
// invalid
}
expectedVer = v + 1
}
finalVer = vers[len(vers)-1]
// now apply the mappings based on current ver.
tx, err := tdb.db.Begin()
defer tx.Rollback()
if err != nil {
return 0, err
}
for v := currentVer + 1; v <= finalVer; v++ {
// attempt to get the "up" migration.
mMap, ok := migrations[v]
if !ok {
err = errors.New("could not find migration for version")
return 0, err
}
upMigration, ok := mMap["up"]
if !ok {
err = errors.New("could not get up migration")
return 0, err
}
upFile, err := migrationsFs.Open(path.Join("migrations", upMigration.FileName))
if err != nil {
return 0, err
}
upStmt, err := io.ReadAll(upFile)
if err != nil {
return 0, err
}
// open the file name
// execute the file.
_, err = tx.Exec(string(upStmt))
if err != nil {
return 0, err
}
}
// if all the versions applied correctly, update the PRAGMA user_version in the database.
tx.Commit()
err = tdb.SetVersion(finalVer)
return
}