Skip to content

Commit

Permalink
Merge pull request #567 from Gencaster/notifications-refactoring
Browse files Browse the repository at this point in the history
Notifications refactoring
  • Loading branch information
vin-ni authored Sep 14, 2023
2 parents 1928881 + bcb2443 commit 90dee61
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 129 deletions.
27 changes: 17 additions & 10 deletions caster-back/gencaster/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dataclasses import asdict, dataclass, field
from typing import AsyncGenerator, Awaitable, Callable, List, Optional, Union

from channels.layers import get_channel_layer
from channels_redis.core import RedisChannelLayer
from strawberry.channels import GraphQLWSConsumer

Expand Down Expand Up @@ -70,30 +71,36 @@ def __init__(self) -> None:
pass

@staticmethod
async def send_graph_update(layer: RedisChannelLayer, graph_uuid: uuid.UUID):
def _get_layer() -> RedisChannelLayer:
if layer := get_channel_layer():
return layer
raise Exception("Could not obtain redis channel layer")

@staticmethod
async def send_graph_update(graph_uuid: uuid.UUID):
return await GenCasterChannel.send_message(
layer=layer,
layer=GenCasterChannel._get_layer(),
message=GraphUpdateMessage(uuid=str(graph_uuid)),
)

@staticmethod
async def send_node_update(layer: RedisChannelLayer, node_uuid: uuid.UUID):
async def send_node_update(node_uuid: uuid.UUID):
return await GenCasterChannel.send_message(
layer=layer, message=NodeUpdateMessage(uuid=str(node_uuid))
layer=GenCasterChannel._get_layer(),
message=NodeUpdateMessage(uuid=str(node_uuid)),
)

@staticmethod
async def send_log_update(
layer: RedisChannelLayer, stream_log_message: "StreamLogUpdateMessage"
):
async def send_log_update(stream_log_message: "StreamLogUpdateMessage"):
return await GenCasterChannel.send_message(
layer=layer, message=stream_log_message
layer=GenCasterChannel._get_layer(), message=stream_log_message
)

@staticmethod
async def send_streams_update(layer: RedisChannelLayer, stream_uuid: str):
async def send_streams_update(stream_uuid: str):
return await GenCasterChannel.send_message(
layer=layer, message=StreamsUpdateMessage(uuid=str(stream_uuid))
layer=GenCasterChannel._get_layer(),
message=StreamsUpdateMessage(uuid=str(stream_uuid)),
)

@staticmethod
Expand Down
90 changes: 6 additions & 84 deletions caster-back/gencaster/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ async def update_audio_file(
audio_file.name = update_audio_file.name
if update_audio_file.description:
audio_file.description = update_audio_file.description
await sync_to_async(audio_file.save)()
await audio_file.asave()
return audio_file # type: ignore

@strawberry.mutation
Expand All @@ -243,14 +243,7 @@ async def add_node(self, info: Info, new_node: NodeCreate) -> None:
if new_value := getattr(new_node, field):
setattr(node, field, new_value)

# asave not yet implemented in django 4.1
await sync_to_async(node.save)()

await GenCasterChannel.send_graph_update(
layer=info.context.channel_layer,
graph_uuid=graph.uuid,
)

await node.asave()
return None

@strawberry.mutation
Expand All @@ -268,18 +261,7 @@ async def update_node(self, info: Info, node_update: NodeUpdate) -> None:
if new_value := getattr(node_update, field):
setattr(node, field, new_value)

await sync_to_async(node.save)()

await GenCasterChannel.send_graph_update(
layer=info.context.channel_layer,
graph_uuid=node.graph.uuid,
)

await GenCasterChannel.send_node_update(
layer=info.context.channel_layer,
node_uuid=node.uuid,
)

await node.asave()
return None

@strawberry.mutation
Expand All @@ -299,57 +281,19 @@ async def add_edge(self, info: Info, new_edge: EdgeInput) -> Edge:
in_node_door=in_node_door,
out_node_door=out_node_door,
)
await GenCasterChannel.send_graph_update(
layer=info.context.channel_layer,
graph_uuid=in_node_door.node.graph.uuid,
)
return edge # type: ignore

@strawberry.mutation
async def delete_edge(self, info, edge_uuid: uuid.UUID) -> None:
"""Deletes a given :class:`~story_graph.models.Edge`."""
await graphql_check_authenticated(info)
try:
edge: story_graph_models.Edge = (
await story_graph_models.Edge.objects.select_related(
"in_node_door__node__graph"
).aget(uuid=edge_uuid)
)
await story_graph_models.Edge.objects.filter(uuid=edge_uuid).adelete()
except Exception:
raise Exception(f"Could not delete edge {edge_uuid}")
if edge.in_node_door:
await GenCasterChannel.send_graph_update(
layer=info.context.channel_layer,
graph_uuid=edge.in_node_door.node.graph.uuid,
)
return None
await story_graph_models.Edge.objects.filter(uuid=edge_uuid).adelete()

