From f6f56bc006ff91c06a9315f3a884f09927329a51 Mon Sep 17 00:00:00 2001 From: Abraham Leal <45460452+abraham-leal@users.noreply.github.com> Date: Mon, 16 Dec 2024 17:41:00 -0600 Subject: [PATCH] fix more typing stuff --- .../langchain/langchain_nvidia_ai_endpoints.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/weave/integrations/langchain/langchain_nvidia_ai_endpoints.py b/weave/integrations/langchain/langchain_nvidia_ai_endpoints.py index 5a135a9b84d..03f0d1953e6 100644 --- a/weave/integrations/langchain/langchain_nvidia_ai_endpoints.py +++ b/weave/integrations/langchain/langchain_nvidia_ai_endpoints.py @@ -5,8 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional if TYPE_CHECKING: - from langchain_core.messages import AIMessageChunk, convert_to_openai_messages - from openai.types.chat import ChatCompletion + pass import weave from weave.trace.op import Op, ProcessedInputs @@ -16,6 +15,7 @@ # NVIDIA-specific accumulator for parsing the objects of streaming interactions def nvidia_accumulator(acc: Optional[Any], value: Any) -> Any: + from langchain_core.messages import AIMessageChunk from langchain_core.outputs import ChatGenerationChunk if acc is None: @@ -31,7 +31,9 @@ def nvidia_accumulator(acc: Optional[Any], value: Any) -> Any: # Post processor to transform output into OpenAI's ChatCompletion format -- need to handle stream and non-stream outputs def post_process_to_openai_format(output: Any) -> dict: + from langchain_core.messages import convert_to_openai_messages from langchain_core.outputs import ChatGenerationChunk, ChatResult + from openai.types.chat import ChatCompletion if isinstance(output, ChatResult): ## its ChatResult message = output.llm_output ## List of ChatGeneration @@ -112,6 +114,8 @@ def post_process_to_openai_format(output: Any) -> dict: def process_inputs_to_openai_format_stream( func: Op, args: tuple, kwargs: dict ) -> ProcessedInputs | None: + from langchain_core.messages import convert_to_openai_messages + original_args = args original_kwargs = kwargs @@ -143,6 +147,8 @@ def process_inputs_to_openai_format_stream( def process_inputs_to_openai_format( func: Op, args: tuple, kwargs: dict ) -> ProcessedInputs | None: + from langchain_core.messages import convert_to_openai_messages + original_args = args original_kwargs = kwargs