diff --git a/README.md b/README.md index 0498eb5..8bf018b 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Servers are able to send the following messages: 1. INCOMPATIBLE_VERSION: `0x00 | 1-byte version` 2. ACKNOWLEDGE_AUTHENTICATION: `0x01 | 0 or 1 success` -3. SEND_MESSAGE: `0x02 | varchar channel name (32) | varchar nickname (32) | varchar content (1024)` +3. SEND_MESSAGE: `0x02 | 8-byte snowflake | varchar channel name (32) | varchar nickname (32) | varchar content (1024)` 4. LIST_CHANNELS: `0x03 | 2-byte length | varchar channel name (32) | ...` When the client disconnects and reconnects, they MUST re-authenticate with the server. diff --git a/src/dumdum/client/chat_frame.py b/src/dumdum/client/chat_frame.py index 9c54864..f91e166 100644 --- a/src/dumdum/client/chat_frame.py +++ b/src/dumdum/client/chat_frame.py @@ -2,7 +2,6 @@ import collections import concurrent.futures -from dataclasses import dataclass from tkinter import Event, StringVar from tkinter.ttk import Button, Entry, Frame, Label, Treeview from typing import ContextManager @@ -12,6 +11,7 @@ ClientEvent, ClientEventChannelsListed, ClientEventMessageReceived, + Message, ) from .app import TkApp @@ -47,10 +47,10 @@ def handle_client_event(self, event: ClientEvent) -> None: self.channels.extend(event.channels) self.channel_list.refresh() elif isinstance(event, ClientEventMessageReceived): - message = self.message_cache.add_message_from_event(event) - channel = self.get_channel(event.channel_name) + self.message_cache.add_message(event.message) + channel = self.get_channel(event.message.channel_name) if channel == self.channel_list.selected_channel: - self.messages.add_message(message) + self.messages.add_message(event.message) def get_channel(self, name: str) -> Channel | None: for channel in self.channels: @@ -164,23 +164,12 @@ def __init__(self, frame: Frame, parent: MessageList, message: Message) -> None: self.content.grid(row=1, column=0, sticky="ew") -@dataclass -class Message: - nick: str - content: str - - class MessageCache: def __init__(self) -> None: self.channel_messages: dict[str, list[Message]] = collections.defaultdict(list) - def add_message(self, channel_name: str, message: Message) -> None: - self.channel_messages[channel_name].append(message) - - def add_message_from_event(self, event: ClientEventMessageReceived) -> Message: - message = Message(event.nick, event.content) - self.add_message(event.channel_name, message) - return message + def add_message(self, message: Message) -> None: + self.channel_messages[message.channel_name].append(message) def get_messages(self, channel: Channel) -> list[Message]: messages = self.channel_messages.get(channel.name) diff --git a/src/dumdum/protocol/README.md b/src/dumdum/protocol/README.md index 2f01596..7dd905d 100644 --- a/src/dumdum/protocol/README.md +++ b/src/dumdum/protocol/README.md @@ -9,6 +9,7 @@ This contains the [Sans-IO] implementation of the Dumdum protocol. - [`highcommand.py`](highcommand.py): A server-side, in-memory datastore for channels and users. - [`interfaces.py`](interfaces.py): Defines a common interface between the client and server. - [`reader.py`](reader.py): Provides functions to read through bytes/bytearrays like streams. +- [`snowflake.py`](snowflake.py): Provides functions to generate snowflake identifiers. - [`varchar.py`](varchar.py): Provides functions to de/serialize variable-length strings. [Sans-IO]: https://sans-io.readthedocs.io/ diff --git a/src/dumdum/protocol/__init__.py b/src/dumdum/protocol/__init__.py index ba5a206..6579944 100644 --- a/src/dumdum/protocol/__init__.py +++ b/src/dumdum/protocol/__init__.py @@ -34,4 +34,6 @@ ) from .highcommand import HighCommand from .interfaces import Protocol +from .message import Message from .reader import Reader, bytearray_reader, byte_reader +from .snowflake import create_snowflake diff --git a/src/dumdum/protocol/client/events.py b/src/dumdum/protocol/client/events.py index c852374..8adf2bc 100644 --- a/src/dumdum/protocol/client/events.py +++ b/src/dumdum/protocol/client/events.py @@ -2,6 +2,7 @@ from typing import Sequence from dumdum.protocol.channel import Channel +from dumdum.protocol.message import Message @dataclass @@ -28,9 +29,7 @@ class ClientEventAuthentication(ClientEvent): class ClientEventMessageReceived(ClientEvent): """The server broadcasted a message to the client.""" - channel_name: str - nick: str - content: str + message: Message @dataclass diff --git a/src/dumdum/protocol/client/protocol.py b/src/dumdum/protocol/client/protocol.py index 4541494..4edd15c 100644 --- a/src/dumdum/protocol/client/protocol.py +++ b/src/dumdum/protocol/client/protocol.py @@ -1,15 +1,11 @@ from enum import Enum from dumdum.protocol.channel import Channel -from dumdum.protocol.constants import ( - MAX_CHANNEL_NAME_LENGTH, - MAX_LIST_CHANNEL_LENGTH_BYTES, - MAX_MESSAGE_LENGTH, - MAX_NICK_LENGTH, -) +from dumdum.protocol.constants import MAX_LIST_CHANNEL_LENGTH_BYTES from dumdum.protocol.enums import ServerMessageType from dumdum.protocol.errors import InvalidStateError from dumdum.protocol.interfaces import Protocol +from dumdum.protocol.message import Message from dumdum.protocol.reader import Reader, byte_reader, bytearray_reader from .events import ( @@ -19,7 +15,11 @@ ClientEventIncompatibleVersion, ClientEventMessageReceived, ) -from .messages import ClientMessageAuthenticate, ClientMessageListChannels, ClientMessagePost +from .messages import ( + ClientMessageAuthenticate, + ClientMessageListChannels, + ClientMessagePost, +) ParsedData = tuple[list[ClientEvent], bytes] @@ -32,7 +32,7 @@ class ClientState(Enum): class Client(Protocol): """The client connected to a server.""" - PROTOCOL_VERSION = 0 + PROTOCOL_VERSION = 1 def __init__(self, nick: str) -> None: self.nick = nick @@ -113,10 +113,8 @@ def _accept_authentication(self, reader: Reader) -> ParsedData: def _parse_message(self, reader: Reader) -> ParsedData: self._assert_state(ClientState.READY) - channel_name = reader.read_varchar(max_length=MAX_CHANNEL_NAME_LENGTH) - nick = reader.read_varchar(max_length=MAX_NICK_LENGTH) - content = reader.read_varchar(max_length=MAX_MESSAGE_LENGTH) - event = ClientEventMessageReceived(channel_name, nick, content) + message = Message.from_reader(reader) + event = ClientEventMessageReceived(message) return [event], b"" def _parse_channel_list(self, reader: Reader) -> ParsedData: diff --git a/src/dumdum/protocol/highcommand.py b/src/dumdum/protocol/highcommand.py index 53b4dad..8e50c5a 100644 --- a/src/dumdum/protocol/highcommand.py +++ b/src/dumdum/protocol/highcommand.py @@ -1,6 +1,9 @@ -from typing import TypeAlias +import bisect +import collections +from typing import Sequence, TypeAlias from .channel import Channel +from .message import Message User: TypeAlias = str @@ -8,6 +11,7 @@ class HighCommand: def __init__(self) -> None: self._channels: dict[str, Channel] = {} + self._messages: dict[str, list[Message]] = collections.defaultdict(list) self._users: dict[str, User] = {} @property @@ -23,6 +27,37 @@ def get_channel(self, name: str) -> Channel | None: def remove_channel(self, name: str) -> Channel | None: return self._channels.pop(name, None) + def get_messages(self, channel_name: str) -> Sequence[Message]: + return self._messages[channel_name] + + def add_message(self, message: Message) -> None: + messages = self._messages[message.channel_name] + bisect.insort(messages, message, key=lambda m: m.id) + + def get_message(self, channel_name: str, id: int) -> Message | None: + messages = self._messages[channel_name] + if len(messages) < 1: + return None + + i = self._index_message(messages, id) + message = messages[i] + if message.id == id: + return message + + def remove_message(self, channel_name: str, id: int) -> Message | None: + messages = self._messages[channel_name] + if len(messages) < 1: + return None + + i = self._index_message(messages, id) + message = messages[i] + if message.id == id: + del messages[i] + return message + + def _index_message(self, messages: Sequence[Message], id: int) -> int: + return bisect.bisect_left(messages, id, key=lambda m: m.id) + @property def users(self) -> tuple[User, ...]: return tuple(self._users.values()) diff --git a/src/dumdum/protocol/message.py b/src/dumdum/protocol/message.py new file mode 100644 index 0000000..c7ed347 --- /dev/null +++ b/src/dumdum/protocol/message.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import Self + +from . import varchar +from .constants import MAX_CHANNEL_NAME_LENGTH, MAX_MESSAGE_LENGTH, MAX_NICK_LENGTH +from .reader import Reader + + +@dataclass +class Message: + id: int + channel_name: str + nick: str + content: str + + def __bytes__(self) -> bytes: + return bytes( + [ + *self.id.to_bytes(8, byteorder="big"), + *varchar.dumps(self.channel_name, max_length=MAX_CHANNEL_NAME_LENGTH), + *varchar.dumps(self.nick, max_length=MAX_NICK_LENGTH), + *varchar.dumps(self.content, max_length=MAX_MESSAGE_LENGTH), + ] + ) + + @classmethod + def from_reader(cls, reader: Reader) -> Self: + id = reader.read_bigint() + channel_name = reader.read_varchar(max_length=MAX_CHANNEL_NAME_LENGTH) + nick = reader.read_varchar(max_length=MAX_NICK_LENGTH) + content = reader.read_varchar(max_length=MAX_MESSAGE_LENGTH) + return cls(id=id, channel_name=channel_name, nick=nick, content=content) diff --git a/src/dumdum/protocol/reader.py b/src/dumdum/protocol/reader.py index 279e1ea..8ca8a6f 100644 --- a/src/dumdum/protocol/reader.py +++ b/src/dumdum/protocol/reader.py @@ -34,6 +34,10 @@ def readexactly(self, n: int) -> bytes: return data + def read_bigint(self) -> int: + data = self.readexactly(8) + return int.from_bytes(data, byteorder="big") + def read_varchar(self, *, max_length: int) -> str: return varchar.load(self, max_length=max_length) diff --git a/src/dumdum/protocol/server/messages.py b/src/dumdum/protocol/server/messages.py index 19e8575..3af3fa7 100644 --- a/src/dumdum/protocol/server/messages.py +++ b/src/dumdum/protocol/server/messages.py @@ -1,15 +1,10 @@ from dataclasses import dataclass from typing import Sequence -from dumdum.protocol import varchar from dumdum.protocol.channel import Channel -from dumdum.protocol.constants import ( - MAX_CHANNEL_NAME_LENGTH, - MAX_LIST_CHANNEL_LENGTH_BYTES, - MAX_MESSAGE_LENGTH, - MAX_NICK_LENGTH, -) +from dumdum.protocol.constants import MAX_LIST_CHANNEL_LENGTH_BYTES from dumdum.protocol.enums import ServerMessageType +from dumdum.protocol.message import Message @dataclass @@ -40,17 +35,13 @@ def __bytes__(self) -> bytes: @dataclass class ServerMessagePost: - channel_name: str - nick: str - content: str + message: Message def __bytes__(self) -> bytes: return bytes( [ ServerMessageType.SEND_MESSAGE.value, - *varchar.dumps(self.channel_name, max_length=MAX_CHANNEL_NAME_LENGTH), - *varchar.dumps(self.nick, max_length=MAX_NICK_LENGTH), - *varchar.dumps(self.content, max_length=MAX_MESSAGE_LENGTH), + *bytes(self.message), ] ) diff --git a/src/dumdum/protocol/server/protocol.py b/src/dumdum/protocol/server/protocol.py index dcf7255..a7925ec 100644 --- a/src/dumdum/protocol/server/protocol.py +++ b/src/dumdum/protocol/server/protocol.py @@ -10,6 +10,7 @@ from dumdum.protocol.enums import ClientMessageType from dumdum.protocol.errors import InvalidStateError from dumdum.protocol.interfaces import Protocol +from dumdum.protocol.message import Message from dumdum.protocol.reader import Reader, bytearray_reader from .events import ( @@ -37,7 +38,7 @@ class ServerState(Enum): class Server(Protocol): """The server for a single client.""" - PROTOCOL_VERSION = 0 + PROTOCOL_VERSION = 1 def __init__(self) -> None: self._buffer = bytearray() @@ -55,9 +56,9 @@ def acknowledge_authentication(self, *, success: bool) -> bytes: return bytes(ServerMessageAcknowledgeAuthentication(success)) - def send_message(self, channel_name: str, nick: str, content: str) -> bytes: + def send_message(self, message: Message) -> bytes: self._assert_state(ServerState.READY) - return bytes(ServerMessagePost(channel_name, nick, content)) + return bytes(ServerMessagePost(message)) def list_channels(self, channels: Sequence[Channel]) -> bytes: return bytes(ServerMessageListChannels(channels)) diff --git a/src/dumdum/protocol/snowflake.py b/src/dumdum/protocol/snowflake.py new file mode 100644 index 0000000..1182411 --- /dev/null +++ b/src/dumdum/protocol/snowflake.py @@ -0,0 +1,45 @@ +""" +64 63 19 12 0 + 0 00011000111000111000000101110101101101000100 1110101 000000000111 + +64-63: Always 0 +63-19: Unix timestamp in milliseconds +19-12: Process ID +12-00: Incrementing per-process ID +""" + +import datetime +import os +import time + +_incrementing_id = 0 + + +def create_snowflake( + t: float | int | datetime.datetime | None = None, + pid: int | None = None, + increment: int | None = None, +) -> int: + if t is None: + # WARNING: this may roll back or forwards depending on system time + t = time.time_ns() // 1000000 + elif isinstance(t, datetime.datetime): + t = int(t.timestamp() * 1000) + elif isinstance(t, float): + t = int(t * 1000) + + if t.bit_length() > 44: + raise OverflowError(f"Timestamp {t} out of bounds (are we in 2527?)") + + if pid is None: + pid = os.getpid() + + if increment is None: + global _incrementing_id + increment = _incrementing_id + _incrementing_id += 1 + + pid = pid % 128 + increment = increment % 4096 + + return (t << 19) + (pid << 12) + increment diff --git a/src/dumdum/server.py b/src/dumdum/server.py index 06edaf4..d094f62 100644 --- a/src/dumdum/server.py +++ b/src/dumdum/server.py @@ -12,11 +12,13 @@ Channel, HighCommand, InvalidStateError, + Message, Server, ServerEvent, ServerEventAuthentication, ServerEventListChannels, ServerEventMessageReceived, + create_snowflake, ) log = logging.getLogger(__name__) @@ -162,13 +164,17 @@ def _broadcast_message( if self.hc.get_channel(event.channel_name) is None: return + message = Message( + create_snowflake(), + event.channel_name, + conn.nick, + event.content, + ) + self.hc.add_message(message) + for peer in self.connections: with contextlib.suppress(InvalidStateError): - data = peer.server.send_message( - event.channel_name, - conn.nick, - event.content, - ) + data = peer.server.send_message(message) peer.writer.write(data) def _list_channels(self, conn: Connection, event: ServerEventListChannels) -> None: diff --git a/tests/test_protocol.py b/tests/test_protocol.py index d84eabe..325afc1 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -9,8 +9,8 @@ ClientEventChannelsListed, ClientEventIncompatibleVersion, ClientState, - HighCommand, InvalidStateError, + Message, Protocol, Server, ServerEventAuthentication, @@ -168,16 +168,16 @@ def test_unauthenticated_send_message(): def test_unauthenticated_server_send_message(): nick = "thegamecracks" content = "Hello world!" - channel = Channel("general") + message = Message(0, "general", nick, content) client = Client(nick=nick) server = Server() with pytest.raises(InvalidStateError): - server.send_message(channel.name, nick, content) + server.send_message(message) server._state = ServerState.READY - data = server.send_message(channel.name, nick, content) + data = server.send_message(message) with pytest.raises(InvalidStateError): communicate(server, data, client)