Skip to content

Commit

Permalink
refactor!: decouple HighCommand from server protocol
Browse files Browse the repository at this point in the history
This makes it easier to separate how channels and users are stored.
HighCommand will eventually be replaced with an on-disk store.

- Remove Server(hc=) parameter
- Remove Server.nick
  - This is now stored in the asyncio Connection
- Add Server.acknowledge_authentication()
  - Authentication is no longer automatic
- Add Server.list_channels()
  - Listing channels is no longer automatic
- Only relay channel name in message posts
  - Server.send_message() now only takes the channel name
- Remove Server.close()
- Remove ServerEventAuthentication.success
- Replace ServerEventMessageReceived.channel with .channel_name
- Add ServerEventListChannels type
  • Loading branch information
thegamecracks committed Mar 13, 2024
1 parent 519bf56 commit d09794c
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 120 deletions.
19 changes: 11 additions & 8 deletions src/dumdum/client/chat_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,15 @@ def handle_client_event(self, event: ClientEvent) -> None:
self.channel_list.refresh()
elif isinstance(event, ClientEventMessageReceived):
message = self.message_cache.add_message_from_event(event)
if event.channel == self.channel_list.selected_channel:
channel = self.get_channel(event.channel_name)
if channel == self.channel_list.selected_channel:
self.messages.add_message(message)

def get_channel(self, name: str) -> Channel | None:
for channel in self.channels:
if channel.name == name:
return channel


class ChannelList(Frame):
def __init__(self, parent: ChatFrame) -> None:
Expand Down Expand Up @@ -82,10 +88,7 @@ def selected_channel(self) -> Channel | None:
if len(selection) < 1:
return None

name = selection[0]
for channel in self.parent.channels:
if channel.name == name:
return channel
return self.parent.get_channel(selection[0])

def refresh(self) -> None:
selection = self.tree.selection()
Expand Down Expand Up @@ -171,12 +174,12 @@ class MessageCache:
def __init__(self) -> None:
self.channel_messages: dict[str, list[Message]] = collections.defaultdict(list)

def add_message(self, channel: Channel, message: Message) -> None:
self.channel_messages[channel.name].append(message)
def add_message(self, channel_name: str, message: Message) -> None:
self.channel_messages[channel_name].append(message)

def add_message_from_event(self, event: ClientEventMessageReceived) -> Message:
message = Message(event.nick, event.content)
self.add_message(event.channel, message)
self.add_message(event.channel_name, message)
return message

def get_messages(self, channel: Channel) -> list[Message]:
Expand Down
1 change: 1 addition & 0 deletions src/dumdum/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ServerEvent,
ServerEventAuthentication,
ServerEventIncompatibleVersion,
ServerEventListChannels,
ServerEventMessageReceived,
ServerMessageAcknowledgeAuthentication,
ServerMessageListChannels,
Expand Down
2 changes: 1 addition & 1 deletion src/dumdum/protocol/client/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ClientEventAuthentication(ClientEvent):
class ClientEventMessageReceived(ClientEvent):
"""The server broadcasted a message to the client."""

channel: Channel
channel_name: str
nick: str
content: str

