Skip to content

Commit

Permalink
Return a parser error for wrong number of values
Browse files Browse the repository at this point in the history
  • Loading branch information
asdine committed Jun 28, 2020
1 parent 00dcef1 commit b9c4fa1
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 21 deletions.
15 changes: 14 additions & 1 deletion sql/parser/insert.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package parser

import (
"fmt"

"github.com/genjidb/genji/sql/query"
"github.com/genjidb/genji/sql/query/expr"
"github.com/genjidb/genji/sql/scanner"
Expand Down Expand Up @@ -36,11 +38,22 @@ func (p *Parser) parseInsertStatement() (query.InsertStmt, error) {
}

// Parse VALUES (v1, v2, v3)
stmt.Values, err = p.parseValues(valueParser)
values, err := p.parseValues(valueParser)
if err != nil {
return stmt, err
}

// ensure the length of field list is the same as the length of values
if withFields {
for _, l := range values {
el := l.(expr.LiteralExprList)
if len(el) != len(stmt.FieldNames) {
return stmt, fmt.Errorf("%d values for %d fields", len(el), len(stmt.FieldNames))
}
}
}

stmt.Values = values
return stmt, nil
}

Expand Down
6 changes: 4 additions & 2 deletions sql/parser/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,16 @@ func TestParserInsert(t *testing.T) {
Values: expr.LiteralExprList{expr.NamedParam("foo"), expr.NamedParam("bar")},
},
false},
{"Values / With fields", "INSERT INTO test (a, b) VALUES ('c', 'd', 'e')",
{"Values / With fields", "INSERT INTO test (a, b) VALUES ('c', 'd')",
query.InsertStmt{
TableName: "test",
FieldNames: []string{"a", "b"},
Values: expr.LiteralExprList{
expr.LiteralExprList{expr.TextValue("c"), expr.TextValue("d"), expr.TextValue("e")},
expr.LiteralExprList{expr.TextValue("c"), expr.TextValue("d")},
},
}, false},
{"Values / With too many values", "INSERT INTO test (a, b) VALUES ('c', 'd', 'e')",
nil, true},
{"Values / Multiple", "INSERT INTO test (a, b) VALUES ('c', 'd'), ('e', 'f')",
query.InsertStmt{
TableName: "test",
Expand Down
18 changes: 0 additions & 18 deletions sql/query/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package query

import (
"errors"
"fmt"

"github.com/genjidb/genji/database"
"github.com/genjidb/genji/document"
Expand Down Expand Up @@ -60,10 +59,6 @@ func (stmt InsertStmt) insertDocuments(t *database.Table, stack expr.EvalStack)
return res, err
}

if v.Type != document.DocumentValue {
return res, errors.New("values must be a list of documents if field list is empty")
}

d, err := v.ConvertToDocument()
if err != nil {
return res, err
Expand Down Expand Up @@ -94,24 +89,11 @@ func (stmt InsertStmt) insertExprList(t *database.Table, stack expr.EvalStack) (

// each document must be a list of expressions
// (e1, e2, e3, ...) or [e1, e2, e2, ....]
if v.Type != document.ArrayValue {
return res, errors.New("invalid values")
}

vlist, err := v.ConvertToArray()
if err != nil {
return res, err
}

lenv, err := document.ArrayLength(vlist)
if err != nil {
return res, err
}

if len(stmt.FieldNames) != lenv {
return res, fmt.Errorf("%d values for %d fields", lenv, len(stmt.FieldNames))
}

// iterate over each value
vlist.Iterate(func(i int, v document.Value) error {
// get the field name
Expand Down

0 comments on commit b9c4fa1

Please sign in to comment.