Skip to content

Commit

Permalink
format and lintiing
Browse files Browse the repository at this point in the history
  • Loading branch information
AbstractUmbra committed Nov 17, 2023
1 parent b09e56b commit acafacc
Show file tree
Hide file tree
Showing 13 changed files with 56 additions and 55 deletions.
1 change: 0 additions & 1 deletion api/middleware/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import core


if TYPE_CHECKING:
from api.server import Server

Expand Down
1 change: 0 additions & 1 deletion api/routes/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import core


if TYPE_CHECKING:
from api.server import Server

Expand Down
25 changes: 11 additions & 14 deletions api/routes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,36 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

from starlette.requests import Request
from starlette.responses import JSONResponse, Response

import core


if TYPE_CHECKING:
from starlette.requests import Request

from api.server import Server


logger: logging.Logger = logging.getLogger(__name__)
LOGGER: logging.Logger = logging.getLogger(__name__)


class Auth(core.View):
def __init__(self, app: Server) -> None:
self.app = app

@core.route('/github', methods=['POST'])
@core.route("/github", methods=["POST"])
async def github_auth(self, request: Request) -> Response:
try:
data = await request.json()
code = data.get("code", None)
except Exception as e:
logger.debug(f'Bad JSON body in "/auth/github": {e}')
LOGGER.debug('Bad JSON body in "/auth/github": %s', e)

return JSONResponse({"error": "Bad JSON body passed"}, status_code=421)

Expand All @@ -57,7 +58,7 @@ async def github_auth(self, request: Request) -> Response:

client_id: str = core.config["OAUTH"]["github_id"]
client_secret: str = core.config["OAUTH"]["github_secret"]
url: str = core.config['OAUTH']['redirect']
url: str = core.config["OAUTH"]["redirect"]

data = {
"client_id": client_id,
Expand All @@ -72,28 +73,24 @@ async def github_auth(self, request: Request) -> Response:
"Accept": "application/json",
}

async with self.app.session.post(
"https://github.com/login/oauth/access_token", data=data, headers=headers
) as resp:
async with self.app.session.post("https://github.com/login/oauth/access_token", data=data, headers=headers) as resp:
resp.raise_for_status()

data = await resp.json()

try:
token = data["access_token"]
except KeyError:
return JSONResponse({'error': 'Bad code query sent.'}, status_code=400)
return JSONResponse({"error": "Bad code query sent."}, status_code=400)

async with self.app.session.get(
"https://api.github.com/user", headers={"Authorization": f"Bearer {token}"}
) as resp:
async with self.app.session.get("https://api.github.com/user", headers={"Authorization": f"Bearer {token}"}) as resp:
resp.raise_for_status()

data = await resp.json()
userid = data["id"]
username = data["name"] or data["login"]

user = await self.app.database.refresh_or_create_user(github_id=userid, username=username)
logger.info(f'Refreshed Bearer: id={user.uid} github_id={user.github_id} username={username}')
LOGGER.info("Refreshed Bearer: id=%s github_id=%s username=%s", user.uid, user.github_id, username)

return JSONResponse(user.as_dict(), status_code=200)
36 changes: 17 additions & 19 deletions api/routes/members.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,21 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

from __future__ import annotations

import logging
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from starlette.authentication import requires
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.websockets import WebSocket

import core


if TYPE_CHECKING:
from starlette.requests import Request
from starlette.websockets import WebSocket

from api.server import Server


Expand All @@ -44,23 +45,23 @@ class Members(core.View):
def __init__(self, app: Server) -> None:
self.app = app

@core.route('/dpy/modlog', methods=['POST'])
@requires('member')
@core.route("/dpy/modlog", methods=["POST"])
@requires("member")
async def post_dpy_modlog(self, request: Request) -> Response:
application: core.ApplicationModel = request.user.model

try:
data = await request.json()
except Exception as e:
logger.debug(f'Received bad JSON in "/members/dpy/modlog": {e}')
return JSONResponse({'error': 'Bad POST JSON Body.'}, status_code=400)
logger.debug('Received bad JSON in "/members/dpy/modlog": %s', e)
return JSONResponse({"error": "Bad POST JSON Body."}, status_code=400)

payload: dict[str, Any] = {
'op': core.WebsocketOPCodes.EVENT,
'subscription': core.WebsocketSubscriptions.DPY_MOD_LOG,
'application': application.uid,
'application_name': application.name,
'payload': data
"op": core.WebsocketOPCodes.EVENT,
"subscription": core.WebsocketSubscriptions.DPY_MOD_LOG,
"application": application.uid,
"application_name": application.name,
"payload": data,
}

count = 0
Expand All @@ -69,18 +70,15 @@ async def post_dpy_modlog(self, request: Request) -> Response:
websockets: list[WebSocket] = list(self.app.sockets[subscriber].values())

total += len(websockets)
payload['user_id'] = subscriber
payload["user_id"] = subscriber

for websocket in websockets:
try:
await websocket.send_json(data=payload)
except Exception as e:
logger.debug(f'Failed to send payload to a websocket for "{subscriber}": {e}')
logger.debug('Failed to send payload to a websocket for "%s": %s', subscriber, e)
else:
count += 1

