Skip to content

Commit

Permalink
feat: add buffer size limits to protocol
Browse files Browse the repository at this point in the history
Currently the largest message we might send is LIST_MESSAGES,
which can span up to 16MB (1+3+2^24). With the maximum size of a
message being 1100 bytes (33+33+1026), this would require a minimum
of 15,253 message objects to be sent. Realistically the server shouldn't
come close to this, and is more likely to be sending a few hundred
messages at a time. We will use a default buffer size of 1MiB which
should be suitable for all practical purposes.
  • Loading branch information
thegamecracks committed Apr 2, 2024
1 parent 69ac37b commit d010df9
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/dumdum/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
ServerMessageSendIncompatibleVersion,
ServerState,
)
from .buffer import extend_limited_buffer
from .channel import Channel
from .constants import MAX_MESSAGE_LENGTH, MAX_NICK_LENGTH
from .enums import ClientMessageType, ServerMessageType
from .errors import (
BufferOverflowError,
InvalidLengthError,
InvalidStateError,
MalformedDataError,
Expand Down
18 changes: 18 additions & 0 deletions src/dumdum/protocol/buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .errors import BufferOverflowError


def extend_limited_buffer(
buffer: bytearray,
data: bytes | bytearray,
*,
limit: int | None,
) -> None:
if limit is None:
buffer.extend(data)
return

len_buffer, len_data = len(buffer), len(data)
if len_buffer + len_data > limit:
raise BufferOverflowError(limit, len_buffer, len_data)

buffer.extend(data)
6 changes: 4 additions & 2 deletions src/dumdum/protocol/client/protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum

from dumdum.protocol.buffer import extend_limited_buffer
from dumdum.protocol.channel import Channel
from dumdum.protocol.constants import (
MAX_LIST_CHANNEL_LENGTH_BYTES,
Expand Down Expand Up @@ -42,14 +43,15 @@ class Client(Protocol):

PROTOCOL_VERSION = 2

def __init__(self, nick: str) -> None:
def __init__(self, nick: str, *, buffer_size: int = 2**20) -> None:
self.nick = nick
self.buffer_size = buffer_size

self._buffer = bytearray()
self._state = ClientState.AWAITING_HELLO

def receive_bytes(self, data: bytes) -> ParsedData:
self._buffer.extend(data)
extend_limited_buffer(self._buffer, data, limit=self.buffer_size)
return self._maybe_parse_buffer()

def hello(self) -> bytes:
Expand Down
15 changes: 15 additions & 0 deletions src/dumdum/protocol/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,21 @@ class ProtocolError(Exception):
"""The base class for errors related to the protocol."""


class BufferOverflowError(ProtocolError):
"""The protocol received more data than the buffer could store.
This may be the result of the buffer size being set too low,
or possibly malformed data that the protocol is unable to parse.
"""

def __init__(self, limit: int, len_buffer: int, len_data: int) -> None:
super().__init__(f"Buffer limit cannot be exceeded ({limit:,d} bytes)")
self.limit = limit
self.len_buffer = len_buffer
self.len_data = len_data


class MalformedDataError(ProtocolError):
"""Raised when there is unambiguously malformed data.
Expand Down
7 changes: 5 additions & 2 deletions src/dumdum/protocol/server/protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum
from typing import Sequence

from dumdum.protocol.buffer import extend_limited_buffer
from dumdum.protocol.channel import Channel
from dumdum.protocol.constants import (
MAX_CHANNEL_NAME_LENGTH,
Expand Down Expand Up @@ -45,12 +46,14 @@ class Server(Protocol):

PROTOCOL_VERSION = 2

def __init__(self) -> None:
def __init__(self, *, buffer_size: int = 2**20) -> None:
self.buffer_size = buffer_size

self._buffer = bytearray()
self._state = ServerState.AWAITING_HELLO

def receive_bytes(self, data: bytes) -> ParsedData:
self._buffer.extend(data)
extend_limited_buffer(self._buffer, data, limit=self.buffer_size)
return self._maybe_parse_buffer()

def hello(self, *, using_ssl: bool) -> bytes:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from dumdum.protocol import (
BufferOverflowError,
Channel,
Client,
ClientEventAuthentication,
Expand Down Expand Up @@ -256,3 +257,17 @@ def test_unicode_decode_error():
# SEND_MESSAGE, Channel name \N{EYES} but missing last 3 bytes
data = b"\x03\x01\xf0"
server.receive_bytes(data)


def test_buffer_overflow_prevention():
data = b"Hello world!\n"
size = len(data) - 1

client = Client("thegamecracks", buffer_size=size)
server = Server(buffer_size=size)

with pytest.raises(BufferOverflowError):
client.receive_bytes(data)

with pytest.raises(BufferOverflowError):
server.receive_bytes(data)

0 comments on commit d010df9

Please sign in to comment.