Skip to content

Commit

Permalink
[SDL-5646] Added sandbox_features_stats endpoint and calculate_featur…
Browse files Browse the repository at this point in the history
…e_stats to pipeline async result (#34)
  • Loading branch information
mkaliberda authored Oct 25, 2024
1 parent 08cadb8 commit 71f05a0
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 9 deletions.
6 changes: 6 additions & 0 deletions src/server/datamanager/pipeline_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,12 @@ def get_pipeline_status(project_id, sandbox_id):
return pipeline_status, message, detail


def get_is_pipeline_active(project_uuid, sandbox_uuid):
result = get_pipeline_task_id_status(project_uuid, sandbox_uuid)
pipeline_status = result.status
return pipeline_status in ["PENDING", "SENT", "STARTED"]


def set_pipeline_to_active(project_uuid, sandbox_uuid, execution_type):
sandbox = Sandbox.objects.get(uuid=sandbox_uuid)
sandbox.active = True
Expand Down
165 changes: 156 additions & 9 deletions src/server/datamanager/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from io import BytesIO

import datamanager.pipeline_queue as pipeline_queue
from datamanager.models import Project, Sandbox, TeamMember, delete_caches_from_disk
from datamanager.models import (
Project,
Sandbox,
TeamMember,
Query,
delete_caches_from_disk,
)
from datamanager.serializers import SandboxConfigSerializer, SandboxSerializer
from datamanager.serializers.utils import SandboxAsyncSerializer
from datamanager.tasks import (
Expand Down Expand Up @@ -54,9 +60,11 @@
from engine.base.pipeline_utils import (
make_pipeline_linear,
)
from engine.drivers import get_selected_features
from library.model_generators import model_generator
from logger.data_logger import usage_log
from logger.log_handler import LogHandler
from pandas import DataFrame
from pandas import DataFrame, notnull, NA
from rest_framework import generics, permissions
from rest_framework.decorators import api_view, permission_classes
from rest_framework.exceptions import NotFound, PermissionDenied, ValidationError
Expand Down Expand Up @@ -111,6 +119,36 @@ def validate_pipeline(pipeline):
return pipeline


def get_label_column_from_pipeline(sandbox):
label_column = None
if len(sandbox.pipeline) > 0:
query_step = next((x for x in sandbox.pipeline if x["type"] == "query"), None)
if query_step is not None:
query_object = Query.objects.get(name=query_step["name"])
label_column = query_object.label_column

return label_column


def calculate_feature_stats(feature_data, feature_table, label_column, sandbox_uuid):
(_ft, selected_features, selected_feature_cols) = get_selected_features(
feature_table, sandbox_uuid
)
selected_feature_cols.append(label_column)
print(label_column)

return {
"feature_statistics": model_generator.compute_feature_stats(
feature_data[selected_feature_cols]
),
"feature_summary": (
selected_features.where(notnull(selected_features), NA)
.fillna("")
.to_dict(orient="records")
),
}


class SandboxCrudMixin(object):
"""Adds CRUD operations to GenericAPIView derived classes"""

Expand Down Expand Up @@ -289,6 +327,9 @@ def delete_sandbox(self, request, project_uuid):
else:
raise NotFound(FileErrors.fil_inv_path)

def __get_step_params(self):
pass

def sandbox_data(self, request, project_uuid):
"""Return the columns of data requested."""

Expand Down Expand Up @@ -319,12 +360,8 @@ def sandbox_data(self, request, project_uuid):
status=400,
)

status, message, detail = pipeline_queue.get_pipeline_status(
project_uuid, sandbox.uuid
)

if status in ["PENDING", "SENT"]:
return Response({"detail": "Pipeline is currently starting."}, status=400)
if pipeline_queue.get_is_pipeline_active(project_uuid, sandbox.uuid):
return Response({"detail": "Pipeline is currently running."}, status=400)

cache_manager = CacheManager(sandbox, pipeline, pipeline_id=sandbox.uuid)

Expand Down Expand Up @@ -354,6 +391,87 @@ def sandbox_data(self, request, project_uuid):
}
)

def sandbox_features_stats(self, request, project_uuid):
sandbox = self.get_object()
pipeline = sandbox.pipeline
step = request.query_params.get("pipeline_step", None)
if step is None:
step = request.data.get("pipeline_step", None)
if step:
step = int(step)

if step is None:
return Response(
{"detail": "pipeline_step queryparam was not provided."},
status=400,
)
if step not in range(len(pipeline)):
return Response(
{"detail": "Selected pipeline_step is not in executed pipeline."},
status=400,
)
if len(pipeline[step]["outputs"]) < 2:
return Response(
{
"detail": f"No features data stored for pipeline step {pipeline[step]['name']}"
},
status=400,
)
if pipeline_queue.get_is_pipeline_active(project_uuid, sandbox.uuid):
return Response({"detail": "Pipeline is currently running."}, status=400)

