From ae1333cc70fff14bacbab6c1a0a96f63c4db92fc Mon Sep 17 00:00:00 2001 From: Farid Rener Date: Fri, 13 Dec 2024 14:51:39 -0500 Subject: [PATCH] Add CodeNode --- apps/pipelines/nodes/nodes.py | 99 +++++++++++++++++++++++++- apps/pipelines/tests/test_code_node.py | 88 +++++++++++++++++++++++ apps/pipelines/tests/utils.py | 12 ++++ requirements/requirements.in | 1 + requirements/requirements.txt | 2 + 5 files changed, 201 insertions(+), 1 deletion(-) create mode 100644 apps/pipelines/tests/test_code_node.py diff --git a/apps/pipelines/nodes/nodes.py b/apps/pipelines/nodes/nodes.py index 2d68811c3..9cf93e2ea 100644 --- a/apps/pipelines/nodes/nodes.py +++ b/apps/pipelines/nodes/nodes.py @@ -1,4 +1,6 @@ +import datetime import json +import time from typing import Literal import tiktoken @@ -14,13 +16,14 @@ from pydantic.config import ConfigDict from pydantic_core import PydanticCustomError from pydantic_core.core_schema import FieldValidationInfo +from RestrictedPython import compile_restricted_function, safe_builtins, safe_globals from apps.assistants.models import OpenAiAssistant from apps.channels.datamodels import Attachment from apps.chat.conversation import compress_chat_history, compress_pipeline_chat_history from apps.chat.models import ChatMessageType from apps.experiments.models import ExperimentSession, ParticipantData -from apps.pipelines.exceptions import PipelineNodeBuildError +from apps.pipelines.exceptions import PipelineNodeBuildError, PipelineNodeRunError from apps.pipelines.models import PipelineChatHistory, PipelineChatHistoryTypes from apps.pipelines.nodes.base import NodeSchema, OptionsSource, PipelineNode, PipelineState, UiSchema, Widgets from apps.pipelines.tasks import send_email_from_pipeline @@ -622,3 +625,97 @@ def _get_assistant_runnable(self, assistant: OpenAiAssistant, session: Experimen return AgentAssistantChat(adapter=adapter, history_manager=history_manager) else: return AssistantChat(adapter=adapter, history_manager=history_manager) + + +class CodeNode(PipelineNode): + """Runs python""" + + model_config = ConfigDict(json_schema_extra=NodeSchema(label="Python Node")) + code: str = Field( + description="The code to run", + json_schema_extra=UiSchema(widget=Widgets.expandable_text), # TODO: add a code widget + ) + + @field_validator("code") + def validate_code(cls, value, info: FieldValidationInfo): + if not value: + value = "return input" + + byte_code = compile_restricted_function( + "input,shared_state", + value, + name="main", + filename="", + ) + + if byte_code.errors: + raise PydanticCustomError("invalid_code", "{errors}", {"errors": "\n".join(byte_code.errors)}) + return value + + def _process(self, input: str, state: PipelineState, node_id: str) -> PipelineState: + function_name = "main" + function_args = "input" + byte_code = compile_restricted_function( + function_args, + self.code, + name=function_name, + filename="", + ) + + custom_locals = {} + custom_globals = self._get_custom_globals() + exec(byte_code.code, custom_globals, custom_locals) + + try: + result = str(custom_locals[function_name](input)) + except Exception as exc: + raise PipelineNodeRunError(exc) from exc + return PipelineState.from_node_output(node_id=node_id, output=result) + + def _get_custom_globals(self): + from RestrictedPython.Eval import ( + default_guarded_getitem, + default_guarded_getiter, + ) + + custom_globals = safe_globals.copy() + custom_globals.update( + { + "__builtins__": self._get_custom_builtins(), + "json": json, + "datetime": datetime, + "time": time, + "_getitem_": default_guarded_getitem, + "_getiter_": default_guarded_getiter, + "_write_": lambda x: x, + } + ) + return custom_globals + + def _get_custom_builtins(self): + allowed_modules = { + "json", + "re", + "datetime", + "time", + } + custom_builtins = safe_builtins.copy() + custom_builtins.update( + { + "min": min, + "max": max, + "sum": sum, + "abs": abs, + "all": all, + "any": any, + "datetime": datetime, + } + ) + + def guarded_import(name, *args, **kwargs): + if name not in allowed_modules: + raise ImportError(f"Importing '{name}' is not allowed") + return __import__(name, *args, **kwargs) + + custom_builtins["__import__"] = guarded_import + return custom_builtins diff --git a/apps/pipelines/tests/test_code_node.py b/apps/pipelines/tests/test_code_node.py new file mode 100644 index 000000000..8edcc1a07 --- /dev/null +++ b/apps/pipelines/tests/test_code_node.py @@ -0,0 +1,88 @@ +import json +from unittest import mock + +import pytest + +from apps.pipelines.exceptions import PipelineNodeBuildError, PipelineNodeRunError +from apps.pipelines.nodes.base import PipelineState +from apps.pipelines.tests.utils import ( + code_node, + create_runnable, + end_node, + start_node, +) +from apps.utils.factories.pipelines import PipelineFactory +from apps.utils.pytest import django_db_with_data + + +@pytest.fixture() +def pipeline(): + return PipelineFactory() + + +EXTRA_FUNCTION = """ +def other(foo): + return f"other {foo}" + +return other(input) +""" + +IMPORTS = """ +import json +import datetime +import re +import time +return json.loads(input) +""" + + +@django_db_with_data(available_apps=("apps.service_providers",)) +@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock()) +@pytest.mark.parametrize( + ("code", "input", "output"), + [ + ("return f'Hello, {input}!'", "World", "Hello, World!"), + ("", "foo", "foo"), # No code just returns the input + (EXTRA_FUNCTION, "blah", "other blah"), # Calling a separate function is possible + ("'foo'", "", "None"), # No return value will return "None" + (IMPORTS, json.dumps({"a": "b"}), str(json.loads('{"a": "b"}'))), # Importing json will work + ], +) +def test_code_node(pipeline, code, input, output): + nodes = [ + start_node(), + code_node(code), + end_node(), + ] + assert create_runnable(pipeline, nodes).invoke(PipelineState(messages=[input]))["messages"][-1] == output + + +@django_db_with_data(available_apps=("apps.service_providers",)) +@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock()) +def test_code_node_syntax_error(pipeline): + nodes = [ + start_node(), + code_node("this{}"), + end_node(), + ] + with pytest.raises(PipelineNodeBuildError, match="SyntaxError: invalid syntax at statement: 'this{}'"): + create_runnable(pipeline, nodes).invoke(PipelineState(messages=["World"]))["messages"][-1] + + +@django_db_with_data(available_apps=("apps.service_providers",)) +@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock()) +@pytest.mark.parametrize( + ("code", "input", "error"), + [ + ("import collections", "", "Importing 'collections' is not allowed"), + ("return f'Hello, {blah}!'", "", "name 'blah' is not defined"), + ], +) +def test_code_node_runtime_errors(pipeline, code, input, error): + nodes = [ + start_node(), + code_node(code), + end_node(), + ] + with pytest.raises(PipelineNodeRunError, match=error): + create_runnable(pipeline, nodes).invoke(PipelineState(messages=[input]))["messages"][-1] diff --git a/apps/pipelines/tests/utils.py b/apps/pipelines/tests/utils.py index d2513921d..8ded20c8a 100644 --- a/apps/pipelines/tests/utils.py +++ b/apps/pipelines/tests/utils.py @@ -174,3 +174,15 @@ def extract_structured_data_node(provider_id: str, provider_model_id: str, data_ "data_schema": data_schema, }, } + + +def code_node(code: str | None = None): + if code is None: + code = "return f'Hello, {input}!'" + return { + "id": str(uuid4()), + "type": nodes.CodeNode.__name__, + "params": { + "code": code, + }, + } diff --git a/requirements/requirements.in b/requirements/requirements.in index d734a7d3b..b0a4a53a8 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -47,6 +47,7 @@ psycopg[binary] pyTelegramBotAPI==4.12.0 pydantic pydub # Audio transcription +RestrictedPython sentry-sdk slack-bolt taskbadger diff --git a/requirements/requirements.txt b/requirements/requirements.txt index d43f13dd9..41485e94e 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -412,6 +412,8 @@ requests-oauthlib==1.3.1 # via django-allauth requests-toolbelt==1.0.0 # via langsmith +restrictedpython==7.4 + # via -r requirements.in rich==13.6.0 # via typer rpds-py==0.12.0