From c5c590de14460a725d939fd7bf11c6ebe45a4412 Mon Sep 17 00:00:00 2001 From: Robin Tang Date: Wed, 3 Jul 2024 10:27:56 -1000 Subject: [PATCH] Checkpoint. --- sources/mongo/message.go | 13 ++++--- sources/mongo/message_test.go | 4 +-- sources/mongo/mongo.go | 3 +- sources/mongo/{collection.go => snapshot.go} | 36 ++++++++++---------- writers/transfer/writer_test.go | 2 +- 5 files changed, 30 insertions(+), 28 deletions(-) rename sources/mongo/{collection.go => snapshot.go} (59%) diff --git a/sources/mongo/message.go b/sources/mongo/message.go index 33d1cef1..6aba3e17 100644 --- a/sources/mongo/message.go +++ b/sources/mongo/message.go @@ -12,12 +12,13 @@ import ( "go.mongodb.org/mongo-driver/bson" ) -type mgoMessage struct { +type Message struct { jsonExtendedString string + operation string pkMap map[string]any } -func (m *mgoMessage) ToRawMessage(collection config.Collection, database string) (lib.RawMessage, error) { +func (m *Message) ToRawMessage(collection config.Collection, database string) (lib.RawMessage, error) { evt := &mongo.SchemaEventPayload{ Schema: debezium.Schema{}, Payload: mongo.Payload{ @@ -27,7 +28,7 @@ func (m *mgoMessage) ToRawMessage(collection config.Collection, database string) Collection: collection.Name, TsMs: time.Now().UnixMilli(), }, - Operation: "r", + Operation: m.operation, }, } @@ -38,7 +39,7 @@ func (m *mgoMessage) ToRawMessage(collection config.Collection, database string) return lib.NewRawMessage(collection.TopicSuffix(database), pkMap, evt), nil } -func ParseMessage(result bson.M) (*mgoMessage, error) { +func ParseMessage(result bson.M, op string) (*Message, error) { jsonExtendedBytes, err := bson.MarshalExtJSON(result, false, false) if err != nil { return nil, fmt.Errorf("failed to marshal document to JSON extended: %w", err) @@ -58,8 +59,10 @@ func ParseMessage(result bson.M) (*mgoMessage, error) { if err != nil { return nil, fmt.Errorf("failed to marshal ext json: %w", err) } - return &mgoMessage{ + + return &Message{ jsonExtendedString: string(jsonExtendedBytes), + operation: op, pkMap: map[string]any{ "id": string(pkBytes), }, diff --git a/sources/mongo/message_test.go b/sources/mongo/message_test.go index 25ebb377..f1e6696f 100644 --- a/sources/mongo/message_test.go +++ b/sources/mongo/message_test.go @@ -18,7 +18,7 @@ import ( func TestParseMessagePartitionKey(t *testing.T) { objId, err := primitive.ObjectIDFromHex("507f1f77bcf86cd799439011") assert.NoError(t, err) - msg, err := ParseMessage(bson.M{"_id": objId}) + msg, err := ParseMessage(bson.M{"_id": objId}, "r") assert.NoError(t, err) assert.Equal(t, `{"$oid":"507f1f77bcf86cd799439011"}`, msg.pkMap["id"]) @@ -61,7 +61,7 @@ func TestParseMessage(t *testing.T) { "trueValue": true, "falseValue": false, "nullValue": nil, - }) + }, "r") assert.NoError(t, err) rawMsg, err := msg.ToRawMessage(config.Collection{}, "database") diff --git a/sources/mongo/mongo.go b/sources/mongo/mongo.go index 90a9e97e..23b5948c 100644 --- a/sources/mongo/mongo.go +++ b/sources/mongo/mongo.go @@ -59,8 +59,7 @@ func (s *Source) Run(ctx context.Context, writer writers.Writer) error { slog.Any("batchSize", collection.GetBatchSize()), ) - iterator := newIterator(s.db, collection, s.cfg) - count, err := writer.Write(ctx, iterator) + count, err := writer.Write(ctx, newSnapshotIterator(s.db, collection, s.cfg)) if err != nil { return fmt.Errorf("failed to snapshot collection %q: %w", collection.Name, err) } diff --git a/sources/mongo/collection.go b/sources/mongo/snapshot.go similarity index 59% rename from sources/mongo/collection.go rename to sources/mongo/snapshot.go index 15eb96e3..211754c9 100644 --- a/sources/mongo/collection.go +++ b/sources/mongo/snapshot.go @@ -11,7 +11,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -type collectionScanner struct { +type snapshotIterator struct { db *mongo.Database cfg config.MongoDB collection config.Collection @@ -21,48 +21,48 @@ type collectionScanner struct { done bool } -func newIterator(db *mongo.Database, collection config.Collection, cfg config.MongoDB) *collectionScanner { - return &collectionScanner{ +func newSnapshotIterator(db *mongo.Database, collection config.Collection, cfg config.MongoDB) *snapshotIterator { + return &snapshotIterator{ db: db, cfg: cfg, collection: collection, } } -func (c *collectionScanner) HasNext() bool { - return !c.done +func (s *snapshotIterator) HasNext() bool { + return !s.done } -func (c *collectionScanner) Next() ([]lib.RawMessage, error) { - if !c.HasNext() { +func (s *snapshotIterator) Next() ([]lib.RawMessage, error) { + if !s.HasNext() { return nil, fmt.Errorf("no more rows to scan") } ctx := context.Background() - if c.cursor == nil { + if s.cursor == nil { findOptions := options.Find() - findOptions.SetBatchSize(c.collection.GetBatchSize()) - cursor, err := c.db.Collection(c.collection.Name).Find(ctx, bson.D{}, findOptions) + findOptions.SetBatchSize(s.collection.GetBatchSize()) + cursor, err := s.db.Collection(s.collection.Name).Find(ctx, bson.D{}, findOptions) if err != nil { return nil, fmt.Errorf("failed to find documents: %w", err) } - c.cursor = cursor + s.cursor = cursor } var rawMsgs []lib.RawMessage - for c.collection.GetBatchSize() > int32(len(rawMsgs)) && c.cursor.Next(ctx) { + for s.collection.GetBatchSize() > int32(len(rawMsgs)) && s.cursor.Next(ctx) { var result bson.M - if err := c.cursor.Decode(&result); err != nil { + if err := s.cursor.Decode(&result); err != nil { return nil, fmt.Errorf("failed to decode document: %w", err) } - mgoMsg, err := ParseMessage(result) + mgoMsg, err := ParseMessage(result, "r") if err != nil { return nil, fmt.Errorf("failed to parse message: %w", err) } - rawMsg, err := mgoMsg.ToRawMessage(c.collection, c.cfg.Database) + rawMsg, err := mgoMsg.ToRawMessage(s.collection, s.cfg.Database) if err != nil { return nil, fmt.Errorf("failed to create raw message: %w", err) } @@ -70,13 +70,13 @@ func (c *collectionScanner) Next() ([]lib.RawMessage, error) { rawMsgs = append(rawMsgs, rawMsg) } - if err := c.cursor.Err(); err != nil { + if err := s.cursor.Err(); err != nil { return nil, fmt.Errorf("failed to iterate over documents: %w", err) } // If the number of fetched documents is less than the batch size, we are done - if c.collection.GetBatchSize() > int32(len(rawMsgs)) { - c.done = true + if s.collection.GetBatchSize() > int32(len(rawMsgs)) { + s.done = true } return rawMsgs, nil diff --git a/writers/transfer/writer_test.go b/writers/transfer/writer_test.go index 270258d5..6e4115cf 100644 --- a/writers/transfer/writer_test.go +++ b/writers/transfer/writer_test.go @@ -22,7 +22,7 @@ func TestWriter_MessageToEvent(t *testing.T) { "string": "Hello, world!", "int64": int64(1234567890), "double": 3.14159, - }) + }, "r") assert.NoError(t, err) message, err := msg.ToRawMessage(config.Collection{Name: "collection"}, "database")