From d3e21a406b8d1e453d34ba8cf1e08fc6238da5c8 Mon Sep 17 00:00:00 2001 From: Ben Davies Date: Sun, 14 May 2017 02:05:37 -0300 Subject: [PATCH] Sockets: Go rewrite Work in progress. I haven't written documentation for the Go code or confirmed whether this works as intended on Windows yet. This turned into a pretty substantial refactor of the Node version of sockets-related code to not only to make using Go child processes possible, but to optimize it and make unit testing of it and any code dependent on it possible to write entirely synchronously. sockets.js and sockets-workers.js are now written in Typescript, though they won't be able to be transpiled until after Config, Users, Dnsbl, and Monitor work with it as well. Fixes #2943 --- config/config-example.js | 7 + dev-tools/sockets.js | 35 ++ pokemon-showdown | 74 +++ sockets-workers.js | 555 +++++++++++++++++ sockets.js | 1005 +++++++++++++++++-------------- sockets/lib/commands.go | 73 +++ sockets/lib/commands_test.go | 31 + sockets/lib/config.go | 29 + sockets/lib/config_test.go | 40 ++ sockets/lib/ipc.go | 98 +++ sockets/lib/ipc_test.go | 64 ++ sockets/lib/master.go | 59 ++ sockets/lib/master_test.go | 75 +++ sockets/lib/multiplexer.go | 308 ++++++++++ sockets/lib/multiplexer_test.go | 111 ++++ sockets/main.go | 118 ++++ test/application/sockets.js | 249 ++------ tsconfig.json | 4 +- users.js | 17 + 19 files changed, 2309 insertions(+), 643 deletions(-) create mode 100644 dev-tools/sockets.js create mode 100644 sockets-workers.js create mode 100644 sockets/lib/commands.go create mode 100644 sockets/lib/commands_test.go create mode 100644 sockets/lib/config.go create mode 100644 sockets/lib/config_test.go create mode 100644 sockets/lib/ipc.go create mode 100644 sockets/lib/ipc_test.go create mode 100644 sockets/lib/master.go create mode 100644 sockets/lib/master_test.go create mode 100644 sockets/lib/multiplexer.go create mode 100644 sockets/lib/multiplexer_test.go create mode 100644 sockets/main.go diff --git a/config/config-example.js b/config/config-example.js index 1d62a9a14a75a..2a8d44e7fd283 100644 --- a/config/config-example.js +++ b/config/config-example.js @@ -10,6 +10,13 @@ exports.port = 8000; // know what you are doing. exports.proxyip = false; +// Go language - whether or not to use Go instead of Node.js to host the static +// and SockJS servers. Go is more likely to be more performant than Node.js for +// this purpose, but this should be kept set to false unless you're capable of +// debugging any issues that may arise due to the additional complexity of +// the code needed for this to run. +exports.golang = false; + // Pokemon of the Day - put a pokemon's name here to make it Pokemon of the Day // The PotD will always be in the #2 slot (not #1 so it won't be a lead) // in every Random Battle team. diff --git a/dev-tools/sockets.js b/dev-tools/sockets.js new file mode 100644 index 0000000000000..a4e79d77e12c7 --- /dev/null +++ b/dev-tools/sockets.js @@ -0,0 +1,35 @@ +'use strict'; + +const {Session, SockJSConnection} = require('sockjs/lib/transport'); + +const chars = 'abcdefghijklmnopqrstuvwxyz1234567890-'; +let sessionidCount = 0; + +/** + * @return string + */ +function generateSessionid() { + let ret = ''; + let idx = sessionidCount; + for (let i = 0; i < 8; i++) { + ret = chars[idx % chars.length] + ret; + idx = idx / chars.length | 0; + } + sessionidCount++; + return ret; +} + +/** + * @param {string} sessionid + * @param {{options: {{}}} config + * @return SockJSConnection + */ +exports.createSocket = function (sessionid = generateSessionid(), config = {options: {}}) { + let session = new Session(sessionid, config); + let socket = new SockJSConnection(session); + socket.remoteAddress = '127.0.0.1'; + socket.protocol = 'websocket'; + return socket; +}; + +// TODO: move worker mocks here, use require('../sockets-workers').Multiplexer to stub IPC diff --git a/pokemon-showdown b/pokemon-showdown index 553b327e0c5d5..22d77cec4d28d 100755 --- a/pokemon-showdown +++ b/pokemon-showdown @@ -2,6 +2,8 @@ 'use strict'; const child_process = require('child_process'); +const fs = require('fs'); +const path = require('path'); // Make sure we're Node 6+ @@ -22,6 +24,78 @@ try { child_process.execSync('npm install --production', {stdio: 'inherit'}); } +// Check if the server is configured to use Go, and ensure the required +// environment variables and dependencies are available if that is the case + +let config; +try { + config = require('./config/config'); +} catch (e) {} + +if (config && config.golang) { + if (!process.env.GOPATH) { + console.log('The GOPATH environment variable is not set! It is required in order to run the server using Go.'); + process.exit(0); + } + if (!process.env.GOROOT) { + console.log('The GOROOT environment variable is not set! It is required in order to run the server using Go.'); + process.exit(0); + } + + const dependencies = ['github.com/gorilla/mux', 'github.com/igm/sockjs-go/sockjs']; + let packages = child_process.execSync('go list all', {stdio: null, encoding: 'utf8'}); + for (let dep of dependencies) { + if (!packages.includes(dep)) { + console.log(`Dependency ${dep} is not installed. Fetching...`); + child_process.execSync(`go get ${dep}`, {stdio: 'inherit'}); + } + } + + const {GOPATH} = process.env; + let stat; + let needsSrcDir = false; + try { + stat = fs.lstatSync(path.resolve(GOPATH, 'src/github.com/Zarel')); + } catch (e) { + needsSrcDir = true; + } finally { + if (stat && !stat.isDirectory()) { + needsSrcDir = true; + } + } + + if (needsSrcDir) { + try { + fs.mkdirSync(path.resolve(GOPATH, 'src/github.com/Zarel')); + } catch (e) { + console.error(`Cannot make go source directory for the sockets library files! Symlink them manually from ${__dirname} to ${path.resolve(GOPATH, 'src/github.com/Zarel/Pokemon-Showdown/')}`); + process.exit(0); + } + } + + try { + stat = fs.lstatSync(path.resolve(GOPATH, 'src/github.com/Zarel/Pokemon-Showdown')); + } catch (e) {} + + if (!stat || !stat.isSymbolicLink()) { + try { + // FIXME: does this even work on Windows? Check to see if `mklink /J` + // might be needed instead + fs.symlink(__dirname, path.resolve(GOPATH, 'src/github.com/Zarel/Pokemon-Showdown')); + } catch (e) { + console.error(`Cannot make go source directory for the sockets library files! Symlink them manually from ${__dirname} to ${path.resolve(GOPATH, './src/github.com/Zarel/Pokemon-Showdown/')}`); + process.exit(0); + } + } + + console.log('Building Go source libs...'); + try { + child_process.execSync('go install github.com/Zarel/Pokemon-Showdown/sockets', {stdio: 'inherit'}); + } catch (e) { + process.exit(0); + } +} + // Start the server. We manually load app.js so it can be configured to run as // the main module, rather than this file being considered the main module. // This ensures any dependencies that were just installed can be found when diff --git a/sockets-workers.js b/sockets-workers.js new file mode 100644 index 0000000000000..d2a559609b7d8 --- /dev/null +++ b/sockets-workers.js @@ -0,0 +1,555 @@ +/** + * Connections + * Pokemon Showdown - http://pokemonshowdown.com/ + * + * Abstraction layer for multi-process SockJS connections. + * + * This file handles all the communications between the users' browsers and + * the main process. + * + * @license MIT license + */ + +'use strict'; + +const cluster = require('cluster'); +const fs = require('fs'); +const sockjs = require('sockjs'); +const StaticServer = require('node-static').Server; + +if (!global.Config) global.Config = require('./config/config'); +if (!global.Dnsbl) global.Dnsbl = require('./dnsbl'); +if (!global.Monitor) global.Monitor = require('./monitor'); + +// IPC command tokens +const EVAL = '$'; +const SOCKET_CONNECT = '*'; +const SOCKET_DISCONNECT = '!'; +const SOCKET_RECEIVE = '<'; +const SOCKET_SEND = '>'; +const CHANNEL_ADD = '+'; +const CHANNEL_REMOVE = '-'; +const CHANNEL_BROADCAST = '#'; +const SUBCHANNEL_ID_MOVE = '.'; +const SUBCHANNEL_ID_BROADCAST = ':'; + +// Subchannel IDs +const DEFAULT_SUBCHANNEL_ID = '0'; +const P1_SUBCHANNEL_ID = '1'; +const P2_SUBCHANNEL_ID = '2'; + +// Regex for splitting subchannel broadcasts between subchannels. +const SUBCHANNEL_ID_MESSAGE_REGEX = /\n\/split(\n[^\n]*)(\n[^\n]*)(\n[^\n]*)\n[^\n]*/g; + +/* + * @typedef {Map} Channel + * @typedef {Map} Sockets + * @typedef {Map} Channels + */ + +/** + * @class Multiplexer + * @description Manages the worker's state for sockets, channels, and + * subchannels. This is responsible for parsing all outgoing and incoming + * messages. + */ +class Multiplexer { + /** + * @param {number} socketCounter + * @param {Sockets} sockets + * @param {Channels} channels + * @param {NodeJS.Timer | null} cleanupInterval + */ + constructor() { + this.socketCounter = 0; + this.sockets = new Map(); + this.channels = new Map(); + this.cleanupInterval = setInterval(() => this.sweepClosedSockets(), 10 * 60 * 1000); + } + + /** + * @description Mitigates a potential bug in SockJS or Faye-Websocket where + * sockets fail to emit a 'close' event after having disconnected. + * @returns {void} + */ + sweepClosedSockets() { + this.sockets.forEach(socket => { + if (socket.protocol === 'xhr-streaming' && + socket._session && + socket._session.recv) { + socket._session.recv.didClose(); + } + + // A ghost connection's `_session.to_tref._idlePrev` (and `_idleNext`) property is `null` while + // it is an object for normal users. Under normal circumstances, those properties should only be + // `null` when the timeout has already been called, but somehow it's not happening for some connections. + // Simply calling `_session.timeout_cb` (the function bound to the aformentioned timeout) manually + // on those connections kills those connections. For a bit of background, this timeout is the timeout + // that sockjs sets to wait for users to reconnect within that time to continue their session. + if (socket._session && + socket._session.to_tref && + !socket._session.to_tref._idlePrev) { + socket._session.timeout_cb(); + } + }); + + // Don't bother deleting the sockets from our map; their close event + // handler will deal with it. + } + + /** + * @description Sends an IPC message to the parent process. + * @param {string} token + * @param {string[]} params + * @returns {void} + */ + sendUpstream(token, ...params) { + let message = `${token}${params.join('\n')}`; + // @ts-ignore + process.send(message); + } + + /** + * @description Parses the params in a downstream message sent as a + * command. + * @param {string} params + * @param {number} count + * @returns {string[]} + */ + parseParams(params, count) { + let i = 0; + let idx = 0; + let ret = []; + while (i++ < count) { + let newIdx = params.indexOf('\n', idx); + if (newIdx < 0) { + // No remaining newlines; just use the rest of the string as + // the last parametre. + ret.push(params.slice(idx)); + break; + } + + let param = params.slice(idx, newIdx); + if (i === count) { + // We reached the number of parametres needed, but there is + // still some remaining string left. Glue it to the last one. + param += `\n${params.slice(newIdx + 1)}`; + } else { + idx = newIdx + 1; + } + + ret.push(param); + } + + return ret; + } + + /** + * @description Parses downstream messages. + * @param {string} data + * @returns {boolean} + */ + receiveDownstream(data) { + let command = data.charAt(0); + let params = data.substr(1); + let socketid; + let channelid; + let subchannelid; + let message; + switch (command) { + case EVAL: + return this.onEval(params); + case SOCKET_DISCONNECT: + return this.onSocketDisconnect(params); + case SOCKET_SEND: + [socketid, message] = this.parseParams(params, 2); + return this.onSocketSend(socketid, message); + case CHANNEL_ADD: + [channelid, socketid] = this.parseParams(params, 2); + return this.onChannelAdd(channelid, socketid); + case CHANNEL_REMOVE: + [channelid, socketid] = this.parseParams(params, 2); + return this.onChannelRemove(channelid, socketid); + case CHANNEL_BROADCAST: + [channelid, message] = this.parseParams(params, 2); + return this.onChannelBroadcast(channelid, message); + case SUBCHANNEL_ID_MOVE: + [channelid, subchannelid, socketid] = this.parseParams(params, 3); + return this.onSubchannelMove(channelid, subchannelid, socketid); + case SUBCHANNEL_ID_BROADCAST: + [channelid, message] = this.parseParams(params, 2); + return this.onSubchannelBroadcast(channelid, message); + default: + Monitor.debug(`Sockets worker IPC error: unknown command type in downstream message: ${data}`); + return false; + } + } + + /** + * @description Safely tries to destroy a socket's connection. + * @param {any} socket + * @returns {void} + */ + tryDestroySocket(socket) { + try { + socket.end(); + socket.destroy(); + } catch (e) {} + } + + /** + * @description Eval handler for downstream messages. + * @param {string} expr + * @returns {boolean} + */ + onEval(expr) { + try { + eval(expr); + return true; + } catch (e) {} + return false; + } + + /** + * @description Sockets.socketConnect message handler. + * @param {any} socket + * @returns {boolean} + */ + onSocketConnect(socket) { + if (!socket) return false; + if (!socket.remoteAddress) { + this.tryDestroySocket(socket); + return false; + } + + let socketid = '' + this.socketCounter++; + let ip = socket.remoteAddress; + let ips = socket.headers['x-forwarded-for'] || ''; + this.sockets.set(socketid, socket); + this.sendUpstream(SOCKET_CONNECT, socketid, ip, ips, socket.protocol); + + socket.on('data', /** @param {string} message */ message => { + this.onSocketReceive(socketid, message); + }); + + socket.on('close', () => { + this.sendUpstream(SOCKET_DISCONNECT, socketid); + this.sockets.delete(socketid); + this.channels.forEach((channel, channelid) => { + if (!channel.has(socketid)) return; + channel.delete(socketid); + if (!channel.size) this.channels.delete(channelid); + }); + }); + + return true; + } + + /** + * @description Sockets.socketDisconnect message handler. + * @param {string} socketid + * @returns {boolean} + */ + onSocketDisconnect(socketid) { + let socket = this.sockets.get(socketid); + if (!socket) return false; + + this.tryDestroySocket(socket); + return true; + } + + /** + * @description Sockets.socketSend message handler. + * @param {string} socketid + * @param {string} message + * @returns {boolean} + */ + onSocketSend(socketid, message) { + let socket = this.sockets.get(socketid); + if (!socket) return false; + + socket.write(message); + return true; + } + + /** + * @description onmessage event handler for sockets. Passes the message + * upstream. + * @param {string} socketid + * @param {string} message + * @returns {boolean} + */ + onSocketReceive(socketid, message) { + // Drop empty messages (DDOS?). + if (!message) return false; + + // Drop >100KB messages. + if (message.length > (1000 * 1024)) { + console.log(`Dropping client message ${message.length / 1024} KB...`); + console.log(message.slice(0, 160)); + return false; + } + + // Drop legacy JSON messages. + if ((typeof message !== 'string') || message.startsWith('{')) return false; + + // Drop invalid messages (again, DDOS?). + if (!message.includes('|') || message.endsWith('|')) return false; + + this.sendUpstream(SOCKET_RECEIVE, socketid, message); + return true; + } + + /** + * @description Sockets.channelAdd message handler. + * @param {string} channelid + * @param {string} socketid + * @returns {boolean} + */ + onChannelAdd(channelid, socketid) { + if (!this.sockets.has(socketid)) return false; + + if (this.channels.has(channelid)) { + let channel = this.channels.get(channelid); + if (channel.has(socketid)) return false; + channel.set(socketid, DEFAULT_SUBCHANNEL_ID); + } else { + let channel = new Map(); + channel.set(socketid, DEFAULT_SUBCHANNEL_ID); + this.channels.set(channelid, channel); + } + + return true; + } + + /** + * @description Sockets.channelRemove message handler. + * @param {string} channelid + * @param {string} socketid + * @returns {boolean} + */ + onChannelRemove(channelid, socketid) { + let channel = this.channels.get(channelid); + if (!channel) return false; + + channel.delete(socketid); + if (!channel.size) this.channels.delete(channelid); + + return true; + } + + /** + * @description Sockets.channelSend and Sockets.channelBroadcast message + * handler. + * @param {string} channelid + * @param {string} message + * @returns {boolean} + */ + onChannelBroadcast(channelid, message) { + let channel = this.channels.get(channelid); + if (!channel) return false; + + channel.forEach( + /** @param {string} subchannelid */ + /** @param {string} socketid */ + (subchannelid, socketid) => { + let socket = this.sockets.get(socketid); + socket.write(message); + } + ); + + return true; + } + + /** + * @description Sockets.subchannelMove message handler. + * @param {string} channelid + * @param {string} subchannelid + * @param {string} socketid + * @returns {boolean} + */ + onSubchannelMove(channelid, subchannelid, socketid) { + if (!this.sockets.has(socketid)) return false; + + if (this.channels.has(channelid)) { + let channel = new Map([[socketid, subchannelid]]); + this.channels.set(channelid, channel); + } else { + let channel = this.channels.get(channelid); + channel.set(socketid, subchannelid); + } + + return true; + } + + /** + * @description Sockets.subchannelBroadcast message handler. + * @param {string} channelid + * @param {string} message + * @returns {boolean} + */ + onSubchannelBroadcast(channelid, message) { + let channel = this.channels.get(channelid); + if (!channel) return false; + + /** @type {RegExpExecArray | null} */ + let matches = SUBCHANNEL_ID_MESSAGE_REGEX.exec(message); + if (!matches) return false; + + let [match, msg1, msg2, msg3] = matches.splice(0); + channel.forEach( + /** @param {string} subchannelid */ + /** @param {string} socketid */ + (subchannelid, socketid) => { + let socket = this.sockets.get(socketid); + if (!socket) return; + + switch (subchannelid) { + case DEFAULT_SUBCHANNEL_ID: + socket.write(msg1); + break; + case P1_SUBCHANNEL_ID: + socket.write(msg2); + break; + case P2_SUBCHANNEL_ID: + socket.write(msg3); + break; + default: + Monitor.debug(`Sockets worker ${cluster.worker.id} received a message targeted at an unknown subchannel: ${match}`); + break; + } + } + ); + + return true; + } +} + +exports.Multiplexer = Multiplexer; + +if (cluster.isWorker) { + if (process.env.PSPORT) Config.port = +process.env.PSPORT; + if (process.env.PSBINDADDR) Config.bindaddress = process.env.PSBINDADDR; + if (+process.env.PSNOSSL) Config.ssl = null; + if (Config.crashguard) { + // Graceful crash. + process.on('uncaughtException', /** @param {Error} err */ err => { + require('./crashlogger')(err, `Socket process ${cluster.worker.id} (${process.pid})`, true); + }); + } + + // This is optional. If ofe is installed, it will take a heapdump if the + // process runs out of memory. + try { + require('ofe').call(); + } catch (e) {} + + let app = require('http').createServer(); + let appssl = null; + if (Config.ssl) { + let key; + let cert; + try { + key = fs.readFileSync(Config.ssl.options.key); + cert = fs.readFileSync(Config.ssl.options.cert); + Config.ssl.options.key = key; + Config.ssl.options.cert = cert; + } catch (e) { + console.error('The configured SSL key and cert must be the filenames of their according files now in order for Go processes to be able to host over HTTPS.'); + } finally { + appssl = require('https').createServer(Config.ssl.options); + } + } + + // Launch the static server. + try { + let cssserver = new StaticServer('./config'); + let avatarserver = new StaticServer('./config/avatars'); + let staticserver = new StaticServer('./static'); + /** @param {any} request */ + /** @param {any} response */ + let staticRequestHandler = (request, response) => { + // console.log("static rq: " + request.socket.remoteAddress + ":" + request.socket.remotePort + " -> " + request.socket.localAddress + ":" + request.socket.localPort + " - " + request.method + " " + request.url + " " + request.httpVersion + " - " + request.rawHeaders.join('|')); + request.resume(); + request.addListener('end', () => { + if (Config.customhttpresponse && + Config.customhttpresponse(request, response)) { + return; + } + let server; + if (request.url === '/custom.css') { + server = cssserver; + } else if (request.url.substr(0, 9) === '/avatars/') { + request.url = request.url.substr(8); + server = avatarserver; + } else { + if (/^\/([A-Za-z0-9][A-Za-z0-9-]*)\/?$/.test(request.url)) { + request.url = '/'; + } + server = staticserver; + } + + server.serve(request, response, + /** @param {any} e */ + /** @param {any} res */ + (e, res) => { + if (e && (e.status === 404)) { + staticserver.serveFile('404.html', 404, {}, request, response); + } + } + ); + }); + }; + app.on('request', staticRequestHandler); + if (appssl) appssl.on('request', staticRequestHandler); + } catch (e) {} + + // Launch the SockJS server. + /** @type {any} */ + const server = sockjs.createServer({ + sockjs_url: '//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js', + /** @param {string} severity */ + /** @param {string} message */ + log: (severity, message) => { + if (severity === 'error') Monitor.debug(`Sockets worker SockJS error: ${message}`); + }, + prefix: '/showdown', + }); + + // Instantiate SockJS' multiplexer. This takes messages received downstream + // from the parent process and distributes them across the sockets they are + // targeting, as well as handling user disconnects and passing user + // messages upstream. + const multiplexer = new Multiplexer(); + + process.on('message', /** @param {string} data */ data => { + // console.log('worker received: ' + data); + let ret = multiplexer.receiveDownstream(data); + if (!ret) { + Monitor.debug(`Sockets worker IPC error: failed to parse downstream message: ${data}`); + } + }); + + process.on('disconnect', () => { + process.exit(0); + }); + + server.on('connection', /** @param {any} socket */ socket => { + multiplexer.onSocketConnect(socket); + }); + + server.installHandlers(app, {}); + if (!Config.bindaddress) Config.bindaddress = '0.0.0.0'; + app.listen(Config.port, Config.bindaddress); + console.log(`Worker ${cluster.worker.id} now listening on ${Config.bindaddress}:${Config.port}`); + + if (appssl) { + server.installHandlers(appssl, {}); + appssl.listen(Config.ssl.port, Config.bindaddress); + console.log(`Worker ${cluster.worker.id} now listening for SSL on port ${Config.ssl.port}`); + } + + console.log(`Test your server at http://${Config.bindaddress === '0.0.0.0' ? 'localhost' : Config.bindaddress}:${Config.port}`); + + require('./repl').start('sockets-', `${cluster.worker.id}-${process.pid}`, /** @param {string} cmd */ cmd => eval(cmd)); +} diff --git a/sockets.js b/sockets.js index 72b87e160bc59..8e2a19aea6c24 100644 --- a/sockets.js +++ b/sockets.js @@ -4,9 +4,8 @@ * * Abstraction layer for multi-process SockJS connections. * - * This file handles all the communications between the users' - * browsers, the networking processes, and users.js in the - * main process. + * This file handles all the communications between the networking processes + * and users.js. * * @license MIT license */ @@ -14,508 +13,612 @@ 'use strict'; const cluster = require('cluster'); -global.Config = require('./config/config'); +const EventEmitter = require('events'); + +if (!global.Config) global.Config = require('./config/config'); if (cluster.isMaster) { cluster.setupMaster({ - exec: require('path').resolve(__dirname, 'sockets'), + exec: require('path').resolve(__dirname, 'sockets-workers'), }); +} - const workers = exports.workers = new Map(); - - const spawnWorker = exports.spawnWorker = function () { - let worker = cluster.fork({PSPORT: Config.port, PSBINDADDR: Config.bindaddress || '0.0.0.0', PSNOSSL: Config.ssl ? 0 : 1}); - let id = worker.id; - workers.set(id, worker); - worker.on('message', data => { - // console.log('master received: ' + data); - switch (data.charAt(0)) { - case '*': { - // *socketid, ip, protocol - // connect - let nlPos = data.indexOf('\n'); - let nlPos2 = data.indexOf('\n', nlPos + 1); - Users.socketConnect(worker, id, data.slice(1, nlPos), data.slice(nlPos + 1, nlPos2), data.slice(nlPos2 + 1)); - break; - } +/** @typedef {any} NodeJSWorker */ +cluster.Worker; // eslint-disable-line no-unused-expressions +/** @typedef {any} Socket */ +require('net').Socket; // eslint-disable-line no-unused-expressions +/** @typedef {NodeJSWorker | GoWorker} Worker */ - case '!': { - // !socketid - // disconnect - Users.socketDisconnect(worker, id, data.substr(1)); - break; +/** + * @description IPC delimiter byte. Required to parse messages sent to and from + * Go workers. + * @type {string} + */ +const DELIM = '\u0003'; + +/** + * @class WorkerWrapper + * @implements NodeJS.Cluster.Worker + * @description A wrapper class for native Node.js Worker and GoWorker + * instances. + */ +class WorkerWrapper { + /** + * @param {Worker} worker + * @prop {number} id; + * @prop {Worker} worker + * @prop {NodeJS.ChildProcess | null} process + * @prop {boolean | undefined} exitedAfterDisconnect + * @prop {(ip: string) => boolean} isTrustedProxyIp + */ + constructor(worker) { + this.id = worker.id; + this.worker = worker; + this.process = worker.process; + this.exitedAfterDisconnect = worker.exitedAfterDisconnect; + this.isTrustedProxyIp = Dnsbl.checker(Config.proxyip); + + worker.on('message', + /** @param {string} data */ + data => this.onMessage(data) + ); + worker.on('error', () => { + // Ignore. Neither kind of child process ever prints to stderr + // without throwing/panicking and emitting the diconnect/exit + // events. + }); + worker.once('disconnect', + /** @param {string} data */ + data => { + if (this.exitedAfterDisconnect !== undefined) return; + this.exitedAfterDisconnect = true; + process.nextTick(() => this.onDisconnect(data)); } + ); + worker.once('exit', + /** @param {number} code */ + /** @param {string} signal */ + (code, signal) => { + if (this.exitedAfterDisconnect !== undefined) return; + this.exitedAfterDisconnect = false; + process.nextTick(() => this.onExit(code, signal)); + } + ); + } + + /** + * @description Worker#suicide getter wrapper + * @returns {boolean | undefined} + */ + get suicide() { + return this.exitedAfterDisconnect; + } + + /** + * @description Worker#suicide setter wrapper + * @param {boolean} val + * @returns {void} + */ + set suicide(val) { + this.exitedAfterDisconnect = val; + } + + /** + * @description Worker#kill wrapper + * @param {string} signal + * @returns {void} + */ + kill(signal = 'SIGTERM') { + return this.worker.kill(signal); + } + + /** + * @description Worker#destroy wrapper + * @param {string} signal + * @returns {void} + */ + destroy(signal) { + return this.kill(signal); + } - case '<': { - // { - if (code === null && signal === 'SIGTERM') { - // worker was killed by Sockets.killWorker or Sockets.killPid - } else { - // worker crashed, try our best to clean up - require('./crashlogger')(new Error(`Worker ${worker.id} abruptly died`), "The main process"); - - // this could get called during cleanup; prevent it from crashing - // note: overwriting Worker#send is unnecessary in Node.js v7.0.0 and above - // see https://github.com/nodejs/node/commit/8c53d2fe9f102944cc1889c4efcac7a86224cf0a - worker.send = () => {}; - - let count = 0; - Users.connections.forEach(connection => { - if (connection.worker === worker) { - Users.socketDisconnect(worker, worker.id, connection.socketid); - count++; + ret.push(param); + } + + return ret; + } + + /** + * @description 'message' event handler for the worker. Parses which type + * of command the incoming IPC message uses, then parses its parametres and + * calls the appropriate Users method. + * @param {string} data + * @returns {boolean} + */ + onMessage(data) { + // console.log('master received: ' + data); + let command = data.charAt(0); + let params = data.substr(1); + switch (command) { + case '*': + let [socketid, ip, header, protocol] = this.parseParams(params, 4); + let ips; + if (this.isTrustedProxyIp(ip)) { + ips = (header || '').split(','); + for (let i = ips.length; i--;) { + ip = ips[i].trim() || ip; + if (!this.isTrustedProxyIp(ip)) break; } - }); - console.error(`${count} connections were lost.`); + } + Users.socketConnect(this.worker, this.id, socketid, ip, protocol); + break; + case '!': + Users.socketDisconnect(this.worker, this.id, params); + break; + case '<': + Users.socketReceive(this.worker, this.id, ...this.parseParams(params, 2)); + break; + default: + Monitor.debug(`Sockets: master received unknown IPC command type: ${data}`); + break; } + } - // don't delete the worker, so we can investigate it if necessary. + /** + * @description 'disconnect' event handler for the worker. Cleans up any + * remaining users whose sockets were contained by the worker's child + * process, then attempts to respawn it.. + * @param {string} data + * @returns {void} + */ + onDisconnect(data) { + require('./crashlogger')(new Error(`Worker ${this.id} abruptly died with the following stack trace: ${data}`), 'The main process'); + console.error(`${Users.socketDisconnectAll(this.worker)} connections were lost.`); + spawnWorker(); + } - // attempt to recover + /** + * @description 'exit' event handler for the worker. Only used by GoWorker + * instances, since the 'disconnect' event is only available for Node.js + * workers. + * @param {number} code + * @param {string?} signal + * @returns {void} + */ + onExit(code, signal) { + require('./crashlogger')(new Error(`Worker ${this.id} abruptly died with code ${code} and signal ${signal}`), 'The main process'); + console.error(`${Users.socketDisconnectAll(this.worker)} connections were lost.`); spawnWorker(); - }); + } +} - exports.listen = function (port, bindAddress, workerCount) { - if (port !== undefined && !isNaN(port)) { - Config.port = port; - Config.ssl = null; - } else { - port = Config.port; - // Autoconfigure the app when running in cloud hosting environments: - try { - let cloudenv = require('cloud-env'); - bindAddress = cloudenv.get('IP', bindAddress); - port = cloudenv.get('PORT', port); - } catch (e) {} - } - if (bindAddress !== undefined) { - Config.bindaddress = bindAddress; - } - if (workerCount === undefined) { - workerCount = (Config.workers !== undefined ? Config.workers : 1); - } - for (let i = 0; i < workerCount; i++) { - spawnWorker(); - } - }; - - exports.killWorker = function (worker) { - let count = 0; - Users.connections.forEach(connection => { - if (connection.worker === worker) { - Users.socketDisconnect(worker, worker.id, connection.socketid); - count++; - } - }); - console.log(`${count} connections were lost.`); +exports.WorkerWrapper = WorkerWrapper; - try { - worker.disconnect(); - worker.kill('SIGTERM'); - } catch (e) {} - workers.delete(worker.id); +/** + * @class GoWorker + * @extends NodeJS.EventEmitter + * @description A mock Worker class for Go child processes. Similarly to + * Node.js workers, it uses a TCP net server to perform IPC. After launching + * the server, it will spawn the Go child process and wait for it to make a + * connection to the worker's server before performing IPC with it. + */ +class GoWorker extends EventEmitter { + /** + * @param {number} id + * @prop {number} id + * @prop {NodeJS.ChildProcess | null} process + * @prop {boolean | undefined} exitedAfterDisconnect + * @prop {NodeJS.net.Server | null} server + * @prop {NodeJS.net.NodeSocket | null} connection + * @prop {string[]} buffer + */ + constructor(id) { + super(); + + this.id = id; + this.process = null; + this.exitedAfterDisconnect = undefined; + + this.server = null; + this.connection = null; + /** @type {string[]} */ + this.buffer = []; + + process.nextTick(() => this.spawnServer()); + } + + /** + * @description Worker#kill mock + * @param {string} signal + * @returns {void} + */ + kill(signal = 'SIGTERM') { + if (this.isConnected()) this.connection.end(); + if (!this.isDead() && this.process) this.process.kill(signal); + if (this.server) this.server.close(); + this.exitedAfterDisconnect = false; + } - return count; - }; + /** + * @description Worker#destroy mock + * @param {string=} signal + * @returns {void} + */ + destroy(signal) { + return this.kill(signal); + } - exports.killPid = function (pid) { - pid = '' + pid; - for (let [workerid, worker] of workers) { // eslint-disable-line no-unused-vars - if (pid === '' + worker.process.pid) { - return this.killWorker(worker); - } + /** + * @description Worker#send mock + * @param {string} message + * @param {any?} sendHandle + * @returns {void} + */ + send(message, sendHandle) { // eslint-disable-line no-unused-vars + if (!this.isConnected()) { + this.buffer.push(message); + return; } - return false; - }; - - exports.socketSend = function (worker, socketid, message) { - worker.send(`>${socketid}\n${message}`); - }; - exports.socketDisconnect = function (worker, socketid) { - worker.send(`!${socketid}`); - }; - - exports.channelBroadcast = function (channelid, message) { - workers.forEach(worker => { - worker.send(`#${channelid}\n${message}`); - }); - }; - exports.channelSend = function (worker, channelid, message) { - worker.send(`#${channelid}\n${message}`); - }; - exports.channelAdd = function (worker, channelid, socketid) { - worker.send(`+${channelid}\n${socketid}`); - }; - exports.channelRemove = function (worker, channelid, socketid) { - worker.send(`-${channelid}\n${socketid}`); - }; - - exports.subchannelBroadcast = function (channelid, message) { - workers.forEach(worker => { - worker.send(`:${channelid}\n${message}`); - }); - }; - exports.subchannelMove = function (worker, channelid, subchannelid, socketid) { - worker.send(`.${channelid}\n${subchannelid}\n${socketid}`); - }; -} else { - // is worker - - if (process.env.PSPORT) Config.port = +process.env.PSPORT; - if (process.env.PSBINDADDR) Config.bindaddress = process.env.PSBINDADDR; - if (+process.env.PSNOSSL) Config.ssl = null; - - // ofe is optional - // if installed, it will heap dump if the process runs out of memory - try { - require('ofe').call(); - } catch (e) {} - // Static HTTP server + if (this.buffer.length) { + this.buffer.splice(0).forEach(msg => { + this.connection.write(JSON.stringify(msg) + DELIM); + }); + } - // This handles the custom CSS and custom avatar features, and also - // redirects yourserver:8001 to yourserver-8001.psim.us + return this.connection.write(JSON.stringify(message) + DELIM); + } - // It's optional if you don't need these features. + /** + * @description Worker#isConnected mock + * @returns {boolean} + */ + isConnected() { + return this.connection && !this.connection.destroyed; + } - global.Dnsbl = require('./dnsbl'); + /** + * @description Worker#isDead mock + * @returns {boolean} + */ + isDead() { + return !this.process || this.connection.exitCode !== null || this.connection.statusCode !== null; + } - if (Config.crashguard) { - // graceful crash - process.on('uncaughtException', err => { - require('./crashlogger')(err, `Socket process ${cluster.worker.id} (${process.pid})`, true); + /** + * @description Spawns the TCP server through which IPC with the child + * process is handled. + * @returns {boolean} + */ + spawnServer() { + if (!this.isDead()) return false; + + this.server = require('net').createServer(); + this.server.on('error', console.error); + this.server.once('listening', () => { + // Spawn the child process after the TCP server has finished + // launching to allow it to connect to it for IPC. + process.nextTick(() => this.spawnChild()); + }); + // When the child process finally connects to the TCP server we can + // begin communicating with it using a random port. + this.server.listen(() => { + if (!this.server) return; + this.server.once('connection', connection => { + process.nextTick(() => this.bootstrapChild(connection)); + }); }); } - let app = require('http').createServer(); - let appssl; - if (Config.ssl) { - appssl = require('https').createServer(Config.ssl.options); - } - try { - let nodestatic = require('node-static'); - let cssserver = new nodestatic.Server('./config'); - let avatarserver = new nodestatic.Server('./config/avatars'); - let staticserver = new nodestatic.Server('./static'); - let staticRequestHandler = (request, response) => { - // console.log("static rq: " + request.socket.remoteAddress + ":" + request.socket.remotePort + " -> " + request.socket.localAddress + ":" + request.socket.localPort + " - " + request.method + " " + request.url + " " + request.httpVersion + " - " + request.rawHeaders.join('|')); - request.resume(); - request.addListener('end', () => { - if (Config.customhttpresponse && - Config.customhttpresponse(request, response)) { - return; - } - let server; - if (request.url === '/custom.css') { - server = cssserver; - } else if (request.url.substr(0, 9) === '/avatars/') { - request.url = request.url.substr(8); - server = avatarserver; - } else { - if (/^\/([A-Za-z0-9][A-Za-z0-9-]*)\/?$/.test(request.url)) { - request.url = '/'; - } - server = staticserver; - } - server.serve(request, response, (e, res) => { - if (e && (e.status === 404)) { - staticserver.serveFile('404.html', 404, {}, request, response); - } - }); - }); - }; - app.on('request', staticRequestHandler); - if (appssl) { - appssl.on('request', staticRequestHandler); - } - } catch (e) { - console.log('Could not start node-static - try `npm install` if you want to use it'); + /** + * @description Spawns the Go child process. Once the process has started, + * it will make a connection to the worker's TCP server. + * @returns {void} + */ + spawnChild() { + if (!this.server) return this.spawnServer(); + this.process = require('child_process').spawn( + `${process.env.GOPATH}/bin/sockets`, [], { + env: { + GOPATH: process.env.GOPATH || '', + GOROOT: process.env.GOROOT || '', + PS_IPC_PORT: `:${this.server.address().port}`, + PS_CONFIG: JSON.stringify({ + workers: Config.workers || 1, + port: `:${Config.port || 8000}`, + bindAddress: Config.bindaddress || '0.0.0.0', + ssl: Config.ssl || null, + }), + }, + stdio: ['inherit', 'inherit', 'pipe'], + shell: true, + } + ); + + this.process.once('exit', (code, signal) => { + process.nextTick(() => this.emit('exit', code, signal)); + }); + + this.process.stderr.setEncoding('utf8'); + this.process.stderr.once('data', data => { + process.nextTick(() => this.emit('error', data)); + }); } - // SockJS server + /** + * @description 'connection' event handler for the TCP server. Begins + * the parsing of incoming IPC messages. + * @param {Socket} connection + * @returns {void} + */ + bootstrapChild(connection) { + this.connection = connection; + this.connection.setEncoding('utf8'); + this.connection.on('data', + /** @param {string} data */ + data => { + let messages = data.slice(0, -1).split(DELIM); + messages.forEach(message => { + this.emit('message', JSON.parse(message)); + }); + } + ); - // This is the main server that handles users connecting to our server - // and doing things on our server. + // Leave the error handling to the process, not the connection. + this.connection.on('error', () => {}); + } +} - const sockjs = require('sockjs'); +exports.GoWorker = GoWorker; - const server = sockjs.createServer({ - sockjs_url: "//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js", - log: (severity, message) => { - if (severity === 'error') console.log('ERROR: ' + message); - }, - prefix: '/showdown', - }); +/** + * @description Map of worker IDs to worker processes. + * @type {Map} + */ +const workers = exports.workers = new Map(); - const sockets = new Map(); - const channels = new Map(); - const subchannels = new Map(); - - // Deal with phantom connections. - const sweepClosedSockets = () => { - sockets.forEach(socket => { - if (socket.protocol === 'xhr-streaming' && - socket._session && - socket._session.recv) { - socket._session.recv.didClose(); - } +/** + * @description Worker ID counter used for Go workers. + * @type {number} + */ +let nextWorkerid = 0; - // A ghost connection's `_session.to_tref._idlePrev` (and `_idleNext`) property is `null` while - // it is an object for normal users. Under normal circumstances, those properties should only be - // `null` when the timeout has already been called, but somehow it's not happening for some connections. - // Simply calling `_session.timeout_cb` (the function bound to the aformentioned timeout) manually - // on those connections kills those connections. For a bit of background, this timeout is the timeout - // that sockjs sets to wait for users to reconnect within that time to continue their session. - if (socket._session && - socket._session.to_tref && - !socket._session.to_tref._idlePrev) { - socket._session.timeout_cb(); - } +/** + * @description Spawns a new worker process. + * @returns {Worker} + */ +function spawnWorker() { + let worker; + if (Config.golang) { + worker = new GoWorker(nextWorkerid); + } else { + worker = cluster.fork({ + PSPORT: Config.port, + PSBINDADDR: Config.bindaddress || '0.0.0.0', + PSNOSSL: Config.ssl ? 0 : 1, }); - }; - const interval = setInterval(sweepClosedSockets, 1000 * 60 * 10); // eslint-disable-line no-unused-vars - - process.on('message', data => { - // console.log('worker received: ' + data); - let socket = null; - let socketid = ''; - let channel = null; - let channelid = ''; - let subchannel = null; - let subchannelid = ''; - let nlLoc = -1; - let message = ''; - - switch (data.charAt(0)) { - case '$': // $code - eval(data.substr(1)); - break; + } - case '!': // !socketid - // destroy - socketid = data.substr(1); - socket = sockets.get(socketid); - if (!socket) return; - socket.end(); - // After sending the FIN packet, we make sure the I/O is totally blocked for this socket - socket.destroy(); - sockets.delete(socketid); - channels.forEach(channel => channel.delete(socketid)); - break; + let wrapper = new WorkerWrapper(worker); + workers.set(wrapper.id, wrapper); + nextWorkerid++; + return wrapper; +} - case '>': - // >socketid, message - // message - nlLoc = data.indexOf('\n'); - socketid = data.substr(1, nlLoc - 1); - socket = sockets.get(socketid); - if (!socket) return; - message = data.substr(nlLoc + 1); - socket.write(message); - break; +exports.spawnWorker = spawnWorker; - case '#': - // #channelid, message - // message to channel - nlLoc = data.indexOf('\n'); - channelid = data.substr(1, nlLoc - 1); - channel = channels.get(channelid); - if (!channel) return; - message = data.substr(nlLoc + 1); - channel.forEach(socket => socket.write(message)); - break; +/** + * @description Initializes the configured number of worker processes. + * @param {any} port + * @param {any} bindAddress + * @param {any} workerCount + * @returns {void} + */ +exports.listen = function (port, bindAddress, workerCount) { + if (port !== undefined && !isNaN(port)) { + Config.port = port; + Config.ssl = null; + } else { + port = Config.port; + // Autoconfigure the app when running in cloud hosting environments: + try { + let cloudenv = require('cloud-env'); + bindAddress = cloudenv.get('IP', bindAddress); + port = cloudenv.get('PORT', port); + } catch (e) {} + } + if (bindAddress !== undefined) { + Config.bindaddress = bindAddress; + } - case '+': - // +channelid, socketid - // add to channel - nlLoc = data.indexOf('\n'); - socketid = data.substr(nlLoc + 1); - socket = sockets.get(socketid); - if (!socket) return; - channelid = data.substr(1, nlLoc - 1); - channel = channels.get(channelid); - if (!channel) { - channel = new Map(); - channels.set(channelid, channel); - } - channel.set(socketid, socket); - break; + // Go only uses one child process since it does not share FD handles for + // serving like Node.js workers do. Workers are instead used to limit the + // number of concurrent requests that can be handled at once in the child + // process. + if (Config.golang) { + spawnWorker(); + return; + } - case '-': - // -channelid, socketid - // remove from channel - nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - channel = channels.get(channelid); - if (!channel) return; - socketid = data.slice(nlLoc + 1); - channel.delete(socketid); - subchannel = subchannels.get(channelid); - if (subchannel) subchannel.delete(socketid); - if (!channel.size) { - channels.delete(channelid); - if (subchannel) subchannels.delete(channelid); - } - break; + if (workerCount === undefined) { + workerCount = (Config.workers !== undefined ? Config.workers : 1); + } + for (let i = 0; i < workerCount; i++) { + spawnWorker(); + } +}; - case '.': - // .channelid, subchannelid, socketid - // move subchannel - nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - let nlLoc2 = data.indexOf('\n', nlLoc + 1); - subchannelid = data.slice(nlLoc + 1, nlLoc2); - socketid = data.slice(nlLoc2 + 1); - - subchannel = subchannels.get(channelid); - if (!subchannel) { - subchannel = new Map(); - subchannels.set(channelid, subchannel); - } - if (subchannelid === '0') { - subchannel.delete(socketid); - } else { - subchannel.set(socketid, subchannelid); - } - break; +/** + * @description Kills a worker process using the given worker object. + * @param {Worker} worker + * @returns {number} + */ +exports.killWorker = function (worker) { + let count = Users.socketDisconnectAll(worker); + try { + worker.kill(); + } catch (e) {} + workers.delete(worker.id); + return count; +}; - case ':': - // :channelid, message - // message to subchannel - nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - channel = channels.get(channelid); - if (!channel) return; - - let messages = [null, null, null]; - message = data.substr(nlLoc + 1); - subchannel = subchannels.get(channelid); - channel.forEach((socket, socketid) => { - switch (subchannel ? subchannel.get(socketid) : '0') { - case '1': - if (!messages[1]) { - messages[1] = message.replace(/\n\|split\n[^\n]*\n([^\n]*)\n[^\n]*\n[^\n]*/g, '\n$1'); - } - socket.write(messages[1]); - break; - case '2': - if (!messages[2]) { - messages[2] = message.replace(/\n\|split\n[^\n]*\n[^\n]*\n([^\n]*)\n[^\n]*/g, '\n$1'); - } - socket.write(messages[2]); - break; - default: - if (!messages[0]) { - messages[0] = message.replace(/\n\|split\n([^\n]*)\n[^\n]*\n[^\n]*\n[^\n]*/g, '\n$1'); - } - socket.write(messages[0]); - break; - } - }); - break; +/** + * @description Kills a worker process using the given worker PID. + * @param {number} pid + * @returns {number | false} + */ +exports.killPid = function (pid) { + workers.forEach(worker => { + if (pid === worker.process.pid) { + return this.killWorker(worker); } }); + return false; +}; - // Clean up any remaining connections on disconnect. If this isn't done, - // the process will not exit until any remaining connections have been destroyed. - // Afterwards, the worker process will die on its own. - process.once('disconnect', () => { - sockets.forEach(socket => { - try { - socket.end(); - socket.destroy(); - } catch (e) {} - }); - sockets.clear(); - channels.clear(); - subchannels.clear(); - app.close(); - if (appssl) appssl.close(); - }); +/** + * @description Sends a message to a socket in a given worker by ID. + * @param {Worker} worker + * @param {string} socketid + * @param {string} message + * @returns {void} + */ +exports.socketSend = function (worker, socketid, message) { + worker.send(`>${socketid}\n${message}`); +}; - // this is global so it can be hotpatched if necessary - let isTrustedProxyIp = Dnsbl.checker(Config.proxyip); - let socketCounter = 0; - server.on('connection', socket => { - if (!socket) { - // For reasons that are not entirely clear, SockJS sometimes triggers - // this event with a null `socket` argument. - return; - } else if (!socket.remoteAddress) { - // This condition occurs several times per day. It may be a SockJS bug. - try { - socket.end(); - } catch (e) {} - return; - } +/** + * @description Forcefully disconnects a socket in a given worker by ID. + * @param {Worker} worker + * @param {string} socketid + * @returns {void} + */ +exports.socketDisconnect = function (worker, socketid) { + worker.send(`!${socketid}`); +}; - let socketid = socket.id = '' + (++socketCounter); - sockets.set(socketid, socket); - - if (isTrustedProxyIp(socket.remoteAddress)) { - let ips = (socket.headers['x-forwarded-for'] || '').split(','); - let ip; - while ((ip = ips.pop())) { - ip = ip.trim(); - if (!isTrustedProxyIp(ip)) { - socket.remoteAddress = ip; - break; - } - } - } +/** + * @description Broadcasts a message to all sockets in a given channel across + * all workers. + * @param {string} channelid + * @param {string} message + * @returns {void} + */ +exports.channelBroadcast = function (channelid, message) { + workers.forEach(worker => { + worker.send(`#${channelid}\n${message}`); + }); +}; - process.send(`*${socketid}\n${socket.remoteAddress}\n${socket.protocol}`); +/** + * @description Broadcasts a message to all sockets in a given channel and a + * given worker. + * @param {Worker} worker + * @param {string} channelid + * @param {string} message + * @returns {void} + */ +exports.channelSend = function (worker, channelid, message) { + worker.send(`#${channelid}\n${message}`); +}; - socket.on('data', message => { - // drop empty messages (DDoS?) - if (!message) return; - // drop messages over 100KB - if (message.length > 100000) { - console.log(`Dropping client message ${message.length / 1024} KB...`); - console.log(message.slice(0, 160)); - return; - } - // drop legacy JSON messages - if (typeof message !== 'string' || message.startsWith('{')) return; - // drop blank messages (DDoS?) - let pipeIndex = message.indexOf('|'); - if (pipeIndex < 0 || pipeIndex === message.length - 1) return; +/** + * @description Adds a socket to a given channel in a given worker by ID. + * @param {Worker} worker + * @param {string} channelid + * @param {string} socketid + * @returns {void} + */ +exports.channelAdd = function (worker, channelid, socketid) { + worker.send(`+${channelid}\n${socketid}`); +}; - process.send(`<${socketid}\n${message}`); - }); +/** + * @description Removes a socket from a given channel in a given worker by ID. + * @param {Worker} worker + * @param {string} channelid + * @param {string} socketid + * @returns {void} + */ +exports.channelRemove = function (worker, channelid, socketid) { + worker.send(`-${channelid}\n${socketid}`); +}; - socket.on('close', () => { - process.send(`!${socketid}`); - sockets.delete(socketid); - channels.forEach(channel => channel.delete(socketid)); - }); +/** + * @description Broadcasts a message to be demuxed into three separate messages + * across three subchannels in a given channel across all workers. + * @param {string} channelid + * @param {string} message + * @returns {void} + */ +exports.subchannelBroadcast = function (channelid, message) { + workers.forEach(worker => { + worker.send(`:${channelid}\n${message}`); }); - server.installHandlers(app, {}); - app.listen(Config.port, Config.bindaddress); - console.log(`Worker ${cluster.worker.id} now listening on ${Config.bindaddress}:${Config.port}`); - - if (appssl) { - server.installHandlers(appssl, {}); - appssl.listen(Config.ssl.port, Config.bindaddress); - console.log(`Worker ${cluster.worker.id} now listening for SSL on port ${Config.ssl.port}`); - } - - console.log(`Test your server at http://${Config.bindaddress === '0.0.0.0' ? 'localhost' : Config.bindaddress}:${Config.port}`); +}; - require('./repl').start(`sockets-${cluster.worker.id}-${process.pid}`, cmd => eval(cmd)); -} +/** + * @description Moves a given socket to a different subchannel in a channel by + * ID in the given worker. + * @param {Worker} worker + * @param {string} channelid + * @param {string} subchannelid + * @param {string} socketid + */ +exports.subchannelMove = function (worker, channelid, subchannelid, socketid) { + worker.send(`.${channelid}\n${subchannelid}\n${socketid}`); +}; diff --git a/sockets/lib/commands.go b/sockets/lib/commands.go new file mode 100644 index 0000000000000..613e0296b95d5 --- /dev/null +++ b/sockets/lib/commands.go @@ -0,0 +1,73 @@ +package sockets + +import "strings" + +const SOCKET_CONNECT string = "*" +const SOCKET_DISCONNECT string = "!" +const SOCKET_RECEIVE string = "<" +const SOCKET_SEND string = ">" +const CHANNEL_ADD string = "+" +const CHANNEL_REMOVE string = "-" +const CHANNEL_BROADCAST string = "#" +const SUBCHANNEL_MOVE string = "." +const SUBCHANNEL_BROADCAST string = ":" + +type Command struct { + token string + paramstr string + count int + target CommandIO +} + +type CommandIO interface { + Process(Command) (err error) +} + +func NewCommand(msg string, target CommandIO) Command { + var count int + token := string(msg[:1]) + paramstr := msg[1:] + + switch token { + case SOCKET_DISCONNECT: + count = 1 + case SOCKET_RECEIVE: + count = 2 + case SOCKET_SEND: + count = 2 + case CHANNEL_ADD: + count = 2 + case CHANNEL_REMOVE: + count = 2 + case CHANNEL_BROADCAST: + count = 2 + case SUBCHANNEL_BROADCAST: + count = 2 + case SUBCHANNEL_MOVE: + count = 3 + case SOCKET_CONNECT: + count = 4 + } + + return Command{ + token: token, + paramstr: paramstr, + count: count, + target: target} +} + +func (c Command) Token() string { + return c.token +} + +func (c Command) Params() []string { + return strings.SplitN(c.paramstr, "\n", c.count) +} + +func (c Command) Message() string { + return c.token + c.paramstr +} + +func (c Command) Process() { + c.target.Process(c) +} diff --git a/sockets/lib/commands_test.go b/sockets/lib/commands_test.go new file mode 100644 index 0000000000000..0a34f8b5701c1 --- /dev/null +++ b/sockets/lib/commands_test.go @@ -0,0 +1,31 @@ +package sockets + +import "testing" + +type testTarget struct { + CommandIO +} + +func TestCommands(t *testing.T) { + tokens := []string{ + SOCKET_CONNECT, + SOCKET_DISCONNECT, + SOCKET_RECEIVE, + SOCKET_SEND, + CHANNEL_ADD, + CHANNEL_REMOVE, + CHANNEL_BROADCAST, + SUBCHANNEL_MOVE, + SUBCHANNEL_BROADCAST} + + cmds := make([]Command, len(tokens)) + for i, token := range tokens { + cmds[i] = NewCommand(token+"1\n2\n3\n4", testTarget{}) + } + for _, cmd := range cmds { + params := cmd.Params() + if len(params) != cmd.count { + t.Errorf("Commands: command type %v was expected to return %v tokens but actually returned %v", cmd.token, cmd.count, len(params)) + } + } +} diff --git a/sockets/lib/config.go b/sockets/lib/config.go new file mode 100644 index 0000000000000..637ab7e72a4a0 --- /dev/null +++ b/sockets/lib/config.go @@ -0,0 +1,29 @@ +package sockets + +import ( + "encoding/json" + "os" +) + +type config struct { + Workers int `json:"workers"` + Port string `json:"port"` + BindAddress string `json:"bindAddress"` + SSL sslOpts `json:"ssl"` +} + +type sslOpts struct { + Port string `json:"port"` + Options sslKeys `json:"options"` +} + +type sslKeys struct { + Cert string `json:"cert"` + Key string `json:"key"` +} + +func NewConfig(envVar string) (c config, err error) { + configEnv := os.Getenv(envVar) + err = json.Unmarshal([]byte(configEnv), &c) + return +} diff --git a/sockets/lib/config_test.go b/sockets/lib/config_test.go new file mode 100644 index 0000000000000..b7adba05f7673 --- /dev/null +++ b/sockets/lib/config_test.go @@ -0,0 +1,40 @@ +package sockets + +import ( + "encoding/json" + "fmt" + "testing" +) + +func newTestConfig(w int, p string, ba string, s interface{}) (c config) { + c = config{ + Workers: w, + Port: p, + BindAddress: ba} + if ssl, ok := s.(sslOpts); ok { + c.SSL = ssl + } + return +} + +func TestConfig(t *testing.T) { + t.Parallel() + ws := []int{1, 2, 3, 4} + ps := []string{":1000", ":2000", ":4000", ":8000"} + bas := []string{"127.0.0.1", "0.0.0.0", "192.168.0.1", "localhost"} + ssl := sslOpts{Port: ":443", Options: sslKeys{Cert: "", Key: ""}} + for _, w := range ws { + for _, p := range ps { + for _, ba := range bas { + t.Run(fmt.Sprintf("%v %v%v", w, ba, p, ssl), func(t *testing.T) { + go func(w int, p string, ba string, ssl sslOpts) { + c := newTestConfig(w, p, ba, ssl) + if _, err := json.Marshal(c); err != nil { + t.Errorf("Config: failed to stringify config JSON: %v", err) + } + }(w, p, ba, ssl) + }) + } + } + } +} diff --git a/sockets/lib/ipc.go b/sockets/lib/ipc.go new file mode 100644 index 0000000000000..b6e2a0aaa0be4 --- /dev/null +++ b/sockets/lib/ipc.go @@ -0,0 +1,98 @@ +package sockets + +import ( + "bufio" + "encoding/json" + "fmt" + "net" + "os" +) + +const DELIM byte = '\u0003' + +type Connection struct { + port string + addr *net.TCPAddr + conn *net.TCPConn + mux *Multiplexer + listening bool +} + +func NewConnection(envVar string) (c *Connection, err error) { + port := os.Getenv(envVar) + addr, err := net.ResolveTCPAddr("tcp", "localhost"+port) + if err != nil { + return nil, fmt.Errorf("Sockets: failed to parse TCP address to connect to the parent process with: %v", err) + } + + conn, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return nil, fmt.Errorf("Sockets: failed to connect to TCP server: %v", err) + } + + c = &Connection{ + port: port, + addr: addr, + conn: conn, + listening: false} + + return +} + +func (c *Connection) Listening() bool { + return c.listening +} + +func (c *Connection) Listen(mux *Multiplexer) { + if c.listening { + return + } + + c.mux = mux + c.listening = true + + go func() { + reader := bufio.NewReader(c.conn) + for { + var token []byte + token, err := reader.ReadBytes(DELIM) + if len(token) == 0 || err != nil { + continue + } + + var msg string + err = json.Unmarshal(token[:len(token)-1], &msg) + cmd := NewCommand(msg, c.mux) + CmdQueue <- cmd + } + }() + + return +} + +func (c *Connection) Process(cmd Command) (err error) { + // fmt.Printf("Sockets => IPC: %v\n", cmd.Message()) + if !c.listening { + return fmt.Errorf("Sockets: can't process connection commands when the connection isn't listening yet") + } + + msg := cmd.Message() + _, err = c.Write(msg) + return +} + +func (c *Connection) Close() error { + return c.conn.Close() +} + +func (c *Connection) Write(message string) (int, error) { + if !c.listening { + return 0, fmt.Errorf("Sockets: can't write messages over a connection that isn't listening yet...") + } + + msg, err := json.Marshal(message) + if err != nil { + return 0, fmt.Errorf("Sockets: failed to parse upstream IPC message: %v", err) + } + return c.conn.Write(append(msg, DELIM)) +} diff --git a/sockets/lib/ipc_test.go b/sockets/lib/ipc_test.go new file mode 100644 index 0000000000000..101cdafa97079 --- /dev/null +++ b/sockets/lib/ipc_test.go @@ -0,0 +1,64 @@ +package sockets + +import ( + "net" + "os" + "testing" +) + +type testMux struct { + CommandIO +} + +func (tm *testMux) Listen(conn CommandIO) (err error) { + return nil +} + +func (tm *testMux) Process(cmd Command) (err error) { + return nil +} + +func TestConnection(t *testing.T) { + port := ":3000" + ln, err := net.Listen("tcp", "localhost"+port) + defer ln.Close() + if err != nil { + t.Errorf("Sockets: failed to launch TCP server on port %v: %v", port, err) + } + + envVar := "PS_IPC_PORT" + err = os.Setenv(envVar, port) + if err != nil { + t.Errorf("Sockets: failed to set %v environment variable: %v", envVar, port) + } + + conn, err := NewConnection(envVar) + defer conn.Close() + if err != nil { + t.Errorf("%v", err) + } + if conn.port != port { + t.Errorf("Sockets: new connection expected to have port %v but had %v instead", port, conn.port) + } + + mux := NewMultiplexer() + conn.Listen(mux) + mux.Listen(conn) + + cmd := NewCommand(SOCKET_SEND+"0\n|ayy lmao", mux) + err = conn.Process(cmd) + if err != nil { + t.Errorf("%v", err) + } + + bc, err := conn.Write(string(DELIM)) + if err != nil { + t.Errorf("%v", err) + } + bc += 3 // For the escaped backslashes and additional DELIM character + + cbc := len([]byte(cmd.Message())) + if bc != cbc { + t.Errorf("Sockets: expected the number of bytes received by the connection to be %v, but actually received %v", bc, cbc) + } +} diff --git a/sockets/lib/master.go b/sockets/lib/master.go new file mode 100644 index 0000000000000..f5d2e9ca37ec3 --- /dev/null +++ b/sockets/lib/master.go @@ -0,0 +1,59 @@ +package sockets + +var CmdQueue = make(chan Command) + +type master struct { + wpool chan chan Command + count int +} + +func NewMaster(count int) *master { + wpool := make(chan chan Command, count) + return &master{ + wpool: wpool, + count: count} +} + +func (m *master) Spawn() { + for i := 0; i < m.count; i++ { + w := newWorker(m.wpool) + w.listen() + } +} + +func (m *master) Listen() { + for { + cmd := <-CmdQueue + cmdch := <-m.wpool + cmdch <- cmd + } +} + +type worker struct { + wpool chan chan Command + cmdch chan Command + quit chan bool +} + +func newWorker(wpool chan chan Command) *worker { + cmdch := make(chan Command) + quit := make(chan bool) + return &worker{ + wpool: wpool, + cmdch: cmdch, + quit: quit} +} + +func (w *worker) listen() { + go func() { + for { + w.wpool <- w.cmdch + select { + case cmd := <-w.cmdch: + cmd.target.Process(cmd) + case <-w.quit: + return + } + } + }() +} diff --git a/sockets/lib/master_test.go b/sockets/lib/master_test.go new file mode 100644 index 0000000000000..be1c44788fef4 --- /dev/null +++ b/sockets/lib/master_test.go @@ -0,0 +1,75 @@ +package sockets + +import ( + "net" + "os" + "testing" + + "github.com/igm/sockjs-go/sockjs" +) + +type testSocket struct { + sockjs.Session +} + +func (ts testSocket) Send(msg string) error { + return nil +} + +func (ts testSocket) Close(code uint32, signal string) error { + return nil +} + +func TestMasterListen(t *testing.T) { + t.Parallel() + ln, _ := net.Listen("tcp", ":3000") + defer ln.Close() + + envVar := "PS_IPC_PORT" + os.Setenv(envVar, ":3000") + conn, _ := NewConnection(envVar) + defer conn.Close() + mux := NewMultiplexer() + mux.Listen(conn) + conn.Listen(mux) + + m := NewMaster(4) + m.Spawn() + go m.Listen() + + for i := 0; i < m.count*250; i++ { + id := string(i) + t.Run("Worker/Multiplexer command #"+id, func(t *testing.T) { + go func(id string, mux *Multiplexer, conn *Connection) { + mux.smux.Lock() + sid := string(mux.nsid) + mux.sockets[sid] = testSocket{} + mux.nsid++ + mux.smux.Unlock() + + cmd := NewCommand(SOCKET_DISCONNECT+sid, mux) + cmd.Process() + if len(CmdQueue) != 0 { + t.Error("Sockets: master failed to pass command struct from worker to multiplexer") + } + }(id, mux, conn) + }) + t.Run("Worker/Connection command #"+id, func(t *testing.T) { + go func(id string, mux *Multiplexer, conn *Connection) { + mux.smux.Lock() + sid := string(mux.nsid) + mux.smux.Unlock() + + cmd := NewCommand(SOCKET_CONNECT+sid+"\n0.0.0.0\n\nwebsocket", conn) + cmd.Process() + if len(CmdQueue) != 0 { + t.Error("Sockets: master failed to pass command struct from worker to connection") + } + }(id, mux, conn) + }) + } + + for len(m.wpool) > 0 { + <-m.wpool + } +} diff --git a/sockets/lib/multiplexer.go b/sockets/lib/multiplexer.go new file mode 100644 index 0000000000000..a7e514a3cabc0 --- /dev/null +++ b/sockets/lib/multiplexer.go @@ -0,0 +1,308 @@ +package sockets + +import ( + "fmt" + "net" + "path" + "regexp" + "strconv" + "sync" + + "github.com/igm/sockjs-go/sockjs" +) + +const DEFAULT_SUBCHANNEL_ID string = "0" +const P1_SUBCHANNEL_ID string = "1" +const P2_SUBCHANNEL_ID string = "2" + +type Multiplexer struct { + nsid uint64 + sockets map[string]sockjs.Session + smux sync.Mutex + channels map[string]map[string]string + cmux sync.Mutex + scre *regexp.Regexp + conn *Connection +} + +func NewMultiplexer() *Multiplexer { + sockets := make(map[string]sockjs.Session) + channels := make(map[string]map[string]string) + scre := regexp.MustCompile(`\n/split(\n[^\n]*)(\n[^\n]*)(\n[^\n]*)\n[^\n]*`) + return &Multiplexer{ + sockets: sockets, + channels: channels, + scre: scre} +} + +func (m *Multiplexer) Listen(conn *Connection) { + m.conn = conn +} + +func (m *Multiplexer) Process(cmd Command) (err error) { + // fmt.Printf("IPC => Sockets: %v\n", cmd.Message()) + params := cmd.Params() + + switch token := cmd.Token(); token { + case SOCKET_DISCONNECT: + sid := params[0] + err = m.socketRemove(sid, true) + case SOCKET_SEND: + sid := params[0] + msg := params[1] + err = m.socketSend(sid, msg) + case SOCKET_RECEIVE: + sid := params[0] + msg := params[1] + err = m.socketReceive(sid, msg) + case CHANNEL_ADD: + cid := params[0] + sid := params[1] + err = m.channelAdd(cid, sid) + case CHANNEL_REMOVE: + cid := params[0] + sid := params[1] + err = m.channelRemove(cid, sid) + case CHANNEL_BROADCAST: + cid := params[0] + msg := params[1] + err = m.channelBroadcast(cid, msg) + case SUBCHANNEL_MOVE: + cid := params[0] + scid := params[1] + sid := params[2] + err = m.subchannelMove(cid, scid, sid) + case SUBCHANNEL_BROADCAST: + cid := params[0] + msg := params[1] + err = m.subchannelBroadcast(cid, msg) + } + + if err != nil { + // Something went wrong somewhere, but it's likely a timing issue from + // the parent process. Let's just log the error instead of crashing. + fmt.Printf("%v\n", err) + } + + return +} + +func (m *Multiplexer) socketAdd(s sockjs.Session) (sid string) { + m.smux.Lock() + defer m.smux.Unlock() + + sid = strconv.FormatUint(m.nsid, 10) + m.nsid++ + m.sockets[sid] = s + + if m.conn.Listening() { + req := s.Request() + ip, _, _ := net.SplitHostPort(req.RemoteAddr) + ips := req.Header.Get("X-Forwarded-For") + protocol := path.Base(req.URL.Path) + + cmd := NewCommand(SOCKET_CONNECT+sid+"\n"+ip+"\n"+ips+"\n"+protocol, m.conn) + CmdQueue <- cmd + } + + return +} + +func (m *Multiplexer) socketRemove(sid string, forced bool) error { + m.smux.Lock() + defer m.smux.Unlock() + + m.cmux.Lock() + for cid, c := range m.channels { + if _, ok := c[sid]; ok { + delete(c, sid) + if len(c) == 0 { + delete((*m).channels, cid) + } + } + } + m.cmux.Unlock() + + s, ok := m.sockets[sid] + if ok { + delete((*m).sockets, sid) + } else { + return fmt.Errorf("Sockets: attempted to remove socket of ID %v that doesn't exist", sid) + } + + if forced { + s.Close(2010, "Normal closure") + } else { + // User disconnected on their own. Poke the parent process to clean up. + if m.conn.Listening() { + cmd := NewCommand(SOCKET_DISCONNECT+sid, m.conn) + CmdQueue <- cmd + } + } + + return nil +} + +func (m *Multiplexer) socketReceive(sid string, msg string) error { + m.smux.Lock() + defer m.smux.Unlock() + + if _, ok := m.sockets[sid]; ok { + if m.conn.Listening() { + cmd := NewCommand(SOCKET_RECEIVE+sid+"\n"+msg, m.conn) + CmdQueue <- cmd + } + return nil + } + + return fmt.Errorf("Sockets: received a message for a socket of ID %v that does not exist: %v", sid, msg) +} + +func (m *Multiplexer) socketSend(sid string, msg string) error { + m.smux.Lock() + defer m.smux.Unlock() + + if s, ok := m.sockets[sid]; ok && m.conn.Listening() { + s.Send(msg) + return nil + } + + return fmt.Errorf("Sockets: attempted to send to socket of ID %v, which does not exist", sid) +} + +func (m *Multiplexer) channelAdd(cid string, sid string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if !ok { + c = make(map[string]string) + m.channels[cid] = c + } + + c[sid] = DEFAULT_SUBCHANNEL_ID + + return nil +} + +func (m *Multiplexer) channelRemove(cid string, sid string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if ok { + if _, ok = c[sid]; !ok { + return fmt.Errorf("Sockets: failed to remove nonexistent socket of ID %v from channel %v", sid, cid) + } + } else { + // This occasionally happens on user disconnect. + return nil + } + + delete(c, sid) + if len(c) == 0 { + delete((*m).channels, cid) + } + + return nil +} + +func (m *Multiplexer) channelBroadcast(cid string, msg string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if !ok { + // This happens occasionally when the last user leaves a room. Mitigate + return nil + } + + m.smux.Lock() + defer m.smux.Unlock() + + for sid, _ := range c { + var s sockjs.Session + if s, ok = m.sockets[sid]; ok { + if m.conn.Listening() { + s.Send(msg) + } + } else { + delete(c, sid) + } + } + + return nil +} + +func (m *Multiplexer) subchannelMove(cid string, scid string, sid string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if !ok { + return fmt.Errorf("Sockets: attempted to move socket of ID %v in channel %v, which does not exist, to subchannel %v", sid, cid, scid) + } + + c[sid] = scid + + return nil +} + +func (m *Multiplexer) subchannelBroadcast(cid string, msg string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if !ok { + return fmt.Errorf("Sockets: attempted to broadcast to subchannels in channel %v, which doesn't exist: %v", cid, msg) + } + + m.smux.Lock() + defer m.smux.Unlock() + + match := m.scre.FindAllStringSubmatch(msg, len(msg)) + for sid, scid := range c { + s, ok := m.sockets[sid] + if !ok { + return fmt.Errorf("Sockets: attempted to broadcast to subchannels in channel %v, but socket of ID %v doesn't exist: %v", cid, sid, msg) + } + + var msg string + for _, msgs := range match { + switch scid { + case DEFAULT_SUBCHANNEL_ID: + msg = msgs[1] + case P1_SUBCHANNEL_ID: + msg = msgs[2] + case P2_SUBCHANNEL_ID: + msg = msgs[3] + } + } + + if m.conn.Listening() { + s.Send(msg) + } + } + + return nil +} + +func (m *Multiplexer) Handler(s sockjs.Session) { + sid := m.socketAdd(s) + for { + if msg, err := s.Recv(); err == nil { + if err = m.socketReceive(sid, msg); err != nil { + // Likely a SockJS glitch if this happens at all. + fmt.Printf("%v\n", err) + break + } + continue + } + break + } + + if err := m.socketRemove(sid, false); err != nil { + // Socket was already removed by a message from the parent process. + fmt.Printf("%v\n", err) + } +} diff --git a/sockets/lib/multiplexer_test.go b/sockets/lib/multiplexer_test.go new file mode 100644 index 0000000000000..93ad3becb74a5 --- /dev/null +++ b/sockets/lib/multiplexer_test.go @@ -0,0 +1,111 @@ +package sockets + +import ( + "net" + "testing" +) + +func TestMultiplexer(t *testing.T) { + port := ":3000" + ln, _ := net.Listen("tcp", "localhost"+port) + defer ln.Close() + conn, _ := NewConnection("PS_IPC_PORT") + defer conn.Close() + mux := NewMultiplexer() + mux.Listen(conn) + + t.Run("*Multiplexer.socketAdd", func(t *testing.T) { + sid := mux.socketAdd(testSocket{}) + if len(mux.sockets) != 1 { + t.Errorf("Sockets: adding sockets to multiplexer doesn't keep them instead") + } + mux.socketRemove(sid, true) + }) + t.Run("*Multiplexer.socketRemove", func(t *testing.T) { + sid := mux.socketAdd(testSocket{}) + if err := mux.socketRemove(sid, true); err != nil { + t.Errorf("%v", err) + } + if len(mux.sockets) != 0 { + t.Errorf("Sockets: forcibly removing sockets from multiplexer keeps them instead") + } + sid = mux.socketAdd(testSocket{}) + if err := mux.socketRemove(sid, false); err != nil { + t.Errorf("%v", err) + } + if len(mux.sockets) != 0 { + t.Fatalf("Sockets: sockets removing themselves from multiplexer keeps them instead") + } + }) + t.Run("*Multiplexer.channelAdd", func(t *testing.T) { + sid := mux.socketAdd(testSocket{}) + if err := mux.channelAdd("global", sid); err != nil { + t.Errorf("%v", err) + } + if len(mux.channels) != 1 { + t.Errorf("Sockets: adding channels to multiplexer doesn't keep them instead") + } + if err := mux.channelAdd("global", sid); err != nil { + t.Errorf("%v", err) + } + mux.channelRemove("global", sid) + mux.socketRemove(sid, true) + }) + t.Run("*Multiplexer.channelRemove", func(t *testing.T) { + sid := mux.socketAdd(testSocket{}) + mux.channelAdd("global", sid) + if err := mux.channelRemove("global", sid); err != nil { + t.Errorf("%v", err) + } + if len(mux.channels) != 0 { + t.Errorf("Sockets: removing channels from multiplexer keeps them instead") + } + if err := mux.channelRemove("global", sid); err != nil { + t.Errorf("%v", err) + } + mux.socketRemove(sid, true) + }) + t.Run("*Multiplexer.channelBroadcast", func(t *testing.T) { + sid := mux.socketAdd(testSocket{}) + mux.channelAdd("global", sid) + if err := mux.channelBroadcast("global", "|raw|ayy lmao"); err != nil { + t.Errorf("%v", err) + } + mux.channelRemove("global", sid) + if err := mux.channelBroadcast("global", "|raw|ayy lmao"); err != nil { + t.Errorf("%v", err) + } + mux.socketRemove(sid, true) + }) + t.Run("*Multiplexer.subchannelMove", func(t *testing.T) { + sid := mux.socketAdd(testSocket{}) + mux.channelAdd("global", sid) + if err := mux.subchannelMove("global", P1_SUBCHANNEL_ID, sid); err != nil { + t.Errorf("%v", err) + } + mux.channelRemove("global", sid) + mux.socketRemove(sid, true) + }) + t.Run("*Multiplexer.subchannelBroadcast", func(t *testing.T) { + msg := "\n/split\n1\n2\n3\nrest" + matches := mux.scre.FindAllStringSubmatch(msg, len(msg)) + msgs := matches[0] + if msgs[1] != "\n1" { + t.Errorf("Sockets: expected broadcast to subchannel '0' to be %v, but was actually %v", "\n1", msgs[1]) + } + if msgs[2] != "\n2" { + t.Errorf("Sockets: expected broadcast to subchannel '1' to be %v, but was actually %v", "\n2", msgs[2]) + } + if msgs[3] != "\n3" { + t.Errorf("Sockets: expected broadcast to subchannel '2' to be %v, but was actually %v", "\n3", msgs[3]) + } + + sid := mux.socketAdd(testSocket{}) + mux.channelAdd("global", sid) + if err := mux.subchannelBroadcast("global", msg); err != nil { + t.Errorf("%v", err) + } + mux.channelRemove("global", sid) + mux.socketRemove(sid, true) + }) +} diff --git a/sockets/main.go b/sockets/main.go new file mode 100644 index 0000000000000..32a17166ed921 --- /dev/null +++ b/sockets/main.go @@ -0,0 +1,118 @@ +package main + +import ( + "crypto/tls" + "fmt" + "log" + "net" + "net/http" + "path/filepath" + + "github.com/Zarel/Pokemon-Showdown/sockets/lib" + + routing "github.com/gorilla/mux" + "github.com/igm/sockjs-go/sockjs" +) + +func notFoundHandler(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/404.html", http.StatusSeeOther) +} + +func main() { + // Parse our config settings passed through the $PS_CONFIG environment + // variable by the parent process. + config, err := sockets.NewConfig("PS_CONFIG") + if err != nil { + log.Fatal("Sockets: failed to read parent's config settings from environment") + } + + // Instantiate the socket multiplexer and IPC struct.. + mux := sockets.NewMultiplexer() + conn, err := sockets.NewConnection("PS_IPC_PORT") + if err != nil { + log.Fatal(err) + } + defer conn.Close() + + // Begin listening for incoming messages from sockets and the TCP + // connection to the parent process. For now, they'll just get enqueued + // for workers to manage later.. + mux.Listen(conn) + err = conn.Listen(mux) + if err != nil { + log.Fatal("%v", err) + } + + // Set up server routing. + r := routing.NewRouter() + + avatarDir, _ := filepath.Abs("./config/avatars") + r.PathPrefix("/avatars/"). + Handler(http.FileServer(http.Dir(avatarDir))) + + customCSSDir, _ := filepath.Abs("./config") + r.Handle("/custom.css", http.FileServer(http.Dir(customCSSDir))) + + // Set up the SockJS server. + opts := sockjs.Options{ + SockJSURL: "//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js", + Websocket: true, + HeartbeatDelay: sockjs.DefaultOptions.HeartbeatDelay, + DisconnectDelay: sockjs.DefaultOptions.DisconnectDelay, + JSessionID: sockjs.DefaultOptions.JSessionID} + + r.PathPrefix("/showdown"). + Handler(sockjs.NewHandler("/showdown", opts, mux.Handler)) + + staticDir, _ := filepath.Abs("/static") + r.Handle("/", http.StripPrefix("/static", http.FileServer(http.Dir(staticDir)))) + + r.NotFoundHandler = http.HandlerFunc(notFoundHandler) + + // Begin serving over HTTPS if configured to do so. + if config.SSL.Options.Cert != "" && config.SSL.Options.Key != "" { + go func(ba string, port string, cert string, key string) { + certs, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + log.Fatalf("Sockets: failed to load certificate and key files for TLS: %v", err) + } + + srv := &http.Server{ + Handler: r, + Addr: ba + port, + TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}}} + + var ln net.Listener + ln, err = tls.Listen("tcp4", srv.Addr, srv.TLSConfig) + defer ln.Close() + if err != nil { + log.Fatalf("Sockets: failed to listen on %v over HTTPS", srv.Addr) + } + + fmt.Printf("Sockets: now serving on https://%v%v/\n", ba, port) + log.Fatal(http.Serve(ln, r)) + }(config.BindAddress, config.SSL.Port, config.SSL.Options.Cert, config.SSL.Options.Key) + } + + // Begin serving over HTTP. + go func(ba string, port string) { + srv := &http.Server{ + Handler: r, + Addr: ba + port} + + ln, err := net.Listen("tcp4", srv.Addr) + defer ln.Close() + if err != nil { + log.Fatalf("Sockets: failed to listen on %v over HTTP", srv.Addr) + } + + fmt.Printf("Sockets: now serving on http://%v%v/\n", ba, port) + log.Fatal(http.Serve(ln, r)) + }(config.BindAddress, config.Port) + + // Finally, spawn workers.to pipe messages received at the multiplexer or + // IPC connection to each other concurrently. + master := sockets.NewMaster(config.Workers) + master.Spawn() + master.Listen() +} diff --git a/test/application/sockets.js b/test/application/sockets.js index 7fc963ebe9bf2..37d9284c6092a 100644 --- a/test/application/sockets.js +++ b/test/application/sockets.js @@ -1,212 +1,79 @@ 'use strict'; const assert = require('assert'); -const cluster = require('cluster'); -describe.skip('Sockets', function () { - const spawnWorker = () => ( - new Promise(resolve => { - let worker = Sockets.spawnWorker(); - worker.removeAllListeners('message'); - resolve(worker); - }) - ); +const {createSocket} = require('../../dev-tools/sockets'); +describe('Sockets workers', function () { before(function () { - cluster.settings.silent = true; - cluster.removeAllListeners('disconnect'); + this.mux = new (require('../../sockets-workers')).Multiplexer(); + clearInterval(this.mux.cleanupInterval); + this.mux.cleanupInterval = null; + this.mux.sendUpstream = () => {}; }); - afterEach(function () { - Sockets.workers.forEach((worker, workerid) => { - worker.kill(); - Sockets.workers.delete(workerid); - }); + beforeEach(function () { + this.socket = createSocket(); }); - describe('master', function () { - it('should be able to spawn workers', function () { - Sockets.spawnWorker(); - assert.strictEqual(Sockets.workers.size, 1); - }); - - it('should be able to spawn workers on listen', function () { - Sockets.listen(0, '127.0.0.1', 1); - assert.strictEqual(Sockets.workers.size, 1); - }); - - it('should be able to kill workers', function () { - return spawnWorker().then(worker => { - Sockets.killWorker(worker); - assert.strictEqual(Sockets.workers.size, 0); - }); - }); - - it('should be able to kill workers by PID', function () { - return spawnWorker().then(worker => { - Sockets.killPid(worker.process.pid); - assert.strictEqual(Sockets.workers.size, 0); - }); - }); + afterEach(function () { + this.mux.tryDestroySocket(this.socket); + this.mux.channels.clear(); }); - describe('workers', function () { - // This composes a sequence of HOFs that send a message to a worker, - // wait for its response, then return the worker for the next function - // to use. - const chain = (eventHandler, msg) => worker => { - worker.once('message', eventHandler(worker)); - msg = msg || `$ - const {Session} = require('sockjs/lib/transport'); - const socket = new Session('aaaaaaaa', server); - socket.remoteAddress = '127.0.0.1'; - if (!('headers' in socket)) socket.headers = {}; - socket.headers['x-forwarded-for'] = ''; - socket.protocol = 'websocket'; - socket.write = msg => process.send(msg); - server.emit('connection', socket);`; - worker.send(msg); - return worker; - }; - - const spawnSocket = eventHandler => spawnWorker().then(chain(eventHandler)); - - it('should allow sockets to connect', function () { - return spawnSocket(worker => data => { - let cmd = data.charAt(0); - let [sid, ip, protocol] = data.substr(1).split('\n'); - assert.strictEqual(cmd, '*'); - assert.strictEqual(sid, '1'); - assert.strictEqual(ip, '127.0.0.1'); - assert.strictEqual(protocol, 'websocket'); - }); - }); - - it('should allow sockets to disconnect', function () { - let querySocket; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - querySocket = `$ - let socket = sockets.get(${sid}); - process.send(!socket);`; - Sockets.socketDisconnect(worker, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, querySocket)); - }); - - it('should allow sockets to send messages', function () { - let msg = 'ayy lmao'; - let socketSend; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - socketSend = `>${sid}\n${msg}`; - }).then(chain(worker => data => { - assert.strictEqual(data, msg); - }, socketSend)); - }); + after(function () { + this.socket = null; + this.mux.sockets.clear(); + this.mux = null; + }); - it('should allow sockets to receive messages', function () { - let sid; - let msg; - let mockReceive; - return spawnSocket(worker => data => { - sid = data.substr(1, data.indexOf('\n')); - msg = '|/cmd rooms'; - mockReceive = `$ - let socket = sockets.get(${sid}); - socket.emit('data', ${msg});`; - }).then(chain(worker => data => { - let cmd = data.charAt(0); - let params = data.substr(1).split('\n'); - assert.strictEqual(cmd, '<'); - assert.strictEqual(sid, params[0]); - assert.strictEqual(msg, params[1]); - }, mockReceive)); - }); + it('should parse more than two params', function () { + let params = '1\n1\n0\n'; + let ret = this.mux.parseParams(params, 4); + assert.deepStrictEqual(ret, ['1', '1', '0', '']); + }); - it('should create a channel for the first socket to get added to it', function () { - let queryChannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'global'; - queryChannel = `$ - let channel = channels.get(${cid}); - process.send(channel && channel.has(${sid}));`; - Sockets.channelAdd(worker, cid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, queryChannel)); - }); + it('should parse params with multiple newlines', function () { + let params = '0\n|1\n|2'; + let ret = this.mux.parseParams(params, 2); + assert.deepStrictEqual(ret, ['0', '|1\n|2']); + }); - it('should remove a channel if the last socket gets removed from it', function () { - let queryChannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'global'; - queryChannel = `$ - process.send(!sockets.has(${sid}) && !channels.has(${cid}));`; - Sockets.channelAdd(worker, cid, sid); - Sockets.channelRemove(worker, cid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, queryChannel)); - }); + it('should add sockets on connect', function () { + let res = this.mux.onSocketConnect(this.socket); + assert.ok(res); + }); - it('should send to all sockets in a channel', function () { - let msg = 'ayy lmao'; - let cid = 'global'; - let channelSend = `#${cid}\n${msg}`; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - Sockets.channelAdd(worker, cid, sid); - }).then(chain(worker => data => { - assert.strictEqual(data, msg); - }, channelSend)); - }); + it('should remove sockets on disconnect', function () { + this.mux.onSocketConnect(this.socket); + let res = this.mux.onSocketDisconnect('0', this.socket); + assert.ok(res); + }); - it('should create a subchannel when moving a socket to it', function () { - let querySubchannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'battle-ou-1'; - let scid = '1'; - querySubchannel = `$ - let subchannel = subchannels[${cid}]; - process.send(!!subchannel && (subchannel.get(${sid}) === ${scid}));`; - Sockets.subchannelMove(worker, cid, scid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, querySubchannel)); - }); + it('should add sockets to channels', function () { + this.mux.onSocketConnect(this.socket); + let res = this.mux.onChannelAdd('global', '0'); + assert.ok(res); + res = this.mux.onChannelAdd('global', '0'); + assert.ok(!res); + this.mux.channels.set('lobby', new Map()); + res = this.mux.onChannelAdd('lobby', '0'); + assert.ok(res); + }); - it('should remove a subchannel when removing its last socket', function () { - let querySubchannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'battle-ou-1'; - let scid = '1'; - querySubchannel = `$ - let subchannel = subchannels.get(${cid}); - process.send(!!subchannel && (subchannel.get(${sid}) === ${scid}));`; - Sockets.subchannelMove(worker, cid, scid, sid); - Sockets.channelRemove(worker, cid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, querySubchannel)); - }); + it('should remove sockets from channels', function () { + this.mux.onSocketConnect(this.socket); + this.mux.onChannelAdd('global', '0'); + let res = this.mux.onChannelRemove('global', '0'); + assert.ok(res); + res = this.mux.onChannelRemove('global', '0'); + assert.ok(!res); + }); - it('should send to sockets in a subchannel', function () { - let cid = 'battle-ou-1'; - let msg = 'ayy lmao'; - let subchannelSend = `.${cid}\n\n|split\n\n${msg}\n\n`; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let scid = '1'; - Sockets.subchannelMove(worker, cid, scid, sid); - }).then(chain(worker => data => { - assert.strictEqual(data, msg); - }, subchannelSend)); - }); + it('should move sockets to subchannels', function () { + this.mux.onSocketConnect(this.socket); + this.mux.onChannelAdd('global', '0'); + let res = this.mux.onSubchannelMove('global', '1', '0'); + assert.ok(res); }); }); diff --git a/tsconfig.json b/tsconfig.json index 98ccbf6d5fae3..018c14885f740 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -15,6 +15,8 @@ "./sim/prng.js", "./crashlogger.js", "./dnsbl.js", - "./repl.js" + "./repl.js", + "./sockets.js", + "./sockets-workers.js" ] } diff --git a/users.js b/users.js index 50252b5c1869e..2b8e054ec7c37 100644 --- a/users.js +++ b/users.js @@ -1610,3 +1610,20 @@ Users.socketReceive = function (worker, workerid, socketid, message) { Monitor.warn(`[slow] ${deltaTime}ms - ${user.name} <${connection.ip}>: ${roomId}|${message}`); } }; + +/** + * @description Clears all connections whose sockets were contained by a + * worker. Called after a worker's process crashes or gets killed. + * @param {object} worker + * @returns {number} + */ +Users.socketDisconnectAll = function (worker) { + let count = 0; + connections.forEach(connection => { + if (connection.worker === worker) { + Users.socketDisconnect(worker, worker.id, connection.socketid); + count++; + } + }); + return count; +};