Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support declaring DTOField via Annotated #2367

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
.scannerwork/
.unasyncd_cache/
.venv/
.venv*
.vscode/
__pycache__/
build/
Expand Down
5 changes: 5 additions & 0 deletions docs/reference/channels/backends/asyncpg.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
asyncpg
=======

.. automodule:: litestar.channels.backends.asyncpg
:members:
2 changes: 2 additions & 0 deletions docs/reference/channels/backends/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ backends
base
memory
redis
psycopg
asyncpg
5 changes: 5 additions & 0 deletions docs/reference/channels/backends/psycopg.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
psycopg
=======

.. automodule:: litestar.channels.backends.psycopg
:members:
13 changes: 12 additions & 1 deletion docs/usage/channels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ implemented are:
A basic in-memory backend, mostly useful for testing and local development, but
still fully capable. Since it stores all data in-process, it can achieve the highest
performance of all the backends, but at the same time is not suitable for
applications running on multiple processes.
applications running on multiple processes

:class:`RedisChannelsPubSubBackend <.redis.RedisChannelsPubSubBackend>`
A Redis based backend, using `Pub/Sub <https://redis.io/docs/manual/pubsub/>`_ to
Expand All @@ -413,6 +413,17 @@ implemented are:
when history is needed


:class:`AsyncPgChannelsBackend <.asyncpg.AsyncPgChannelsBackend>`
A postgres backend using the
`asyncpg <https://magicstack.github.io/asyncpg/current/>`_ driver


:class:`PsycoPgChannelsBackend <.psycopg.AsyncPgChannelsBackend>`
A postgres backend using the `psycopg3 <https://www.psycopg.org/psycopg3/docs/>`_
async driver




Integrating with websocket handlers
-----------------------------------
Expand Down
82 changes: 82 additions & 0 deletions litestar/channels/backends/asyncpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations

import asyncio
from contextlib import AsyncExitStack
from functools import partial
from typing import AsyncGenerator, Awaitable, Callable, Iterable, overload

import asyncpg

from litestar.channels import ChannelsBackend
from litestar.exceptions import ImproperlyConfiguredException


class AsyncPgChannelsBackend(ChannelsBackend):
_listener_conn: asyncpg.Connection
_queue: asyncio.Queue[tuple[str, bytes]]

@overload
def __init__(self, dsn: str) -> None:
...

@overload
def __init__(
self,
*,
make_connection: Callable[[], Awaitable[asyncpg.Connection]],
) -> None:
...

def __init__(
self,
dsn: str | None = None,
*,
make_connection: Callable[[], Awaitable[asyncpg.Connection]] | None = None,
) -> None:
if not (dsn or make_connection):
raise ImproperlyConfiguredException("Need to specify dsn or make_connection")

self._subscribed_channels: set[str] = set()
self._exit_stack = AsyncExitStack()
self._connect = make_connection or partial(asyncpg.connect, dsn=dsn)

async def on_startup(self) -> None:
self._queue = asyncio.Queue()
self._listener_conn = await self._connect()

async def on_shutdown(self) -> None:
await self._listener_conn.close()
del self._queue

async def publish(self, data: bytes, channels: Iterable[str]) -> None:
dec_data = data.decode("utf-8")

conn = await self._connect()
try:
for channel in channels:
await conn.execute("SELECT pg_notify($1, $2);", channel, dec_data)
finally:
await conn.close()

async def subscribe(self, channels: Iterable[str]) -> None:
for channel in set(channels) - self._subscribed_channels:
await self._listener_conn.add_listener(channel, self._listener) # type: ignore[arg-type]
self._subscribed_channels.add(channel)

async def unsubscribe(self, channels: Iterable[str]) -> None:
for channel in channels:
await self._listener_conn.remove_listener(channel, self._listener) # type: ignore[arg-type]
self._subscribed_channels = self._subscribed_channels - set(channels)

async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
while self._queue:
yield await self._queue.get()
self._queue.task_done()

async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
raise NotImplementedError()

def _listener(self, /, connection: asyncpg.Connection, pid: int, channel: str, payload: object) -> None:
if not isinstance(payload, str):
raise RuntimeError("Invalid data received")
self._queue.put_nowait((channel, payload.encode("utf-8")))
54 changes: 54 additions & 0 deletions litestar/channels/backends/psycopg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from contextlib import AsyncExitStack
from typing import AsyncGenerator, Iterable

import psycopg

from .base import ChannelsBackend


def _safe_quote(ident: str) -> str:
return '"{}"'.format(ident.replace('"', '""')) # sourcery skip


class PsycoPgChannelsBackend(ChannelsBackend):
_listener_conn: psycopg.AsyncConnection

def __init__(self, pg_dsn: str) -> None:
self._pg_dsn = pg_dsn
self._subscribed_channels: set[str] = set()
self._exit_stack = AsyncExitStack()

async def on_startup(self) -> None:
self._listener_conn = await psycopg.AsyncConnection.connect(self._pg_dsn, autocommit=True)
await self._exit_stack.enter_async_context(self._listener_conn)

async def on_shutdown(self) -> None:
await self._exit_stack.aclose()

async def publish(self, data: bytes, channels: Iterable[str]) -> None:
dec_data = data.decode("utf-8")
async with await psycopg.AsyncConnection.connect(self._pg_dsn) as conn:
for channel in channels:
await conn.execute("SELECT pg_notify(%s, %s);", (channel, dec_data))

async def subscribe(self, channels: Iterable[str]) -> None:
for channel in set(channels) - self._subscribed_channels:
# can't use placeholders in LISTEN
await self._listener_conn.execute(f"LISTEN {_safe_quote(channel)};") # pyright: ignore