@strawberry.mutation
async def delete_node(self, info, node_uuid: uuid.UUID) -> None:
"""Deletes a given :class:`~story_graph.models.Node`."""
await graphql_check_authenticated(info)
try:
node: story_graph_models.Node = (
await story_graph_models.Node.objects.select_related("graph").aget(
uuid=node_uuid
)
)
await story_graph_models.Node.objects.filter(uuid=node_uuid).adelete()
except Exception:
raise Exception(f"Could delete node {node_uuid}")

await GenCasterChannel.send_graph_update(
layer=info.context.channel_layer,
graph_uuid=node.graph.uuid,
)

await GenCasterChannel.send_node_update(
layer=info.context.channel_layer,
node_uuid=node.uuid,
)

return None
await story_graph_models.Node.objects.filter(uuid=node_uuid).adelete()

@strawberry.mutation
async def create_script_cells(
Expand All @@ -367,7 +311,6 @@ async def create_script_cells(
)
except story_graph_models.Node.DoesNotExist as e:
log.error(f"Received update on unknown node {node_uuid}")
# @todo return error
raise e

script_cells: List[story_graph_models.ScriptCell] = []
Expand Down Expand Up @@ -398,9 +341,6 @@ async def create_script_cells(
log.debug(f"Created script cell {script_cell.uuid}")
script_cells.append(script_cell)

await GenCasterChannel.send_node_update(
layer=info.context.channel_layer, node_uuid=node.uuid
)
return script_cells # type: ignore

@strawberry.mutation
Expand Down Expand Up @@ -439,35 +379,17 @@ async def update_script_cells(
).aupdate(**updates)
script_cells.append(script_cell)

# send update to subscription if something was updated
if len(script_cells) > 0:
await GenCasterChannel.send_node_update(
layer=info.context.channel_layer,
node_uuid=await sync_to_async(lambda: script_cells[0].node.uuid)(),
)

return script_cells # type: ignore

@strawberry.mutation
async def delete_script_cell(self, info, script_cell_uuid: uuid.UUID) -> None:
"""Deletes a given :class:`~story_graph.models.ScriptCell`."""
await graphql_check_authenticated(info)

# first get the node before the cell is deleted
node = await story_graph_models.Node.objects.filter(
script_cells__uuid=script_cell_uuid
).afirst()

await story_graph_models.ScriptCell.objects.filter(
uuid=script_cell_uuid
).adelete()

if node:
await GenCasterChannel.send_node_update(
layer=info.context.channel_layer,
node_uuid=node.uuid,
)

@strawberry.mutation
async def add_graph(self, info, graph_input: AddGraphInput) -> Graph:
await graphql_check_authenticated(info)
Expand Down Expand Up @@ -574,7 +496,7 @@ async def update_node_door(
self,
info,
node_door_input: NodeDoorInputUpdate,
) -> NodeDoorResponse:
) -> NodeDoorResponse: # type: ignore
await graphql_check_authenticated(info)
node_door = await story_graph_models.NodeDoor.objects.aget(
uuid=node_door_input.uuid
Expand Down
5 changes: 2 additions & 3 deletions caster-back/gencaster/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,7 @@ async def test_delete_unavailable_edge(self):
variable_values={"edgeUuid": str(uuid.uuid4())},
context_value=self.get_login_context(),
)

self.assertGreaterEqual(len(resp.errors), 1) # type: ignore
self.assertIsNone(resp.data["deleteEdge"]) # type: ignore

NODE_DELETE_MUTATION = """
mutation deleteNode($nodeUuid: UUID!) {
Expand Down Expand Up @@ -226,7 +225,7 @@ async def test_delete_unavailable_node(self):
context_value=self.get_login_context(),
)

self.assertGreaterEqual(len(resp.errors), 1) # type: ignore
self.assertIsNone(resp.data["deleteNode"]) # type: ignore

CREATE_SCRIPT_CELL = """
mutation CreateScriptCells($nodeUuid: UUID!, $scriptCellInputs: [ScriptCellInputCreate!]!) {
Expand Down
73 changes: 48 additions & 25 deletions caster-back/story_graph/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import uuid

from asgiref.sync import async_to_sync, sync_to_async
from channels.layers import get_channel_layer
from django.core.exceptions import ValidationError
from django.db import models, transaction
from django.db.models import Q, signals
Expand Down Expand Up @@ -160,6 +159,20 @@ def __str__(self) -> str:
return self.name


def update_graph_db_to_ws(graph_uuid: uuid.UUID):
# sorry for this atrocity - there seems to be race conditions with signals
# which makes updates out-dated, see
# https://docs.djangoproject.com/en/dev/topics/db/transactions/#performing-actions-after-commit
transaction.on_commit(
lambda: async_to_sync(GenCasterChannel.send_graph_update)(graph_uuid)
)


@receiver(signals.post_save, sender=Graph, dispatch_uid="update_graph_ws")
def update_graph_ws(sender, instance: Graph, **kwargs) -> None:
update_graph_db_to_ws(instance.uuid)


class Node(models.Model):
"""
A node.
Expand Down Expand Up @@ -273,6 +286,22 @@ def __str__(self) -> str:
return self.name


def update_node_db_to_ws(node_uuid: uuid.UUID):
# sorry for this atrocity - there seems to be race conditions with signals
# which makes updates out-dated, see
# https://docs.djangoproject.com/en/dev/topics/db/transactions/#performing-actions-after-commit
transaction.on_commit(
lambda: async_to_sync(GenCasterChannel.send_node_update)(node_uuid)
)


@receiver(signals.post_delete, sender=Node, dispatch_uid="delete_node_ws")
@receiver(signals.post_save, sender=Node, dispatch_uid="update_node_ws")
def update_node_ws(sender, instance: Node, **kwargs) -> None:
update_node_db_to_ws(instance.uuid)
update_graph_db_to_ws(instance.graph.uuid)


class NodeDoorMissing(Exception):
"""Exception that can be thrown if a node door is missing.
Normally each node should have a default in- and out
Expand Down Expand Up @@ -401,30 +430,8 @@ def __str__(self) -> str:

@receiver(signals.post_save, sender=NodeDoor, dispatch_uid="update_node_door_ws")
def update_node_door_ws(sender, instance: NodeDoor, **kwargs) -> None:
channel_layer = get_channel_layer()
if channel_layer is None:
log.error(
"Failed to obtain a handle on the channel layer to distribute node_door updates"
)
return
# sorry for this atrocity - there seems to be race conditions with signals
# which makes updates out-dated, see
# https://docs.djangoproject.com/en/dev/topics/db/transactions/#performing-actions-after-commit
# it is possible that the instance is not available anymore after the commit, so
# so we store it here in memory
node_uuid = instance.node.uuid
graph_uuid = instance.node.graph.uuid
transaction.on_commit(
lambda: async_to_sync(GenCasterChannel.send_node_update)(
channel_layer, node_uuid
)
)
transaction.on_commit(
lambda: async_to_sync(GenCasterChannel.send_graph_update)(
channel_layer,
graph_uuid,
)
)
update_graph_db_to_ws(instance.node.graph.uuid)
update_node_db_to_ws(instance.node.uuid)


@receiver(signals.post_delete, sender=NodeDoor, dispatch_uid="delete_node_door_ws")
Expand Down Expand Up @@ -546,6 +553,16 @@ def __str__(self) -> str:
return f"{self.in_node_door} -> {self.out_node_door}"


@receiver(signals.pre_delete, sender=Edge, dispatch_uid="delete_edge_ws")
@receiver(signals.post_save, sender=Edge, dispatch_uid="update_edge_ws")
def update_edge_ws(sender, instance: Edge, **kwargs) -> None:
if instance.out_node_door:
update_graph_db_to_ws(instance.out_node_door.node.graph.uuid)
update_node_db_to_ws(instance.out_node_door.node.uuid)
if instance.in_node_door:
update_node_db_to_ws(instance.in_node_door.node.uuid)


class AudioCell(models.Model):
"""Stores information for playback of static audio files."""

Expand Down Expand Up @@ -718,3 +735,9 @@ class Meta:

def __str__(self) -> str:
return f"{self.node}-{self.cell_order} ({self.cell_type})"


@receiver(signals.post_delete, sender=ScriptCell, dispatch_uid="delete_script_cell")
@receiver(signals.post_save, sender=ScriptCell, dispatch_uid="update_script_cell_ws")
def update_script_cell_ws(sender, instance: ScriptCell, **kwargs) -> None:
update_node_db_to_ws(instance.node.uuid)
3 changes: 0 additions & 3 deletions caster-back/stream/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from threading import Thread

from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.apps import AppConfig


Expand All @@ -32,7 +31,6 @@ def __init__(self, level: int = logging.DEBUG) -> None:
self._thread = Thread(target=self._loop)
self._thread.daemon = True
self._thread.start()
self._channel = get_channel_layer()
self._event_loop = asyncio.get_event_loop()
super().__init__(level)

Expand Down Expand Up @@ -61,7 +59,6 @@ def _loop(self):
)

async_to_sync(GenCasterChannel.send_log_update)(
self._channel,
StreamLogUpdateMessage(
uuid=str(stream_log.uuid),
stream_point_uuid=str(stream_point.uuid)
Expand Down
5 changes: 1 addition & 4 deletions caster-back/stream/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Optional

from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.conf import settings
from django.contrib import admin
from django.core.files import File
Expand Down Expand Up @@ -335,9 +334,7 @@ def __str__(self) -> str:

@receiver(signals.post_save, sender=Stream, dispatch_uid="update_streams_ws")
def update_streams_ws(sender, instance: Stream, **kwargs):
async_to_sync(GenCasterChannel.send_streams_update)(
get_channel_layer(), str(instance.uuid)
)
async_to_sync(GenCasterChannel.send_streams_update)(str(instance.uuid))


class StreamVariable(models.Model):
Expand Down

0 comments on commit 90dee61

Please sign in to comment.