Skip to content

Commit

Permalink
add redis for tldraw
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomasz Wiaderek authored and Tomasz Wiaderek committed Nov 9, 2023
1 parent 785714d commit 962bbc2
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 21 deletions.
2 changes: 2 additions & 0 deletions apps/server/src/modules/tldraw/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export interface TldrawConfig {
FEATURE_TLDRAW_ENABLED: boolean;
TLDRAW_PING_TIMEOUT: number;
TLDRAW_GC_ENABLED: number;
REDIS_URI: string;
}

const tldrawConnectionString: string = Configuration.get('TLDRAW_DB_URL') as string;
Expand All @@ -24,6 +25,7 @@ const tldrawConfig = {
CONNECTION_STRING: tldrawConnectionString,
TLDRAW_PING_TIMEOUT: Configuration.get('TLDRAW__PING_TIMEOUT') as number,
TLDRAW_GC_ENABLED: Configuration.get('TLDRAW__GC_ENABLED') as boolean,
REDIS_URI: Configuration.get('REDIS_URI') as string,
};

export const SOCKET_PORT = Configuration.get('TLDRAW__SOCKET_PORT') as number;
Expand Down
4 changes: 2 additions & 2 deletions apps/server/src/modules/tldraw/controller/tldraw.ws.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ export class TldrawWs implements OnGatewayInit, OnGatewayConnection {
private readonly tldrawWsService: TldrawWsService
) {}

