Skip to content

Commit

Permalink
fix langchain core
Browse files Browse the repository at this point in the history
  • Loading branch information
abraham-leal committed Dec 17, 2024
1 parent 26c2120 commit a66e352
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions weave/integrations/langchain/langchain_nv_ai_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
import time
from collections.abc import Iterator
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import Any, Callable, Optional

if TYPE_CHECKING:
pass
import_failed = False

try:
from langchain_core.messages import AIMessageChunk, convert_to_openai_messages
from langchain_core.outputs import ChatGenerationChunk, ChatResult
except ImportError:
import_failed = True

import weave
from weave.trace.op import Op, ProcessedInputs
Expand All @@ -15,9 +20,6 @@

# NVIDIA-specific accumulator for parsing the objects of streaming interactions
def nvidia_accumulator(acc: Optional[Any], value: Any) -> Any:
from langchain_core.messages import AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk

if acc is None:
acc = ChatGenerationChunk(message=AIMessageChunk(content=""))
acc = acc + value
Expand All @@ -31,8 +33,6 @@ def nvidia_accumulator(acc: Optional[Any], value: Any) -> Any:

# Post processor to transform output into OpenAI's ChatCompletion format -- need to handle stream and non-stream outputs
def post_process_to_openai_format(output: Any) -> dict:
from langchain_core.messages import convert_to_openai_messages
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from openai.types.chat import ChatCompletion

if isinstance(output, ChatResult): ## its ChatResult
Expand Down Expand Up @@ -114,8 +114,6 @@ def post_process_to_openai_format(output: Any) -> dict:
def process_inputs_to_openai_format_stream(
func: Op, args: tuple, kwargs: dict
) -> ProcessedInputs:
from langchain_core.messages import convert_to_openai_messages

original_args = args
original_kwargs = kwargs

Expand Down Expand Up @@ -147,8 +145,6 @@ def process_inputs_to_openai_format_stream(
def process_inputs_to_openai_format(
func: Op, args: tuple, kwargs: dict
) -> ProcessedInputs:
from langchain_core.messages import convert_to_openai_messages

original_args = args
original_kwargs = kwargs

Expand Down Expand Up @@ -202,7 +198,6 @@ def invoke_fn(*args: Any, **kwargs: Any) -> Any:
# Wrap streaming methods (synchronous)
def create_stream_wrapper(name: str) -> Callable[[Callable], Callable]:
"""Wrap a synchronous streaming method for ChatNVIDIA."""
from langchain_core.outputs import ChatGenerationChunk

def wrapper(fn: Callable) -> Callable:
@wraps(fn)
Expand Down

0 comments on commit a66e352

Please sign in to comment.