diff --git a/flytekit/__init__.py b/flytekit/__init__.py index bba4947b38..5f26b27664 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -246,6 +246,8 @@ StructuredDatasetType, ) +from flytekit.extend.backend.task_executor import SyncAgentBase # isort:skip. This is for circular import avoidance. + def current_context() -> ExecutionParameters: """ diff --git a/flytekit/core/external_api_task.py b/flytekit/core/external_api_task.py new file mode 100644 index 0000000000..3d8942f346 --- /dev/null +++ b/flytekit/core/external_api_task.py @@ -0,0 +1,71 @@ +import collections +import inspect +from abc import abstractmethod +from typing import Any, Dict, Optional, TypeVar + +from flyteidl.admin.agent_pb2 import CreateTaskResponse +from typing_extensions import get_type_hints + +from flytekit.configuration import SerializationSettings +from flytekit.core.base_task import PythonTask +from flytekit.core.interface import Interface +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin + +T = TypeVar("T") +TASK_MODULE = "task_module" +TASK_NAME = "task_name" +TASK_CONFIG = "task_config" +TASK_TYPE = "api_task" + + +class ExternalApiTask(AsyncAgentExecutorMixin, PythonTask): + """ + Base class for all external API tasks. External API tasks are tasks that are designed to run until they receive a + response from an external service. When the response is received, the task will complete. External API tasks are + designed to be run by the flyte agent. + """ + + def __init__( + self, + name: str, + config: Optional[T] = None, + task_type: str = TASK_TYPE, + return_type: Optional[Any] = None, + **kwargs, + ): + type_hints = get_type_hints(self.do, include_extras=True) + signature = inspect.signature(self.do) + inputs = collections.OrderedDict() + outputs = collections.OrderedDict({"o0": return_type}) if return_type else collections.OrderedDict() + + for k, _ in signature.parameters.items(): # type: ignore + annotation = type_hints.get(k, None) + inputs[k] = annotation + + super().__init__( + task_type=task_type, + name=name, + task_config=config, + interface=Interface(inputs=inputs, outputs=outputs), + **kwargs, + ) + + self._task_config = config + + @abstractmethod + async def do(self, **kwargs) -> CreateTaskResponse: + """ + Initiate an HTTP request to an external service such as OpenAI or Vertex AI and retrieve the response. + """ + raise NotImplementedError + + def get_custom(self, settings: SerializationSettings = None) -> Dict[str, Any]: + cfg = { + TASK_MODULE: type(self).__module__, + TASK_NAME: type(self).__name__, + } + + if self._task_config is not None: + cfg[TASK_CONFIG] = self._task_config + + return cfg diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 73737f3a6c..7b17e6e2a9 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -44,7 +44,11 @@ def agent_exception_handler(func): async def wrapper( self, - request: typing.Union[CreateTaskRequest, GetTaskRequest, DeleteTaskRequest], + request: typing.Union[ + CreateTaskRequest, + GetTaskRequest, + DeleteTaskRequest, + ], context: grpc.ServicerContext, *args, **kwargs, @@ -92,7 +96,6 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon tmp = TaskTemplate.from_flyte_idl(request.template) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None agent = AgentRegistry.get_agent(tmp.type) - logger.info(f"{tmp.type} agent start creating the job") if agent.asynchronous: return await agent.async_create( diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 2af8bfb29f..601bc3a1e3 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -6,7 +6,7 @@ from abc import ABC from collections import OrderedDict from functools import partial -from types import FrameType +from types import FrameType, coroutine import grpc from flyteidl.admin.agent_pb2 import ( @@ -138,7 +138,8 @@ def convert_to_flyte_state(state: str) -> State: Convert the state from the agent to the state in flyte. """ state = state.lower() - if state in ["failed", "timedout", "canceled"]: + # timedout is the state of Databricks job. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate + if state in ["failed", "timeout", "timedout", "canceled"]: return RETRYABLE_FAILURE elif state in ["done", "succeeded", "success"]: return SUCCEEDED @@ -158,15 +159,27 @@ def get_agent_secret(secret_key: str) -> str: return flytekit.current_context().secrets.get(secret_key) +def _get_grpc_context() -> grpc.ServicerContext: + from unittest.mock import MagicMock + + grpc_ctx = MagicMock(spec=grpc.ServicerContext) + return grpc_ctx + + class AsyncAgentExecutorMixin: """ This mixin class is used to run the agent task locally, and it's only used for local execution. Task should inherit from this class if the task can be run in the agent. + It can handle asynchronous tasks and synchronous tasks. + Asynchronous tasks are for tasks running long, for example running query job. + Synchronous tasks are for tasks running quick, for example, you want to execute something really fast, or even retrieving some metadata from a backend service. """ - _is_canceled = None - _agent = None - _entity = None + _clean_up_task: coroutine = None + _agent: AgentBase = None + _entity: PythonTask = None + _ctx: FlyteContext = FlyteContext.current_context() + _grpc_ctx: grpc.ServicerContext = _get_grpc_context() def execute(self, **kwargs) -> typing.Any: ctx = FlyteContext.current_context() @@ -180,6 +193,13 @@ def execute(self, **kwargs) -> typing.Any: self._agent = AgentRegistry.get_agent(task_template.type) res = asyncio.run(self._create(task_template, output_prefix, kwargs)) + + # If the task is synchronous, the agent will return the output from the resource literals. + if res.HasField("resource"): + if res.resource.state != SUCCEEDED: + raise FlyteUserException(f"Failed to run the task {self._entity.name}") + return LiteralMap.from_flyte_idl(res.resource.outputs) + res = asyncio.run(self._get(resource_meta=res.resource_meta)) if res.resource.state != SUCCEEDED: @@ -198,7 +218,6 @@ async def _create( self, task_template: TaskTemplate, output_prefix: str, inputs: typing.Dict[str, typing.Any] = None ) -> CreateTaskResponse: ctx = FlyteContext.current_context() - grpc_ctx = _get_grpc_context() # Convert python inputs to literals literals = inputs or {} @@ -213,9 +232,9 @@ async def _create( task_template = render_task_template(task_template, output_prefix) if self._agent.asynchronous: - res = await self._agent.async_create(grpc_ctx, output_prefix, task_template, inputs) + res = await self._agent.async_create(self._grpc_ctx, output_prefix, task_template, literal_map) else: - res = self._agent.create(grpc_ctx, output_prefix, task_template, inputs) + res = self._agent.create(self._grpc_ctx, output_prefix, task_template, literal_map) signal.signal(signal.SIGINT, partial(self.signal_handler, res.resource_meta)) # type: ignore return res @@ -232,8 +251,8 @@ async def _get(self, resource_meta: bytes) -> GetTaskResponse: time.sleep(1) if self._agent.asynchronous: res = await self._agent.async_get(grpc_ctx, resource_meta) - if self._is_canceled: - await self._is_canceled + if self._clean_up_task: + await self._clean_up_task sys.exit(1) else: res = self._agent.get(grpc_ctx, resource_meta) @@ -242,12 +261,11 @@ async def _get(self, resource_meta: bytes) -> GetTaskResponse: return res def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any: - grpc_ctx = _get_grpc_context() if self._agent.asynchronous: - if self._is_canceled is None: - self._is_canceled = asyncio.create_task(self._agent.async_delete(grpc_ctx, resource_meta)) + if self._clean_up_task is None: + self._clean_up_task = asyncio.create_task(self._agent.async_delete(self._grpc_ctx, resource_meta)) else: - self._agent.delete(grpc_ctx, resource_meta) + self._agent.delete(self._grpc_ctx, resource_meta) sys.exit(1) diff --git a/flytekit/extend/backend/task_executor.py b/flytekit/extend/backend/task_executor.py new file mode 100644 index 0000000000..22b6675cca --- /dev/null +++ b/flytekit/extend/backend/task_executor.py @@ -0,0 +1,71 @@ +import importlib +import typing +from dataclasses import dataclass +from typing import final + +import grpc +from flyteidl.admin.agent_pb2 import CreateTaskResponse + +from flytekit import FlyteContextManager +from flytekit.core.external_api_task import TASK_CONFIG, TASK_MODULE, TASK_NAME, TASK_TYPE +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +T = typing.TypeVar("T") + + +@dataclass +class IOContext: + inputs: LiteralMap + output_prefix: str + + +class SyncAgentBase(AgentBase): + """ + SyncAgentBase is an agent responsible for syncrhounous tasks, which are fast and quick. + + This class is meant to be subclassed when implementing plugins that require + an external API to perform the task execution. It provides a routing mechanism + to direct the task to the appropriate handler based on the task's specifications. + """ + + def __init__(self): + super().__init__(task_type=TASK_TYPE, asynchronous=True) + + @final + async def async_create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + return await self.do(context, output_prefix, task_template, inputs) + + async def do( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + python_interface_inputs = { + name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() + } + ctx = FlyteContextManager.current_context() + + native_inputs = {} + if inputs: + native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) + + meta = task_template.custom + + task_module = importlib.import_module(name=meta[TASK_MODULE]) + task_def = getattr(task_module, meta[TASK_NAME]) + config = meta[TASK_CONFIG] if meta.get(TASK_CONFIG) else None + return await task_def(TASK_TYPE, config=config).do(**native_inputs) + + +AgentRegistry.register(SyncAgentBase()) diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index c39f8dea37..7ae03d37a6 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -8,11 +8,10 @@ from flytekit.exceptions import user as _user_exceptions from flytekit.models import common as _common from flytekit.models.core import types as _core_types -from flytekit.models.types import Error +from flytekit.models.types import Error, StructuredDatasetType from flytekit.models.types import LiteralType as _LiteralType from flytekit.models.types import OutputReference as _OutputReference from flytekit.models.types import SchemaType as _SchemaType -from flytekit.models.types import StructuredDatasetType class RetryStrategy(_common.FlyteIdlEntity): diff --git a/flytekit/sensor/base_sensor.py b/flytekit/sensor/base_sensor.py index 60beb6aa2b..fed5f6493b 100644 --- a/flytekit/sensor/base_sensor.py +++ b/flytekit/sensor/base_sensor.py @@ -15,6 +15,7 @@ SENSOR_MODULE = "sensor_module" SENSOR_NAME = "sensor_name" SENSOR_CONFIG_PKL = "sensor_config_pkl" +SENSOR_TYPE = "sensor" INPUTS = "inputs" @@ -35,7 +36,7 @@ def __init__( type_hints = get_type_hints(self.poke, include_extras=True) signature = inspect.signature(self.poke) inputs = collections.OrderedDict() - for k, v in signature.parameters.items(): # type: ignore + for k, _ in signature.parameters.items(): # type: ignore annotation = type_hints.get(k, None) inputs[k] = annotation diff --git a/flytekit/sensor/sensor_engine.py b/flytekit/sensor/sensor_engine.py index 79d2e0f4b4..3f7b1b7a69 100644 --- a/flytekit/sensor/sensor_engine.py +++ b/flytekit/sensor/sensor_engine.py @@ -19,7 +19,7 @@ from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from flytekit.sensor.base_sensor import INPUTS, SENSOR_CONFIG_PKL, SENSOR_MODULE, SENSOR_NAME +from flytekit.sensor.base_sensor import INPUTS, SENSOR_CONFIG_PKL, SENSOR_MODULE, SENSOR_NAME, SENSOR_TYPE T = typing.TypeVar("T") @@ -39,9 +39,11 @@ async def async_create( name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() } ctx = FlyteContextManager.current_context() + if inputs: native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) task_template.custom[INPUTS] = native_inputs + return CreateTaskResponse(resource_meta=cloudpickle.dumps(task_template.custom)) async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: @@ -52,7 +54,7 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - sensor_config = jsonpickle.decode(meta[SENSOR_CONFIG_PKL]) if meta.get(SENSOR_CONFIG_PKL) else None inputs = meta.get(INPUTS, {}) - cur_state = SUCCEEDED if await sensor_def("sensor", config=sensor_config).poke(**inputs) else RUNNING + cur_state = SUCCEEDED if await sensor_def(SENSOR_TYPE, config=sensor_config).poke(**inputs) else RUNNING return GetTaskResponse(resource=Resource(state=cur_state, outputs=None)) async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: diff --git a/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py b/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py index 1ae47339b3..2efca76e99 100644 --- a/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py +++ b/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py @@ -5,6 +5,7 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask +from flytekit.extend.backend.base_agent import ASYNC_PLUGIN from flytekit.models.presto import PrestoQuery from flytekit.types.schema import FlyteSchema @@ -65,6 +66,7 @@ def __init__( inputs=inputs, outputs=outputs, task_type=self._TASK_TYPE, + runtime_flavor=ASYNC_PLUGIN, **kwargs, ) self._output_schema_type = output_schema_type diff --git a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py index e0326f112b..d5568ab041 100644 --- a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py +++ b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py @@ -41,7 +41,10 @@ def __init__(self, task_config: AWSBatchConfig, task_function: Callable, **kwarg if task_config is None: task_config = AWSBatchConfig() super(AWSBatchFunctionTask, self).__init__( - task_config=task_config, task_type=self._AWS_BATCH_TASK_TYPE, task_function=task_function, **kwargs + task_config=task_config, + task_type=self._AWS_BATCH_TASK_TYPE, + task_function=task_function, + **kwargs, ) self._task_config = task_config diff --git a/plugins/flytekit-openai-chatgpt/dev-requirements.in b/plugins/flytekit-openai-chatgpt/dev-requirements.in new file mode 100644 index 0000000000..2d73dba5b4 --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/dev-requirements.in @@ -0,0 +1 @@ +pytest-asyncio diff --git a/plugins/flytekit-openai-chatgpt/dev-requirements.txt b/plugins/flytekit-openai-chatgpt/dev-requirements.txt new file mode 100644 index 0000000000..1c37cda90d --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/dev-requirements.txt @@ -0,0 +1,20 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# pip-compile dev-requirements.in +# +exceptiongroup==1.1.3 + # via pytest +iniconfig==2.0.0 + # via pytest +packaging==23.2 + # via pytest +pluggy==1.3.0 + # via pytest +pytest==7.4.2 + # via pytest-asyncio +pytest-asyncio==0.21.1 + # via -r dev-requirements.in +tomli==2.0.1 + # via pytest diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py new file mode 100644 index 0000000000..7a47fd2ffb --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py @@ -0,0 +1,13 @@ +""" +.. currentmodule:: flytekitplugins.chatgpt + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + ChatGPTTask +""" + +from .task import ChatGPTTask diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py new file mode 100644 index 0000000000..5b606bb824 --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -0,0 +1,66 @@ +import asyncio +from typing import Any, Dict + +import openai +from flyteidl.admin.agent_pb2 import SUCCEEDED, CreateTaskResponse, Resource + +from flytekit import FlyteContextManager +from flytekit.core.external_api_task import ExternalApiTask +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import get_agent_secret +from flytekit.models.literals import LiteralMap + +TIMEOUT_SECONDS = 10 + + +class ChatGPTTask(ExternalApiTask): + """ + This is the simplest form of a ChatGPTTask Task, you can define the model and the input you want. + + Args: + openai_organization: OpenAI Organization. String can be found here. https://platform.openai.com/docs/api-reference/organization-optional + chatgpt_conf: ChatGPT job configuration. Config structure can be found here. https://platform.openai.com/docs/api-reference/completions/create + """ + + _openai_organization: str = None + _chatgpt_conf: Dict[str, Any] = None + + def __init__(self, name: str, config: Dict[str, Any], **kwargs): + if "openai_organization" not in config: + raise ValueError("The 'openai_organization' configuration variable is required") + + if "chatgpt_conf" not in config: + raise ValueError("The 'chatgpt_conf' configuration variable is required") + + if "model" not in config["chatgpt_conf"]: + raise ValueError("The 'model' configuration variable in 'chatgpt_conf' is required") + + self._openai_organization = config["openai_organization"] + self._chatgpt_conf = config["chatgpt_conf"] + + super().__init__(name=name, config=config, return_type=str, **kwargs) + + async def do( + self, + message: str = None, + ) -> CreateTaskResponse: + openai.organization = self._openai_organization + openai.api_key = get_agent_secret(secret_key="FLYTE_OPENAI_ACCESS_TOKEN") + + self._chatgpt_conf["messages"] = [{"role": "user", "content": message}] + + completion = await asyncio.wait_for(openai.ChatCompletion.acreate(**self._chatgpt_conf), TIMEOUT_SECONDS) + message = completion.choices[0].message.content + + ctx = FlyteContextManager.current_context() + outputs = LiteralMap( + { + "o0": TypeEngine.to_literal( + ctx, + message, + type(message), + TypeEngine.to_literal_type(type(message)), + ) + } + ).to_flyte_idl() + return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) diff --git a/plugins/flytekit-openai-chatgpt/setup.py b/plugins/flytekit-openai-chatgpt/setup.py new file mode 100644 index 0000000000..ba27a45a4b --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/setup.py @@ -0,0 +1,37 @@ +from setuptools import setup + +PLUGIN_NAME = "chatgpt" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.10.0", "openai>=0.28.1", "flyteidl>=1.10.0"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package holds the ChatGPT plugins for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-openai-chatgpt/tests/__init__.py b/plugins/flytekit-openai-chatgpt/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-openai-chatgpt/tests/test_chatgpt_task.py b/plugins/flytekit-openai-chatgpt/tests/test_chatgpt_task.py new file mode 100644 index 0000000000..c7ba185b79 --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/tests/test_chatgpt_task.py @@ -0,0 +1,38 @@ +from unittest import mock + +import pytest +from flyteidl.admin.agent_pb2 import SUCCEEDED +from flytekitplugins.chatgpt import ChatGPTTask + + +async def mock_acreate(*args, **kwargs) -> str: + mock_response = mock.MagicMock() + mock_choice = mock.MagicMock() + mock_choice.message.content = "mocked_message" + mock_response.choices = [mock_choice] + return mock_response + + +@pytest.mark.asyncio +async def test_chatgpt_task_do(): + message = "TEST MESSAGE" + organization = "TEST ORGANIZATION" + + chatgpt_job = ChatGPTTask( + name="chatgpt", + config={ + "openai_organization": organization, + "chatgpt_conf": { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": message}], + "temperature": 0.7, + }, + }, + ) + + with mock.patch("openai.ChatCompletion.acreate", new=mock_acreate): + with mock.patch("flytekit.extend.backend.base_agent.get_agent_secret", return_value="mocked_secret"): + response = await chatgpt_job.do(message=message) + + assert response.resource.state == SUCCEEDED + assert "mocked_message" in str(response.resource.outputs) diff --git a/tests/flytekit/unit/core/test_external_api_task.py b/tests/flytekit/unit/core/test_external_api_task.py new file mode 100644 index 0000000000..cbe0ffa181 --- /dev/null +++ b/tests/flytekit/unit/core/test_external_api_task.py @@ -0,0 +1,40 @@ +import collections +import json + +import pytest + +from flytekit.core.external_api_task import TASK_CONFIG, TASK_MODULE, TASK_NAME, ExternalApiTask +from flytekit.core.interface import Interface, transform_interface_to_typed_interface + + +class MockExternalApiTask(ExternalApiTask): + async def do(self, test_int_input: int, **kwargs) -> int: + return test_int_input + + +def test_init(): + task = MockExternalApiTask(name="test_task", return_type=int) + assert task.name == "test_task" + + interface = Interface( + inputs=collections.OrderedDict({"test_int_input": int, "kwargs": None}), + outputs=collections.OrderedDict({"o0": int}), + ) + assert task.interface == transform_interface_to_typed_interface(interface) + + +@pytest.mark.asyncio +async def test_do(): + input_num = 100 + task = MockExternalApiTask(name="test_task", return_type=int) + assert input_num == await task.do(test_int_input=input_num) + + +def test_get_custom(): + task = MockExternalApiTask(name="test_task", config={"key": "value"}) + custom = task.get_custom() + + expected_config = json.loads('{"key": "value"}') + assert custom[TASK_MODULE] == MockExternalApiTask.__module__ + assert custom[TASK_NAME] == MockExternalApiTask.__name__ + assert json.loads(custom[TASK_CONFIG]) == expected_config diff --git a/tests/flytekit/unit/core/test_task_metadata.py b/tests/flytekit/unit/core/test_task_metadata.py new file mode 100644 index 0000000000..a158a3ac31 --- /dev/null +++ b/tests/flytekit/unit/core/test_task_metadata.py @@ -0,0 +1,59 @@ +import datetime + +import pytest + +from flytekit import __version__ +from flytekit.core.base_task import TaskMetadata +from flytekit.models import literals as _literal_models +from flytekit.models import task as _task_model + + +def test_post_init_conditions(): + with pytest.raises(ValueError, match="Caching is enabled ``cache=True`` but ``cache_version`` is not set."): + TaskMetadata(cache=True, cache_version="") + + with pytest.raises( + ValueError, match="Cache serialize is enabled ``cache_serialize=True`` but ``cache`` is not enabled." + ): + TaskMetadata(cache=False, cache_serialize=True) + + with pytest.raises( + ValueError, match="timeout should be duration represented as either a datetime.timedelta or int seconds" + ): + TaskMetadata(timeout="invalid_timeout") + + tm = TaskMetadata(timeout=3600) + assert isinstance(tm.timeout, datetime.timedelta) + + +def test_retry_strategy(): + tm = TaskMetadata(retries=5) + assert tm.retry_strategy.retries == 5 + + +def test_to_task_metadata_model(): + tm = TaskMetadata( + cache=True, + cache_serialize=True, + cache_version="v1", + interruptible=True, + deprecated="TEST DEPRECATED ERROR MESSAGE", + retries=3, + timeout=3600, + pod_template_name="TEST POD TEMPLATE NAME", + ) + model = tm.to_taskmetadata_model() + + assert model.discoverable is True + assert model.runtime == _task_model.RuntimeMetadata( + _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, + __version__, + "python", + ) + assert model.retries == _literal_models.RetryStrategy(3) + assert model.timeout == datetime.timedelta(seconds=3600) + assert model.interruptible is True + assert model.discovery_version == "v1" + assert model.deprecated_error_message == "TEST DEPRECATED ERROR MESSAGE" + assert model.cache_serializable is True + assert model.pod_template_name == "TEST POD TEMPLATE NAME" diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index c9a56ea384..75719d619d 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -16,6 +16,8 @@ CreateTaskResponse, DeleteTaskRequest, DeleteTaskResponse, + DoTaskRequest, + DoTaskResponse, GetTaskRequest, GetTaskResponse, Resource, @@ -23,7 +25,7 @@ from flytekit import PythonFunctionTask, task from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings -from flytekit.extend.backend.agent_service import AsyncAgentService +from flytekit.extend.backend.agent_service import AsyncAgentService, SyncAgentService from flytekit.extend.backend.base_agent import ( AgentBase, AgentRegistry, @@ -47,24 +49,27 @@ class Metadata: job_id: str -class DummyAgent(AgentBase): +class SyncDummyAgent(AgentBase): def __init__(self): - super().__init__(task_type="dummy", asynchronous=False) + super().__init__(task_type="sync_dummy", asynchronous=True) - def create( + async def async_do( self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, - ) -> CreateTaskResponse: - return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")) - - def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: - return GetTaskResponse(resource=Resource(state=SUCCEEDED)) + ) -> DoTaskResponse: + return DoTaskResponse(resource=Resource(state=SUCCEEDED)) - def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: - return DeleteTaskResponse() + def do( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> DoTaskResponse: + return DoTaskResponse(resource=Resource(state=SUCCEEDED)) class AsyncDummyAgent(AgentBase): @@ -113,27 +118,38 @@ def simple_task(i: int): ) -dummy_template = get_task_template("dummy") async_dummy_template = get_task_template("async_dummy") +sync_dummy_template = get_task_template("sync_dummy", True) def test_dummy_agent(): - AgentRegistry.register(DummyAgent()) ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent("dummy") + async_agent = AgentRegistry.get_agent("async_dummy") + sync_agent = AgentRegistry.get_agent("sync_dummy") metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") - assert agent.create(ctx, "/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes - assert agent.get(ctx, metadata_bytes).resource.state == SUCCEEDED - assert agent.delete(ctx, metadata_bytes) == DeleteTaskResponse() + assert async_agent.create(ctx, "/tmp", async_dummy_template, task_inputs).resource_meta == metadata_bytes + assert async_agent.get(ctx, metadata_bytes).resource.state == SUCCEEDED + assert async_agent.delete(ctx, metadata_bytes) == DeleteTaskResponse() + assert sync_agent.do(ctx, sync_dummy_template, task_inputs) == DoTaskResponse(resource=Resource(state=SUCCEEDED)) + + class AsyncDummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): + def __init__(self, **kwargs): + super().__init__( + task_type="async_dummy", + **kwargs, + ) + + t = AsyncDummyTask(task_config={}, task_function=lambda: None, container_image="dummy") + t.execute() - class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): + class SyncDummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): super().__init__( - task_type="dummy", + task_type="sync_dummy", **kwargs, ) - t = DummyTask(task_config={}, task_function=lambda: None, container_image="dummy") + t = SyncDummyTask(task_config={}, task_function=lambda: None, container_image="sync_dummy") t.execute() t._task_type = "non-exist-type" @@ -143,46 +159,64 @@ def __init__(self, **kwargs): @pytest.mark.asyncio async def test_async_dummy_agent(): - AgentRegistry.register(AsyncDummyAgent()) ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent("async_dummy") + async_agent = AgentRegistry.get_agent("async_dummy") + sync_agent = AgentRegistry.get_agent("sync_dummy") metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") - res = await agent.async_create(ctx, "/tmp", async_dummy_template, task_inputs) + res = await async_agent.async_create(ctx, "/tmp", async_dummy_template, task_inputs) assert res.resource_meta == metadata_bytes - res = await agent.async_get(ctx, metadata_bytes) + res = await async_agent.async_get(ctx, metadata_bytes) assert res.resource.state == SUCCEEDED - res = await agent.async_delete(ctx, metadata_bytes) + res = await async_agent.async_delete(ctx, metadata_bytes) assert res == DeleteTaskResponse() + res = await sync_agent.async_do(ctx, "/tmp", sync_dummy_template, task_inputs) + assert res == DoTaskResponse(resource=Resource(state=SUCCEEDED)) @pytest.mark.asyncio async def run_agent_server(): - service = AsyncAgentService() + async_agent_service = AsyncAgentService() + sync_agent_service = SyncAgentService() + ctx = MagicMock(spec=grpc.ServicerContext) - request = CreateTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() + create_request = CreateTaskRequest( + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() ) - async_request = CreateTaskRequest( + async_create_request = CreateTaskRequest( inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() ) + do_request = DoTaskRequest( + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=sync_dummy_template.to_flyte_idl() + ) + async_do_request = DoTaskRequest( + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=sync_dummy_template.to_flyte_idl() + ) fake_agent = "fake" metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") - res = await service.CreateTask(request, ctx) + res = await async_agent_service.CreateTask(create_request, ctx) assert res.resource_meta == metadata_bytes - res = await service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) + res = await async_agent_service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) assert res.resource.state == SUCCEEDED - res = await service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) + res = await async_agent_service.DeleteTask( + DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx + ) assert isinstance(res, DeleteTaskResponse) + res = await sync_agent_service.DoTask(do_request, ctx) + assert res.resource.state == SUCCEEDED - res = await service.CreateTask(async_request, ctx) + res = await async_agent_service.CreateTask(async_create_request, ctx) assert res.resource_meta == metadata_bytes - res = await service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) + res = await async_agent_service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) assert res.resource.state == SUCCEEDED - res = await service.DeleteTask(DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) + res = await async_agent_service.DeleteTask( + DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx + ) assert isinstance(res, DeleteTaskResponse) + res = await sync_agent_service.DoTask(async_do_request, ctx) + assert res.resource.state == SUCCEEDED - res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx) + res = await async_agent_service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx) assert res is None @@ -192,7 +226,7 @@ def test_agent_server(): def test_is_terminal_state(): assert is_terminal_state(SUCCEEDED) - assert is_terminal_state(PERMANENT_FAILURE) + assert is_terminal_state(RETRYABLE_FAILURE) assert is_terminal_state(PERMANENT_FAILURE) assert not is_terminal_state(RUNNING) @@ -247,3 +281,7 @@ def test_render_task_template(): "task-name", "simple_task", ] + + +AgentRegistry.register(AsyncDummyAgent()) +AgentRegistry.register(SyncDummyAgent()) diff --git a/tests/flytekit/unit/extend/test_task_executor.py b/tests/flytekit/unit/extend/test_task_executor.py new file mode 100644 index 0000000000..a1a9a277dd --- /dev/null +++ b/tests/flytekit/unit/extend/test_task_executor.py @@ -0,0 +1,76 @@ +import collections +from unittest.mock import MagicMock + +import grpc +import pytest +from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource + +from flytekit import FlyteContext, FlyteContextManager +from flytekit.core.external_api_task import TASK_MODULE, TASK_NAME, TASK_TYPE, ExternalApiTask +from flytekit.core.interface import Interface, transform_interface_to_typed_interface +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.models import literals +from flytekit.models.literals import LiteralMap +from tests.flytekit.unit.extend.test_agent import get_task_template + + +class MockExternalApiTask(ExternalApiTask): + async def do(self, input: str, **kwargs) -> DoTaskResponse: + ctx = FlyteContextManager.current_context() + outputs = LiteralMap( + { + "o0": TypeEngine.to_literal( + ctx, + input, + type(input), + TypeEngine.to_literal_type(type(input)), + ) + } + ).to_flyte_idl() + return DoTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs, message=input)) + + +@pytest.mark.asyncio +async def test_task_executor_engine(): + input = "TASK INPUT" + + interface = Interface( + inputs=collections.OrderedDict({"input": str, "kwargs": None}), + outputs=collections.OrderedDict({"o0": str}), + ) + tmp = get_task_template(TASK_TYPE, True) + tmp._custom = { + TASK_MODULE: MockExternalApiTask.__module__, + TASK_NAME: MockExternalApiTask.__name__, + } + + tmp._interface = transform_interface_to_typed_interface(interface) + + task_inputs = literals.LiteralMap( + { + "input": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(string_value="TASK INPUT"))), + }, + ) + output_prefix = FlyteContext.current_context().file_access.get_random_local_directory() + + ctx = MagicMock(spec=grpc.ServicerContext) + agent = AgentRegistry.get_agent(TASK_TYPE) + + res = await agent.async_do(ctx, output_prefix, tmp, task_inputs) + assert res.resource.state == SUCCEEDED + assert ( + res.resource.outputs + == literals.LiteralMap( + { + "o0": literals.Literal( + scalar=literals.Scalar( + primitive=literals.Primitive( + string_value=input, + ) + ) + ) + } + ).to_flyte_idl() + ) + assert res.resource.message == input diff --git a/tests/flytekit/unit/sensor/test_sensor_engine.py b/tests/flytekit/unit/sensor/test_sensor_engine.py index dbb81c3f47..3078f2e6a0 100644 --- a/tests/flytekit/unit/sensor/test_sensor_engine.py +++ b/tests/flytekit/unit/sensor/test_sensor_engine.py @@ -10,7 +10,7 @@ from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.models import literals, types from flytekit.sensor import FileSensor -from flytekit.sensor.base_sensor import SENSOR_MODULE, SENSOR_NAME +from flytekit.sensor.base_sensor import SENSOR_MODULE, SENSOR_NAME, SENSOR_TYPE from tests.flytekit.unit.extend.test_agent import get_task_template @@ -22,7 +22,7 @@ async def test_sensor_engine(): }, {}, ) - tmp = get_task_template("sensor") + tmp = get_task_template(SENSOR_TYPE) tmp._custom = { SENSOR_MODULE: FileSensor.__module__, SENSOR_NAME: FileSensor.__name__, @@ -37,7 +37,7 @@ async def test_sensor_engine(): }, ) ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent("sensor") + agent = AgentRegistry.get_agent(SENSOR_TYPE) res = await agent.async_create(ctx, "/tmp", tmp, task_inputs)