Skip to content

Commit

Permalink
Escape MySQL identifiers (#16)
Browse files Browse the repository at this point in the history
Related to gobuffalo/pop#42
  • Loading branch information
stanislas-m authored Aug 16, 2018
1 parent 8ef750f commit f57f22e
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 115 deletions.
8 changes: 4 additions & 4 deletions translators/cockroach.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (p *Cockroach) RenameTable(t []fizz.Table) (string, error) {

func (p *Cockroach) ChangeColumn(t fizz.Table) (string, error) {
if len(t.Columns) == 0 {
return "", errors.New("Not enough columns supplied!")
return "", errors.New("not enough columns supplied")
}
c := t.Columns[0]

Expand Down Expand Up @@ -130,7 +130,7 @@ func (p *Cockroach) ChangeColumn(t fizz.Table) (string, error) {

func (p *Cockroach) AddColumn(t fizz.Table) (string, error) {
if len(t.Columns) == 0 {
return "", errors.New("Not enough columns supplied!")
return "", errors.New("not enough columns supplied")
}
c := t.Columns[0]
s := fmt.Sprintf("ALTER TABLE \"%s\" ADD COLUMN %s;COMMIT TRANSACTION;BEGIN TRANSACTION;", t.Name, p.buildAddColumn(c))
Expand All @@ -155,7 +155,7 @@ func (p *Cockroach) AddColumn(t fizz.Table) (string, error) {

func (p *Cockroach) DropColumn(t fizz.Table) (string, error) {
if len(t.Columns) == 0 {
return "", errors.New("Not enough columns supplied!")
return "", errors.New("not enough columns supplied")
}
c := t.Columns[0]
p.Schema.DeleteColumn(t.Name, c.Name)
Expand All @@ -164,7 +164,7 @@ func (p *Cockroach) DropColumn(t fizz.Table) (string, error) {

func (p *Cockroach) RenameColumn(t fizz.Table) (string, error) {
if len(t.Columns) < 2 {
return "", errors.New("Not enough columns supplied!")
return "", errors.New("not enough columns supplied")
}

oc := t.Columns[0]
Expand Down
78 changes: 50 additions & 28 deletions translators/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ import (
"github.com/gobuffalo/fizz"
)

// MySQL is a MySQL-specific translator.
type MySQL struct {
Schema SchemaQuery
}

// NewMySQL constructs a new MySQL translator.
func NewMySQL(url, name string) *MySQL {
schema := &mysqlSchema{Schema{URL: url, Name: name, schema: map[string]*fizz.Table{}}}
schema.Builder = schema
Expand All @@ -22,21 +24,22 @@ func NewMySQL(url, name string) *MySQL {
}
}

// CreateTable translates a fizz Table to its MySQL SQL definition.
func (p *MySQL) CreateTable(t fizz.Table) (string, error) {
sql := []string{}
cols := []string{}
for _, c := range t.Columns {
cols = append(cols, p.buildColumn(c))
if c.Primary {
cols = append(cols, fmt.Sprintf("PRIMARY KEY(%s)", c.Name))
cols = append(cols, fmt.Sprintf("PRIMARY KEY(`%s`)", c.Name))
}
}

for _, fk := range t.ForeignKeys {
cols = append(cols, p.buildForeignKey(t, fk, true))
}

s := fmt.Sprintf("CREATE TABLE %s (\n%s\n) ENGINE=InnoDB;", t.Name, strings.Join(cols, ",\n"))
s := fmt.Sprintf("CREATE TABLE %s (\n%s\n) ENGINE=InnoDB;", p.escapeIdentifier(t.Name), strings.Join(cols, ",\n"))

sql = append(sql, s)

Expand All @@ -55,58 +58,58 @@ func (p *MySQL) CreateTable(t fizz.Table) (string, error) {
}

func (p *MySQL) DropTable(t fizz.Table) (string, error) {
return fmt.Sprintf("DROP TABLE %s;", t.Name), nil
return fmt.Sprintf("DROP TABLE %s;", p.escapeIdentifier(t.Name)), nil
}

func (p *MySQL) RenameTable(t []fizz.Table) (string, error) {
if len(t) < 2 {
return "", errors.New("Not enough table names supplied!")
return "", errors.New("not enough table names supplied")
}
return fmt.Sprintf("ALTER TABLE %s RENAME TO %s;", t[0].Name, t[1].Name), nil
return fmt.Sprintf("ALTER TABLE %s RENAME TO %s;", p.escapeIdentifier(t[0].Name), p.escapeIdentifier(t[1].Name)), nil
}

func (p *MySQL) ChangeColumn(t fizz.Table) (string, error) {
if len(t.Columns) == 0 {
return "", errors.New("Not enough columns supplied!")
return "", errors.New("not enough columns supplied")
}
c := t.Columns[0]
s := fmt.Sprintf("ALTER TABLE %s MODIFY %s;", t.Name, p.buildColumn(c))
s := fmt.Sprintf("ALTER TABLE %s MODIFY %s;", p.escapeIdentifier(t.Name), p.buildColumn(c))
return s, nil
}

func (p *MySQL) AddColumn(t fizz.Table) (string, error) {
if len(t.Columns) == 0 {
return "", errors.New("Not enough columns supplied!")
return "", errors.New("not enough columns supplied")
}

if _, ok := t.Columns[0].Options["first"]; ok {
c := t.Columns[0]
s := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s FIRST;", t.Name, p.buildColumn(c))
s := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s FIRST;", p.escapeIdentifier(t.Name), p.buildColumn(c))
return s, nil
}

if val, ok := t.Columns[0].Options["after"]; ok {
c := t.Columns[0]
s := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s AFTER %s;", t.Name, p.buildColumn(c), val)
s := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s AFTER `%s`;", p.escapeIdentifier(t.Name), p.buildColumn(c), val)
return s, nil
}

c := t.Columns[0]
s := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", t.Name, p.buildColumn(c))
s := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", p.escapeIdentifier(t.Name), p.buildColumn(c))
return s, nil
}

func (p *MySQL) DropColumn(t fizz.Table) (string, error) {
if len(t.Columns) == 0 {
return "", errors.New("Not enough columns supplied!")
return "", errors.New("not enough columns supplied")
}
c := t.Columns[0]
return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s;", t.Name, c.Name), nil
return fmt.Sprintf("ALTER TABLE %s DROP COLUMN `%s`;", p.escapeIdentifier(t.Name), c.Name), nil
}

func (p *MySQL) RenameColumn(t fizz.Table) (string, error) {
if len(t.Columns) < 2 {
return "", errors.New("Not enough columns supplied!")
return "", errors.New("not enough columns supplied")
}
oc := t.Columns[0]
nc := t.Columns[1]
Expand All @@ -123,16 +126,20 @@ func (p *MySQL) RenameColumn(t fizz.Table) (string, error) {
}
col := p.buildColumn(c)
col = strings.Replace(col, oc.Name, fmt.Sprintf("%s %s", oc.Name, nc.Name), -1)
s := fmt.Sprintf("ALTER TABLE %s CHANGE %s;", t.Name, col)
s := fmt.Sprintf("ALTER TABLE %s CHANGE %s;", p.escapeIdentifier(t.Name), col)
return s, nil
}

func (p *MySQL) AddIndex(t fizz.Table) (string, error) {
if len(t.Indexes) == 0 {
return "", errors.New("Not enough indexes supplied!")
return "", errors.New("not enough indexes supplied")
}
i := t.Indexes[0]
s := fmt.Sprintf("CREATE INDEX %s ON %s (%s);", i.Name, t.Name, strings.Join(i.Columns, ", "))
cols := []string{}
for _, c := range i.Columns {
cols = append(cols, fmt.Sprintf("`%s`", c))
}
s := fmt.Sprintf("CREATE INDEX `%s` ON %s (%s);", i.Name, p.escapeIdentifier(t.Name), strings.Join(cols, ", "))
if i.Unique {
s = strings.Replace(s, "CREATE", "CREATE UNIQUE", 1)
}
Expand All @@ -141,10 +148,10 @@ func (p *MySQL) AddIndex(t fizz.Table) (string, error) {

func (p *MySQL) DropIndex(t fizz.Table) (string, error) {
if len(t.Indexes) == 0 {
return "", errors.New("Not enough indexes supplied!")
return "", errors.New("not enough indexes supplied")
}
i := t.Indexes[0]
return fmt.Sprintf("DROP INDEX %s ON %s;", i.Name, t.Name), nil
return fmt.Sprintf("DROP INDEX `%s` ON %s;", i.Name, p.escapeIdentifier(t.Name)), nil
}

func (p *MySQL) RenameIndex(t fizz.Table) (string, error) {
Expand All @@ -158,24 +165,24 @@ func (p *MySQL) RenameIndex(t fizz.Table) (string, error) {
}
ix := t.Indexes
if len(ix) < 2 {
return "", errors.New("Not enough indexes supplied!")
return "", errors.New("not enough indexes supplied")
}
oi := ix[0]
ni := ix[1]
return fmt.Sprintf("ALTER TABLE %s RENAME INDEX %s TO %s;", t.Name, oi.Name, ni.Name), nil
return fmt.Sprintf("ALTER TABLE %s RENAME INDEX `%s` TO `%s`;", p.escapeIdentifier(t.Name), oi.Name, ni.Name), nil
}

func (p *MySQL) AddForeignKey(t fizz.Table) (string, error) {
if len(t.ForeignKeys) == 0 {
return "", errors.New("Not enough foreign keys supplied!")
return "", errors.New("not enough foreign keys supplied")
}

return p.buildForeignKey(t, t.ForeignKeys[0], false), nil
}

func (p *MySQL) DropForeignKey(t fizz.Table) (string, error) {
if len(t.ForeignKeys) == 0 {
return "", errors.New("Not enough foreign keys supplied!")
return "", errors.New("not enough foreign keys supplied")
}

fk := t.ForeignKeys[0]
Expand All @@ -185,12 +192,12 @@ func (p *MySQL) DropForeignKey(t fizz.Table) (string, error) {
ifExists = "IF EXISTS"
}

s := fmt.Sprintf("ALTER TABLE %s DROP FOREIGN KEY %s %s;", t.Name, ifExists, fk.Name)
s := fmt.Sprintf("ALTER TABLE %s DROP FOREIGN KEY %s `%s`;", p.escapeIdentifier(t.Name), ifExists, fk.Name)
return s, nil
}

func (p *MySQL) buildColumn(c fizz.Column) string {
s := fmt.Sprintf("%s %s", c.Name, p.colType(c))
s := fmt.Sprintf("`%s` %s", c.Name, p.colType(c))
if c.Options["null"] == nil || c.Primary {
s = fmt.Sprintf("%s NOT NULL", s)
}
Expand Down Expand Up @@ -231,8 +238,12 @@ func (p *MySQL) colType(c fizz.Column) string {
}

func (p *MySQL) buildForeignKey(t fizz.Table, fk fizz.ForeignKey, onCreate bool) string {
refs := fmt.Sprintf("%s (%s)", fk.References.Table, strings.Join(fk.References.Columns, ", "))
s := fmt.Sprintf("FOREIGN KEY (%s) REFERENCES %s", fk.Column, refs)
rcols := []string{}
for _, c := range fk.References.Columns {
rcols = append(rcols, fmt.Sprintf("`%s`", c))
}
refs := fmt.Sprintf("%s (%s)", p.escapeIdentifier(fk.References.Table), strings.Join(rcols, ", "))
s := fmt.Sprintf("FOREIGN KEY (`%s`) REFERENCES %s", fk.Column, refs)

if onUpdate, ok := fk.Options["on_update"]; ok {
s += fmt.Sprintf(" ON UPDATE %s", onUpdate)
Expand All @@ -243,8 +254,19 @@ func (p *MySQL) buildForeignKey(t fizz.Table, fk fizz.ForeignKey, onCreate bool)
}

if !onCreate {
s = fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s;", t.Name, fk.Name, s)
s = fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT `%s` %s;", p.escapeIdentifier(t.Name), fk.Name, s)
}

return s
}

func (p *MySQL) escapeIdentifier(s string) string {
if !strings.ContainsRune(s, '.') {
return fmt.Sprintf("`%s`", s)
}
parts := strings.Split(s, ".")
for _, p := range parts {
p = fmt.Sprintf("`%s`", p)
}
return strings.Join(parts, ".")
}
Loading

0 comments on commit f57f22e

Please sign in to comment.