public handleConnection(client: WebSocket, request: Request): void {
public async handleConnection(client: WebSocket, request: Request): Promise<void> {
const docName = this.getDocNameFromRequest(request);

if (docName.length > 0 && this.configService.get<string>('FEATURE_TLDRAW_ENABLED')) {
this.tldrawWsService.setupWSConnection(client, docName);
await this.tldrawWsService.setupWSConnection(client, docName);
} else {
client.close(
WsCloseCodeEnum.WS_CLIENT_BAD_REQUEST_CODE,
Expand Down
25 changes: 20 additions & 5 deletions apps/server/src/modules/tldraw/domain/ws-shared-doc.do.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { Doc } from 'yjs';
import { applyUpdate, Doc } from 'yjs';
import WebSocket from 'ws';
import { Awareness, encodeAwarenessUpdate } from 'y-protocols/awareness';
import { applyAwarenessUpdate, Awareness, encodeAwarenessUpdate } from 'y-protocols/awareness';
import { encoding } from 'lib0';
import { WSMessageType } from '../types/connection-enum';
import { WSMessageType } from '@modules/tldraw/types';
import { TldrawWsService } from '../service';

export class WsSharedDocDo extends Doc {
Expand All @@ -12,6 +12,8 @@ export class WsSharedDocDo extends Doc {

public awareness: Awareness;

public awarenessChannel: string;

/**
* @param {string} name
* @param {TldrawWsService} tldrawService
Expand All @@ -23,11 +25,25 @@ export class WsSharedDocDo extends Doc {
this.conns = new Map();
this.awareness = new Awareness(this);
this.awareness.setLocalState(null);
this.awarenessChannel = `${name}-awareness`;

this.awareness.on('update', this.awarenessChangeHandler);
this.on('update', (update: Uint8Array, origin, doc: WsSharedDocDo) => {
this.tldrawService.updateHandler(update, origin, doc);
});

// eslint-disable-next-line promise/always-return
void this.tldrawService.sub.subscribe([this.name, this.awarenessChannel]).then(() => {
this.tldrawService.sub.on('messageBuffer', (channel: string, update: Uint8Array) => {
const channelId = channel;

if (channelId === this.name) {
applyUpdate(this, update, this.tldrawService.sub);
} else if (channelId === this.awarenessChannel) {
applyAwarenessUpdate(this.awareness, update, this.tldrawService.sub);
}
});
});
}

/**
Expand Down Expand Up @@ -73,8 +89,7 @@ export class WsSharedDocDo extends Doc {
const encoder = encoding.createEncoder();
encoding.writeVarUint(encoder, WSMessageType.AWARENESS);
encoding.writeVarUint8Array(encoder, encodeAwarenessUpdate(this.awareness, changedClients));
const message = encoding.toUint8Array(encoder);
return message;
return encoding.toUint8Array(encoder);
}

/**
Expand Down
2 changes: 2 additions & 0 deletions apps/server/src/modules/tldraw/redis/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export * from './config';

Check failure on line 1 in apps/server/src/modules/tldraw/redis/index.ts

View workflow job for this annotation

GitHub Actions / nest_lint

Missing file extension for "./config"
export * from './redis';
25 changes: 25 additions & 0 deletions apps/server/src/modules/tldraw/redis/redis.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import Redis from 'ioredis';
import { WsSharedDocDo } from '@modules/tldraw/domain/ws-shared-doc.do';

export const getDocUpdatesKey = (doc: WsSharedDocDo) => `doc:${doc.name}:updates`;

export const getDocUpdatesFromQueue = async (redis: Redis.Redis, doc: WsSharedDocDo) =>
redis.lrangeBuffer(getDocUpdatesKey(doc), 0, -1);

export const pushDocUpdatesToQueue = async (redis: Redis.Redis, doc: WsSharedDocDo, update: Uint8Array) => {
const len = await redis.llen(getDocUpdatesKey(doc));
if (len > 100) {
void redis
.pipeline()
.lpopBuffer(getDocUpdatesKey(doc))
.rpushBuffer(getDocUpdatesKey(doc), Buffer.from(update))
.expire(getDocUpdatesKey(doc), 300)
.exec();
} else {
await redis
.pipeline()
.rpushBuffer(getDocUpdatesKey(doc), Buffer.from(update))
.expire(getDocUpdatesKey(doc), 300)
.exec();
}
};
64 changes: 53 additions & 11 deletions apps/server/src/modules/tldraw/service/tldraw.ws.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ import { applyAwarenessUpdate, encodeAwarenessUpdate, removeAwarenessStates } fr
import { encoding, decoding, map } from 'lib0';
import { readSyncMessage, writeSyncStep1, writeUpdate } from 'y-protocols/sync';
import { TldrawBoardRepo } from '@src/modules/tldraw/repo';
import * as mutex from 'lib0/mutex';
import Redis from 'ioredis';
import { Buffer } from 'node:buffer';
import { getDocUpdatesFromQueue, pushDocUpdatesToQueue } from '@modules/tldraw/redis';
import { applyUpdate, Doc } from 'yjs';

@Injectable()
export class TldrawWsService {
Expand All @@ -16,11 +21,23 @@ export class TldrawWsService {

public docs = new Map();

private mux: mutex.mutex;

private readonly redisUri: string;

private pub: Redis.Redis;

sub: Redis.Redis;

constructor(
private readonly configService: ConfigService<TldrawConfig, true>,
private readonly tldrawBoardRepo: TldrawBoardRepo
) {
this.pingTimeout = this.configService.get<number>('TLDRAW_PING_TIMEOUT');
this.redisUri = this.configService.get<string>('REDIS_URI');
this.mux = mutex.createMutex();
this.pub = new Redis(this.redisUri);
this.sub = new Redis(this.redisUri);
}

public setPersistence(persistence_: Persitence): void {
Expand Down Expand Up @@ -82,13 +99,18 @@ export class TldrawWsService {
* @param {WsSharedDocDo} doc
*/
public updateHandler(update: Uint8Array, origin, doc: WsSharedDocDo): void {
const encoder = encoding.createEncoder();
encoding.writeVarUint(encoder, WSMessageType.SYNC);
writeUpdate(encoder, update);
const message = encoding.toUint8Array(encoder);
doc.conns.forEach((_, conn) => {
this.send(doc, conn, message);
});
const isOriginWSConn = origin instanceof WebSocket && doc.conns.has(origin);

if (isOriginWSConn) {
void Promise.all([
this.pub.publishBuffer(doc.name, Buffer.from(update)),
pushDocUpdatesToQueue(this.pub, doc, update),
]);

this.propagateUpdate(update, doc);
} else {
this.propagateUpdate(update, doc);
}
}

/**
Expand All @@ -109,7 +131,7 @@ export class TldrawWsService {
});
}

public messageHandler(conn: WebSocket, doc: WsSharedDocDo, message: Uint8Array): void {
public async messageHandler(conn: WebSocket, doc: WsSharedDocDo, message: Uint8Array): Promise<void> {
try {
const encoder = encoding.createEncoder();
const decoder = decoding.createDecoder(message);
Expand All @@ -127,7 +149,9 @@ export class TldrawWsService {
}
break;
case WSMessageType.AWARENESS: {
applyAwarenessUpdate(doc.awareness, decoding.readVarUint8Array(decoder), conn);
const update = decoding.readVarUint8Array(decoder);
await this.pub.publishBuffer(doc.awarenessChannel, Buffer.from(update));
applyAwarenessUpdate(doc.awareness, update, conn);
break;
}
default:
Expand All @@ -142,15 +166,23 @@ export class TldrawWsService {
* @param {WebSocket} ws
* @param {string} docName
*/
public setupWSConnection(ws: WebSocket, docName = 'GLOBAL'): void {
public async setupWSConnection(ws: WebSocket, docName = 'GLOBAL'): Promise<void> {
ws.binaryType = 'arraybuffer';
// get doc, initialize if it does not exist yet
const doc = this.getYDoc(docName, true);
doc.conns.set(ws, new Set());

// listen and reply to events
ws.on('message', (message: ArrayBufferLike) => {
this.messageHandler(ws, doc, new Uint8Array(message));
void this.messageHandler(ws, doc, new Uint8Array(message));
});

const redisUpdates = await getDocUpdatesFromQueue(this.sub, doc);
const redisYDoc = new Doc();
redisYDoc.transact(() => {
for (const u of redisUpdates) {
applyUpdate(redisYDoc, u);
}
});

// Check if connection is still alive
Expand Down Expand Up @@ -205,4 +237,14 @@ export class TldrawWsService {
public async flushDocument(docName: string): Promise<void> {
await this.tldrawBoardRepo.flushDocument(docName);
}

private propagateUpdate(update: Uint8Array, doc: WsSharedDocDo): void {
const encoder = encoding.createEncoder();
encoding.writeVarUint(encoder, WSMessageType.SYNC);
writeUpdate(encoder, update);
const message = encoding.toUint8Array(encoder);
doc.conns.forEach((_, conn) => {
this.send(doc, conn, message);
});
}
}
6 changes: 3 additions & 3 deletions apps/server/src/modules/tldraw/tldraw.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import { Logger } from '@src/core/logger';
import { MikroOrmModule, MikroOrmModuleSyncOptions } from '@mikro-orm/nestjs';
import { TldrawDrawing } from '@src/modules/tldraw/entities';
import { AuthenticationModule } from '@src/modules/authentication/authentication.module';
import { RabbitMQWrapperTestModule } from '@infra/rabbitmq';
import { RabbitMQWrapperTestModule } from '@shared/infra/rabbitmq';
import { Dictionary, IPrimaryKey } from '@mikro-orm/core';
import { AuthorizationModule } from '@modules/authorization';
import { config } from './config';
import { TldrawService } from './service/tldraw.service';
import { TldrawBoardRepo } from './repo';
import { TldrawController } from './controller/tldraw.controller';
import { TldrawService } from './service/tldraw.service';
import { TldrawRepo } from './repo/tldraw.repo';
import { TldrawBoardRepo } from './repo';

const defaultMikroOrmOptions: MikroOrmModuleSyncOptions = {
findOneOrFailHandler: (entityName: string, where: Dictionary | IPrimaryKey) =>
Expand Down

0 comments on commit 962bbc2

Please sign in to comment.