Skip to content

Commit

Permalink
Merge pull request #1094 from paullouisageneau/websocket-enforce-max-…
Browse files Browse the repository at this point in the history
…message-size

Enforce WebSocket max message size at reception
  • Loading branch information
paullouisageneau authored Jan 18, 2024
2 parents 080a798 + 7a25011 commit dbdfb49
Show file tree
Hide file tree
Showing 14 changed files with 128 additions and 62 deletions.
29 changes: 29 additions & 0 deletions include/rtc/configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,35 @@ struct RTC_CPP_EXPORT Configuration {
optional<size_t> maxMessageSize;
};

#ifdef RTC_ENABLE_WEBSOCKET

struct WebSocketConfiguration {
bool disableTlsVerification = false; // if true, don't verify the TLS certificate
optional<ProxyServer> proxyServer; // only non-authenticated http supported for now
std::vector<string> protocols;
optional<std::chrono::milliseconds> connectionTimeout; // zero to disable
optional<std::chrono::milliseconds> pingInterval; // zero to disable
optional<int> maxOutstandingPings;
optional<string> caCertificatePemFile;
optional<string> certificatePemFile;
optional<string> keyPemFile;
optional<string> keyPemPass;
optional<size_t> maxMessageSize;
};

struct WebSocketServerConfiguration {
uint16_t port = 8080;
bool enableTls = false;
optional<string> certificatePemFile;
optional<string> keyPemFile;
optional<string> keyPemPass;
optional<string> bindAddress;
optional<std::chrono::milliseconds> connectionTimeout;
optional<size_t> maxMessageSize;
};

#endif

} // namespace rtc

#endif
4 changes: 3 additions & 1 deletion include/rtc/rtc.h
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ typedef struct {
int connectionTimeoutMs; // in milliseconds, 0 means default, < 0 means disabled
int pingIntervalMs; // in milliseconds, 0 means default, < 0 means disabled
int maxOutstandingPings; // 0 means default, < 0 means disabled
int maxMessageSize; // <= 0 means default
} rtcWsConfiguration;

RTC_C_EXPORT int rtcCreateWebSocket(const char *url); // returns ws id
Expand All @@ -441,8 +442,9 @@ typedef struct {
const char *certificatePemFile; // NULL for autogenerated certificate
const char *keyPemFile; // NULL for autogenerated certificate
const char *keyPemPass; // NULL if no pass
const char *bindAddress; // NULL for IP_ANY_ADDR
const char *bindAddress; // NULL for any
int connectionTimeoutMs; // in milliseconds, 0 means default, < 0 means disabled
int maxMessageSize; // <= 0 means default
} rtcWsServerConfiguration;

