Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flyte-core] Flyte Connection #2297

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon
task_template=template,
inputs=inputs,
output_prefix=request.output_prefix,
secrets=request.secrets,
task_execution_metadata=task_execution_metadata,
)
return CreateTaskResponse(resource_meta=resource_mata.encode())
Expand All @@ -134,7 +135,9 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext)
else:
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.name} start checking the status of the job")
res = await mirror_async_methods(agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta))
res = await mirror_async_methods(
agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta), secrets=request.secrets
)

if res.outputs is None:
outputs = None
Expand All @@ -154,7 +157,9 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon
else:
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.name} start deleting the job")
await mirror_async_methods(agent.delete, resource_meta=agent.metadata_type.decode(request.resource_meta))
await mirror_async_methods(
agent.delete, resource_meta=agent.metadata_type.decode(request.resource_meta), secrets=request.secrets
)
return DeleteTaskResponse()


Expand All @@ -164,6 +169,7 @@ async def ExecuteTaskSync(
) -> typing.AsyncIterator[ExecuteTaskSyncResponse]:
request = await request_iterator.__anext__()
template = TaskTemplate.from_flyte_idl(request.header.template)
secrets = request.header.secrets
task_type = template.type
try:
with request_latency.labels(task_type=task_type, operation=do_operation).time():
Expand All @@ -173,7 +179,7 @@ async def ExecuteTaskSync(

request = await request_iterator.__anext__()
literal_map = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
res = await mirror_async_methods(agent.do, task_template=template, inputs=literal_map)
res = await mirror_async_methods(agent.do, task_template=template, inputs=literal_map, secrets=secrets)

if res.outputs is None:
outputs = None
Expand Down
17 changes: 14 additions & 3 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from types import FrameType, coroutine
from typing import Any, Dict, List, Optional, Union

from flyteidl.admin.agent_pb2 import Agent
from flyteidl.admin.agent_pb2 import Agent, Secret
from flyteidl.admin.agent_pb2 import TaskCategory as _TaskCategory
from flyteidl.core import literals_pb2
from flyteidl.core.execution_pb2 import TaskExecution, TaskLog
from rich.progress import Progress

import flytekit
from flytekit import FlyteContext, PythonFunctionTask, logger
from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core import utils
Expand Down Expand Up @@ -117,7 +118,13 @@ class SyncAgentBase(AgentBase):
name = "Base Sync Agent"

@abstractmethod
def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs) -> Resource:
def do(
self,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
secrets: Optional[List[Secret]] = None,
**kwargs,
) -> Resource:
"""
This is the method that the agent will run.
"""
Expand Down Expand Up @@ -257,8 +264,12 @@ async def _do(
) -> Resource:
try:
ctx = FlyteContext.current_context()
secrets = []
for secret in template.security_context.secrets:
value = flytekit.current_context().secrets.get(secret.group, secret.key, secret.group_version)
secrets.append(Secret(value=value))
literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types())
return await mirror_async_methods(agent.do, task_template=template, inputs=literal_map)
return await mirror_async_methods(agent.do, task_template=template, inputs=literal_map, secrets=secrets)
except Exception as error_message:
raise FlyteUserException(f"Failed to run the task {self.name} with error: {error_message}")

Expand Down
9 changes: 8 additions & 1 deletion flytekit/extend/backend/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import functools
import inspect
import typing
from typing import Callable, Coroutine

from flyteidl.admin.agent_pb2 import Secret
from flyteidl.core.execution_pb2 import TaskExecution

import flytekit
Expand Down Expand Up @@ -38,7 +40,12 @@ def is_terminal_phase(phase: TaskExecution.Phase) -> bool:
return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED]


def get_agent_secret(secret_key: str) -> str:
def get_agent_secret(secret_key: str, secret: typing.Optional[Secret] = None) -> str:
"""
Get the secret from the context if the secret is not provided.
"""
if secret:
return secret.value
return flytekit.current_context().secrets.get(secret_key)


Expand Down
2 changes: 1 addition & 1 deletion flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def config(self):
return self._config

@property
def security_context(self):
def security_context(self) -> _sec.SecurityContext:
return self._security_context

@property
Expand Down
9 changes: 8 additions & 1 deletion plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import logging
import typing
from typing import Optional

from flyteidl.admin.agent_pb2 import Secret
from flyteidl.core.execution_pb2 import TaskExecution

from flytekit import FlyteContextManager, lazy_module
Expand All @@ -27,16 +29,21 @@ async def do(
self,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
secrets: typing.List[Secret] = None,
**kwargs,
) -> Resource:
ctx = FlyteContextManager.current_context()
input_python_value = TypeEngine.literal_map_to_kwargs(ctx, inputs, {"message": str})
message = input_python_value["message"]

api_key = None
if secrets and len(secrets) > 0:
api_key = secrets[0]
custom = task_template.custom
custom["chatgpt_config"]["messages"] = [{"role": "user", "content": message}]
client = openai.AsyncOpenAI(
organization=custom["openai_organization"],
api_key=get_agent_secret(secret_key=OPENAI_API_KEY),
api_key=get_agent_secret(secret_key=OPENAI_API_KEY, secret=api_key),
)

logger = logging.getLogger("httpx")
Expand Down
15 changes: 13 additions & 2 deletions plugins/flytekit-openai/flytekitplugins/chatgpt/task.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import typing
from typing import Any, Dict

from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.interface import Interface
from flytekit.extend.backend.base_agent import SyncAgentExecutorMixin
from flytekit.models.security import Secret, SecurityContext


class ChatGPTTask(SyncAgentExecutorMixin, PythonTask):
Expand All @@ -13,7 +15,14 @@ class ChatGPTTask(SyncAgentExecutorMixin, PythonTask):

_TASK_TYPE = "chatgpt"

def __init__(self, name: str, openai_organization: str, chatgpt_config: Dict[str, Any], **kwargs):
def __init__(
self,
name: str,
openai_organization: str,
chatgpt_config: Dict[str, Any],
openai_key: typing.Optional[Secret] = None,
**kwargs,
):
"""
Args:
name: Name of this task, should be unique in the project
Expand All @@ -25,15 +34,17 @@ def __init__(self, name: str, openai_organization: str, chatgpt_config: Dict[str
raise ValueError("The 'model' configuration variable is required in chatgpt_config")

task_config = {"openai_organization": openai_organization, "chatgpt_config": chatgpt_config}

inputs = {"message": str}
outputs = {"o0": str}

sec_ctx = SecurityContext(secrets=[openai_key])

super().__init__(
task_type=self._TASK_TYPE,
name=name,
task_config=task_config,
interface=Interface(inputs=inputs, outputs=outputs),
security_ctx=sec_ctx,
**kwargs,
)

Expand Down
Loading