Skip to content

Commit

Permalink
feat(api): add ability to use environment variables to set spark anno…
Browse files Browse the repository at this point in the history
…tations and resources (#165)

Centralised spark submit to a specific method in order to be able to use environment variables to tune pods resources and set annotations.

Removed also API used to test spark submit
  • Loading branch information
maocorte authored Aug 13, 2024
1 parent 022e94c commit c0843b3
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 50 deletions.
9 changes: 9 additions & 0 deletions api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,12 @@ SPARK_ON_K8S_EXECUTOR_MAX_INSTANCES: 2
```

Adjust these variables as needed to allocate more resources or modify the number of executor instances for your specific use case.

### Pods annotations

If you need for some reasons to add annotations to driver and executor pods, you can adjust the following environment variables in the backend container passing a valid json as in the example:

```
SPARK_ON_K8S_SPARK_DRIVER_ANNOTATIONS: '{"my.annotation/driver": "my-value"}'
SPARK_ON_K8S_SPARK_EXECUTOR_ANNOTATIONS: '{"my.annotation/executor": "my-value"}'
```
2 changes: 0 additions & 2 deletions api/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from app.routes.infer_schema_route import InferSchemaRoute
from app.routes.metrics_route import MetricsRoute
from app.routes.model_route import ModelRoute
from app.routes.spark_job_route import SparkJobRoute
from app.routes.upload_dataset_route import UploadDatasetRoute
from app.services.file_service import FileService
from app.services.metrics_service import MetricsService
Expand Down Expand Up @@ -122,7 +121,6 @@ async def lifespan(fastapi: FastAPI):
app.include_router(UploadDatasetRoute.get_router(file_service), prefix='/api/models')
app.include_router(InferSchemaRoute.get_router(file_service), prefix='/api/schema')
app.include_router(MetricsRoute.get_router(metrics_service), prefix='/api/models')
app.include_router(SparkJobRoute.get_router(spark_k8s_service), prefix='/api/jobs')

app.include_router(HealthcheckRoute.get_healthcheck_route())

Expand Down
15 changes: 0 additions & 15 deletions api/app/routes/spark_job_route.py

This file was deleted.

76 changes: 43 additions & 33 deletions api/app/services/file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from fastapi import HTTPException, UploadFile
from fastapi_pagination import Page, Params
import pandas as pd
from spark_on_k8s.client import SparkOnK8S
from spark_on_k8s.client import ExecutorInstances, PodResources, SparkOnK8S
from spark_on_k8s.utils.configuration import Configuration

from app.core.config.config import create_secrets, get_config
from app.db.dao.current_dataset_dao import CurrentDatasetDAO
Expand Down Expand Up @@ -108,21 +109,15 @@ def upload_reference_file(
logger.debug('File %s has been correctly stored in the db', inserted_file)

spark_config = get_config().spark_config
self.spark_k8s_client.submit_app(
image=spark_config.spark_image,
self.__submit_app(
app_name=str(model_out.uuid),
app_path=spark_config.spark_reference_app_path,
app_arguments=[
model_out.model_dump_json(),
path.replace('s3://', 's3a://'),
str(inserted_file.uuid),
ReferenceDatasetMetrics.__tablename__,
],
app_name=str(model_out.uuid),
namespace=spark_config.spark_namespace,
service_account=spark_config.spark_service_account,
image_pull_policy=spark_config.spark_image_pull_policy,
app_waiter='no_wait',
secret_values=create_secrets(),
)

return ReferenceDatasetDTO.from_reference_dataset(inserted_file)
Expand Down Expand Up @@ -163,21 +158,15 @@ def bind_reference_file(
logger.debug('File %s has been correctly stored in the db', inserted_file)

spark_config = get_config().spark_config
self.spark_k8s_client.submit_app(
image=spark_config.spark_image,
self.__submit_app(
app_name=str(model_out.uuid),
app_path=spark_config.spark_reference_app_path,
app_arguments=[
model_out.model_dump_json(),
file_ref.file_url.replace('s3://', 's3a://'),
str(inserted_file.uuid),
ReferenceDatasetMetrics.__tablename__,
],
app_name=str(model_out.uuid),
namespace=spark_config.spark_namespace,
service_account=spark_config.spark_service_account,
image_pull_policy=spark_config.spark_image_pull_policy,
app_waiter='no_wait',
secret_values=create_secrets(),
)

return ReferenceDatasetDTO.from_reference_dataset(inserted_file)
Expand Down Expand Up @@ -252,8 +241,8 @@ def upload_current_file(
logger.debug('File %s has been correctly stored in the db', inserted_file)

spark_config = get_config().spark_config
self.spark_k8s_client.submit_app(
image=spark_config.spark_image,
self.__submit_app(
app_name=str(model_out.uuid),
app_path=spark_config.spark_current_app_path,
app_arguments=[
model_out.model_dump_json(),
Expand All @@ -262,12 +251,6 @@ def upload_current_file(
reference_dataset.path.replace('s3://', 's3a://'),
CurrentDatasetMetrics.__tablename__,
],
app_name=str(model_out.uuid),
namespace=spark_config.spark_namespace,
service_account=spark_config.spark_service_account,
image_pull_policy=spark_config.spark_image_pull_policy,
app_waiter='no_wait',
secret_values=create_secrets(),
)

return CurrentDatasetDTO.from_current_dataset(inserted_file)
Expand Down Expand Up @@ -311,8 +294,8 @@ def bind_current_file(
logger.debug('File %s has been correctly stored in the db', inserted_file)

spark_config = get_config().spark_config
self.spark_k8s_client.submit_app(
image=spark_config.spark_image,
self.__submit_app(
app_name=str(model_out.uuid),
app_path=spark_config.spark_current_app_path,
app_arguments=[
model_out.model_dump_json(),
Expand All @@ -321,12 +304,6 @@ def bind_current_file(
reference_dataset.path.replace('s3://', 's3a://'),
CurrentDatasetMetrics.__tablename__,
],
app_name=str(model_out.uuid),
namespace=spark_config.spark_namespace,
service_account=spark_config.spark_service_account,
image_pull_policy=spark_config.spark_image_pull_policy,
app_waiter='no_wait',
secret_values=create_secrets(),
)

return CurrentDatasetDTO.from_current_dataset(inserted_file)
Expand Down Expand Up @@ -471,3 +448,36 @@ def validate_file(

csv_file.file.flush()
csv_file.file.seek(0)

def __submit_app(
self, app_name: str, app_path: str, app_arguments: List[str]
) -> None:
spark_config = get_config().spark_config
self.spark_k8s_client.submit_app(
image=spark_config.spark_image,
app_path=app_path,
app_arguments=app_arguments,
app_name=app_name,
namespace=spark_config.spark_namespace,
service_account=spark_config.spark_service_account,
image_pull_policy=spark_config.spark_image_pull_policy,
app_waiter='no_wait',
secret_values=create_secrets(),
driver_annotations=Configuration.SPARK_ON_K8S_SPARK_DRIVER_ANNOTATIONS,
executor_annotations=Configuration.SPARK_ON_K8S_SPARK_EXECUTOR_ANNOTATIONS,
driver_resources=PodResources(
cpu=Configuration.SPARK_ON_K8S_DRIVER_CPU,
memory=Configuration.SPARK_ON_K8S_DRIVER_MEMORY,
memory_overhead=Configuration.SPARK_ON_K8S_DRIVER_MEMORY_OVERHEAD,
),
executor_resources=PodResources(
cpu=Configuration.SPARK_ON_K8S_EXECUTOR_CPU,
memory=Configuration.SPARK_ON_K8S_EXECUTOR_MEMORY,
memory_overhead=Configuration.SPARK_ON_K8S_EXECUTOR_MEMORY_OVERHEAD,
),
executor_instances=ExecutorInstances(
min=Configuration.SPARK_ON_K8S_EXECUTOR_MIN_INSTANCES,
max=Configuration.SPARK_ON_K8S_EXECUTOR_MAX_INSTANCES,
initial=Configuration.SPARK_ON_K8S_EXECUTOR_INITIAL_INSTANCES,
),
)

0 comments on commit c0843b3

Please sign in to comment.