Skip to content

Commit

Permalink
Merge pull request #992 from dimagi/fr/code-node
Browse files Browse the repository at this point in the history
Add CodeNode
  • Loading branch information
proteusvacuum authored Dec 30, 2024
2 parents 3794faf + 0e00b33 commit 1f1508b
Show file tree
Hide file tree
Showing 10 changed files with 655 additions and 11 deletions.
1 change: 1 addition & 0 deletions apps/pipelines/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def logger(self):

class Widgets(StrEnum):
expandable_text = "expandable_text"
code = "code"
toggle = "toggle"
select = "select"
float = "float"
Expand Down
121 changes: 120 additions & 1 deletion apps/pipelines/nodes/nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import datetime
import inspect
import json
import time
from typing import Literal

import tiktoken
Expand All @@ -14,13 +17,14 @@
from pydantic.config import ConfigDict
from pydantic_core import PydanticCustomError
from pydantic_core.core_schema import FieldValidationInfo
from RestrictedPython import compile_restricted, safe_builtins, safe_globals

from apps.assistants.models import OpenAiAssistant
from apps.channels.datamodels import Attachment
from apps.chat.conversation import compress_chat_history, compress_pipeline_chat_history
from apps.chat.models import ChatMessageType
from apps.experiments.models import ExperimentSession, ParticipantData
from apps.pipelines.exceptions import PipelineNodeBuildError
from apps.pipelines.exceptions import PipelineNodeBuildError, PipelineNodeRunError
from apps.pipelines.models import PipelineChatHistory, PipelineChatHistoryTypes
from apps.pipelines.nodes.base import NodeSchema, OptionsSource, PipelineNode, PipelineState, UiSchema, Widgets
from apps.pipelines.tasks import send_email_from_pipeline
Expand Down Expand Up @@ -622,3 +626,118 @@ def _get_assistant_runnable(self, assistant: OpenAiAssistant, session: Experimen
return AgentAssistantChat(adapter=adapter, history_manager=history_manager)
else:
return AssistantChat(adapter=adapter, history_manager=history_manager)


DEFAULT_FUNCTION = """# You must define a main function, which takes the node input as a string.
# Return a string to pass to the next node.
def main(input: str, **kwargs) -> str:
return input
"""


class CodeNode(PipelineNode):
"""Runs python"""

model_config = ConfigDict(json_schema_extra=NodeSchema(label="Python Node"))
code: str = Field(
default=DEFAULT_FUNCTION,
description="The code to run",
json_schema_extra=UiSchema(widget=Widgets.code),
)

@field_validator("code")
def validate_code(cls, value, info: FieldValidationInfo):
if not value:
value = DEFAULT_FUNCTION
try:
byte_code = compile_restricted(
value,
filename="<inline code>",
mode="exec",
)
custom_locals = {}
exec(byte_code, {}, custom_locals)

try:
main = custom_locals["main"]
except KeyError:
raise SyntaxError("You must define a 'main' function")

for name, item in custom_locals.items():
if name != "main" and inspect.isfunction(item):
raise SyntaxError(
"You can only define a single function, 'main' at the top level. "
"You may use nested functions inside that function if required"
)

if list(inspect.signature(main).parameters) != ["input", "kwargs"]:
raise SyntaxError("The main function should have the signature main(input, **kwargs) only.")

except SyntaxError as exc:
raise PydanticCustomError("invalid_code", "{error}", {"error": exc.msg})
return value

def _process(self, input: str, state: PipelineState, node_id: str) -> PipelineState:
function_name = "main"
byte_code = compile_restricted(
self.code,
filename="<inline code>",
mode="exec",
)

custom_locals = {}
custom_globals = self._get_custom_globals()
try:
exec(byte_code, custom_globals, custom_locals)
result = str(custom_locals[function_name](input))
except Exception as exc:
raise PipelineNodeRunError(exc) from exc
return PipelineState.from_node_output(node_id=node_id, output=result)

def _get_custom_globals(self):
from RestrictedPython.Eval import (
default_guarded_getitem,
default_guarded_getiter,
)

custom_globals = safe_globals.copy()
custom_globals.update(
{
"__builtins__": self._get_custom_builtins(),
"json": json,
"datetime": datetime,
"time": time,
"_getitem_": default_guarded_getitem,
"_getiter_": default_guarded_getiter,
"_write_": lambda x: x,
}
)
return custom_globals

def _get_custom_builtins(self):
allowed_modules = {
"json",
"re",
"datetime",
"time",
}
custom_builtins = safe_builtins.copy()
custom_builtins.update(
{
"min": min,
"max": max,
"sum": sum,
"abs": abs,
"all": all,
"any": any,
"datetime": datetime,
}
)

def guarded_import(name, *args, **kwargs):
if name not in allowed_modules:
raise ImportError(f"Importing '{name}' is not allowed")
return __import__(name, *args, **kwargs)

custom_builtins["__import__"] = guarded_import
return custom_builtins
16 changes: 16 additions & 0 deletions apps/pipelines/tests/data/CodeNode.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"description": "Runs python",
"properties": {
"code": {
"default": "# You must define a main function, which takes the node input as a string.\n# Return a string to pass to the next node.\ndef main(input: str, **kwargs) -> str:\n return input\n",
"description": "The code to run",
"title": "Code",
"type": "string",
"ui:widget": "code"
}
},
"title": "CodeNode",
"type": "object",
"ui:flow_node_type": "pipelineNode",
"ui:label": "Python Node"
}
114 changes: 114 additions & 0 deletions apps/pipelines/tests/test_code_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import json
from unittest import mock

