diff --git a/diracx-db/pyproject.toml b/diracx-db/pyproject.toml index fc4ec487..1ebcefc1 100644 --- a/diracx-db/pyproject.toml +++ b/diracx-db/pyproject.toml @@ -37,6 +37,7 @@ TaskQueueDB = "diracx.db.sql:TaskQueueDB" [project.entry-points."diracx.db.os"] JobParametersDB = "diracx.db.os:JobParametersDB" +PilotLogsDB = "diracx.db.os:PilotLogsDB" [tool.setuptools.packages.find] where = ["src"] diff --git a/diracx-db/src/diracx/db/os/__init__.py b/diracx-db/src/diracx/db/os/__init__.py index 535e2a95..c1ce89bc 100644 --- a/diracx-db/src/diracx/db/os/__init__.py +++ b/diracx-db/src/diracx/db/os/__init__.py @@ -1,5 +1,9 @@ from __future__ import annotations -__all__ = ("JobParametersDB",) +__all__ = ( + "JobParametersDB", + "PilotLogsDB", +) from .job_parameters import JobParametersDB +from .pilot_logs import PilotLogsDB diff --git a/diracx-db/src/diracx/db/os/pilot_logs.py b/diracx-db/src/diracx/db/os/pilot_logs.py new file mode 100644 index 00000000..5c901191 --- /dev/null +++ b/diracx-db/src/diracx/db/os/pilot_logs.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from diracx.db.os.utils import BaseOSDB + + +class PilotLogsDB(BaseOSDB): + fields = { + "PilotStamp": {"type": "keyword"}, + "PilotID": {"type": "long"}, + "SubmissionTime": {"type": "date"}, + "LineNumber": {"type": "long"}, + "Message": {"type": "text"}, + "VO": {"type": "keyword"}, + "timestamp": {"type": "date"}, + } + index_prefix = "pilot_logs" + + def index_name(self, doc_id: int) -> str: + # TODO decide how to define the index name + # use pilot ID + return f"{self.index_prefix}_{doc_id // 1e6:.0f}" diff --git a/diracx-db/src/diracx/db/os/utils.py b/diracx-db/src/diracx/db/os/utils.py index 431cceaa..eb4e92fc 100644 --- a/diracx-db/src/diracx/db/os/utils.py +++ b/diracx-db/src/diracx/db/os/utils.py @@ -13,6 +13,7 @@ from typing import Any, Self from opensearchpy import AsyncOpenSearch +from opensearchpy.helpers import async_bulk from diracx.core.exceptions import InvalidQueryError from diracx.core.extensions import select_from_extension @@ -190,6 +191,13 @@ async def upsert(self, doc_id, document) -> None: ) print(f"{response=}") + async def bulk_insert(self, index_name: str, docs: list[dict[str, Any]]) -> None: + # bulk inserting to database + n_inserted = await async_bulk( + self.client, actions=[doc | {"_index": index_name} for doc in docs] + ) + logger.info("Inserted %s documents to %s", n_inserted, index_name) + async def search( self, parameters, search, sorts, *, per_page: int = 100, page: int | None = None ) -> list[dict[str, Any]]: @@ -231,6 +239,15 @@ async def search( return hits + async def delete(self, query: list[dict[str, Any]]) -> None: + + # Delete multiple documents by query. + + body = {} + if query: + body["query"] = apply_search_filters(self.fields, query) + await self.client.delete_by_query(body=body, index=f"{self.index_prefix}*") + def require_type(operator, field_name, field_type, allowed_types): if field_type not in allowed_types: diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index 7bae7dd8..e76cda92 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -48,6 +48,7 @@ types = [ ] [project.entry-points."diracx.services"] +pilots = "diracx.routers.pilots:router" jobs = "diracx.routers.jobs:router" config = "diracx.routers.configuration:router" auth = "diracx.routers.auth:router" @@ -56,6 +57,7 @@ auth = "diracx.routers.auth:router" [project.entry-points."diracx.access_policies"] WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy" SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy" +PilotLogsAccessPolicy = "diracx.routers.pilots.access_policies:PilotLogsAccessPolicy" # Minimum version of the client supported [project.entry-points."diracx.min_client_version"] diff --git a/diracx-routers/src/diracx/routers/dependencies.py b/diracx-routers/src/diracx/routers/dependencies.py index ab40190b..73d4c420 100644 --- a/diracx-routers/src/diracx/routers/dependencies.py +++ b/diracx-routers/src/diracx/routers/dependencies.py @@ -8,6 +8,7 @@ "SandboxMetadataDB", "TaskQueueDB", "PilotAgentsDB", + "PilotLogsDB", "add_settings_annotation", "AvailableSecurityProperties", ) @@ -21,6 +22,7 @@ from diracx.core.properties import SecurityProperty from diracx.core.settings import DevelopmentSettings as _DevelopmentSettings from diracx.db.os import JobParametersDB as _JobParametersDB +from diracx.db.os import PilotLogsDB as _PilotLogsDB from diracx.db.sql import AuthDB as _AuthDB from diracx.db.sql import JobDB as _JobDB from diracx.db.sql import JobLoggingDB as _JobLoggingDB @@ -36,7 +38,7 @@ def add_settings_annotation(cls: T) -> T: return Annotated[cls, Depends(cls.create)] # type: ignore -# Databases +# SQL Databases AuthDB = Annotated[_AuthDB, Depends(_AuthDB.transaction)] JobDB = Annotated[_JobDB, Depends(_JobDB.transaction)] JobLoggingDB = Annotated[_JobLoggingDB, Depends(_JobLoggingDB.transaction)] @@ -46,9 +48,9 @@ def add_settings_annotation(cls: T) -> T: ] TaskQueueDB = Annotated[_TaskQueueDB, Depends(_TaskQueueDB.transaction)] -# Opensearch databases +# OpenSearch Databases JobParametersDB = Annotated[_JobParametersDB, Depends(_JobParametersDB.session)] - +PilotLogsDB = Annotated[_PilotLogsDB, Depends(_PilotLogsDB.session)] # Miscellaneous Config = Annotated[_Config, Depends(ConfigSource.create)] diff --git a/diracx-routers/src/diracx/routers/pilots/__init__.py b/diracx-routers/src/diracx/routers/pilots/__init__.py new file mode 100644 index 00000000..3e9084bc --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from logging import getLogger + +from ..fastapi_classes import DiracxRouter +from .logging import router as logging_router + +logger = getLogger(__name__) + +router = DiracxRouter() +router.include_router(logging_router) diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py new file mode 100644 index 00000000..68b2ebe7 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from enum import StrEnum, auto +from typing import Annotated, Callable + +from fastapi import Depends, HTTPException, status + +from diracx.core.properties import ( + GENERIC_PILOT, + NORMAL_USER, + OPERATOR, + PILOT, + SERVICE_ADMINISTRATOR, +) +from diracx.routers.access_policies import BaseAccessPolicy + +from ..utils.users import AuthorizedUserInfo + + +class ActionType(StrEnum): + #: Create/update pilot log records + CREATE = auto() + #: delete pilot logs + DELETE = auto() + #: Search + QUERY = auto() + + +class PilotLogsAccessPolicy(BaseAccessPolicy): + """Rules: + Only PILOT, GENERIC_PILOT, SERVICE_ADMINISTRATOR and OPERATOR can process log records. + Policies for other actions to be determined. + """ + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + *, + action: ActionType | None = None, + ): + + if action is None: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, detail="Action is a mandatory argument" + ) + + if GENERIC_PILOT in user_info.properties and action == ActionType.CREATE: + return user_info + if PILOT in user_info.properties and action == ActionType.CREATE: + return user_info + if NORMAL_USER in user_info.properties and action == ActionType.QUERY: + return user_info + if SERVICE_ADMINISTRATOR in user_info.properties: + return user_info + if OPERATOR in user_info.properties: + return user_info + + raise HTTPException(status.HTTP_403_FORBIDDEN, detail=user_info.properties) + + +CheckPilotLogsPolicyCallable = Annotated[Callable, Depends(PilotLogsAccessPolicy.check)] diff --git a/diracx-routers/src/diracx/routers/pilots/logging.py b/diracx-routers/src/diracx/routers/pilots/logging.py new file mode 100644 index 00000000..b182890d --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/logging.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import datetime +import logging + +from fastapi import HTTPException, status +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.exc import NoResultFound + +from diracx.core.exceptions import InvalidQueryError +from diracx.core.properties import OPERATOR, SERVICE_ADMINISTRATOR +from diracx.db.sql.pilot_agents.schema import PilotAgents +from diracx.db.sql.utils import BaseSQLDB + +from ..dependencies import PilotLogsDB +from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthorizedUserInfo +from .access_policies import ActionType, CheckPilotLogsPolicyCallable + +logger = logging.getLogger(__name__) +router = DiracxRouter() + + +class LogLine(BaseModel): + line_no: int + line: str + + +class LogMessage(BaseModel): + pilot_stamp: str + lines: list[LogLine] + vo: str + + +class DateRange(BaseModel): + min: str | None = None # expects a string in ISO 8601 ("%Y-%m-%dT%H:%M:%S.%f%z") + max: str | None = None # expects a string in ISO 8601 ("%Y-%m-%dT%H:%M:%S.%f%z") + + +@router.post("/") +async def send_message( + data: LogMessage, + pilot_logs_db: PilotLogsDB, + check_permissions: CheckPilotLogsPolicyCallable, +) -> int: + + logger.warning(f"Message received '{data}'") + user_info = await check_permissions(action=ActionType.CREATE) + pilot_id = 0 # need to get pilot id from pilot_stamp (via PilotAgentsDB) + # also add a timestamp to be able to select and delete logs based on pilot creation dates, even if corresponding + # pilots have been already deleted from PilotAgentsDB (so the logs can live longer than pilots). + submission_time = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) + pilot_agents_db = BaseSQLDB.available_implementations("PilotAgentsDB")[0] + url = BaseSQLDB.available_urls()["PilotAgentsDB"] + db = pilot_agents_db(url) + + try: + async with db.engine_context(): + async with db: + stmt = select(PilotAgents.pilot_id, PilotAgents.submission_time).where( + PilotAgents.pilot_stamp == data.pilot_stamp + ) + pilot_id, submission_time = (await db.conn.execute(stmt)).one() + except NoResultFound as exc: + logger.error( + f"Cannot determine PilotID for requested PilotStamp: {data.pilot_stamp}, Error: {exc}." + ) + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc + + docs = [] + for line in data.lines: + docs.append( + { + "PilotStamp": data.pilot_stamp, + "PilotID": pilot_id, + "SubmissionTime": submission_time, + "VO": user_info.vo, + "LineNumber": line.line_no, + "Message": line.line, + } + ) + await pilot_logs_db.bulk_insert(pilot_logs_db.index_name(pilot_id), docs) + return pilot_id + + +@router.get("/logs") +async def get_logs( + pilot_id: int, + db: PilotLogsDB, + check_permissions: CheckPilotLogsPolicyCallable, +) -> list[dict]: + + logger.warning(f"Retrieving logs for pilot ID '{pilot_id}'") + user_info = await check_permissions(action=ActionType.QUERY) + + # here, users with privileged properties will see logs from all VOs. Is it what we want ? + search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}] + if _non_privileged(user_info): + search_params.append( + {"parameter": "VO", "operator": "eq", "value": user_info.vo} + ) + result = await db.search( + ["Message"], + search_params, + [{"parameter": "LineNumber", "direction": "asc"}], + ) + if not result: + return [{"Message": f"No logs for pilot ID = {pilot_id}"}] + return result + + +@router.delete("/logs") +async def delete( + pilot_id: int, + data: DateRange, + db: PilotLogsDB, + check_permissions: CheckPilotLogsPolicyCallable, +) -> str: + """Delete either logs for a specific PilotID or a creation date range. + Non-privileged users can only delete log files within their own VO. + """ + message = "no-op" + user_info = await check_permissions(action=ActionType.DELETE) + non_privil_params = {"parameter": "VO", "operator": "eq", "value": user_info.vo} + + # id pilot_id is provided we ignore data.min and data.max + if data.min and data.max and not pilot_id: + raise InvalidQueryError( + "This query requires a range operator definition in DiracX" + ) + + if pilot_id: + search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}] + if _non_privileged(user_info): + search_params.append(non_privil_params) + await db.delete(search_params) + message = f"Logs for pilot ID '{pilot_id}' successfully deleted" + + elif data.min: + logger.warning(f"Deleting logs for pilots with submission data >='{data.min}'") + search_params = [ + {"parameter": "SubmissionTime", "operator": "gt", "value": data.min} + ] + if _non_privileged(user_info): + search_params.append(non_privil_params) + await db.delete(search_params) + message = f"Logs for for pilots with submission data >='{data.min}' successfully deleted" + + return message + + +def _non_privileged(user_info: AuthorizedUserInfo): + return ( + SERVICE_ADMINISTRATOR not in user_info.properties + and OPERATOR not in user_info.properties + ) diff --git a/diracx-routers/tests/pilots/test_access_policies.py b/diracx-routers/tests/pilots/test_access_policies.py new file mode 100644 index 00000000..c11a26ea --- /dev/null +++ b/diracx-routers/tests/pilots/test_access_policies.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from contextlib import nullcontext +from unittest.mock import MagicMock + +import pytest +from fastapi import HTTPException + +from diracx.core.properties import ( + GENERIC_PILOT, + NORMAL_USER, + OPERATOR, + PILOT, + SERVICE_ADMINISTRATOR, +) +from diracx.routers.pilots.access_policies import ( + ActionType, + PilotLogsAccessPolicy, +) + + +@pytest.mark.parametrize( + "user, action, expectation", + [ + (PILOT, ActionType.CREATE, nullcontext()), + (PILOT, ActionType.QUERY, pytest.raises(HTTPException, match="403")), + (PILOT, ActionType.DELETE, pytest.raises(HTTPException, match="403")), + (GENERIC_PILOT, ActionType.CREATE, nullcontext()), + (GENERIC_PILOT, ActionType.QUERY, pytest.raises(HTTPException, match="403")), + (GENERIC_PILOT, ActionType.DELETE, pytest.raises(HTTPException, match="403")), + (SERVICE_ADMINISTRATOR, ActionType.CREATE, nullcontext()), + (SERVICE_ADMINISTRATOR, ActionType.QUERY, nullcontext()), + (SERVICE_ADMINISTRATOR, ActionType.DELETE, nullcontext()), + (OPERATOR, ActionType.CREATE, nullcontext()), + (OPERATOR, ActionType.QUERY, nullcontext()), + (OPERATOR, ActionType.DELETE, nullcontext()), + (NORMAL_USER, ActionType.CREATE, pytest.raises(HTTPException, match="403")), + (NORMAL_USER, ActionType.QUERY, nullcontext()), + (NORMAL_USER, ActionType.DELETE, pytest.raises(HTTPException, match="403")), + ( + "malicious_user", + ActionType.CREATE, + pytest.raises(HTTPException, match="403"), + ), + ("malicious_user", ActionType.QUERY, pytest.raises(HTTPException, match="403")), + ( + "malicious_user", + ActionType.DELETE, + pytest.raises(HTTPException, match="403"), + ), + ("any_user", None, pytest.raises(HTTPException, match="400")), + ], +) +async def test_access_policies(user, action, expectation): + user_info = MagicMock() + user_info.properties = [user] + with expectation: + ret = await PilotLogsAccessPolicy.policy( + "PilotLogsAccessPolicy", user_info, action=action + ) + assert user in ret.properties diff --git a/diracx-routers/tests/pilots/test_pilot_logger.py b/diracx-routers/tests/pilots/test_pilot_logger.py new file mode 100644 index 00000000..a538df29 --- /dev/null +++ b/diracx-routers/tests/pilots/test_pilot_logger.py @@ -0,0 +1,100 @@ +from contextlib import nullcontext +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, patch + +import pytest +from sqlalchemy import inspect, update + +from diracx.core.properties import PILOT +from diracx.db.os import PilotLogsDB +from diracx.db.sql import PilotAgentsDB +from diracx.db.sql.pilot_agents.schema import PilotAgents +from diracx.routers.pilots.logging import ( + LogLine, + LogMessage, + get_logs, + send_message, +) +from diracx.testing.mock_osdb import MockOSDBMixin + +# class PilotLogsDB(MockOSDBMixin, PilotLogsDB): +# pass +# PilotLogsDB = fake_available_osdb_implementations("PilotLogsDB", +# real_available_implementations=BaseOSDB.available_implementations)[0] + + +@pytest.fixture +async def pilot_agents_db(tmp_path) -> PilotAgentsDB: + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +@pytest.fixture +async def pilot_logs_db(): + # create a class that has sqlite backend replacing OpenSearch PilotLogsDB + m_pilot_logs_db = type("JobParametersDB", (MockOSDBMixin, PilotLogsDB), {}) + + db = m_pilot_logs_db( + connection_kwargs={"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"} + ) + async with db.client_context(): + await db.create_index_template() + yield db + + +@patch("diracx.routers.pilots.logging.BaseSQLDB.available_implementations") +@patch("diracx.routers.pilots.logging.BaseSQLDB.available_urls") +async def test_logging( + mock_url, mock_impl, pilot_logs_db: "PilotLogsDB", pilot_agents_db: PilotAgentsDB +): + + async with pilot_agents_db as db: + # Add a pilot reference + upper_limit = 6 + refs = [f"ref_{i}" for i in range(1, upper_limit)] + stamps = [f"stamp_{i}" for i in range(1, upper_limit)] + stamp_dict = dict(zip(refs, stamps)) + + await db.add_pilot_references( + refs, "test_vo", grid_type="DIRAC", pilot_stamps=stamp_dict + ) + tables = await db.conn.run_sync( + lambda sync_conn: inspect(sync_conn).get_table_names() + ) + assert "PilotAgents" in tables + + # move submission time back in time + now = datetime.now(tz=timezone.utc) + for i in range(1, upper_limit): + sub_time = now - timedelta(hours=2 * i - 1) + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp == f"stamp_{i}") + .values(SubmissionTime=sub_time) + ) + await db.conn.execute(stmt) + # 4 message records for the first pilot. + line = [{"Message": f"Message_no_{i}"} for i in range(1, 4)] + log_lines = [LogLine(line_no=i + 1, line=line[i]["Message"]) for i in range(3)] + message = LogMessage(pilot_stamp="stamp_1", lines=log_lines, vo="gridpp") + + check_permissions_mock = AsyncMock() + check_permissions_mock.return_value.vo = "gridpp" + # TODO add user properties dict return_value above + mock_url.return_value = {"PilotAgentsDB": "sqlite+aiosqlite:///:memory:"} + # use the existing context (we have a DB already): + pilot_agents_db.engine_context = nullcontext + mock_impl.return_value = [lambda x: pilot_agents_db] + # send logs for stamp_1, pilot id = 1 + pilot_id = await send_message(message, pilot_logs_db, check_permissions_mock) + assert pilot_id == 1 + # get logs for pilot_id=1 + log_records = await get_logs(pilot_id, pilot_logs_db, check_permissions_mock) + assert log_records == line + # delete logs for pilot_id = 1 + check_permissions_mock.return_value.properties = [PILOT] + # TODO: await mock_osdb delete implementation... + # res = await delete(pilot_id, DateRange(), pilot_logs_db, check_permissions_mock) diff --git a/diracx-testing/src/diracx/testing/mock_osdb.py b/diracx-testing/src/diracx/testing/mock_osdb.py index 6e181a79..f92f3968 100644 --- a/diracx-testing/src/diracx/testing/mock_osdb.py +++ b/diracx-testing/src/diracx/testing/mock_osdb.py @@ -96,6 +96,20 @@ async def upsert(self, doc_id, document) -> None: stmt = stmt.on_conflict_do_update(index_elements=["doc_id"], set_=values) await self._sql_db.conn.execute(stmt) + async def bulk_insert(self, index_name: str, docs: list[dict[str, Any]]) -> None: + async with self: + rows = [] + for item, doc in enumerate(docs): + values = {"doc_id": item + 1} + for key, value in doc.items(): + if key in self.fields: + values[key] = value + else: + values.setdefault("extra", {})[key] = value + rows.append(values) + stmt = sqlite_insert(self._table).values(rows) + await self._sql_db.conn.execute(stmt) + async def search( self, parameters: list[str] | None,