diff --git a/db.go b/db.go index 8bc06d186..e3e5a55a4 100644 --- a/db.go +++ b/db.go @@ -239,12 +239,7 @@ func (tx Tx) CreateTable(name string) (*Table, error) { return nil, errors.Wrapf(err, "failed to create table %q", name) } - s, err := tx.tx.Store(name) - return &Table{ - tx: &tx, - store: s, - name: name, - }, nil + return tx.GetTable(name) } // GetTable returns a table by name. The table instance is only valid for the lifetime of the transaction. @@ -257,11 +252,18 @@ func (tx Tx) GetTable(name string) (*Table, error) { return nil, err } - return &Table{ + t := Table{ tx: &tx, store: s, name: name, - }, nil + } + + t.indexes, err = t.Indexes() + if err != nil { + return nil, err + } + + return &t, nil } // DropTable deletes a table from the database. @@ -390,9 +392,10 @@ func (tx Tx) DropIndex(name string) error { // A Table represents a collection of records. type Table struct { - tx *Tx - store engine.Store - name string + tx *Tx + store engine.Store + name string + indexes map[string]Index } type encodedRecordWithKey struct { @@ -481,12 +484,7 @@ func (t Table) Insert(r record.Record) ([]byte, error) { return nil, err } - indexes, err := t.Indexes() - if err != nil { - return nil, err - } - - for _, idx := range indexes { + for _, idx := range t.indexes { f, err := r.GetField(idx.FieldName) if err != nil { continue @@ -508,27 +506,24 @@ func (t Table) Insert(r record.Record) ([]byte, error) { // Delete a record by key. // Indexes are automatically updated. func (t Table) Delete(key []byte) error { - err := t.store.Delete(key) + r, err := t.GetRecord(key) if err != nil { - if err == engine.ErrKeyNotFound { - return ErrRecordNotFound - } return err } - indexes, err := t.Indexes() - if err != nil { - return err - } + for _, idx := range t.indexes { + f, err := r.GetField(idx.FieldName) + if err != nil { + return err + } - for _, idx := range indexes { - err = idx.Delete(key) + err = idx.Delete(f.Data, key) if err != nil { return err } } - return nil + return t.store.Delete(key) } type pkWrapper struct { @@ -545,23 +540,19 @@ func (p pkWrapper) PrimaryKey() ([]byte, error) { // Indexes are automatically updated. func (t Table) Replace(key []byte, r record.Record) error { // make sure key exists - _, err := t.store.Get(key) + old, err := t.GetRecord(key) if err != nil { - if err == engine.ErrKeyNotFound { - return ErrRecordNotFound - } - return err } // remove key from indexes - indexes, err := t.Indexes() - if err != nil { - return err - } + for _, idx := range t.indexes { + f, err := old.GetField(idx.FieldName) + if err != nil { + return err + } - for _, idx := range indexes { - err = idx.Delete(key) + err = idx.Delete(f.Data, key) if err != nil { return err } @@ -580,7 +571,7 @@ func (t Table) Replace(key []byte, r record.Record) error { } // update indexes - for _, idx := range indexes { + for _, idx := range t.indexes { f, err := r.GetField(idx.FieldName) if err != nil { continue @@ -607,11 +598,17 @@ func (t Table) TableName() string { // Indexes returns a map of all the indexes of a table. func (t Table) Indexes() (map[string]Index, error) { - tb, err := t.tx.GetTable(indexTable) + s, err := t.tx.tx.Store(indexTable) if err != nil { return nil, err } + tb := Table{ + tx: t.tx, + store: s, + name: indexTable, + } + tableName := []byte(t.name) indexes := make(map[string]Index) @@ -653,6 +650,7 @@ func (t Table) Indexes() (map[string]Index, error) { return nil, err } + t.indexes = indexes return indexes, nil } diff --git a/db_test.go b/db_test.go index f60fa0203..c7e3b4218 100644 --- a/db_test.go +++ b/db_test.go @@ -402,12 +402,15 @@ func TestTableInsert(t *testing.T) { tx, cleanup := newTestDB(t) defer cleanup() - tb, err := tx.CreateTable("test") + _, err := tx.CreateTable("test") require.NoError(t, err) idx, err := tx.CreateIndex("idxFoo", "test", "foo", index.Options{}) require.NoError(t, err) + tb, err := tx.GetTable("test") + require.NoError(t, err) + rec := newRecord() foo := record.NewFloat32Field("foo", 10) rec = append(rec, foo) diff --git a/delete.go b/delete.go index a98e79c9e..55718bdab 100644 --- a/delete.go +++ b/delete.go @@ -45,15 +45,15 @@ func (stmt deleteStmt) IsReadOnly() bool { return false } -const bufferSize = 100 +const deleteBufferSize = 100 -// Run deletes matching records by batches of bufferSize records. +// Run deletes matching records by batches of deleteBufferSize records. // Some engines can't iterate while deleting keys (https://github.com/etcd-io/bbolt/issues/146) // and some can't create more than one iterator per read-write transaction (https://github.com/dgraph-io/badger/issues/1093). // To deal with these limitations, Run will iterate on a limited number of records, copy the keys // to a buffer and delete them after the iteration is complete, and it will do that until there is no record // left to delete. -// Increasing bufferSize will occasionate less key searches (O(log n) for most engines) but will take more memory. +// Increasing deleteBufferSize will occasionate less key searches (O(log n) for most engines) but will take more memory. func (stmt deleteStmt) Run(tx *Tx, args []driver.NamedValue) (Result, error) { var res Result if stmt.tableName == "" { @@ -68,9 +68,9 @@ func (stmt deleteStmt) Run(tx *Tx, args []driver.NamedValue) (Result, error) { } st := record.NewStream(t) - st = st.Filter(whereClause(stmt.whereExpr, stack)).Limit(bufferSize) + st = st.Filter(whereClause(stmt.whereExpr, stack)).Limit(deleteBufferSize) - keys := make([][]byte, bufferSize) + keys := make([][]byte, deleteBufferSize) for { var i int @@ -98,7 +98,7 @@ func (stmt deleteStmt) Run(tx *Tx, args []driver.NamedValue) (Result, error) { } } - if i < bufferSize { + if i < deleteBufferSize { break } } diff --git a/engine/enginetest/testing.go b/engine/enginetest/testing.go index f46551283..093d2d30a 100644 --- a/engine/enginetest/testing.go +++ b/engine/enginetest/testing.go @@ -3,12 +3,14 @@ package enginetest import ( + "bytes" "errors" "fmt" "testing" "github.com/asdine/genji" "github.com/asdine/genji/engine" + "github.com/asdine/genji/record/recordutil" "github.com/stretchr/testify/require" ) @@ -878,31 +880,31 @@ func TestQueries(t *testing.T, builder Builder) { require.NoError(t, err) }) - // t.Run("UPDATE", func(t *testing.T) { - // ng, cleanup := builder() - // defer cleanup() - - // db, err := genji.New(ng) - // require.NoError(t, err) - // defer db.Close() - - // st, err := db.Query(` - // CREATE TABLE test; - // INSERT INTO test (a) VALUES (1), (2), (3), (4); - // UPDATE test SET a = 5; - // SELECT * FROM test; - // `) - // require.NoError(t, err) - // defer st.Close() - // var buf bytes.Buffer - // err = recordutil.IteratorToJSON(&buf, st) - // require.NoError(t, err) - // require.Equal(t, `{"a":5} - // {"a":5} - // {"a":5} - // {"a":5} - // `, buf.String()) - // }) + t.Run("UPDATE", func(t *testing.T) { + ng, cleanup := builder() + defer cleanup() + + db, err := genji.New(ng) + require.NoError(t, err) + defer db.Close() + + st, err := db.Query(` + CREATE TABLE test; + INSERT INTO test (a) VALUES (1), (2), (3), (4); + UPDATE test SET a = 5; + SELECT * FROM test; + `) + require.NoError(t, err) + defer st.Close() + var buf bytes.Buffer + err = recordutil.IteratorToJSON(&buf, st) + require.NoError(t, err) + require.Equal(t, `{"a":5} +{"a":5} +{"a":5} +{"a":5} +`, buf.String()) + }) t.Run("DELETE", func(t *testing.T) { ng, cleanup := builder() @@ -981,35 +983,35 @@ func TestQueriesSameTransaction(t *testing.T, builder Builder) { require.NoError(t, err) }) - // t.Run("UPDATE", func(t *testing.T) { - // ng, cleanup := builder() - // defer cleanup() - - // db, err := genji.New(ng) - // require.NoError(t, err) - // defer db.Close() - - // err = db.Update(func(tx *genji.Tx) error { - // st, err := tx.Query(` - // CREATE TABLE test; - // INSERT INTO test (a) VALUES (1), (2), (3), (4); - // UPDATE test SET a = 5; - // SELECT * FROM test; - // `) - // require.NoError(t, err) - // defer st.Close() - // var buf bytes.Buffer - // err = recordutil.IteratorToJSON(&buf, st) - // require.NoError(t, err) - // require.Equal(t, `{"a":5} - // {"a":5} - // {"a":5} - // {"a":5} - // `, buf.String()) - // return nil - // }) - // require.NoError(t, err) - // }) + t.Run("UPDATE", func(t *testing.T) { + ng, cleanup := builder() + defer cleanup() + + db, err := genji.New(ng) + require.NoError(t, err) + defer db.Close() + + err = db.Update(func(tx *genji.Tx) error { + st, err := tx.Query(` + CREATE TABLE test; + INSERT INTO test (a) VALUES (1), (2), (3), (4); + UPDATE test SET a = 5; + SELECT * FROM test; + `) + require.NoError(t, err) + defer st.Close() + var buf bytes.Buffer + err = recordutil.IteratorToJSON(&buf, st) + require.NoError(t, err) + require.Equal(t, `{"a":5} +{"a":5} +{"a":5} +{"a":5} +`, buf.String()) + return nil + }) + require.NoError(t, err) + }) t.Run("DELETE", func(t *testing.T) { ng, cleanup := builder() diff --git a/index/index.go b/index/index.go index e7ed1dce4..03d843e20 100644 --- a/index/index.go +++ b/index/index.go @@ -23,7 +23,7 @@ type Index interface { Set(value []byte, key []byte) error // Delete all the references to the key from the index. - Delete(key []byte) error + Delete(value []byte, key []byte) error // AscendGreaterOrEqual seeks for the pivot and then goes through all the subsequent key value pairs in increasing order and calls the given function for each pair. // If the given function returns an error, the iteration stops and returns that error. @@ -62,7 +62,7 @@ type listIndex struct { // Set associates a value with a key. It is possible to associate multiple keys for the same value // but a key can be associated to only one value. -func (i *listIndex) Set(value []byte, key []byte) error { +func (i *listIndex) Set(value, key []byte) error { if len(value) == 0 { return errors.New("value cannot be nil") } @@ -75,30 +75,13 @@ func (i *listIndex) Set(value []byte, key []byte) error { return i.store.Put(buf, nil) } -func (i *listIndex) Delete(key []byte) error { - suffix := make([]byte, len(key)+1) - suffix[0] = separator - copy(suffix[1:], key) - - errStop := errors.New("stop") - - err := i.store.AscendGreaterOrEqual(nil, func(k []byte, v []byte) error { - if bytes.HasSuffix(k, suffix) { - err := i.store.Delete(k) - if err != nil { - return err - } - return errStop - } - - return nil - }) - - if err != errStop { - return err - } +func (i *listIndex) Delete(value, key []byte) error { + buf := make([]byte, 0, len(value)+len(key)+1) + buf = append(buf, value...) + buf = append(buf, separator) + buf = append(buf, key...) - return nil + return i.store.Delete(buf) } func (i *listIndex) AscendGreaterOrEqual(pivot []byte, fn func(value []byte, key []byte) error) error { @@ -142,29 +125,8 @@ func (i *uniqueIndex) Set(value []byte, key []byte) error { return i.store.Put(value, key) } -func (i *uniqueIndex) Delete(key []byte) error { - var toDelete [][]byte - - err := i.store.AscendGreaterOrEqual(nil, func(value []byte, rID []byte) error { - if bytes.Equal(key, rID) { - toDelete = append(toDelete, value) - } - - return nil - }) - - if err != nil { - return err - } - - for _, v := range toDelete { - err := i.store.Delete(v) - if err != nil { - return err - } - } - - return nil +func (i *uniqueIndex) Delete(value, key []byte) error { + return i.store.Delete(value) } func (i *uniqueIndex) AscendGreaterOrEqual(pivot []byte, fn func(value []byte, key []byte) error) error { diff --git a/index/index_test.go b/index/index_test.go index 3e444e3fa..8ea94c5f6 100644 --- a/index/index_test.go +++ b/index/index_test.go @@ -69,7 +69,7 @@ func TestIndexDelete(t *testing.T) { require.NoError(t, idx.Set([]byte("value1"), []byte("key"))) require.NoError(t, idx.Set([]byte("value1"), []byte("other-key"))) require.NoError(t, idx.Set([]byte("value2"), []byte("yet-another-key"))) - require.NoError(t, idx.Delete([]byte("key"))) + require.NoError(t, idx.Delete([]byte("value1"), []byte("key"))) i := 0 err := idx.AscendGreaterOrEqual([]byte("value1"), func(v, key []byte) error { @@ -95,16 +95,20 @@ func TestIndexDelete(t *testing.T) { defer cleanup() require.NoError(t, idx.Set([]byte("value1"), []byte("key1"))) - require.NoError(t, idx.Set([]byte("value2"), []byte("key1"))) - require.NoError(t, idx.Set([]byte("value3"), []byte("key2"))) - require.NoError(t, idx.Delete([]byte("key1"))) + require.NoError(t, idx.Set([]byte("value2"), []byte("key2"))) + require.NoError(t, idx.Set([]byte("value3"), []byte("key3"))) + require.NoError(t, idx.Delete([]byte("value2"), []byte("key2"))) i := 0 err := idx.AscendGreaterOrEqual(nil, func(v, key []byte) error { - if i == 0 { + switch i { + case 0: + require.Equal(t, "value1", string(v)) + require.Equal(t, "key1", string(key)) + case 1: require.Equal(t, "value3", string(v)) - require.Equal(t, "key2", string(key)) - } else { + require.Equal(t, "key3", string(key)) + default: return errors.New("should not reach this point") } @@ -112,17 +116,17 @@ func TestIndexDelete(t *testing.T) { return nil }) require.NoError(t, err) - require.Equal(t, 1, i) + require.Equal(t, 2, i) }) for _, unique := range []bool{true, false} { text := fmt.Sprintf("Unique: %v, ", unique) - t.Run(text+"Delete non existing key succeeds", func(t *testing.T) { + t.Run(text+"Delete non existing key fails", func(t *testing.T) { idx, cleanup := getIndex(t, index.Options{Unique: unique}) defer cleanup() - require.NoError(t, idx.Delete([]byte("foo"))) + require.Error(t, idx.Delete([]byte("foo"), []byte("foo"))) }) } }