From 53f9559e0432faa5ae757135bae97e1922dac7d2 Mon Sep 17 00:00:00 2001 From: MoritzWeber Date: Wed, 13 Nov 2024 13:27:26 +0100 Subject: [PATCH] refactor: Consolidate session hook request arguments into one This makes it easier in the future to extend the requests and allows better type annotations. --- .../capellacollab/sessions/hooks/__init__.py | 24 +- .../sessions/hooks/authentication.py | 8 +- .../capellacollab/sessions/hooks/guacamole.py | 39 ++- backend/capellacollab/sessions/hooks/http.py | 23 +- .../capellacollab/sessions/hooks/interface.py | 227 ++++++++++-------- .../capellacollab/sessions/hooks/jupyter.py | 11 +- .../sessions/hooks/log_collector.py | 51 ++-- .../sessions/hooks/networking.py | 29 +-- .../sessions/hooks/persistent_workspace.py | 24 +- .../sessions/hooks/provisioning.py | 29 ++- .../sessions/hooks/pure_variants.py | 19 +- .../sessions/hooks/read_only_workspace.py | 8 +- .../sessions/hooks/session_preparation.py | 19 +- backend/capellacollab/sessions/hooks/t4c.py | 41 ++-- backend/capellacollab/sessions/routes.py | 56 +++-- backend/capellacollab/sessions/util.py | 12 +- backend/tests/sessions/hooks/conftest.py | 88 +++++++ .../sessions/hooks/test_guacamole_hook.py | 63 +---- .../tests/sessions/hooks/test_http_hook.py | 32 +-- .../tests/sessions/hooks/test_jupyter_hook.py | 34 +-- .../sessions/hooks/test_networking_hook.py | 17 +- .../hooks/test_persistent_workspace.py | 51 ++-- .../hooks/test_pre_authentiation_hook.py | 13 +- .../sessions/hooks/test_provisioning_hook.py | 142 +++++------ .../sessions/hooks/test_pure_variants.py | 22 +- .../hooks/test_session_preparation.py | 21 +- backend/tests/sessions/hooks/test_t4c_hook.py | 96 +++----- backend/tests/sessions/test_session_hooks.py | 45 +--- backend/tests/sessions/test_session_routes.py | 2 +- 29 files changed, 582 insertions(+), 664 deletions(-) create mode 100644 backend/tests/sessions/hooks/conftest.py diff --git a/backend/capellacollab/sessions/hooks/__init__.py b/backend/capellacollab/sessions/hooks/__init__.py index bfebfb9c1a..18d5719a92 100644 --- a/backend/capellacollab/sessions/hooks/__init__.py +++ b/backend/capellacollab/sessions/hooks/__init__.py @@ -25,17 +25,17 @@ "pure_variants": pure_variants.PureVariantsIntegration(), } -REGISTER_HOOKS_AUTO_USE: dict[str, interface.HookRegistration] = { - "persistent_workspace": persistent_workspace.PersistentWorkspaceHook(), - "guacamole": guacamole.GuacamoleIntegration(), - "http": http.HTTPIntegration(), - "read_only_hook": read_only_workspace.ReadOnlyWorkspaceHook(), - "provisioning": provisioning.ProvisionWorkspaceHook(), - "session_preparation": session_preparation.GitRepositoryCloningHook(), - "networking": networking.NetworkingIntegration(), - "authentication": authentication.PreAuthenticationHook(), - "log_collector": log_collector.LogCollectorIntegration(), -} +REGISTER_HOOKS_AUTO_USE: list[interface.HookRegistration] = [ + persistent_workspace.PersistentWorkspaceHook(), + guacamole.GuacamoleIntegration(), + http.HTTPIntegration(), + read_only_workspace.ReadOnlyWorkspaceHook(), + provisioning.ProvisionWorkspaceHook(), + session_preparation.GitRepositoryCloningHook(), + networking.NetworkingIntegration(), + authentication.PreAuthenticationHook(), + log_collector.LogCollectorIntegration(), +] def get_activated_integration_hooks( @@ -46,4 +46,4 @@ def get_activated_integration_hooks( hook for integration, hook in REGISTERED_HOOKS.items() if getattr(tool.integrations, integration, False) - ] + list(REGISTER_HOOKS_AUTO_USE.values()) + ] + REGISTER_HOOKS_AUTO_USE diff --git a/backend/capellacollab/sessions/hooks/authentication.py b/backend/capellacollab/sessions/hooks/authentication.py index 16b0aaf2f3..5a8d4a564c 100644 --- a/backend/capellacollab/sessions/hooks/authentication.py +++ b/backend/capellacollab/sessions/hooks/authentication.py @@ -13,18 +13,16 @@ class PreAuthenticationHook(interface.HookRegistration): - def session_connection_hook( # type: ignore[override] + def session_connection_hook( self, - db_session: sessions_models.DatabaseSession, - user: users_models.DatabaseUser, - **kwargs, + request: interface.SessionConnectionHookRequest, ) -> interface.SessionConnectionHookResult: """Issue pre-authentication tokens for sessions""" return interface.SessionConnectionHookResult( cookies={ "ccm_session_token": self._issue_session_token( - user, db_session + request.user, request.db_session ) } ) diff --git a/backend/capellacollab/sessions/hooks/guacamole.py b/backend/capellacollab/sessions/hooks/guacamole.py index dce8f001c6..4ea12c157c 100644 --- a/backend/capellacollab/sessions/hooks/guacamole.py +++ b/backend/capellacollab/sessions/hooks/guacamole.py @@ -12,9 +12,6 @@ from capellacollab.config import config from capellacollab.core import credentials -from capellacollab.sessions import models as sessions_models -from capellacollab.sessions.operators import k8s -from capellacollab.tools import models as tools_models from . import interface @@ -40,14 +37,11 @@ class GuacamoleIntegration(interface.HookRegistration): "https": None, } - def post_session_creation_hook( # type: ignore[override] + def post_session_creation_hook( self, - session: k8s.Session, - db_session: sessions_models.DatabaseSession, - connection_method: tools_models.ToolSessionConnectionMethod, - **kwargs, + request: interface.PostSessionCreationHookRequest, ) -> interface.PostSessionCreationHookResult: - if connection_method.type != "guacamole": + if request.connection_method.type != "guacamole": return interface.PostSessionCreationHookResult() guacamole_username = credentials.generate_password() @@ -60,9 +54,9 @@ def post_session_creation_hook( # type: ignore[override] guacamole_identifier = self._create_connection( guacamole_token, - db_session.environment["CAPELLACOLLAB_SESSION_TOKEN"], - session["host"], - session["port"], + request.db_session.environment["CAPELLACOLLAB_SESSION_TOKEN"], + request.session["host"], + request.session["port"], )["identifier"] self._assign_user_to_connection( @@ -79,16 +73,14 @@ def post_session_creation_hook( # type: ignore[override] config=guacamole_config, ) - def session_connection_hook( # type: ignore[override] + def session_connection_hook( self, - db_session: sessions_models.DatabaseSession, - connection_method: tools_models.ToolSessionConnectionMethod, - **kwargs, + request: interface.SessionConnectionHookRequest, ) -> interface.SessionConnectionHookResult: - if connection_method.type != "guacamole": + if request.connection_method.type != "guacamole": return interface.SessionConnectionHookResult() - session_config = db_session.config + session_config = request.db_session.config if not session_config or not session_config.get("guacamole_username"): return interface.SessionConnectionHookResult() @@ -102,16 +94,13 @@ def session_connection_hook( # type: ignore[override] redirect_url=config.extensions.guacamole.public_uri + "/#/", ) - def pre_session_termination_hook( # type: ignore[override] - self, - session: sessions_models.DatabaseSession, - connection_method: tools_models.ToolSessionConnectionMethod, - **kwargs, + def pre_session_termination_hook( + self, request: interface.PreSessionTerminationHookRequest ) -> interface.PreSessionTerminationHookResult: - if connection_method.type != "guacamole": + if request.connection_method.type != "guacamole": return interface.SessionConnectionHookResult() - session_config = session.config + session_config = request.session.config if session_config and session_config.get("guacamole_username"): guacamole_token = self._get_admin_token() diff --git a/backend/capellacollab/sessions/hooks/http.py b/backend/capellacollab/sessions/hooks/http.py index aae9953021..e97b36b15b 100644 --- a/backend/capellacollab/sessions/hooks/http.py +++ b/backend/capellacollab/sessions/hooks/http.py @@ -1,35 +1,28 @@ # SPDX-FileCopyrightText: Copyright DB InfraGO AG and contributors # SPDX-License-Identifier: Apache-2.0 -import logging - from capellacollab.core import models as core_models from capellacollab.tools import models as tools_models -from .. import models as sessions_models from .. import util as sessions_util from . import interface class HTTPIntegration(interface.HookRegistration): - def session_connection_hook( # type: ignore[override] - self, - db_session: sessions_models.DatabaseSession, - connection_method: tools_models.ToolSessionConnectionMethod, - logger: logging.LoggerAdapter, - **kwargs, + def session_connection_hook( + self, request: interface.SessionConnectionHookRequest ) -> interface.SessionConnectionHookResult: if not isinstance( - connection_method, tools_models.HTTPConnectionMethod + request.connection_method, tools_models.HTTPConnectionMethod ): return interface.SessionConnectionHookResult() try: - redirect_url = connection_method.redirect_url.format( - **db_session.environment + redirect_url = request.connection_method.redirect_url.format( + **request.db_session.environment ) except Exception: - logger.error( + request.logger.error( "Error while formatting the redirect URL", exc_info=True ) return interface.SessionConnectionHookResult( @@ -43,7 +36,9 @@ def session_connection_hook( # type: ignore[override] ) cookies, warnings = sessions_util.resolve_environment_variables( - logger, db_session.environment, connection_method.cookies + request.logger, + request.db_session.environment, + request.connection_method.cookies, ) return interface.SessionConnectionHookResult( diff --git a/backend/capellacollab/sessions/hooks/interface.py b/backend/capellacollab/sessions/hooks/interface.py index 743a3c2c8b..b06f59ca3b 100644 --- a/backend/capellacollab/sessions/hooks/interface.py +++ b/backend/capellacollab/sessions/hooks/interface.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import abc +import dataclasses import logging import typing as t @@ -17,6 +18,43 @@ from .. import models as sessions_models +@dataclasses.dataclass() +class ConfigurationHookRequest: + """Request type of the configuration hook + + Attributes + ---------- + db : sqlalchemy.orm.Session + Database session. Can be used to access the database + operator : operators.KubernetesOperator + Operator, which is used to spawn the session + user : users_models.DatabaseUser + User who has requested the session + tool : tools_models.DatabaseTool + Tool of the requested session + tool_version : tools_models.DatabaseVersion + Tool version of the requested session + session_type : sessions_models.SessionType + Type of the session (persistent, read-only, etc.) + connection_method : tools_models.ToolSessionConnectionMethod + Requested connection method for the session + provisioning : list[sessions_models.SessionProvisioningRequest] + List of workspace provisioning requests + session_id: str + ID of the session to be created + """ + + db: orm.Session + operator: operators.KubernetesOperator + user: users_models.DatabaseUser + tool: tools_models.DatabaseTool + tool_version: tools_models.DatabaseVersion + session_type: sessions_models.SessionType + connection_method: tools_models.ToolSessionConnectionMethod + provisioning: list[sessions_models.SessionProvisioningRequest] + session_id: str + + class ConfigurationHookResult(t.TypedDict): """Return type of the configuration hook @@ -42,6 +80,37 @@ class ConfigurationHookResult(t.TypedDict): init_environment: t.NotRequired[t.Mapping] +@dataclasses.dataclass() +class PostSessionCreationHookRequest: + """Request type of the post session creation hook + + Attributes + ---------- + session_id : str + ID of the session + session : k8s.Session + Session object (contains connection information) + db_session : sessions_models.DatabaseSession + Collaboration Manager session in the database + operator : operators.KubernetesOperator + Operator, which is used to spawn the session + user : users_models.DatabaseUser + User who has requested the session + connection_method : tools_models.ToolSessionConnectionMethod + Requested connection method for the session + db : orm.Session + Database session. Can be used to access the database + """ + + session_id: str + session: k8s.Session + db_session: sessions_models.DatabaseSession + operator: operators.KubernetesOperator + user: users_models.DatabaseUser + connection_method: tools_models.ToolSessionConnectionMethod + db: orm.Session + + class PostSessionCreationHookResult(t.TypedDict): """Return type of the post session creation hook @@ -55,6 +124,31 @@ class PostSessionCreationHookResult(t.TypedDict): config: t.NotRequired[t.Mapping] +@dataclasses.dataclass() +class SessionConnectionHookRequest: + """Request type of the session connection hook + + Attributes + ---------- + db : sqlalchemy.orm.Session + Database session. Can be used to access the database + db_session : sessions_models.DatabaseSession + Collaboration Manager session in the database + connection_method : tools_models.ToolSessionConnectionMethod + Connection method of the session + logger : logging.LoggerAdapter + Logger for the specific request + user : users_models.DatabaseUser + User who is connecting to the session + """ + + db: orm.Session + db_session: sessions_models.DatabaseSession + connection_method: tools_models.ToolSessionConnectionMethod + logger: logging.LoggerAdapter + user: users_models.DatabaseUser + + class SessionConnectionHookResult(t.TypedDict): """Return type of the session connection hook @@ -80,6 +174,28 @@ class SessionConnectionHookResult(t.TypedDict): warnings: t.NotRequired[list[core_models.Message]] +@dataclasses.dataclass() +class PreSessionTerminationHookRequest: + """Request type of the pre session termination hook + + Attributes + ---------- + db : sqlalchemy.orm.Session + Database session. Can be used to access the database + operator : operators.KubernetesOperator + Operator, which is used to spawn the session + session : sessions_models.DatabaseSession + Session which is to be terminated + connection_method : tools_models.ToolSessionConnectionMethod + Connection method of the session + """ + + db: orm.Session + operator: operators.KubernetesOperator + session: sessions_models.DatabaseSession + connection_method: tools_models.ToolSessionConnectionMethod + + class PreSessionTerminationHookResult(t.TypedDict): """Return type of the pre session termination hook""" @@ -101,146 +217,47 @@ class HookRegistration(metaclass=abc.ABCMeta): # pylint: disable=unused-argument def configuration_hook( - self, - db: orm.Session, - operator: operators.KubernetesOperator, - user: users_models.DatabaseUser, - tool: tools_models.DatabaseTool, - tool_version: tools_models.DatabaseVersion, - session_type: sessions_models.SessionType, - connection_method: tools_models.ToolSessionConnectionMethod, - provisioning: list[sessions_models.SessionProvisioningRequest], - session_id: str, - **kwargs, + self, request: ConfigurationHookRequest ) -> ConfigurationHookResult: """Hook to determine session configuration This hook is executed before the creation of persistent sessions. - - Parameters - ---------- - db : sqlalchemy.orm.Session - Database session. Can be used to access the database - operator : operators.KubernetesOperator - Operator, which is used to spawn the session - user : users_models.DatabaseUser - User who has requested the session - tool : tools_models.DatabaseTool - Tool of the requested session - tool_version : tools_models.DatabaseVersion - Tool version of the requested session - session_type : sessions_models.SessionType - Type of the session (persistent, read-only, etc.) - connection_method : tools_models.ToolSessionConnectionMethod - Requested connection method for the session - provisioning : list[sessions_models.SessionProvisioningRequest] - List of workspace provisioning requests - session_id: str - ID of the session to be created - Returns - ------- - result : ConfigurationHookResult """ return ConfigurationHookResult() + # pylint: disable=unused-argument def post_session_creation_hook( self, - session_id: str, - session: k8s.Session, - db_session: sessions_models.DatabaseSession, - operator: operators.KubernetesOperator, - user: users_models.DatabaseUser, - connection_method: tools_models.ToolSessionConnectionMethod, - db: orm.Session, - **kwargs, + request: PostSessionCreationHookRequest, ) -> PostSessionCreationHookResult: """Hook executed after session creation This hook is executed after a persistent session was created by the operator. - - Parameters - ---------- - session_id : str - ID of the session - session : k8s.Session - Session object (contains connection information) - db_session : sessions_models.DatabaseSession - Collaboration Manager session in the database - operator : operators.KubernetesOperator - Operator, which is used to spawn the session - user : users_models.DatabaseUser - User who has requested the session - connection_method : tools_models.ToolSessionConnectionMethod - Requested connection method for the session - db : sqlalchemy.orm.Session - Database session. Can be used to access the database - - Returns - ------- - result : PostSessionCreationHookResult """ return PostSessionCreationHookResult() # pylint: disable=unused-argument def session_connection_hook( - self, - db: orm.Session, - db_session: sessions_models.DatabaseSession, - connection_method: tools_models.ToolSessionConnectionMethod, - logger: logging.LoggerAdapter, - **kwargs, + self, request: SessionConnectionHookRequest ) -> SessionConnectionHookResult: """Hook executed while connecting to a session The hook is executed each time the GET `/sessions/{session_id}/connection` endpoint is called. - - Parameters - ---------- - db : sqlalchemy.orm.Session - Database session. Can be used to access the database - db_session : sessions_models.DatabaseSession - Collaboration Manager session in the database - connection_method : tools_models.ToolSessionConnectionMethod - Connection method of the session - logger : logging.LoggerAdapter - Logger for the specific request - Returns - ------- - result : SessionConnectionHookResult """ return SessionConnectionHookResult() + # pylint: disable=unused-argument def pre_session_termination_hook( - self, - db: orm.Session, - operator: operators.KubernetesOperator, - session: sessions_models.DatabaseSession, - connection_method: tools_models.ToolSessionConnectionMethod, - **kwargs, + self, request: PreSessionTerminationHookRequest ) -> PreSessionTerminationHookResult: """Hook executed directly before session termination This hook is executed before a read-only or persistent session is terminated by the operator. - - Parameters - ---------- - db : sqlalchemy.orm.Session - Database session. Can be used to access the database - operator : operators.KubernetesOperator - Operator, which is used to spawn the session - session : sessions_models.DatabaseSession - Session which is to be terminated - connection_method : tools_models.ToolSessionConnectionMethod - Connection method of the session - - Returns - ------- - result : PreSessionTerminationHookResult """ return PreSessionTerminationHookResult() diff --git a/backend/capellacollab/sessions/hooks/jupyter.py b/backend/capellacollab/sessions/hooks/jupyter.py index 4e0fa7050a..5e208291da 100644 --- a/backend/capellacollab/sessions/hooks/jupyter.py +++ b/backend/capellacollab/sessions/hooks/jupyter.py @@ -14,7 +14,6 @@ from capellacollab.sessions import operators from capellacollab.sessions.operators import models as operators_models from capellacollab.tools import models as tools_models -from capellacollab.users import models as users_models from . import interface @@ -22,16 +21,12 @@ class JupyterIntegration(interface.HookRegistration): - def configuration_hook( # type: ignore[override] + def configuration_hook( self, - db: orm.Session, - user: users_models.DatabaseUser, - tool: tools_models.DatabaseTool, - operator: operators.KubernetesOperator, - **kwargs, + request: interface.ConfigurationHookRequest, ) -> interface.ConfigurationHookResult: volumes, warnings = self._get_project_share_volume_mounts( - db, user.name, tool, operator + request.db, request.user.name, request.tool, request.operator ) return interface.ConfigurationHookResult( volumes=volumes, warnings=warnings diff --git a/backend/capellacollab/sessions/hooks/log_collector.py b/backend/capellacollab/sessions/hooks/log_collector.py index ab1ffa6604..fea9bb802b 100644 --- a/backend/capellacollab/sessions/hooks/log_collector.py +++ b/backend/capellacollab/sessions/hooks/log_collector.py @@ -5,13 +5,10 @@ import pathlib import yaml -from sqlalchemy import orm from capellacollab.config import config from capellacollab.sessions import models as sessions_models -from capellacollab.sessions import operators from capellacollab.sessions.operators import models as operators_models -from capellacollab.users import models as users_models from capellacollab.users.workspaces import crud as users_workspaces_crud from . import interface @@ -22,38 +19,36 @@ class LogCollectorIntegration(interface.HookRegistration): _loki_enabled: bool = config.k8s.promtail.loki_enabled - def post_session_creation_hook( # type: ignore[override] + def post_session_creation_hook( self, - db_session: sessions_models.DatabaseSession, - operator: operators.KubernetesOperator, - user: users_models.DatabaseUser, - db: orm.Session, - **kwargs, + request: interface.PostSessionCreationHookRequest, ) -> interface.PostSessionCreationHookResult: if ( not self._loki_enabled - or db_session.type == sessions_models.SessionType.READONLY + or request.db_session.type == sessions_models.SessionType.READONLY ): return interface.PostSessionCreationHookResult() - workspaces = users_workspaces_crud.get_workspaces_for_user(db, user) + workspaces = users_workspaces_crud.get_workspaces_for_user( + request.db, request.user + ) if not workspaces: return interface.PostSessionCreationHookResult() - operator._create_configmap( - name=db_session.id, + request.operator._create_configmap( + name=request.db_session.id, data=self._promtail_configuration( - username=user.name, - session_type=db_session.type.value, - tool_name=db_session.tool.name, - version_name=db_session.version.name, + username=request.user.name, + session_type=request.db_session.type.value, + tool_name=request.db_session.tool.name, + version_name=request.db_session.version.name, ), ) labels: dict[str, str] = { "capellacollab/workload": "session-sidecar", - "capellacollab/session-id": db_session.id, - "capellacollab/owner-id": str(user.id), + "capellacollab/session-id": request.db_session.id, + "capellacollab/owner-id": str(request.user.id), } volumes = [ @@ -61,7 +56,7 @@ def post_session_creation_hook( # type: ignore[override] name="prom-config", read_only=True, container_path=pathlib.PurePosixPath("/etc/promtail"), - config_map_name=db_session.id, + config_map_name=request.db_session.id, optional=False, ), operators_models.PersistentVolume( @@ -72,9 +67,9 @@ def post_session_creation_hook( # type: ignore[override] ), ] - operator._create_sidecar_pod( + request.operator._create_sidecar_pod( image=f"{config.docker.external_registry}/grafana/promtail", - name=f"{db_session.id}-promtail", + name=f"{request.db_session.id}-promtail", labels=labels, args=[ "--config.file=/etc/promtail/promtail.yaml", @@ -85,20 +80,18 @@ def post_session_creation_hook( # type: ignore[override] return interface.PostSessionCreationHookResult() - def pre_session_termination_hook( # type: ignore[override] + def pre_session_termination_hook( self, - session: sessions_models.DatabaseSession, - operator: operators.KubernetesOperator, - **kwargs, + request: interface.PreSessionTerminationHookRequest, ) -> interface.PreSessionTerminationHookResult: if ( not self._loki_enabled - or session.type == sessions_models.SessionType.READONLY + or request.session.type == sessions_models.SessionType.READONLY ): return interface.PostSessionCreationHookResult() - operator._delete_config_map(name=session.id) - operator._delete_pod(name=f"{session.id}-promtail") + request.operator._delete_config_map(name=request.session.id) + request.operator._delete_pod(name=f"{request.session.id}-promtail") return interface.PreSessionTerminationHookResult() diff --git a/backend/capellacollab/sessions/hooks/networking.py b/backend/capellacollab/sessions/hooks/networking.py index 4b613331b2..7c7e693506 100644 --- a/backend/capellacollab/sessions/hooks/networking.py +++ b/backend/capellacollab/sessions/hooks/networking.py @@ -2,43 +2,32 @@ # SPDX-License-Identifier: Apache-2.0 -from capellacollab.sessions import operators -from capellacollab.users import models as users_models - -from .. import models as sessions_models from . import interface class NetworkingIntegration(interface.HookRegistration): """Allow sessions of the same user to talk to each other.""" - def post_session_creation_hook( # type: ignore - self, - session_id: str, - operator: operators.KubernetesOperator, - user: users_models.DatabaseUser, - **kwargs, + def post_session_creation_hook( + self, request: interface.PostSessionCreationHookRequest ) -> interface.PostSessionCreationHookResult: """Allow sessions of the user to talk to each other.""" - operator.create_network_policy_from_pod_to_label( - session_id, + request.operator.create_network_policy_from_pod_to_label( + request.session_id, match_labels_from={ - "capellacollab/session-id": session_id, + "capellacollab/session-id": request.session_id, "capellacollab/workload": "session", }, match_labels_to={ - "capellacollab/owner-id": str(user.id), + "capellacollab/owner-id": str(request.user.id), "capellacollab/workload": "session", }, ) return interface.PostSessionCreationHookResult() - def pre_session_termination_hook( # type: ignore - self, - operator: operators.KubernetesOperator, - session: sessions_models.DatabaseSession, - **kwargs, + def pre_session_termination_hook( + self, request: interface.PreSessionTerminationHookRequest ): - operator.delete_network_policy(session.id) + request.operator.delete_network_policy(request.session.id) diff --git a/backend/capellacollab/sessions/hooks/persistent_workspace.py b/backend/capellacollab/sessions/hooks/persistent_workspace.py index 5086f57a79..98e7ab37ed 100644 --- a/backend/capellacollab/sessions/hooks/persistent_workspace.py +++ b/backend/capellacollab/sessions/hooks/persistent_workspace.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import pathlib -import typing as t import uuid from sqlalchemy import orm @@ -19,32 +18,25 @@ from . import interface -class PersistentWorkspacEnvironment(t.TypedDict): - pass - - class PersistentWorkspaceHook(interface.HookRegistration): """Takes care of the persistent workspace of a user. Is responsible for mounting the persistent workspace into persistent sessions. """ - def configuration_hook( # type: ignore + def configuration_hook( self, - db: orm.Session, - operator: operators.KubernetesOperator, - user: users_models.DatabaseUser, - session_type: sessions_models.SessionType, - tool: tools_models.DatabaseTool, - **kwargs, + request: interface.ConfigurationHookRequest, ) -> interface.ConfigurationHookResult: - if session_type == sessions_models.SessionType.READONLY: + if request.session_type == sessions_models.SessionType.READONLY: # Skip read-only sessions, no persistent workspace needed. return interface.ConfigurationHookResult() - self._check_that_persistent_workspace_is_allowed(tool) + self._check_that_persistent_workspace_is_allowed(request.tool) - volume_name = self._create_persistent_workspace(db, operator, user) + volume_name = self._create_persistent_workspace( + request.db, request.operator, request.user + ) volume = operators_models.PersistentVolume( name="workspace", read_only=False, @@ -53,7 +45,7 @@ def configuration_hook( # type: ignore ) return interface.ConfigurationHookResult( - volumes=[volume], + volumes=[volume], init_volumes=[volume] ) def _check_that_persistent_workspace_is_allowed( diff --git a/backend/capellacollab/sessions/hooks/provisioning.py b/backend/capellacollab/sessions/hooks/provisioning.py index 4d1b92c902..9174c52f9f 100644 --- a/backend/capellacollab/sessions/hooks/provisioning.py +++ b/backend/capellacollab/sessions/hooks/provisioning.py @@ -40,26 +40,29 @@ class ProvisionWorkspaceHook(interface.HookRegistration): """Takes care of the provisioning of user workspaces.""" @classmethod - def configuration_hook( # type: ignore - cls, - db: orm.Session, - tool: tools_models.DatabaseTool, - tool_version: tools_models.DatabaseVersion, - user: users_models.DatabaseUser, - provisioning: list[sessions_models.SessionProvisioningRequest], - **kwargs, + def configuration_hook( + cls, request: interface.ConfigurationHookRequest ) -> interface.ConfigurationHookResult: - max_number_of_models = tool.config.provisioning.max_number_of_models - if max_number_of_models and len(provisioning) > max_number_of_models: + max_number_of_models = ( + request.tool.config.provisioning.max_number_of_models + ) + if ( + max_number_of_models + and len(request.provisioning) > max_number_of_models + ): raise sessions_exceptions.TooManyModelsRequestedToProvisionError( max_number_of_models ) - resolved_entries = cls._resolve_provisioning_request(db, provisioning) + resolved_entries = cls._resolve_provisioning_request( + request.db, request.provisioning + ) cls._verify_matching_tool_version_and_model( - db, tool_version, resolved_entries + request.db, request.tool_version, resolved_entries + ) + cls._verify_model_permissions( + request.db, request.user, resolved_entries ) - cls._verify_model_permissions(db, user, resolved_entries) init_environment = { "CAPELLACOLLAB_PROVISIONING": cls._get_git_repos_json( diff --git a/backend/capellacollab/sessions/hooks/pure_variants.py b/backend/capellacollab/sessions/hooks/pure_variants.py index 75d8b8c19e..f0016ee620 100644 --- a/backend/capellacollab/sessions/hooks/pure_variants.py +++ b/backend/capellacollab/sessions/hooks/pure_variants.py @@ -5,8 +5,6 @@ import pathlib import typing as t -from sqlalchemy import orm - from capellacollab.core import models as core_models from capellacollab.projects.toolmodels import models as toolmodels_models from capellacollab.sessions import models as sessions_models @@ -26,20 +24,17 @@ class PureVariantsConfigEnvironment(t.TypedDict): class PureVariantsIntegration(interface.HookRegistration): - def configuration_hook( # type: ignore + def configuration_hook( self, - db: orm.Session, - user: users_models.DatabaseUser, - session_type: sessions_models.SessionType, - **kwargs, + request: interface.ConfigurationHookRequest, ) -> interface.ConfigurationHookResult: - if session_type == sessions_models.SessionType.READONLY: + if request.session_type == sessions_models.SessionType.READONLY: # Skip read-only sessions, no pure::variants integration supported. return interface.ConfigurationHookResult() if ( - not self._user_has_project_with_pure_variants_model(user) - and user.role == users_models.Role.USER + not self._user_has_project_with_pure_variants_model(request.user) + and request.user.role == users_models.Role.USER ): warnings = [ core_models.Message( @@ -57,7 +52,9 @@ def configuration_hook( # type: ignore warnings=warnings, ) - pv_license = purevariants_crud.get_pure_variants_configuration(db) + pv_license = purevariants_crud.get_pure_variants_configuration( + request.db + ) if not pv_license or pv_license.license_server_url is None: warnings = [ core_models.Message( diff --git a/backend/capellacollab/sessions/hooks/read_only_workspace.py b/backend/capellacollab/sessions/hooks/read_only_workspace.py index 10109a4812..cbe1b45119 100644 --- a/backend/capellacollab/sessions/hooks/read_only_workspace.py +++ b/backend/capellacollab/sessions/hooks/read_only_workspace.py @@ -12,12 +12,10 @@ class ReadOnlyWorkspaceHook(interface.HookRegistration): """Mounts an empty workspace to the container for read-only sessions.""" - def configuration_hook( # type: ignore - self, - session_type: sessions_models.SessionType, - **kwargs, + def configuration_hook( + self, request: interface.ConfigurationHookRequest ) -> interface.ConfigurationHookResult: - if session_type != sessions_models.SessionType.READONLY: + if request.session_type != sessions_models.SessionType.READONLY: # Configuration for persistent workspace sessions happens in the PersistentWorkspaceHook. return interface.ConfigurationHookResult() diff --git a/backend/capellacollab/sessions/hooks/session_preparation.py b/backend/capellacollab/sessions/hooks/session_preparation.py index 4e27ec7f4f..e81c7f53a7 100644 --- a/backend/capellacollab/sessions/hooks/session_preparation.py +++ b/backend/capellacollab/sessions/hooks/session_preparation.py @@ -2,39 +2,30 @@ # SPDX-License-Identifier: Apache-2.0 import pathlib -import typing as t from capellacollab.sessions import models as sessions_models from capellacollab.sessions.operators import models as operators_models -from capellacollab.tools import models as tools_models from . import interface -class PersistentWorkspacEnvironment(t.TypedDict): - pass - - class GitRepositoryCloningHook(interface.HookRegistration): """Creates a volume that is shared between the actual container and the session preparation. The volume is used to clone Git repositories as preparation for the session. """ - def configuration_hook( # type: ignore + def configuration_hook( self, - session_type: sessions_models.SessionType, - session_id: str, - tool: tools_models.DatabaseTool, - **kwargs, + request: interface.ConfigurationHookRequest, ) -> interface.ConfigurationHookResult: - if session_type != sessions_models.SessionType.READONLY: + if request.session_type != sessions_models.SessionType.READONLY: return interface.ConfigurationHookResult() shared_model_volume = operators_models.EmptyVolume( - name=f"{session_id}-models", + name=f"{request.session_id}-models", container_path=pathlib.PurePosixPath( - tool.config.provisioning.directory + request.tool.config.provisioning.directory ), read_only=False, ) diff --git a/backend/capellacollab/sessions/hooks/t4c.py b/backend/capellacollab/sessions/hooks/t4c.py index e6ffd4a884..92f03ba81e 100644 --- a/backend/capellacollab/sessions/hooks/t4c.py +++ b/backend/capellacollab/sessions/hooks/t4c.py @@ -17,7 +17,6 @@ from capellacollab.settings.modelsources.t4c.instance.repositories import ( interface as repo_interface, ) -from capellacollab.tools import models as tools_models from capellacollab.users import models as users_models from .. import models as sessions_models @@ -34,22 +33,19 @@ class T4CConfigEnvironment(t.TypedDict): class T4CIntegration(interface.HookRegistration): - def configuration_hook( # type: ignore - self, - db: orm.Session, - user: users_models.DatabaseUser, - tool_version: tools_models.DatabaseVersion, - session_type: sessions_models.SessionType, - **kwargs, + def configuration_hook( + self, request: interface.ConfigurationHookRequest ) -> interface.ConfigurationHookResult: - if session_type != sessions_models.SessionType.PERSISTENT: + user = request.user + + if request.session_type != sessions_models.SessionType.PERSISTENT: # Skip non-persistent sessions, no T4C integration needed. return interface.ConfigurationHookResult() warnings: list[core_models.Message] = [] t4c_repositories = repo_crud.get_user_t4c_repositories( - db, tool_version, user + request.db, request.tool_version, user ) t4c_json = json.dumps( @@ -91,7 +87,7 @@ def configuration_hook( # type: ignore password=environment["T4C_PASSWORD"], is_admin=auth_injectables.RoleVerification( required_role=users_models.Role.ADMIN, verify=False - )(user.name, db), + )(user.name, request.db), ) except requests.RequestException: warnings.append( @@ -116,30 +112,25 @@ def configuration_hook( # type: ignore environment=environment, warnings=warnings ) - def pre_session_termination_hook( # type: ignore - self, - db: orm.Session, - session: sessions_models.DatabaseSession, - **kwargs, + def pre_session_termination_hook( + self, request: interface.PreSessionTerminationHookRequest ): - if session.type == sessions_models.SessionType.PERSISTENT: - self._revoke_session_tokens(db, session) + if request.session.type == sessions_models.SessionType.PERSISTENT: + self._revoke_session_tokens(request.db, request.session) - def session_connection_hook( # type: ignore[override] + def session_connection_hook( self, - db_session: sessions_models.DatabaseSession, - user: users_models.DatabaseUser, - **kwargs, + request: interface.SessionConnectionHookRequest, ) -> interface.SessionConnectionHookResult: - if db_session.type != sessions_models.SessionType.PERSISTENT: + if request.db_session.type != sessions_models.SessionType.PERSISTENT: return interface.SessionConnectionHookResult() - if db_session.owner != user: + if request.db_session.owner != request.user: # The session is shared, don't provide the T4C token. return interface.SessionConnectionHookResult() return interface.SessionConnectionHookResult( - t4c_token=db_session.environment.get("T4C_PASSWORD") + t4c_token=request.db_session.environment.get("T4C_PASSWORD") ) def _revoke_session_tokens( diff --git a/backend/capellacollab/sessions/routes.py b/backend/capellacollab/sessions/routes.py index fb4c0546ed..04e810de7b 100644 --- a/backend/capellacollab/sessions/routes.py +++ b/backend/capellacollab/sessions/routes.py @@ -19,6 +19,7 @@ from capellacollab.core.authentication import injectables as auth_injectables from capellacollab.sessions import hooks from capellacollab.sessions.files import routes as files_routes +from capellacollab.sessions.hooks import interface as hooks_interface from capellacollab.tools import exceptions as tools_exceptions from capellacollab.tools import injectables as tools_injectables from capellacollab.tools import models as tools_models @@ -125,19 +126,21 @@ def request_session( init_volumes: list[operators_models.Volume] = [] init_environment: dict[str, str] = {} + hook_request = hooks_interface.ConfigurationHookRequest( + db=db, + user=user, + tool_version=version, + tool=tool, + operator=operator, + session_type=body.session_type, + connection_method=connection_method, + provisioning=body.provisioning, + session_id=session_id, + ) + for hook in hooks.get_activated_integration_hooks(tool): - hook_result = hook.configuration_hook( - db=db, - user=user, - tool_version=version, - tool=tool, - username=user.name, - operator=operator, - session_type=body.session_type, - connection_method=connection_method, - provisioning=body.provisioning, - session_id=session_id, - ) + hook_result = hook.configuration_hook(hook_request) + environment |= hook_result.get("environment", {}) init_environment |= hook_result.get("init_environment", {}) volumes += hook_result.get("volumes", []) @@ -224,15 +227,18 @@ def request_session( ) hook_config: dict[str, str] = {} + for hook in hooks.get_activated_integration_hooks(tool): result = hook.post_session_creation_hook( - session_id=session_id, - operator=operator, - user=user, - session=session, - db_session=db_session, - connection_method=connection_method, - db=db, + hooks_interface.PostSessionCreationHookRequest( + session_id=session_id, + operator=operator, + user=user, + session=session, + db_session=db_session, + connection_method=connection_method, + db=db, + ), ) hook_config |= result.get("config", {}) @@ -370,11 +376,13 @@ def get_session_connection_information( for hook in hooks.get_activated_integration_hooks(session.tool): hook_result = hook.session_connection_hook( - db=db, - user=user, - db_session=session, - connection_method=connection_method, - logger=logger, + hooks_interface.SessionConnectionHookRequest( + db=db, + db_session=session, + connection_method=connection_method, + logger=logger, + user=user, + ) ) local_storage |= hook_result.get("local_storage", {}) diff --git a/backend/capellacollab/sessions/util.py b/backend/capellacollab/sessions/util.py index 60c8a69855..a4102107f0 100644 --- a/backend/capellacollab/sessions/util.py +++ b/backend/capellacollab/sessions/util.py @@ -13,6 +13,7 @@ from capellacollab.core import credentials from capellacollab.core import models as core_models from capellacollab.sessions import hooks +from capellacollab.sessions.hooks import interface as hooks_interface from capellacollab.sessions.operators import k8s from capellacollab.tools import models as tools_models from capellacollab.users import models as users_models @@ -33,11 +34,12 @@ def terminate_session( ) for hook in hooks.get_activated_integration_hooks(session.tool): hook.pre_session_termination_hook( - db=db, - session=session, - operator=operator, - user=session.owner, - connection_method=connection_method, + hooks_interface.PreSessionTerminationHookRequest( + db=db, + session=session, + operator=operator, + connection_method=connection_method, + ) ) crud.delete_session(db, session) diff --git a/backend/tests/sessions/hooks/conftest.py b/backend/tests/sessions/hooks/conftest.py new file mode 100644 index 0000000000..2a8f5d2d55 --- /dev/null +++ b/backend/tests/sessions/hooks/conftest.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright DB InfraGO AG and contributors +# SPDX-License-Identifier: Apache-2.0 + +import datetime +import logging + +import pytest +from sqlalchemy import orm + +from capellacollab.sessions import models as sessions_models +from capellacollab.sessions.hooks import interface as hooks_interface +from capellacollab.sessions.operators import k8s as k8s_operator +from capellacollab.tools import models as tools_models +from capellacollab.users import models as users_models + + +@pytest.fixture(name="configuration_hook_request") +def fixture_configuration_hook_request( + db: orm.Session, + user: users_models.DatabaseUser, + capella_tool: tools_models.DatabaseTool, + capella_tool_version: tools_models.DatabaseVersion, +) -> hooks_interface.ConfigurationHookRequest: + return hooks_interface.ConfigurationHookRequest( + db=db, + operator=k8s_operator.KubernetesOperator(), + user=user, + tool=capella_tool, + tool_version=capella_tool_version, + session_type=sessions_models.SessionType.PERSISTENT, + connection_method=tools_models.GuacamoleConnectionMethod(), + provisioning=[], + session_id="nxylxqbmfqwvswlqlcbsirvrt", + ) + + +@pytest.fixture(name="post_session_creation_hook_request") +def fixture_post_session_creation_hook_request( + session: sessions_models.DatabaseSession, + user: users_models.DatabaseUser, + db: orm.Session, +) -> hooks_interface.PostSessionCreationHookRequest: + return hooks_interface.PostSessionCreationHookRequest( + session_id="test", + db_session=session, + session={ + "id": "test", + "port": 8080, + "created_at": datetime.datetime.fromisoformat( + "2021-01-01T00:00:00" + ), + "host": "test", + }, + user=user, + connection_method=tools_models.GuacamoleConnectionMethod(), + operator=k8s_operator.KubernetesOperator(), + db=db, + ) + + +@pytest.fixture(name="session_connection_hook_request") +def fixture_session_connection_hook_request( + db: orm.Session, + user: users_models.DatabaseUser, + session: sessions_models.DatabaseSession, + logger: logging.LoggerAdapter, +) -> hooks_interface.SessionConnectionHookRequest: + + return hooks_interface.SessionConnectionHookRequest( + db=db, + db_session=session, + connection_method=tools_models.GuacamoleConnectionMethod(), + logger=logger, + user=user, + ) + + +@pytest.fixture(name="pre_session_termination_hook_request") +def fixture_pre_session_termination_hook_request( + db: orm.Session, + session: sessions_models.DatabaseSession, +) -> hooks_interface.PreSessionTerminationHookRequest: + return hooks_interface.PreSessionTerminationHookRequest( + db=db, + connection_method=tools_models.GuacamoleConnectionMethod(), + operator=k8s_operator.KubernetesOperator(), + session=session, + ) diff --git a/backend/tests/sessions/hooks/test_guacamole_hook.py b/backend/tests/sessions/hooks/test_guacamole_hook.py index 3a6013cfee..c79dece563 100644 --- a/backend/tests/sessions/hooks/test_guacamole_hook.py +++ b/backend/tests/sessions/hooks/test_guacamole_hook.py @@ -127,23 +127,12 @@ def match_user_creation_body( "guacamole_create_token", "guacamole_delete_user", "guacamole_apis" ) def test_guacamole_configuration_hook( - session: sessions_models.DatabaseSession, + post_session_creation_hook_request: session_hooks_interface.PostSessionCreationHookRequest, ): """Test that the Guacamole hook creates a user and a connection""" response = guacamole.GuacamoleIntegration().post_session_creation_hook( - session_id="test", - db_session=session, - session={ - "id": "test", - "port": "8080", - "created_at": "2021-01-01T00:00:00", - "host": "test", - }, - user=users_models.DatabaseUser( - name="test", idp_identifier="test", role=users_models.Role.USER - ), - connection_method=tools_models.GuacamoleConnectionMethod(), + post_session_creation_hook_request ) assert response["config"]["guacamole_username"] @@ -157,7 +146,7 @@ def test_guacamole_configuration_hook( "guacamole_create_token", "guacamole_delete_user", "guacamole_apis" ) def test_fail_if_guacamole_unreachable( - session: sessions_models.DatabaseSession, + post_session_creation_hook_request: session_hooks_interface.PostSessionCreationHookRequest, ): """If Guacamole is unreachable, the session hook will abort the session creation""" @@ -171,18 +160,7 @@ def test_fail_if_guacamole_unreachable( with pytest.raises(guacamole.GuacamoleError): guacamole.GuacamoleIntegration().post_session_creation_hook( - session_id="test", - db_session=session, - session={ - "id": "test", - "port": "8080", - "created_at": "2021-01-01T00:00:00", - "host": "test", - }, - user=users_models.DatabaseUser( - name="test", idp_identifier="test", role=users_models.Role.USER - ), - connection_method=tools_models.GuacamoleConnectionMethod(), + post_session_creation_hook_request ) @@ -191,26 +169,18 @@ def test_fail_if_guacamole_unreachable( "guacamole_create_token", "guacamole_delete_user", "guacamole_apis" ) def test_guacamole_hook_not_executed_for_http_method( - session: sessions_models.DatabaseSession, + post_session_creation_hook_request: session_hooks_interface.PostSessionCreationHookRequest, ): """Skip if connection method is not Guacamole If the connection method is not Guacamole, the hook should skip the preparation. """ + post_session_creation_hook_request.connection_method = ( + tools_models.HTTPConnectionMethod() + ) response = guacamole.GuacamoleIntegration().post_session_creation_hook( - session_id="test", - session={ - "id": "test", - "port": "8080", - "created_at": "2021-01-01T00:00:00", - "host": "test", - }, - db_session=session, - user=users_models.DatabaseUser( - name="test", idp_identifier="test", role=users_models.Role.USER - ), - connection_method=tools_models.HTTPConnectionMethod(), + post_session_creation_hook_request ) assert session_hooks_interface.PostSessionCreationHookResult() == response @@ -222,23 +192,12 @@ def test_guacamole_hook_not_executed_for_http_method( "guacamole_create_token", "guacamole_delete_user", "guacamole_apis" ) def test_skip_guacamole_user_deletion_on_404( - session: sessions_models.DatabaseSession, + post_session_creation_hook_request: session_hooks_interface.PostSessionCreationHookRequest, ): """If the user does not exist, the hook should not fail""" response = guacamole.GuacamoleIntegration().post_session_creation_hook( - session_id="test", - db_session=session, - session={ - "id": "test", - "port": "8080", - "created_at": "2021-01-01T00:00:00", - "host": "test", - }, - user=users_models.DatabaseUser( - name="test", idp_identifier="test", role=users_models.Role.USER - ), - connection_method=tools_models.GuacamoleConnectionMethod(), + post_session_creation_hook_request ) assert response["config"] diff --git a/backend/tests/sessions/hooks/test_http_hook.py b/backend/tests/sessions/hooks/test_http_hook.py index 79acf10ed3..1fc65cc55a 100644 --- a/backend/tests/sessions/hooks/test_http_hook.py +++ b/backend/tests/sessions/hooks/test_http_hook.py @@ -1,19 +1,15 @@ # SPDX-FileCopyrightText: Copyright DB InfraGO AG and contributors # SPDX-License-Identifier: Apache-2.0 -import logging - from capellacollab.sessions import models as sessions_models from capellacollab.sessions.hooks import http from capellacollab.sessions.hooks import interface as sessions_hooks_interface from capellacollab.tools import models as tools_models -from capellacollab.users import models as users_models def test_http_hook( session: sessions_models.DatabaseSession, - user: users_models.DatabaseUser, - logger: logging.LoggerAdapter, + session_connection_hook_request: sessions_hooks_interface.SessionConnectionHookRequest, ): session.environment = { "TEST": "test", @@ -23,11 +19,10 @@ def test_http_hook( redirect_url="http://localhost:8000/{TEST}", cookies={"test": "{TEST}"}, ) + session_connection_hook_request.connection_method = connection_method + session_connection_hook_request.db_session = session result = http.HTTPIntegration().session_connection_hook( - db_session=session, - user=user, - connection_method=connection_method, - logger=logger, + session_connection_hook_request ) assert result["cookies"]["test"] == "test" @@ -36,33 +31,26 @@ def test_http_hook( def test_skip_http_hook_if_guacamole( - session: sessions_models.DatabaseSession, - user: users_models.DatabaseUser, - logger: logging.LoggerAdapter, + session_connection_hook_request: sessions_hooks_interface.SessionConnectionHookRequest, ): result = http.HTTPIntegration().session_connection_hook( - db_session=session, - connection_method=tools_models.GuacamoleConnectionMethod(), - user=user, - logger=logger, + session_connection_hook_request ) assert result == sessions_hooks_interface.SessionConnectionHookResult() def test_fail_derive_redirect_url( session: sessions_models.DatabaseSession, - user: users_models.DatabaseUser, - logger: logging.LoggerAdapter, + session_connection_hook_request: sessions_hooks_interface.SessionConnectionHookRequest, ): session.environment = {"TEST": "test"} connection_method = tools_models.HTTPConnectionMethod( redirect_url="http://localhost:8000/{TEST2}" ) + session_connection_hook_request.connection_method = connection_method + session_connection_hook_request.db_session = session result = http.HTTPIntegration().session_connection_hook( - db_session=session, - connection_method=connection_method, - user=user, - logger=logger, + session_connection_hook_request ) assert len(result["warnings"]) == 1 diff --git a/backend/tests/sessions/hooks/test_jupyter_hook.py b/backend/tests/sessions/hooks/test_jupyter_hook.py index a4333561bb..8c9a36817b 100644 --- a/backend/tests/sessions/hooks/test_jupyter_hook.py +++ b/backend/tests/sessions/hooks/test_jupyter_hook.py @@ -3,34 +3,40 @@ import pytest -from sqlalchemy import orm import capellacollab.projects.toolmodels.models as toolmodels_models +from capellacollab.sessions.hooks import interface as hooks_interface from capellacollab.sessions.hooks import jupyter as jupyter_hook +from capellacollab.sessions.operators import models as operators_models from capellacollab.tools import models as tools_models -from capellacollab.users import models as users_models @pytest.mark.usefixtures("project_user") def test_jupyter_successful_volume_mount( jupyter_model: toolmodels_models.DatabaseToolModel, jupyter_tool: tools_models.DatabaseTool, - user: users_models.DatabaseUser, - db: orm.Session, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): + class MockOperator: # pylint: disable=unused-argument def persistent_volume_exists(self, name: str) -> bool: return True + configuration_hook_request.operator = MockOperator() # type: ignore + configuration_hook_request.tool = jupyter_tool + result = jupyter_hook.JupyterIntegration().configuration_hook( - db=db, user=user, tool=jupyter_tool, operator=MockOperator() + configuration_hook_request ) assert not result["warnings"] assert len(result["volumes"]) == 1 + + volume = result["volumes"][0] + assert isinstance(volume, operators_models.PersistentVolume) assert ( - result["volumes"][0].volume_name + volume.volume_name == "shared-workspace-" + jupyter_model.configuration["workspace"] ) @@ -38,16 +44,18 @@ def persistent_volume_exists(self, name: str) -> bool: @pytest.mark.usefixtures("project_user", "jupyter_model") def test_jupyter_volume_mount_not_found( jupyter_tool: tools_models.DatabaseTool, - user: users_models.DatabaseUser, - db: orm.Session, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): class MockOperator: # pylint: disable=unused-argument def persistent_volume_exists(self, name: str) -> bool: return False + configuration_hook_request.operator = MockOperator() # type: ignore + configuration_hook_request.tool = jupyter_tool + result = jupyter_hook.JupyterIntegration().configuration_hook( - db=db, user=user, tool=jupyter_tool, operator=MockOperator() + configuration_hook_request ) assert not result["volumes"] @@ -60,14 +68,12 @@ def persistent_volume_exists(self, name: str) -> bool: @pytest.mark.usefixtures("jupyter_model") def test_jupyter_volume_mount_without_project_access( jupyter_tool: tools_models.DatabaseTool, - user: users_models.DatabaseUser, - db: orm.Session, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): - class MockOperator: - pass + configuration_hook_request.tool = jupyter_tool result = jupyter_hook.JupyterIntegration().configuration_hook( - db=db, user=user, tool=jupyter_tool, operator=MockOperator() + configuration_hook_request ) assert not result["volumes"] diff --git a/backend/tests/sessions/hooks/test_networking_hook.py b/backend/tests/sessions/hooks/test_networking_hook.py index 93934c255c..8219e7ad04 100644 --- a/backend/tests/sessions/hooks/test_networking_hook.py +++ b/backend/tests/sessions/hooks/test_networking_hook.py @@ -5,14 +5,13 @@ import kubernetes.client import pytest -from capellacollab.sessions import models as sessions_models -from capellacollab.sessions import operators +from capellacollab.sessions.hooks import interface as session_hooks_interface from capellacollab.sessions.hooks import networking as networking_hook -from capellacollab.users import models as users_models def test_network_policy_created( - user: users_models.DatabaseUser, monkeypatch: pytest.MonkeyPatch + monkeypatch: pytest.MonkeyPatch, + post_session_creation_hook_request: session_hooks_interface.PostSessionCreationHookRequest, ): network_policy_counter = 0 @@ -32,16 +31,15 @@ def mock_create_namespaced_network_policy( ) networking_hook.NetworkingIntegration().post_session_creation_hook( - session_id="test", - operator=operators.KubernetesOperator(), - user=user, + post_session_creation_hook_request ) assert network_policy_counter == 1 def test_network_policy_deleted( - session: sessions_models.DatabaseSession, monkeypatch: pytest.MonkeyPatch + monkeypatch: pytest.MonkeyPatch, + pre_session_termination_hook_request: session_hooks_interface.PreSessionTerminationHookRequest, ): network_policy_del_counter = 0 @@ -61,8 +59,7 @@ def mock_delete_namespaced_network_policy( ) networking_hook.NetworkingIntegration().pre_session_termination_hook( - operator=operators.KubernetesOperator(), - session=session, + pre_session_termination_hook_request ) assert network_policy_del_counter == 1 diff --git a/backend/tests/sessions/hooks/test_persistent_workspace.py b/backend/tests/sessions/hooks/test_persistent_workspace.py index 56c6ca303f..597039ce27 100644 --- a/backend/tests/sessions/hooks/test_persistent_workspace.py +++ b/backend/tests/sessions/hooks/test_persistent_workspace.py @@ -11,41 +11,34 @@ from capellacollab.sessions import operators from capellacollab.sessions.hooks import interface as hooks_interface from capellacollab.sessions.hooks import persistent_workspace -from capellacollab.tools import models as tools_models from capellacollab.users import models as users_models from capellacollab.users.workspaces import crud as users_workspaces_crud from capellacollab.users.workspaces import models as users_workspaces_models def test_persistent_workspace_mounting_not_allowed( - db: orm.Session, - tool: tools_models.DatabaseTool, - test_user: users_models.DatabaseUser, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): - tool.config.persistent_workspaces.mounting_enabled = False + configuration_hook_request.tool.config.persistent_workspaces.mounting_enabled = ( + False + ) with pytest.raises(sessions_exceptions.WorkspaceMountingNotAllowedError): persistent_workspace.PersistentWorkspaceHook().configuration_hook( - db=db, - operator=operators.KubernetesOperator(), - user=test_user, - session_type=sessions_models.SessionType.PERSISTENT, - tool=tool, + configuration_hook_request ) def persistent_workspace_mounting_readonly_session( - db: orm.Session, - tool: tools_models.DatabaseTool, - test_user: users_models.DatabaseUser, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): + configuration_hook_request.session_type = ( + sessions_models.SessionType.READONLY + ) + response = ( persistent_workspace.PersistentWorkspaceHook().configuration_hook( - db=db, - operator=operators.KubernetesOperator(), - user=test_user, - session_type=sessions_models.SessionType.READONLY, - tool=tool, + configuration_hook_request ) ) @@ -54,9 +47,9 @@ def persistent_workspace_mounting_readonly_session( def test_workspace_is_created( db: orm.Session, - tool: tools_models.DatabaseTool, test_user: users_models.DatabaseUser, monkeypatch: pytest.MonkeyPatch, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): created_volumes = 0 volume_name = None @@ -80,12 +73,11 @@ def mock_create_namespaced_persistent_volume_claim( assert ( len(users_workspaces_crud.get_workspaces_for_user(db, test_user)) == 0 ) + + configuration_hook_request.operator = operators.KubernetesOperator() + configuration_hook_request.user = test_user persistent_workspace.PersistentWorkspaceHook().configuration_hook( - db=db, - operator=operators.KubernetesOperator(), - user=test_user, - session_type=sessions_models.SessionType.PERSISTENT, - tool=tool, + configuration_hook_request ) assert created_volumes == 1 assert isinstance(volume_name, str) @@ -97,10 +89,10 @@ def mock_create_namespaced_persistent_volume_claim( def test_existing_workspace_is_mounted( db: orm.Session, - tool: tools_models.DatabaseTool, test_user: users_models.DatabaseUser, user_workspace: users_workspaces_models.DatabaseWorkspace, monkeypatch: pytest.MonkeyPatch, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): created_volumes = 0 volume_name = None @@ -120,12 +112,11 @@ def mock_create_namespaced_persistent_volume_claim(self, ns, pvc): assert ( len(users_workspaces_crud.get_workspaces_for_user(db, test_user)) == 1 ) + + configuration_hook_request.user = test_user + configuration_hook_request.operator = operators.KubernetesOperator() persistent_workspace.PersistentWorkspaceHook().configuration_hook( - db=db, - operator=operators.KubernetesOperator(), - user=test_user, - session_type=sessions_models.SessionType.PERSISTENT, - tool=tool, + configuration_hook_request ) assert created_volumes == 1 assert isinstance(volume_name, str) diff --git a/backend/tests/sessions/hooks/test_pre_authentiation_hook.py b/backend/tests/sessions/hooks/test_pre_authentiation_hook.py index 50ef300b58..72b08f82e6 100644 --- a/backend/tests/sessions/hooks/test_pre_authentiation_hook.py +++ b/backend/tests/sessions/hooks/test_pre_authentiation_hook.py @@ -1,29 +1,22 @@ # SPDX-FileCopyrightText: Copyright DB InfraGO AG and contributors # SPDX-License-Identifier: Apache-2.0 -import logging - import pytest from capellacollab.sessions import auth as sessions_auth -from capellacollab.sessions import models as sessions_models from capellacollab.sessions.hooks import authentication -from capellacollab.users import models as users_models +from capellacollab.sessions.hooks import interface as sessions_hooks_interface def test_pre_authentication_hook( - session: sessions_models.DatabaseSession, - user: users_models.DatabaseUser, - logger: logging.LoggerAdapter, + session_connection_hook_request: sessions_hooks_interface.SessionConnectionHookRequest, monkeypatch: pytest.MonkeyPatch, ): private_key = sessions_auth.generate_private_key() monkeypatch.setattr(sessions_auth, "PRIVATE_KEY", private_key) result = authentication.PreAuthenticationHook().session_connection_hook( - db_session=session, - user=user, - logger=logger, + session_connection_hook_request ) assert "ccm_session_token" in result["cookies"] diff --git a/backend/tests/sessions/hooks/test_provisioning_hook.py b/backend/tests/sessions/hooks/test_provisioning_hook.py index fe83c0c2cd..9cf3de2307 100644 --- a/backend/tests/sessions/hooks/test_provisioning_hook.py +++ b/backend/tests/sessions/hooks/test_provisioning_hook.py @@ -12,6 +12,7 @@ ) from capellacollab.sessions import exceptions as sessions_exceptions from capellacollab.sessions import models as sessions_models +from capellacollab.sessions.hooks import interface as hooks_interface from capellacollab.sessions.hooks import provisioning as hooks_provisioning from capellacollab.tools import models as tools_models from capellacollab.users import models as users_models @@ -19,30 +20,27 @@ @pytest.mark.usefixtures("project_user") def test_git_models_are_resolved_correctly( - db: orm.Session, - user: users_models.DatabaseUser, - capella_tool: tools_models.DatabaseTool, - capella_tool_version: tools_models.DatabaseVersion, project: projects_models.DatabaseProject, capella_model: toolmodels_models.DatabaseToolModel, git_model: git_models.DatabaseGitModel, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): - """Make sure that the Git models are correctly translated to GIT_MODELS environment""" + """Make sure that the Git models are correctly translated to environment""" + configuration_hook_request.session_type = ( + sessions_models.SessionType.READONLY + ) + configuration_hook_request.provisioning = [ + sessions_models.SessionProvisioningRequest( + project_slug=project.slug, + model_slug=capella_model.slug, + git_model_id=git_model.id, + revision="test", + deep_clone=False, + ) + ] response = hooks_provisioning.ProvisionWorkspaceHook().configuration_hook( - db=db, - tool=capella_tool, - tool_version=capella_tool_version, - user=user, - provisioning=[ - sessions_models.SessionProvisioningRequest( - project_slug=project.slug, - model_slug=capella_model.slug, - git_model_id=git_model.id, - revision="test", - deep_clone=False, - ) - ], + configuration_hook_request ) expected_response_dict = { @@ -67,43 +65,35 @@ def test_git_models_are_resolved_correctly( ] -def test_provisioning_fails_missing_permission( - db: orm.Session, - user: users_models.DatabaseUser, - capella_tool: tools_models.DatabaseTool, - capella_tool_version: tools_models.DatabaseVersion, +async def test_provisioning_fails_missing_permission( project: projects_models.DatabaseProject, capella_model: toolmodels_models.DatabaseToolModel, git_model: git_models.DatabaseGitModel, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): """Make sure that provisioning fails when the user does not have the correct permissions""" + configuration_hook_request.provisioning = [ + sessions_models.SessionProvisioningRequest( + project_slug=project.slug, + model_slug=capella_model.slug, + git_model_id=git_model.id, + revision="main", + deep_clone=False, + ) + ] with pytest.raises(fastapi.HTTPException): hooks_provisioning.ProvisionWorkspaceHook().configuration_hook( - db=db, - tool=capella_tool, - tool_version=capella_tool_version, - user=user, - provisioning=[ - sessions_models.SessionProvisioningRequest( - project_slug=project.slug, - model_slug=capella_model.slug, - git_model_id=git_model.id, - revision="main", - deep_clone=False, - ) - ], + configuration_hook_request ) @pytest.mark.usefixtures("project_user") def test_provisioning_fails_too_many_models_requested( - db: orm.Session, - user: users_models.DatabaseUser, capella_tool: tools_models.DatabaseTool, - capella_tool_version: tools_models.DatabaseVersion, project: projects_models.DatabaseProject, capella_model: toolmodels_models.DatabaseToolModel, git_model: git_models.DatabaseGitModel, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): capella_tool.config.provisioning.max_number_of_models = 1 @@ -115,58 +105,50 @@ def test_provisioning_fails_too_many_models_requested( deep_clone=False, ) + configuration_hook_request.provisioning = [ + session_provisioning_request, + session_provisioning_request, + ] with pytest.raises( sessions_exceptions.TooManyModelsRequestedToProvisionError ): hooks_provisioning.ProvisionWorkspaceHook().configuration_hook( - db=db, - tool=capella_tool, - tool_version=capella_tool_version, - user=user, - provisioning=[ - session_provisioning_request, - session_provisioning_request, - ], + configuration_hook_request ) def test_tool_model_mismatch( - db: orm.Session, - user: users_models.DatabaseUser, - tool: tools_models.DatabaseTool, - tool_version: tools_models.DatabaseVersion, project: projects_models.DatabaseProject, capella_model: toolmodels_models.DatabaseToolModel, + tool_version: tools_models.DatabaseVersion, git_model: git_models.DatabaseGitModel, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): """Make sure that provisioning fails when the provided model doesn't match the selected tool""" + configuration_hook_request.tool_version = tool_version + configuration_hook_request.provisioning = [ + sessions_models.SessionProvisioningRequest( + project_slug=project.slug, + model_slug=capella_model.slug, + git_model_id=git_model.id, + revision="main", + deep_clone=False, + ) + ] with pytest.raises(sessions_exceptions.ToolAndModelMismatchError): hooks_provisioning.ProvisionWorkspaceHook().configuration_hook( - db=db, - tool=tool, - tool_version=tool_version, - user=user, - provisioning=[ - sessions_models.SessionProvisioningRequest( - project_slug=project.slug, - model_slug=capella_model.slug, - git_model_id=git_model.id, - revision="main", - deep_clone=False, - ) - ], + configuration_hook_request ) def test_provision_session_with_compatible_tool_versions( db: orm.Session, - admin: users_models.DatabaseUser, tool_version: tools_models.DatabaseVersion, - tool: tools_models.DatabaseTool, capella_tool_version: tools_models.DatabaseVersion, project: projects_models.DatabaseProject, capella_model: toolmodels_models.DatabaseToolModel, git_model: git_models.DatabaseGitModel, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): """Make sure that provisioning is successful when the tool is compatible with the tool of the model""" @@ -174,19 +156,21 @@ def test_provision_session_with_compatible_tool_versions( orm.attributes.flag_modified(tool_version, "config") db.commit() + configuration_hook_request.session_type = ( + sessions_models.SessionType.READONLY + ) + + configuration_hook_request.provisioning = [ + sessions_models.SessionProvisioningRequest( + project_slug=project.slug, + model_slug=capella_model.slug, + git_model_id=git_model.id, + revision="main", + deep_clone=False, + ) + ] + configuration_hook_request.user.role = users_models.Role.ADMIN response = hooks_provisioning.ProvisionWorkspaceHook().configuration_hook( - db=db, - tool=tool, - tool_version=tool_version, - user=admin, - provisioning=[ - sessions_models.SessionProvisioningRequest( - project_slug=project.slug, - model_slug=capella_model.slug, - git_model_id=git_model.id, - revision="main", - deep_clone=False, - ) - ], + configuration_hook_request ) assert response["environment"]["CAPELLACOLLAB_SESSION_PROVISIONING"] diff --git a/backend/tests/sessions/hooks/test_pure_variants.py b/backend/tests/sessions/hooks/test_pure_variants.py index 6f23039e6e..a16f3e79fe 100644 --- a/backend/tests/sessions/hooks/test_pure_variants.py +++ b/backend/tests/sessions/hooks/test_pure_variants.py @@ -18,7 +18,6 @@ from capellacollab.sessions.hooks import pure_variants from capellacollab.tools import crud as tools_crud from capellacollab.tools import models as tools_models -from capellacollab.users import models as users_models @pytest.fixture(name="pure_variants_tool") @@ -51,16 +50,17 @@ def fixture_pure_variants_model( def test_skip_for_read_only_sessions( - db: orm.Session, - user: users_models.DatabaseUser, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): """pure::variants has no read-only support Therefore, the hook also shouldn't do anything for read-only sessions. """ - + configuration_hook_request.session_type = ( + sessions_models.SessionType.READONLY + ) result = pure_variants.PureVariantsIntegration().configuration_hook( - db, user, sessions_models.SessionType.READONLY + configuration_hook_request ) assert result == hooks_interface.ConfigurationHookResult() @@ -69,8 +69,8 @@ def test_skip_for_read_only_sessions( @pytest.mark.usefixtures("project_user") def test_skip_when_user_has_no_pv_access( db: orm.Session, - user: users_models.DatabaseUser, pure_variants_model: toolmodels_models.DatabaseToolModel, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): """If a user has no access to a project with a model that has the pure::variants restriction enabled, skip loading of the license. @@ -84,7 +84,7 @@ def test_skip_when_user_has_no_pv_access( ) result = pure_variants.PureVariantsIntegration().configuration_hook( - db, user, sessions_models.SessionType.PERSISTENT + configuration_hook_request ) assert "environment" not in result @@ -95,8 +95,8 @@ def test_skip_when_user_has_no_pv_access( @pytest.mark.usefixtures("project_user") def test_skip_when_license_server_not_configured( db: orm.Session, - user: users_models.DatabaseUser, pure_variants_model: toolmodels_models.DatabaseToolModel, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): """If no pure::variants license is configured in the settings, skip loading of the license. @@ -111,7 +111,7 @@ def test_skip_when_license_server_not_configured( ) result = pure_variants.PureVariantsIntegration().configuration_hook( - db, user, sessions_models.SessionType.PERSISTENT + configuration_hook_request ) assert "environment" not in result @@ -122,8 +122,8 @@ def test_skip_when_license_server_not_configured( @pytest.mark.usefixtures("project_user", "pure_variants_license") def test_inject_pure_variants_license_information( db: orm.Session, - user: users_models.DatabaseUser, pure_variants_model: toolmodels_models.DatabaseToolModel, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): """Test that the configured license information is properly injected in the session container. @@ -138,7 +138,7 @@ def test_inject_pure_variants_license_information( ) result = pure_variants.PureVariantsIntegration().configuration_hook( - db, user, sessions_models.SessionType.PERSISTENT + configuration_hook_request ) assert result["environment"] == { diff --git a/backend/tests/sessions/hooks/test_session_preparation.py b/backend/tests/sessions/hooks/test_session_preparation.py index ab12c919a4..2677774b34 100644 --- a/backend/tests/sessions/hooks/test_session_preparation.py +++ b/backend/tests/sessions/hooks/test_session_preparation.py @@ -5,33 +5,32 @@ from capellacollab.sessions import models as sessions_models from capellacollab.sessions.hooks import interface as hooks_interface from capellacollab.sessions.hooks import session_preparation -from capellacollab.tools import models as tools_models -def test_session_preparation_hook(tool: tools_models.DatabaseTool): +def test_session_preparation_hook( + configuration_hook_request: hooks_interface.ConfigurationHookRequest, +): """Test that the session preparation hook registers a shared volume""" - + configuration_hook_request.session_type = ( + sessions_models.SessionType.READONLY + ) result = session_preparation.GitRepositoryCloningHook().configuration_hook( - session_type=sessions_models.SessionType.READONLY, - session_id="session-id", - tool=tool, + configuration_hook_request ) assert len(result["volumes"]) == 1 assert len(result["init_volumes"]) == 1 assert result["volumes"][0] == result["init_volumes"][0] - assert result["volumes"][0].name == "session-id-models" + assert result["volumes"][0].name == "nxylxqbmfqwvswlqlcbsirvrt-models" def test_session_preparation_hook_with_persistent_session( - tool: tools_models.DatabaseTool, + configuration_hook_request: hooks_interface.ConfigurationHookRequest, ): """Test that the session preparation hook doesn't do anything for persistent sessions""" result = session_preparation.GitRepositoryCloningHook().configuration_hook( - session_type=sessions_models.SessionType.PERSISTENT, - session_id="session-id", - tool=tool, + configuration_hook_request ) assert result == hooks_interface.ConfigurationHookResult() diff --git a/backend/tests/sessions/hooks/test_t4c_hook.py b/backend/tests/sessions/hooks/test_t4c_hook.py index e058e96ae2..3a26641d1d 100644 --- a/backend/tests/sessions/hooks/test_t4c_hook.py +++ b/backend/tests/sessions/hooks/test_t4c_hook.py @@ -53,16 +53,12 @@ def fixture_mock_add_user_to_repository_failed( @responses.activate @pytest.mark.usefixtures("t4c_model", "project_user") def test_t4c_configuration_hook( - db: orm.Session, user: users_models.DatabaseUser, - capella_tool_version: tools_models.DatabaseVersion, mock_add_user_to_repository: responses.BaseResponse, + configuration_hook_request: sessions_hooks_interface.ConfigurationHookRequest, ): result = t4c.T4CIntegration().configuration_hook( - db=db, - user=user, - tool_version=capella_tool_version, - session_type=sessions_models.SessionType.PERSISTENT, + configuration_hook_request ) assert result["environment"]["T4C_LICENCE_SECRET"] @@ -76,21 +72,20 @@ def test_t4c_configuration_hook( @responses.activate @pytest.mark.usefixtures("t4c_model") def test_t4c_configuration_hook_as_admin( - db: orm.Session, - admin: users_models.DatabaseUser, - capella_tool_version: tools_models.DatabaseVersion, mock_add_user_to_repository: responses.BaseResponse, + configuration_hook_request: sessions_hooks_interface.ConfigurationHookRequest, ): + configuration_hook_request.user.role = users_models.Role.ADMIN result = t4c.T4CIntegration().configuration_hook( - db=db, - user=admin, - tool_version=capella_tool_version, - session_type=sessions_models.SessionType.PERSISTENT, + configuration_hook_request ) assert result["environment"]["T4C_LICENCE_SECRET"] assert len(json.loads(result["environment"]["T4C_JSON"])) == 1 - assert result["environment"]["T4C_USERNAME"] == admin.name + assert ( + result["environment"]["T4C_USERNAME"] + == configuration_hook_request.user.name + ) assert result["environment"]["T4C_PASSWORD"] assert not result["warnings"] assert mock_add_user_to_repository.call_count == 1 @@ -100,11 +95,11 @@ def test_t4c_configuration_hook_as_admin( @pytest.mark.usefixtures("t4c_model") def test_t4c_configuration_hook_with_same_repository_used_twice( db: orm.Session, - admin: users_models.DatabaseUser, project: projects_models.DatabaseProject, capella_tool_version: tools_models.DatabaseVersion, mock_add_user_to_repository: responses.BaseResponse, t4c_repository: settings_t4c_repositories_models.DatabaseT4CRepository, + configuration_hook_request: sessions_hooks_interface.ConfigurationHookRequest, ): model = toolmodels_models.PostToolModel( name="test2", description="test", tool_id=capella_tool_version.tool.id @@ -113,11 +108,9 @@ def test_t4c_configuration_hook_with_same_repository_used_twice( db, project, model, capella_tool_version.tool, capella_tool_version ) models_t4c_crud.create_t4c_model(db, db_model, t4c_repository, "default2") + configuration_hook_request.user.role = users_models.Role.ADMIN result = t4c.T4CIntegration().configuration_hook( - db=db, - user=admin, - tool_version=capella_tool_version, - session_type=sessions_models.SessionType.PERSISTENT, + configuration_hook_request ) assert len(json.loads(result["environment"]["T4C_JSON"])) == 1 @@ -128,16 +121,14 @@ def test_t4c_configuration_hook_with_same_repository_used_twice( @responses.activate @pytest.mark.usefixtures("t4c_model", "project_user") def test_t4c_configuration_hook_failure( - db: orm.Session, user: users_models.DatabaseUser, - capella_tool_version: tools_models.DatabaseVersion, mock_add_user_to_repository_failed: responses.BaseResponse, + configuration_hook_request: sessions_hooks_interface.ConfigurationHookRequest, ): + """Test behavior when T4C API call fails""" + result = t4c.T4CIntegration().configuration_hook( - db=db, - user=user, - tool_version=capella_tool_version, - session_type=sessions_models.SessionType.PERSISTENT, + configuration_hook_request ) assert result["environment"]["T4C_LICENCE_SECRET"] @@ -154,17 +145,14 @@ def test_configuration_hook_for_archived_project( project: projects_models.DatabaseProject, db: orm.Session, user: users_models.DatabaseUser, - capella_tool_version: tools_models.DatabaseVersion, mock_add_user_to_repository: responses.BaseResponse, + configuration_hook_request: sessions_hooks_interface.ConfigurationHookRequest, ): project.is_archived = True db.commit() result = t4c.T4CIntegration().configuration_hook( - db=db, - user=user, - tool_version=capella_tool_version, - session_type=sessions_models.SessionType.PERSISTENT, + configuration_hook_request ) assert not result["environment"]["T4C_LICENCE_SECRET"] @@ -180,18 +168,15 @@ def test_configuration_hook_for_archived_project( def test_configuration_hook_as_rw_user( db: orm.Session, user: users_models.DatabaseUser, - capella_tool_version: tools_models.DatabaseVersion, mock_add_user_to_repository: responses.BaseResponse, project_user: projects_users_models.ProjectUserAssociation, + configuration_hook_request: sessions_hooks_interface.ConfigurationHookRequest, ): project_user.permission = projects_users_models.ProjectUserPermission.READ db.commit() result = t4c.T4CIntegration().configuration_hook( - db=db, - user=user, - tool_version=capella_tool_version, - session_type=sessions_models.SessionType.PERSISTENT, + configuration_hook_request ) assert not result["environment"]["T4C_LICENCE_SECRET"] @@ -209,6 +194,7 @@ def test_configuration_hook_for_compatible_tool( user: users_models.DatabaseUser, capella_tool_version: tools_models.DatabaseVersion, mock_add_user_to_repository: responses.BaseResponse, + configuration_hook_request: sessions_hooks_interface.ConfigurationHookRequest, ): custom_tool = tools_crud.create_tool( db, tools_models.CreateTool(name="custom") @@ -223,11 +209,9 @@ def test_configuration_hook_for_compatible_tool( db, custom_tool, create_compatible_tool_version ) + configuration_hook_request.tool_version = compatible_tool_version result = t4c.T4CIntegration().configuration_hook( - db=db, - user=user, - tool_version=compatible_tool_version, - session_type=sessions_models.SessionType.PERSISTENT, + configuration_hook_request ) assert result["environment"]["T4C_LICENCE_SECRET"] @@ -239,28 +223,25 @@ def test_configuration_hook_for_compatible_tool( def test_t4c_configuration_hook_non_persistent( - db: orm.Session, - user: users_models.DatabaseUser, - tool_version: tools_models.DatabaseVersion, + configuration_hook_request: sessions_hooks_interface.ConfigurationHookRequest, ): + configuration_hook_request.session_type = ( + sessions_models.SessionType.READONLY + ) result = t4c.T4CIntegration().configuration_hook( - db=db, - user=user, - tool_version=tool_version, - session_type=sessions_models.SessionType.READONLY, + configuration_hook_request ) assert result == sessions_hooks_interface.ConfigurationHookResult() def test_t4c_connection_hook_non_persistent( - user: users_models.DatabaseUser, session: sessions_models.DatabaseSession, + session_connection_hook_request: sessions_hooks_interface.SessionConnectionHookRequest, ): session.type = sessions_models.SessionType.READONLY result = t4c.T4CIntegration().session_connection_hook( - db_session=session, - user=user, + session_connection_hook_request ) assert result == sessions_hooks_interface.SessionConnectionHookResult() @@ -268,8 +249,8 @@ def test_t4c_connection_hook_non_persistent( def test_t4c_connection_hook_shared_session( db: orm.Session, - user: users_models.DatabaseUser, session: sessions_models.DatabaseSession, + session_connection_hook_request: sessions_hooks_interface.SessionConnectionHookRequest, ): user2 = users_crud.create_user( db, @@ -280,21 +261,19 @@ def test_t4c_connection_hook_shared_session( ) session.owner = user2 result = t4c.T4CIntegration().session_connection_hook( - db_session=session, - user=user, + session_connection_hook_request ) assert result == sessions_hooks_interface.SessionConnectionHookResult() def test_t4c_connection_hook( - user: users_models.DatabaseUser, session: sessions_models.DatabaseSession, + session_connection_hook_request: sessions_hooks_interface.SessionConnectionHookRequest, ): session.environment = {"T4C_PASSWORD": "test"} result = t4c.T4CIntegration().session_connection_hook( - db_session=session, - user=user, + session_connection_hook_request ) assert result["t4c_token"] == "test" @@ -303,18 +282,21 @@ def test_t4c_connection_hook( @responses.activate @pytest.mark.usefixtures("t4c_model", "project_user") def test_t4c_termination_hook( - db: orm.Session, session: sessions_models.DatabaseSession, user: users_models.DatabaseUser, t4c_instance: t4c_models.DatabaseT4CInstance, capella_tool_version: tools_models.DatabaseVersion, + pre_session_termination_hook_request: sessions_hooks_interface.PreSessionTerminationHookRequest, ): session.version = capella_tool_version + pre_session_termination_hook_request.session = session rsp = responses.delete( f"{t4c_instance.rest_api}/users/{user.name}?repositoryName=test", status=200, ) - t4c.T4CIntegration().pre_session_termination_hook(db=db, session=session) + t4c.T4CIntegration().pre_session_termination_hook( + pre_session_termination_hook_request + ) assert rsp.call_count == 1 diff --git a/backend/tests/sessions/test_session_hooks.py b/backend/tests/sessions/test_session_hooks.py index 376d9ba9bb..8e6a122a18 100644 --- a/backend/tests/sessions/test_session_hooks.py +++ b/backend/tests/sessions/test_session_hooks.py @@ -44,52 +44,25 @@ class TestSessionHook(hooks_interface.HookRegistration): post_termination_hook_counter = 0 def configuration_hook( - self, - db: orm.Session, - operator: operators.KubernetesOperator, - user: users_models.DatabaseUser, - tool: tools_models.DatabaseTool, - tool_version: tools_models.DatabaseVersion, - session_type: sessions_models.SessionType, - connection_method: tools_models.ToolSessionConnectionMethod, - provisioning: list[sessions_models.SessionProvisioningRequest], - session_id: str, - **kwargs, + self, request: hooks_interface.ConfigurationHookRequest ) -> hooks_interface.ConfigurationHookResult: self.configuration_hook_counter += 1 return hooks_interface.ConfigurationHookResult() def post_session_creation_hook( - self, - session_id: str, - session: k8s.Session, - db_session: sessions_models.DatabaseSession, - operator: operators.KubernetesOperator, - user: users_models.DatabaseUser, - connection_method: tools_models.ToolSessionConnectionMethod, - **kwargs, + self, request: hooks_interface.PostSessionCreationHookRequest ) -> hooks_interface.PostSessionCreationHookResult: self.post_session_creation_hook_counter += 1 return hooks_interface.PostSessionCreationHookResult() def session_connection_hook( - self, - db: orm.Session, - db_session: sessions_models.DatabaseSession, - connection_method: tools_models.ToolSessionConnectionMethod, - logger: logging.LoggerAdapter, - **kwargs, + self, request: hooks_interface.SessionConnectionHookRequest ) -> hooks_interface.SessionConnectionHookResult: self.session_connection_hook_counter += 1 return hooks_interface.SessionConnectionHookResult() def pre_session_termination_hook( - self, - db: orm.Session, - operator: operators.KubernetesOperator, - session: sessions_models.DatabaseSession, - connection_method: tools_models.ToolSessionConnectionMethod, - **kwargs, + self, request: hooks_interface.PreSessionTerminationHookRequest ) -> hooks_interface.PreSessionTerminationHookResult: self.post_termination_hook_counter += 1 return hooks_interface.PreSessionTerminationHookResult() @@ -99,9 +72,7 @@ def pre_session_termination_hook( def fixture_session_hook(monkeypatch: pytest.MonkeyPatch) -> TestSessionHook: hook = TestSessionHook() - REGISTER_HOOKS_AUTO_USE: dict[str, hooks_interface.HookRegistration] = { - "test": hook, - } + REGISTER_HOOKS_AUTO_USE: list[hooks_interface.HookRegistration] = [hook] monkeypatch.setattr( sessions_hooks, "REGISTER_HOOKS_AUTO_USE", REGISTER_HOOKS_AUTO_USE @@ -131,6 +102,7 @@ def test_hook_calls_during_session_request( mockoperator: MockOperator, session_hook: TestSessionHook, tool: tools_models.DatabaseTool, + logger: logging.LoggerAdapter, ): """Test that the relevant session hooks are called during a session request. @@ -156,7 +128,7 @@ def test_hook_calls_during_session_request( user, db, mockoperator, # type: ignore - logging.getLogger("test"), + logger, ) assert session_hook.configuration_hook_counter == 1 @@ -168,6 +140,7 @@ def test_hook_calls_during_session_request( def test_hook_call_during_session_connection( db: orm.Session, session: sessions_models.DatabaseSession, + logger: logging.LoggerAdapter, ): """Test that the session hook is called when connecting to a session""" @@ -176,7 +149,7 @@ def test_hook_call_during_session_connection( db, session, session.owner, - logging.getLogger("test"), + logger, ) diff --git a/backend/tests/sessions/test_session_routes.py b/backend/tests/sessions/test_session_routes.py index aa8982226b..3e0e582fa0 100644 --- a/backend/tests/sessions/test_session_routes.py +++ b/backend/tests/sessions/test_session_routes.py @@ -59,7 +59,7 @@ def get_mock_operator(): @pytest.fixture(autouse=True, name="session_hook") def fixture_session_hook(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(sessions_hooks, "REGISTER_HOOKS_AUTO_USE", {}) + monkeypatch.setattr(sessions_hooks, "REGISTER_HOOKS_AUTO_USE", []) @pytest.mark.usefixtures("user", "session")