From b9c4fa15f0f4bb6889df949dca03a714f1983d8b Mon Sep 17 00:00:00 2001 From: Asdine El Hrychy Date: Sun, 21 Jun 2020 20:41:19 +0400 Subject: [PATCH] Return a parser error for wrong number of values --- sql/parser/insert.go | 15 ++++++++++++++- sql/parser/insert_test.go | 6 ++++-- sql/query/insert.go | 18 ------------------ 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/sql/parser/insert.go b/sql/parser/insert.go index e3ad6e0c2..4a70b3d91 100644 --- a/sql/parser/insert.go +++ b/sql/parser/insert.go @@ -1,6 +1,8 @@ package parser import ( + "fmt" + "github.com/genjidb/genji/sql/query" "github.com/genjidb/genji/sql/query/expr" "github.com/genjidb/genji/sql/scanner" @@ -36,11 +38,22 @@ func (p *Parser) parseInsertStatement() (query.InsertStmt, error) { } // Parse VALUES (v1, v2, v3) - stmt.Values, err = p.parseValues(valueParser) + values, err := p.parseValues(valueParser) if err != nil { return stmt, err } + // ensure the length of field list is the same as the length of values + if withFields { + for _, l := range values { + el := l.(expr.LiteralExprList) + if len(el) != len(stmt.FieldNames) { + return stmt, fmt.Errorf("%d values for %d fields", len(el), len(stmt.FieldNames)) + } + } + } + + stmt.Values = values return stmt, nil } diff --git a/sql/parser/insert_test.go b/sql/parser/insert_test.go index be315be74..30f9a3c08 100644 --- a/sql/parser/insert_test.go +++ b/sql/parser/insert_test.go @@ -52,14 +52,16 @@ func TestParserInsert(t *testing.T) { Values: expr.LiteralExprList{expr.NamedParam("foo"), expr.NamedParam("bar")}, }, false}, - {"Values / With fields", "INSERT INTO test (a, b) VALUES ('c', 'd', 'e')", + {"Values / With fields", "INSERT INTO test (a, b) VALUES ('c', 'd')", query.InsertStmt{ TableName: "test", FieldNames: []string{"a", "b"}, Values: expr.LiteralExprList{ - expr.LiteralExprList{expr.TextValue("c"), expr.TextValue("d"), expr.TextValue("e")}, + expr.LiteralExprList{expr.TextValue("c"), expr.TextValue("d")}, }, }, false}, + {"Values / With too many values", "INSERT INTO test (a, b) VALUES ('c', 'd', 'e')", + nil, true}, {"Values / Multiple", "INSERT INTO test (a, b) VALUES ('c', 'd'), ('e', 'f')", query.InsertStmt{ TableName: "test", diff --git a/sql/query/insert.go b/sql/query/insert.go index dd8214249..7da0bef9a 100644 --- a/sql/query/insert.go +++ b/sql/query/insert.go @@ -2,7 +2,6 @@ package query import ( "errors" - "fmt" "github.com/genjidb/genji/database" "github.com/genjidb/genji/document" @@ -60,10 +59,6 @@ func (stmt InsertStmt) insertDocuments(t *database.Table, stack expr.EvalStack) return res, err } - if v.Type != document.DocumentValue { - return res, errors.New("values must be a list of documents if field list is empty") - } - d, err := v.ConvertToDocument() if err != nil { return res, err @@ -94,24 +89,11 @@ func (stmt InsertStmt) insertExprList(t *database.Table, stack expr.EvalStack) ( // each document must be a list of expressions // (e1, e2, e3, ...) or [e1, e2, e2, ....] - if v.Type != document.ArrayValue { - return res, errors.New("invalid values") - } - vlist, err := v.ConvertToArray() if err != nil { return res, err } - lenv, err := document.ArrayLength(vlist) - if err != nil { - return res, err - } - - if len(stmt.FieldNames) != lenv { - return res, fmt.Errorf("%d values for %d fields", lenv, len(stmt.FieldNames)) - } - // iterate over each value vlist.Iterate(func(i int, v document.Value) error { // get the field name