RTC_C_EXPORT int rtcCreateWebSocketServer(const rtcWsServerConfiguration *config,
Expand Down
16 changes: 2 additions & 14 deletions include/rtc/websocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

#include "channel.hpp"
#include "common.hpp"
#include "configuration.hpp" // for ProxyServer
#include "configuration.hpp"

namespace rtc {

Expand All @@ -32,19 +32,7 @@ class RTC_CPP_EXPORT WebSocket final : private CheshireCat<impl::WebSocket>, pub
Closed = 3,
};

struct Configuration {
bool disableTlsVerification = false; // if true, don't verify the TLS certificate
optional<ProxyServer> proxyServer; // only non-authenticated http supported for now
std::vector<string> protocols;
optional<std::chrono::milliseconds> connectionTimeout; // zero to disable
optional<std::chrono::milliseconds> pingInterval; // zero to disable
optional<int> maxOutstandingPings;
optional<string> caCertificatePemFile;
optional<string> certificatePemFile;
optional<string> keyPemFile;
optional<string> keyPemPass;
optional<size_t> maxMessageSize;
};
using Configuration = WebSocketConfiguration;

WebSocket();
WebSocket(Configuration config);
Expand Down
11 changes: 2 additions & 9 deletions include/rtc/websocketserver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#if RTC_ENABLE_WEBSOCKET

#include "common.hpp"
#include "configuration.hpp"
#include "websocket.hpp"

namespace rtc {
Expand All @@ -24,15 +25,7 @@ struct WebSocketServer;

class RTC_CPP_EXPORT WebSocketServer final : private CheshireCat<impl::WebSocketServer> {
public:
struct Configuration {
uint16_t port = 8080;
bool enableTls = false;
optional<string> certificatePemFile;
optional<string> keyPemFile;
optional<string> keyPemPass;
optional<string> bindAddress;
optional<std::chrono::milliseconds> connectionTimeout;
};
using Configuration = WebSocketServerConfiguration;

WebSocketServer();
WebSocketServer(Configuration config);
Expand Down
7 changes: 7 additions & 0 deletions src/capi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1479,6 +1479,9 @@ int rtcCreateWebSocketEx(const char *url, const rtcWsConfiguration *config) {
else if (config->maxOutstandingPings < 0)
c.maxOutstandingPings = 0; // setting to 0 disables, not setting keeps default

if(config->maxMessageSize > 0)
c.maxMessageSize = size_t(config->maxMessageSize);

auto webSocket = std::make_shared<WebSocket>(std::move(c));
webSocket->open(url);
return emplaceWebSocket(webSocket);
Expand Down Expand Up @@ -1533,6 +1536,10 @@ RTC_C_EXPORT int rtcCreateWebSocketServer(const rtcWsServerConfiguration *config
c.keyPemFile = config->keyPemFile ? make_optional(string(config->keyPemFile)) : nullopt;
c.keyPemPass = config->keyPemPass ? make_optional(string(config->keyPemPass)) : nullopt;
c.bindAddress = config->bindAddress ? make_optional(string(config->bindAddress)) : nullopt;

if(config->maxMessageSize > 0)
c.maxMessageSize = size_t(config->maxMessageSize);

auto webSocketServer = std::make_shared<WebSocketServer>(std::move(c));
int wsserver = emplaceWebSocketServer(webSocketServer);

Expand Down
2 changes: 1 addition & 1 deletion src/channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Channel::~Channel() { impl()->resetCallbacks(); }

Channel::Channel(impl_ptr<impl::Channel> impl) : CheshireCat<impl::Channel>(std::move(impl)) {}

size_t Channel::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
size_t Channel::maxMessageSize() const { return 0; }

size_t Channel::bufferedAmount() const { return impl()->bufferedAmount; }

Expand Down
2 changes: 1 addition & 1 deletion src/impl/datachannel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ bool DataChannel::isClosed(void) const { return mIsClosed; }

size_t DataChannel::maxMessageSize() const {
auto pc = mPeerConnection.lock();
return pc ? pc->remoteMaxMessageSize() : DEFAULT_MAX_MESSAGE_SIZE;
return pc ? pc->remoteMaxMessageSize() : DEFAULT_REMOTE_MAX_MESSAGE_SIZE;
}

void DataChannel::assignStream(uint16_t stream) {
Expand Down
4 changes: 3 additions & 1 deletion src/impl/internals.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ const uint16_t MAX_SCTP_STREAMS_COUNT = 1024; // Max number of negotiated SCTP s
// of memory, Chromium historically limits to 1024.

const size_t DEFAULT_LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Default local max message size
const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536; // Remote max message size if not specified in SDP
const size_t DEFAULT_REMOTE_MAX_MESSAGE_SIZE = 65536; // Remote max message size if not in SDP

const size_t DEFAULT_WS_MAX_MESSAGE_SIZE = 256 * 1024; // Default max message size for WebSockets

const size_t RECV_QUEUE_LIMIT = 1024 * 1024; // Max per-channel queue size

Expand Down
2 changes: 1 addition & 1 deletion src/impl/peerconnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ optional<Description> PeerConnection::remoteDescription() const {
size_t PeerConnection::remoteMaxMessageSize() const {
const size_t localMax = config.maxMessageSize.value_or(DEFAULT_LOCAL_MAX_MESSAGE_SIZE);

size_t remoteMax = DEFAULT_MAX_MESSAGE_SIZE;
size_t remoteMax = DEFAULT_REMOTE_MAX_MESSAGE_SIZE;
std::lock_guard lock(mRemoteDescriptionMutex);
if (mRemoteDescription)
if (auto *application = mRemoteDescription->application())
Expand Down
5 changes: 2 additions & 3 deletions src/impl/websocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ bool WebSocket::isOpen() const { return state == State::Open; }

bool WebSocket::isClosed() const { return state == State::Closed; }

size_t WebSocket::maxMessageSize() const { return config.maxMessageSize.value_or(DEFAULT_MAX_MESSAGE_SIZE); }
size_t WebSocket::maxMessageSize() const { return config.maxMessageSize.value_or(DEFAULT_WS_MAX_MESSAGE_SIZE); }

optional<message_variant> WebSocket::receive() {
auto next = mRecvQueue.pop();
Expand Down Expand Up @@ -443,8 +443,7 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
}
};

auto maxOutstandingPings = config.maxOutstandingPings.value_or(0);
auto transport = std::make_shared<WsTransport>(lower, mWsHandshake, maxOutstandingPings,
auto transport = std::make_shared<WsTransport>(lower, mWsHandshake, config,
weak_bind(&WebSocket::incoming, this, _1),
stateChangeCallback);

Expand Down
1 change: 1 addition & 0 deletions src/impl/websocketserver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ void WebSocketServer::runLoop() {

WebSocket::Configuration clientConfig;
clientConfig.connectionTimeout = config.connectionTimeout;
clientConfig.maxMessageSize = config.maxMessageSize;

auto impl = std::make_shared<WebSocket>(std::move(clientConfig), mCertificate);
impl->changeState(WebSocket::State::Connecting);
Expand Down
71 changes: 51 additions & 20 deletions src/impl/wstransport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,18 @@ using std::to_integer;
using std::to_string;
using std::chrono::system_clock;

WsTransport::WsTransport(
variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>
lower,
shared_ptr<WsHandshake> handshake, int maxOutstandingPings, message_callback recvCallback,
state_callback stateCallback)
WsTransport::WsTransport(LowerTransport lower, shared_ptr<WsHandshake> handshake,
const WebSocketConfiguration &config, message_callback recvCallback,
state_callback stateCallback)
: Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
std::move(stateCallback)),
mHandshake(std::move(handshake)),
mIsClient(
std::visit(rtc::overloaded{[](auto l) { return l->isActive(); },
[](shared_ptr<TlsTransport> l) { return l->isClient(); }},
lower)),
mMaxOutstandingPings(maxOutstandingPings) {
mMaxMessageSize(config.maxMessageSize.value_or(DEFAULT_WS_MAX_MESSAGE_SIZE)),
mMaxOutstandingPings(config.maxOutstandingPings.value_or(0)) {

onRecv(std::move(recvCallback));

Expand All @@ -75,7 +74,10 @@ void WsTransport::start() {
void WsTransport::stop() { close(); }

bool WsTransport::send(message_ptr message) {
if (!message || state() != State::Connected)
if (state() != State::Connected)
throw std::runtime_error("WebSocket is not open");

if (!message)
return false;

PLOG_VERBOSE << "Send size=" << message->size();
Expand Down Expand Up @@ -146,10 +148,22 @@ void WsTransport::incoming(message_ptr message) {
sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, mIsClient});
addOutstandingPing();
} else {
Frame frame;
while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
recvFrame(frame);
if (mIgnoreLength > 0) {
size_t len = std::min(mIgnoreLength, mBuffer.size());
mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
mIgnoreLength -= len;
}
if (mIgnoreLength == 0) {
Frame frame;
while (size_t len = parseFrame(mBuffer.data(), mBuffer.size(), frame)) {
recvFrame(frame);
if (len > mBuffer.size()) {
mIgnoreLength = len - mBuffer.size();
mBuffer.clear();
break;
}
mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
}
}
}
}
Expand Down Expand Up @@ -229,7 +243,7 @@ bool WsTransport::sendHttpError(int code) {
// | Payload Data continued ... |
// +---------------------------------------------------------------+

size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
size_t WsTransport::parseFrame(byte *buffer, size_t size, Frame &frame) {
const byte *end = buffer + size;
if (end - buffer < 2)
return 0;
Expand Down Expand Up @@ -263,16 +277,25 @@ size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
cur += 4;
}

if (size_t(end - cur) < frame.length)
const size_t maxControlFrameLength = 125;
const size_t maxFrameLength = std::max(maxControlFrameLength, mMaxMessageSize);
if (size_t(end - cur) < std::min(frame.length, maxFrameLength))
return 0;

size_t length = frame.length;
if (frame.length > maxFrameLength) {
PLOG_WARNING << "WebSocket frame is too large (length=" << frame.length
<< "), truncating it";
frame.length = maxFrameLength;
}

frame.payload = cur;

if (maskingKey)
for (size_t i = 0; i < frame.length; ++i)
frame.payload[i] ^= maskingKey[i % 4];
cur += frame.length;

return size_t(cur - buffer);
return frame.payload + length - buffer; // can be more than buffer size
}

void WsTransport::recvFrame(const Frame &frame) {
Expand All @@ -282,32 +305,40 @@ void WsTransport::recvFrame(const Frame &frame) {
switch (frame.opcode) {
case TEXT_FRAME:
case BINARY_FRAME: {
size_t size = frame.length;
if (size > mMaxMessageSize) {
PLOG_WARNING << "WebSocket message is too large, truncating it";
size = mMaxMessageSize;
}
if (!mPartial.empty()) {
PLOG_WARNING << "WebSocket unfinished message: type="
<< (mPartialOpcode == TEXT_FRAME ? "text" : "binary")
<< ", length=" << mPartial.size();
<< ", size=" << mPartial.size();
auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
recv(make_message(mPartial.begin(), mPartial.end(), type));
mPartial.clear();
}
mPartialOpcode = frame.opcode;
if (frame.fin) {
PLOG_DEBUG << "WebSocket finished message: type="
<< (frame.opcode == TEXT_FRAME ? "text" : "binary")
<< ", length=" << frame.length;
<< (frame.opcode == TEXT_FRAME ? "text" : "binary") << ", size=" << size;
auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary;
recv(make_message(frame.payload, frame.payload + frame.length, type));
recv(make_message(frame.payload, frame.payload + size, type));
} else {
mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
mPartial.insert(mPartial.end(), frame.payload, frame.payload + size);
}
break;
}
case CONTINUATION: {
mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
if (mPartial.size() > mMaxMessageSize) {
PLOG_WARNING << "WebSocket message is too large, truncating it";
mPartial.resize(mMaxMessageSize);
}
if (frame.fin) {
PLOG_DEBUG << "WebSocket finished message: type="
<< (frame.opcode == TEXT_FRAME ? "text" : "binary")
<< ", length=" << mPartial.size();
<< ", size=" << mPartial.size();
auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
recv(make_message(mPartial.begin(), mPartial.end(), type));
mPartial.clear();
Expand Down
16 changes: 10 additions & 6 deletions src/impl/wstransport.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "common.hpp"
#include "transport.hpp"
#include "configuration.hpp"
#include "wshandshake.hpp"

#if RTC_ENABLE_WEBSOCKET
Expand All @@ -25,11 +26,12 @@ class TlsTransport;

class WsTransport final : public Transport, public std::enable_shared_from_this<WsTransport> {
public:
WsTransport(
variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>
lower,
shared_ptr<WsHandshake> handshake, int maxOutstandingPings, message_callback recvCallback,
state_callback stateCallback);
using LowerTransport =
variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>;

WsTransport(LowerTransport lower, shared_ptr<WsHandshake> handshake,
const WebSocketConfiguration &config, message_callback recvCallback,
state_callback stateCallback);
~WsTransport();

void start() override;
Expand Down Expand Up @@ -62,19 +64,21 @@ class WsTransport final : public Transport, public std::enable_shared_from_this<
bool sendHttpError(int code);
bool sendHttpResponse();

size_t readFrame(byte *buffer, size_t size, Frame &frame);
size_t parseFrame(byte *buffer, size_t size, Frame &frame);
void recvFrame(const Frame &frame);
bool sendFrame(const Frame &frame);

void addOutstandingPing();

const shared_ptr<WsHandshake> mHandshake;
const bool mIsClient;
const size_t mMaxMessageSize;
const int mMaxOutstandingPings;

binary mBuffer;
binary mPartial;
Opcode mPartialOpcode;
size_t mIgnoreLength = 0;
std::mutex mSendMutex;
int mOutstandingPings = 0;
std::atomic<bool> mCloseSent = false;
Expand Down
Loading

0 comments on commit dbdfb49

Please sign in to comment.