Expand Down
5 changes: 3 additions & 2 deletions src/dumdum/protocol/client/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dumdum.protocol.channel import Channel
from dumdum.protocol.constants import (
MAX_CHANNEL_NAME_LENGTH,
MAX_LIST_CHANNEL_LENGTH_BYTES,
MAX_MESSAGE_LENGTH,
MAX_NICK_LENGTH,
Expand Down Expand Up @@ -112,10 +113,10 @@ def _accept_authentication(self, reader: Reader) -> ParsedData:

def _parse_message(self, reader: Reader) -> ParsedData:
self._assert_state(ClientState.READY)
channel = Channel.from_reader(reader)
channel_name = reader.read_varchar(max_length=MAX_CHANNEL_NAME_LENGTH)
nick = reader.read_varchar(max_length=MAX_NICK_LENGTH)
content = reader.read_varchar(max_length=MAX_MESSAGE_LENGTH)
event = ClientEventMessageReceived(channel, nick, content)
event = ClientEventMessageReceived(channel_name, nick, content)
return [event], b""

def _parse_channel_list(self, reader: Reader) -> ParsedData:
Expand Down
1 change: 1 addition & 0 deletions src/dumdum/protocol/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
ServerEvent,
ServerEventAuthentication,
ServerEventIncompatibleVersion,
ServerEventListChannels,
ServerEventMessageReceived,
)
from .messages import (
Expand Down
10 changes: 6 additions & 4 deletions src/dumdum/protocol/server/events.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from dataclasses import dataclass

from dumdum.protocol.channel import Channel


@dataclass
class ServerEvent:
Expand All @@ -19,13 +17,17 @@ class ServerEventIncompatibleVersion(ServerEvent):
class ServerEventAuthentication(ServerEvent):
"""The client attempted to authenticate with the server."""

success: bool
nick: str


@dataclass
class ServerEventMessageReceived(ServerEvent):
"""The client sent a message to the server."""

channel: Channel
channel_name: str
content: str


@dataclass
class ServerEventListChannels(ServerEvent):
"""The client requested a list of channels."""
5 changes: 3 additions & 2 deletions src/dumdum/protocol/server/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dumdum.protocol import varchar
from dumdum.protocol.channel import Channel
from dumdum.protocol.constants import (
MAX_CHANNEL_NAME_LENGTH,
MAX_LIST_CHANNEL_LENGTH_BYTES,
MAX_MESSAGE_LENGTH,
MAX_NICK_LENGTH,
Expand Down Expand Up @@ -39,15 +40,15 @@ def __bytes__(self) -> bytes:

@dataclass
class ServerMessagePost:
channel: Channel
channel_name: str
nick: str
content: str

def __bytes__(self) -> bytes:
return bytes(
[
ServerMessageType.SEND_MESSAGE.value,
*bytes(self.channel),
*varchar.dumps(self.channel_name, max_length=MAX_CHANNEL_NAME_LENGTH),
*varchar.dumps(self.nick, max_length=MAX_NICK_LENGTH),
*varchar.dumps(self.content, max_length=MAX_MESSAGE_LENGTH),
]
Expand Down
56 changes: 20 additions & 36 deletions src/dumdum/protocol/server/protocol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum
from typing import Sequence

from dumdum.protocol.channel import Channel
from dumdum.protocol.constants import (
Expand All @@ -8,14 +9,14 @@
)
from dumdum.protocol.enums import ClientMessageType
from dumdum.protocol.errors import InvalidStateError
from dumdum.protocol.highcommand import HighCommand
from dumdum.protocol.interfaces import Protocol
from dumdum.protocol.reader import Reader, bytearray_reader

from .events import (
ServerEvent,
ServerEventAuthentication,
ServerEventIncompatibleVersion,
ServerEventListChannels,
ServerEventMessageReceived,
)
from .messages import (
Expand All @@ -38,25 +39,28 @@ class Server(Protocol):

PROTOCOL_VERSION = 0

nick: str | None

def __init__(self, hc: HighCommand) -> None:
self.hc = hc
self.nick = None
def __init__(self) -> None:
self._buffer = bytearray()
self._state = ServerState.AWAITING_AUTHENTICATION

def receive_bytes(self, data: bytes) -> ParsedData:
self._buffer.extend(data)
return self._maybe_parse_buffer()

def send_message(self, channel: Channel, nick: str, content: str) -> bytes:
def acknowledge_authentication(self, *, success: bool) -> bytes:
self._assert_state(ServerState.AWAITING_AUTHENTICATION)

if success:
self._state = ServerState.READY

return bytes(ServerMessageAcknowledgeAuthentication(success))

def send_message(self, channel_name: str, nick: str, content: str) -> bytes:
self._assert_state(ServerState.READY)
return bytes(ServerMessagePost(channel, nick, content))
return bytes(ServerMessagePost(channel_name, nick, content))

def close(self) -> None:
if self.nick is not None:
self.hc.remove_user(self.nick)
def list_channels(self, channels: Sequence[Channel]) -> bytes:
return bytes(ServerMessageListChannels(channels))

def _assert_state(self, *states: ServerState) -> None:
if self._state not in states:
Expand Down Expand Up @@ -101,39 +105,19 @@ def _authenticate(self, reader: Reader) -> ParsedData:
response = ServerMessageSendIncompatibleVersion(self.PROTOCOL_VERSION)
return [event], bytes(response)

success = True

user = self.hc.get_user(nick)
if user is not None:
# TODO: maybe add message type for taken nickname
success = False

if success:
self.hc.add_user(nick)
self.nick = nick
self._state = ServerState.READY

event = ServerEventAuthentication(success=success, nick=nick)
response = ServerMessageAcknowledgeAuthentication(success)
return [event], bytes(response)
event = ServerEventAuthentication(nick=nick)
return [event], b""

def _send_message(self, reader: Reader) -> ParsedData:
self._assert_state(ServerState.READY)
channel_name = reader.read_varchar(max_length=MAX_CHANNEL_NAME_LENGTH)
content = reader.read_varchar(max_length=MAX_MESSAGE_LENGTH)

channel = self.hc.get_channel(channel_name)
if channel is None:
# TODO: maybe add event for invalid channel
return [], b""

event = ServerEventMessageReceived(channel, content)
event = ServerEventMessageReceived(channel_name, content)
# TODO: broadcast message to all users
assert self.nick is not None
return [event], b""

def _list_channels(self, reader: Reader) -> ParsedData:
# TODO: maybe add event for listing channels
self._assert_state(ServerState.READY)
response = ServerMessageListChannels(self.hc.channels)
return [], bytes(response)
event = ServerEventListChannels()
return [event], b""
68 changes: 50 additions & 18 deletions src/dumdum/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
InvalidStateError,
Server,
ServerEvent,
ServerEventAuthentication,
ServerEventListChannels,
ServerEventMessageReceived,
)

Expand Down Expand Up @@ -115,10 +117,10 @@ async def accept_connection(
log.info("Connection %s has disconnected", addr)
writer.close()
await writer.wait_closed()
self.connections.remove(connection)
self._close_connection(connection)

def _create_server(self) -> Server:
return Server(self.hc)
return Server()

def _handle_events(self, conn: Connection, events: list[ServerEvent]) -> None:
for event in events:
Expand All @@ -131,26 +133,57 @@ def _handle_event(self, conn: Connection, event: ServerEvent) -> None:
conn.writer.get_extra_info("peername"),
)

if isinstance(event, ServerEventMessageReceived):
if isinstance(event, ServerEventAuthentication):
self._authenticate(conn, event)
elif isinstance(event, ServerEventMessageReceived):
self._broadcast_message(conn, event)
elif isinstance(event, ServerEventListChannels):
self._list_channels(conn, event)

def _authenticate(self, conn: Connection, event: ServerEventAuthentication) -> None:
user = self.hc.get_user(event.nick)
if user is None:
self.hc.add_user(event.nick)
conn.nick = event.nick
success = True
else:
success = False

data = conn.server.acknowledge_authentication(success=success)
conn.writer.write(data)

def _broadcast_message(
self,
conn: Connection,
event: ServerEventMessageReceived,
) -> None:
assert conn.server.nick is not None
assert conn.nick is not None

if self.hc.get_channel(event.channel_name) is None:
return

for peer in self.connections:
with contextlib.suppress(InvalidStateError):
data = peer.server.send_message(
event.channel,
conn.server.nick,
event.channel_name,
conn.nick,
event.content,
)
peer.writer.write(data)

def _list_channels(self, conn: Connection, event: ServerEventListChannels) -> None:
data = conn.server.list_channels(self.hc.channels)
conn.writer.write(data)

def _close_connection(self, conn: Connection) -> None:
self.connections.remove(conn)
if conn.nick is not None:
self.hc.remove_user(conn.nick)


class Connection:
nick: str | None

def __init__(
self,
manager: Manager,
Expand All @@ -163,19 +196,18 @@ def __init__(
self.writer = writer
self.server = server

self.nick = None

async def communicate(self) -> None:
try:
while True:
data = await self.reader.read(1024)
if len(data) == 0:
break

events, outgoing = self.server.receive_bytes(data)
self.writer.write(outgoing)
self._handle_events(events)
await self.writer.drain() # exert backpressure
finally:
self.server.close()
while True:
data = await self.reader.read(1024)
if len(data) == 0:
break

events, outgoing = self.server.receive_bytes(data)
self.writer.write(outgoing)
self._handle_events(events)
await self.writer.drain() # exert backpressure

def _handle_events(self, events: list[ServerEvent]) -> None:
self.manager._handle_events(self, events)
Expand Down
Loading

0 comments on commit d09794c

Please sign in to comment.