diff --git a/src/dumdum/protocol/__init__.py b/src/dumdum/protocol/__init__.py index 8036f65..6f5814a 100644 --- a/src/dumdum/protocol/__init__.py +++ b/src/dumdum/protocol/__init__.py @@ -31,10 +31,12 @@ ServerMessageSendIncompatibleVersion, ServerState, ) +from .buffer import extend_limited_buffer from .channel import Channel from .constants import MAX_MESSAGE_LENGTH, MAX_NICK_LENGTH from .enums import ClientMessageType, ServerMessageType from .errors import ( + BufferOverflowError, InvalidLengthError, InvalidStateError, MalformedDataError, diff --git a/src/dumdum/protocol/buffer.py b/src/dumdum/protocol/buffer.py new file mode 100644 index 0000000..a1f254f --- /dev/null +++ b/src/dumdum/protocol/buffer.py @@ -0,0 +1,18 @@ +from .errors import BufferOverflowError + + +def extend_limited_buffer( + buffer: bytearray, + data: bytes | bytearray, + *, + limit: int | None, +) -> None: + if limit is None: + buffer.extend(data) + return + + len_buffer, len_data = len(buffer), len(data) + if len_buffer + len_data > limit: + raise BufferOverflowError(limit, len_buffer, len_data) + + buffer.extend(data) diff --git a/src/dumdum/protocol/client/protocol.py b/src/dumdum/protocol/client/protocol.py index b7f9a6b..0761cfe 100644 --- a/src/dumdum/protocol/client/protocol.py +++ b/src/dumdum/protocol/client/protocol.py @@ -1,5 +1,6 @@ from enum import Enum +from dumdum.protocol.buffer import extend_limited_buffer from dumdum.protocol.channel import Channel from dumdum.protocol.constants import ( MAX_LIST_CHANNEL_LENGTH_BYTES, @@ -42,14 +43,15 @@ class Client(Protocol): PROTOCOL_VERSION = 2 - def __init__(self, nick: str) -> None: + def __init__(self, nick: str, *, buffer_size: int = 2**20) -> None: self.nick = nick + self.buffer_size = buffer_size self._buffer = bytearray() self._state = ClientState.AWAITING_HELLO def receive_bytes(self, data: bytes) -> ParsedData: - self._buffer.extend(data) + extend_limited_buffer(self._buffer, data, limit=self.buffer_size) return self._maybe_parse_buffer() def hello(self) -> bytes: diff --git a/src/dumdum/protocol/errors.py b/src/dumdum/protocol/errors.py index 14b0213..045e921 100644 --- a/src/dumdum/protocol/errors.py +++ b/src/dumdum/protocol/errors.py @@ -5,6 +5,21 @@ class ProtocolError(Exception): """The base class for errors related to the protocol.""" +class BufferOverflowError(ProtocolError): + """The protocol received more data than the buffer could store. + + This may be the result of the buffer size being set too low, + or possibly malformed data that the protocol is unable to parse. + + """ + + def __init__(self, limit: int, len_buffer: int, len_data: int) -> None: + super().__init__(f"Buffer limit cannot be exceeded ({limit:,d} bytes)") + self.limit = limit + self.len_buffer = len_buffer + self.len_data = len_data + + class MalformedDataError(ProtocolError): """Raised when there is unambiguously malformed data. diff --git a/src/dumdum/protocol/server/protocol.py b/src/dumdum/protocol/server/protocol.py index 2a592c5..4ec6e58 100644 --- a/src/dumdum/protocol/server/protocol.py +++ b/src/dumdum/protocol/server/protocol.py @@ -1,6 +1,7 @@ from enum import Enum from typing import Sequence +from dumdum.protocol.buffer import extend_limited_buffer from dumdum.protocol.channel import Channel from dumdum.protocol.constants import ( MAX_CHANNEL_NAME_LENGTH, @@ -45,12 +46,14 @@ class Server(Protocol): PROTOCOL_VERSION = 2 - def __init__(self) -> None: + def __init__(self, *, buffer_size: int = 2**20) -> None: + self.buffer_size = buffer_size + self._buffer = bytearray() self._state = ServerState.AWAITING_HELLO def receive_bytes(self, data: bytes) -> ParsedData: - self._buffer.extend(data) + extend_limited_buffer(self._buffer, data, limit=self.buffer_size) return self._maybe_parse_buffer() def hello(self, *, using_ssl: bool) -> bytes: diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 72399c8..cbb3748 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -3,6 +3,7 @@ import pytest from dumdum.protocol import ( + BufferOverflowError, Channel, Client, ClientEventAuthentication, @@ -256,3 +257,17 @@ def test_unicode_decode_error(): # SEND_MESSAGE, Channel name \N{EYES} but missing last 3 bytes data = b"\x03\x01\xf0" server.receive_bytes(data) + + +def test_buffer_overflow_prevention(): + data = b"Hello world!\n" + size = len(data) - 1 + + client = Client("thegamecracks", buffer_size=size) + server = Server(buffer_size=size) + + with pytest.raises(BufferOverflowError): + client.receive_bytes(data) + + with pytest.raises(BufferOverflowError): + server.receive_bytes(data)