import pytest

from apps.pipelines.exceptions import PipelineNodeBuildError, PipelineNodeRunError
from apps.pipelines.nodes.base import PipelineState
from apps.pipelines.tests.utils import (
code_node,
create_runnable,
end_node,
start_node,
)
from apps.utils.factories.pipelines import PipelineFactory
from apps.utils.pytest import django_db_with_data


@pytest.fixture()
def pipeline():
return PipelineFactory()


IMPORTS = """
import json
import datetime
import re
import time
def main(input, **kwargs):
return json.loads(input)
"""


@django_db_with_data(available_apps=("apps.service_providers",))
@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock())
@pytest.mark.parametrize(
("code", "input", "output"),
[
("def main(input, **kwargs):\n\treturn f'Hello, {input}!'", "World", "Hello, World!"),
("", "foo", "foo"), # No code just returns the input
("def main(input, **kwargs):\n\t'foo'", "", "None"), # No return value will return "None"
(IMPORTS, json.dumps({"a": "b"}), str(json.loads('{"a": "b"}'))), # Importing json will work
],
)
def test_code_node(pipeline, code, input, output):
nodes = [
start_node(),
code_node(code),
end_node(),
]
assert create_runnable(pipeline, nodes).invoke(PipelineState(messages=[input]))["messages"][-1] == output


EXTRA_FUNCTION = """
def other(foo):
return f"other {foo}"
def main(input, **kwargs):
return other(input)
"""


@django_db_with_data(available_apps=("apps.service_providers",))
@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock())
@pytest.mark.parametrize(
("code", "input", "error"),
[
("this{}", "", "SyntaxError: invalid syntax at statement: 'this{}"),
(
EXTRA_FUNCTION,
"",
(
"You can only define a single function, 'main' at the top level. "
"You may use nested functions inside that function if required"
),
),
("def other(input):\n\treturn input", "", "You must define a 'main' function"),
(
"def main(input, others, **kwargs):\n\treturn input",
"",
r"The main function should have the signature main\(input, \*\*kwargs\) only\.",
),
],
)
def test_code_node_build_errors(pipeline, code, input, error):
nodes = [
start_node(),
code_node(code),
end_node(),
]
with pytest.raises(PipelineNodeBuildError, match=error):
create_runnable(pipeline, nodes).invoke(PipelineState(messages=[input]))["messages"][-1]


@django_db_with_data(available_apps=("apps.service_providers",))
@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock())
@pytest.mark.parametrize(
("code", "input", "error"),
[
(
"import collections\ndef main(input, **kwargs):\n\treturn input",
"",
"Importing 'collections' is not allowed",
),
("def main(input, **kwargs):\n\treturn f'Hello, {blah}!'", "", "name 'blah' is not defined"),
],
)
def test_code_node_runtime_errors(pipeline, code, input, error):
nodes = [
start_node(),
code_node(code),
end_node(),
]
with pytest.raises(PipelineNodeRunError, match=error):
create_runnable(pipeline, nodes).invoke(PipelineState(messages=[input]))["messages"][-1]
12 changes: 12 additions & 0 deletions apps/pipelines/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,15 @@ def extract_structured_data_node(provider_id: str, provider_model_id: str, data_
"data_schema": data_schema,
},
}


def code_node(code: str | None = None):
if code is None:
code = "return f'Hello, {input}!'"
return {
"id": str(uuid4()),
"type": nodes.CodeNode.__name__,
"params": {
"code": code,
},
}
Loading

0 comments on commit 1f1508b

Please sign in to comment.