Skip to content

Commit

Permalink
feat!: implement message caching
Browse files Browse the repository at this point in the history
Later we'll add a new command to list messages

- Bump protocol version due to new/changed commands
- Add Message dataclass with ID field
- Add message cache to HighCommand
- Add Reader.read_bigint()
- Add snowflake.create_snowflake()
- Replace ClientEventMessageReceived fields with one message field
- Replace ServerMessagePost fields with one message field
- Replace Server.send_message() parameters with one message parameter
  • Loading branch information
thegamecracks committed Mar 13, 2024
1 parent c1acd9b commit 5909c2a
Show file tree
Hide file tree
Showing 14 changed files with 162 additions and 59 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Servers are able to send the following messages:

1. INCOMPATIBLE_VERSION: `0x00 | 1-byte version`
2. ACKNOWLEDGE_AUTHENTICATION: `0x01 | 0 or 1 success`
3. SEND_MESSAGE: `0x02 | varchar channel name (32) | varchar nickname (32) | varchar content (1024)`
3. SEND_MESSAGE: `0x02 | 8-byte snowflake | varchar channel name (32) | varchar nickname (32) | varchar content (1024)`
4. LIST_CHANNELS: `0x03 | 2-byte length | varchar channel name (32) | ...`

When the client disconnects and reconnects, they MUST re-authenticate with the server.
Expand Down
23 changes: 6 additions & 17 deletions src/dumdum/client/chat_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import collections

import concurrent.futures
from dataclasses import dataclass
from tkinter import Event, StringVar
from tkinter.ttk import Button, Entry, Frame, Label, Treeview
from typing import ContextManager
Expand All @@ -12,6 +11,7 @@
ClientEvent,
ClientEventChannelsListed,
ClientEventMessageReceived,
Message,
)

from .app import TkApp
Expand Down Expand Up @@ -47,10 +47,10 @@ def handle_client_event(self, event: ClientEvent) -> None:
self.channels.extend(event.channels)
self.channel_list.refresh()
elif isinstance(event, ClientEventMessageReceived):
message = self.message_cache.add_message_from_event(event)
channel = self.get_channel(event.channel_name)
self.message_cache.add_message(event.message)
channel = self.get_channel(event.message.channel_name)
if channel == self.channel_list.selected_channel:
self.messages.add_message(message)
self.messages.add_message(event.message)

def get_channel(self, name: str) -> Channel | None:
for channel in self.channels:
Expand Down Expand Up @@ -164,23 +164,12 @@ def __init__(self, frame: Frame, parent: MessageList, message: Message) -> None:
self.content.grid(row=1, column=0, sticky="ew")


@dataclass
class Message:
nick: str
content: str


class MessageCache:
def __init__(self) -> None:
self.channel_messages: dict[str, list[Message]] = collections.defaultdict(list)

def add_message(self, channel_name: str, message: Message) -> None:
self.channel_messages[channel_name].append(message)

def add_message_from_event(self, event: ClientEventMessageReceived) -> Message:
message = Message(event.nick, event.content)
self.add_message(event.channel_name, message)
return message
def add_message(self, message: Message) -> None:
self.channel_messages[message.channel_name].append(message)

def get_messages(self, channel: Channel) -> list[Message]:
messages = self.channel_messages.get(channel.name)
Expand Down
1 change: 1 addition & 0 deletions src/dumdum/protocol/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This contains the [Sans-IO] implementation of the Dumdum protocol.
- [`highcommand.py`](highcommand.py): A server-side, in-memory datastore for channels and users.
- [`interfaces.py`](interfaces.py): Defines a common interface between the client and server.
- [`reader.py`](reader.py): Provides functions to read through bytes/bytearrays like streams.
- [`snowflake.py`](snowflake.py): Provides functions to generate snowflake identifiers.
- [`varchar.py`](varchar.py): Provides functions to de/serialize variable-length strings.

[Sans-IO]: https://sans-io.readthedocs.io/
2 changes: 2 additions & 0 deletions src/dumdum/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,6 @@
)
from .highcommand import HighCommand
from .interfaces import Protocol
from .message import Message
from .reader import Reader, bytearray_reader, byte_reader
from .snowflake import create_snowflake
5 changes: 2 additions & 3 deletions src/dumdum/protocol/client/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Sequence

from dumdum.protocol.channel import Channel
from dumdum.protocol.message import Message


@dataclass
Expand All @@ -28,9 +29,7 @@ class ClientEventAuthentication(ClientEvent):
class ClientEventMessageReceived(ClientEvent):
"""The server broadcasted a message to the client."""