cache_manager = CacheManager(sandbox, pipeline, pipeline_id=sandbox.uuid)

data_name, features_name = pipeline[step]["outputs"]
fd_index = ft_index = 0

feature_data, data_pages = cache_manager.get_result_from_cache(
data_name, fd_index, cache_key="data"
)
feature_table, ft_pages = cache_manager.get_result_from_cache(
features_name, ft_index, cache_key="data"
)

if feature_data is None or feature_table is None:
return Response(
{"detail": "No features data stored for this pipeline"}, status=400
)

while fd_index < data_pages:
fd_item, _ = cache_manager.get_result_from_cache(
data_name, fd_index, cache_key="data"
)
feature_data.append(fd_item)
fd_index += 1

while ft_index < ft_pages:
ft_item, _ = cache_manager.get_result_from_cache(
data_name, ft_index, cache_key="data"
)
feature_table.append(ft_item)
ft_index += 1

project = Project.objects.get(uuid=project_uuid)
label_column = get_label_column_from_pipeline(sandbox=sandbox) or "Label"

usage_log(
PJID=project,
operation="pipeline_statistic",
detail={"page_index": "1"},
team=project.team,
team_member=request.user.teammember,
)

return Response(
{
"label_column": label_column,
"feature_data": feature_data,
**calculate_feature_stats(
feature_data, feature_table, label_column, sandbox.uuid
),
}
)


class SandboxAsyncMixin(object):
"""Adds async operations to GenericAPIView derived classes"""
Expand Down Expand Up @@ -501,6 +619,7 @@ def async_retrieve(self, request, project_uuid, sandbox_uuid, err_queue=deque())
sandbox, sandbox.pipeline, pipeline_id=sandbox_uuid
)
errors = cache_manager.get_cache_list("errors")
statistics_summary = {}

if status in ["FAILURE"]:
return Response(
Expand Down Expand Up @@ -544,8 +663,17 @@ def async_retrieve(self, request, project_uuid, sandbox_uuid, err_queue=deque())
if sandbox.result_type == "pipeline":
result_name = "pipeline_result.{}".format(sandbox_uuid)
feature_name = "feature_table.{}".format(sandbox_uuid)
(summary, _) = cache_manager.get_result_from_cache(feature_name)
(feature_table, _) = cache_manager.get_result_from_cache(
feature_name
)
(feature_data, _) = cache_manager.get_result_from_cache(result_name)
summary_key = "feature_table"
label_column = (
get_label_column_from_pipeline(sandbox=sandbox) or "Label"
)
statistics_summary = calculate_feature_stats(
feature_data, feature_table, label_column, sandbox.uuid
)

elif sandbox.result_type == "grid_search":
result_name = "grid_result.{}".format(sandbox_uuid)
Expand Down Expand Up @@ -610,6 +738,7 @@ def sanitize_return(value):
"results": sanitize_return(data),
"status": "SUCCESS",
summary_key: sanitize_return(summary),
"statistics_summary": sanitize_return(statistics_summary),
"execution_summary": sanitize_return(execution_summary),
"page_index": page_index,
"number_of_pages": number_of_pages,
Expand Down Expand Up @@ -875,6 +1004,19 @@ def get(self, request, *args, **kwargs):
return self.sandbox_data(request, self.kwargs["project_uuid"])


class SandboxFeatureStatsView(SandboxCrudMixin, generics.GenericAPIView):
permission_classes = (IsAuthenticated, DjangoModelPermissions)
lookup_field = "uuid"

def get_queryset(self):
return filter_query_statistic_sandbox(
self.request.user, self.kwargs["project_uuid"]
)

def get(self, request, *args, **kwargs):
return self.sandbox_features_stats(request, self.kwargs["project_uuid"])


@extend_schema_view(
get=extend_schema(
summary="Retrieve sandbox status or results if completed",
Expand Down Expand Up @@ -934,6 +1076,11 @@ def delete(self, request, *args, **kwargs):
SandboxDataView.as_view(),
name="sandbox-data",
),
url(
r"^project/(?P<project_uuid>[^/]+)/sandbox/(?P<uuid>[^/]+)/features-stats/",
SandboxFeatureStatsView.as_view(),
name="sandbox-features-statistic",
),
url(
r"^project/(?P<project_uuid>[^/]+)/sandbox-async/(?P<uuid>[^/]+)/$",
SandboxAsyncView.as_view(),
Expand Down

0 comments on commit 71f05a0

Please sign in to comment.