From bcb2443b5f9e0a83b5a1b40e3f4962ef023bc453 Mon Sep 17 00:00:00 2001 From: Dennis Scheiba Date: Thu, 14 Sep 2023 17:11:32 +0200 Subject: [PATCH] refactor notifications from schema to models --- caster-back/gencaster/distributor.py | 27 +++++---- caster-back/gencaster/schema.py | 90 ++-------------------------- caster-back/gencaster/tests.py | 5 +- caster-back/story_graph/models.py | 73 ++++++++++++++-------- caster-back/stream/apps.py | 3 - caster-back/stream/models.py | 5 +- 6 files changed, 74 insertions(+), 129 deletions(-) diff --git a/caster-back/gencaster/distributor.py b/caster-back/gencaster/distributor.py index 22ddba1a..c2f964b7 100644 --- a/caster-back/gencaster/distributor.py +++ b/caster-back/gencaster/distributor.py @@ -11,6 +11,7 @@ from dataclasses import asdict, dataclass, field from typing import AsyncGenerator, Awaitable, Callable, List, Optional, Union +from channels.layers import get_channel_layer from channels_redis.core import RedisChannelLayer from strawberry.channels import GraphQLWSConsumer @@ -70,30 +71,36 @@ def __init__(self) -> None: pass @staticmethod - async def send_graph_update(layer: RedisChannelLayer, graph_uuid: uuid.UUID): + def _get_layer() -> RedisChannelLayer: + if layer := get_channel_layer(): + return layer + raise Exception("Could not obtain redis channel layer") + + @staticmethod + async def send_graph_update(graph_uuid: uuid.UUID): return await GenCasterChannel.send_message( - layer=layer, + layer=GenCasterChannel._get_layer(), message=GraphUpdateMessage(uuid=str(graph_uuid)), ) @staticmethod - async def send_node_update(layer: RedisChannelLayer, node_uuid: uuid.UUID): + async def send_node_update(node_uuid: uuid.UUID): return await GenCasterChannel.send_message( - layer=layer, message=NodeUpdateMessage(uuid=str(node_uuid)) + layer=GenCasterChannel._get_layer(), + message=NodeUpdateMessage(uuid=str(node_uuid)), ) @staticmethod - async def send_log_update( - layer: RedisChannelLayer, stream_log_message: "StreamLogUpdateMessage" - ): + async def send_log_update(stream_log_message: "StreamLogUpdateMessage"): return await GenCasterChannel.send_message( - layer=layer, message=stream_log_message + layer=GenCasterChannel._get_layer(), message=stream_log_message ) @staticmethod - async def send_streams_update(layer: RedisChannelLayer, stream_uuid: str): + async def send_streams_update(stream_uuid: str): return await GenCasterChannel.send_message( - layer=layer, message=StreamsUpdateMessage(uuid=str(stream_uuid)) + layer=GenCasterChannel._get_layer(), + message=StreamsUpdateMessage(uuid=str(stream_uuid)), ) @staticmethod diff --git a/caster-back/gencaster/schema.py b/caster-back/gencaster/schema.py index b1231b21..e37526f6 100644 --- a/caster-back/gencaster/schema.py +++ b/caster-back/gencaster/schema.py @@ -221,7 +221,7 @@ async def update_audio_file( audio_file.name = update_audio_file.name if update_audio_file.description: audio_file.description = update_audio_file.description - await sync_to_async(audio_file.save)() + await audio_file.asave() return audio_file # type: ignore @strawberry.mutation @@ -243,14 +243,7 @@ async def add_node(self, info: Info, new_node: NodeCreate) -> None: if new_value := getattr(new_node, field): setattr(node, field, new_value) - # asave not yet implemented in django 4.1 - await sync_to_async(node.save)() - - await GenCasterChannel.send_graph_update( - layer=info.context.channel_layer, - graph_uuid=graph.uuid, - ) - + await node.asave() return None @strawberry.mutation @@ -268,18 +261,7 @@ async def update_node(self, info: Info, node_update: NodeUpdate) -> None: if new_value := getattr(node_update, field): setattr(node, field, new_value) - await sync_to_async(node.save)() - - await GenCasterChannel.send_graph_update( - layer=info.context.channel_layer, - graph_uuid=node.graph.uuid, - ) - - await GenCasterChannel.send_node_update( - layer=info.context.channel_layer, - node_uuid=node.uuid, - ) - + await node.asave() return None @strawberry.mutation @@ -299,57 +281,19 @@ async def add_edge(self, info: Info, new_edge: EdgeInput) -> Edge: in_node_door=in_node_door, out_node_door=out_node_door, ) - await GenCasterChannel.send_graph_update( - layer=info.context.channel_layer, - graph_uuid=in_node_door.node.graph.uuid, - ) return edge # type: ignore @strawberry.mutation async def delete_edge(self, info, edge_uuid: uuid.UUID) -> None: """Deletes a given :class:`~story_graph.models.Edge`.""" await graphql_check_authenticated(info) - try: - edge: story_graph_models.Edge = ( - await story_graph_models.Edge.objects.select_related( - "in_node_door__node__graph" - ).aget(uuid=edge_uuid) - ) - await story_graph_models.Edge.objects.filter(uuid=edge_uuid).adelete() - except Exception: - raise Exception(f"Could not delete edge {edge_uuid}") - if edge.in_node_door: - await GenCasterChannel.send_graph_update( - layer=info.context.channel_layer, - graph_uuid=edge.in_node_door.node.graph.uuid, - ) - return None + await story_graph_models.Edge.objects.filter(uuid=edge_uuid).adelete() @strawberry.mutation async def delete_node(self, info, node_uuid: uuid.UUID) -> None: """Deletes a given :class:`~story_graph.models.Node`.""" await graphql_check_authenticated(info) - try: - node: story_graph_models.Node = ( - await story_graph_models.Node.objects.select_related("graph").aget( - uuid=node_uuid - ) - ) - await story_graph_models.Node.objects.filter(uuid=node_uuid).adelete() - except Exception: - raise Exception(f"Could delete node {node_uuid}") - - await GenCasterChannel.send_graph_update( - layer=info.context.channel_layer, - graph_uuid=node.graph.uuid, - ) - - await GenCasterChannel.send_node_update( - layer=info.context.channel_layer, - node_uuid=node.uuid, - ) - - return None + await story_graph_models.Node.objects.filter(uuid=node_uuid).adelete() @strawberry.mutation async def create_script_cells( @@ -367,7 +311,6 @@ async def create_script_cells( ) except story_graph_models.Node.DoesNotExist as e: log.error(f"Received update on unknown node {node_uuid}") - # @todo return error raise e script_cells: List[story_graph_models.ScriptCell] = [] @@ -398,9 +341,6 @@ async def create_script_cells( log.debug(f"Created script cell {script_cell.uuid}") script_cells.append(script_cell) - await GenCasterChannel.send_node_update( - layer=info.context.channel_layer, node_uuid=node.uuid - ) return script_cells # type: ignore @strawberry.mutation @@ -439,13 +379,6 @@ async def update_script_cells( ).aupdate(**updates) script_cells.append(script_cell) - # send update to subscription if something was updated - if len(script_cells) > 0: - await GenCasterChannel.send_node_update( - layer=info.context.channel_layer, - node_uuid=await sync_to_async(lambda: script_cells[0].node.uuid)(), - ) - return script_cells # type: ignore @strawberry.mutation @@ -453,21 +386,10 @@ async def delete_script_cell(self, info, script_cell_uuid: uuid.UUID) -> None: """Deletes a given :class:`~story_graph.models.ScriptCell`.""" await graphql_check_authenticated(info) - # first get the node before the cell is deleted - node = await story_graph_models.Node.objects.filter( - script_cells__uuid=script_cell_uuid - ).afirst() - await story_graph_models.ScriptCell.objects.filter( uuid=script_cell_uuid ).adelete() - if node: - await GenCasterChannel.send_node_update( - layer=info.context.channel_layer, - node_uuid=node.uuid, - ) - @strawberry.mutation async def add_graph(self, info, graph_input: AddGraphInput) -> Graph: await graphql_check_authenticated(info) @@ -574,7 +496,7 @@ async def update_node_door( self, info, node_door_input: NodeDoorInputUpdate, - ) -> NodeDoorResponse: + ) -> NodeDoorResponse: # type: ignore await graphql_check_authenticated(info) node_door = await story_graph_models.NodeDoor.objects.aget( uuid=node_door_input.uuid diff --git a/caster-back/gencaster/tests.py b/caster-back/gencaster/tests.py index 4974a497..ad8c3151 100644 --- a/caster-back/gencaster/tests.py +++ b/caster-back/gencaster/tests.py @@ -194,8 +194,7 @@ async def test_delete_unavailable_edge(self): variable_values={"edgeUuid": str(uuid.uuid4())}, context_value=self.get_login_context(), ) - - self.assertGreaterEqual(len(resp.errors), 1) # type: ignore + self.assertIsNone(resp.data["deleteEdge"]) # type: ignore NODE_DELETE_MUTATION = """ mutation deleteNode($nodeUuid: UUID!) { @@ -226,7 +225,7 @@ async def test_delete_unavailable_node(self): context_value=self.get_login_context(), ) - self.assertGreaterEqual(len(resp.errors), 1) # type: ignore + self.assertIsNone(resp.data["deleteNode"]) # type: ignore CREATE_SCRIPT_CELL = """ mutation CreateScriptCells($nodeUuid: UUID!, $scriptCellInputs: [ScriptCellInputCreate!]!) { diff --git a/caster-back/story_graph/models.py b/caster-back/story_graph/models.py index 30f6c9f5..5816eb96 100644 --- a/caster-back/story_graph/models.py +++ b/caster-back/story_graph/models.py @@ -8,7 +8,6 @@ import uuid from asgiref.sync import async_to_sync, sync_to_async -from channels.layers import get_channel_layer from django.core.exceptions import ValidationError from django.db import models, transaction from django.db.models import Q, signals @@ -160,6 +159,20 @@ def __str__(self) -> str: return self.name +def update_graph_db_to_ws(graph_uuid: uuid.UUID): + # sorry for this atrocity - there seems to be race conditions with signals + # which makes updates out-dated, see + # https://docs.djangoproject.com/en/dev/topics/db/transactions/#performing-actions-after-commit + transaction.on_commit( + lambda: async_to_sync(GenCasterChannel.send_graph_update)(graph_uuid) + ) + + +@receiver(signals.post_save, sender=Graph, dispatch_uid="update_graph_ws") +def update_graph_ws(sender, instance: Graph, **kwargs) -> None: + update_graph_db_to_ws(instance.uuid) + + class Node(models.Model): """ A node. @@ -273,6 +286,22 @@ def __str__(self) -> str: return self.name +def update_node_db_to_ws(node_uuid: uuid.UUID): + # sorry for this atrocity - there seems to be race conditions with signals + # which makes updates out-dated, see + # https://docs.djangoproject.com/en/dev/topics/db/transactions/#performing-actions-after-commit + transaction.on_commit( + lambda: async_to_sync(GenCasterChannel.send_node_update)(node_uuid) + ) + + +@receiver(signals.post_delete, sender=Node, dispatch_uid="delete_node_ws") +@receiver(signals.post_save, sender=Node, dispatch_uid="update_node_ws") +def update_node_ws(sender, instance: Node, **kwargs) -> None: + update_node_db_to_ws(instance.uuid) + update_graph_db_to_ws(instance.graph.uuid) + + class NodeDoorMissing(Exception): """Exception that can be thrown if a node door is missing. Normally each node should have a default in- and out @@ -401,30 +430,8 @@ def __str__(self) -> str: @receiver(signals.post_save, sender=NodeDoor, dispatch_uid="update_node_door_ws") def update_node_door_ws(sender, instance: NodeDoor, **kwargs) -> None: - channel_layer = get_channel_layer() - if channel_layer is None: - log.error( - "Failed to obtain a handle on the channel layer to distribute node_door updates" - ) - return - # sorry for this atrocity - there seems to be race conditions with signals - # which makes updates out-dated, see - # https://docs.djangoproject.com/en/dev/topics/db/transactions/#performing-actions-after-commit - # it is possible that the instance is not available anymore after the commit, so - # so we store it here in memory - node_uuid = instance.node.uuid - graph_uuid = instance.node.graph.uuid - transaction.on_commit( - lambda: async_to_sync(GenCasterChannel.send_node_update)( - channel_layer, node_uuid - ) - ) - transaction.on_commit( - lambda: async_to_sync(GenCasterChannel.send_graph_update)( - channel_layer, - graph_uuid, - ) - ) + update_graph_db_to_ws(instance.node.graph.uuid) + update_node_db_to_ws(instance.node.uuid) @receiver(signals.post_delete, sender=NodeDoor, dispatch_uid="delete_node_door_ws") @@ -546,6 +553,16 @@ def __str__(self) -> str: return f"{self.in_node_door} -> {self.out_node_door}" +@receiver(signals.pre_delete, sender=Edge, dispatch_uid="delete_edge_ws") +@receiver(signals.post_save, sender=Edge, dispatch_uid="update_edge_ws") +def update_edge_ws(sender, instance: Edge, **kwargs) -> None: + if instance.out_node_door: + update_graph_db_to_ws(instance.out_node_door.node.graph.uuid) + update_node_db_to_ws(instance.out_node_door.node.uuid) + if instance.in_node_door: + update_node_db_to_ws(instance.in_node_door.node.uuid) + + class AudioCell(models.Model): """Stores information for playback of static audio files.""" @@ -718,3 +735,9 @@ class Meta: def __str__(self) -> str: return f"{self.node}-{self.cell_order} ({self.cell_type})" + + +@receiver(signals.post_delete, sender=ScriptCell, dispatch_uid="delete_script_cell") +@receiver(signals.post_save, sender=ScriptCell, dispatch_uid="update_script_cell_ws") +def update_script_cell_ws(sender, instance: ScriptCell, **kwargs) -> None: + update_node_db_to_ws(instance.node.uuid) diff --git a/caster-back/stream/apps.py b/caster-back/stream/apps.py index d33f56a1..af1851a4 100644 --- a/caster-back/stream/apps.py +++ b/caster-back/stream/apps.py @@ -6,7 +6,6 @@ from threading import Thread from asgiref.sync import async_to_sync -from channels.layers import get_channel_layer from django.apps import AppConfig @@ -32,7 +31,6 @@ def __init__(self, level: int = logging.DEBUG) -> None: self._thread = Thread(target=self._loop) self._thread.daemon = True self._thread.start() - self._channel = get_channel_layer() self._event_loop = asyncio.get_event_loop() super().__init__(level) @@ -61,7 +59,6 @@ def _loop(self): ) async_to_sync(GenCasterChannel.send_log_update)( - self._channel, StreamLogUpdateMessage( uuid=str(stream_log.uuid), stream_point_uuid=str(stream_point.uuid) diff --git a/caster-back/stream/models.py b/caster-back/stream/models.py index 1e88bf5f..9cd8c099 100644 --- a/caster-back/stream/models.py +++ b/caster-back/stream/models.py @@ -5,7 +5,6 @@ from typing import Optional from asgiref.sync import async_to_sync -from channels.layers import get_channel_layer from django.conf import settings from django.contrib import admin from django.core.files import File @@ -335,9 +334,7 @@ def __str__(self) -> str: @receiver(signals.post_save, sender=Stream, dispatch_uid="update_streams_ws") def update_streams_ws(sender, instance: Stream, **kwargs): - async_to_sync(GenCasterChannel.send_streams_update)( - get_channel_layer(), str(instance.uuid) - ) + async_to_sync(GenCasterChannel.send_streams_update)(str(instance.uuid)) class StreamVariable(models.Model):