From 32ffad5eca17df5d24e3c9d575fc806a3f28f4df Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 26 Mar 2024 12:10:22 -0700 Subject: [PATCH 01/18] wip Signed-off-by: Kevin Su --- flytekit/extend/backend/utils.py | 6 +++++- plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py | 4 ++++ plugins/flytekit-openai/flytekitplugins/chatgpt/task.py | 7 ++++++- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py index 5199536b5d..22769539b9 100644 --- a/flytekit/extend/backend/utils.py +++ b/flytekit/extend/backend/utils.py @@ -1,7 +1,9 @@ import asyncio import inspect +import typing from typing import Callable, Coroutine +from flyteidl.admin.agent_pb2 import Secret from flyteidl.core.execution_pb2 import TaskExecution import flytekit @@ -38,7 +40,9 @@ def is_terminal_phase(phase: TaskExecution.Phase) -> bool: return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED] -def get_agent_secret(secret_key: str) -> str: +def get_agent_secret(secret_key: str, secrets: typing.Optional[typing.Dict[str, Secret]] = None) -> str: + if secrets and secret_key in secrets: + return secrets[secret_key].value return flytekit.current_context().secrets.get(secret_key) diff --git a/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py index afd3af1321..90c3faf59f 100644 --- a/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py @@ -1,8 +1,10 @@ import asyncio import logging +import typing from typing import Optional from flyteidl.core.execution_pb2 import TaskExecution +from flyteidl.admin.agent_pb2 import Secret from flytekit import FlyteContextManager, lazy_module from flytekit.core.type_engine import TypeEngine @@ -27,6 +29,8 @@ async def do( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, + secrets: Optional[typing.Dict[str, Secret]] = None, + **kwargs, ) -> Resource: ctx = FlyteContextManager.current_context() input_python_value = TypeEngine.literal_map_to_kwargs(ctx, inputs, {"message": str}) diff --git a/plugins/flytekit-openai/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai/flytekitplugins/chatgpt/task.py index c37a40650d..87d8645b06 100644 --- a/plugins/flytekit-openai/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai/flytekitplugins/chatgpt/task.py @@ -1,9 +1,11 @@ +import typing from typing import Any, Dict from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface from flytekit.extend.backend.base_agent import SyncAgentExecutorMixin +from flytekit.models.security import Secret, SecurityContext class ChatGPTTask(SyncAgentExecutorMixin, PythonTask): @@ -13,7 +15,7 @@ class ChatGPTTask(SyncAgentExecutorMixin, PythonTask): _TASK_TYPE = "chatgpt" - def __init__(self, name: str, openai_organization: str, chatgpt_config: Dict[str, Any], **kwargs): + def __init__(self, name: str, openai_organization: str, chatgpt_config: Dict[str, Any], openai_key: typing.Optional[Secret] = None, **kwargs): """ Args: name: Name of this task, should be unique in the project @@ -29,11 +31,14 @@ def __init__(self, name: str, openai_organization: str, chatgpt_config: Dict[str inputs = {"message": str} outputs = {"o0": str} + sec_ctx = SecurityContext(secrets=[openai_key]) + super().__init__( task_type=self._TASK_TYPE, name=name, task_config=task_config, interface=Interface(inputs=inputs, outputs=outputs), + security_ctx=sec_ctx, **kwargs, ) From 0ecc89bf4376c2c21aa22b4eb9be885cd9222900 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 26 Mar 2024 14:19:53 -0700 Subject: [PATCH 02/18] Add secrets to the agent interface Signed-off-by: Kevin Su --- flytekit/extend/backend/agent_service.py | 12 +++++++++--- flytekit/extend/backend/base_agent.py | 17 ++++++++++++++--- flytekit/extend/backend/utils.py | 13 ++++++++----- flytekit/models/task.py | 2 +- .../flytekitplugins/chatgpt/agent.py | 9 ++++++--- .../flytekitplugins/chatgpt/task.py | 10 ++++++++-- 6 files changed, 46 insertions(+), 17 deletions(-) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 3e1527c5c5..d6e7e310ee 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -122,6 +122,7 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon task_template=template, inputs=inputs, output_prefix=request.output_prefix, + secrets=request.secrets, ) return CreateTaskResponse(resource_meta=resource_mata.encode()) @@ -132,7 +133,9 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) else: agent = AgentRegistry.get_agent(request.task_type) logger.info(f"{agent.name} start checking the status of the job") - res = await mirror_async_methods(agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta)) + res = await mirror_async_methods( + agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta), secrets=request.secrets + ) if res.outputs is None: outputs = None @@ -152,7 +155,9 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon else: agent = AgentRegistry.get_agent(request.task_type) logger.info(f"{agent.name} start deleting the job") - await mirror_async_methods(agent.delete, resource_meta=agent.metadata_type.decode(request.resource_meta)) + await mirror_async_methods( + agent.delete, resource_meta=agent.metadata_type.decode(request.resource_meta), secrets=request.secrets + ) return DeleteTaskResponse() @@ -162,6 +167,7 @@ async def ExecuteTaskSync( ) -> typing.AsyncIterator[ExecuteTaskSyncResponse]: request = await request_iterator.__anext__() template = TaskTemplate.from_flyte_idl(request.header.template) + secrets = request.header.secrets task_type = template.type try: with request_latency.labels(task_type=task_type, operation=do_operation).time(): @@ -171,7 +177,7 @@ async def ExecuteTaskSync( request = await request_iterator.__anext__() literal_map = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None - res = await mirror_async_methods(agent.do, task_template=template, inputs=literal_map) + res = await mirror_async_methods(agent.do, task_template=template, inputs=literal_map, secrets=secrets) if res.outputs is None: outputs = None diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 3c1a149abc..511c7c2374 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -11,12 +11,13 @@ from types import FrameType, coroutine from typing import Any, Dict, List, Optional, Union -from flyteidl.admin.agent_pb2 import Agent +from flyteidl.admin.agent_pb2 import Agent, Secret from flyteidl.admin.agent_pb2 import TaskCategory as _TaskCategory from flyteidl.core import literals_pb2 from flyteidl.core.execution_pb2 import TaskExecution, TaskLog from rich.progress import Progress +import flytekit from flytekit import FlyteContext, PythonFunctionTask, logger from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core import utils @@ -117,7 +118,13 @@ class SyncAgentBase(AgentBase): name = "Base Sync Agent" @abstractmethod - def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs) -> Resource: + def do( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + secrets: Optional[List[Secret]] = None, + **kwargs, + ) -> Resource: """ This is the method that the agent will run. """ @@ -252,8 +259,12 @@ async def _do( ) -> Resource: try: ctx = FlyteContext.current_context() + secrets = [] + for secret in template.security_context.secrets: + value = flytekit.current_context().secrets.get(secret.group, secret.key, secret.group_version) + secrets.append(Secret(value=value)) literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) - return await mirror_async_methods(agent.do, task_template=template, inputs=literal_map) + return await mirror_async_methods(agent.do, task_template=template, inputs=literal_map, secrets=secrets) except Exception as error_message: raise FlyteUserException(f"Failed to run the task {self.name} with error: {error_message}") diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py index 22769539b9..32cc37c2b8 100644 --- a/flytekit/extend/backend/utils.py +++ b/flytekit/extend/backend/utils.py @@ -1,4 +1,5 @@ import asyncio +import functools import inspect import typing from typing import Callable, Coroutine @@ -13,8 +14,7 @@ def mirror_async_methods(func: Callable, **kwargs) -> Coroutine: if inspect.iscoroutinefunction(func): return func(**kwargs) - args = [v for _, v in kwargs.items()] - return asyncio.get_running_loop().run_in_executor(None, func, *args) + return asyncio.get_running_loop().run_in_executor(None, functools.partial(func, **kwargs)) def convert_to_flyte_phase(state: str) -> TaskExecution.Phase: @@ -40,9 +40,12 @@ def is_terminal_phase(phase: TaskExecution.Phase) -> bool: return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED] -def get_agent_secret(secret_key: str, secrets: typing.Optional[typing.Dict[str, Secret]] = None) -> str: - if secrets and secret_key in secrets: - return secrets[secret_key].value +def get_agent_secret(secret_key: str, secret: typing.Optional[Secret] = None) -> str: + """ + Get the secret from the context if the secret is not provided. + """ + if secret: + return secret.value return flytekit.current_context().secrets.get(secret_key) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index b6e8222fb9..d4021df383 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -455,7 +455,7 @@ def config(self): return self._config @property - def security_context(self): + def security_context(self) -> _sec.SecurityContext: return self._security_context @property diff --git a/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py index 90c3faf59f..0db990383b 100644 --- a/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py @@ -3,8 +3,8 @@ import typing from typing import Optional -from flyteidl.core.execution_pb2 import TaskExecution from flyteidl.admin.agent_pb2 import Secret +from flyteidl.core.execution_pb2 import TaskExecution from flytekit import FlyteContextManager, lazy_module from flytekit.core.type_engine import TypeEngine @@ -29,18 +29,21 @@ async def do( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, - secrets: Optional[typing.Dict[str, Secret]] = None, + secrets: typing.List[Secret] = None, **kwargs, ) -> Resource: ctx = FlyteContextManager.current_context() input_python_value = TypeEngine.literal_map_to_kwargs(ctx, inputs, {"message": str}) message = input_python_value["message"] + api_key = None + if secrets and len(secrets) > 0: + api_key = secrets[0] custom = task_template.custom custom["chatgpt_config"]["messages"] = [{"role": "user", "content": message}] client = openai.AsyncOpenAI( organization=custom["openai_organization"], - api_key=get_agent_secret(secret_key=OPENAI_API_KEY), + api_key=get_agent_secret(secret_key=OPENAI_API_KEY, secret=api_key), ) logger = logging.getLogger("httpx") diff --git a/plugins/flytekit-openai/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai/flytekitplugins/chatgpt/task.py index 87d8645b06..c1c720d010 100644 --- a/plugins/flytekit-openai/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai/flytekitplugins/chatgpt/task.py @@ -15,7 +15,14 @@ class ChatGPTTask(SyncAgentExecutorMixin, PythonTask): _TASK_TYPE = "chatgpt" - def __init__(self, name: str, openai_organization: str, chatgpt_config: Dict[str, Any], openai_key: typing.Optional[Secret] = None, **kwargs): + def __init__( + self, + name: str, + openai_organization: str, + chatgpt_config: Dict[str, Any], + openai_key: typing.Optional[Secret] = None, + **kwargs, + ): """ Args: name: Name of this task, should be unique in the project @@ -27,7 +34,6 @@ def __init__(self, name: str, openai_organization: str, chatgpt_config: Dict[str raise ValueError("The 'model' configuration variable is required in chatgpt_config") task_config = {"openai_organization": openai_organization, "chatgpt_config": chatgpt_config} - inputs = {"message": str} outputs = {"o0": str} From c78142b5aa24a97694150d9a6714703c4b3d56fd Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 28 Mar 2024 01:59:36 -0700 Subject: [PATCH 03/18] nit Signed-off-by: Kevin Su --- flytekit/core/python_auto_container.py | 2 ++ .../flytekitplugins/chatgpt/agent.py | 5 ++--- .../flytekitplugins/spark/agent.py | 17 ++++++++++------- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 7099456e5b..0cfdc43599 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -78,6 +78,8 @@ def __init__( """ sec_ctx = None if secret_requests: + if not isinstance(secret_requests, list): + raise AssertionError(f"secret_requests should be of type list, received {type(secret_requests)}") for s in secret_requests: if not isinstance(s, Secret): raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") diff --git a/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py index 0db990383b..5db48088e9 100644 --- a/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py @@ -1,7 +1,6 @@ import asyncio import logging -import typing -from typing import Optional +from typing import Optional, List from flyteidl.admin.agent_pb2 import Secret from flyteidl.core.execution_pb2 import TaskExecution @@ -29,7 +28,7 @@ async def do( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, - secrets: typing.List[Secret] = None, + secrets: Optional[List[Secret]] = None, **kwargs, ) -> Resource: ctx = FlyteContextManager.current_context() diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index d367f3f04a..ed3e91918f 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -1,8 +1,7 @@ import http import json -import typing from dataclasses import dataclass -from typing import Optional +from typing import Optional, List, Dict from flyteidl.core.execution_pb2 import TaskExecution @@ -12,6 +11,7 @@ from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate +from flyteidl.admin.agent_pb2 import Secret aiohttp = lazy_module("aiohttp") @@ -31,7 +31,7 @@ def __init__(self): super().__init__(task_type_name="spark", metadata_type=DatabricksJobMetadata) async def create( - self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, secrets: Optional[List[Secret]] = None, **kwargs ) -> DatabricksJobMetadata: custom = task_template.custom container = task_template.container @@ -69,7 +69,7 @@ async def create( return DatabricksJobMetadata(databricks_instance=databricks_instance, run_id=str(response["run_id"])) - async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource: + async def get(self, resource_meta: DatabricksJobMetadata, secrets: Optional[List[Secret]] = None, **kwargs) -> Resource: databricks_instance = resource_meta.databricks_instance databricks_url = ( f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={resource_meta.run_id}" @@ -102,7 +102,7 @@ async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource: return Resource(phase=cur_phase, message=message, log_links=log_links) - async def delete(self, resource_meta: DatabricksJobMetadata, **kwargs): + async def delete(self, resource_meta: DatabricksJobMetadata, secrets: Optional[List[Secret]] = None, **kwargs): databricks_url = f"https://{resource_meta.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel" data = json.dumps({"run_id": resource_meta.run_id}) @@ -113,8 +113,11 @@ async def delete(self, resource_meta: DatabricksJobMetadata, **kwargs): await resp.json() -def get_header() -> typing.Dict[str, str]: - token = get_agent_secret("FLYTE_DATABRICKS_ACCESS_TOKEN") +def get_header(secrets: Optional[List[Secret]] = None) -> Dict[str, str]: + access_token = None + if secrets and len(secrets) > 0: + access_token = secrets[0] + token = get_agent_secret(secret_key="FLYTE_DATABRICKS_ACCESS_TOKEN", secret=access_token) return {"Authorization": f"Bearer {token}", "content-type": "application/json"} From 037fa2c9f9a9b57fc49e8d10a3f33c39d7c918fb Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 4 Apr 2024 12:38:26 -0700 Subject: [PATCH 04/18] update databricks agent Signed-off-by: Kevin Su --- flytekit/extend/backend/utils.py | 4 ++-- flytekit/models/task.py | 2 +- .../flytekitplugins/openai/chatgpt/agent.py | 2 +- .../flytekitplugins/spark/agent.py | 20 ++++++++++++------- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py index 32cc37c2b8..a016232118 100644 --- a/flytekit/extend/backend/utils.py +++ b/flytekit/extend/backend/utils.py @@ -7,7 +7,6 @@ from flyteidl.admin.agent_pb2 import Secret from flyteidl.core.execution_pb2 import TaskExecution -import flytekit from flytekit.models.task import TaskTemplate @@ -44,9 +43,10 @@ def get_agent_secret(secret_key: str, secret: typing.Optional[Secret] = None) -> """ Get the secret from the context if the secret is not provided. """ + print(f"Getting secret for: {secret}") if secret: return secret.value - return flytekit.current_context().secrets.get(secret_key) + # return flytekit.current_context().secrets.get(secret_key) def render_task_template(tt: TaskTemplate, file_prefix: str) -> TaskTemplate: diff --git a/flytekit/models/task.py b/flytekit/models/task.py index f27f0169c4..708f74ae98 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -939,7 +939,7 @@ def from_flyte_idl(cls, pb2_object): return cls( image=pb2_object.image, command=pb2_object.command, - args=pb2_object.args, + args=[arg for arg in pb2_object.args], resources=Resources.from_flyte_idl(pb2_object.resources), env={kv.key: kv.value for kv in pb2_object.env}, config={kv.key: kv.value for kv in pb2_object.config}, diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py index 5db48088e9..85d08882ce 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Optional, List +from typing import List, Optional from flyteidl.admin.agent_pb2 import Secret from flyteidl.core.execution_pb2 import TaskExecution diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index ed3e91918f..65a42522bc 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -1,8 +1,9 @@ import http import json from dataclasses import dataclass -from typing import Optional, List, Dict +from typing import Dict, List, Optional +from flyteidl.admin.agent_pb2 import Secret from flyteidl.core.execution_pb2 import TaskExecution from flytekit import lazy_module @@ -11,7 +12,6 @@ from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from flyteidl.admin.agent_pb2 import Secret aiohttp = lazy_module("aiohttp") @@ -31,7 +31,11 @@ def __init__(self): super().__init__(task_type_name="spark", metadata_type=DatabricksJobMetadata) async def create( - self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, secrets: Optional[List[Secret]] = None, **kwargs + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + secrets: Optional[List[Secret]] = None, + **kwargs, ) -> DatabricksJobMetadata: custom = task_template.custom container = task_template.container @@ -62,21 +66,23 @@ async def create( data = json.dumps(databricks_job) async with aiohttp.ClientSession() as session: - async with session.post(databricks_url, headers=get_header(), data=data) as resp: + async with session.post(databricks_url, headers=get_header(secrets), data=data) as resp: response = await resp.json() if resp.status != http.HTTPStatus.OK: raise Exception(f"Failed to create databricks job with error: {response}") return DatabricksJobMetadata(databricks_instance=databricks_instance, run_id=str(response["run_id"])) - async def get(self, resource_meta: DatabricksJobMetadata, secrets: Optional[List[Secret]] = None, **kwargs) -> Resource: + async def get( + self, resource_meta: DatabricksJobMetadata, secrets: Optional[List[Secret]] = None, **kwargs + ) -> Resource: databricks_instance = resource_meta.databricks_instance databricks_url = ( f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={resource_meta.run_id}" ) async with aiohttp.ClientSession() as session: - async with session.get(databricks_url, headers=get_header()) as resp: + async with session.get(databricks_url, headers=get_header(secrets)) as resp: if resp.status != http.HTTPStatus.OK: raise Exception(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}") response = await resp.json() @@ -107,7 +113,7 @@ async def delete(self, resource_meta: DatabricksJobMetadata, secrets: Optional[L data = json.dumps({"run_id": resource_meta.run_id}) async with aiohttp.ClientSession() as session: - async with session.post(databricks_url, headers=get_header(), data=data) as resp: + async with session.post(databricks_url, headers=get_header(secrets), data=data) as resp: if resp.status != http.HTTPStatus.OK: raise Exception(f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}") await resp.json() From 44c4866384e7b59d27f929559e246f45d982fd0c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 4 Apr 2024 14:16:03 -0700 Subject: [PATCH 05/18] update base agent Signed-off-by: Kevin Su --- flytekit/extend/backend/base_agent.py | 26 ++++++++++++++++++++------ flytekit/extend/backend/utils.py | 4 ++-- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 36a61675c4..d2d27a1ff3 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -295,8 +295,9 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap: task_template = get_serializable(OrderedDict(), ss, self).template self._agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version) - resource_mata = asyncio.run(self._create(task_template, output_prefix, kwargs)) - resource = asyncio.run(self._get(resource_meta=resource_mata)) + secrets = get_secrets(task_template) + resource_mata = asyncio.run(self._create(task_template, output_prefix, kwargs, secrets)) + resource = asyncio.run(self._get(resource_meta=resource_mata, secrets=secrets)) if resource.phase != TaskExecution.SUCCEEDED: raise FlyteUserException(f"Failed to run the task {self.name} with error: {resource.message}") @@ -314,7 +315,11 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap: return resource.outputs async def _create( - self: PythonTask, task_template: TaskTemplate, output_prefix: str, inputs: Dict[str, Any] = None + self: PythonTask, + task_template: TaskTemplate, + output_prefix: str, + inputs: Dict[str, Any] = None, + secrets: List[Secret] = None, ) -> ResourceMeta: ctx = FlyteContext.current_context() @@ -331,12 +336,13 @@ async def _create( task_template=task_template, inputs=literal_map, output_prefix=output_prefix, + secrets=secrets, ) - signal.signal(signal.SIGINT, partial(self.signal_handler, resource_meta)) # type: ignore + signal.signal(signal.SIGINT, partial(self.signal_handler, resource_meta, secrets)) # type: ignore return resource_meta - async def _get(self: PythonTask, resource_meta: ResourceMeta) -> Resource: + async def _get(self: PythonTask, resource_meta: ResourceMeta, secrets: List[Secret] = None) -> Resource: phase = TaskExecution.RUNNING progress = Progress(transient=True) @@ -347,7 +353,7 @@ async def _get(self: PythonTask, resource_meta: ResourceMeta) -> Resource: while not is_terminal_phase(phase): progress.start_task(task) time.sleep(1) - resource = await mirror_async_methods(self._agent.get, resource_meta=resource_meta) + resource = await mirror_async_methods(self._agent.get, resource_meta=resource_meta, secrets=secrets) if self._clean_up_task: await self._clean_up_task sys.exit(1) @@ -371,3 +377,11 @@ def signal_handler(self, resource_meta: ResourceMeta, signum: int, frame: FrameT if self._clean_up_task is None: co = mirror_async_methods(self._agent.delete, resource_meta=resource_meta) self._clean_up_task = asyncio.create_task(co) + + +def get_secrets(task_template: TaskTemplate) -> List[Secret]: + secrets = [] + for secret in task_template.security_context.secrets: + value = flytekit.current_context().secrets.get(secret.group, secret.key, secret.group_version) + secrets.append(Secret(value=value)) + return secrets diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py index a016232118..32cc37c2b8 100644 --- a/flytekit/extend/backend/utils.py +++ b/flytekit/extend/backend/utils.py @@ -7,6 +7,7 @@ from flyteidl.admin.agent_pb2 import Secret from flyteidl.core.execution_pb2 import TaskExecution +import flytekit from flytekit.models.task import TaskTemplate @@ -43,10 +44,9 @@ def get_agent_secret(secret_key: str, secret: typing.Optional[Secret] = None) -> """ Get the secret from the context if the secret is not provided. """ - print(f"Getting secret for: {secret}") if secret: return secret.value - # return flytekit.current_context().secrets.get(secret_key) + return flytekit.current_context().secrets.get(secret_key) def render_task_template(tt: TaskTemplate, file_prefix: str) -> TaskTemplate: From 0a7e7761427378516a5e089bb03ba02fa422c76e Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 4 Apr 2024 14:45:35 -0700 Subject: [PATCH 06/18] nit Signed-off-by: Kevin Su --- flytekit/extend/backend/base_agent.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index d2d27a1ff3..ef5024f0fe 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -158,6 +158,7 @@ def create( inputs: Optional[LiteralMap], output_prefix: Optional[str], task_execution_metadata: Optional[TaskExecutionMetadata], + secrets: Optional[List[Secret]] = None, **kwargs, ) -> ResourceMeta: """ @@ -166,7 +167,7 @@ def create( raise NotImplementedError @abstractmethod - def get(self, resource_meta: ResourceMeta, **kwargs) -> Resource: + def get(self, resource_meta: ResourceMeta, secrets: Optional[List[Secret]] = None, **kwargs) -> Resource: """ Return the status of the task, and return the outputs in some cases. For example, bigquery job can't write the structured dataset to the output location, so it returns the output literals to the propeller, @@ -175,9 +176,9 @@ def get(self, resource_meta: ResourceMeta, **kwargs) -> Resource: raise NotImplementedError @abstractmethod - def delete(self, resource_meta: ResourceMeta, **kwargs): + def delete(self, resource_meta: ResourceMeta, secrets: Optional[List[Secret]] = None, **kwargs): """ - Delete the task. This call should be idempotent. It should raise an error if fails to delete the task. + Delete the task. This call should be idempotent. It should raise an error if it fails to delete the task. """ raise NotImplementedError @@ -264,10 +265,7 @@ async def _do( ) -> Resource: try: ctx = FlyteContext.current_context() - secrets = [] - for secret in template.security_context.secrets: - value = flytekit.current_context().secrets.get(secret.group, secret.key, secret.group_version) - secrets.append(Secret(value=value)) + secrets = get_secrets(template) literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) return await mirror_async_methods(agent.do, task_template=template, inputs=literal_map, secrets=secrets) except Exception as error_message: From 0a5bd774aa1c9e870c30824e190d9d3b629f6d44 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 17 Apr 2024 01:30:50 -0700 Subject: [PATCH 07/18] add connection Signed-off-by: Kevin Su --- flytekit/extend/backend/agent_service.py | 12 ++++--- flytekit/extend/backend/base_agent.py | 45 ++++++++++-------------- flytekit/models/security.py | 5 ++- 3 files changed, 30 insertions(+), 32 deletions(-) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index bbaada60ac..fee285df25 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -123,7 +123,7 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon task_template=template, inputs=inputs, output_prefix=request.output_prefix, - secrets=request.secrets, + connection=request.connection, task_execution_metadata=task_execution_metadata, ) return CreateTaskResponse(resource_meta=resource_mata.encode()) @@ -136,7 +136,7 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) agent = AgentRegistry.get_agent(request.task_type) logger.info(f"{agent.name} start checking the status of the job") res = await mirror_async_methods( - agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta), secrets=request.secrets + agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta), connection=request.connection ) if res.outputs is None: @@ -158,7 +158,7 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon agent = AgentRegistry.get_agent(request.task_type) logger.info(f"{agent.name} start deleting the job") await mirror_async_methods( - agent.delete, resource_meta=agent.metadata_type.decode(request.resource_meta), secrets=request.secrets + agent.delete, resource_meta=agent.metadata_type.decode(request.resource_meta), connection=request.connection ) return DeleteTaskResponse() @@ -169,7 +169,7 @@ async def ExecuteTaskSync( ) -> typing.AsyncIterator[ExecuteTaskSyncResponse]: request = await request_iterator.__anext__() template = TaskTemplate.from_flyte_idl(request.header.template) - secrets = request.header.secrets + connection = request.header.connection task_type = template.type try: with request_latency.labels(task_type=task_type, operation=do_operation).time(): @@ -179,7 +179,9 @@ async def ExecuteTaskSync( request = await request_iterator.__anext__() literal_map = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None - res = await mirror_async_methods(agent.do, task_template=template, inputs=literal_map, secrets=secrets) + res = await mirror_async_methods( + agent.do, task_template=template, inputs=literal_map, connection=connection + ) if res.outputs is None: outputs = None diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index ef5024f0fe..fcd2fa0d51 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -11,13 +11,13 @@ from types import FrameType, coroutine from typing import Any, Dict, List, Optional, Union -from flyteidl.admin.agent_pb2 import Agent, Secret +from flyteidl.admin.agent_pb2 import Agent from flyteidl.admin.agent_pb2 import TaskCategory as _TaskCategory from flyteidl.core import literals_pb2 from flyteidl.core.execution_pb2 import TaskExecution, TaskLog +from flyteidl.core.security_pb2 import Connection from rich.progress import Progress -import flytekit from flytekit import FlyteContext, PythonFunctionTask, logger from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core import utils @@ -122,7 +122,7 @@ def do( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, - secrets: Optional[List[Secret]] = None, + connection: Optional[List[Connection]] = None, **kwargs, ) -> Resource: """ @@ -158,7 +158,7 @@ def create( inputs: Optional[LiteralMap], output_prefix: Optional[str], task_execution_metadata: Optional[TaskExecutionMetadata], - secrets: Optional[List[Secret]] = None, + connection: Optional[Connection] = None, **kwargs, ) -> ResourceMeta: """ @@ -167,7 +167,7 @@ def create( raise NotImplementedError @abstractmethod - def get(self, resource_meta: ResourceMeta, secrets: Optional[List[Secret]] = None, **kwargs) -> Resource: + def get(self, resource_meta: ResourceMeta, connection: Optional[Connection] = None, **kwargs) -> Resource: """ Return the status of the task, and return the outputs in some cases. For example, bigquery job can't write the structured dataset to the output location, so it returns the output literals to the propeller, @@ -176,7 +176,7 @@ def get(self, resource_meta: ResourceMeta, secrets: Optional[List[Secret]] = Non raise NotImplementedError @abstractmethod - def delete(self, resource_meta: ResourceMeta, secrets: Optional[List[Secret]] = None, **kwargs): + def delete(self, resource_meta: ResourceMeta, connection: Optional[Connection] = None, **kwargs): """ Delete the task. This call should be idempotent. It should raise an error if it fails to delete the task. """ @@ -243,7 +243,9 @@ class SyncAgentExecutorMixin: Sending a prompt to ChatGPT and getting a response, or retrieving some metadata from a backend system. """ - def execute(self: PythonTask, **kwargs) -> LiteralMap: + T = typing.TypeVar("T", "SyncAgentExecutorMixin", PythonTask) + + def execute(self: T, **kwargs) -> LiteralMap: from flytekit.tools.translator import get_serializable ctx = FlyteContext.current_context() @@ -265,9 +267,8 @@ async def _do( ) -> Resource: try: ctx = FlyteContext.current_context() - secrets = get_secrets(template) literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) - return await mirror_async_methods(agent.do, task_template=template, inputs=literal_map, secrets=secrets) + return await mirror_async_methods(agent.do, task_template=template, inputs=literal_map) except Exception as error_message: raise FlyteUserException(f"Failed to run the task {self.name} with error: {error_message}") @@ -280,10 +281,12 @@ class AsyncAgentExecutorMixin: Asynchronous tasks are tasks that take a long time to complete, such as running a query. """ + T = typing.TypeVar("T", "AsyncAgentExecutorMixin", PythonTask) + _clean_up_task: coroutine = None _agent: AsyncAgentBase = None - def execute(self: PythonTask, **kwargs) -> LiteralMap: + def execute(self: T, **kwargs) -> LiteralMap: ctx = FlyteContext.current_context() ss = ctx.serialization_settings or SerializationSettings(ImageConfig()) output_prefix = ctx.file_access.get_random_remote_directory() @@ -293,9 +296,8 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap: task_template = get_serializable(OrderedDict(), ss, self).template self._agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version) - secrets = get_secrets(task_template) - resource_mata = asyncio.run(self._create(task_template, output_prefix, kwargs, secrets)) - resource = asyncio.run(self._get(resource_meta=resource_mata, secrets=secrets)) + resource_mata = asyncio.run(self._create(task_template, output_prefix, kwargs)) + resource = asyncio.run(self._get(resource_meta=resource_mata)) if resource.phase != TaskExecution.SUCCEEDED: raise FlyteUserException(f"Failed to run the task {self.name} with error: {resource.message}") @@ -313,11 +315,11 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap: return resource.outputs async def _create( - self: PythonTask, + self: T, task_template: TaskTemplate, output_prefix: str, inputs: Dict[str, Any] = None, - secrets: List[Secret] = None, + connection: Optional[Connection] = None, ) -> ResourceMeta: ctx = FlyteContext.current_context() @@ -334,13 +336,12 @@ async def _create( task_template=task_template, inputs=literal_map, output_prefix=output_prefix, - secrets=secrets, ) signal.signal(signal.SIGINT, partial(self.signal_handler, resource_meta, secrets)) # type: ignore return resource_meta - async def _get(self: PythonTask, resource_meta: ResourceMeta, secrets: List[Secret] = None) -> Resource: + async def _get(self: T, resource_meta: ResourceMeta) -> Resource: phase = TaskExecution.RUNNING progress = Progress(transient=True) @@ -351,7 +352,7 @@ async def _get(self: PythonTask, resource_meta: ResourceMeta, secrets: List[Secr while not is_terminal_phase(phase): progress.start_task(task) time.sleep(1) - resource = await mirror_async_methods(self._agent.get, resource_meta=resource_meta, secrets=secrets) + resource = await mirror_async_methods(self._agent.get, resource_meta=resource_meta) if self._clean_up_task: await self._clean_up_task sys.exit(1) @@ -375,11 +376,3 @@ def signal_handler(self, resource_meta: ResourceMeta, signum: int, frame: FrameT if self._clean_up_task is None: co = mirror_async_methods(self._agent.delete, resource_meta=resource_meta) self._clean_up_task = asyncio.create_task(co) - - -def get_secrets(task_template: TaskTemplate) -> List[Secret]: - secrets = [] - for secret in task_template.security_context.secrets: - value = flytekit.current_context().secrets.get(secret.group, secret.key, secret.group_version) - secrets.append(Secret(value=value)) - return secrets diff --git a/flytekit/models/security.py b/flytekit/models/security.py index a9ee7e7cb9..c95a8b981f 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -148,6 +148,7 @@ class SecurityContext(_common.FlyteIdlEntity): run_as: Optional[Identity] = None secrets: Optional[List[Secret]] = None tokens: Optional[List[OAuth2TokenRequest]] = None + connection: Optional[str] = None def __post_init__(self): if self.secrets and not isinstance(self.secrets, list): @@ -156,12 +157,13 @@ def __post_init__(self): self.tokens = [self.tokens] def to_flyte_idl(self) -> _sec.SecurityContext: - if self.run_as is None and self.secrets is None and self.tokens is None: + if self.run_as is None and self.secrets is None and self.tokens is None and self.connection is None: return None return _sec.SecurityContext( run_as=self.run_as.to_flyte_idl() if self.run_as else None, secrets=[s.to_flyte_idl() for s in self.secrets] if self.secrets else None, tokens=[t.to_flyte_idl() for t in self.tokens] if self.tokens else None, + connection=self.connection, ) @classmethod @@ -172,4 +174,5 @@ def from_flyte_idl(cls, pb2_object: _sec.SecurityContext) -> "SecurityContext": else None, secrets=[Secret.from_flyte_idl(s) for s in pb2_object.secrets] if pb2_object.secrets else None, tokens=[OAuth2TokenRequest.from_flyte_idl(t) for t in pb2_object.tokens] if pb2_object.tokens else None, + connection=pb2_object.connection if pb2_object.connection else None, ) From d765e26142144bd7ea5412f129b42b4ade0a2b78 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 17 Apr 2024 10:33:45 -0700 Subject: [PATCH 08/18] add connection Signed-off-by: Kevin Su --- flytekit/core/python_auto_container.py | 5 +++- flytekit/extend/backend/base_agent.py | 23 ++++++++++++++++--- flytekit/extend/backend/utils.py | 8 ++----- .../flytekitplugins/openai/chatgpt/agent.py | 20 ++++++++-------- 4 files changed, 37 insertions(+), 19 deletions(-) diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 0cfdc43599..0c8bbd642b 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -79,7 +79,10 @@ def __init__( sec_ctx = None if secret_requests: if not isinstance(secret_requests, list): - raise AssertionError(f"secret_requests should be of type list, received {type(secret_requests)}") + if isinstance(secret_requests, Secret): + secret_requests = [secret_requests] + else: + raise AssertionError(f"Secret {secret_requests} should be of type flytekit.Secret.") for s in secret_requests: if not isinstance(s, Secret): raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index fcd2fa0d51..fd1a77eacb 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -13,9 +13,8 @@ from flyteidl.admin.agent_pb2 import Agent from flyteidl.admin.agent_pb2 import TaskCategory as _TaskCategory -from flyteidl.core import literals_pb2 +from flyteidl.core import literals_pb2, security_pb2 from flyteidl.core.execution_pb2 import TaskExecution, TaskLog -from flyteidl.core.security_pb2 import Connection from rich.progress import Progress from flytekit import FlyteContext, PythonFunctionTask, logger @@ -53,6 +52,20 @@ def __str__(self): return f"{self._name}_v{self._version}" +@dataclass +class Connection: + """ + This is the connection object that the agent can use to connect to the external services. + """ + + @classmethod + def decode(cls, data: security_pb2.Connection) -> "Connection": + """ + Decode the resource meta from bytes. + """ + return dataclass_from_dict(cls, {k: v for k, v in data.secrets.items()}) + + @dataclass class ResourceMeta: """ @@ -117,12 +130,16 @@ class SyncAgentBase(AgentBase): name = "Base Sync Agent" + def __init__(self, connection_type: Connection, **kwargs): + super().__init__(**kwargs) + self._connection_type = connection_type + @abstractmethod def do( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, - connection: Optional[List[Connection]] = None, + connection: Optional[Connection] = None, **kwargs, ) -> Resource: """ diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py index 32cc37c2b8..ac88089b26 100644 --- a/flytekit/extend/backend/utils.py +++ b/flytekit/extend/backend/utils.py @@ -1,10 +1,8 @@ import asyncio import functools import inspect -import typing from typing import Callable, Coroutine -from flyteidl.admin.agent_pb2 import Secret from flyteidl.core.execution_pb2 import TaskExecution import flytekit @@ -40,12 +38,10 @@ def is_terminal_phase(phase: TaskExecution.Phase) -> bool: return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED] -def get_agent_secret(secret_key: str, secret: typing.Optional[Secret] = None) -> str: +def get_agent_secret(secret_key: str) -> str: """ - Get the secret from the context if the secret is not provided. + Get the secret from the Flyte context. """ - if secret: - return secret.value return flytekit.current_context().secrets.get(secret_key) diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py index 85d08882ce..4eb3ef9c7a 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py @@ -1,13 +1,13 @@ import asyncio import logging -from typing import List, Optional +from dataclasses import dataclass +from typing import Optional -from flyteidl.admin.agent_pb2 import Secret from flyteidl.core.execution_pb2 import TaskExecution from flytekit import FlyteContextManager, lazy_module from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentRegistry, Resource, SyncAgentBase +from flytekit.extend.backend.base_agent import AgentRegistry, Connection, Resource, SyncAgentBase from flytekit.extend.backend.utils import get_agent_secret from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -18,31 +18,33 @@ OPENAI_API_KEY = "FLYTE_OPENAI_API_KEY" +@dataclass +class ChatGPTConnection(Connection): + openai_api_key: str + + class ChatGPTAgent(SyncAgentBase): name = "ChatGPT Agent" def __init__(self): - super().__init__(task_type_name="chatgpt") + super().__init__(task_type_name="chatgpt", connection_type=ChatGPTConnection) async def do( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, - secrets: Optional[List[Secret]] = None, + connection: Optional[ChatGPTConnection] = None, **kwargs, ) -> Resource: ctx = FlyteContextManager.current_context() input_python_value = TypeEngine.literal_map_to_kwargs(ctx, inputs, {"message": str}) message = input_python_value["message"] - api_key = None - if secrets and len(secrets) > 0: - api_key = secrets[0] custom = task_template.custom custom["chatgpt_config"]["messages"] = [{"role": "user", "content": message}] client = openai.AsyncOpenAI( organization=custom["openai_organization"], - api_key=get_agent_secret(secret_key=OPENAI_API_KEY, secret=api_key), + api_key=connection.openai_api_key or get_agent_secret(secret_key=OPENAI_API_KEY), ) logger = logging.getLogger("httpx") From c09e8c83beaca97d07bd330f8a7e2ee4b1bd9b03 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 17 Apr 2024 11:17:31 -0700 Subject: [PATCH 09/18] cleanup Signed-off-by: Kevin Su --- flytekit/extend/backend/agent_service.py | 14 +++++-- flytekit/extend/backend/base_agent.py | 21 +++++++--- .../flytekitplugins/openai/chatgpt/task.py | 7 ++-- .../flytekitplugins/spark/agent.py | 41 +++++++++++-------- 4 files changed, 54 insertions(+), 29 deletions(-) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index fee285df25..ba7dcc12d4 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -134,9 +134,11 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) agent = AgentRegistry.get_agent(request.task_category.name, request.task_category.version) else: agent = AgentRegistry.get_agent(request.task_type) - logger.info(f"{agent.name} start checking the status of the job") + res = await mirror_async_methods( - agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta), connection=request.connection + agent.get, + resource_meta=agent.metadata_type.decode(request.resource_meta), + connection=agent.connection_type.decode(request.connection) if request.connection else None, ) if res.outputs is None: @@ -158,7 +160,9 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon agent = AgentRegistry.get_agent(request.task_type) logger.info(f"{agent.name} start deleting the job") await mirror_async_methods( - agent.delete, resource_meta=agent.metadata_type.decode(request.resource_meta), connection=request.connection + agent.delete, + resource_meta=agent.metadata_type.decode(request.resource_meta), + connection=agent.connection_type.decode(request.connection) if request.connection else None, ) return DeleteTaskResponse() @@ -169,7 +173,6 @@ async def ExecuteTaskSync( ) -> typing.AsyncIterator[ExecuteTaskSyncResponse]: request = await request_iterator.__anext__() template = TaskTemplate.from_flyte_idl(request.header.template) - connection = request.header.connection task_type = template.type try: with request_latency.labels(task_type=task_type, operation=do_operation).time(): @@ -179,6 +182,9 @@ async def ExecuteTaskSync( request = await request_iterator.__anext__() literal_map = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None + connection = ( + agent.connection_type.decode(request.header.connection) if request.header.connection else None + ) res = await mirror_async_methods( agent.do, task_template=template, inputs=literal_map, connection=connection ) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index fd1a77eacb..dcac2fe756 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -59,11 +59,13 @@ class Connection: """ @classmethod - def decode(cls, data: security_pb2.Connection) -> "Connection": + def decode(cls, connection: security_pb2.Connection) -> "Connection": """ Decode the resource meta from bytes. """ - return dataclass_from_dict(cls, {k: v for k, v in data.secrets.items()}) + data = {k: v for k, v in connection.secrets.items()} + data.update({k: v for k, v in connection.config.items()}) + return dataclass_from_dict(cls, data) @dataclass @@ -130,10 +132,14 @@ class SyncAgentBase(AgentBase): name = "Base Sync Agent" - def __init__(self, connection_type: Connection, **kwargs): + def __init__(self, connection_type: Optional[Connection] = None, **kwargs): super().__init__(**kwargs) self._connection_type = connection_type + @property + def connection_type(self) -> Connection: + return self._connection_type + @abstractmethod def do( self, @@ -160,14 +166,19 @@ class AsyncAgentBase(AgentBase): name = "Base Async Agent" - def __init__(self, metadata_type: ResourceMeta, **kwargs): + def __init__(self, metadata_type: ResourceMeta, connection_type: Optional[Connection] = None, **kwargs): super().__init__(**kwargs) self._metadata_type = metadata_type + self._connection_type = connection_type @property def metadata_type(self) -> ResourceMeta: return self._metadata_type + @property + def connection_type(self) -> Connection: + return self._connection_type + @abstractmethod def create( self, @@ -355,7 +366,7 @@ async def _create( output_prefix=output_prefix, ) - signal.signal(signal.SIGINT, partial(self.signal_handler, resource_meta, secrets)) # type: ignore + signal.signal(signal.SIGINT, partial(self.signal_handler, resource_meta)) # type: ignore return resource_meta async def _get(self: T, resource_meta: ResourceMeta) -> Resource: diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py index c1c720d010..027f4c494e 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py @@ -1,11 +1,10 @@ -import typing from typing import Any, Dict from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface from flytekit.extend.backend.base_agent import SyncAgentExecutorMixin -from flytekit.models.security import Secret, SecurityContext +from flytekit.models.security import SecurityContext class ChatGPTTask(SyncAgentExecutorMixin, PythonTask): @@ -20,7 +19,7 @@ def __init__( name: str, openai_organization: str, chatgpt_config: Dict[str, Any], - openai_key: typing.Optional[Secret] = None, + connection: str = "chatgpt", **kwargs, ): """ @@ -37,7 +36,7 @@ def __init__( inputs = {"message": str} outputs = {"o0": str} - sec_ctx = SecurityContext(secrets=[openai_key]) + sec_ctx = SecurityContext(connection=connection) super().__init__( task_type=self._TASK_TYPE, diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 65a42522bc..dfaff833fd 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -1,13 +1,12 @@ import http import json from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, Optional -from flyteidl.admin.agent_pb2 import Secret from flyteidl.core.execution_pb2 import TaskExecution from flytekit import lazy_module -from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Connection, Resource, ResourceMeta from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap @@ -18,6 +17,12 @@ DATABRICKS_API_ENDPOINT = "/api/2.1/jobs" +@dataclass +class DatabricksConnection(Connection): + access_token: str + databricks_instance: str + + @dataclass class DatabricksJobMetadata(ResourceMeta): databricks_instance: str @@ -28,13 +33,15 @@ class DatabricksAgent(AsyncAgentBase): name = "Databricks Agent" def __init__(self): - super().__init__(task_type_name="spark", metadata_type=DatabricksJobMetadata) + super().__init__( + task_type_name="spark", metadata_type=DatabricksJobMetadata, connection_type=DatabricksConnection + ) async def create( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, - secrets: Optional[List[Secret]] = None, + connection: Optional[DatabricksConnection] = None, **kwargs, ) -> DatabricksJobMetadata: custom = task_template.custom @@ -66,7 +73,7 @@ async def create( data = json.dumps(databricks_job) async with aiohttp.ClientSession() as session: - async with session.post(databricks_url, headers=get_header(secrets), data=data) as resp: + async with session.post(databricks_url, headers=get_header(connection), data=data) as resp: response = await resp.json() if resp.status != http.HTTPStatus.OK: raise Exception(f"Failed to create databricks job with error: {response}") @@ -74,7 +81,7 @@ async def create( return DatabricksJobMetadata(databricks_instance=databricks_instance, run_id=str(response["run_id"])) async def get( - self, resource_meta: DatabricksJobMetadata, secrets: Optional[List[Secret]] = None, **kwargs + self, resource_meta: DatabricksJobMetadata, connection: Optional[DatabricksConnection] = None, **kwargs ) -> Resource: databricks_instance = resource_meta.databricks_instance databricks_url = ( @@ -82,7 +89,7 @@ async def get( ) async with aiohttp.ClientSession() as session: - async with session.get(databricks_url, headers=get_header(secrets)) as resp: + async with session.get(databricks_url, headers=get_header(connection)) as resp: if resp.status != http.HTTPStatus.OK: raise Exception(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}") response = await resp.json() @@ -108,23 +115,25 @@ async def get( return Resource(phase=cur_phase, message=message, log_links=log_links) - async def delete(self, resource_meta: DatabricksJobMetadata, secrets: Optional[List[Secret]] = None, **kwargs): + async def delete( + self, resource_meta: DatabricksJobMetadata, connection: Optional[DatabricksConnection] = None, **kwargs + ): databricks_url = f"https://{resource_meta.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel" data = json.dumps({"run_id": resource_meta.run_id}) async with aiohttp.ClientSession() as session: - async with session.post(databricks_url, headers=get_header(secrets), data=data) as resp: + async with session.post(databricks_url, headers=get_header(connection), data=data) as resp: if resp.status != http.HTTPStatus.OK: raise Exception(f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}") await resp.json() -def get_header(secrets: Optional[List[Secret]] = None) -> Dict[str, str]: - access_token = None - if secrets and len(secrets) > 0: - access_token = secrets[0] - token = get_agent_secret(secret_key="FLYTE_DATABRICKS_ACCESS_TOKEN", secret=access_token) - return {"Authorization": f"Bearer {token}", "content-type": "application/json"} +def get_header(connection: Optional[DatabricksConnection] = None) -> Dict[str, str]: + if connection: + access_token = connection.access_token + else: + access_token = get_agent_secret(secret_key="FLYTE_DATABRICKS_ACCESS_TOKEN") + return {"Authorization": f"Bearer {access_token}", "content-type": "application/json"} def result_state_is_available(life_cycle_state: str) -> bool: From 9ede49abeeab4b974cbbbac100c60470071a081a Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 17 Apr 2024 16:19:46 -0700 Subject: [PATCH 10/18] mot Signed-off-by: Kevin Su --- flytekit/extend/backend/agent_service.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index ba7dcc12d4..04a749618e 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -172,6 +172,7 @@ async def ExecuteTaskSync( self, request_iterator: typing.AsyncIterator[ExecuteTaskSyncRequest], context: grpc.ServicerContext ) -> typing.AsyncIterator[ExecuteTaskSyncResponse]: request = await request_iterator.__anext__() + connection_pb = request.header.connection template = TaskTemplate.from_flyte_idl(request.header.template) task_type = template.type try: @@ -182,9 +183,8 @@ async def ExecuteTaskSync( request = await request_iterator.__anext__() literal_map = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None - connection = ( - agent.connection_type.decode(request.header.connection) if request.header.connection else None - ) + connection = agent.connection_type.decode(connection_pb) if connection_pb else None + res = await mirror_async_methods( agent.do, task_template=template, inputs=literal_map, connection=connection ) From 961c5681a1cabbbec7d25deeabd7eb414b8387bf Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 29 Apr 2024 16:35:00 +0800 Subject: [PATCH 11/18] fix tests Signed-off-by: Kevin Su --- dev-requirements.in | 2 +- .../flytekit-openai/flytekitplugins/openai/chatgpt/agent.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/dev-requirements.in b/dev-requirements.in index fb90c597b9..871394beb7 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -1,5 +1,5 @@ -e file:.#egg=flytekit -git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl +git+https://github.com/flyteorg/flyte.git@df830eaf6e0d7632adec0069341a67c5a2eade00#subdirectory=flyteidl coverage[toml] hypothesis diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py index 4eb3ef9c7a..b4d5f76290 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py @@ -42,9 +42,10 @@ async def do( custom = task_template.custom custom["chatgpt_config"]["messages"] = [{"role": "user", "content": message}] + openai_api_key = connection.openai_api_key if connection else None client = openai.AsyncOpenAI( organization=custom["openai_organization"], - api_key=connection.openai_api_key or get_agent_secret(secret_key=OPENAI_API_KEY), + api_key=openai_api_key or get_agent_secret(secret_key=OPENAI_API_KEY), ) logger = logging.getLogger("httpx") From 9d2e8f38f6f114344829b3759e23861b1f5c968d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 29 Apr 2024 17:24:50 +0800 Subject: [PATCH 12/18] fix tests Signed-off-by: Kevin Su --- Dockerfile.dev | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.dev b/Dockerfile.dev index f4f56d0d4a..b15a3ddc41 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -30,7 +30,7 @@ RUN apt-get update && apt-get install build-essential vim libmagic1 git -y RUN pip install scikit-learn COPY . /flytekit RUN SETUPTOOLS_SCM_PRETEND_VERSION_FOR_FLYTEKIT=$PSEUDO_VERSION pip install --no-cache-dir -U \ - "git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl" \ + "git+https://github.com/flyteorg/flyte.git@9574f595c617fba18ef230da780e840f2b90b22d#subdirectory=flyteidl" \ -e /flytekit \ -e /flytekit/plugins/flytekit-k8s-pod \ -e /flytekit/plugins/flytekit-deck-standard \ From 5e378a95f76bd1520b4fcc49d0866f0a6c29ec58 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 29 Apr 2024 19:50:36 +0800 Subject: [PATCH 13/18] add openai_organization Signed-off-by: Kevin Su --- flytekit/extend/backend/base_agent.py | 2 +- plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 834ae85f91..90336c2481 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -66,7 +66,7 @@ def decode(cls, connection: security_pb2.Connection) -> "Connection": Decode the resource meta from bytes. """ data = {k: v for k, v in connection.secrets.items()} - data.update({k: v for k, v in connection.config.items()}) + data.update({k: v for k, v in connection.configs.items()}) return dataclass_from_dict(cls, data) diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py index b4d5f76290..c66901566e 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py @@ -21,6 +21,7 @@ @dataclass class ChatGPTConnection(Connection): openai_api_key: str + openai_organization: str class ChatGPTAgent(SyncAgentBase): From 50bac66dc5f02d4c4e07a22b33a1e582f91b15ca Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 29 Apr 2024 20:23:49 +0800 Subject: [PATCH 14/18] nit Signed-off-by: Kevin Su --- .../flytekit-openai/flytekitplugins/openai/chatgpt/agent.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py index c66901566e..19682f9064 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py @@ -43,9 +43,12 @@ async def do( custom = task_template.custom custom["chatgpt_config"]["messages"] = [{"role": "user", "content": message}] + openai_api_key = connection.openai_api_key if connection else None + openai_organization = connection.openai_organization if connection else None + client = openai.AsyncOpenAI( - organization=custom["openai_organization"], + organization=openai_organization or custom["openai_organization"], api_key=openai_api_key or get_agent_secret(secret_key=OPENAI_API_KEY), ) From 864214152f77d81641cfb7e378242e083c132e96 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 29 Apr 2024 22:46:52 +0800 Subject: [PATCH 15/18] nit Signed-off-by: Kevin Su --- .../flytekit-openai/flytekitplugins/openai/chatgpt/task.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py index 027f4c494e..a66287fde5 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask @@ -17,9 +17,9 @@ class ChatGPTTask(SyncAgentExecutorMixin, PythonTask): def __init__( self, name: str, - openai_organization: str, chatgpt_config: Dict[str, Any], - connection: str = "chatgpt", + openai_organization: Optional[str] = None, + connection: Optional[str] = None, **kwargs, ): """ From 6f851b37f9b4d30341d41c92daf0a30a0cfe981d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 16 May 2024 17:51:25 +0800 Subject: [PATCH 16/18] use connection_ref Signed-off-by: Kevin Su --- flytekit/models/security.py | 8 ++++---- .../flytekitplugins/openai/chatgpt/task.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/flytekit/models/security.py b/flytekit/models/security.py index b0f1c78086..b860b2f987 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -152,7 +152,7 @@ class SecurityContext(_common.FlyteIdlEntity): run_as: Optional[Identity] = None secrets: Optional[List[Secret]] = None tokens: Optional[List[OAuth2TokenRequest]] = None - connection: Optional[str] = None + connection_ref: Optional[str] = None def __post_init__(self): if self.secrets and not isinstance(self.secrets, list): @@ -161,13 +161,13 @@ def __post_init__(self): self.tokens = [self.tokens] def to_flyte_idl(self) -> _sec.SecurityContext: - if self.run_as is None and self.secrets is None and self.tokens is None and self.connection is None: + if self.run_as is None and self.secrets is None and self.tokens is None and self.connection_ref is None: return None return _sec.SecurityContext( run_as=self.run_as.to_flyte_idl() if self.run_as else None, secrets=[s.to_flyte_idl() for s in self.secrets] if self.secrets else None, tokens=[t.to_flyte_idl() for t in self.tokens] if self.tokens else None, - connection=self.connection, + connection_ref=self.connection_ref, ) @classmethod @@ -178,5 +178,5 @@ def from_flyte_idl(cls, pb2_object: _sec.SecurityContext) -> "SecurityContext": else None, secrets=[Secret.from_flyte_idl(s) for s in pb2_object.secrets] if pb2_object.secrets else None, tokens=[OAuth2TokenRequest.from_flyte_idl(t) for t in pb2_object.tokens] if pb2_object.tokens else None, - connection=pb2_object.connection if pb2_object.connection else None, + connection_ref=pb2_object.connection_ref if pb2_object.connection_ref else None, ) diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py index a66287fde5..e78446f70f 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py @@ -19,7 +19,7 @@ def __init__( name: str, chatgpt_config: Dict[str, Any], openai_organization: Optional[str] = None, - connection: Optional[str] = None, + connection_ref: Optional[str] = None, **kwargs, ): """ @@ -36,7 +36,7 @@ def __init__( inputs = {"message": str} outputs = {"o0": str} - sec_ctx = SecurityContext(connection=connection) + sec_ctx = SecurityContext(connection_ref=connection_ref) super().__init__( task_type=self._TASK_TYPE, From e7a090ba87b0883db70c365c2e30f16fc40221fa Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 17 May 2024 18:19:41 +0800 Subject: [PATCH 17/18] nit Signed-off-by: Kevin Su --- plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py | 2 +- plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py index 19682f9064..1c1f926c8e 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py @@ -44,7 +44,7 @@ async def do( custom = task_template.custom custom["chatgpt_config"]["messages"] = [{"role": "user", "content": message}] - openai_api_key = connection.openai_api_key if connection else None + openai_api_key = connection.openai_api_key.strip() if connection else None openai_organization = connection.openai_organization if connection else None client = openai.AsyncOpenAI( diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py index e78446f70f..d91d401c60 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/task.py @@ -17,7 +17,7 @@ class ChatGPTTask(SyncAgentExecutorMixin, PythonTask): def __init__( self, name: str, - chatgpt_config: Dict[str, Any], + chatgpt_config: Optional[Dict[str, Any]] = None, openai_organization: Optional[str] = None, connection_ref: Optional[str] = None, **kwargs, From 31a6fc814682db508fd49bb35cb6a882ec936832 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 26 Jun 2024 15:38:02 -0700 Subject: [PATCH 18/18] update idl Signed-off-by: Kevin Su --- dev-requirements.in | 2 +- flytekit/extend/backend/agent_service.py | 6 +++++- plugins/flytekit-spark/flytekitplugins/spark/agent.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/dev-requirements.in b/dev-requirements.in index 1e098b1445..9d3d0ec0a0 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -1,5 +1,5 @@ -e file:. -git+https://github.com/flyteorg/flyte.git@5f6d682cb7f5d417f217c61393ce62fff4dbeb8c#subdirectory=flyteidl +flyteidl @ git+https://github.com/flyteorg/flyte.git@5f6d682cb7f5d417f217c61393ce62fff4dbeb8c#subdirectory=flyteidl coverage[toml] hypothesis diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 1e10c9cf11..a720022307 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -188,7 +188,11 @@ async def ExecuteTaskSync( connection = agent.connection_type.decode(connection_pb) if connection_pb else None res = await mirror_async_methods( - agent.do, task_template=template, inputs=literal_map, output_prefix=output_prefix, connection=connection + agent.do, + task_template=template, + inputs=literal_map, + output_prefix=output_prefix, + connection=connection, ) if res.outputs is None: diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index dfaff833fd..9a02c62da2 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -68,7 +68,7 @@ async def create( "git_commit": "aff8a9f2adbf5deda81d36d59a0b8fa3b1fc3679", } - databricks_instance = custom["databricksInstance"] + databricks_instance = connection.databricks_instance or custom["databricksInstance"] databricks_url = f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/submit" data = json.dumps(databricks_job)