Skip to content

Commit

Permalink
weave_client
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Aug 21, 2024
1 parent bac6631 commit e1db3b8
Show file tree
Hide file tree
Showing 32 changed files with 76 additions and 68 deletions.
2 changes: 1 addition & 1 deletion docs/scripts/generate_python_sdk_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ def doc_module(module, root_path="./docs/reference/python-sdk", module_root_path
def main():
import weave
from weave import feedback
from weave import weave_client as client
from weave.trace import util
from weave.trace import weave_client as client
from weave.trace_server import (
remote_http_trace_server,
trace_server_interface,
Expand Down
2 changes: 1 addition & 1 deletion weave/client_context/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from weave.legacy import context_state

if TYPE_CHECKING:
from weave.weave_client import WeaveClient
from weave.trace.weave_client import WeaveClient

_global_weave_client: Optional["WeaveClient"] = None
lock = threading.Lock()
Expand Down
4 changes: 1 addition & 3 deletions weave/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
remote_http_trace_server,
sqlite_trace_server,
)
from weave.trace_server import (
trace_server_interface as tsi,
)
from weave.trace_server import trace_server_interface as tsi

from . import environment, logs
from .tests import fixture_fakewandb
Expand Down
2 changes: 1 addition & 1 deletion weave/flow/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from weave.trace.errors import OpCallError
from weave.trace.op import Op
from weave.trace.vals import WeaveObject
from weave.weave_client import get_ref
from weave.trace.weave_client import get_ref

console = Console()

Expand Down
2 changes: 1 addition & 1 deletion weave/flow/obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from weave.trace.op import ObjectRef, Op, call
from weave.trace.vals import WeaveObject, pydantic_getattribute
from weave.weave_client import get_ref
from weave.trace.weave_client import get_ref


class Object(BaseModel):
Expand Down
10 changes: 5 additions & 5 deletions weave/integrations/anthropic/anthropic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _get_call_output(call: tsi.CallSchema) -> Any:
allowed_hosts=["api.wandb.ai", "localhost"],
)
def test_anthropic(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
api_key = os.environ.get("ANTHROPIC_API_KEY", "DUMMY_API_KEY")
anthropic_client = Anthropic(
Expand Down Expand Up @@ -67,7 +67,7 @@ def test_anthropic(
allowed_hosts=["api.wandb.ai", "localhost"],
)
def test_anthropic_stream(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
api_key = os.environ.get("ANTHROPIC_API_KEY", "DUMMY_API_KEY")
anthropic_client = Anthropic(
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_anthropic_stream(
)
@pytest.mark.asyncio
async def test_async_anthropic(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
anthropic_client = AsyncAnthropic(
# This is the default and can be omitted
Expand Down Expand Up @@ -156,7 +156,7 @@ async def test_async_anthropic(
)
@pytest.mark.asyncio
async def test_async_anthropic_stream(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
anthropic_client = AsyncAnthropic(
# This is the default and can be omitted
Expand Down Expand Up @@ -204,7 +204,7 @@ async def test_async_anthropic_stream(
allowed_hosts=["api.wandb.ai", "localhost"],
)
def test_tools_calling(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
api_key = os.environ.get("ANTHROPIC_API_KEY", "DUMMY_API_KEY")
anthropic_client = Anthropic(
Expand Down
8 changes: 4 additions & 4 deletions weave/integrations/cohere/cohere_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def op_name_from_ref(ref: str) -> str:
allowed_hosts=["api.wandb.ai", "localhost"],
)
def test_cohere(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
api_key = os.environ.get("COHERE_API_KEY", "DUMMY_API_KEY")

Expand Down Expand Up @@ -79,7 +79,7 @@ def test_cohere(
allowed_hosts=["api.wandb.ai", "localhost"],
)
def test_cohere_stream(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
api_key = os.environ.get("COHERE_API_KEY", "DUMMY_API_KEY")
cohere_client = cohere.Client(api_key=api_key)
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_cohere_stream(
allowed_hosts=["api.wandb.ai", "localhost"],
)
async def test_cohere_async(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
api_key = os.environ.get("COHERE_API_KEY", "DUMMY_API_KEY")

Expand Down Expand Up @@ -190,7 +190,7 @@ async def test_cohere_async(
allowed_hosts=["api.wandb.ai", "localhost"],
)
async def test_cohere_async_stream(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
api_key = os.environ.get("COHERE_API_KEY", "DUMMY_API_KEY")
cohere_client = cohere.AsyncClient(api_key=api_key)
Expand Down
2 changes: 1 addition & 1 deletion weave/integrations/dspy/dspy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pytest

import weave
from weave.trace.weave_client import WeaveClient
from weave.trace_server import trace_server_interface as tsi
from weave.weave_client import WeaveClient


def _get_call_output(call: tsi.CallSchema) -> Any:
Expand Down
10 changes: 5 additions & 5 deletions weave/integrations/groq/groq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def op_name_from_ref(ref: str) -> str:
allowed_hosts=["api.wandb.ai", "localhost", "trace.wandb.ai"],
)
def test_groq_quickstart(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
from groq import Groq

Expand Down Expand Up @@ -97,7 +97,7 @@ def test_groq_quickstart(
allowed_hosts=["api.wandb.ai", "localhost", "trace.wandb.ai"],
)
def test_groq_async_chat_completion(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
from groq import AsyncGroq

Expand Down Expand Up @@ -165,7 +165,7 @@ async def complete_chat() -> None:
allowed_hosts=["api.wandb.ai", "localhost", "trace.wandb.ai"],
)
def test_groq_streaming_chat_completion(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
from groq import Groq

Expand Down Expand Up @@ -248,7 +248,7 @@ def test_groq_streaming_chat_completion(
allowed_hosts=["api.wandb.ai", "localhost", "trace.wandb.ai"],
)
def test_groq_async_streaming_chat_completion(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
from groq import AsyncGroq

Expand Down Expand Up @@ -327,7 +327,7 @@ async def generate_reponse() -> str:
allowed_hosts=["api.wandb.ai", "localhost", "trace.wandb.ai"],
)
def test_groq_tool_call(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
import json

Expand Down
2 changes: 1 addition & 1 deletion weave/integrations/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)
from weave.trace import call_context
from weave.trace.patcher import Patcher
from weave.weave_client import Call
from weave.trace.weave_client import Call

import_failed = False

Expand Down
2 changes: 1 addition & 1 deletion weave/integrations/langchain/langchain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pytest

import weave
from weave.trace.weave_client import WeaveClient
from weave.trace_server import trace_server_interface as tsi
from weave.weave_client import WeaveClient


def filter_body(r: Any) -> Any:
Expand Down
10 changes: 5 additions & 5 deletions weave/integrations/litellm/litellm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def patch_litellm(request: Any) -> Generator[None, None, None]:
filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"]
)
def test_litellm_quickstart(
client: weave.weave_client.WeaveClient, patch_litellm: None
client: weave.trace.weave_client.WeaveClient, patch_litellm: None
) -> None:
# This is taken directly from https://docs.litellm.ai/docs/
chat_response = litellm.completion(
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_litellm_quickstart(
)
@pytest.mark.asyncio
async def test_litellm_quickstart_async(
client: weave.weave_client.WeaveClient, patch_litellm: None
client: weave.trace.weave_client.WeaveClient, patch_litellm: None
) -> None:
# This is taken directly from https://docs.litellm.ai/docs/
chat_response = await litellm.acompletion(
Expand Down Expand Up @@ -144,7 +144,7 @@ async def test_litellm_quickstart_async(
filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"]
)
def test_litellm_quickstart_stream(
client: weave.weave_client.WeaveClient, patch_litellm: None
client: weave.trace.weave_client.WeaveClient, patch_litellm: None
) -> None:
# This is taken directly from https://docs.litellm.ai/docs/
chat_response = litellm.completion(
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_litellm_quickstart_stream(
)
@pytest.mark.asyncio
async def test_litellm_quickstart_stream_async(
client: weave.weave_client.WeaveClient, patch_litellm: None
client: weave.trace.weave_client.WeaveClient, patch_litellm: None
) -> None:
# This is taken directly from https://docs.litellm.ai/docs/
chat_response = await litellm.acompletion(
Expand Down Expand Up @@ -231,7 +231,7 @@ async def test_litellm_quickstart_stream_async(
)
@pytest.mark.asyncio
def test_model_predict(
client: weave.weave_client.WeaveClient, patch_litellm: None
client: weave.trace.weave_client.WeaveClient, patch_litellm: None
) -> None:
class TranslatorModel(weave.Model):
model: str
Expand Down
2 changes: 1 addition & 1 deletion weave/integrations/llamaindex/llamaindex.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from weave.client_context import weave_client as weave_client_context
from weave.trace.patcher import Patcher
from weave.weave_client import Call
from weave.trace.weave_client import Call

TRANSFORM_EMBEDDINGS = False
ALLOWED_ROOT_EVENT_TYPES = ("query",)
Expand Down
4 changes: 2 additions & 2 deletions weave/integrations/llamaindex/llamaindex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def fake_api_key() -> Generator[None, None, None]:
before_record_request=filter_body,
)
def test_llamaindex_quickstart(
client: weave.weave_client.WeaveClient, fake_api_key: None
client: weave.trace.weave_client.WeaveClient, fake_api_key: None
) -> None:
# This is taken directly from https://docs.llamaindex.ai/en/stable/getting_started/starter_example/
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
Expand All @@ -105,7 +105,7 @@ def test_llamaindex_quickstart(
)
@pytest.mark.asyncio
async def test_llamaindex_quickstart_async(
client: weave.weave_client.WeaveClient, fake_api_key: None
client: weave.trace.weave_client.WeaveClient, fake_api_key: None
) -> None:
# This is taken directly from https://docs.llamaindex.ai/en/stable/getting_started/starter_example/
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
Expand Down
12 changes: 8 additions & 4 deletions weave/integrations/mistral/mistral_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _get_call_output(call: tsi.CallSchema) -> Any:
@pytest.mark.vcr(
filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"]
)
def test_mistral_quickstart(client: weave.weave_client.WeaveClient) -> None:
def test_mistral_quickstart(client: weave.trace.weave_client.WeaveClient) -> None:
# This is taken directly from https://docs.mistral.ai/getting-started/quickstart/
api_key = os.environ.get("MISTRAL_API_KEY", "DUMMY_API_KEY")
model = "mistral-large-latest"
Expand Down Expand Up @@ -77,7 +77,9 @@ def test_mistral_quickstart(client: weave.weave_client.WeaveClient) -> None:
filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"]
)
@pytest.mark.asyncio
async def test_mistral_quickstart_async(client: weave.weave_client.WeaveClient) -> None:
async def test_mistral_quickstart_async(
client: weave.trace.weave_client.WeaveClient,
) -> None:
# This is taken directly from https://docs.mistral.ai/getting-started/quickstart/
api_key = os.environ.get("MISTRAL_API_KEY", "DUMMY_API_KEY")
model = "mistral-large-latest"
Expand Down Expand Up @@ -125,7 +127,9 @@ async def test_mistral_quickstart_async(client: weave.weave_client.WeaveClient)
@pytest.mark.vcr(
filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"]
)
def test_mistral_quickstart_with_stream(client: weave.weave_client.WeaveClient) -> None:
def test_mistral_quickstart_with_stream(
client: weave.trace.weave_client.WeaveClient,
) -> None:
# This is taken directly from https://docs.mistral.ai/getting-started/quickstart/
api_key = os.environ.get("MISTRAL_API_KEY", "DUMMY_API_KEY")
model = "mistral-large-latest"
Expand Down Expand Up @@ -182,7 +186,7 @@ def test_mistral_quickstart_with_stream(client: weave.weave_client.WeaveClient)
)
@pytest.mark.asyncio
async def test_mistral_quickstart_with_stream_async(
client: weave.weave_client.WeaveClient,
client: weave.trace.weave_client.WeaveClient,
) -> None:
# This is taken directly from https://docs.mistral.ai/getting-started/quickstart/
api_key = os.environ.get("MISTRAL_API_KEY", "DUMMY_API_KEY")
Expand Down
Loading

0 comments on commit e1db3b8

Please sign in to comment.