Skip to content

Commit

Permalink
Add CodeNode
Browse files Browse the repository at this point in the history
  • Loading branch information
proteusvacuum committed Dec 13, 2024
1 parent de8dc1e commit ae1333c
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 1 deletion.
99 changes: 98 additions & 1 deletion apps/pipelines/nodes/nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import datetime
import json
import time
from typing import Literal

import tiktoken
Expand All @@ -14,13 +16,14 @@
from pydantic.config import ConfigDict
from pydantic_core import PydanticCustomError
from pydantic_core.core_schema import FieldValidationInfo
from RestrictedPython import compile_restricted_function, safe_builtins, safe_globals

from apps.assistants.models import OpenAiAssistant
from apps.channels.datamodels import Attachment
from apps.chat.conversation import compress_chat_history, compress_pipeline_chat_history
from apps.chat.models import ChatMessageType
from apps.experiments.models import ExperimentSession, ParticipantData
from apps.pipelines.exceptions import PipelineNodeBuildError
from apps.pipelines.exceptions import PipelineNodeBuildError, PipelineNodeRunError
from apps.pipelines.models import PipelineChatHistory, PipelineChatHistoryTypes
from apps.pipelines.nodes.base import NodeSchema, OptionsSource, PipelineNode, PipelineState, UiSchema, Widgets
from apps.pipelines.tasks import send_email_from_pipeline
Expand Down Expand Up @@ -622,3 +625,97 @@ def _get_assistant_runnable(self, assistant: OpenAiAssistant, session: Experimen
return AgentAssistantChat(adapter=adapter, history_manager=history_manager)
else:
return AssistantChat(adapter=adapter, history_manager=history_manager)


class CodeNode(PipelineNode):
"""Runs python"""

model_config = ConfigDict(json_schema_extra=NodeSchema(label="Python Node"))
code: str = Field(
description="The code to run",
json_schema_extra=UiSchema(widget=Widgets.expandable_text), # TODO: add a code widget
)

@field_validator("code")
def validate_code(cls, value, info: FieldValidationInfo):
if not value:
value = "return input"

byte_code = compile_restricted_function(
"input,shared_state",
value,
name="main",
filename="<inline code>",
)

if byte_code.errors:
raise PydanticCustomError("invalid_code", "{errors}", {"errors": "\n".join(byte_code.errors)})
return value

def _process(self, input: str, state: PipelineState, node_id: str) -> PipelineState:
function_name = "main"
function_args = "input"
byte_code = compile_restricted_function(
function_args,
self.code,
name=function_name,
filename="<inline code>",
)

custom_locals = {}
custom_globals = self._get_custom_globals()
exec(byte_code.code, custom_globals, custom_locals)

try:
result = str(custom_locals[function_name](input))
except Exception as exc:
raise PipelineNodeRunError(exc) from exc
return PipelineState.from_node_output(node_id=node_id, output=result)

def _get_custom_globals(self):
from RestrictedPython.Eval import (
default_guarded_getitem,
default_guarded_getiter,
)

custom_globals = safe_globals.copy()
custom_globals.update(
{
"__builtins__": self._get_custom_builtins(),
"json": json,
"datetime": datetime,
"time": time,
"_getitem_": default_guarded_getitem,
"_getiter_": default_guarded_getiter,
"_write_": lambda x: x,
}
)
return custom_globals

def _get_custom_builtins(self):
allowed_modules = {
"json",
"re",
"datetime",
"time",
}
custom_builtins = safe_builtins.copy()
custom_builtins.update(
{
"min": min,
"max": max,
"sum": sum,
"abs": abs,
"all": all,
"any": any,
"datetime": datetime,
}
)

def guarded_import(name, *args, **kwargs):
if name not in allowed_modules:
raise ImportError(f"Importing '{name}' is not allowed")
return __import__(name, *args, **kwargs)

custom_builtins["__import__"] = guarded_import
return custom_builtins
88 changes: 88 additions & 0 deletions apps/pipelines/tests/test_code_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
from unittest import mock

import pytest

from apps.pipelines.exceptions import PipelineNodeBuildError, PipelineNodeRunError
from apps.pipelines.nodes.base import PipelineState
from apps.pipelines.tests.utils import (
code_node,
create_runnable,
end_node,
start_node,
)
from apps.utils.factories.pipelines import PipelineFactory
from apps.utils.pytest import django_db_with_data


@pytest.fixture()
def pipeline():
return PipelineFactory()


EXTRA_FUNCTION = """
def other(foo):
return f"other {foo}"
return other(input)
"""

IMPORTS = """
import json
import datetime
import re
import time
return json.loads(input)
"""


@django_db_with_data(available_apps=("apps.service_providers",))
@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock())
@pytest.mark.parametrize(
("code", "input", "output"),
[
("return f'Hello, {input}!'", "World", "Hello, World!"),
("", "foo", "foo"), # No code just returns the input
(EXTRA_FUNCTION, "blah", "other blah"), # Calling a separate function is possible
("'foo'", "", "None"), # No return value will return "None"
(IMPORTS, json.dumps({"a": "b"}), str(json.loads('{"a": "b"}'))), # Importing json will work
],
)
def test_code_node(pipeline, code, input, output):
nodes = [
start_node(),
code_node(code),
end_node(),
]
assert create_runnable(pipeline, nodes).invoke(PipelineState(messages=[input]))["messages"][-1] == output


@django_db_with_data(available_apps=("apps.service_providers",))
@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock())
def test_code_node_syntax_error(pipeline):
nodes = [
start_node(),
code_node("this{}"),
end_node(),
]
with pytest.raises(PipelineNodeBuildError, match="SyntaxError: invalid syntax at statement: 'this{}'"):
create_runnable(pipeline, nodes).invoke(PipelineState(messages=["World"]))["messages"][-1]


@django_db_with_data(available_apps=("apps.service_providers",))
@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock())
@pytest.mark.parametrize(
("code", "input", "error"),
[
("import collections", "", "Importing 'collections' is not allowed"),
("return f'Hello, {blah}!'", "", "name 'blah' is not defined"),
],
)
def test_code_node_runtime_errors(pipeline, code, input, error):
nodes = [
start_node(),
code_node(code),
end_node(),
]
with pytest.raises(PipelineNodeRunError, match=error):
create_runnable(pipeline, nodes).invoke(PipelineState(messages=[input]))["messages"][-1]
12 changes: 12 additions & 0 deletions apps/pipelines/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,15 @@ def extract_structured_data_node(provider_id: str, provider_model_id: str, data_
"data_schema": data_schema,
},
}


def code_node(code: str | None = None):
if code is None:
code = "return f'Hello, {input}!'"
return {
"id": str(uuid4()),
"type": nodes.CodeNode.__name__,
"params": {
"code": code,
},
}
1 change: 1 addition & 0 deletions requirements/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ psycopg[binary]
pyTelegramBotAPI==4.12.0
pydantic
pydub # Audio transcription
RestrictedPython
sentry-sdk
slack-bolt
taskbadger
Expand Down
2 changes: 2 additions & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ requests-oauthlib==1.3.1
# via django-allauth
requests-toolbelt==1.0.0
# via langsmith
restrictedpython==7.4
# via -r requirements.in
rich==13.6.0
# via typer
rpds-py==0.12.0
Expand Down

0 comments on commit ae1333c

Please sign in to comment.