From 6e6f69872c5b0e018bb852023238845a0e153bf6 Mon Sep 17 00:00:00 2001 From: Jeev B Date: Tue, 31 Oct 2023 21:17:35 -0700 Subject: [PATCH] [Extended Resources] GPU Accelerators (#1843) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * pip through to container Signed-off-by: Yee Hing Tong Signed-off-by: Jeev B * move around Signed-off-by: Yee Hing Tong Signed-off-by: Jeev B * add asserts Signed-off-by: Yee Hing Tong Signed-off-by: Jeev B * delete bad line Signed-off-by: Yee Hing Tong Signed-off-by: Jeev B * switch to abc and add support for gpu unpartitioned Signed-off-by: Jeev B * Add Azure-specific headers when uploading to blob storage (#1784) * Add Azure-specific headers when uploading to blob storage Signed-off-by: Victor Delépine * Add comment about HTTP 201 check Signed-off-by: Victor Delépine --------- Signed-off-by: Victor Delépine Signed-off-by: Jeev B * Add async delete function in base_agent (#1800) Signed-off-by: Future Outlier Co-authored-by: Future Outlier Signed-off-by: Jeev B * Add support for execution name prefixes (#1803) Signed-off-by: troychiu Signed-off-by: Jeev B * Remove ref in output (#1794) Signed-off-by: Yee Hing Tong Signed-off-by: Jeev B * Inherit directly from DataClassJsonMixin instead of using @dataclass_json for improved static type checking (#1801) * Inherit directly from DataClassJsonMixin instead of @dataclass_json for improved static type checking As it says in the dataclasses-json README: https://github.com/lidatong/dataclasses-json/blob/89578cb9ebed290e70dba8946bfdb68ff6746755/README.md?plain=1#L111-L129, we can use inheritance for improved static type checking; this one change eliminates something like 467 pyright errors from the flytekit module Signed-off-by: Matthew Hoffman Signed-off-by: Jeev B * Async file sensor (#1790) --------- Signed-off-by: Kevin Su Signed-off-by: Jeev B * Eager workflows to support async workflows (#1579) * Eager workflows to support async workflows Signed-off-by: Niels Bantilan * move array node maptask to experimental/__init__.py Signed-off-by: Niels Bantilan * clean up docs Signed-off-by: Niels Bantilan * clean up Signed-off-by: Niels Bantilan * more clean up Signed-off-by: Niels Bantilan * docs cleanup Signed-off-by: Niels Bantilan * Update test_eager_workflows.py * clean up timeout handling Signed-off-by: Niels Bantilan * fix lint Signed-off-by: Niels Bantilan --------- Signed-off-by: Niels Bantilan Signed-off-by: Jeev B * Enable SecretsManager.get to load and return bytes (#1798) * fix secretsmanager Signed-off-by: Yue Shang * fix lint issue Signed-off-by: Yue Shang * add doc Signed-off-by: Yue Shang * fix github check Signed-off-by: Yue Shang --------- Signed-off-by: Yue Shang Signed-off-by: Jeev B * Batch upload flyte directory (#1806) * Batch upload flyte directory Signed-off-by: Kevin Su * Update get method Signed-off-by: Kevin Su * Move batch size to type engine Signed-off-by: Kevin Su * comment Signed-off-by: Kevin Su * update comment Signed-off-by: Kevin Su * Update flytekit/core/type_engine.py Co-authored-by: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> * Add test Signed-off-by: Kevin Su --------- Signed-off-by: Kevin Su Co-authored-by: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Signed-off-by: Jeev B * Better error messaging for overrides (#1807) - using incorrect type of overrides - using incorrect type for resources - using promises in overrides Signed-off-by: Ketan Umare Signed-off-by: Jeev B * Run remote Launchplan from `pyflyte run` (#1785) * Beautified pyflyte run even for every task and workflow - identify a task or a workflow - task or workflow help menus show types and use rich to beautify Signed-off-by: Ketan Umare * one more improvement Signed-off-by: Ketan Umare * updated Signed-off-by: Ketan Umare * updated command Signed-off-by: Ketan Umare * Updated Signed-off-by: Ketan Umare * updated formatting Signed-off-by: Ketan Umare * updated Signed-off-by: Ketan Umare * updated Signed-off-by: Ketan Umare * bug fixed in types Signed-off-by: Ketan Umare * Updated Signed-off-by: Ketan Umare * lint Signed-off-by: Kevin Su --------- Signed-off-by: Ketan Umare Signed-off-by: Kevin Su Co-authored-by: Kevin Su Signed-off-by: Jeev B * Add is none function (#1757) Signed-off-by: Kevin Su Signed-off-by: Jeev B * Dynamic workflow should not throw nested task warning (#1812) Signed-off-by: oliverhu Signed-off-by: Jeev B * Add a manual image building GH action (#1816) Signed-off-by: Yee Hing Tong Signed-off-by: Jeev B * catch abfs protocol in data_persistence.py/get_filesystem and set anon to False (#1813) Signed-off-by: Jan Fiedler Signed-off-by: Jeev B * None doesnt work Signed-off-by: Jeev B * unpartitioned selector Signed-off-by: Jeev B * Fix list of annotated structured dataset (#1817) Signed-off-by: Yee Hing Tong Signed-off-by: Jeev B * Support the flytectl config.yaml admin.clientSecretEnvVar option in flytekit (#1819) * Support the flytectl config.yaml admin.clientSecretEnvVar option in flytekit Signed-off-by: Chao-Heng Lee * remove helper of getting env var. Signed-off-by: Chao-Heng Lee * refactor variable name. Signed-off-by: Chao-Heng Lee --------- Signed-off-by: Chao-Heng Lee Signed-off-by: Jeev B * Async agent delete function for while loop case (#1802) Signed-off-by: Future Outlier Signed-off-by: Kevin Su Co-authored-by: Future Outlier Co-authored-by: Kevin Su Signed-off-by: Jeev B * refactor Signed-off-by: Jeev B * fix docs warnings (#1827) Signed-off-by: Jeev B * Fix extract_task_module (#1829) --------- Signed-off-by: Kevin Su Signed-off-by: Jeev B * Feat: Add type support for pydantic BaseModels (#1660) Signed-off-by: Adrian Rumpold Signed-off-by: Arthur Signed-off-by: wirthual Signed-off-by: Kevin Su Signed-off-by: Yee Hing Tong Signed-off-by: eduardo apolinario Signed-off-by: Jeev B * add test for unspecified mig Signed-off-by: Jeev B * add support for overriding accelerator Signed-off-by: Jeev B * cleanup Signed-off-by: Jeev B * move from core to extras Signed-off-by: Jeev B * fixes Signed-off-by: Jeev B * fixes Signed-off-by: Jeev B * fixes Signed-off-by: Jeev B * cleanup Signed-off-by: Jeev B * Make FlyteRemote slightly more copy/pastable (#1830) Signed-off-by: Katrina Rogan Signed-off-by: Jeev B * Pyflyte meta inputs (#1823) * Re-orgining pyflyte run Signed-off-by: Ketan Umare * Pyflyte beautified and simplified Signed-off-by: Ketan Umare * fixed unit test Signed-off-by: Ketan Umare * Added Launch options Signed-off-by: Ketan Umare * lint fix Signed-off-by: Ketan Umare * test fix Signed-off-by: Ketan Umare * fixing docs failure Signed-off-by: Ketan Umare --------- Signed-off-by: Ketan Umare Signed-off-by: Jeev B * Use mashumaro to serialize/deserialize dataclass (#1735) Signed-off-by: HH Signed-off-by: hhcs9527 Signed-off-by: Matthew Hoffman Co-authored-by: Matthew Hoffman Signed-off-by: Jeev B * Databricks Agent (#1797) Signed-off-by: Future Outlier Signed-off-by: Kevin Su Co-authored-by: Future Outlier Co-authored-by: Kevin Su Signed-off-by: Jeev B * Prometheus metrics (#1815) Signed-off-by: Kevin Su Signed-off-by: Jeev B * Pyflyte register optionally activates schedule (#1832) * Pyflyte register auto activates schedule Signed-off-by: Ketan Umare * comment addressed Signed-off-by: Ketan Umare --------- Signed-off-by: Ketan Umare Signed-off-by: Jeev B * Remove versions 3.9 and 3.10 (#1831) Signed-off-by: Yee Hing Tong Signed-off-by: Jeev B * Snowflake agent (#1799) Signed-off-by: hhcs9527 Signed-off-by: HH Signed-off-by: Jeev B * Update agent metric name (#1835) Signed-off-by: Kevin Su Signed-off-by: Jeev B * MemVerge MMCloud Agent (#1821) Signed-off-by: Edwin Yu Signed-off-by: Jeev B * Add download badges in readme (#1836) Signed-off-by: Kevin Su Signed-off-by: Jeev B * Eager local entrypoint and support for offloaded types (#1833) * implement eager workflow local entrypoint, support offloaded types Signed-off-by: Niels Bantilan * wip local entrypoint Signed-off-by: Niels Bantilan * add tests Signed-off-by: Niels Bantilan * add local entrypoint tests Signed-off-by: Niels Bantilan * update eager unit tests, delete test script Signed-off-by: Niels Bantilan * clean up tests Signed-off-by: Niels Bantilan * update ci Signed-off-by: Niels Bantilan * update ci Signed-off-by: Niels Bantilan * update ci Signed-off-by: Niels Bantilan * update ci Signed-off-by: Niels Bantilan * remove push step Signed-off-by: Niels Bantilan --------- Signed-off-by: Niels Bantilan Signed-off-by: Jeev B * update requirements and add snowflake agent to api reference (#1838) * update requirements and add snowflake agent to api reference Signed-off-by: Samhita Alla * update requirements Signed-off-by: Samhita Alla * remove versions Signed-off-by: Samhita Alla * remove tensorflow-macos Signed-off-by: Samhita Alla * lint Signed-off-by: Samhita Alla * downgrade sphinxcontrib-youtube package Signed-off-by: Samhita Alla --------- Signed-off-by: Samhita Alla Signed-off-by: Jeev B * Fix: Make sure decks created in elastic task workers are transferred to parent process (#1837) * Transfer decks created in the worker process to the parent process Signed-off-by: Fabio Graetz * Add test for decks in elastic tasks Signed-off-by: Fabio Graetz * Update plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py Signed-off-by: Fabio Graetz * Update plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py Signed-off-by: Fabio Graetz --------- Signed-off-by: Fabio Graetz Signed-off-by: Jeev B * add accept grpc (#1841) * add accept grpc Signed-off-by: Yee Hing Tong Signed-off-by: Jeev B * unpin setup.py grpc Signed-off-by: Yee Hing Tong Signed-off-by: Jeev B * Revert "add accept grpc" This reverts commit 2294592f9ca30e7758d18c900fa058049d26ddda. Signed-off-by: Jeev B * default headers interceptor Signed-off-by: Jeev B * setup.py Signed-off-by: Jeev B * fixes Signed-off-by: Jeev B * fmt Signed-off-by: Jeev B * move prometheus-client import Signed-off-by: Jeev B --------- Signed-off-by: Yee Hing Tong Signed-off-by: Jeev B Co-authored-by: Jeev B Signed-off-by: Jeev B * Feat: Enable `flytekit` to authenticate with proxy in front of FlyteAdmin (#1787) * Introduce authenticator engine and make proxy auth work Signed-off-by: Fabio Grätz * Use proxy authed session for client credentials flow Signed-off-by: Fabio Grätz * Don't use authenticator engine but do proxy authentication via existing external command authenticator Signed-off-by: Fabio Grätz * Add docstring to AuthenticationHTTPAdapter Signed-off-by: Fabio Grätz * Address todo in docstring Signed-off-by: Fabio Grätz * Create blank session if none provided Signed-off-by: Fabio Grätz * Create blank session if none provided in get_token Signed-off-by: Fabio Grätz * Refresh proxy creds in session when not existing without triggering 401 Signed-off-by: Fabio Grätz * Add test for get_session Signed-off-by: Fabio Grätz * Move auth helper test into existing module Signed-off-by: Fabio Grätz * Move auth helper test into existing module Signed-off-by: Fabio Grätz * Add test for upgrade_channel_to_proxy_authenticated Signed-off-by: Fabio Grätz * Auth helper tests without use of responses package Signed-off-by: Fabio Grätz * Feat: Add plugin for generating GCP IAP ID tokens via external command (#1795) * Add external command plugin to generate id tokens for identity aware proxy Signed-off-by: Fabio Grätz * Retrieve desktop app client secret from gcp secret manager Signed-off-by: Fabio Grätz * Remove comments Signed-off-by: Fabio Grätz * Introduce a command group that allows adding a command to generate service account id tokens later Signed-off-by: Fabio Grätz * Document how to use plugin and deploy Flyte with IAP Signed-off-by: Fabio Grätz * Minor corrections README.md Signed-off-by: Fabio Grätz --------- Signed-off-by: Fabio Grätz Co-authored-by: Fabio Grätz Signed-off-by: Fabio Grätz * Use proxy auth'ed session for device code auth flow Signed-off-by: Fabio Grätz * Fix token client tests Signed-off-by: Fabio Grätz * Make poll token endpoint test more specific Signed-off-by: Fabio Grätz * Make test_client_creds_authenticator test work and more specific Signed-off-by: Fabio Grätz * Make test_client_creds_authenticator_with_custom_scopes test work and more specific Signed-off-by: Fabio Grätz * Implement subcommand to generate id tokens for service accounts Signed-off-by: Fabio Graetz * Test id token generation from service accounts Signed-off-by: Fabio Graetz * Fix plugin requirements Signed-off-by: Fabio Graetz * Document usage of generate-service-account-id-token subcommand Signed-off-by: Fabio Grätz * Document alternative ways to obtain service account id tokens Signed-off-by: Fabio Grätz --------- Signed-off-by: Fabio Grätz Signed-off-by: Fabio Graetz Co-authored-by: Fabio Grätz Signed-off-by: Jeev B * bump flyteidl Signed-off-by: Jeev B * make requirements Signed-off-by: Jeev B * fix failing tests Signed-off-by: Jeev B * move gpu accelerator to flyteidl.core.Resources Signed-off-by: Jeev B * Use ResourceExtensions for extended resources Signed-off-by: Jeev B * cleanup Signed-off-by: Jeev B * Switch to using ExtendedResources in TaskTemplate Signed-off-by: Jeev B * cleanups Signed-off-by: Jeev B * update flyteidl Signed-off-by: Jeev B * Replace _core_task imports with tasks_pb2 Signed-off-by: Jeev B * less verbose definitions Signed-off-by: Jeev B * Attempt at less confusing syntax Signed-off-by: Jeev B * Streamline UX Signed-off-by: Jeev B * Run make fmt Signed-off-by: Jeev B --------- Signed-off-by: Yee Hing Tong Signed-off-by: Jeev B Signed-off-by: Victor Delépine Signed-off-by: Future Outlier Signed-off-by: troychiu Signed-off-by: Matthew Hoffman Signed-off-by: Niels Bantilan Signed-off-by: Yue Shang Signed-off-by: Kevin Su Signed-off-by: Ketan Umare Signed-off-by: oliverhu Signed-off-by: Jan Fiedler Signed-off-by: Chao-Heng Lee Signed-off-by: Adrian Rumpold Signed-off-by: Arthur Signed-off-by: wirthual Signed-off-by: eduardo apolinario Signed-off-by: Katrina Rogan Signed-off-by: HH Signed-off-by: hhcs9527 Signed-off-by: Edwin Yu Signed-off-by: Samhita Alla Signed-off-by: Fabio Graetz Signed-off-by: Fabio Grätz Co-authored-by: Yee Hing Tong Co-authored-by: Victor Delépine Co-authored-by: Future-Outlier Co-authored-by: Future Outlier Co-authored-by: Yi Chiu <114708546+troychiu@users.noreply.github.com> Co-authored-by: Matthew Hoffman Co-authored-by: Kevin Su Co-authored-by: Niels Bantilan Co-authored-by: Yue Shang <138256885+ysysys3074@users.noreply.github.com> Co-authored-by: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Co-authored-by: Ketan Umare <16888709+kumare3@users.noreply.github.com> Co-authored-by: Keqiu Hu Co-authored-by: Jan Fiedler <89976021+fiedlerNr9@users.noreply.github.com> Co-authored-by: Chao-Heng Lee Co-authored-by: Samhita Alla Co-authored-by: Arthur Böök <49250723+ArthurBook@users.noreply.github.com> Co-authored-by: Katrina Rogan Co-authored-by: Po Han(Hank) Huang Co-authored-by: Edwin Yu <92917168+edwinyyyu@users.noreply.github.com> Co-authored-by: Fabio M. Graetz, Ph.D Co-authored-by: Fabio Grätz --- flytekit/core/base_task.py | 8 ++ flytekit/core/node.py | 8 ++ flytekit/core/python_auto_container.py | 15 ++++ flytekit/core/task.py | 6 ++ flytekit/extras/accelerators.py | 90 +++++++++++++++++++ flytekit/models/core/workflow.py | 16 +++- flytekit/models/task.py | 13 +++ flytekit/tools/translator.py | 9 +- tests/flytekit/common/parameterizers.py | 20 ++++- .../flytekit/unit/core/test_node_creation.py | 27 ++++++ .../flytekit/unit/extras/test_accelerators.py | 64 +++++++++++++ .../unit/models/core/test_workflow.py | 11 ++- tests/flytekit/unit/models/test_tasks.py | 34 ++++++- .../unit/models/test_workflow_closure.py | 4 + 14 files changed, 315 insertions(+), 10 deletions(-) create mode 100644 flytekit/extras/accelerators.py create mode 100644 tests/flytekit/unit/extras/test_accelerators.py diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 286097c668e..2d7938c67c2 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -24,6 +24,8 @@ from dataclasses import dataclass from typing import Any, Coroutine, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast +from flyteidl.core import tasks_pb2 + from flytekit.configuration import SerializationSettings from flytekit.core.context_manager import ( ExecutionParameters, @@ -344,6 +346,12 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str] """ return None + def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]: + """ + Returns the extended resources to allocate to the task on hosted Flyte. + """ + return None + def local_execution_mode(self) -> ExecutionState.Mode: """ """ return ExecutionState.Mode.LOCAL_TASK_EXECUTION diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 1038c005219..2957abe0dfb 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -4,6 +4,8 @@ import typing from typing import Any, List +from flyteidl.core import tasks_pb2 + from flytekit.core.resources import Resources, convert_resources_to_resource_model from flytekit.core.utils import _dnsify from flytekit.loggers import logger @@ -62,6 +64,7 @@ def __init__( self._aliases: _workflow_model.Alias = None self._outputs = None self._resources: typing.Optional[_resources_model] = None + self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None def runs_before(self, other: Node): """ @@ -172,6 +175,11 @@ def with_overrides(self, *args, **kwargs): assert_not_promise(v, "container_image") self.flyte_entity._container_image = v + if "accelerator" in kwargs: + v = kwargs["accelerator"] + assert_not_promise(v, "accelerator") + self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=v.to_flyte_idl()) + return self diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 5335410a79f..1ad1de0216f 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -5,6 +5,8 @@ from abc import ABC from typing import Callable, Dict, List, Optional, TypeVar, Union +from flyteidl.core import tasks_pb2 + from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin from flytekit.core.context_manager import FlyteContextManager @@ -13,6 +15,7 @@ from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance, extract_task_module from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit +from flytekit.extras.accelerators import BaseAccelerator from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec from flytekit.loggers import logger from flytekit.models import task as _task_model @@ -44,6 +47,7 @@ def __init__( secret_requests: Optional[List[Secret]] = None, pod_template: Optional[PodTemplate] = None, pod_template_name: Optional[str] = None, + accelerator: Optional[BaseAccelerator] = None, **kwargs, ): """ @@ -70,6 +74,7 @@ def __init__( - `AWS Parameter store `__ :param pod_template: Custom PodTemplate for this task. :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. + :param accelerator: The accelerator to use for this task. """ sec_ctx = None if secret_requests: @@ -110,6 +115,7 @@ def __init__( self._get_command_fn = self.get_default_command self.pod_template = pod_template + self.accelerator = accelerator @property def task_resolver(self) -> TaskResolverMixin: @@ -219,6 +225,15 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str] return {} return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name} + def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]: + """ + Returns the extended resources to allocate to the task on hosted Flyte. + """ + if self.accelerator is None: + return None + + return tasks_pb2.ExtendedResources(gpu_accelerator=self.accelerator.to_flyte_idl()) + class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): """ diff --git a/flytekit/core/task.py b/flytekit/core/task.py index ce16e9634d5..547abd41fa1 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -8,6 +8,7 @@ from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference from flytekit.core.resources import Resources +from flytekit.extras.accelerators import BaseAccelerator from flytekit.image_spec.image_spec import ImageSpec from flytekit.models.documentation import Documentation from flytekit.models.security import Secret @@ -102,6 +103,7 @@ def task( enable_deck: Optional[bool] = ..., pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., + accelerator: Optional[BaseAccelerator] = ..., ) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ... @@ -129,6 +131,7 @@ def task( enable_deck: Optional[bool] = ..., pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., + accelerator: Optional[BaseAccelerator] = ..., ) -> Union[PythonFunctionTask[T], Callable[..., FuncOut]]: ... @@ -155,6 +158,7 @@ def task( enable_deck: Optional[bool] = None, pod_template: Optional["PodTemplate"] = None, pod_template_name: Optional[str] = None, + accelerator: Optional[BaseAccelerator] = None, ) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]], PythonFunctionTask[T], Callable[..., FuncOut]]: """ This is the core decorator to use for any task type in flytekit. @@ -248,6 +252,7 @@ def foo2(): :param docs: Documentation about this task :param pod_template: Custom PodTemplate for this task. :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. + :param accelerator: The accelerator to use for this task. """ def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: @@ -277,6 +282,7 @@ def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: docs=docs, pod_template=pod_template, pod_template_name=pod_template_name, + accelerator=accelerator, ) update_wrapper(task_instance, fn) return task_instance diff --git a/flytekit/extras/accelerators.py b/flytekit/extras/accelerators.py new file mode 100644 index 00000000000..3615f32bdbd --- /dev/null +++ b/flytekit/extras/accelerators.py @@ -0,0 +1,90 @@ +import abc +import copy +from typing import ClassVar, Generic, Optional, Type, TypeVar + +from flyteidl.core import tasks_pb2 + +T = TypeVar("T") +MIG = TypeVar("MIG", bound="MultiInstanceGPUAccelerator") + + +class BaseAccelerator(abc.ABC, Generic[T]): + @abc.abstractmethod + def to_flyte_idl(self) -> T: + ... + + +class GPUAccelerator(BaseAccelerator): + def __init__(self, device: str) -> None: + self._device = device + + def to_flyte_idl(self) -> tasks_pb2.GPUAccelerator: + return tasks_pb2.GPUAccelerator(device=self._device) + + +A10G = GPUAccelerator("nvidia-a10g") +L4 = GPUAccelerator("nvidia-l4-vws") +K80 = GPUAccelerator("nvidia-tesla-k80") +M60 = GPUAccelerator("nvidia-tesla-m60") +P4 = GPUAccelerator("nvidia-tesla-p4") +P100 = GPUAccelerator("nvidia-tesla-p100") +T4 = GPUAccelerator("nvidia-tesla-t4") +V100 = GPUAccelerator("nvidia-tesla-v100") + + +class MultiInstanceGPUAccelerator(BaseAccelerator): + device: ClassVar[str] + _partition_size: Optional[str] + + @property + def unpartitioned(self: MIG) -> MIG: + instance = copy.deepcopy(self) + instance._partition_size = None + return instance + + @classmethod + def partitioned(cls: Type[MIG], partition_size: str) -> MIG: + instance = cls() + instance._partition_size = partition_size + return instance + + def to_flyte_idl(self) -> tasks_pb2.GPUAccelerator: + msg = tasks_pb2.GPUAccelerator(device=self.device) + if not hasattr(self, "_partition_size"): + return msg + + if self._partition_size is None: + msg.unpartitioned = True + else: + msg.partition_size = self._partition_size + return msg + + +class _A100_Base(MultiInstanceGPUAccelerator): + device = "nvidia-tesla-a100" + + +class _A100(_A100_Base): + partition_1g_5gb = _A100_Base.partitioned("1g.5gb") + partition_2g_10gb = _A100_Base.partitioned("2g.10gb") + partition_3g_20gb = _A100_Base.partitioned("3g.20gb") + partition_4g_20gb = _A100_Base.partitioned("4g.20gb") + partition_7g_40gb = _A100_Base.partitioned("7g.40gb") + + +A100 = _A100() + + +class _A100_80GB_Base(MultiInstanceGPUAccelerator): + device = "nvidia-a100-80gb" + + +class _A100_80GB(_A100_80GB_Base): + partition_1g_10gb = _A100_80GB_Base.partitioned("1g.10gb") + partition_2g_20gb = _A100_80GB_Base.partitioned("2g.20gb") + partition_3g_40gb = _A100_80GB_Base.partitioned("3g.40gb") + partition_4g_40gb = _A100_80GB_Base.partitioned("4g.40gb") + partition_7g_80gb = _A100_80GB_Base.partitioned("7g.80gb") + + +A100_80GB = _A100_80GB() diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index e60038c0f6d..efd18babcda 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -1,6 +1,7 @@ import datetime import typing +from flyteidl.core import tasks_pb2 from flyteidl.core import workflow_pb2 as _core_workflow from flytekit.models import common as _common @@ -562,24 +563,33 @@ def from_flyte_idl(cls, pb2_object): class TaskNodeOverrides(_common.FlyteIdlEntity): - def __init__(self, resources: typing.Optional[Resources] = None): + def __init__( + self, resources: typing.Optional[Resources], extended_resources: typing.Optional[tasks_pb2.ExtendedResources] + ): self._resources = resources + self._extended_resources = extended_resources @property def resources(self) -> Resources: return self._resources + @property + def extended_resources(self) -> tasks_pb2.ExtendedResources: + return self._extended_resources + def to_flyte_idl(self): return _core_workflow.TaskNodeOverrides( resources=self.resources.to_flyte_idl() if self.resources is not None else None, + extended_resources=self.extended_resources, ) @classmethod def from_flyte_idl(cls, pb2_object): resources = Resources.from_flyte_idl(pb2_object.resources) + extended_resources = pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None if bool(resources.requests) or bool(resources.limits): - return cls(resources=resources) - return cls(resources=None) + return cls(resources=resources, extended_resources=extended_resources) + return cls(resources=None, extended_resources=extended_resources) class TaskNode(_common.FlyteIdlEntity): diff --git a/flytekit/models/task.py b/flytekit/models/task.py index f7f1d710c9f..48a8abfde17 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -336,6 +336,7 @@ def __init__( config=None, k8s_pod=None, sql=None, + extended_resources=None, ): """ A task template represents the full set of information necessary to perform a unit of work in the Flyte system. @@ -359,6 +360,7 @@ def __init__( in tandem with the custom. :param K8sPod k8s_pod: Alternative to the container used to execute this task. :param Sql sql: This is used to execute query in FlytePropeller instead of running container or k8s_pod. + :param flyteidl.core.tasks_pb2.ExtendedResources extended_resources: The extended resources to allocate to the task. """ if ( (container is not None and k8s_pod is not None) @@ -377,6 +379,7 @@ def __init__( self._security_context = security_context self._k8s_pod = k8s_pod self._sql = sql + self._extended_resources = extended_resources @property def id(self): @@ -451,6 +454,14 @@ def k8s_pod(self): def sql(self): return self._sql + @property + def extended_resources(self): + """ + If not None, the extended resources to allocate to the task. + :rtype: flyteidl.core.tasks_pb2.ExtendedResources + """ + return self._extended_resources + def to_flyte_idl(self): """ :rtype: flyteidl.core.tasks_pb2.TaskTemplate @@ -464,6 +475,7 @@ def to_flyte_idl(self): container=self.container.to_flyte_idl() if self.container else None, task_type_version=self.task_type_version, security_context=self.security_context.to_flyte_idl() if self.security_context else None, + extended_resources=self.extended_resources, config={k: v for k, v in self.config.items()} if self.config is not None else None, k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None, sql=self.sql.to_flyte_idl() if self.sql else None, @@ -487,6 +499,7 @@ def from_flyte_idl(cls, pb2_object): security_context=_sec.SecurityContext.from_flyte_idl(pb2_object.security_context) if pb2_object.security_context and pb2_object.security_context.ByteSize() > 0 else None, + extended_resources=pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None, config={k: v for k, v in pb2_object.config.items()} if pb2_object.config is not None else None, k8s_pod=K8sPod.from_flyte_idl(pb2_object.k8s_pod) if pb2_object.HasField("k8s_pod") else None, sql=Sql.from_flyte_idl(pb2_object.sql) if pb2_object.HasField("sql") else None, diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 87ccd2f5346..cfe43544f36 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -214,6 +214,7 @@ def get_serializable_task( config=entity.get_config(settings), k8s_pod=pod, sql=entity.get_sql(settings), + extended_resources=entity.get_extended_resources(settings), ) if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask): entity.reset_command_fn() @@ -440,7 +441,8 @@ def get_serializable_node( upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], task_node=workflow_model.TaskNode( - reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources) + reference_id=task_spec.template.id, + overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources), ), ) if entity._aliases: @@ -516,7 +518,8 @@ def get_serializable_node( upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], task_node=workflow_model.TaskNode( - reference_id=entity.flyte_entity.id, overrides=TaskNodeOverrides(resources=entity._resources) + reference_id=entity.flyte_entity.id, + overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources), ), ) elif isinstance(entity.flyte_entity, FlyteWorkflow): @@ -565,7 +568,7 @@ def get_serializable_array_node( task_spec = get_serializable(entity_mapping, settings, entity, options) task_node = workflow_model.TaskNode( reference_id=task_spec.template.id, - overrides=TaskNodeOverrides(resources=node._resources), + overrides=TaskNodeOverrides(resources=node._resources, extended_resources=node._extended_resources), ) node = workflow_model.Node( id=entity.name, diff --git a/tests/flytekit/common/parameterizers.py b/tests/flytekit/common/parameterizers.py index d5b07fe4202..96c30b69b4e 100644 --- a/tests/flytekit/common/parameterizers.py +++ b/tests/flytekit/common/parameterizers.py @@ -1,6 +1,9 @@ from datetime import timedelta from itertools import product +from flyteidl.core import tasks_pb2 + +from flytekit.extras.accelerators import A100, T4 from flytekit.models import interface, literals, security, task, types from flytekit.models.core import identifier from flytekit.models.core import types as _core_types @@ -136,7 +139,6 @@ ) ] - LIST_OF_TASK_TEMPLATES = [ task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), @@ -250,3 +252,19 @@ LIST_OF_SECURITY_CONTEXT = [ security.SecurityContext(run_as=r, secrets=s, tokens=None) for r in LIST_RUN_AS for s in LIST_OF_SECRETS ] + [None] + +LIST_OF_ACCELERATORS = [ + None, + T4, + A100, + A100.unpartitioned, + A100.partition_1g_5gb, +] + +LIST_OF_EXTENDED_RESOURCES = [ + None, + *[ + tasks_pb2.ExtendedResources(gpu_accelerator=None if accelerator is None else accelerator.to_flyte_idl()) + for accelerator in LIST_OF_ACCELERATORS + ], +] diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 81621ef3fc4..cb790f3c2e1 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -13,6 +13,7 @@ from flytekit.core.task import task from flytekit.core.workflow import workflow from flytekit.exceptions.user import FlyteAssertion +from flytekit.extras.accelerators import A100, T4 from flytekit.models import literals as _literal_models from flytekit.models.task import Resources as _resources_models from flytekit.tools.translator import get_serializable @@ -465,3 +466,29 @@ def wf() -> str: return "hi" assert wf.nodes[0].flyte_entity.container_image == "hello/world" + + +def test_override_accelerator(): + @task(accelerator=T4) + def bar() -> str: + return "hello" + + @workflow + def my_wf() -> str: + return bar().with_overrides(accelerator=A100.partition_1g_5gb) + + serialization_settings = flytekit.configuration.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + assert len(wf_spec.template.nodes) == 1 + assert wf_spec.template.nodes[0].task_node.overrides is not None + assert wf_spec.template.nodes[0].task_node.overrides.extended_resources is not None + accelerator = wf_spec.template.nodes[0].task_node.overrides.extended_resources.gpu_accelerator + assert accelerator.device == "nvidia-tesla-a100" + assert accelerator.partition_size == "1g.5gb" + assert not accelerator.HasField("unpartitioned") diff --git a/tests/flytekit/unit/extras/test_accelerators.py b/tests/flytekit/unit/extras/test_accelerators.py new file mode 100644 index 00000000000..a62dff7af79 --- /dev/null +++ b/tests/flytekit/unit/extras/test_accelerators.py @@ -0,0 +1,64 @@ +from collections import OrderedDict + +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.task import task +from flytekit.extras.accelerators import A100, T4 +from flytekit.tools.translator import get_serializable + +serialization_settings = SerializationSettings( + project="proj", + domain="dom", + version="123", + image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), + env={}, +) + + +class TestAccelerators: + def test_gpu_accelerator(self): + @task(accelerator=T4) + def needs_t4(a: int): + pass + + ts = get_serializable(OrderedDict(), serialization_settings, needs_t4).to_flyte_idl() + gpu_accelerator = ts.template.extended_resources.gpu_accelerator + assert gpu_accelerator is not None + assert gpu_accelerator.device == "nvidia-tesla-t4" + assert not gpu_accelerator.HasField("unpartitioned") + assert not gpu_accelerator.HasField("partition_size") + + def test_mig(self): + @task(accelerator=A100) + def needs_a100(a: int): + pass + + ts = get_serializable(OrderedDict(), serialization_settings, needs_a100).to_flyte_idl() + gpu_accelerator = ts.template.extended_resources.gpu_accelerator + assert gpu_accelerator is not None + assert gpu_accelerator.device == "nvidia-tesla-a100" + assert not gpu_accelerator.HasField("unpartitioned") + assert not gpu_accelerator.HasField("partition_size") + + def test_mig_unpartitioned(self): + @task(accelerator=A100.unpartitioned) + def needs_unpartitioned_a100(a: int): + pass + + ts = get_serializable(OrderedDict(), serialization_settings, needs_unpartitioned_a100).to_flyte_idl() + gpu_accelerator = ts.template.extended_resources.gpu_accelerator + assert gpu_accelerator is not None + assert gpu_accelerator.device == "nvidia-tesla-a100" + assert gpu_accelerator.unpartitioned + assert not gpu_accelerator.HasField("partition_size") + + def test_mig_partitioned(self): + @task(accelerator=A100.partition_1g_5gb) + def needs_partitioned_a100(a: int): + pass + + ts = get_serializable(OrderedDict(), serialization_settings, needs_partitioned_a100).to_flyte_idl() + gpu_accelerator = ts.template.extended_resources.gpu_accelerator + assert gpu_accelerator is not None + assert gpu_accelerator.device == "nvidia-tesla-a100" + assert gpu_accelerator.partition_size == "1g.5gb" + assert not gpu_accelerator.HasField("unpartitioned") diff --git a/tests/flytekit/unit/models/core/test_workflow.py b/tests/flytekit/unit/models/core/test_workflow.py index 6775d589403..cd36381ea01 100644 --- a/tests/flytekit/unit/models/core/test_workflow.py +++ b/tests/flytekit/unit/models/core/test_workflow.py @@ -1,5 +1,8 @@ from datetime import timedelta +from flyteidl.core import tasks_pb2 + +from flytekit.extras.accelerators import T4 from flytekit.models import interface as _interface from flytekit.models import literals as _literals from flytekit.models import types as _types @@ -300,10 +303,12 @@ def test_task_node_overrides(): Resources( requests=[Resources.ResourceEntry(Resources.ResourceName.CPU, "1")], limits=[Resources.ResourceEntry(Resources.ResourceName.CPU, "2")], - ) + ), + tasks_pb2.ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), ) assert overrides.resources.requests == [Resources.ResourceEntry(Resources.ResourceName.CPU, "1")] assert overrides.resources.limits == [Resources.ResourceEntry(Resources.ResourceName.CPU, "2")] + assert overrides.extended_resources.gpu_accelerator == T4.to_flyte_idl() obj = _workflow.TaskNodeOverrides.from_flyte_idl(overrides.to_flyte_idl()) assert overrides == obj @@ -316,12 +321,14 @@ def test_task_node_with_overrides(): Resources( requests=[Resources.ResourceEntry(Resources.ResourceName.CPU, "1")], limits=[Resources.ResourceEntry(Resources.ResourceName.CPU, "2")], - ) + ), + tasks_pb2.ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), ), ) assert task_node.overrides.resources.requests == [Resources.ResourceEntry(Resources.ResourceName.CPU, "1")] assert task_node.overrides.resources.limits == [Resources.ResourceEntry(Resources.ResourceName.CPU, "2")] + assert task_node.overrides.extended_resources.gpu_accelerator == T4.to_flyte_idl() obj = _workflow.TaskNode.from_flyte_idl(task_node.to_flyte_idl()) assert task_node == obj diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index a979a39b661..b4158c38521 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -2,12 +2,13 @@ from itertools import product import pytest -from flyteidl.core.tasks_pb2 import TaskMetadata +from flyteidl.core.tasks_pb2 import ExtendedResources, TaskMetadata from google.protobuf import text_format import flytekit.models.interface as interface_models import flytekit.models.literals as literal_models from flytekit import Description, Documentation, SourceCode +from flytekit.extras.accelerators import T4 from flytekit.models import literals, task, types from flytekit.models.core import identifier from tests.flytekit.common import parameterizers @@ -108,6 +109,7 @@ def test_task_template(in_tuple): {"d": "e"}, ), config={"a": "b"}, + extended_resources=ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), ) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" @@ -124,6 +126,9 @@ def test_task_template(in_tuple): task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl() ) assert obj.config == {"a": "b"} + assert obj.extended_resources.gpu_accelerator.device == "nvidia-tesla-t4" + assert not obj.extended_resources.gpu_accelerator.HasField("unpartitioned") + assert not obj.extended_resources.gpu_accelerator.HasField("partition_size") def test_task_spec(): @@ -166,6 +171,7 @@ def test_task_spec(): {"d": "e"}, ), config={"a": "b"}, + extended_resources=ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), ) short_description = "short" @@ -212,6 +218,7 @@ def test_task_template_k8s_pod_target(): metadata=task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}), pod_spec={"str": "val", "int": 1}, ), + extended_resources=ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), ) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" @@ -226,6 +233,9 @@ def test_task_template_k8s_pod_target(): task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl() ) assert obj.config == {"a": "b"} + assert obj.extended_resources.gpu_accelerator.device == "nvidia-tesla-t4" + assert not obj.extended_resources.gpu_accelerator.HasField("unpartitioned") + assert not obj.extended_resources.gpu_accelerator.HasField("partition_size") @pytest.mark.parametrize("sec_ctx", parameterizers.LIST_OF_SECURITY_CONTEXT) @@ -254,6 +264,28 @@ def test_task_template_security_context(sec_ctx): assert task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).security_context == expected +@pytest.mark.parametrize("extended_resources", parameterizers.LIST_OF_EXTENDED_RESOURCES) +def test_task_template_extended_resources(extended_resources): + obj = task.TaskTemplate( + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), + "python", + parameterizers.LIST_OF_TASK_METADATA[0], + parameterizers.LIST_OF_INTERFACES[0], + {"a": 1, "b": {"c": 2, "d": 3}}, + container=task.Container( + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + parameterizers.LIST_OF_RESOURCES[0], + {"a": "b"}, + {"d": "e"}, + ), + extended_resources=extended_resources, + ) + assert obj.extended_resources == extended_resources + assert task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).extended_resources == extended_resources + + @pytest.mark.parametrize("task_closure", parameterizers.LIST_OF_TASK_CLOSURES) def test_task(task_closure): obj = task.Task( diff --git a/tests/flytekit/unit/models/test_workflow_closure.py b/tests/flytekit/unit/models/test_workflow_closure.py index 2b5b06696b6..64e5a577139 100644 --- a/tests/flytekit/unit/models/test_workflow_closure.py +++ b/tests/flytekit/unit/models/test_workflow_closure.py @@ -1,5 +1,8 @@ from datetime import timedelta +from flyteidl.core import tasks_pb2 + +from flytekit.extras.accelerators import T4 from flytekit.models import interface as _interface from flytekit.models import literals as _literals from flytekit.models import task as _task @@ -58,6 +61,7 @@ def test_workflow_closure(): {}, {}, ), + extended_resources=tasks_pb2.ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), ) task_node = _workflow.TaskNode(task.id)