to_send: dict[str, int] = {
'subscribers': total,
'successful': count
}
to_send: dict[str, int] = {"subscribers": total, "successful": count}
return JSONResponse(to_send, status_code=200)
1 change: 0 additions & 1 deletion api/routes/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

import core


if TYPE_CHECKING:
from api.server import Server

Expand Down
1 change: 0 additions & 1 deletion core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from .tokens import *
from .utils import *


# Setup root logging formatter...
handler: logging.StreamHandler[TextIO] = logging.StreamHandler()
handler.setFormatter(ColourFormatter())
Expand Down
8 changes: 7 additions & 1 deletion core/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import pathlib

import tomllib

_PATH = pathlib.Path("config.toml")

if not _PATH.exists():
raise RuntimeError("No config file found.")

with open('config.toml', 'rb') as fp:
with _PATH.open("rb") as fp:
config = tomllib.load(fp)
30 changes: 18 additions & 12 deletions core/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,35 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

from __future__ import annotations

import datetime
import itertools
import logging
from typing import Any, Self
import pathlib
from typing import TYPE_CHECKING, Any, Self

import asyncpg
from starlette.requests import Request
from starlette.responses import Response

import core
from core.config import config

from .models import *

if TYPE_CHECKING:
from starlette.requests import Request
from starlette.responses import Response

logger: logging.Logger = logging.getLogger(__name__)
LOGGER: logging.Logger = logging.getLogger(__name__)


class Database:
_pool: asyncpg.Pool[asyncpg.Record]

def __init__(self) -> None:
self.schema_file = pathlib.Path("core/databases/SCHEMA.sql")

async def __aenter__(self) -> Self:
await self.setup()
return self
Expand All @@ -51,16 +57,16 @@ async def __aexit__(self, *args: Any) -> None:
await self._pool.close()

async def setup(self) -> Self:
logger.info('Setting up Database.')
LOGGER.info("Setting up Database.")

self._pool = await asyncpg.create_pool(dsn=config['DATABASE']['dsn']) # type: ignore
self._pool = await asyncpg.create_pool(dsn=config["DATABASE"]["dsn"]) # type: ignore
assert self._pool

async with self._pool.acquire() as connection:
with open('core/database/SCHEMA.sql', 'r') as schema:
with self.schema_file.open() as schema:
await connection.execute(schema.read())

logger.info('Completed Database Setup.')
LOGGER.info("Completed Database Setup.")

return self

Expand Down Expand Up @@ -163,7 +169,7 @@ async def create_application(self, *, user_id: int, name: str, description: str)

query: str = """
WITH create_application AS (
INSERT INTO tokens(user_id, token_name, token_description, token) VALUES ($1, $2, $3, $4) RETURNING *
INSERT INTO tokens(user_id, token_name, token_description, token) VALUES ($1, $2, $3, $4) RETURNING *
)
SELECT * FROM create_application
JOIN users u ON u.uid = create_application.user_id
Expand All @@ -179,7 +185,7 @@ async def add_log(self, *, request: Request, response: Response) -> None:
query: str = """INSERT INTO logs VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"""

try:
body: str | None = str(request._body.decode(encoding='UTF-8')) # pyright: ignore [reportPrivateUsage]
body: str | None = str(request._body.decode(encoding="UTF-8")) # pyright: ignore [reportPrivateUsage]
except AttributeError:
body = None

Expand All @@ -195,7 +201,7 @@ async def add_log(self, *, request: Request, response: Response) -> None:
tid = model.tid
uid = model.uid

host: str | None = getattr(request.client, 'host', None)
host: str | None = getattr(request.client, "host", None)
ip: str | None = request.headers.get("X-Forwarded-For", host)

async with self._pool.acquire() as connection:
Expand Down Expand Up @@ -236,7 +242,7 @@ async def fetch_all_user_uses(self, *, user_id: int) -> dict[Any, int]:
logs.sort(key=lambda l: (l.tid is None, l.tid))

grouped = [(k, len(list(group))) for k, group in itertools.groupby(logs, lambda l: l.tid)]
base = {'total': len(logs)}
base = {"total": len(logs)}
base.update(grouped) # type: ignore

return base
1 change: 0 additions & 1 deletion core/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import asyncpg


__all__ = ('UserModel', 'ApplicationModel', 'LogModel')


Expand Down
1 change: 0 additions & 1 deletion core/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"""
import logging


__all__ = ('ColourFormatter',)


Expand Down
1 change: 0 additions & 1 deletion core/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import base64
import secrets


__all__ = ('EPOCH', 'generate_token', 'id_from_token')


Expand Down
1 change: 0 additions & 1 deletion core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from starlette.routing import Route
from starlette.types import Receive, Scope, Send


__all__ = (
'route',
'View',
Expand Down
4 changes: 3 additions & 1 deletion launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ async def main() -> None:
async with aiohttp.ClientSession() as session, core.Database() as database:
app: api.Server = api.Server(session=session, database=database)

config = uvicorn.Config(app, port=core.config['SERVER']['port'], ws_ping_interval=10, ws_ping_timeout=None)
config = uvicorn.Config(
app, host="0.0.0.0", port=core.config['SERVER']['port'], ws_ping_interval=10, ws_ping_timeout=None
)
server = uvicorn.Server(config)
await server.serve()

Expand Down

0 comments on commit acafacc

Please sign in to comment.