diff --git a/weave/integrations/langchain/langchain_nv_ai_endpoints.py b/weave/integrations/langchain/langchain_nv_ai_endpoints.py index d9cfd65941d..ea0a01c158a 100644 --- a/weave/integrations/langchain/langchain_nv_ai_endpoints.py +++ b/weave/integrations/langchain/langchain_nv_ai_endpoints.py @@ -2,10 +2,15 @@ import time from collections.abc import Iterator from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import Any, Callable, Optional -if TYPE_CHECKING: - pass +import_failed = False + +try: + from langchain_core.messages import AIMessageChunk, convert_to_openai_messages + from langchain_core.outputs import ChatGenerationChunk, ChatResult +except ImportError: + import_failed = True import weave from weave.trace.op import Op, ProcessedInputs @@ -15,9 +20,6 @@ # NVIDIA-specific accumulator for parsing the objects of streaming interactions def nvidia_accumulator(acc: Optional[Any], value: Any) -> Any: - from langchain_core.messages import AIMessageChunk - from langchain_core.outputs import ChatGenerationChunk - if acc is None: acc = ChatGenerationChunk(message=AIMessageChunk(content="")) acc = acc + value @@ -31,8 +33,6 @@ def nvidia_accumulator(acc: Optional[Any], value: Any) -> Any: # Post processor to transform output into OpenAI's ChatCompletion format -- need to handle stream and non-stream outputs def post_process_to_openai_format(output: Any) -> dict: - from langchain_core.messages import convert_to_openai_messages - from langchain_core.outputs import ChatGenerationChunk, ChatResult from openai.types.chat import ChatCompletion if isinstance(output, ChatResult): ## its ChatResult @@ -114,8 +114,6 @@ def post_process_to_openai_format(output: Any) -> dict: def process_inputs_to_openai_format_stream( func: Op, args: tuple, kwargs: dict ) -> ProcessedInputs: - from langchain_core.messages import convert_to_openai_messages - original_args = args original_kwargs = kwargs @@ -147,8 +145,6 @@ def process_inputs_to_openai_format_stream( def process_inputs_to_openai_format( func: Op, args: tuple, kwargs: dict ) -> ProcessedInputs: - from langchain_core.messages import convert_to_openai_messages - original_args = args original_kwargs = kwargs @@ -202,7 +198,6 @@ def invoke_fn(*args: Any, **kwargs: Any) -> Any: # Wrap streaming methods (synchronous) def create_stream_wrapper(name: str) -> Callable[[Callable], Callable]: """Wrap a synchronous streaming method for ChatNVIDIA.""" - from langchain_core.outputs import ChatGenerationChunk def wrapper(fn: Callable) -> Callable: @wraps(fn)