diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index d44f064aa14..46672e37048 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -30,7 +30,10 @@ jobs: scopes: | ui weave + weave_ts weave_query + app + dev wip: false requireScope: true validateSingleCommit: false diff --git a/.gitignore b/.gitignore index 866ec62f011..0f6040766e8 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ gha-creds-*.json .nox *.log */file::memory:?cache=shared +weave/trace_server/model_providers \ No newline at end of file diff --git a/docs/docs/guides/core-types/prompts.md b/docs/docs/guides/core-types/prompts.md new file mode 100644 index 00000000000..9a2d50ecf2b --- /dev/null +++ b/docs/docs/guides/core-types/prompts.md @@ -0,0 +1,373 @@ +# Prompts + +Creating, evaluating, and refining prompts is a core activity for AI engineers. +Small changes to a prompt can have big impacts on your application's behavior. +Weave lets you create prompts, save and retrieve them, and evolve them over time. +Some of the benefits of Weave's prompt management system are: + +- Unopinionated core, with a batteries-included option for rapid development +- Versioning that shows you how a prompt has evolved over time +- The ability to update a prompt in production without redeploying your application +- The ability to evaluate a prompt against many inputs to evaluate performance + +## Getting started + +If you want complete control over how a Prompt is constructed, you can subclass the base class, `weave.Prompt`, `weave.StringPrompt`, or `weave.MessagesPrompt` and implement the corresponding `format` method. When you publish one of these objects with `weave.publish`, it will appear in your Weave project on the "Prompts" page. + +``` +class Prompt(Object): + def format(self, **kwargs: Any) -> Any: + ... + +class StringPrompt(Prompt): + def format(self, **kwargs: Any) -> str: + ... + +class MessagesPrompt(Prompt): + def format(self, **kwargs: Any) -> list: + ... +``` + +Weave also includes a "batteries-included" class called `EasyPrompt` that can be simpler to start with, especially if you are working with APIs that are similar to OpenAI. This document highlights the features you get with EasyPrompt. + +## Constructing prompts + +You can think of the EasyPrompt object as a list of messages with associated roles, optional +placeholder variables, and an optional model configuration. +But constructing a prompt can be as simple as providing a single string: + +```python +import weave + +prompt = weave.EasyPrompt("What's 23 * 42?") +assert prompt[0] == {"role": "user", "content": "What's 23 * 42?"} +``` + +For terseness, the weave library aliases the `EasyPrompt` class to `P`. + +```python +from weave import P +p = P("What's 23 * 42?") +``` + +It is common for a prompt to consist of multiple messages. Each message has an associated `role`. +If the role is omitted, it defaults to `"user"`. + +**Some common roles** + +| Role | Description | +| --------- | -------------------------------------------------------------------------------------------------------------------- | +| system | System prompts provide high level instructions and can be used to set the behavior, knowledge, or persona of the AI. | +| user | Represents input from a human user. (This is the default role.) | +| assistant | Represents the AI's generated replies. Can be used for historical completions or to show examples. | + +For convenience, you can prefix a message string with one of these known roles: + +```python +import weave + +prompt = weave.EasyPrompt("system: Talk like a pirate") +assert prompt[0] == {"role": "system", "content": "Talk like a pirate"} + +# An explicit role parameter takes precedence +prompt = weave.EasyPrompt("system: Talk like a pirate", role="user") +assert prompt[0] == {"role": "user", "content": "system: Talk like a pirate"} + +``` + +Messages can be appended to a prompt one-by-one: + +```python +import weave + +prompt = weave.EasyPrompt() +prompt.append("You are an expert travel consultant.", role="system") +prompt.append("Give me five ideas for top kid-friendly attractions in New Zealand.") +``` + +Or you can append multiple messages at once, either with the `append` method or with the `Prompt` +constructor, which is convenient for constructing a prompt from existing messages. + +```python +import weave + +prompt = weave.EasyPrompt() +prompt.append([ + {"role": "system", "content": "You are an expert travel consultant."}, + "Give me five ideas for top kid-friendly attractions in New Zealand." +]) + +# Same +prompt = weave.EasyPrompt([ + {"role": "system", "content": "You are an expert travel consultant."}, + "Give me five ideas for top kid-friendly attractions in New Zealand." +]) +``` + +The Prompt class is designed to be easily inserted into existing code. +For example, you can quickly wrap it around all of the arguments to the +OpenAI chat completion `create` call including its messages and model +configuration. If you don't wrap the inputs, Weave's integration would still +track all of the call's inputs, but it would not extract them as a separate +versioned object. Having a separate Prompt object allows you to version +the prompt, easily filter calls by that version, etc. + +```python +from weave import init, P +from openai import OpenAI +client = OpenAI() + +# Must specify a target project, otherwise the Weave code is a no-op +# highlight-next-line +init("intro-example") + +# highlight-next-line +response = client.chat.completions.create(P( + model="gpt-4o-mini", + messages=[ + {"role": "user", "content": "What's 23 * 42?"} + ], + temperature=0.7, + max_tokens=64, + top_p=1 +# highlight-next-line +)) +``` + +:::note +Why this works: Weave's OpenAI integration wraps the OpenAI `create` method to make it a Weave Op. +When the Op is executed, the Prompt object in the input will get saved and associated with the Call. +However, it will be replaced with the structure the `create` method expects for the execution of the +underlying function. +::: + +## Parameterizing prompts + +When specifying a prompt, you can include placeholders for values you want to fill in later. These placeholders are called "Parameters". +Parameters are indicated with curly braces. Here's a simple example: + +```python +import weave + +prompt = weave.EasyPrompt("What's {A} + {B}?") +``` + +You will specify values for all of the parameters or "bind" them, when you [use the prompt](#using-prompts). + +The `require` method of Prompt allows you to associate parameters with restrictions that will be checked at bind time to detect programming errors. + +```python +import weave + +prompt = weave.EasyPrompt("What's {A} + 42?") +prompt.require("A", type="int", min=0, max=100) + +prompt = weave.EasyPrompt("system: You are a {profession}") +prompt.require("profession", oneof=('pirate', 'cartoon mouse', 'hungry dragon'), default='pirate') +``` + +## Using prompts + +You use a Prompt by converting it into a list of messages where all template placeholders have been filled in. You can bind a prompt to parameter values with the `bind` method or by simply calling it as a function. Here's an example where the prompt has zero parameters. + +```python +import weave +prompt = weave.EasyPrompt("What's 23 * 42?") +assert prompt() == prompt.bind() == [ + {"role": "user", "content": "What's 23 * 42?"} +] +``` + +If a prompt has parameters, you would specify values for them when you use the prompt. +Parameter values can be passed in as a dictionary or as keyword arguments. + +```python +import weave +prompt = weave.EasyPrompt("What's {A} + {B}?") +assert prompt(A=5, B="10") == prompt({"A": 5, "B": "10"}) +``` + +If any parameters are missing, they will be left unsubstituted in the output. + +Here's a complete example of using a prompt with OpenAI. This example also uses [Weave's OpenAI integration](../integrations/openai.md) to automatically log the prompt and response. + +```python +import weave +from openai import OpenAI +client = OpenAI() + +weave.init("intro-example") +prompt = weave.EasyPrompt() +prompt.append("You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative.", role="system") +prompt.append("I love {this_thing}!") + +response = client.chat.completions.create( + model="gpt-4o-mini", + messages=prompt(this_thing="Weave"), + temperature=0.7, + max_tokens=64, + top_p=1 +) +``` + +## Publishing to server + +Prompt are a type of [Weave object](../tracking/objects.md), and use the same methods for publishing to the Weave server. +You must specify a destination project name with `weave.init` before you can publish a prompt. + +```python +import weave + +prompt = weave.EasyPrompt() +prompt.append("What's 23 * 42?") + +weave.init("intro-example") # Use entity/project format if not targeting your default entity +weave.publish(prompt, name="calculation-prompt") +``` + +Weave will automatically determine if the object has changed and only publish a new version if it has. +You can also specify a name or description for the Prompt as part of its constructor. + +```python +import weave + +prompt = weave.EasyPrompt( + "What's 23 * 42?", + name="calculation-prompt", + description="A prompt for calculating the product of two numbers.", +) + +weave.init("intro-example") +weave.publish(prompt) +``` + +## Retrieving from server + +Prompt are a type of [Weave object](../tracking/objects.md), and use the same methods for retrieval from the Weave server. +You must specify a source project name with `weave.init` before you can retrieve a prompt. + +```python +import weave + +weave.init("intro-example") +prompt = weave.ref("calculation-prompt").get() +``` + +By default, the latest version of the prompt is returned. You can make this explicit or select a specific version by providing its version id. + +```python +import weave + +weave.init("intro-example") +prompt = weave.ref("calculation-prompt:latest").get() +# "<prompt_name>:<version_digest>", for example: +prompt = weave.ref("calculation-prompt:QSLzr96CTzFwLWgFFi3EuawCI4oODz4Uax98SxIY79E").get() +``` + +It is also possible to retrieve a Prompt without calling `init` if you pass a fully qualified URI to `weave.ref`. + +## Loading and saving from files + +Prompts can be saved to files and loaded from files. This can be convenient if you want your Prompt to be versioned through +a mechanism other than Weave such as git, or as a fallback if Weave is not available. + +To save a prompt to a file, you can use the `dump_file` method. + +```python +import weave + +prompt = weave.EasyPrompt("What's 23 * 42?") +prompt.dump_file("~/prompt.json") +``` + +and load it again later with `Prompt.load_file`. + +```python +import weave + +prompt = weave.EasyPrompt.load_file("~/prompt.json") +``` + +You can also use the lower level `dump` and `Prompt.load` methods for custom (de)serialization. + +## Evaluating prompts + +The [Parameter feature of prompts](#parameterizing-prompts) can be used to execute or evaluate variations of a prompt. + +You can bind each row of a [Dataset](./datasets.md) to generate N variations of a prompt. + +```python +import weave + +# Create a dataset +dataset = weave.Dataset(name='countries', rows=[ + {'id': '0', 'country': "Argentina"}, + {'id': '1', 'country': "Belize"}, + {'id': '2', 'country': "Canada"}, + {'id': '3', 'country': "New Zealand"}, +]) + +prompt = weave.EasyPrompt(name='travel_agent') +prompt.append("You are an expert travel consultant.", role="system") +prompt.append("Tell me the capital of {country} and about five kid-friendly attractions there.") + + +prompts = prompt.bind_rows(dataset) +assert prompts[2][1]["content"] == "Tell me the capital of Canada and about five kid-friendly attractions there." +``` + +You can extend this into an [Evaluation](./evaluations.md): + +```python +import asyncio + +import openai +import weave + +weave.init("intro-example") + +# Create a dataset +dataset = weave.Dataset(name='countries', rows=[ + {'id': '0', 'country': "Argentina", 'capital': "Buenos Aires"}, + {'id': '1', 'country': "Belize", 'capital': "Belmopan"}, + {'id': '2', 'country': "Canada", 'capital': "Ottawa"}, + {'id': '3', 'country': "New Zealand", 'capital': "Wellington"}, +]) + +# Create a prompt +prompt = weave.EasyPrompt(name='travel_agent') +prompt.append("You are an expert travel consultant.", role="system") +prompt.append("Tell me the capital of {country} and about five kid-friendly attractions there.") + +# Create a model, combining a prompt with model configuration +class TravelAgentModel(weave.Model): + + model_name: str + prompt: weave.EasyPrompt + + @weave.op + async def predict(self, country: str) -> dict: + client = openai.AsyncClient() + + response = await client.chat.completions.create( + model=self.model_name, + messages=self.prompt(country=country), + ) + result = response.choices[0].message.content + if result is None: + raise ValueError("No response from model") + return result + +# Define and run the evaluation +@weave.op +def mentions_capital_scorer(capital: str, model_output: str) -> dict: + return {'correct': capital in model_output} + +model = TravelAgentModel(model_name="gpt-4o-mini", prompt=prompt) +evaluation = weave.Evaluation( + dataset=dataset, + scorers=[mentions_capital_scorer], +) +asyncio.run(evaluation.evaluate(model)) + +``` diff --git a/docs/sidebars.ts b/docs/sidebars.ts index c5da61462b5..d56f563fd3a 100644 --- a/docs/sidebars.ts +++ b/docs/sidebars.ts @@ -64,6 +64,7 @@ const sidebars: SidebarsConfig = { "guides/evaluation/scorers", ], }, + "guides/core-types/prompts", "guides/core-types/models", "guides/core-types/datasets", "guides/tracking/feedback", diff --git a/tests/integrations/litellm/client_completions_create_test.py b/tests/integrations/litellm/client_completions_create_test.py new file mode 100644 index 00000000000..a48f9155465 --- /dev/null +++ b/tests/integrations/litellm/client_completions_create_test.py @@ -0,0 +1,97 @@ +import os +from contextlib import contextmanager +from unittest.mock import patch + +from litellm.types.utils import ModelResponse + +from tests.trace.util import client_is_sqlite +from weave.trace.settings import _context_vars +from weave.trace_server import trace_server_interface as tsi +from weave.trace_server.secret_fetcher_context import secret_fetcher_context + + +@contextmanager +def with_tracing_disabled(): + token = _context_vars["disabled"].set(True) + try: + yield + finally: + _context_vars["disabled"].reset(token) + + +def test_completions_create(client): + """ + This test is testing the backend implementation of completions_create. It relies on LiteLLM + and we don't want to jump through the hoops to add it to the integration sharding. So we are putting + it here for now. Should be moved to a dedicated client tester that pins to a single python version. + """ + is_sqlite = client_is_sqlite(client) + if is_sqlite: + # no need to test in sqlite + return + + model_name = "gpt-4o" + inputs = { + "model": model_name, + "messages": [{"role": "user", "content": "Hello, world!"}], + } + mock_response = { + "id": "chatcmpl-ANnboqjHwrm6uWcubQma9pzxye0Cm", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Hello! How can I assist you today?", + "role": "assistant", + "tool_calls": None, + "function_call": None, + }, + } + ], + "created": 1730235604, + "model": "gpt-4o-2024-08-06", + "object": "chat.completion", + "system_fingerprint": "fp_90354628f2", + "usage": { + "completion_tokens": 9, + "prompt_tokens": 11, + "total_tokens": 20, + "completion_tokens_details": {"audio_tokens": None, "reasoning_tokens": 0}, + "prompt_tokens_details": {"audio_tokens": None, "cached_tokens": 0}, + }, + "service_tier": None, + } + + class DummySecretFetcher: + def fetch(self, secret_name: str) -> dict: + return { + "secrets": { + secret_name: os.environ.get(secret_name, "DUMMY_SECRET_VALUE") + } + } + + # Have to do this since we run the tests in the same process as the server + # and the inner litellm gets patched! + with with_tracing_disabled(): + with secret_fetcher_context(DummySecretFetcher()): + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse.model_validate( + mock_response + ) + res = client.server.completions_create( + tsi.CompletionsCreateReq.model_validate( + { + "project_id": client._project_id(), + "inputs": inputs, + } + ) + ) + + assert res.response == mock_response + calls = list(client.get_calls()) + assert len(calls) == 1 + assert calls[0].output == res.response + assert calls[0].summary["usage"][model_name] == res.response["usage"] + assert calls[0].inputs == inputs + assert calls[0].op_name == "weave.completions_create" diff --git a/tests/trace/test_prompt.py b/tests/trace/test_prompt.py new file mode 100644 index 00000000000..98bb731d076 --- /dev/null +++ b/tests/trace/test_prompt.py @@ -0,0 +1,23 @@ +from weave.flow.prompt.prompt import MessagesPrompt, StringPrompt + + +def test_stringprompt_format(): + class MyPrompt(StringPrompt): + def format(self, **kwargs) -> str: + return "Imagine a lot of complicated logic build this string." + + prompt = MyPrompt() + assert prompt.format() == "Imagine a lot of complicated logic build this string." + + +def test_messagesprompt_format(): + class MyPrompt(MessagesPrompt): + def format(self, **kwargs) -> list: + return [ + {"role": "user", "content": "What's 23 * 42"}, + ] + + prompt = MyPrompt() + assert prompt.format() == [ + {"role": "user", "content": "What's 23 * 42"}, + ] diff --git a/tests/trace/test_prompt_easy.py b/tests/trace/test_prompt_easy.py new file mode 100644 index 00000000000..6d01db92a9f --- /dev/null +++ b/tests/trace/test_prompt_easy.py @@ -0,0 +1,260 @@ +import itertools + +import pytest + +from weave import EasyPrompt + + +def iter_equal(items1, items2): + """`True` if iterators `items1` and `items2` contain equal items.""" + return (items1 is items2) or all( + a == b for a, b in itertools.zip_longest(items1, items2, fillvalue=object()) + ) + + +def test_prompt_message_constructor_str(): + prompt = EasyPrompt("What's 23 * 42") + assert prompt() == [{"role": "user", "content": "What's 23 * 42"}] + + +def test_prompt_message_constructor_prefix_str(): + prompt = EasyPrompt("system: you are a pirate") + assert prompt() == [{"role": "system", "content": "you are a pirate"}] + + +def test_prompt_message_constructor_role_arg(): + prompt = EasyPrompt("You're a calculator.", role="system") + assert prompt() == [{"role": "system", "content": "You're a calculator."}] + + +def test_prompt_message_constructor_array(): + prompt = EasyPrompt( + [ + {"role": "system", "content": "You're a calculator."}, + {"role": "user", "content": "What's 23 * 42"}, + ] + ) + assert prompt() == [ + {"role": "system", "content": "You're a calculator."}, + {"role": "user", "content": "What's 23 * 42"}, + ] + + +def test_prompt_message_constructor_obj(): + prompt = EasyPrompt( + name="myprompt", + model="gpt-4o", + messages=[ + { + "role": "system", + "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.", + }, + { + "role": "user", + "content": "Artificial intelligence is a technology with great promise.", + }, + ], + temperature=0.8, + max_tokens=64, + top_p=1, + ) + assert prompt() == [ + { + "role": "system", + "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.", + }, + { + "role": "user", + "content": "Artificial intelligence is a technology with great promise.", + }, + ] + assert prompt.config == { + "model": "gpt-4o", + "temperature": 0.8, + "max_tokens": 64, + "top_p": 1, + } + + +def test_prompt_append() -> None: + prompt = EasyPrompt() + prompt.append("You are a helpful assistant.", role="system") + prompt.append("system: who knows a lot about geography") + prompt.append( + """ + What's the capital of Brazil? + """, + dedent=True, + ) + assert prompt() == [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": "who knows a lot about geography"}, + {"role": "user", "content": "What's the capital of Brazil?"}, + ] + + +def test_prompt_append_with_role() -> None: + prompt = EasyPrompt() + prompt.append("system: who knows a lot about geography", role="asdf") + assert prompt() == [ + {"role": "asdf", "content": "system: who knows a lot about geography"}, + ] + + +def test_prompt_unbound_iteration() -> None: + """We don't error - is that the right behavior?""" + prompt = EasyPrompt("Tell me about {x}, {y}, and {z}. Especially {z}.") + prompt.bind(y="strawberry") + assert prompt.placeholders == ["x", "y", "z"] + assert not prompt.is_bound + assert prompt.unbound_placeholders == ["x", "z"] + assert list(prompt()) == [ + { + "role": "user", + "content": "Tell me about {x}, strawberry, and {z}. Especially {z}.", + } + ] + prompt.bind(x="vanilla", z="chocolate") + assert prompt.is_bound + assert prompt.unbound_placeholders == [] + assert list(prompt()) == [ + { + "role": "user", + "content": "Tell me about vanilla, strawberry, and chocolate. Especially chocolate.", + } + ] + + +def test_prompt_format_specifiers() -> None: + prompt = EasyPrompt("{x:.5}") + assert prompt.placeholders == ["x"] + assert prompt(x=3.14159)[0]["content"] == "3.1416" + + +def test_prompt_parameter_default() -> None: + prompt = EasyPrompt("{A} * {B}") + prompt.require("A", default=23) + prompt.require("B", default=42) + assert list(prompt()) == [{"role": "user", "content": "23 * 42"}] + + +def test_prompt_parameter_validation_int() -> None: + prompt = EasyPrompt("{A} + {B}") + prompt.require("A", min=10, max=100) + with pytest.raises(ValueError) as e: + prompt.bind(A=0) + assert str(e.value) == "A (0) is less than min (10)" + + +def test_prompt_parameter_validation_oneof() -> None: + prompt = EasyPrompt("{flavor}") + prompt.require("flavor", oneof=("vanilla", "strawberry", "chocolate")) + with pytest.raises(ValueError) as e: + prompt.bind(flavor="mint chip") + assert ( + str(e.value) + == "flavor (mint chip) must be one of vanilla, strawberry, chocolate" + ) + + +def test_prompt_bind_iteration() -> None: + """Iterating over a prompt should return messages with placeholders filled in.""" + prompt = EasyPrompt( + model="gpt-4o", + messages=[ + { + "role": "system", + "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.", + }, + {"role": "user", "content": "{sentence}"}, + ], + temperature=0.8, + max_tokens=64, + top_p=1, + ).bind(sentence="Artificial intelligence is a technology with great promise.") + desired = [ + { + "role": "system", + "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.", + }, + { + "role": "user", + "content": "Artificial intelligence is a technology with great promise.", + }, + ] + assert iter_equal(prompt, iter(desired)) + + +def test_prompt_as_dict(): + prompt = EasyPrompt( + model="gpt-4o", + messages=[ + { + "role": "system", + "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.", + }, + { + "role": "user", + "content": "Artificial intelligence is a technology with great promise.", + }, + ], + temperature=0.8, + max_tokens=64, + top_p=1, + ) + assert prompt.as_dict() == { + "model": "gpt-4o", + "temperature": 0.8, + "max_tokens": 64, + "top_p": 1, + "messages": [ + { + "role": "system", + "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.", + }, + { + "role": "user", + "content": "Artificial intelligence is a technology with great promise.", + }, + ], + } + + +def test_prompt_as_pydantic_dict(): + prompt = EasyPrompt( + model="gpt-4o", + messages=[ + { + "role": "system", + "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.", + }, + { + "role": "user", + "content": "Artificial intelligence is a technology with great promise.", + }, + ], + temperature=0.8, + max_tokens=64, + top_p=1, + ) + assert prompt.as_pydantic_dict() == { + "name": None, + "description": None, + "config": { + "model": "gpt-4o", + "temperature": 0.8, + "max_tokens": 64, + "top_p": 1, + }, + "data": [ + { + "role": "system", + "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.", + }, + { + "role": "user", + "content": "Artificial intelligence is a technology with great promise.", + }, + ], + "requirements": {}, + } diff --git a/weave-js/package.json b/weave-js/package.json index fd4ae47bf23..d925f9d0a42 100644 --- a/weave-js/package.json +++ b/weave-js/package.json @@ -161,6 +161,7 @@ "@types/color": "^3.0.0", "@types/cytoscape": "^3.2.0", "@types/cytoscape-dagre": "^2.2.2", + "@types/d3-array": "^3.2.1", "@types/diff": "^5.0.3", "@types/downloadjs": "^1.4.2", "@types/is-buffer": "^2.0.0", diff --git a/weave-js/src/components/Callout/Callout.tsx b/weave-js/src/components/Callout/Callout.tsx index 51028420f46..9fc6535d9cf 100644 --- a/weave-js/src/components/Callout/Callout.tsx +++ b/weave-js/src/components/Callout/Callout.tsx @@ -18,6 +18,7 @@ export const Callout = ({className, color, icon, size}: CalloutProps) => { <Tailwind> <div className={twMerge( + 'night-aware', getTagColorClass(color), 'flex items-center justify-center rounded-full', size === 'x-small' && 'h-[40px] w-[40px]', diff --git a/weave-js/src/components/FancyPage/useProjectSidebar.ts b/weave-js/src/components/FancyPage/useProjectSidebar.ts index 6fcc50703f9..ab3e11a77df 100644 --- a/weave-js/src/components/FancyPage/useProjectSidebar.ts +++ b/weave-js/src/components/FancyPage/useProjectSidebar.ts @@ -144,6 +144,13 @@ export const useProjectSidebar = ( isShown: showWeaveSidebarItems || isShowAll, iconName: IconNames.BaselineAlt, }, + { + type: 'button' as const, + name: 'Prompts', + slug: 'weave/prompts', + isShown: showWeaveSidebarItems || isShowAll, + iconName: IconNames.ForumChatBubble, + }, { type: 'button' as const, name: 'Models', diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index 1bd8c13106b..79f091e6a31 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -125,9 +125,11 @@ const useCallTabs = (call: CallSchema) => { { label: 'Use', content: ( - <Tailwind> - <TabUseCall call={call} /> - </Tailwind> + <ScrollableTabContent> + <Tailwind> + <TabUseCall call={call} /> + </Tailwind> + </ScrollableTabContent> ), }, ]; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCharts.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCharts.tsx new file mode 100644 index 00000000000..164122753d8 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCharts.tsx @@ -0,0 +1,190 @@ +import {GridFilterModel, GridSortModel} from '@mui/x-data-grid-pro'; +import React, {useMemo} from 'react'; + +import {MOON_400} from '../../../../../../common/css/color.styles'; +import {IconInfo} from '../../../../../Icon'; +import {WaveLoader} from '../../../../../Loaders/WaveLoader'; +import {Tailwind} from '../../../../../Tailwind'; +import {WFHighLevelCallFilter} from './callsTableFilter'; +import {useCallsForQuery} from './callsTableQuery'; +import { + ErrorPlotlyChart, + LatencyPlotlyChart, + RequestsPlotlyChart, +} from './Charts'; + +type CallsChartsProps = { + entity: string; + project: string; + filterModelProp: GridFilterModel; + filter: WFHighLevelCallFilter; +}; + +const Chart = ({ + isLoading, + chartData, + title, +}: { + isLoading: boolean; + chartData: any; + title: string; +}) => { + const CHART_CONTAINER_STYLES = + 'flex-1 rounded-lg border border-moon-250 bg-white p-10'; + const CHART_TITLE_STYLES = 'ml-12 mt-8 text-base font-semibold text-moon-750'; + const CHART_HEIGHT = 250; + const LOADING_CONTAINER_STYLES = `flex h-[${CHART_HEIGHT}px] items-center justify-center`; + + let chart = null; + if (isLoading) { + chart = ( + <div className={LOADING_CONTAINER_STYLES}> + <WaveLoader size="small" /> + </div> + ); + } else if (chartData.length > 0) { + switch (title) { + case 'Latency': + chart = ( + <LatencyPlotlyChart chartData={chartData} height={CHART_HEIGHT} /> + ); + break; + case 'Errors': + chart = ( + <ErrorPlotlyChart chartData={chartData} height={CHART_HEIGHT} /> + ); + break; + case 'Requests': + chart = ( + <RequestsPlotlyChart chartData={chartData} height={CHART_HEIGHT} /> + ); + break; + } + } else { + chart = ( + <div className={LOADING_CONTAINER_STYLES}> + <div className="flex flex-col items-center justify-center"> + <IconInfo color={MOON_400} /> + <div className="text-moon-500"> + No data available for the selected time frame + </div> + </div> + </div> + ); + } + return ( + <div className={CHART_CONTAINER_STYLES}> + <div className={CHART_TITLE_STYLES}>{title}</div> + {chart} + </div> + ); +}; + +export const CallsCharts = ({ + entity, + project, + filter, + filterModelProp, +}: CallsChartsProps) => { + const columns = useMemo( + () => ['started_at', 'ended_at', 'exception', 'id'], + [] + ); + const columnSet = useMemo(() => new Set(columns), [columns]); + const sortCalls: GridSortModel = useMemo( + () => [{field: 'started_at', sort: 'desc'}], + [] + ); + const page = useMemo( + () => ({ + pageSize: 1000, + page: 0, + }), + [] + ); + + const calls = useCallsForQuery( + entity, + project, + filter, + filterModelProp, + page, + sortCalls, + columnSet, + columns + ); + + const chartData = useMemo(() => { + if (calls.loading || !calls.result || calls.result.length === 0) { + return {latency: [], errors: [], requests: []}; + } + + const data: { + latency: Array<{started_at: string; latency: number}>; + errors: Array<{started_at: string; isError: boolean}>; + requests: Array<{started_at: string}>; + } = { + latency: [], + errors: [], + requests: [], + }; + + calls.result.forEach(call => { + const started_at = call.traceCall?.started_at; + if (!started_at) { + return; + } + const ended_at = call.traceCall?.ended_at; + + const isError = + call.traceCall?.exception !== null && + call.traceCall?.exception !== undefined && + call.traceCall?.exception !== ''; + + data.requests.push({started_at}); + + if (isError) { + data.errors.push({started_at, isError}); + } else { + data.errors.push({started_at, isError: false}); + } + + if (ended_at !== undefined) { + const startTime = new Date(started_at).getTime(); + const endTime = new Date(ended_at).getTime(); + const latency = endTime - startTime; + data.latency.push({started_at, latency}); + } + }); + return data; + }, [calls.result, calls.loading]); + + const charts = ( + <div className="m-10 flex flex-row gap-10"> + <Chart + isLoading={calls.loading} + chartData={chartData.latency} + title="Latency" + /> + <Chart + isLoading={calls.loading} + chartData={chartData.errors} + title="Errors" + /> + <Chart + isLoading={calls.loading} + chartData={chartData.requests} + title="Requests" + /> + </div> + ); + + return ( + <Tailwind> + {/* setting the width to the width of the screen minus the sidebar width because of overflow: 'hidden' properties in SimplePageLayout causing issues */} + <div className="md:w-[calc(100vw-56px)]"> + <div className="mb-20 mt-10">{charts}</div> + </div> + </Tailwind> + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx index 224e4d9a12d..25d80005260 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx @@ -26,6 +26,7 @@ import { useGridApiRef, } from '@mui/x-data-grid-pro'; import {MOON_200, TEAL_300} from '@wandb/weave/common/css/color.styles'; +import {Switch} from '@wandb/weave/components'; import {Checkbox} from '@wandb/weave/components/Checkbox/Checkbox'; import {Icon} from '@wandb/weave/components/Icon'; import React, { @@ -69,6 +70,7 @@ import {traceCallToUICallSchema} from '../wfReactInterface/tsDataModelHooks'; import {EXPANDED_REF_REF_KEY} from '../wfReactInterface/tsDataModelHooksCallRefExpansion'; import {objectVersionNiceString} from '../wfReactInterface/utilities'; import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; +import {CallsCharts} from './CallsCharts'; import {CallsCustomColumnMenu} from './CallsCustomColumnMenu'; import { BulkDeleteButton, @@ -168,6 +170,7 @@ export const CallsTable: FC<{ allowedColumnPatterns, }) => { const {loading: loadingUserInfo, userInfo} = useViewerInfo(); + const [isMetricsChecked, setMetricsChecked] = useState(false); const isReadonly = loadingUserInfo || !userInfo?.username || !userInfo?.teams.includes(entity); @@ -245,8 +248,8 @@ export const CallsTable: FC<{ project, effectiveFilter, filterModelResolved, - sortModelResolved, paginationModelResolved, + sortModelResolved, expandedRefCols ); @@ -742,6 +745,15 @@ export const CallsTable: FC<{ clearSelectedCalls={clearSelectedCalls} /> )} + <div className="flex items-center gap-6"> + <Switch.Root + size="small" + checked={isMetricsChecked} + onCheckedChange={setMetricsChecked}> + <Switch.Thumb size="small" checked={isMetricsChecked} /> + </Switch.Root> + Metrics + </div> {selectedInputObjectVersion && ( <Chip label={`Input: ${objectVersionNiceString( @@ -849,6 +861,14 @@ export const CallsTable: FC<{ )} </TailwindContents> }> + {isMetricsChecked && ( + <CallsCharts + entity={entity} + project={project} + filter={filter} + filterModelProp={filterModelResolved} + /> + )} <StyledDataGrid // Start Column Menu // ColumnMenu is needed to support pinning and column visibility diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/Charts.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/Charts.tsx new file mode 100644 index 00000000000..7d1e9313069 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/Charts.tsx @@ -0,0 +1,293 @@ +import {quantile} from 'd3-array'; +import _ from 'lodash'; +import moment from 'moment'; +import * as Plotly from 'plotly.js'; +import React, {useEffect, useMemo, useRef} from 'react'; + +import { + BLUE_500, + GREEN_500, + MOON_200, + MOON_300, + MOON_500, + RED_400, + TEAL_400, +} from '../../../../../../common/css/color.styles'; + +type ChartDataRequests = { + started_at: string; +}; + +type ChartDataErrors = { + started_at: string; + isError: boolean; +}; + +type ChartDataLatency = { + started_at: string; + latency: number; +}; + +const CHART_MARGIN_STYLE = { + l: 50, + r: 30, + b: 50, + t: 20, + pad: 0, +}; + +const X_AXIS_STYLE: Partial<Plotly.LayoutAxis> = { + type: 'date' as const, + automargin: true, + showgrid: false, + linecolor: MOON_300, + tickfont: {color: MOON_500}, + showspikes: false, +}; + +const X_AXIS_STYLE_WITH_SPIKES: Partial<Plotly.LayoutAxis> = { + ...X_AXIS_STYLE, + showspikes: true, + spikemode: 'across', + spikethickness: 1, + spikecolor: MOON_300, +}; + +const Y_AXIS_STYLE: Partial<Plotly.LayoutAxis> = { + automargin: true, + griddash: 'dot', + showgrid: true, + gridcolor: MOON_300, + linecolor: MOON_300, + showspikes: false, + tickfont: {color: MOON_500}, + zeroline: false, +}; + +export const calculateBinSize = ( + data: ChartDataLatency[] | ChartDataErrors[] | ChartDataRequests[], + targetBinCount = 15 +) => { + if (data.length === 0) { + return 60; + } // default to 60 minutes if no data + + const startTime = moment(_.minBy(data, 'started_at')?.started_at); + const endTime = moment(_.maxBy(data, 'started_at')?.started_at); + + const minutesInRange = endTime.diff(startTime, 'minutes'); + + // Calculate bin size in minutes, rounded to a nice number + const rawBinSize = Math.max(1, Math.ceil(minutesInRange / targetBinCount)); + const niceNumbers = [1, 2, 5, 10, 15, 30, 60, 120, 240, 360, 720, 1440]; + + // Find the closest nice number + return niceNumbers.reduce((prev, curr) => { + return Math.abs(curr - rawBinSize) < Math.abs(prev - rawBinSize) + ? curr + : prev; + }, niceNumbers[0]); +}; + +export const LatencyPlotlyChart: React.FC<{ + height: number; + chartData: ChartDataLatency[]; + targetBinCount?: number; +}> = ({height, chartData, targetBinCount}) => { + const divRef = useRef<HTMLDivElement>(null); + const binSize = calculateBinSize(chartData, targetBinCount); + + const plotlyData: Plotly.Data[] = useMemo(() => { + const groupedData = _(chartData) + .groupBy(d => { + const date = moment(d.started_at); + const roundedMinutes = Math.floor(date.minutes() / binSize) * binSize; + return date.startOf('hour').add(roundedMinutes, 'minutes').format(); + }) + .map((group, date) => { + const latenciesNonSorted = group.map(d => d.latency); + const p50 = quantile(latenciesNonSorted, 0.5) ?? 0; + const p95 = quantile(latenciesNonSorted, 0.95) ?? 0; + const p99 = quantile(latenciesNonSorted, 0.99) ?? 0; + return {timestamp: date, p50, p95, p99}; + }) + .value(); + + return [ + { + type: 'scatter', + mode: 'lines+markers', + x: groupedData.map(d => d.timestamp), + y: groupedData.map(d => d.p50), + name: 'p50 Latency', + line: {color: BLUE_500}, + marker: {color: BLUE_500}, + hovertemplate: '%{data.name}: %{y:.2f} ms<extra></extra>', + }, + { + type: 'scatter', + mode: 'lines+markers', + x: groupedData.map(d => d.timestamp), + y: groupedData.map(d => d.p95), + name: 'p95 Latency', + line: {color: GREEN_500}, + marker: {color: GREEN_500}, + hovertemplate: '%{data.name}: %{y:.2f} ms<extra></extra>', + }, + { + type: 'scatter', + mode: 'lines+markers', + x: groupedData.map(d => d.timestamp), + y: groupedData.map(d => d.p99), + name: 'p99 Latency', + line: {color: MOON_500}, + marker: {color: MOON_500}, + hovertemplate: '%{data.name}: %{y:.2f} ms<extra></extra>', + }, + ]; + }, [chartData, binSize]); + + useEffect(() => { + const plotlyLayout: Partial<Plotly.Layout> = { + height, + margin: CHART_MARGIN_STYLE, + xaxis: X_AXIS_STYLE_WITH_SPIKES, + yaxis: Y_AXIS_STYLE, + hovermode: 'x unified', + showlegend: false, + hoverlabel: { + bordercolor: MOON_200, + }, + }; + + const plotlyConfig: Partial<Plotly.Config> = { + displayModeBar: false, + responsive: true, + }; + + if (divRef.current) { + Plotly.newPlot(divRef.current, plotlyData, plotlyLayout, plotlyConfig); + } + }, [plotlyData, height]); + + return <div ref={divRef}></div>; +}; + +export const ErrorPlotlyChart: React.FC<{ + height: number; + chartData: ChartDataErrors[]; + targetBinCount?: number; +}> = ({height, chartData, targetBinCount}) => { + const divRef = useRef<HTMLDivElement>(null); + const binSize = calculateBinSize(chartData, targetBinCount); + + const plotlyData: Plotly.Data[] = useMemo(() => { + const groupedData = _(chartData) + .groupBy(d => { + const date = moment(d.started_at); + const roundedMinutes = Math.floor(date.minutes() / binSize) * binSize; + return date.startOf('hour').add(roundedMinutes, 'minutes').format(); + }) + .map((group, date) => ({ + timestamp: date, + count: group.filter(d => d.isError).length, + })) + .value(); + + return [ + { + type: 'bar', + x: groupedData.map(d => d.timestamp), + y: groupedData.map(d => d.count), + name: 'Error Count', + marker: {color: RED_400}, + hovertemplate: '%{y} errors<extra></extra>', + }, + ]; + }, [chartData, binSize]); + + useEffect(() => { + const plotlyLayout: Partial<Plotly.Layout> = { + height, + margin: CHART_MARGIN_STYLE, + bargap: 0.2, + xaxis: X_AXIS_STYLE, + yaxis: Y_AXIS_STYLE, + hovermode: 'x unified', + hoverlabel: { + bordercolor: MOON_200, + }, + dragmode: 'zoom', + }; + + const plotlyConfig: Partial<Plotly.Config> = { + displayModeBar: false, + responsive: true, + }; + + if (divRef.current) { + Plotly.newPlot(divRef.current, plotlyData, plotlyLayout, plotlyConfig); + } + }, [plotlyData, height]); + + return <div ref={divRef}></div>; +}; + +export const RequestsPlotlyChart: React.FC<{ + height: number; + chartData: ChartDataRequests[]; + targetBinCount?: number; +}> = ({height, chartData, targetBinCount}) => { + const divRef = useRef<HTMLDivElement>(null); + const binSize = calculateBinSize(chartData, targetBinCount); + + const plotlyData: Plotly.Data[] = useMemo(() => { + const groupedData = _(chartData) + .groupBy(d => { + const date = moment(d.started_at); + const roundedMinutes = Math.floor(date.minutes() / binSize) * binSize; + return date.startOf('hour').add(roundedMinutes, 'minutes').format(); + }) + .map((group, date) => ({ + timestamp: date, + count: group.length, + })) + .value(); + + return [ + { + type: 'bar', + x: groupedData.map(d => d.timestamp), + y: groupedData.map(d => d.count), + name: 'Requests', + marker: {color: TEAL_400}, + hovertemplate: '%{y} requests<extra></extra>', + }, + ]; + }, [chartData, binSize]); + + useEffect(() => { + const plotlyLayout: Partial<Plotly.Layout> = { + height, + margin: CHART_MARGIN_STYLE, + xaxis: X_AXIS_STYLE, + yaxis: Y_AXIS_STYLE, + bargap: 0.2, + hovermode: 'x unified', + hoverlabel: { + bordercolor: MOON_200, + }, + }; + + const plotlyConfig: Partial<Plotly.Config> = { + displayModeBar: false, + responsive: true, + }; + + if (divRef.current) { + Plotly.newPlot(divRef.current, plotlyData, plotlyLayout, plotlyConfig); + } + }, [plotlyData, height]); + + return <div ref={divRef}></div>; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableQuery.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableQuery.ts index 2a0d1bad489..de221b652dc 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableQuery.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableQuery.ts @@ -32,9 +32,9 @@ export const useCallsForQuery = ( project: string, filter: WFHighLevelCallFilter, gridFilter: GridFilterModel, - gridSort: GridSortModel, gridPage: GridPaginationModel, - expandedColumns: Set<string>, + gridSort?: GridSortModel, + expandedColumns?: Set<string>, columns?: string[] ): { costsLoading: boolean; @@ -44,8 +44,8 @@ export const useCallsForQuery = ( refetch: () => void; } => { const {useCalls, useCallsStats} = useWFHooks(); - const offset = gridPage.page * gridPage.pageSize; - const limit = gridPage.pageSize; + const effectiveOffset = gridPage?.page * gridPage?.pageSize; + const effectiveLimit = gridPage.pageSize; const {sortBy, lowLevelFilter, filterBy} = useFilterSortby( filter, gridFilter, @@ -56,8 +56,8 @@ export const useCallsForQuery = ( entity, project, lowLevelFilter, - limit, - offset, + effectiveLimit, + effectiveOffset, sortBy, filterBy, columns, @@ -77,11 +77,16 @@ export const useCallsForQuery = ( const total = useMemo(() => { if (callsStats.loading || callsStats.result == null) { - return offset + callResults.length; + return effectiveOffset + callResults.length; } else { return callsStats.result.count; } - }, [callResults.length, callsStats.loading, callsStats.result, offset]); + }, [ + callResults.length, + callsStats.loading, + callsStats.result, + effectiveOffset, + ]); const costFilter: CallFilter = useMemo( () => ({ @@ -94,7 +99,7 @@ export const useCallsForQuery = ( entity, project, costFilter, - limit, + effectiveLimit, undefined, sortBy, undefined, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx index 3d461681a3c..b5c1a4bf96c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx @@ -111,8 +111,8 @@ const AddEvaluationButton: React.FC<{ props.state.data.project, evaluationsFilter, DEFAULT_FILTER_CALLS, - DEFAULT_SORT_CALLS, page, + DEFAULT_SORT_CALLS, expandedRefCols, columns ); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx index 7e1663c70dc..045ceb54900 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx @@ -27,9 +27,11 @@ import { SimplePageLayoutWithHeader, } from './common/SimplePageLayout'; import {EvaluationLeaderboardTab} from './LeaderboardTab'; +import {TabPrompt} from './TabPrompt'; import {TabUseDataset} from './TabUseDataset'; import {TabUseModel} from './TabUseModel'; import {TabUseObject} from './TabUseObject'; +import {TabUsePrompt} from './TabUsePrompt'; import {KNOWN_BASE_OBJECT_CLASSES} from './wfReactInterface/constants'; import {useWFHooks} from './wfReactInterface/context'; import { @@ -127,6 +129,8 @@ const ObjectVersionPageInner: React.FC<{ }, [objectVersion.baseObjectClass]); const refUri = objectVersionKeyToRefUri(objectVersion); + const showPromptTab = objectVersion.val._class_name === 'EasyPrompt'; + const minimalColumns = useMemo(() => { return ['id', 'op_name', 'project_id']; }, []); @@ -287,6 +291,26 @@ const ObjectVersionPageInner: React.FC<{ // }, // ]} tabs={[ + ...(showPromptTab + ? [ + { + label: 'Prompt', + content: ( + <ScrollableTabContent> + {data.loading ? ( + <CenteredAnimatedLoader /> + ) : ( + <TabPrompt + entity={entityName} + project={projectName} + data={viewerDataAsObject} + /> + )} + </ScrollableTabContent> + ), + }, + ] + : []), ...(isEvaluation && evalHasCalls ? [ { @@ -333,23 +357,33 @@ const ObjectVersionPageInner: React.FC<{ { label: 'Use', content: ( - <Tailwind> - {baseObjectClass === 'Dataset' ? ( - <TabUseDataset - name={objectName} - uri={refUri} - versionIndex={objectVersionIndex} - /> - ) : baseObjectClass === 'Model' ? ( - <TabUseModel - name={objectName} - uri={refUri} - projectName={projectName} - /> - ) : ( - <TabUseObject name={objectName} uri={refUri} /> - )} - </Tailwind> + <ScrollableTabContent> + <Tailwind> + {baseObjectClass === 'Prompt' ? ( + <TabUsePrompt + name={objectName} + uri={refUri} + entityName={entityName} + projectName={projectName} + data={viewerDataAsObject} + /> + ) : baseObjectClass === 'Dataset' ? ( + <TabUseDataset + name={objectName} + uri={refUri} + versionIndex={objectVersionIndex} + /> + ) : baseObjectClass === 'Model' ? ( + <TabUseModel + name={objectName} + uri={refUri} + projectName={projectName} + /> + ) : ( + <TabUseObject name={objectName} uri={refUri} /> + )} + </Tailwind> + </ScrollableTabContent> ), }, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx index 5e06b4a0474..1a6e4afc577 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx @@ -12,6 +12,7 @@ import { } from './common/Links'; import {CenteredAnimatedLoader} from './common/Loader'; import { + ScrollableTabContent, SimpleKeyValueTable, SimplePageLayoutWithHeader, } from './common/SimplePageLayout'; @@ -136,9 +137,11 @@ const OpVersionPageInner: React.FC<{ { label: 'Use', content: ( - <Tailwind> - <TabUseOp name={opNiceName(opId)} uri={uri} /> - </Tailwind> + <ScrollableTabContent> + <Tailwind> + <TabUseOp name={opNiceName(opId)} uri={uri} /> + </Tailwind> + </ScrollableTabContent> ), }, ] diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabPrompt.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabPrompt.tsx new file mode 100644 index 00000000000..2f2819c3b34 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabPrompt.tsx @@ -0,0 +1,25 @@ +import classNames from 'classnames'; +import React from 'react'; + +import {Tailwind} from '../../../../Tailwind'; +import {MessageList} from './ChatView/MessageList'; + +type Data = Record<string, any>; + +type TabPromptProps = { + entity: string; + project: string; + data: Data; +}; + +export const TabPrompt = ({entity, project, data}: TabPromptProps) => { + return ( + <Tailwind> + <div className="flex flex-col sm:flex-row"> + <div className={classNames('mt-4 w-full')}> + <MessageList messages={data.data} /> + </div> + </div> + </Tailwind> + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseCall.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseCall.tsx index 817d647d970..3f33be98e7c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseCall.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseCall.tsx @@ -30,7 +30,7 @@ os.environ["WF_TRACE_SERVER_URL"] = "http://127.0.0.1:6345" const codeFeedback = `call.feedback.add("correctness", {"value": 4})`; return ( - <Box m={2} className="text-sm"> + <Box className="text-sm"> <TabUseBanner> See{' '} <DocLink path="guides/tracking/tracing" text="Weave docs on tracing" />{' '} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseDataset.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseDataset.tsx index 8b56a17604d..861eb15f443 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseDataset.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseDataset.tsx @@ -43,7 +43,7 @@ ${pythonName} = weave.ref('${ref.artifactName}:v${versionIndex}').get()`; } return ( - <Box m={2} className="text-sm"> + <Box className="text-sm"> <TabUseBanner> See{' '} <DocLink diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseModel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseModel.tsx index da2560ad94b..234e9ffe6f1 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseModel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseModel.tsx @@ -21,7 +21,7 @@ export const TabUseModel = ({name, uri, projectName}: TabUseModelProps) => { const label = isParentObject ? 'model version' : 'object'; return ( - <Box m={2} className="text-sm"> + <Box className="text-sm"> <TabUseBanner> See{' '} <DocLink path="guides/tracking/models" text="Weave docs on models" />{' '} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseObject.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseObject.tsx index 4ea8dc6af30..e8178521316 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseObject.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseObject.tsx @@ -15,7 +15,7 @@ type TabUseObjectProps = { export const TabUseObject = ({name, uri}: TabUseObjectProps) => { const pythonName = isValidVarName(name) ? name : 'obj'; return ( - <Box m={2} className="text-sm"> + <Box className="text-sm"> <TabUseBanner> See{' '} <DocLink diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseOp.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseOp.tsx index 6cc0dd7c848..2370fac74cf 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseOp.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseOp.tsx @@ -16,7 +16,7 @@ export const TabUseOp = ({name, uri}: TabUseOpProps) => { const pythonName = isValidVarName(name) ? name : 'op'; return ( - <Box m={2} className="text-sm"> + <Box className="text-sm"> <TabUseBanner> See <DocLink path="guides/tracking/ops" text="Weave docs on ops" /> for more information. diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUsePrompt.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUsePrompt.tsx new file mode 100644 index 00000000000..6d00af48bc6 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUsePrompt.tsx @@ -0,0 +1,99 @@ +import {Box} from '@mui/material'; +import React from 'react'; + +import {isValidVarName} from '../../../../../core/util/var'; +import {parseRef} from '../../../../../react'; +import {abbreviateRef} from '../../../../../util/refs'; +import {Alert} from '../../../../Alert'; +import {CopyableText} from '../../../../CopyableText'; +import {DocLink} from './common/Links'; + +type Data = Record<string, any>; + +type TabUsePromptProps = { + name: string; + uri: string; + entityName: string; + projectName: string; + data: Data; +}; + +export const TabUsePrompt = ({ + name, + uri, + entityName, + projectName, + data, +}: TabUsePromptProps) => { + const pythonName = isValidVarName(name) ? name : 'prompt'; + const ref = parseRef(uri); + const isParentObject = !ref.artifactRefExtra; + const label = isParentObject ? 'prompt version' : 'prompt'; + + // TODO: Simplify if no params. + const longExample = `import weave +from openai import OpenAI + +weave.init("${projectName}") + +${pythonName} = weave.ref("${uri}").get() + +class MyModel(weave.Model): + model_name: str + prompt: weave.Prompt + + @weave.op + def predict(self, params: dict) -> dict: + client = OpenAI() + response = client.chat.completions.create( + model=self.model_name, + messages=self.prompt.bind(params), + ) + result = response.choices[0].message.content + if result is None: + raise ValueError("No response from model") + return result + +mymodel = MyModel(model_name="gpt-3.5-turbo", prompt=${pythonName}) + +# Replace with desired parameter values +params = ${JSON.stringify({}, null, 2)} +print(mymodel.predict(params)) +`; + + return ( + <Box className="text-sm"> + <Alert icon="lightbulb-info"> + See{' '} + <DocLink + path="guides/tracking/objects#getting-an-object-back" + text="Weave docs on refs" + />{' '} + and <DocLink path="guides/core-types/prompts" text="prompts" /> for more + information. + </Alert> + + <Box mt={2}> + The ref for this {label} is: + <CopyableText text={uri} /> + </Box> + <Box mt={2}> + Use the following code to retrieve this {label}: + <CopyableText + text={`${pythonName} = weave.ref("${abbreviateRef(uri)}").get()`} + copyText={`${pythonName} = weave.ref("${uri}").get()`} + tooltipText="Click to copy unabridged string" + /> + </Box> + + <Box mt={2}>A more complete example:</Box> + <Box mt={2}> + <CopyableText + text={longExample} + copyText={longExample} + tooltipText="Click to copy unabridged string" + /> + </Box> + </Box> + ); +}; diff --git a/weave-js/yarn.lock b/weave-js/yarn.lock index 3315ac14ade..2ee20553257 100644 --- a/weave-js/yarn.lock +++ b/weave-js/yarn.lock @@ -4215,6 +4215,11 @@ resolved "https://registry.yarnpkg.com/@types/cytoscape/-/cytoscape-3.19.10.tgz#f4540749d68cd3db6f89da5197f7ec2a2ca516ee" integrity sha512-PLsKQcsUd05nz4PYyulIhjkLnlq9oD2WYpswrWOjoqtFZEuuBje0f9fi2zTG5/yfTf5+Gpllf/MPcFmfDzZ24w== +"@types/d3-array@^3.2.1": + version "3.2.1" + resolved "https://registry.yarnpkg.com/@types/d3-array/-/d3-array-3.2.1.tgz#1f6658e3d2006c4fceac53fde464166859f8b8c5" + integrity sha512-Y2Jn2idRrLzUfAKV2LyRImR+y4oa2AntrgID95SHJxuMUrkNXmanDSed71sRNZysveJVt1hLLemQZIady0FpEg== + "@types/debug@^4.0.0": version "4.1.8" resolved "https://registry.yarnpkg.com/@types/debug/-/debug-4.1.8.tgz#cef723a5d0a90990313faec2d1e22aee5eecb317" diff --git a/weave/__init__.py b/weave/__init__.py index 3b54ba97176..781d1e89d89 100644 --- a/weave/__init__.py +++ b/weave/__init__.py @@ -12,9 +12,15 @@ from weave.flow.eval import Evaluation, Scorer from weave.flow.model import Model from weave.flow.obj import Object +from weave.flow.prompt.prompt import EasyPrompt, Prompt +from weave.flow.prompt.prompt import MessagesPrompt as MessagesPrompt +from weave.flow.prompt.prompt import StringPrompt as StringPrompt from weave.trace.util import Thread as Thread from weave.trace.util import ThreadPoolExecutor as ThreadPoolExecutor +# Alias for succinct code +P = EasyPrompt + # Special object informing doc generation tooling which symbols # to document & to associate with this module. __docspec__ = [ @@ -31,6 +37,7 @@ Object, Dataset, Model, + Prompt, Evaluation, Scorer, ] diff --git a/weave/flow/prompt/common.py b/weave/flow/prompt/common.py new file mode 100644 index 00000000000..80bc63ae60f --- /dev/null +++ b/weave/flow/prompt/common.py @@ -0,0 +1,14 @@ +# TODO: Maybe use an enum or something to lock down types more + +ROLE_COLORS: dict[str, str] = { + "system": "bold blue", + "user": "bold green", + "assistant": "bold magenta", +} + + +def color_role(role: str) -> str: + color = ROLE_COLORS.get(role) + if color: + return f"[{color}]{role}[/]" + return role diff --git a/weave/flow/prompt/prompt.py b/weave/flow/prompt/prompt.py new file mode 100644 index 00000000000..016e9d3f996 --- /dev/null +++ b/weave/flow/prompt/prompt.py @@ -0,0 +1,440 @@ +import copy +import json +import os +import re +import textwrap +from collections import UserList +from pathlib import Path +from typing import IO, Any, Optional, SupportsIndex, TypedDict, Union, overload + +from pydantic import Field +from rich.table import Table + +from weave.flow.obj import Object +from weave.flow.prompt.common import ROLE_COLORS, color_role +from weave.trace.api import publish as weave_publish +from weave.trace.op import op +from weave.trace.refs import ObjectRef +from weave.trace.rich import pydantic_util + + +class Message(TypedDict): + role: str + content: str + + +def maybe_dedent(content: str, dedent: bool) -> str: + if dedent: + return textwrap.dedent(content).strip() + return content + + +def str_to_message( + content: str, role: Optional[str] = None, dedent: bool = False +) -> Message: + if role is not None: + return {"role": role, "content": maybe_dedent(content, dedent)} + for role in ROLE_COLORS: + prefix = role + ":" + if content.startswith(prefix): + return { + "role": role, + "content": maybe_dedent(content[len(prefix) :].lstrip(), dedent), + } + return {"role": "user", "content": maybe_dedent(content, dedent)} + + +# TODO: This supports Python format specifiers, but maybe we don't want to +# because it will be harder to do in clients in other languages? +RE_PLACEHOLDER = re.compile(r"\{(\w+)(:[^}]+)?\}") + + +def extract_placeholders(text: str) -> list[str]: + placeholders = re.findall(RE_PLACEHOLDER, text) + unique = [] + for name, _ in placeholders: + if name not in unique: + unique.append(name) + return unique + + +def color_content(content: str, values: dict) -> str: + placeholders = extract_placeholders(content) + colored_values = {} + for placeholder in placeholders: + if placeholder not in values: + colored_values[placeholder] = "[red]{" + placeholder + "}[/]" + else: + colored_values[placeholder] = ( + "[orange3]{" + placeholder + ":" + str(values[placeholder]) + "}[/]" + ) + return content.format(**colored_values) + + +class Prompt(Object): + def format(self, **kwargs: Any) -> Any: + raise NotImplemented + + +class MessagesPrompt(Prompt): + def format(self, **kwargs: Any) -> list: + raise NotImplemented + + +class StringPrompt(Prompt): + def format(self, **kwargs: Any) -> str: + raise NotImplemented + + +class EasyPrompt(UserList, Prompt): + data: list = Field(default_factory=list) + config: dict = Field(default_factory=dict) + requirements: dict = Field(default_factory=dict) + + _values: dict + + def __init__( + self, + content: Optional[Union[str, dict, list]] = None, + *, + role: Optional[str] = None, + dedent: bool = False, + **kwargs: Any, + ) -> None: + super(UserList, self).__init__() + name = kwargs.pop("name", None) + description = kwargs.pop("description", None) + config = kwargs.pop("config", {}) + requirements = kwargs.pop("requirements", {}) + if "messages" in kwargs: + content = kwargs.pop("messages") + config.update(kwargs) + kwargs = {"config": config, "requirements": requirements} + super(Object, self).__init__(name=name, description=description, **kwargs) + self._values = {} + if content is not None: + if isinstance(content, (str, dict)): + content = [content] + for item in content: + self.append(item, role=role, dedent=dedent) + + def __add__(self, other: Any) -> "Prompt": + new_prompt = self.copy() + new_prompt += other + return new_prompt + + def append( + self, + item: Any, + role: Optional[str] = None, + dedent: bool = False, + ) -> None: + if isinstance(item, str): + # Seems like we don't want to do this, if the user wants + # all system we have helpers for that, and we want to make the + # case of constructing system + user easy + # role = self.data[-1].get("role", "user") if self.data else "user" + self.data.append(str_to_message(item, role=role, dedent=dedent)) + elif isinstance(item, dict): + # TODO: Validate that item has message shape + # TODO: Override role and do dedent? + self.data.append(item) + elif isinstance(item, list): + for item in item: + self.append(item) + else: + raise ValueError(f"Cannot append {item} of type {type(item)} to Prompt") + + def __iadd__(self, item: Any) -> "Prompt": + self.append(item) + return self + + @property + def as_str(self) -> str: + """Join all messages into a single string.""" + return " ".join(message.get("content", "") for message in self.data) + + @property + def system_message(self) -> Message: + """Join all messages into a system prompt message.""" + return {"role": "system", "content": self.as_str} + + @property + def system_prompt(self) -> "Prompt": + """Join all messages into a system prompt object.""" + return Prompt(self.as_str, role="system") + + @property + def messages(self) -> list[Message]: + return self.data + + @property + def placeholders(self) -> list[str]: + all_placeholders: list[str] = [] + for message in self.data: + # TODO: Support placeholders in image messages? + placeholders = extract_placeholders(message["content"]) + all_placeholders.extend( + p for p in placeholders if p not in all_placeholders + ) + return all_placeholders + + @property + def unbound_placeholders(self) -> list[str]: + unbound = [] + for p in self.placeholders: + if p not in self._values: + unbound.append(p) + return unbound + + @property + def is_bound(self) -> bool: + return not self.unbound_placeholders + + def validate_requirement(self, key: str, value: Any) -> list: + problems = [] + requirement = self.requirements.get(key) + if not requirement: + return [] + # TODO: Type coercion + min = requirement.get("min") + if min is not None and value < min: + problems.append(f"{key} ({value}) is less than min ({min})") + max = requirement.get("max") + if max is not None and value > max: + problems.append(f"{key} ({value}) is greater than max ({max})") + oneof = requirement.get("oneof") + if oneof is not None and value not in oneof: + problems.append(f"{key} ({value}) must be one of {', '.join(oneof)}") + return problems + + def validate_requirements(self, values: dict[str, Any]) -> list: + problems = [] + for key, value in values.items(): + problems += self.validate_requirement(key, value) + return problems + + def bind(self, *args: Any, **kwargs: Any) -> "Prompt": + is_dict = len(args) == 1 and isinstance(args[0], dict) + problems = [] + if is_dict: + problems += self.validate_requirements(args[0]) + problems += self.validate_requirements(kwargs) + if problems: + raise ValueError("\n".join(problems)) + if is_dict: + self._values.update(args[0]) + self._values.update(kwargs) + return self + + def __call__(self, *args: Any, **kwargs: Any) -> list[Message]: + if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict): + kwargs = args[0] + prompt = self.bind(kwargs) + return list(prompt) + + # TODO: Any should be Dataset but there is a circular dependency issue + def bind_rows(self, dataset: Union[list[dict], Any]) -> list["Prompt"]: + rows = dataset if isinstance(dataset, list) else dataset.rows + bound: list["Prompt"] = [] + for row in rows: + bound.append(self.copy().bind(row)) + return bound + + @overload + def __getitem__(self, index: SupportsIndex) -> Any: ... + + @overload + def __getitem__(self, key: slice) -> "EasyPrompt": ... + + def __getitem__(self, key: Union[SupportsIndex, slice]) -> Any: + """Override getitem to return a Message, Prompt object, or config value.""" + if isinstance(key, SupportsIndex): + int_index = key.__index__() + message = self.data[int_index].copy() + placeholders = extract_placeholders(message["content"]) + values = {} + for placeholder in placeholders: + if placeholder in self._values: + values[placeholder] = self._values[placeholder] + elif ( + placeholder in self.requirements + and "default" in self.requirements[placeholder] + ): + values[placeholder] = self.requirements[placeholder]["default"] + else: + values[placeholder] = "{" + placeholder + "}" + message["content"] = message["content"].format(**values) + return message + elif isinstance(key, slice): + new_prompt = Prompt() + new_prompt.name = self.name + new_prompt.description = self.description + new_prompt.data = self.data[key] + new_prompt.config = self.config.copy() + new_prompt.requirements = self.requirements.copy() + new_prompt._values = self._values.copy() + return new_prompt + elif isinstance(key, str): + if key == "ref": + return self + if key == "messages": + return self.data + return self.config[key] + else: + raise TypeError(f"Invalid argument type: {type(key)}") + + def __deepcopy__(self, memo: dict) -> "Prompt": + # I'm sure this isn't right, but hacking in to avoid + # TypeError: cannot pickle '_thread.lock' object. + # Basically, as part of logging our message objects are + # turning into WeaveDicts which have a sever reference which + # in turn can't be copied + c = copy.deepcopy(dict(self.config), memo) + r = copy.deepcopy(dict(self.requirements), memo) + p = Prompt( + name=self.name, description=self.description, config=c, requirements=r + ) + p._values = dict(self._values) + for value in self.data: + p.data.append(dict(value)) + return p + + def require(self, param_name: str, **kwargs: Any) -> "Prompt": + self.requirements[param_name] = kwargs + return self + + def configure(self, config: Optional[dict] = None, **kwargs: Any) -> "Prompt": + if config: + self.config = config + self.config.update(kwargs) + return self + + def publish(self, name: Optional[str] = None) -> ObjectRef: + # TODO: This only works if we've called weave.init, but it seems like + # that shouldn't be necessary if we have loaded this from a ref. + return weave_publish(self, name=name) + + def messages_table(self, title: Optional[str] = None) -> Table: + table = Table(title=title, title_justify="left", show_header=False) + table.add_column("Role", justify="right") + table.add_column("Content") + # TODO: Maybe we should inline the values here? Or highlight placeholders missing values in red? + for message in self.data: + table.add_row( + color_role(message.get("role", "user")), + color_content(message.get("content", ""), self._values), + ) + return table + + def values_table(self, title: Optional[str] = None) -> Table: + table = Table(title=title, title_justify="left", show_header=False) + table.add_column("Parameter", justify="right") + table.add_column("Value") + for key, value in self._values.items(): + table.add_row(key, str(value)) + return table + + def config_table(self, title: Optional[str] = None) -> Table: + table = Table(title=title, title_justify="left", show_header=False) + table.add_column("Key", justify="right") + table.add_column("Value") + for key, value in self.config.items(): + table.add_row(key, str(value)) + return table + + def print(self) -> str: + tables = [] + if self.name or self.description: + table1 = Table(show_header=False) + table1.add_column("Key", justify="right", style="bold cyan") + table1.add_column("Value") + if self.name is not None: + table1.add_row("Name", self.name) + if self.description is not None: + table1.add_row("Description", self.description) + tables.append(table1) + if self.data: + tables.append(self.messages_table(title="Messages")) + if self._values: + tables.append(self.values_table(title="Parameters")) + if self.config: + tables.append(self.config_table(title="Config")) + tables = [pydantic_util.table_to_str(t) for t in tables] + return "\n".join(tables) + + def __str__(self) -> str: + """Return a single prompt string when str() is called on the object.""" + return self.as_str + + def _repr_pretty_(self, p: Any, cycle: bool) -> None: + """Show a nicely formatted table in ipython.""" + if cycle: + p.text("Prompt(...)") + else: + p.text(self.print()) + + def as_pydantic_dict(self) -> dict[str, Any]: + return self.model_dump() + + def as_dict(self) -> dict[str, Any]: + # In chat completion kwargs format + return { + **self.config, + "messages": list(self), + } + + @staticmethod + def from_obj(obj: Any) -> "EasyPrompt": + messages = obj.messages if hasattr(obj, "messages") else obj.data + messages = [dict(m) for m in messages] + config = dict(obj.config) + requirements = dict(obj.requirements) + return EasyPrompt( + name=obj.name, + description=obj.description, + messages=messages, + config=config, + requirements=requirements, + ) + + @staticmethod + def load(fp: IO) -> "EasyPrompt": + if isinstance(fp, str): # Common mistake + raise ValueError( + "Prompt.load() takes a file-like object, not a string. Did you mean Prompt.e()?" + ) + data = json.load(fp) + prompt = EasyPrompt(**data) + return prompt + + @staticmethod + def load_file(filepath: Union[str, Path]) -> "Prompt": + expanded_path = os.path.expanduser(str(filepath)) + with open(expanded_path, "r") as f: + return EasyPrompt.load(f) + + def dump(self, fp: IO) -> None: + json.dump(self.as_pydantic_dict(), fp, indent=2) + + def dump_file(self, filepath: Union[str, Path]) -> None: + expanded_path = os.path.expanduser(str(filepath)) + with open(expanded_path, "w") as f: + self.dump(f) + + # TODO: We would like to be able to make this an Op. + # Unfortunately, litellm tries to make a deepcopy of the messages + # and that fails because the Message objects aren't picklable. + # TypeError: cannot pickle '_thread.RLock' object + # (Which I think is because they keep a reference to the server interface maybe?) + @op + def run(self) -> Any: + # TODO: Nicer result type + import litellm + + result = litellm.completion( + messages=list(self), + model=self.config.get("model", "gpt-4o-mini"), + ) + # TODO: Print in a nicer format + return result diff --git a/weave/integrations/openai/openai_sdk.py b/weave/integrations/openai/openai_sdk.py index d32d1a80a70..558373ab44a 100644 --- a/weave/integrations/openai/openai_sdk.py +++ b/weave/integrations/openai/openai_sdk.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional import weave +from weave.trace.op import Op, ProcessedInputs from weave.trace.op_extensions.accumulator import add_accumulator from weave.trace.patcher import MultiPatcher, SymbolPatcher @@ -277,6 +278,28 @@ def should_use_accumulator(inputs: dict) -> bool: ) +def openai_on_input_handler( + func: Op, args: tuple, kwargs: dict +) -> Optional[ProcessedInputs]: + if len(args) == 2 and isinstance(args[1], weave.EasyPrompt): + original_args = args + original_kwargs = kwargs + prompt = args[1] + args = args[:-1] + kwargs.update(prompt.as_dict()) + inputs = { + "prompt": prompt, + } + return ProcessedInputs( + original_args=original_args, + original_kwargs=original_kwargs, + args=args, + kwargs=kwargs, + inputs=inputs, + ) + return None + + def create_wrapper_sync( name: str, ) -> Callable[[Callable], Callable]: @@ -301,6 +324,7 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: op = weave.op()(_add_stream_options(fn)) op.name = name # type: ignore + op._set_on_input_handler(openai_on_input_handler) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: lambda acc, value: openai_accumulator( @@ -338,6 +362,7 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: op = weave.op()(_add_stream_options(fn)) op.name = name # type: ignore + op._set_on_input_handler(openai_on_input_handler) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: lambda acc, value: openai_accumulator( diff --git a/weave/trace/op.py b/weave/trace/op.py index 7614b1d8630..ae85d65e7b8 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -5,6 +5,7 @@ import sys import traceback import typing +from dataclasses import dataclass from functools import partial, wraps from types import MethodType from typing import ( @@ -84,6 +85,21 @@ def print_call_link(call: "Call") -> None: print(f"{TRACE_CALL_EMOJI} {call.ui_url}") +@dataclass +class ProcessedInputs: + # What the user passed to the function + original_args: tuple + original_kwargs: dict[str, Any] + + # What should get passed to the interior function + args: tuple + kwargs: dict[str, Any] + + # What should get sent to the Weave server + inputs: dict[str, Any] + + +OnInputHandlerType = Callable[["Op", tuple, dict], Optional[ProcessedInputs]] FinishCallbackType = Callable[[Any, Optional[BaseException]], None] OnOutputHandlerType = Callable[[Any, FinishCallbackType, Dict], Any] # Call, original function output, exception if occurred @@ -155,6 +171,9 @@ class Op(Protocol): call: Callable[..., Any] calls: Callable[..., "CallsIter"] + _set_on_input_handler: Callable[[OnInputHandlerType], None] + _on_input_handler: Optional[OnInputHandlerType] + # not sure if this is the best place for this, but kept for compat _set_on_output_handler: Callable[[OnOutputHandlerType], None] _on_output_handler: Optional[OnOutputHandlerType] @@ -175,6 +194,12 @@ class Op(Protocol): _tracing_enabled: bool +def _set_on_input_handler(func: Op, on_input: OnInputHandlerType) -> None: + if func._on_input_handler is not None: + raise ValueError("Cannot set on_input_handler multiple times") + func._on_input_handler = on_input + + def _set_on_output_handler(func: Op, on_output: OnOutputHandlerType) -> None: if func._on_output_handler is not None: raise ValueError("Cannot set on_output_handler multiple times") @@ -203,16 +228,32 @@ def _is_unbound_method(func: Callable) -> bool: return bool(is_method) -def _create_call( - func: Op, *args: Any, __weave: Optional[WeaveKwargs] = None, **kwargs: Any -) -> "Call": - client = weave_client_context.require_weave_client() - +def default_on_input_handler(func: Op, args: tuple, kwargs: dict) -> ProcessedInputs: try: inputs = func.signature.bind(*args, **kwargs).arguments except TypeError as e: raise OpCallError(f"Error calling {func.name}: {e}") inputs_with_defaults = _apply_fn_defaults_to_inputs(func, inputs) + return ProcessedInputs( + original_args=args, + original_kwargs=kwargs, + args=args, + kwargs=kwargs, + inputs=inputs_with_defaults, + ) + + +def _create_call( + func: Op, *args: Any, __weave: Optional[WeaveKwargs] = None, **kwargs: Any +) -> "Call": + client = weave_client_context.require_weave_client() + + pargs = None + if func._on_input_handler is not None: + pargs = func._on_input_handler(func, args, kwargs) + if not pargs: + pargs = default_on_input_handler(func, args, kwargs) + inputs_with_defaults = pargs.inputs # This should probably be configurable, but for now we redact the api_key if "api_key" in inputs_with_defaults: @@ -368,12 +409,19 @@ def _do_call( ) -> tuple[Any, "Call"]: func = op.resolve_fn call = _placeholder_call() + + pargs = None + if op._on_input_handler is not None: + pargs = op._on_input_handler(op, args, kwargs) + if not pargs: + pargs = default_on_input_handler(op, args, kwargs) + if settings.should_disable_weave(): - res = func(*args, **kwargs) + res = func(*pargs.args, **pargs.kwargs) elif weave_client_context.get_weave_client() is None: - res = func(*args, **kwargs) + res = func(*pargs.args, **pargs.kwargs) elif not op._tracing_enabled: - res = func(*args, **kwargs) + res = func(*pargs.args, **pargs.kwargs) else: try: # This try/except allows us to fail gracefully and @@ -388,10 +436,10 @@ def _do_call( logger.error, CALL_CREATE_MSG.format(traceback.format_exc()), ) - res = func(*args, **kwargs) + res = func(*pargs.args, **pargs.kwargs) else: execute_result = _execute_call( - op, call, *args, __should_raise=__should_raise, **kwargs + op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs ) if inspect.iscoroutine(execute_result): raise Exception( @@ -600,6 +648,9 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: wrapper.__call__ = wrapper # type: ignore wrapper.__self__ = wrapper # type: ignore + wrapper._set_on_input_handler = partial(_set_on_input_handler, wrapper) # type: ignore + wrapper._on_input_handler = None # type: ignore + wrapper._set_on_output_handler = partial(_set_on_output_handler, wrapper) # type: ignore wrapper._on_output_handler = None # type: ignore diff --git a/weave/trace/refs.py b/weave/trace/refs.py index f29c79091a1..ef002997ea3 100644 --- a/weave/trace/refs.py +++ b/weave/trace/refs.py @@ -144,6 +144,19 @@ def uri(self) -> str: u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra) return u + def objectify(self, obj: Any) -> Any: + """Convert back to higher level object.""" + class_name = getattr(obj, "_class_name", None) + if "EasyPrompt" == class_name: + from weave.flow.prompt.prompt import EasyPrompt + + prompt = EasyPrompt.from_obj(obj) + # We want to use the ref on the object (and not self) as it will have had + # version number or latest alias resolved to a specific digest. + prompt.__dict__["ref"] = obj.ref + return prompt + return obj + def get(self) -> Any: # Move import here so that it only happens when the function is called. # This import is invalid in the trace server and represents a dependency @@ -153,7 +166,7 @@ def get(self) -> Any: gc = get_weave_client() if gc is not None: - return gc.get(self) + return self.objectify(gc.get(self)) # Special case: If the user is attempting to fetch an object but has not # yet initialized the client, we can initialize a client to @@ -166,7 +179,7 @@ def get(self) -> Any: res = init_client.client.get(self) finally: init_client.reset() - return res + return self.objectify(res) def is_descended_from(self, potential_ancestor: "ObjectRef") -> bool: if self.entity != potential_ancestor.entity: diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index bc6f834f662..9a1f1a1303f 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -71,6 +71,7 @@ SelectableCHCallSchema, SelectableCHObjSchema, ) +from weave.trace_server.constants import COMPLETIONS_CREATE_OP_NAME from weave.trace_server.emoji_util import detone_emojis from weave.trace_server.errors import InsertTooLarge, InvalidRequest, RequestTooLarge from weave.trace_server.feedback import ( @@ -92,7 +93,13 @@ ActionScore, feedback_base_models, ) +from weave.trace_server.llm_completion import lite_llm_completion +from weave.trace_server.model_providers.model_providers import ( + MODEL_PROVIDERS_FILE, + fetch_model_to_provider_info_map, +) from weave.trace_server.orm import ParamBuilder, Row +from weave.trace_server.secret_fetcher_context import _secret_fetcher_context from weave.trace_server.table_query_builder import ( ROW_ORDER_COLUMN_NAME, TABLE_ROWS_ALIAS, @@ -200,6 +207,9 @@ def __init__( self._flush_immediately = True self._call_batch: list[list[Any]] = [] self._use_async_insert = use_async_insert + self._model_to_provider_info_map = fetch_model_to_provider_info_map( + MODEL_PROVIDERS_FILE + ) @classmethod def from_env(cls, use_async_insert: bool = False) -> "ClickHouseTraceServer": @@ -1513,6 +1523,64 @@ def execute_batch_action( return tsi.ExecuteBatchActionRes() + def completions_create( + self, req: tsi.CompletionsCreateReq + ) -> tsi.CompletionsCreateRes: + model_name = req.inputs.model + model_info = self._model_to_provider_info_map.get(model_name) + if not model_info: + raise InvalidRequest(f"No model info found for model {model_name}") + secret_fetcher = _secret_fetcher_context.get() + if not secret_fetcher: + raise InvalidRequest( + f"No secret fetcher found, cannot fetch API key for model {model_name}" + ) + secret_name = model_info.get("api_key_name") + if not secret_name: + raise InvalidRequest(f"No secret name found for model {model_name}") + api_key = secret_fetcher.fetch(secret_name).get("secrets", {}).get(secret_name) + if not api_key: + raise InvalidRequest(f"No API key found for model {model_name}") + + start_time = datetime.datetime.now() + res = lite_llm_completion(api_key, req.inputs) + end_time = datetime.datetime.now() + + start = tsi.StartedCallSchemaForInsert( + project_id=req.project_id, + wb_user_id=req.wb_user_id, + op_name=COMPLETIONS_CREATE_OP_NAME, + started_at=start_time, + inputs={**req.inputs.model_dump(exclude_none=True)}, + attributes={}, + ) + start_call = _start_call_for_insert_to_ch_insertable_start_call(start) + end = tsi.EndedCallSchemaForInsert( + project_id=req.project_id, + id=start_call.id, + ended_at=end_time, + output=res.response, + summary={}, + ) + if "usage" in res.response: + end.summary["usage"] = {req.inputs.model: res.response["usage"]} + + if "error" in res.response: + end.exception = res.response["error"] + end_call = _end_call_for_insert_to_ch_insertable_end_call(end) + calls: list[Union[CallStartCHInsertable, CallEndCHInsertable]] = [ + start_call, + end_call, + ] + batch_data = [] + for call in calls: + call_dict = call.model_dump() + values = [call_dict.get(col) for col in all_call_insert_columns] + batch_data.append(values) + + self._insert_call_batch(batch_data) + return res + # Private Methods @property def ch_client(self) -> CHClient: diff --git a/weave/trace_server/constants.py b/weave/trace_server/constants.py new file mode 100644 index 00000000000..2e6f117d816 --- /dev/null +++ b/weave/trace_server/constants.py @@ -0,0 +1 @@ +COMPLETIONS_CREATE_OP_NAME = "weave.completions_create" diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index 1acdb0c55c4..ffc229a9270 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -351,3 +351,10 @@ def execute_batch_action( ) -> tsi.ExecuteBatchActionRes: req.project_id = self._idc.ext_to_int_project_id(req.project_id) return self._ref_apply(self._internal_trace_server.execute_batch_action, req) + + def completions_create( + self, req: tsi.CompletionsCreateReq + ) -> tsi.CompletionsCreateRes: + req.project_id = self._idc.ext_to_int_project_id(req.project_id) + res = self._ref_apply(self._internal_trace_server.completions_create, req) + return res diff --git a/weave/trace_server/llm_completion.py b/weave/trace_server/llm_completion.py new file mode 100644 index 00000000000..907e8fe413b --- /dev/null +++ b/weave/trace_server/llm_completion.py @@ -0,0 +1,13 @@ +from weave.trace_server import trace_server_interface as tsi + + +def lite_llm_completion( + api_key: str, inputs: tsi.CompletionsCreateRequestInputs +) -> tsi.CompletionsCreateRes: + from litellm import completion + + try: + res = completion(**inputs.model_dump(exclude_none=True), api_key=api_key) + return tsi.CompletionsCreateRes(response=res.model_dump()) + except Exception as e: + return tsi.CompletionsCreateRes(response={"error": str(e)}) diff --git a/weave/trace_server/model_providers/model_providers.py b/weave/trace_server/model_providers/model_providers.py new file mode 100644 index 00000000000..f33f18ad224 --- /dev/null +++ b/weave/trace_server/model_providers/model_providers.py @@ -0,0 +1,49 @@ +import json +import os +from typing import Dict, TypedDict + +import requests + +model_providers_url = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" +MODEL_PROVIDERS_FILE = "model_providers.json" + +PROVIDER_TO_API_KEY_NAME_MAP = { + "anthropic": "ANTHROPIC_API_KEY", + "gemini": "GOOGLE_API_KEY", + "openai": "OPENAI_API_KEY", + "fireworks": "FIREWORKS_API_KEY", + "groq": "GEMMA_API_KEY", +} + + +class LLMModelProviderInfo(TypedDict): + litellm_provider: str + api_key_name: str + + +def fetch_model_to_provider_info_map( + cached_file_name: str = MODEL_PROVIDERS_FILE, +) -> Dict[str, LLMModelProviderInfo]: + full_path = os.path.join(os.path.dirname(__file__), cached_file_name) + if os.path.exists(full_path): + with open(full_path, "r") as f: + return json.load(f) + try: + req = requests.get(model_providers_url) + req.raise_for_status() + except requests.exceptions.RequestException as e: + print("Failed to fetch models:", e) + return {} + + providers: Dict[str, LLMModelProviderInfo] = {} + for k, val in req.json().items(): + provider = val.get("litellm_provider") + api_key_name = PROVIDER_TO_API_KEY_NAME_MAP.get(provider) + if api_key_name: + providers[k] = LLMModelProviderInfo( + litellm_provider=provider, api_key_name=api_key_name + ) + + with open(full_path, "w") as f: + json.dump(providers, f) + return providers diff --git a/weave/trace_server/secret_fetcher_context.py b/weave/trace_server/secret_fetcher_context.py new file mode 100644 index 00000000000..535118078e2 --- /dev/null +++ b/weave/trace_server/secret_fetcher_context.py @@ -0,0 +1,21 @@ +import contextvars +from contextlib import contextmanager +from typing import Generator, Optional, Protocol + + +class SecretFetcher(Protocol): + def fetch(self, secret_name: str) -> dict: ... + + +_secret_fetcher_context: contextvars.ContextVar[Optional[SecretFetcher]] = ( + contextvars.ContextVar("secret_fetcher", default=None) +) + + +@contextmanager +def secret_fetcher_context(sf: SecretFetcher) -> Generator[None, None, None]: + token = _secret_fetcher_context.set(sf) + try: + yield + finally: + _secret_fetcher_context.reset(token) diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 673c5237705..6f5b5648f63 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1088,6 +1088,12 @@ def execute_batch_action( "EXECUTE BATCH ACTION is not implemented for local sqlite" ) + def completions_create( + self, req: tsi.CompletionsCreateReq + ) -> tsi.CompletionsCreateRes: + print("COMPLETIONS CREATE is not implemented for local sqlite", req) + return tsi.CompletionsCreateRes() + def _table_row_read(self, project_id: str, row_digest: str) -> tsi.TableRowSchema: conn, cursor = get_conn_cursor(self.db_path) # Now get the rows diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 330696a9890..4f064759505 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -1,6 +1,6 @@ import datetime from enum import Enum -from typing import Any, Dict, Iterator, List, Literal, Optional, Protocol, Union +from typing import Any, Dict, Iterator, List, Literal, Optional, Protocol, Type, Union from pydantic import BaseModel, ConfigDict, Field, field_serializer from typing_extensions import TypedDict @@ -237,6 +237,46 @@ class CallsDeleteRes(BaseModel): pass +class CompletionsCreateRequestInputs(BaseModel): + model: str + messages: List = [] + timeout: Optional[Union[float, str]] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + n: Optional[int] = None + stop: Optional[Union[str, List]] = None + max_completion_tokens: Optional[int] = None + max_tokens: Optional[int] = None + modalities: Optional[List] = None + presence_penalty: Optional[float] = None + frequency_penalty: Optional[float] = None + logit_bias: Optional[dict] = None + user: Optional[str] = None + # openai v1.0+ new params + response_format: Optional[Union[dict, Type[BaseModel]]] = None + seed: Optional[int] = None + tools: Optional[List] = None + tool_choice: Optional[Union[str, dict]] = None + logprobs: Optional[bool] = None + top_logprobs: Optional[int] = None + parallel_tool_calls: Optional[bool] = None + extra_headers: Optional[dict] = None + # soon to be deprecated params by OpenAI + functions: Optional[List] = None + function_call: Optional[str] = None + api_version: Optional[str] = None + + +class CompletionsCreateReq(BaseModel): + project_id: str + inputs: CompletionsCreateRequestInputs + wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) + + +class CompletionsCreateRes(BaseModel): + response: Dict[str, Any] + + class CallsFilter(BaseModel): op_names: Optional[List[str]] = None input_refs: Optional[List[str]] = None @@ -848,8 +888,10 @@ def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ... def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: ... def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ... def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ... - # Action API def execute_batch_action( self, req: ExecuteBatchActionReq ) -> ExecuteBatchActionRes: ... + + # Execute LLM API + def completions_create(self, req: CompletionsCreateReq) -> CompletionsCreateRes: ... diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index 2b8dc7ae170..af13a7c856f 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -559,6 +559,16 @@ def execute_batch_action( tsi.ExecuteBatchActionRes, ) + def completions_create( + self, req: tsi.CompletionsCreateReq + ) -> tsi.CompletionsCreateRes: + return self._generic_request( + "/completions/create", + req, + tsi.CompletionsCreateReq, + tsi.CompletionsCreateRes, + ) + __docspec__ = [ RemoteHTTPTraceServer,