channel_name: str
nick: str
content: str
message: Message


@dataclass
Expand Down
22 changes: 10 additions & 12 deletions src/dumdum/protocol/client/protocol.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from enum import Enum

from dumdum.protocol.channel import Channel
from dumdum.protocol.constants import (
MAX_CHANNEL_NAME_LENGTH,
MAX_LIST_CHANNEL_LENGTH_BYTES,
MAX_MESSAGE_LENGTH,
MAX_NICK_LENGTH,
)
from dumdum.protocol.constants import MAX_LIST_CHANNEL_LENGTH_BYTES
from dumdum.protocol.enums import ServerMessageType
from dumdum.protocol.errors import InvalidStateError
from dumdum.protocol.interfaces import Protocol
from dumdum.protocol.message import Message
from dumdum.protocol.reader import Reader, byte_reader, bytearray_reader

from .events import (
Expand All @@ -19,7 +15,11 @@
ClientEventIncompatibleVersion,
ClientEventMessageReceived,
)
from .messages import ClientMessageAuthenticate, ClientMessageListChannels, ClientMessagePost
from .messages import (
ClientMessageAuthenticate,
ClientMessageListChannels,
ClientMessagePost,
)

ParsedData = tuple[list[ClientEvent], bytes]

Expand All @@ -32,7 +32,7 @@ class ClientState(Enum):
class Client(Protocol):
"""The client connected to a server."""

PROTOCOL_VERSION = 0
PROTOCOL_VERSION = 1

def __init__(self, nick: str) -> None:
self.nick = nick
Expand Down Expand Up @@ -113,10 +113,8 @@ def _accept_authentication(self, reader: Reader) -> ParsedData:

def _parse_message(self, reader: Reader) -> ParsedData:
self._assert_state(ClientState.READY)
channel_name = reader.read_varchar(max_length=MAX_CHANNEL_NAME_LENGTH)
nick = reader.read_varchar(max_length=MAX_NICK_LENGTH)
content = reader.read_varchar(max_length=MAX_MESSAGE_LENGTH)
event = ClientEventMessageReceived(channel_name, nick, content)
message = Message.from_reader(reader)
event = ClientEventMessageReceived(message)
return [event], b""

def _parse_channel_list(self, reader: Reader) -> ParsedData:
Expand Down
37 changes: 36 additions & 1 deletion src/dumdum/protocol/highcommand.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import TypeAlias
import bisect
import collections
from typing import Sequence, TypeAlias

from .channel import Channel
from .message import Message

User: TypeAlias = str


class HighCommand:
def __init__(self) -> None:
self._channels: dict[str, Channel] = {}
self._messages: dict[str, list[Message]] = collections.defaultdict(list)
self._users: dict[str, User] = {}

@property
Expand All @@ -23,6 +27,37 @@ def get_channel(self, name: str) -> Channel | None:
def remove_channel(self, name: str) -> Channel | None:
return self._channels.pop(name, None)

def get_messages(self, channel_name: str) -> Sequence[Message]:
return self._messages[channel_name]

def add_message(self, message: Message) -> None:
messages = self._messages[message.channel_name]
bisect.insort(messages, message, key=lambda m: m.id)

def get_message(self, channel_name: str, id: int) -> Message | None:
messages = self._messages[channel_name]
if len(messages) < 1:
return None

i = self._index_message(messages, id)
message = messages[i]
if message.id == id:
return message

def remove_message(self, channel_name: str, id: int) -> Message | None:
messages = self._messages[channel_name]
if len(messages) < 1:
return None

i = self._index_message(messages, id)
message = messages[i]
if message.id == id:
del messages[i]
return message

def _index_message(self, messages: Sequence[Message], id: int) -> int:
return bisect.bisect_left(messages, id, key=lambda m: m.id)

@property
def users(self) -> tuple[User, ...]:
return tuple(self._users.values())
Expand Down
32 changes: 32 additions & 0 deletions src/dumdum/protocol/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dataclasses import dataclass
from typing import Self

from . import varchar
from .constants import MAX_CHANNEL_NAME_LENGTH, MAX_MESSAGE_LENGTH, MAX_NICK_LENGTH
from .reader import Reader


@dataclass
class Message:
id: int
channel_name: str
nick: str
content: str

def __bytes__(self) -> bytes:
return bytes(
[
*self.id.to_bytes(8, byteorder="big"),
*varchar.dumps(self.channel_name, max_length=MAX_CHANNEL_NAME_LENGTH),
*varchar.dumps(self.nick, max_length=MAX_NICK_LENGTH),
*varchar.dumps(self.content, max_length=MAX_MESSAGE_LENGTH),
]
)

