diff --git a/sdks/python/src/opik/api_objects/opik_client.py b/sdks/python/src/opik/api_objects/opik_client.py index 3b8a18df52..eb919fd653 100644 --- a/sdks/python/src/opik/api_objects/opik_client.py +++ b/sdks/python/src/opik/api_objects/opik_client.py @@ -15,11 +15,11 @@ constants, validation_helpers, ) -from ..message_processing import streamer_constructors, messages, jsonable_encoder +from ..message_processing import streamer_constructors, messages from ..rest_api import client as rest_api_client from ..rest_api.types import dataset_public, trace_public, span_public, project_public from ..rest_api.core.api_error import ApiError -from .. import datetime_helpers, config, httpx_client, url_helpers +from .. import datetime_helpers, config, httpx_client, jsonable_encoder, url_helpers LOGGER = logging.getLogger(__name__) diff --git a/sdks/python/src/opik/integrations/langchain/__init__.py b/sdks/python/src/opik/integrations/langchain/__init__.py index 5543499bcd..9e9bfebb98 100644 --- a/sdks/python/src/opik/integrations/langchain/__init__.py +++ b/sdks/python/src/opik/integrations/langchain/__init__.py @@ -1,4 +1,3 @@ from .opik_tracer import OpikTracer - __all__ = ["OpikTracer"] diff --git a/sdks/python/src/opik/integrations/langchain/opik_encoder_extension.py b/sdks/python/src/opik/integrations/langchain/opik_encoder_extension.py new file mode 100644 index 0000000000..986c091c2d --- /dev/null +++ b/sdks/python/src/opik/integrations/langchain/opik_encoder_extension.py @@ -0,0 +1,13 @@ +from typing import Any +from langchain.load import serializable +from opik import jsonable_encoder + + +def register() -> None: + def encoder_extension(obj: serializable.Serializable) -> Any: + return obj.to_json() + + jsonable_encoder.register_encoder_extension( + obj_type=serializable.Serializable, + encoder=encoder_extension, + ) diff --git a/sdks/python/src/opik/integrations/langchain/opik_tracer.py b/sdks/python/src/opik/integrations/langchain/opik_tracer.py index 94b4ee68ce..64c1985e5c 100644 --- a/sdks/python/src/opik/integrations/langchain/opik_tracer.py +++ b/sdks/python/src/opik/integrations/langchain/opik_tracer.py @@ -7,7 +7,7 @@ from opik import dict_utils from opik import opik_context -from . import openai_run_helpers +from . import openai_run_helpers, opik_encoder_extension from ...logging_messages import NESTED_SPAN_PROJECT_NAME_MISMATCH_WARNING_MESSAGE if TYPE_CHECKING: @@ -17,6 +17,8 @@ LOGGER = logging.getLogger(__name__) +opik_encoder_extension.register() + def _get_span_type(run: "Run") -> Literal["llm", "tool", "general"]: if run.run_type in ["llm", "tool"]: diff --git a/sdks/python/src/opik/message_processing/jsonable_encoder.py b/sdks/python/src/opik/jsonable_encoder.py similarity index 82% rename from sdks/python/src/opik/message_processing/jsonable_encoder.py rename to sdks/python/src/opik/jsonable_encoder.py index cf89c5e80c..458ad2e204 100644 --- a/sdks/python/src/opik/message_processing/jsonable_encoder.py +++ b/sdks/python/src/opik/jsonable_encoder.py @@ -1,10 +1,12 @@ import logging import dataclasses import datetime as dt + +from typing import Callable, Any, Type, Set, Tuple + from enum import Enum from pathlib import PurePath from types import GeneratorType -from typing import Any import numpy as np @@ -12,6 +14,12 @@ LOGGER = logging.getLogger(__name__) +_ENCODER_EXTENSIONS: Set[Tuple[Type, Callable[[Any], Any]]] = set() + + +def register_encoder_extension(obj_type: Type, encoder: Callable[[Any], Any]) -> None: + _ENCODER_EXTENSIONS.add((obj_type, encoder)) + def jsonable_encoder(obj: Any) -> Any: """ @@ -50,11 +58,9 @@ def jsonable_encoder(obj: Any) -> Any: if isinstance(obj, np.ndarray): return jsonable_encoder(obj.tolist()) - if hasattr(obj, "to_string"): # langchain internal data objects - try: - return jsonable_encoder(obj.to_string()) - except Exception: - pass + for type_, encoder in _ENCODER_EXTENSIONS: + if isinstance(obj, type_): + return jsonable_encoder(encoder(obj)) except Exception: LOGGER.debug("Failed to serialize object.", exc_info=True) diff --git a/sdks/python/src/opik/message_processing/message_processors.py b/sdks/python/src/opik/message_processing/message_processors.py index 678a17317c..9b3afaf7b1 100644 --- a/sdks/python/src/opik/message_processing/message_processors.py +++ b/sdks/python/src/opik/message_processing/message_processors.py @@ -4,7 +4,7 @@ from opik import logging_messages from . import messages -from .jsonable_encoder import jsonable_encoder +from ..jsonable_encoder import jsonable_encoder from .. import dict_utils from ..rest_api import client as rest_api_client from ..rest_api.types import feedback_score_batch_item diff --git a/sdks/python/tests/unit/message_processing/test_jsonable_encoder.py b/sdks/python/tests/unit/message_processing/test_jsonable_encoder.py index 0fa09b5e3a..73e0879b59 100644 --- a/sdks/python/tests/unit/message_processing/test_jsonable_encoder.py +++ b/sdks/python/tests/unit/message_processing/test_jsonable_encoder.py @@ -6,7 +6,7 @@ import pytest import dataclasses -import opik.message_processing.jsonable_encoder as jsonable_encoder +import opik.jsonable_encoder as jsonable_encoder @pytest.mark.parametrize(