Skip to content

Commit

Permalink
Supporting MongoDB streaming part 1 (#430)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored Jul 5, 2024
1 parent f2af92b commit cd051ea
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 28 deletions.
13 changes: 8 additions & 5 deletions sources/mongo/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ import (
"go.mongodb.org/mongo-driver/bson"
)

type mgoMessage struct {
type Message struct {
jsonExtendedString string
operation string
pkMap map[string]any
}

func (m *mgoMessage) ToRawMessage(collection config.Collection, database string) (lib.RawMessage, error) {
func (m *Message) ToRawMessage(collection config.Collection, database string) (lib.RawMessage, error) {
evt := &mongo.SchemaEventPayload{
Schema: debezium.Schema{},
Payload: mongo.Payload{
Expand All @@ -27,7 +28,7 @@ func (m *mgoMessage) ToRawMessage(collection config.Collection, database string)
Collection: collection.Name,
TsMs: time.Now().UnixMilli(),
},
Operation: "r",
Operation: m.operation,
},
}

Expand All @@ -38,7 +39,7 @@ func (m *mgoMessage) ToRawMessage(collection config.Collection, database string)
return lib.NewRawMessage(collection.TopicSuffix(database), pkMap, evt), nil
}

func ParseMessage(result bson.M) (*mgoMessage, error) {
func ParseMessage(result bson.M, op string) (*Message, error) {
jsonExtendedBytes, err := bson.MarshalExtJSON(result, false, false)
if err != nil {
return nil, fmt.Errorf("failed to marshal document to JSON extended: %w", err)
Expand All @@ -58,8 +59,10 @@ func ParseMessage(result bson.M) (*mgoMessage, error) {
if err != nil {
return nil, fmt.Errorf("failed to marshal ext json: %w", err)
}
return &mgoMessage{

return &Message{
jsonExtendedString: string(jsonExtendedBytes),
operation: op,
pkMap: map[string]any{
"id": string(pkBytes),
},
Expand Down
4 changes: 2 additions & 2 deletions sources/mongo/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
func TestParseMessagePartitionKey(t *testing.T) {
objId, err := primitive.ObjectIDFromHex("507f1f77bcf86cd799439011")
assert.NoError(t, err)
msg, err := ParseMessage(bson.M{"_id": objId})
msg, err := ParseMessage(bson.M{"_id": objId}, "r")
assert.NoError(t, err)
assert.Equal(t, `{"$oid":"507f1f77bcf86cd799439011"}`, msg.pkMap["id"])

Expand Down Expand Up @@ -61,7 +61,7 @@ func TestParseMessage(t *testing.T) {
"trueValue": true,
"falseValue": false,
"nullValue": nil,
})
}, "r")
assert.NoError(t, err)

rawMsg, err := msg.ToRawMessage(config.Collection{}, "database")
Expand Down
3 changes: 1 addition & 2 deletions sources/mongo/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ func (s *Source) Run(ctx context.Context, writer writers.Writer) error {
slog.Any("batchSize", collection.GetBatchSize()),
)

iterator := newIterator(s.db, collection, s.cfg)
count, err := writer.Write(ctx, iterator)
count, err := writer.Write(ctx, newSnapshotIterator(s.db, collection, s.cfg))
if err != nil {
return fmt.Errorf("failed to snapshot collection %q: %w", collection.Name, err)
}
Expand Down
36 changes: 18 additions & 18 deletions sources/mongo/collection.go → sources/mongo/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/options"
)

type collectionScanner struct {
type snapshotIterator struct {
db *mongo.Database
cfg config.MongoDB
collection config.Collection
Expand All @@ -21,62 +21,62 @@ type collectionScanner struct {
done bool
}

func newIterator(db *mongo.Database, collection config.Collection, cfg config.MongoDB) *collectionScanner {
return &collectionScanner{
func newSnapshotIterator(db *mongo.Database, collection config.Collection, cfg config.MongoDB) *snapshotIterator {
return &snapshotIterator{
db: db,
cfg: cfg,
collection: collection,
}
}

func (c *collectionScanner) HasNext() bool {
return !c.done
func (s *snapshotIterator) HasNext() bool {
return !s.done
}

func (c *collectionScanner) Next() ([]lib.RawMessage, error) {
if !c.HasNext() {
func (s *snapshotIterator) Next() ([]lib.RawMessage, error) {
if !s.HasNext() {
return nil, fmt.Errorf("no more rows to scan")
}

ctx := context.Background()
if c.cursor == nil {
if s.cursor == nil {
findOptions := options.Find()
findOptions.SetBatchSize(c.collection.GetBatchSize())
cursor, err := c.db.Collection(c.collection.Name).Find(ctx, bson.D{}, findOptions)
findOptions.SetBatchSize(s.collection.GetBatchSize())
cursor, err := s.db.Collection(s.collection.Name).Find(ctx, bson.D{}, findOptions)
if err != nil {
return nil, fmt.Errorf("failed to find documents: %w", err)
}

c.cursor = cursor
s.cursor = cursor
}

var rawMsgs []lib.RawMessage
for c.collection.GetBatchSize() > int32(len(rawMsgs)) && c.cursor.Next(ctx) {
for s.collection.GetBatchSize() > int32(len(rawMsgs)) && s.cursor.Next(ctx) {
var result bson.M
if err := c.cursor.Decode(&result); err != nil {
if err := s.cursor.Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode document: %w", err)
}

mgoMsg, err := ParseMessage(result)
mgoMsg, err := ParseMessage(result, "r")
if err != nil {
return nil, fmt.Errorf("failed to parse message: %w", err)
}

rawMsg, err := mgoMsg.ToRawMessage(c.collection, c.cfg.Database)
rawMsg, err := mgoMsg.ToRawMessage(s.collection, s.cfg.Database)
if err != nil {
return nil, fmt.Errorf("failed to create raw message: %w", err)
}

rawMsgs = append(rawMsgs, rawMsg)
}

if err := c.cursor.Err(); err != nil {
if err := s.cursor.Err(); err != nil {
return nil, fmt.Errorf("failed to iterate over documents: %w", err)
}

// If the number of fetched documents is less than the batch size, we are done
if c.collection.GetBatchSize() > int32(len(rawMsgs)) {
c.done = true
if s.collection.GetBatchSize() > int32(len(rawMsgs)) {
s.done = true
}

return rawMsgs, nil
Expand Down
2 changes: 1 addition & 1 deletion writers/transfer/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestWriter_MessageToEvent(t *testing.T) {
"string": "Hello, world!",
"int64": int64(1234567890),
"double": 3.14159,
})
}, "r")
assert.NoError(t, err)

message, err := msg.ToRawMessage(config.Collection{Name: "collection"}, "database")
Expand Down

0 comments on commit cd051ea

Please sign in to comment.