Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket pubsub #85

Open
wants to merge 10 commits into
base: job-queue
Choose a base branch
from
38 changes: 38 additions & 0 deletions alembic/versions/5d464f430f43_add_job_logging_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""add job logging table

Revision ID: 5d464f430f43
Revises: bdf5e21a88df
Create Date: 2024-11-11 19:04:46.977449+00:00

"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = '5d464f430f43'
down_revision = 'bdf5e21a88df'
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('job_status',
sa.Column('internal_id', sa.Integer(), nullable=False),
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('status', sa.Enum('PENDING', 'RECEIVED', 'STARTED', 'RETRY', 'FAILURE', 'SUCCESS', 'REVOKED', name='jobstatus'), nullable=False),
sa.Column('type', sa.Text(), nullable=True),
sa.Column('status_date', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('CURRENT_TIMESTAMP')),
sa.PrimaryKeyConstraint('internal_id')
)
op.create_index(op.f('ix_job_status_internal_id'), 'job_status', ['internal_id'], unique=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_job_status_internal_id'), table_name='job_status')
op.drop_table('job_status')
op.execute(sa.text('DROP TYPE jobstatus'))
# ### end Alembic commands ###
6 changes: 4 additions & 2 deletions app/api/api_v1/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from .endpoints import (
submission_router, assignment_router, user_router,
student_router, instructor_router, course_router,
settings_router, auth_router, lms_router, job_router
settings_router, auth_router, lms_router, job_router,
websocket_router
)

api_router = APIRouter()
Expand All @@ -15,4 +16,5 @@
api_router.include_router(settings_router.router, tags=["settings"])
api_router.include_router(auth_router.router, tags=["auth"])
api_router.include_router(lms_router.router, tags=["lms"])
api_router.include_router(job_router.router, tags=["jobs"])
api_router.include_router(job_router.router, tags=["jobs"])
api_router.include_router(websocket_router.router, tags=["websocket"])
18 changes: 6 additions & 12 deletions app/api/api_v1/endpoints/lms_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import APIRouter, Request, Depends, UploadFile, File
from sqlalchemy.orm import Session
from app.celery.tasks import downsync_task
from app.services import LmsSyncService, AssignmentService
from app.services import LmsSyncService, AssignmentService, JobStatusService
from app.schemas import JobSchema
from app.core.dependencies import (
get_db, PermissionDependency,
Expand All @@ -29,18 +29,12 @@ async def downsync(
task = downsync_task.delay()
return JobSchema.from_async_result(task)

@router.post("/lms/downsync/students")
async def downsync_students(
@router.get("/lms/downsync/status", response_model=JobSchema | None)
async def get_downsync_job(
*,
db: Session = Depends(get_db),
perm: None = Depends(PermissionDependency(UserIsInstructorPermission))
):
return await LmsSyncService(db).sync_students()

@router.post("/lms/downsync/assignments")
async def downsync_assignments(
*,
db: Session = Depends(get_db),
perm: None = Depends(PermissionDependency(UserIsInstructorPermission))
):
return await LmsSyncService(db).sync_assignments()
status = JobStatusService(db).get_singleton_job_status(downsync_task)
if status is None: return None
return JobSchema.from_orm(status)
27 changes: 27 additions & 0 deletions app/api/api_v1/endpoints/websocket_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, WebSocketException, Depends
from sqlalchemy.orm import Session
from websockets.exceptions import ConnectionClosed
from app.core.dependencies import get_db, PermissionDependency, RequireLoginPermission
from app.services import WebsocketManagerService

router = APIRouter()

@router.websocket("/websocket")
async def accept_websocket(
websocket: WebSocket,
db: Session = Depends(get_db),
perm: None = Depends(PermissionDependency(RequireLoginPermission))
):
websocket_manager = WebsocketManagerService(db)
await websocket_manager.connect(websocket)
try:
while True:
await websocket_manager.handle_client_message(websocket)

except (WebSocketDisconnect, ConnectionClosed) as e:
# NOTE: `WebSocketDisconnect` and `ConnectionClosed` are both expected errors that may be thrown.
await websocket_manager.disconnect(websocket)

except Exception as e:
# We could be doing some more complex logic here, but simply disconnecting the client is sufficient.
await websocket_manager.disconnect(websocket)
73 changes: 52 additions & 21 deletions app/celery/signals.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,63 @@
import inspect
import asyncio
from celery.signals import (
task_received, task_internal_error, task_revoked,
task_success, task_failure, task_retry,
task_success, task_failure, task_retry, task_prerun
)
from celery_singleton.util import generate_lock
from celery_singleton.config import Config
from celery_singleton.backends import get_backend
from app.celery.worker import celery_app
from app.database import SessionLocal
from app.services import WebsocketManagerService, JobStatusService, InstructorService
from app.models.job_status import JobStatusModel, JobStatus
from app.schemas import JobSchema, WebsocketJobStatusMessage

def handle_job_status_update(job_id: str, job_type: str | None, job_status: JobStatus) -> None:
""" Store a job status update in the database and dispatch a websocket event. """
async def _handle_job_status_update():
with SessionLocal() as session:
job_status_model = JobStatusService(session).create_job_status_update(job_id, job_type, job_status)
ws_message = WebsocketJobStatusMessage(
data=JobSchema.from_orm(job_status_model)
)

if job_type == "downsync":
user_list = await InstructorService(session).list_instructors()
else:
user_list = []

WebsocketManagerService.publish_sync_pubsub_ws_message(ws_message, user_list)

asyncio.run(_handle_job_status_update())

@task_received.connect
def task_received_handler(request, **kw):
handle_job_status_update(request.task_id, request.task_name, JobStatus.RECEIVED)

@task_prerun.connect
def task_prerun_handler(sender=None, task_id=None, **kw):
handle_job_status_update(task_id, sender.name, JobStatus.STARTED)

@task_success.connect
def task_success_handler(sender=None, result=None, **kw):
handle_job_status_update(sender.request.id, sender.name, JobStatus.SUCCESS)

@task_failure.connect
def task_failure_handler(sender=None, task_id=None, **kw):
handle_job_status_update(task_id, sender.name, JobStatus.FAILURE)

@task_internal_error.connect
def task_internal_error_handler(sender=None, task_id=None, **kw):
handle_job_status_update(task_id, sender.name, JobStatus.FAILURE)

@task_retry.connect
def task_retry_handler(sender=None, request=None, **kw):
handle_job_status_update(request.task_id, request.task_name, JobStatus.RETRY)

@task_revoked.connect
def task_revoked_handler(sender=None, request=None, **kw):
handle_job_status_update(request.task_id, request.task_name, JobStatus.REVOKED)

""" Source: https://github.com/steinitzu/celery-singleton/issues/29 """
@task_revoked.connect
Expand Down Expand Up @@ -50,23 +101,3 @@ def clean_singleton_lock_on_revoke(sender=None, request=None, reason=None, **kwa
if cache_task_id and cache_task_id.startswith(task_id):
print(f"Cleaning singletion lock: {redis_key}, task_id: {task_id}")
backend.clear(redis_key)

# @task_received.connect
# def task_received_handler(sender=None, request=None, reason=None, **kw):
# TaskLog.objects.(sender=sender, request=request, reason=reason)

# @task_success.connect
# def task_success_handler(sender=None, result=None, **kw):
# TaskLog.objects.create_from_success_signal(sender=sender, result=result)

# @task_failure.connect
# def task_failure_handler(sender=None, task_id=None, exception=None, **kw):
# TaskLog.objects.create_from_failure_signal(sender, task_id, exception)

# @task_retry.connect
# def task_retry_handler(sender=None, request=None, reason=None, **kw):
# TaskLog.objects.create_from_retry_signal(sender=sender, request=request, reason=reason)

# @task_internal_error.connect
# def task_internal_error_handler(sender=None, request=None, reason=None, **kw):
# TaskLog.objects.(sender=sender, request=request, reason=reason)
7 changes: 3 additions & 4 deletions app/celery/worker.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import json
from redis import Redis
from celery import Celery
from celery.result import AsyncResult
from app.database import redis_broker_client
from app.core.config import settings

redis_client = Redis.from_url(settings.CELERY_BROKER_URI)
celery_app = Celery(
__name__,
broker=settings.CELERY_BROKER_URI,
backend=settings.CELERY_RESULT_BACKEND,
backend=settings.CELERY_RESULT_BACKEND_URI,
result_extended=True
)
celery_app.conf.update(
Expand Down Expand Up @@ -45,7 +44,7 @@ def get_task_ids_by_name(

# Load pending tasks (still in queue)
if include_pending:
for item in redis_client.lrange("celery", 0, -1):
for item in redis_broker_client.lrange("celery", 0, -1):
task_data = json.loads(item)
if task_data.get("headers", {}).get("task") == task_name:
task_id = task_data["headers"].get("id")
Expand Down
18 changes: 15 additions & 3 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ class Settings(BaseSettings):
BROKER_PASSWORD: str

CELERY_BROKER_URI: Optional[str] = None # computed
CELERY_RESULT_BACKEND: Optional[str] = None # computed
CELERY_RESULT_BACKEND_URI: Optional[str] = None # computed
# Techincally, we could use either of the above as the general database as well,
# but using a new database helps with organization
REDIS_GENERAL_DATABASE_URI: Optional[str] = None # computed

# Gitea
GITEA_SSH_URL: str
Expand Down Expand Up @@ -109,15 +112,24 @@ def assemble_broker_uri(cls, v: Optional[str], values: Dict[str, Any]) -> str:
port = values.get("BROKER_PORT")
return f"redis://{ user }:{ pw }@{ host }:{ port }/0"

@validator("CELERY_RESULT_BACKEND", pre=True)
def assemble_result_backend(cls, v: Optional[str], values: Dict[str, Any]) -> str:
@validator("CELERY_RESULT_BACKEND_URI", pre=True)
def assemble_result_backend_uri(cls, v: Optional[str], values: Dict[str, Any]) -> str:
if isinstance(v, str): return v
user = values.get("BROKER_USER")
pw = values.get("BROKER_PASSWORD")
host = values.get("BROKER_HOST")
port = values.get("BROKER_PORT")
return f"redis://{ user }:{ pw }@{ host }:{ port }/1"

@validator("REDIS_GENERAL_DATABASE_URI", pre=True)
def assemble_redis_general_connection_uri(cls, v: Optional[str], values: Dict[str, Any]) -> str:
if isinstance(v, str): return v
user = values.get("BROKER_USER")
pw = values.get("BROKER_PASSWORD")
host = values.get("BROKER_HOST")
port = values.get("BROKER_PORT")
return f"redis://{ user }:{ pw }@{ host }:{ port }/2"

@validator("SQLALCHEMY_DATABASE_URI", pre=True)
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> str:
if isinstance(v, str): return v
Expand Down
30 changes: 15 additions & 15 deletions app/core/dependencies/permission.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import List, Type

from fastapi import Request
from fastapi.requests import HTTPConnection
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase

Expand All @@ -20,25 +20,25 @@ def __init__(self, db, user):
self.user = user

@abstractmethod
async def verify_permission(self, request: Request):
async def verify_permission(self, conn: HTTPConnection):
pass

# For endpoints that require a user to be logged in, but nothing beyond that.
class RequireLoginPermission(BasePermission):
async def verify_permission(self, request: Request):
async def verify_permission(self, conn: HTTPConnection):
if self.user is None or self.user.role is None:
raise UnauthorizedException()

class UserIsStudentPermission(RequireLoginPermission):
async def verify_permission(self, request: Request):
await super().verify_permission(request)
async def verify_permission(self, conn: HTTPConnection):
await super().verify_permission(conn)

if not isinstance(self.user, StudentModel):
raise NotAStudentException()

class UserIsInstructorPermission(RequireLoginPermission):
async def verify_permission(self, request: Request):
await super().verify_permission(request)
async def verify_permission(self, conn: HTTPConnection):
await super().verify_permission(conn)

if not isinstance(self.user, InstructorModel):
raise NotAnInstructorException()
Expand All @@ -48,8 +48,8 @@ async def verify_permission(self, request: Request):
# access to everyone other than students without restricting access
# to just instructors.
class UserIsSuperuserPermission(RequireLoginPermission):
async def verify_permission(self, request: Request):
await super().verify_permission(request)
async def verify_permission(self, conn: HTTPConnection):
await super().verify_permission(conn)

if isinstance(self.user, StudentModel):
raise NotASuperuserException()
Expand All @@ -58,8 +58,8 @@ async def verify_permission(self, request: Request):
class BaseRolePermission(RequireLoginPermission):
permission: UserPermission

async def verify_permission(self, request: Request):
await super().verify_permission(request)
async def verify_permission(self, conn: HTTPConnection):
await super().verify_permission(conn)

for permission in self.user.role.permissions:
if permission == self.permission:
Expand Down Expand Up @@ -121,22 +121,22 @@ def __init__(self, *permissions: List[Type[BasePermission]]):
self.model: APIKey = APIKey(**{"in": APIKeyIn.header}, name="Authorization")
self.scheme_name = self.__class__.__name__

async def __call__(self, request: Request):
async def __call__(self, conn: HTTPConnection):
from app.services import UserService

if settings.DISABLE_AUTHENTICATION and settings.IMPERSONATE_USER is not None:
if request.user.onyen is None:
if conn.user.onyen is None:
raise UserNotFoundException(f'The impersonated user "{ settings.IMPERSONATE_USER }" does not exist.')
elif settings.DISABLE_AUTHENTICATION and settings.IMPERSONATE_USER is None:
# If authentication is disabled, we treat the anonymous user as if they have every permission.
return

with SessionLocal() as session:
try:
user = await UserService(session).get_user_by_onyen(request.user.onyen)
user = await UserService(session).get_user_by_onyen(conn.user.onyen)
except UserNotFoundException:
user = None

for permission in self.permissions:
cls = permission(session, user)
await cls.verify_permission(request=request)
await cls.verify_permission(conn=conn)
Loading