diff --git a/go/go.sum b/go/go.sum index 15390626451..1977a366523 100644 --- a/go/go.sum +++ b/go/go.sum @@ -12,6 +12,8 @@ github.com/AthenZ/athenz v1.10.39/go.mod h1:3Tg8HLsiQZp81BJY58JBeU2BR6B/H4/0MQGf github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/DataDog/zstd v1.5.0 h1:+K/VEwIAaPcHiMtQvpLD4lqW7f0Gk3xdYZmI1hD+CXo= github.com/DataDog/zstd v1.5.0/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= +github.com/alecthomas/kong v0.7.1 h1:azoTh0IOfwlAX3qN9sHWTxACE2oV8Bg2gAwBsMwDQY4= +github.com/alecthomas/kong v0.7.1/go.mod h1:n1iCIO2xS46oE8ZfYCNDqdR0b0wZNrXAIAqro/2132U= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -344,6 +346,7 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg= +golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/go/internal/metastore/db/dao/record_log.go b/go/internal/metastore/db/dao/record_log.go index afff8ee2c08..538fa992020 100644 --- a/go/internal/metastore/db/dao/record_log.go +++ b/go/internal/metastore/db/dao/record_log.go @@ -1,6 +1,7 @@ package dao import ( + "errors" "github.com/chroma/chroma-coordinator/internal/metastore/db/dbmodel" "github.com/chroma/chroma-coordinator/internal/types" "github.com/pingcap/log" @@ -22,15 +23,26 @@ func (s *recordLogDb) PushLogs(collectionID types.UniqueID, recordsContent [][]b zap.Int64("timestamp", timestamp), zap.Int("count", len(recordsContent))) + var lastLog *dbmodel.RecordLog + err := tx.Select("id").Where("collection_id = ?", collectionIDStr).Last(&lastLog).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + log.Error("Get last log id error", zap.Error(err)) + tx.Rollback() + return err + } + var lastLogId = lastLog.ID + log.Info("PushLogs", zap.Int64("lastLogId", lastLogId)) + var recordLogs []*dbmodel.RecordLog for index := range recordsContent { recordLogs = append(recordLogs, &dbmodel.RecordLog{ CollectionID: collectionIDStr, + ID: lastLogId + int64(index) + 1, Timestamp: timestamp, Record: &recordsContent[index], }) } - err := tx.CreateInBatches(recordLogs, len(recordLogs)).Error + err = tx.CreateInBatches(recordLogs, len(recordLogs)).Error if err != nil { log.Error("Batch insert error", zap.Error(err)) tx.Rollback() @@ -53,7 +65,11 @@ func (s *recordLogDb) PullLogs(collectionID types.UniqueID, id int64, batchSize zap.Int("batch_size", batchSize)) var recordLogs []*dbmodel.RecordLog - s.db.Where("collection_id = ? AND id >= ?", collectionIDStr, id).Order("id").Limit(batchSize).Find(&recordLogs) + result := s.db.Where("collection_id = ? AND id >= ?", collectionIDStr, id).Order("id").Limit(batchSize).Find(&recordLogs) + if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) { + log.Error("PullLogs error", zap.Error(result.Error)) + return nil, result.Error + } log.Info("PullLogs", zap.String("collectionID", *collectionIDStr), zap.Int64("ID", id), diff --git a/go/internal/metastore/db/dao/record_log_test.go b/go/internal/metastore/db/dao/record_log_test.go index 091ccb8e2cd..3c536aafa92 100644 --- a/go/internal/metastore/db/dao/record_log_test.go +++ b/go/internal/metastore/db/dao/record_log_test.go @@ -7,17 +7,19 @@ import ( "github.com/pingcap/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "go.uber.org/zap" "gorm.io/gorm" "testing" ) type RecordLogDbTestSuite struct { suite.Suite - db *gorm.DB - Db *recordLogDb - t *testing.T - collectionId types.UniqueID - records [][]byte + db *gorm.DB + Db *recordLogDb + t *testing.T + collectionId1 types.UniqueID + collectionId2 types.UniqueID + records [][]byte } func (suite *RecordLogDbTestSuite) SetupSuite() { @@ -26,7 +28,8 @@ func (suite *RecordLogDbTestSuite) SetupSuite() { suite.Db = &recordLogDb{ db: suite.db, } - suite.collectionId = types.NewUniqueID() + suite.collectionId1 = types.NewUniqueID() + suite.collectionId2 = types.NewUniqueID() suite.records = make([][]byte, 0, 5) suite.records = append(suite.records, []byte("test1"), []byte("test2"), []byte("test3"), []byte("test4"), []byte("test5")) @@ -34,12 +37,50 @@ func (suite *RecordLogDbTestSuite) SetupSuite() { func (suite *RecordLogDbTestSuite) SetupTest() { log.Info("setup test") + suite.db.Migrator().DropTable(&dbmodel.Segment{}) + suite.db.Migrator().CreateTable(&dbmodel.Segment{}) + suite.db.Migrator().DropTable(&dbmodel.Collection{}) + suite.db.Migrator().CreateTable(&dbmodel.Collection{}) suite.db.Migrator().DropTable(&dbmodel.RecordLog{}) suite.db.Migrator().CreateTable(&dbmodel.RecordLog{}) + + // create test collection + collectionName := "collection1" + collectionTopic := "topic1" + var collectionDimension int32 = 6 + collection := &dbmodel.Collection{ + ID: suite.collectionId1.String(), + Name: &collectionName, + Topic: &collectionTopic, + Dimension: &collectionDimension, + DatabaseID: types.NewUniqueID().String(), + } + err := suite.db.Create(collection).Error + if err != nil { + log.Error("create collection error", zap.Error(err)) + } + + collectionName = "collection2" + collectionTopic = "topic2" + collection = &dbmodel.Collection{ + ID: suite.collectionId2.String(), + Name: &collectionName, + Topic: &collectionTopic, + Dimension: &collectionDimension, + DatabaseID: types.NewUniqueID().String(), + } + err = suite.db.Create(collection).Error + if err != nil { + log.Error("create collection error", zap.Error(err)) + } } func (suite *RecordLogDbTestSuite) TearDownTest() { log.Info("teardown test") + suite.db.Migrator().DropTable(&dbmodel.Segment{}) + suite.db.Migrator().CreateTable(&dbmodel.Segment{}) + suite.db.Migrator().DropTable(&dbmodel.Collection{}) + suite.db.Migrator().CreateTable(&dbmodel.Collection{}) suite.db.Migrator().DropTable(&dbmodel.RecordLog{}) suite.db.Migrator().CreateTable(&dbmodel.RecordLog{}) } @@ -47,15 +88,14 @@ func (suite *RecordLogDbTestSuite) TearDownTest() { func (suite *RecordLogDbTestSuite) TestRecordLogDb_PushLogs() { // run push logs in transaction // id: 0, - // offset: 0, 1, 2 // records: test1, test2, test3 - count, err := suite.Db.PushLogs(suite.collectionId, suite.records[:3]) + count, err := suite.Db.PushLogs(suite.collectionId1, suite.records[:3]) assert.NoError(suite.t, err) assert.Equal(suite.t, 3, count) // verify logs are pushed var recordLogs []*dbmodel.RecordLog - suite.db.Where("collection_id = ?", types.FromUniqueID(suite.collectionId)).Find(&recordLogs) + suite.db.Where("collection_id = ?", types.FromUniqueID(suite.collectionId1)).Find(&recordLogs) assert.Len(suite.t, recordLogs, 3) for index := range recordLogs { assert.Equal(suite.t, int64(index+1), recordLogs[index].ID) @@ -64,14 +104,28 @@ func (suite *RecordLogDbTestSuite) TestRecordLogDb_PushLogs() { // run push logs in transaction // id: 1, - // offset: 0, 1 // records: test4, test5 - count, err = suite.Db.PushLogs(suite.collectionId, suite.records[3:]) + count, err = suite.Db.PushLogs(suite.collectionId1, suite.records[3:]) assert.NoError(suite.t, err) assert.Equal(suite.t, 2, count) // verify logs are pushed - suite.db.Where("collection_id = ?", types.FromUniqueID(suite.collectionId)).Find(&recordLogs) + suite.db.Where("collection_id = ?", types.FromUniqueID(suite.collectionId1)).Find(&recordLogs) + assert.Len(suite.t, recordLogs, 5) + for index := range recordLogs { + assert.Equal(suite.t, int64(index+1), recordLogs[index].ID, "id mismatch for index %d", index) + assert.Equal(suite.t, suite.records[index], *recordLogs[index].Record, "record mismatch for index %d", index) + } + + // run push logs in transaction + // id: 0, + // records: test1, test2, test3, test4, test5 + count, err = suite.Db.PushLogs(suite.collectionId2, suite.records) + assert.NoError(suite.t, err) + assert.Equal(suite.t, 5, count) + + // verify logs are pushed + suite.db.Where("collection_id = ?", types.FromUniqueID(suite.collectionId2)).Find(&recordLogs) assert.Len(suite.t, recordLogs, 5) for index := range recordLogs { assert.Equal(suite.t, int64(index+1), recordLogs[index].ID, "id mismatch for index %d", index) @@ -80,17 +134,22 @@ func (suite *RecordLogDbTestSuite) TestRecordLogDb_PushLogs() { } func (suite *RecordLogDbTestSuite) TestRecordLogDb_PullLogsFromID() { + // pull empty logs + var recordLogs []*dbmodel.RecordLog + recordLogs, err := suite.Db.PullLogs(suite.collectionId1, 0, 3) + assert.NoError(suite.t, err) + assert.Len(suite.t, recordLogs, 0) + // push some logs - count, err := suite.Db.PushLogs(suite.collectionId, suite.records[:3]) + count, err := suite.Db.PushLogs(suite.collectionId1, suite.records[:3]) assert.NoError(suite.t, err) assert.Equal(suite.t, 3, count) - count, err = suite.Db.PushLogs(suite.collectionId, suite.records[3:]) + count, err = suite.Db.PushLogs(suite.collectionId1, suite.records[3:]) assert.NoError(suite.t, err) assert.Equal(suite.t, 2, count) // pull logs from id 0 batch_size 3 - var recordLogs []*dbmodel.RecordLog - recordLogs, err = suite.Db.PullLogs(suite.collectionId, 0, 3) + recordLogs, err = suite.Db.PullLogs(suite.collectionId1, 0, 3) assert.NoError(suite.t, err) assert.Len(suite.t, recordLogs, 3) for index := range recordLogs { @@ -99,7 +158,7 @@ func (suite *RecordLogDbTestSuite) TestRecordLogDb_PullLogsFromID() { } // pull logs from id 0 batch_size 6 - recordLogs, err = suite.Db.PullLogs(suite.collectionId, 0, 6) + recordLogs, err = suite.Db.PullLogs(suite.collectionId1, 0, 6) assert.NoError(suite.t, err) assert.Len(suite.t, recordLogs, 5) @@ -109,7 +168,7 @@ func (suite *RecordLogDbTestSuite) TestRecordLogDb_PullLogsFromID() { } // pull logs from id 3 batch_size 4 - recordLogs, err = suite.Db.PullLogs(suite.collectionId, 3, 4) + recordLogs, err = suite.Db.PullLogs(suite.collectionId1, 3, 4) assert.NoError(suite.t, err) assert.Len(suite.t, recordLogs, 3) for index := range recordLogs { diff --git a/go/internal/metastore/db/dbmodel/record_log.go b/go/internal/metastore/db/dbmodel/record_log.go index 15cbfb1f8d8..17537af0083 100644 --- a/go/internal/metastore/db/dbmodel/record_log.go +++ b/go/internal/metastore/db/dbmodel/record_log.go @@ -4,7 +4,7 @@ import "github.com/chroma/chroma-coordinator/internal/types" type RecordLog struct { CollectionID *string `gorm:"collection_id;primaryKey;autoIncrement:false"` - ID int64 `gorm:"id;primaryKey;"` // auto_increment id + ID int64 `gorm:"id;primaryKey;autoIncrement:false"` Timestamp int64 `gorm:"timestamp;"` Record *[]byte `gorm:"record;type:bytea"` } diff --git a/go/migrations/20240216211350.sql b/go/migrations/20240226214452.sql similarity index 99% rename from go/migrations/20240216211350.sql rename to go/migrations/20240226214452.sql index 2d4b286c681..ae9d6c04920 100644 --- a/go/migrations/20240216211350.sql +++ b/go/migrations/20240226214452.sql @@ -49,7 +49,7 @@ CREATE TABLE "public"."notifications" ( -- Create "record_logs" table CREATE TABLE "public"."record_logs" ( "collection_id" text NOT NULL, - "id" bigserial NOT NULL, + "id" bigint NOT NULL, "timestamp" bigint NULL, "record" bytea NULL, PRIMARY KEY ("collection_id", "id") diff --git a/go/migrations/atlas.sum b/go/migrations/atlas.sum index 6d1a0e5baaa..b56b6992da3 100644 --- a/go/migrations/atlas.sum +++ b/go/migrations/atlas.sum @@ -1,2 +1,2 @@ -h1:0AmSHt0xnRVJjHv8/LoOph5FzyVC5io1/O1lOY/Ihdo= -20240216211350.sql h1:yoz9m9lOVG1g7JPG0sWW+PXOb5sNg1W7Y5kLqhibGqg= +h1:do3nf7bNLB1RKM9w0yUfQjQ1W9Wn0qDnZXrlod4o8fo= +20240226214452.sql h1:KL5un7kPJrACxerAeDZR4rY9cylUI2huxoby6SMtfso=