From 962bbc2f6667076a6379a632414febe75cef5116 Mon Sep 17 00:00:00 2001 From: Tomasz Wiaderek Date: Thu, 9 Nov 2023 16:27:07 +0100 Subject: [PATCH] add redis for tldraw --- apps/server/src/modules/tldraw/config.ts | 2 + .../modules/tldraw/controller/tldraw.ws.ts | 4 +- .../modules/tldraw/domain/ws-shared-doc.do.ts | 25 ++++++-- apps/server/src/modules/tldraw/redis/index.ts | 2 + apps/server/src/modules/tldraw/redis/redis.ts | 25 ++++++++ .../tldraw/service/tldraw.ws.service.ts | 64 +++++++++++++++---- .../src/modules/tldraw/tldraw.module.ts | 6 +- 7 files changed, 107 insertions(+), 21 deletions(-) create mode 100644 apps/server/src/modules/tldraw/redis/index.ts create mode 100644 apps/server/src/modules/tldraw/redis/redis.ts diff --git a/apps/server/src/modules/tldraw/config.ts b/apps/server/src/modules/tldraw/config.ts index a892ee6c843..b22fe78d2bd 100644 --- a/apps/server/src/modules/tldraw/config.ts +++ b/apps/server/src/modules/tldraw/config.ts @@ -10,6 +10,7 @@ export interface TldrawConfig { FEATURE_TLDRAW_ENABLED: boolean; TLDRAW_PING_TIMEOUT: number; TLDRAW_GC_ENABLED: number; + REDIS_URI: string; } const tldrawConnectionString: string = Configuration.get('TLDRAW_DB_URL') as string; @@ -24,6 +25,7 @@ const tldrawConfig = { CONNECTION_STRING: tldrawConnectionString, TLDRAW_PING_TIMEOUT: Configuration.get('TLDRAW__PING_TIMEOUT') as number, TLDRAW_GC_ENABLED: Configuration.get('TLDRAW__GC_ENABLED') as boolean, + REDIS_URI: Configuration.get('REDIS_URI') as string, }; export const SOCKET_PORT = Configuration.get('TLDRAW__SOCKET_PORT') as number; diff --git a/apps/server/src/modules/tldraw/controller/tldraw.ws.ts b/apps/server/src/modules/tldraw/controller/tldraw.ws.ts index 1851b1d565d..bb688d5f4f6 100644 --- a/apps/server/src/modules/tldraw/controller/tldraw.ws.ts +++ b/apps/server/src/modules/tldraw/controller/tldraw.ws.ts @@ -15,11 +15,11 @@ export class TldrawWs implements OnGatewayInit, OnGatewayConnection { private readonly tldrawWsService: TldrawWsService ) {} - public handleConnection(client: WebSocket, request: Request): void { + public async handleConnection(client: WebSocket, request: Request): Promise { const docName = this.getDocNameFromRequest(request); if (docName.length > 0 && this.configService.get('FEATURE_TLDRAW_ENABLED')) { - this.tldrawWsService.setupWSConnection(client, docName); + await this.tldrawWsService.setupWSConnection(client, docName); } else { client.close( WsCloseCodeEnum.WS_CLIENT_BAD_REQUEST_CODE, diff --git a/apps/server/src/modules/tldraw/domain/ws-shared-doc.do.ts b/apps/server/src/modules/tldraw/domain/ws-shared-doc.do.ts index a84c0e7a6b5..eacfa2b696b 100644 --- a/apps/server/src/modules/tldraw/domain/ws-shared-doc.do.ts +++ b/apps/server/src/modules/tldraw/domain/ws-shared-doc.do.ts @@ -1,8 +1,8 @@ -import { Doc } from 'yjs'; +import { applyUpdate, Doc } from 'yjs'; import WebSocket from 'ws'; -import { Awareness, encodeAwarenessUpdate } from 'y-protocols/awareness'; +import { applyAwarenessUpdate, Awareness, encodeAwarenessUpdate } from 'y-protocols/awareness'; import { encoding } from 'lib0'; -import { WSMessageType } from '../types/connection-enum'; +import { WSMessageType } from '@modules/tldraw/types'; import { TldrawWsService } from '../service'; export class WsSharedDocDo extends Doc { @@ -12,6 +12,8 @@ export class WsSharedDocDo extends Doc { public awareness: Awareness; + public awarenessChannel: string; + /** * @param {string} name * @param {TldrawWsService} tldrawService @@ -23,11 +25,25 @@ export class WsSharedDocDo extends Doc { this.conns = new Map(); this.awareness = new Awareness(this); this.awareness.setLocalState(null); + this.awarenessChannel = `${name}-awareness`; this.awareness.on('update', this.awarenessChangeHandler); this.on('update', (update: Uint8Array, origin, doc: WsSharedDocDo) => { this.tldrawService.updateHandler(update, origin, doc); }); + + // eslint-disable-next-line promise/always-return + void this.tldrawService.sub.subscribe([this.name, this.awarenessChannel]).then(() => { + this.tldrawService.sub.on('messageBuffer', (channel: string, update: Uint8Array) => { + const channelId = channel; + + if (channelId === this.name) { + applyUpdate(this, update, this.tldrawService.sub); + } else if (channelId === this.awarenessChannel) { + applyAwarenessUpdate(this.awareness, update, this.tldrawService.sub); + } + }); + }); } /** @@ -73,8 +89,7 @@ export class WsSharedDocDo extends Doc { const encoder = encoding.createEncoder(); encoding.writeVarUint(encoder, WSMessageType.AWARENESS); encoding.writeVarUint8Array(encoder, encodeAwarenessUpdate(this.awareness, changedClients)); - const message = encoding.toUint8Array(encoder); - return message; + return encoding.toUint8Array(encoder); } /** diff --git a/apps/server/src/modules/tldraw/redis/index.ts b/apps/server/src/modules/tldraw/redis/index.ts new file mode 100644 index 00000000000..f30d14b2a94 --- /dev/null +++ b/apps/server/src/modules/tldraw/redis/index.ts @@ -0,0 +1,2 @@ +export * from './config'; +export * from './redis'; diff --git a/apps/server/src/modules/tldraw/redis/redis.ts b/apps/server/src/modules/tldraw/redis/redis.ts new file mode 100644 index 00000000000..f69332f74c9 --- /dev/null +++ b/apps/server/src/modules/tldraw/redis/redis.ts @@ -0,0 +1,25 @@ +import Redis from 'ioredis'; +import { WsSharedDocDo } from '@modules/tldraw/domain/ws-shared-doc.do'; + +export const getDocUpdatesKey = (doc: WsSharedDocDo) => `doc:${doc.name}:updates`; + +export const getDocUpdatesFromQueue = async (redis: Redis.Redis, doc: WsSharedDocDo) => + redis.lrangeBuffer(getDocUpdatesKey(doc), 0, -1); + +export const pushDocUpdatesToQueue = async (redis: Redis.Redis, doc: WsSharedDocDo, update: Uint8Array) => { + const len = await redis.llen(getDocUpdatesKey(doc)); + if (len > 100) { + void redis + .pipeline() + .lpopBuffer(getDocUpdatesKey(doc)) + .rpushBuffer(getDocUpdatesKey(doc), Buffer.from(update)) + .expire(getDocUpdatesKey(doc), 300) + .exec(); + } else { + await redis + .pipeline() + .rpushBuffer(getDocUpdatesKey(doc), Buffer.from(update)) + .expire(getDocUpdatesKey(doc), 300) + .exec(); + } +}; diff --git a/apps/server/src/modules/tldraw/service/tldraw.ws.service.ts b/apps/server/src/modules/tldraw/service/tldraw.ws.service.ts index 77d6bd9afbd..207edfd2910 100644 --- a/apps/server/src/modules/tldraw/service/tldraw.ws.service.ts +++ b/apps/server/src/modules/tldraw/service/tldraw.ws.service.ts @@ -7,6 +7,11 @@ import { applyAwarenessUpdate, encodeAwarenessUpdate, removeAwarenessStates } fr import { encoding, decoding, map } from 'lib0'; import { readSyncMessage, writeSyncStep1, writeUpdate } from 'y-protocols/sync'; import { TldrawBoardRepo } from '@src/modules/tldraw/repo'; +import * as mutex from 'lib0/mutex'; +import Redis from 'ioredis'; +import { Buffer } from 'node:buffer'; +import { getDocUpdatesFromQueue, pushDocUpdatesToQueue } from '@modules/tldraw/redis'; +import { applyUpdate, Doc } from 'yjs'; @Injectable() export class TldrawWsService { @@ -16,11 +21,23 @@ export class TldrawWsService { public docs = new Map(); + private mux: mutex.mutex; + + private readonly redisUri: string; + + private pub: Redis.Redis; + + sub: Redis.Redis; + constructor( private readonly configService: ConfigService, private readonly tldrawBoardRepo: TldrawBoardRepo ) { this.pingTimeout = this.configService.get('TLDRAW_PING_TIMEOUT'); + this.redisUri = this.configService.get('REDIS_URI'); + this.mux = mutex.createMutex(); + this.pub = new Redis(this.redisUri); + this.sub = new Redis(this.redisUri); } public setPersistence(persistence_: Persitence): void { @@ -82,13 +99,18 @@ export class TldrawWsService { * @param {WsSharedDocDo} doc */ public updateHandler(update: Uint8Array, origin, doc: WsSharedDocDo): void { - const encoder = encoding.createEncoder(); - encoding.writeVarUint(encoder, WSMessageType.SYNC); - writeUpdate(encoder, update); - const message = encoding.toUint8Array(encoder); - doc.conns.forEach((_, conn) => { - this.send(doc, conn, message); - }); + const isOriginWSConn = origin instanceof WebSocket && doc.conns.has(origin); + + if (isOriginWSConn) { + void Promise.all([ + this.pub.publishBuffer(doc.name, Buffer.from(update)), + pushDocUpdatesToQueue(this.pub, doc, update), + ]); + + this.propagateUpdate(update, doc); + } else { + this.propagateUpdate(update, doc); + } } /** @@ -109,7 +131,7 @@ export class TldrawWsService { }); } - public messageHandler(conn: WebSocket, doc: WsSharedDocDo, message: Uint8Array): void { + public async messageHandler(conn: WebSocket, doc: WsSharedDocDo, message: Uint8Array): Promise { try { const encoder = encoding.createEncoder(); const decoder = decoding.createDecoder(message); @@ -127,7 +149,9 @@ export class TldrawWsService { } break; case WSMessageType.AWARENESS: { - applyAwarenessUpdate(doc.awareness, decoding.readVarUint8Array(decoder), conn); + const update = decoding.readVarUint8Array(decoder); + await this.pub.publishBuffer(doc.awarenessChannel, Buffer.from(update)); + applyAwarenessUpdate(doc.awareness, update, conn); break; } default: @@ -142,7 +166,7 @@ export class TldrawWsService { * @param {WebSocket} ws * @param {string} docName */ - public setupWSConnection(ws: WebSocket, docName = 'GLOBAL'): void { + public async setupWSConnection(ws: WebSocket, docName = 'GLOBAL'): Promise { ws.binaryType = 'arraybuffer'; // get doc, initialize if it does not exist yet const doc = this.getYDoc(docName, true); @@ -150,7 +174,15 @@ export class TldrawWsService { // listen and reply to events ws.on('message', (message: ArrayBufferLike) => { - this.messageHandler(ws, doc, new Uint8Array(message)); + void this.messageHandler(ws, doc, new Uint8Array(message)); + }); + + const redisUpdates = await getDocUpdatesFromQueue(this.sub, doc); + const redisYDoc = new Doc(); + redisYDoc.transact(() => { + for (const u of redisUpdates) { + applyUpdate(redisYDoc, u); + } }); // Check if connection is still alive @@ -205,4 +237,14 @@ export class TldrawWsService { public async flushDocument(docName: string): Promise { await this.tldrawBoardRepo.flushDocument(docName); } + + private propagateUpdate(update: Uint8Array, doc: WsSharedDocDo): void { + const encoder = encoding.createEncoder(); + encoding.writeVarUint(encoder, WSMessageType.SYNC); + writeUpdate(encoder, update); + const message = encoding.toUint8Array(encoder); + doc.conns.forEach((_, conn) => { + this.send(doc, conn, message); + }); + } } diff --git a/apps/server/src/modules/tldraw/tldraw.module.ts b/apps/server/src/modules/tldraw/tldraw.module.ts index 2529e057526..6330d8a61a2 100644 --- a/apps/server/src/modules/tldraw/tldraw.module.ts +++ b/apps/server/src/modules/tldraw/tldraw.module.ts @@ -6,14 +6,14 @@ import { Logger } from '@src/core/logger'; import { MikroOrmModule, MikroOrmModuleSyncOptions } from '@mikro-orm/nestjs'; import { TldrawDrawing } from '@src/modules/tldraw/entities'; import { AuthenticationModule } from '@src/modules/authentication/authentication.module'; -import { RabbitMQWrapperTestModule } from '@infra/rabbitmq'; +import { RabbitMQWrapperTestModule } from '@shared/infra/rabbitmq'; import { Dictionary, IPrimaryKey } from '@mikro-orm/core'; import { AuthorizationModule } from '@modules/authorization'; import { config } from './config'; -import { TldrawService } from './service/tldraw.service'; -import { TldrawBoardRepo } from './repo'; import { TldrawController } from './controller/tldraw.controller'; +import { TldrawService } from './service/tldraw.service'; import { TldrawRepo } from './repo/tldraw.repo'; +import { TldrawBoardRepo } from './repo'; const defaultMikroOrmOptions: MikroOrmModuleSyncOptions = { findOneOrFailHandler: (entityName: string, where: Dictionary | IPrimaryKey) =>