self._subscribed_channels.add(channel)

async def unsubscribe(self, channels: Iterable[str]) -> None:
for channel in channels:
# can't use placeholders in UNLISTEN
await self._listener_conn.execute(f"UNLISTEN {_safe_quote(channel)};") # pyright: ignore
self._subscribed_channels = self._subscribed_channels - set(channels)

async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
async for notify in self._listener_conn.notifies():
yield notify.channel, notify.payload.encode("utf-8")

async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
raise NotImplementedError()
2 changes: 1 addition & 1 deletion litestar/channels/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,10 @@ async def _sub_worker(self) -> None:
subscriber.put_nowait(payload)

async def _on_startup(self) -> None:
await self._backend.on_startup()
self._pub_queue = Queue()
self._pub_task = create_task(self._pub_worker())
self._sub_task = create_task(self._sub_worker())
await self._backend.on_startup()
if self._channels:
await self._backend.subscribe(list(self._channels))

Expand Down
5 changes: 3 additions & 2 deletions litestar/contrib/pydantic/pydantic_dto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from litestar.contrib.pydantic.utils import is_pydantic_undefined
from litestar.dto.base_dto import AbstractDTO
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.dto.field import DTO_FIELD_META_KEY, DTOField
from litestar.dto.field import DTO_FIELD_META_KEY, extract_dto_field
from litestar.exceptions import MissingDependencyException, ValidationException
from litestar.types.empty import Empty

Expand Down Expand Up @@ -81,7 +81,8 @@ def generate_field_definitions(

for field_name, field_info in model_fields.items():
field_definition = model_field_definitions[field_name]
dto_field = (field_definition.extra or {}).pop(DTO_FIELD_META_KEY, DTOField())
dto_field = extract_dto_field(field_definition, field_definition.extra)
field_definition.extra.pop(DTO_FIELD_META_KEY, None)

if not is_pydantic_undefined(field_info.default):
default = field_info.default
Expand Down
4 changes: 2 additions & 2 deletions litestar/dto/dataclass_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from litestar.dto.base_dto import AbstractDTO
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.dto.field import DTO_FIELD_META_KEY, DTOField
from litestar.dto.field import extract_dto_field
from litestar.params import DependencyKwarg, KwargDefinition
from litestar.types.empty import Empty

Expand Down Expand Up @@ -40,7 +40,7 @@ def generate_field_definitions(
DTOFieldDefinition.from_field_definition(
field_definition=field_definition,
default_factory=default_factory,
dto_field=dc_field.metadata.get(DTO_FIELD_META_KEY, DTOField()),
dto_field=extract_dto_field(field_definition, dc_field.metadata),
model_name=model_type.__name__,
),
name=key,
Expand Down
35 changes: 33 additions & 2 deletions litestar/dto/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@

from dataclasses import dataclass
from enum import Enum
from typing import Literal
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, Literal, Mapping

from litestar.typing import FieldDefinition

__all__ = (
"DTO_FIELD_META_KEY",
"DTOField",
"Mark",
"dto_field",
"extract_dto_field",
)

DTO_FIELD_META_KEY = "__dto__"
Expand All @@ -26,7 +32,7 @@ class Mark(str, Enum):
"""To mark a field that can neither be read or updated by clients."""


@dataclass
@dataclass(unsafe_hash=True)
class DTOField:
"""For configuring DTO behavior on model fields."""

Expand All @@ -47,3 +53,28 @@ def dto_field(mark: Literal["read-only", "write-only", "private"] | Mark) -> dic
Marking a field automates its inclusion/exclusion from DTO field definitions, depending on the DTO's purpose.
"""
return {DTO_FIELD_META_KEY: DTOField(mark=Mark(mark))}


def extract_dto_field(field_definition: FieldDefinition, field_info_mapping: Mapping[str, Any]) -> DTOField:
"""Extract ``DTOField`` instance for a model field.

Supports ``DTOField`` to bet set via ``Annotated`` or via a field info/metadata mapping.

E.g., ``Annotated[str, DTOField(mark="read-only")]`` or ``info=dto_field(mark="read-only")``.

If a value is found in ``field_info_mapping``, it is prioritized over the field definition's metadata.

Args:
field_definition: A field definition.
field_info_mapping: A field metadata/info attribute mapping, e.g., SQLAlchemy's ``info`` attribute,
or dataclasses ``metadata`` attribute.

Returns:
DTO field info, if any.
"""
if inst := field_info_mapping.get(DTO_FIELD_META_KEY):
if not isinstance(inst, DTOField):
raise TypeError(f"DTO field info must be an instance of DTOField, got '{inst}'")
return inst

return next((f for f in field_definition.metadata if isinstance(f, DTOField)), DTOField())
5 changes: 3 additions & 2 deletions litestar/dto/msgspec_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from litestar.dto.base_dto import AbstractDTO
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.dto.field import DTO_FIELD_META_KEY, DTOField
from litestar.dto.field import DTO_FIELD_META_KEY, extract_dto_field
from litestar.types.empty import Empty

if TYPE_CHECKING:
Expand Down Expand Up @@ -36,7 +36,8 @@ def default_or_none(value: Any) -> Any:

for key, field_definition in cls.get_model_type_hints(model_type).items():
msgspec_field = msgspec_fields[key]
dto_field = (field_definition.extra or {}).pop(DTO_FIELD_META_KEY, DTOField())
dto_field = extract_dto_field(field_definition, field_definition.extra)
field_definition.extra.pop(DTO_FIELD_META_KEY, None)

yield replace(
DTOFieldDefinition.from_field_definition(
Expand Down
Loading
Loading