@classmethod
def from_reader(cls, reader: Reader) -> Self:
id = reader.read_bigint()
channel_name = reader.read_varchar(max_length=MAX_CHANNEL_NAME_LENGTH)
nick = reader.read_varchar(max_length=MAX_NICK_LENGTH)
content = reader.read_varchar(max_length=MAX_MESSAGE_LENGTH)
return cls(id=id, channel_name=channel_name, nick=nick, content=content)
4 changes: 4 additions & 0 deletions src/dumdum/protocol/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def readexactly(self, n: int) -> bytes:

return data

def read_bigint(self) -> int:
data = self.readexactly(8)
return int.from_bytes(data, byteorder="big")

def read_varchar(self, *, max_length: int) -> str:
return varchar.load(self, max_length=max_length)

Expand Down
17 changes: 4 additions & 13 deletions src/dumdum/protocol/server/messages.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
from dataclasses import dataclass
from typing import Sequence

from dumdum.protocol import varchar
from dumdum.protocol.channel import Channel
from dumdum.protocol.constants import (
MAX_CHANNEL_NAME_LENGTH,
MAX_LIST_CHANNEL_LENGTH_BYTES,
MAX_MESSAGE_LENGTH,
MAX_NICK_LENGTH,
)
from dumdum.protocol.constants import MAX_LIST_CHANNEL_LENGTH_BYTES
from dumdum.protocol.enums import ServerMessageType
from dumdum.protocol.message import Message


@dataclass
Expand Down Expand Up @@ -40,17 +35,13 @@ def __bytes__(self) -> bytes:

@dataclass
class ServerMessagePost:
channel_name: str
nick: str
content: str
message: Message

def __bytes__(self) -> bytes:
return bytes(
[
ServerMessageType.SEND_MESSAGE.value,
*varchar.dumps(self.channel_name, max_length=MAX_CHANNEL_NAME_LENGTH),
*varchar.dumps(self.nick, max_length=MAX_NICK_LENGTH),
*varchar.dumps(self.content, max_length=MAX_MESSAGE_LENGTH),
*bytes(self.message),
]
)

Expand Down
7 changes: 4 additions & 3 deletions src/dumdum/protocol/server/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dumdum.protocol.enums import ClientMessageType
from dumdum.protocol.errors import InvalidStateError
from dumdum.protocol.interfaces import Protocol
from dumdum.protocol.message import Message
from dumdum.protocol.reader import Reader, bytearray_reader

from .events import (
Expand Down Expand Up @@ -37,7 +38,7 @@ class ServerState(Enum):
class Server(Protocol):
"""The server for a single client."""

PROTOCOL_VERSION = 0
PROTOCOL_VERSION = 1

def __init__(self) -> None:
self._buffer = bytearray()
Expand All @@ -55,9 +56,9 @@ def acknowledge_authentication(self, *, success: bool) -> bytes:

return bytes(ServerMessageAcknowledgeAuthentication(success))

def send_message(self, channel_name: str, nick: str, content: str) -> bytes:
def send_message(self, message: Message) -> bytes:
self._assert_state(ServerState.READY)
return bytes(ServerMessagePost(channel_name, nick, content))
return bytes(ServerMessagePost(message))

def list_channels(self, channels: Sequence[Channel]) -> bytes:
return bytes(ServerMessageListChannels(channels))
Expand Down
45 changes: 45 additions & 0 deletions src/dumdum/protocol/snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
64 63 19 12 0
0 00011000111000111000000101110101101101000100 1110101 000000000111
64-63: Always 0
63-19: Unix timestamp in milliseconds
19-12: Process ID
12-00: Incrementing per-process ID
"""

import datetime
import os
import time

_incrementing_id = 0


def create_snowflake(
t: float | int | datetime.datetime | None = None,
pid: int | None = None,
increment: int | None = None,
) -> int:
if t is None:
# WARNING: this may roll back or forwards depending on system time
t = time.time_ns() // 1000000
elif isinstance(t, datetime.datetime):
t = int(t.timestamp() * 1000)
elif isinstance(t, float):
t = int(t * 1000)

if t.bit_length() > 44:
raise OverflowError(f"Timestamp {t} out of bounds (are we in 2527?)")

if pid is None:
pid = os.getpid()

if increment is None:
global _incrementing_id
increment = _incrementing_id
_incrementing_id += 1

pid = pid % 128
increment = increment % 4096

return (t << 19) + (pid << 12) + increment
Loading

0 comments on commit 5909c2a

Please sign in to comment.