Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Return attrs for more media repo APIs. (#16611)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Nov 9, 2023
1 parent 91587d4 commit ff716b4
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 110 deletions.
1 change: 1 addition & 0 deletions changelog.d/16611.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
15 changes: 9 additions & 6 deletions synapse/handlers/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import logging
import random
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from synapse.api.errors import (
AuthError,
Expand All @@ -23,6 +23,7 @@
StoreError,
SynapseError,
)
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri
Expand Down Expand Up @@ -306,7 +307,9 @@ async def check_avatar_size_and_mime_type(self, mxc: str) -> bool:
server_name = host

if self._is_mine_server_name(server_name):
media_info = await self.store.get_local_media(media_id)
media_info: Optional[
Union[LocalMedia, RemoteMedia]
] = await self.store.get_local_media(media_id)
else:
media_info = await self.store.get_cached_remote_media(server_name, media_id)

Expand All @@ -322,25 +325,25 @@ async def check_avatar_size_and_mime_type(self, mxc: str) -> bool:

if self.max_avatar_size:
# Ensure avatar does not exceed max allowed avatar size
if media_info["media_length"] > self.max_avatar_size:
if media_info.media_length > self.max_avatar_size:
logger.warning(
"Forbidding avatar change to %s: %d bytes is above the allowed size "
"limit",
mxc,
media_info["media_length"],
media_info.media_length,
)
return False

if self.allowed_avatar_mimetypes:
# Ensure the avatar's file type is allowed
if (
self.allowed_avatar_mimetypes
and media_info["media_type"] not in self.allowed_avatar_mimetypes
and media_info.media_type not in self.allowed_avatar_mimetypes
):
logger.warning(
"Forbidding avatar change to %s: mimetype %s not allowed",
mxc,
media_info["media_type"],
media_info.media_type,
)
return False

Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def is_allowed_mime_type(content_type: str) -> bool:
media_id = profile["avatar_url"].split("/")[-1]
if self._is_mine_server_name(server_name):
media = await self._media_repo.store.get_local_media(media_id)
if media is not None and upload_name == media["upload_name"]:
if media is not None and upload_name == media.upload_name:
logger.info("skipping saving the user avatar")
return True

Expand Down
70 changes: 40 additions & 30 deletions synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple

import attr
from matrix_common.types.mxc_uri import MXCUri

import twisted.internet.error
Expand Down Expand Up @@ -50,6 +51,7 @@
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.media_repository import RemoteMedia
from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
Expand Down Expand Up @@ -245,18 +247,18 @@ async def get_local_media(
Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
if not media_info or media_info.quarantined_by:
respond_404(request)
return

self.mark_recently_accessed(None, media_id)

media_type = media_info["media_type"]
media_type = media_info.media_type
if not media_type:
media_type = "application/octet-stream"
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
media_length = media_info.media_length
upload_name = name if name else media_info.upload_name
url_cache = media_info.url_cache

file_info = FileInfo(None, media_id, url_cache=bool(url_cache))

Expand Down Expand Up @@ -310,16 +312,20 @@ async def get_remote_media(

# We deliberately stream the file outside the lock
if responder:
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
upload_name = name if name else media_info.upload_name
await respond_with_responder(
request, responder, media_type, media_length, upload_name
request,
responder,
media_info.media_type,
media_info.media_length,
upload_name,
)
else:
respond_404(request)

async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
async def get_remote_media_info(
self, server_name: str, media_id: str
) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading
if necessary.
Expand Down Expand Up @@ -353,7 +359,7 @@ async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:

async def _get_remote_media_impl(
self, server_name: str, media_id: str
) -> Tuple[Optional[Responder], dict]:
) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
Expand All @@ -373,15 +379,17 @@ async def _get_remote_media_impl(

# If we have an entry in the DB, try and look for it
if media_info:
file_id = media_info["filesystem_id"]
file_id = media_info.filesystem_id
file_info = FileInfo(server_name, file_id)

if media_info["quarantined_by"]:
if media_info.quarantined_by:
logger.info("Media is quarantined")
raise NotFoundError()

if not media_info["media_type"]:
media_info["media_type"] = "application/octet-stream"
if not media_info.media_type:
media_info = attr.evolve(
media_info, media_type="application/octet-stream"
)

responder = await self.media_storage.fetch_media(file_info)
if responder:
Expand All @@ -403,9 +411,9 @@ async def _get_remote_media_impl(
if not media_info:
raise e

file_id = media_info["filesystem_id"]
if not media_info["media_type"]:
media_info["media_type"] = "application/octet-stream"
file_id = media_info.filesystem_id
if not media_info.media_type:
media_info = attr.evolve(media_info, media_type="application/octet-stream")
file_info = FileInfo(server_name, file_id)

# We generate thumbnails even if another process downloaded the media
Expand All @@ -415,7 +423,7 @@ async def _get_remote_media_impl(
# otherwise they'll request thumbnails and get a 404 if they're not
# ready yet.
await self._generate_thumbnails(
server_name, media_id, file_id, media_info["media_type"]
server_name, media_id, file_id, media_info.media_type
)

responder = await self.media_storage.fetch_media(file_info)
Expand All @@ -425,7 +433,7 @@ async def _download_remote_file(
self,
server_name: str,
media_id: str,
) -> dict:
) -> RemoteMedia:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
Expand Down Expand Up @@ -518,23 +526,25 @@ async def _download_remote_file(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
time_now_ms=time_now_ms,
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)

logger.info("Stored remote media in file %r", fname)

media_info = {
"media_type": media_type,
"media_length": length,
"upload_name": upload_name,
"created_ts": time_now_ms,
"filesystem_id": file_id,
}

return media_info
return RemoteMedia(
media_origin=server_name,
media_id=media_id,
media_type=media_type,
media_length=length,
upload_name=upload_name,
created_ts=time_now_ms,
filesystem_id=file_id,
last_access_ts=time_now_ms,
quarantined_by=None,
)

def _get_thumbnail_requirements(
self, media_type: str
Expand Down
11 changes: 5 additions & 6 deletions synapse/media/url_previewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,14 @@ async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes:
cache_result = await self.store.get_url_cache(url, ts)
if (
cache_result
and cache_result["expires_ts"] > ts
and cache_result["response_code"] / 100 == 2
and cache_result.expires_ts > ts
and cache_result.response_code // 100 == 2
):
# It may be stored as text in the database, not as bytes (such as
# PostgreSQL). If so, encode it back before handing it on.
og = cache_result["og"]
if isinstance(og, str):
og = og.encode("utf8")
return og
if isinstance(cache_result.og, str):
return cache_result.og.encode("utf8")
return cache_result.og

# If this URL can be accessed via an allowed oEmbed, use that instead.
url_to_download = url
Expand Down
16 changes: 8 additions & 8 deletions synapse/rest/media/thumbnail_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def _respond_local_thumbnail(
if not media_info:
respond_404(request)
return
if media_info["quarantined_by"]:
if media_info.quarantined_by:
logger.info("Media is quarantined")
respond_404(request)
return
Expand All @@ -134,7 +134,7 @@ async def _respond_local_thumbnail(
thumbnail_infos,
media_id,
media_id,
url_cache=bool(media_info["url_cache"]),
url_cache=bool(media_info.url_cache),
server_name=None,
)

Expand All @@ -152,7 +152,7 @@ async def _select_or_generate_local_thumbnail(
if not media_info:
respond_404(request)
return
if media_info["quarantined_by"]:
if media_info.quarantined_by:
logger.info("Media is quarantined")
respond_404(request)
return
Expand All @@ -168,7 +168,7 @@ async def _select_or_generate_local_thumbnail(
file_info = FileInfo(
server_name=None,
file_id=media_id,
url_cache=media_info["url_cache"],
url_cache=bool(media_info.url_cache),
thumbnail=info,
)

Expand All @@ -188,7 +188,7 @@ async def _select_or_generate_local_thumbnail(
desired_height,
desired_method,
desired_type,
url_cache=bool(media_info["url_cache"]),
url_cache=bool(media_info.url_cache),
)

if file_path:
Expand All @@ -213,7 +213,7 @@ async def _select_or_generate_remote_thumbnail(
server_name, media_id
)

file_id = media_info["filesystem_id"]
file_id = media_info.filesystem_id

for info in thumbnail_infos:
t_w = info.width == desired_width
Expand All @@ -224,7 +224,7 @@ async def _select_or_generate_remote_thumbnail(
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=server_name,
file_id=media_info["filesystem_id"],
file_id=file_id,
thumbnail=info,
)

Expand Down Expand Up @@ -280,7 +280,7 @@ async def _respond_remote_thumbnail(
m_type,
thumbnail_infos,
media_id,
media_info["filesystem_id"],
media_info.filesystem_id,
url_cache=False,
server_name=server_name,
)
Expand Down
Loading

0 comments on commit ff716b4

Please sign in to comment.