diff --git a/src/server/datamanager/pipeline_queue.py b/src/server/datamanager/pipeline_queue.py index 1cf1137..a252197 100644 --- a/src/server/datamanager/pipeline_queue.py +++ b/src/server/datamanager/pipeline_queue.py @@ -283,6 +283,12 @@ def get_pipeline_status(project_id, sandbox_id): return pipeline_status, message, detail +def get_is_pipeline_active(project_uuid, sandbox_uuid): + result = get_pipeline_task_id_status(project_uuid, sandbox_uuid) + pipeline_status = result.status + return pipeline_status in ["PENDING", "SENT", "STARTED"] + + def set_pipeline_to_active(project_uuid, sandbox_uuid, execution_type): sandbox = Sandbox.objects.get(uuid=sandbox_uuid) sandbox.active = True diff --git a/src/server/datamanager/sandbox.py b/src/server/datamanager/sandbox.py index 3e752ca..81fc9f9 100644 --- a/src/server/datamanager/sandbox.py +++ b/src/server/datamanager/sandbox.py @@ -26,7 +26,13 @@ from io import BytesIO import datamanager.pipeline_queue as pipeline_queue -from datamanager.models import Project, Sandbox, TeamMember, delete_caches_from_disk +from datamanager.models import ( + Project, + Sandbox, + TeamMember, + Query, + delete_caches_from_disk, +) from datamanager.serializers import SandboxConfigSerializer, SandboxSerializer from datamanager.serializers.utils import SandboxAsyncSerializer from datamanager.tasks import ( @@ -54,9 +60,11 @@ from engine.base.pipeline_utils import ( make_pipeline_linear, ) +from engine.drivers import get_selected_features +from library.model_generators import model_generator from logger.data_logger import usage_log from logger.log_handler import LogHandler -from pandas import DataFrame +from pandas import DataFrame, notnull, NA from rest_framework import generics, permissions from rest_framework.decorators import api_view, permission_classes from rest_framework.exceptions import NotFound, PermissionDenied, ValidationError @@ -111,6 +119,36 @@ def validate_pipeline(pipeline): return pipeline +def get_label_column_from_pipeline(sandbox): + label_column = None + if len(sandbox.pipeline) > 0: + query_step = next((x for x in sandbox.pipeline if x["type"] == "query"), None) + if query_step is not None: + query_object = Query.objects.get(name=query_step["name"]) + label_column = query_object.label_column + + return label_column + + +def calculate_feature_stats(feature_data, feature_table, label_column, sandbox_uuid): + (_ft, selected_features, selected_feature_cols) = get_selected_features( + feature_table, sandbox_uuid + ) + selected_feature_cols.append(label_column) + print(label_column) + + return { + "feature_statistics": model_generator.compute_feature_stats( + feature_data[selected_feature_cols] + ), + "feature_summary": ( + selected_features.where(notnull(selected_features), NA) + .fillna("") + .to_dict(orient="records") + ), + } + + class SandboxCrudMixin(object): """Adds CRUD operations to GenericAPIView derived classes""" @@ -289,6 +327,9 @@ def delete_sandbox(self, request, project_uuid): else: raise NotFound(FileErrors.fil_inv_path) + def __get_step_params(self): + pass + def sandbox_data(self, request, project_uuid): """Return the columns of data requested.""" @@ -319,12 +360,8 @@ def sandbox_data(self, request, project_uuid): status=400, ) - status, message, detail = pipeline_queue.get_pipeline_status( - project_uuid, sandbox.uuid - ) - - if status in ["PENDING", "SENT"]: - return Response({"detail": "Pipeline is currently starting."}, status=400) + if pipeline_queue.get_is_pipeline_active(project_uuid, sandbox.uuid): + return Response({"detail": "Pipeline is currently running."}, status=400) cache_manager = CacheManager(sandbox, pipeline, pipeline_id=sandbox.uuid) @@ -354,6 +391,87 @@ def sandbox_data(self, request, project_uuid): } ) + def sandbox_features_stats(self, request, project_uuid): + sandbox = self.get_object() + pipeline = sandbox.pipeline + step = request.query_params.get("pipeline_step", None) + if step is None: + step = request.data.get("pipeline_step", None) + if step: + step = int(step) + + if step is None: + return Response( + {"detail": "pipeline_step queryparam was not provided."}, + status=400, + ) + if step not in range(len(pipeline)): + return Response( + {"detail": "Selected pipeline_step is not in executed pipeline."}, + status=400, + ) + if len(pipeline[step]["outputs"]) < 2: + return Response( + { + "detail": f"No features data stored for pipeline step {pipeline[step]['name']}" + }, + status=400, + ) + if pipeline_queue.get_is_pipeline_active(project_uuid, sandbox.uuid): + return Response({"detail": "Pipeline is currently running."}, status=400) + + cache_manager = CacheManager(sandbox, pipeline, pipeline_id=sandbox.uuid) + + data_name, features_name = pipeline[step]["outputs"] + fd_index = ft_index = 0 + + feature_data, data_pages = cache_manager.get_result_from_cache( + data_name, fd_index, cache_key="data" + ) + feature_table, ft_pages = cache_manager.get_result_from_cache( + features_name, ft_index, cache_key="data" + ) + + if feature_data is None or feature_table is None: + return Response( + {"detail": "No features data stored for this pipeline"}, status=400 + ) + + while fd_index < data_pages: + fd_item, _ = cache_manager.get_result_from_cache( + data_name, fd_index, cache_key="data" + ) + feature_data.append(fd_item) + fd_index += 1 + + while ft_index < ft_pages: + ft_item, _ = cache_manager.get_result_from_cache( + data_name, ft_index, cache_key="data" + ) + feature_table.append(ft_item) + ft_index += 1 + + project = Project.objects.get(uuid=project_uuid) + label_column = get_label_column_from_pipeline(sandbox=sandbox) or "Label" + + usage_log( + PJID=project, + operation="pipeline_statistic", + detail={"page_index": "1"}, + team=project.team, + team_member=request.user.teammember, + ) + + return Response( + { + "label_column": label_column, + "feature_data": feature_data, + **calculate_feature_stats( + feature_data, feature_table, label_column, sandbox.uuid + ), + } + ) + class SandboxAsyncMixin(object): """Adds async operations to GenericAPIView derived classes""" @@ -501,6 +619,7 @@ def async_retrieve(self, request, project_uuid, sandbox_uuid, err_queue=deque()) sandbox, sandbox.pipeline, pipeline_id=sandbox_uuid ) errors = cache_manager.get_cache_list("errors") + statistics_summary = {} if status in ["FAILURE"]: return Response( @@ -544,8 +663,17 @@ def async_retrieve(self, request, project_uuid, sandbox_uuid, err_queue=deque()) if sandbox.result_type == "pipeline": result_name = "pipeline_result.{}".format(sandbox_uuid) feature_name = "feature_table.{}".format(sandbox_uuid) - (summary, _) = cache_manager.get_result_from_cache(feature_name) + (feature_table, _) = cache_manager.get_result_from_cache( + feature_name + ) + (feature_data, _) = cache_manager.get_result_from_cache(result_name) summary_key = "feature_table" + label_column = ( + get_label_column_from_pipeline(sandbox=sandbox) or "Label" + ) + statistics_summary = calculate_feature_stats( + feature_data, feature_table, label_column, sandbox.uuid + ) elif sandbox.result_type == "grid_search": result_name = "grid_result.{}".format(sandbox_uuid) @@ -610,6 +738,7 @@ def sanitize_return(value): "results": sanitize_return(data), "status": "SUCCESS", summary_key: sanitize_return(summary), + "statistics_summary": sanitize_return(statistics_summary), "execution_summary": sanitize_return(execution_summary), "page_index": page_index, "number_of_pages": number_of_pages, @@ -875,6 +1004,19 @@ def get(self, request, *args, **kwargs): return self.sandbox_data(request, self.kwargs["project_uuid"]) +class SandboxFeatureStatsView(SandboxCrudMixin, generics.GenericAPIView): + permission_classes = (IsAuthenticated, DjangoModelPermissions) + lookup_field = "uuid" + + def get_queryset(self): + return filter_query_statistic_sandbox( + self.request.user, self.kwargs["project_uuid"] + ) + + def get(self, request, *args, **kwargs): + return self.sandbox_features_stats(request, self.kwargs["project_uuid"]) + + @extend_schema_view( get=extend_schema( summary="Retrieve sandbox status or results if completed", @@ -934,6 +1076,11 @@ def delete(self, request, *args, **kwargs): SandboxDataView.as_view(), name="sandbox-data", ), + url( + r"^project/(?P[^/]+)/sandbox/(?P[^/]+)/features-stats/", + SandboxFeatureStatsView.as_view(), + name="sandbox-features-statistic", + ), url( r"^project/(?P[^/]+)/sandbox-async/(?P[^/]+)/$", SandboxAsyncView.as_view(),