diff --git a/docs/docs/guides/cookbooks/summarization/.gitignore b/docs/docs/guides/cookbooks/summarization/.gitignore new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/docs/docs/guides/cookbooks/summarization/.gitignore @@ -0,0 +1 @@ + diff --git a/docs/docs/reference/gen_notebooks/chain_of_density.md b/docs/docs/reference/gen_notebooks/chain_of_density.md new file mode 100644 index 00000000000..04b898bd88c --- /dev/null +++ b/docs/docs/reference/gen_notebooks/chain_of_density.md @@ -0,0 +1,357 @@ +--- +title: Chain of Density Summarization +--- + + +:::tip[This is a notebook] + +
Open In Colab
Open in Colab
+ +
View in Github
View in Github
+ +::: + + + +# Summarization using Chain of Density + +Summarizing complex technical documents while preserving crucial details is a challenging task. The Chain of Density (CoD) summarization technique offers a solution by iteratively refining summaries to be more concise and information-dense. This guide demonstrates how to implement CoD using Weave, a powerful framework for building, tracking, and evaluating LLM applications. By combining CoD's effectiveness with Weave's robust tooling, you'll learn to create a summarization pipeline that produces high-quality, entity-rich summaries of technical content while gaining insights into the summarization process. + +![Final Evaluation](./media/chain_of_density/eval_comparison.gif) + +## What is Chain of Density Summarization? + +[![arXiv](https://img.shields.io/badge/arXiv-2309.04269-b31b1b.svg)](https://arxiv.org/abs/2309.04269) + +Chain of Density (CoD) is an iterative summarization technique that produces increasingly concise and information-dense summaries. It works by: + +1. Starting with an initial summary +2. Iteratively refining the summary, making it more concise while preserving key information +3. Increasing the density of entities and technical details with each iteration + +This approach is particularly useful for summarizing scientific papers or technical documents where preserving detailed information is crucial. + +## Why use Weave? + +In this tutorial, we'll use Weave to implement and evaluate a Chain of Density summarization pipeline for ArXiv papers. You'll learn how to: + +1. **Track your LLM pipeline**: Use Weave to automatically log inputs, outputs, and intermediate steps of your summarization process. +2. **Evaluate LLM outputs**: Create rigorous, apples-to-apples evaluations of your summaries using Weave's built-in tools. +3. **Build composable operations**: Combine and reuse Weave operations across different parts of your summarization pipeline. +4. **Integrate seamlessly**: Add Weave to your existing Python code with minimal overhead. + +By the end of this tutorial, you'll have created a CoD summarization pipeline that leverages Weave's capabilities for model serving, evaluation, and result tracking. + +## Set up the environment + +First, let's set up our environment and import the necessary libraries: + + +```python +!pip install -qU anthropic weave pydantic requests PyPDF2 set-env-colab-kaggle-dotenv +``` + +>To get an Anthropic API key: +> 1. Sign up for an account at https://www.anthropic.com +> 2. Navigate to the API section in your account settings +> 3. Generate a new API key +> 4. Store the API key securely in your .env file + + +```python +import os +import anthropic +import weave +from datetime import datetime, timezone +from pydantic import BaseModel +import requests +import io +from PyPDF2 import PdfReader +from set_env import set_env + +set_env("WANDB_API_KEY") +set_env("ANTHROPIC_API_KEY") + +weave.init("summarization-chain-of-density-cookbook") +anthropic_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) +``` + +We're using Weave to track our experiment and Anthropic's Claude model for text generation. The `weave.init()` call sets up a new Weave project for our summarization task. + +## Define the ArxivPaper model + +We'll create a simple `ArxivPaper` class to represent our data: + + +```python +# Define ArxivPaper model +class ArxivPaper(BaseModel): + entry_id: str + updated: datetime + published: datetime + title: str + authors: list[str] + summary: str + pdf_url: str + +# Create sample ArxivPaper +arxiv_paper = ArxivPaper( + entry_id="http://arxiv.org/abs/2406.04744v1", + updated=datetime(2024, 6, 7, 8, 43, 7, tzinfo=timezone.utc), + published=datetime(2024, 6, 7, 8, 43, 7, tzinfo=timezone.utc), + title="CRAG -- Comprehensive RAG Benchmark", + authors=["Xiao Yang", "Kai Sun", "Hao Xin"], # Truncated for brevity + summary="Retrieval-Augmented Generation (RAG) has recently emerged as a promising solution...", # Truncated + pdf_url="https://arxiv.org/pdf/2406.04744" +) +``` + +This class encapsulates the metadata and content of an ArXiv paper, which will be the input to our summarization pipeline. + +![Arxiv Paper](./media/chain_of_density/arxiv_paper.gif) + +## Load PDF content + +To work with the full paper content, we'll add a function to load and extract text from PDFs: + + +```python +@weave.op() +def load_pdf(pdf_url: str) -> str: + # Download the PDF + response = requests.get(pdf_url) + pdf_file = io.BytesIO(response.content) + + # Read the PDF + pdf_reader = PdfReader(pdf_file) + + # Extract text from all pages + text = "" + for page in pdf_reader.pages: + text += page.extract_text() + + return text +``` + +## Implement Chain of Density summarization + +Now, let's implement the core CoD summarization logic using Weave operations: + + +```python +# Chain of Density Summarization +@weave.op() +def summarize_current_summary(document: str, instruction: str, current_summary: str = "", iteration: int = 1, model: str = "claude-3-sonnet-20240229"): + prompt = f""" + Document: {document} + Current summary: {current_summary} + Instruction to focus on: {instruction} + Iteration: {iteration} + + Generate an increasingly concise, entity-dense, and highly technical summary from the provided document that specifically addresses the given instruction. + """ + response = anthropic_client.messages.create( + model=model, + max_tokens=4096, + messages=[{"role": "user", "content": prompt}] + ) + return response.content[0].text + +@weave.op() +def iterative_density_summarization(document: str, instruction: str, current_summary: str, density_iterations: int, model: str = "claude-3-sonnet-20240229"): + iteration_summaries = [] + for iteration in range(1, density_iterations + 1): + current_summary = summarize_current_summary(document, instruction, current_summary, iteration, model) + iteration_summaries.append(current_summary) + return current_summary, iteration_summaries + +@weave.op() +def final_summary(instruction: str, current_summary: str, model: str = "claude-3-sonnet-20240229"): + prompt = f""" + Given this summary: {current_summary} + And this instruction to focus on: {instruction} + Create an extremely dense, final summary that captures all key technical information in the most concise form possible, while specifically addressing the given instruction. + """ + return anthropic_client.messages.create( + model=model, + max_tokens=4096, + messages=[{"role": "user", "content": prompt}] + ).content[0].text + +@weave.op() +def chain_of_density_summarization(document: str, instruction: str, current_summary: str = "", model: str = "claude-3-sonnet-20240229", density_iterations: int = 2): + current_summary, iteration_summaries = iterative_density_summarization(document, instruction, current_summary, density_iterations, model) + final_summary_text = final_summary(instruction, current_summary, model) + return { + "final_summary": final_summary_text, + "accumulated_summary": current_summary, + "iteration_summaries": iteration_summaries, + } +``` + +Here's what each function does: + +- `summarize_current_summary`: Generates a single summary iteration based on the current state. +- `iterative_density_summarization`: Applies the CoD technique by calling `summarize_current_summary` multiple times. +- `chain_of_density_summarization`: Orchestrates the entire summarization process and returns the results. + +By using `@weave.op()` decorators, we ensure that Weave tracks the inputs, outputs, and execution of these functions. + +![Chain of Density](./media/chain_of_density/chain_of_density.gif) + +## Create a Weave Model + +Now, let's wrap our summarization pipeline in a Weave Model: + + +```python +# Weave Model +class ArxivChainOfDensityPipeline(weave.Model): + model: str = "claude-3-sonnet-20240229" + density_iterations: int = 3 + + @weave.op() + def predict(self, paper: ArxivPaper, instruction: str) -> dict: + text = load_pdf(paper["pdf_url"]) + result = chain_of_density_summarization(text, instruction, model=self.model, density_iterations=self.density_iterations) + return result + +``` + +This `ArxivChainOfDensityPipeline` class encapsulates our summarization logic as a Weave Model, providing several key benefits: + +1. Automatic experiment tracking: Weave captures inputs, outputs, and parameters for each run of the model. +2. Versioning: Changes to the model's attributes or code are automatically versioned, creating a clear history of how your summarization pipeline evolves over time. +3. Reproducibility: The versioning and tracking make it easy to reproduce any previous result or configuration of your summarization pipeline. +4. Hyperparameter management: Model attributes (like `model` and `density_iterations`) are clearly defined and tracked across different runs, facilitating experimentation. +5. Integration with Weave ecosystem: Using `weave.Model` allows seamless integration with other Weave tools, such as evaluations and serving capabilities. + +![Arxiv Chain of Density Pipeline](./media/chain_of_density/model.gif) + +## Implement evaluation metrics + +To assess the quality of our summaries, we'll implement simple evaluation metrics: + + +```python +import json + +@weave.op() +def evaluate_summary(summary: str, instruction: str, model: str = "claude-3-sonnet-20240229") -> dict: + prompt = f""" + Summary: {summary} + Instruction: {instruction} + + Evaluate the summary based on the following criteria: + 1. Relevance (1-5): How well does the summary address the given instruction? + 2. Conciseness (1-5): How concise is the summary while retaining key information? + 3. Technical Accuracy (1-5): How accurately does the summary convey technical details? + + Your response MUST be in the following JSON format: + {{ + "relevance": {{ + "score": , + "explanation": "" + }}, + "conciseness": {{ + "score": , + "explanation": "" + }}, + "technical_accuracy": {{ + "score": , + "explanation": "" + }} + }} + + Ensure that the scores are integers between 1 and 5, and that the explanations are concise. + """ + response = anthropic_client.messages.create( + model=model, + max_tokens=1000, + messages=[{"role": "user", "content": prompt}] + ) + print(response.content[0].text) + + eval_dict = json.loads(response.content[0].text) + + return { + "relevance": eval_dict['relevance']['score'], + "conciseness": eval_dict['conciseness']['score'], + "technical_accuracy": eval_dict['technical_accuracy']['score'], + "average_score": sum(eval_dict[k]['score'] for k in eval_dict) / 3, + "evaluation_text": response.content[0].text + } +``` + +These evaluation functions use the Claude model to assess the quality of the generated summaries based on relevance, conciseness, and technical accuracy. + +![Evaluation](./media/chain_of_density/evals_main_screen.gif) + +## Create a Weave Dataset and run evaluation + +To evaluate our pipeline, we'll create a Weave Dataset and run an evaluation: + + +```python +# Create a Weave Dataset +dataset = weave.Dataset( + name="arxiv_papers", + rows=[ + { + "paper": arxiv_paper, + "instruction": "What was the approach to experimenting with different data mixtures?" + }, + ] +) + +weave.publish(dataset) +``` + +![Dataset](./media/chain_of_density/eval_dataset.gif) + +For our evaluation, we'll use an LLM-as-a-judge approach. This technique involves using a language model to assess the quality of outputs generated by another model or system. It leverages the LLM's understanding and reasoning capabilities to provide nuanced evaluations, especially for tasks where traditional metrics may fall short. + +[![arXiv](https://img.shields.io/badge/arXiv-2306.05685-b31b1b.svg)](https://arxiv.org/abs/2306.05685) + + +```python +# Define the scorer function +@weave.op() +def quality_scorer(instruction: str, model_output: dict) -> dict: + result = evaluate_summary(model_output["final_summary"], instruction) + return result +``` + + +```python +# Run evaluation +evaluation = weave.Evaluation(dataset=dataset, scorers=[quality_scorer]) +arxiv_chain_of_density_pipeline = ArxivChainOfDensityPipeline() +results = await evaluation.evaluate(arxiv_chain_of_density_pipeline) +``` + +![Final Evaluation](./media/chain_of_density/eval_comparison.gif) + +This code creates a dataset with our sample ArXiv paper, defines a quality scorer, and runs an evaluation of our summarization pipeline. + +## Conclusion + +In this example, we've demonstrated how to implement a Chain of Density summarization pipeline for ArXiv papers using Weave. We've shown how to: + +1. Create Weave operations for each step of the summarization process +2. Wrap the pipeline in a Weave Model for easy tracking and evaluation +3. Implement custom evaluation metrics using Weave operations +4. Create a dataset and run an evaluation of the pipeline + +Weave's seamless integration allows us to track inputs, outputs, and intermediate steps throughout the summarization process, making it easier to debug, optimize, and evaluate our LLM application. + +For more information on Weave and its capabilities, check out the [Weave documentation](https://docs.wandb.ai/weave). You can extend this example to handle larger datasets, implement more sophisticated evaluation metrics, or integrate with other LLM workflows. + + + View Full Report on W&B + diff --git a/docs/docs/reference/gen_notebooks/media/chain_of_density/arxiv_paper.gif b/docs/docs/reference/gen_notebooks/media/chain_of_density/arxiv_paper.gif new file mode 100644 index 00000000000..ae746acf8f0 Binary files /dev/null and b/docs/docs/reference/gen_notebooks/media/chain_of_density/arxiv_paper.gif differ diff --git a/docs/docs/reference/gen_notebooks/media/chain_of_density/chain_of_density.gif b/docs/docs/reference/gen_notebooks/media/chain_of_density/chain_of_density.gif new file mode 100644 index 00000000000..02da71faa26 Binary files /dev/null and b/docs/docs/reference/gen_notebooks/media/chain_of_density/chain_of_density.gif differ diff --git a/docs/docs/reference/gen_notebooks/media/chain_of_density/eval_comparison.gif b/docs/docs/reference/gen_notebooks/media/chain_of_density/eval_comparison.gif new file mode 100644 index 00000000000..df0b7c15e11 Binary files /dev/null and b/docs/docs/reference/gen_notebooks/media/chain_of_density/eval_comparison.gif differ diff --git a/docs/docs/reference/gen_notebooks/media/chain_of_density/eval_dataset.gif b/docs/docs/reference/gen_notebooks/media/chain_of_density/eval_dataset.gif new file mode 100644 index 00000000000..c68db6d594a Binary files /dev/null and b/docs/docs/reference/gen_notebooks/media/chain_of_density/eval_dataset.gif differ diff --git a/docs/docs/reference/gen_notebooks/media/chain_of_density/evals_main_screen.gif b/docs/docs/reference/gen_notebooks/media/chain_of_density/evals_main_screen.gif new file mode 100644 index 00000000000..8739b4747ec Binary files /dev/null and b/docs/docs/reference/gen_notebooks/media/chain_of_density/evals_main_screen.gif differ diff --git a/docs/docs/reference/gen_notebooks/media/chain_of_density/fetch_arxiv_papers.gif b/docs/docs/reference/gen_notebooks/media/chain_of_density/fetch_arxiv_papers.gif new file mode 100644 index 00000000000..bf8d44be585 Binary files /dev/null and b/docs/docs/reference/gen_notebooks/media/chain_of_density/fetch_arxiv_papers.gif differ diff --git a/docs/docs/reference/gen_notebooks/media/chain_of_density/generate_arxiv_query_args.gif b/docs/docs/reference/gen_notebooks/media/chain_of_density/generate_arxiv_query_args.gif new file mode 100644 index 00000000000..c51ca5e94c9 Binary files /dev/null and b/docs/docs/reference/gen_notebooks/media/chain_of_density/generate_arxiv_query_args.gif differ diff --git a/docs/docs/reference/gen_notebooks/media/chain_of_density/model.gif b/docs/docs/reference/gen_notebooks/media/chain_of_density/model.gif new file mode 100644 index 00000000000..6c19afbefd3 Binary files /dev/null and b/docs/docs/reference/gen_notebooks/media/chain_of_density/model.gif differ diff --git a/docs/docs/reference/gen_notebooks/media/chain_of_density/model_extract_images.gif b/docs/docs/reference/gen_notebooks/media/chain_of_density/model_extract_images.gif new file mode 100644 index 00000000000..682d9c3404e Binary files /dev/null and b/docs/docs/reference/gen_notebooks/media/chain_of_density/model_extract_images.gif differ diff --git a/docs/docs/reference/gen_notebooks/media/chain_of_density/model_replace_image_with_descriptions.gif b/docs/docs/reference/gen_notebooks/media/chain_of_density/model_replace_image_with_descriptions.gif new file mode 100644 index 00000000000..78f355f8187 Binary files /dev/null and b/docs/docs/reference/gen_notebooks/media/chain_of_density/model_replace_image_with_descriptions.gif differ diff --git a/docs/docs/tutorial-tracing_2.md b/docs/docs/tutorial-tracing_2.md index 108571cf650..da1980f1155 100644 --- a/docs/docs/tutorial-tracing_2.md +++ b/docs/docs/tutorial-tracing_2.md @@ -5,7 +5,6 @@ In the [Track LLM inputs & outputs](/quickstart) tutorial, the basics of trackin In this tutorial you will learn how to: - **Track data** as it flows though your application - **Track metadata** at call time -- **Export data** that was logged to Weave ## Tracking nested function calls diff --git a/docs/notebooks/chain_of_density.ipynb b/docs/notebooks/chain_of_density.ipynb new file mode 100644 index 00000000000..7a1df2da44b --- /dev/null +++ b/docs/notebooks/chain_of_density.ipynb @@ -0,0 +1,554 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Summarization using Chain of Density\n", + "\n", + "Summarizing complex technical documents while preserving crucial details is a challenging task. The Chain of Density (CoD) summarization technique offers a solution by iteratively refining summaries to be more concise and information-dense. This guide demonstrates how to implement CoD using Weave, a powerful framework for building, tracking, and evaluating LLM applications. By combining CoD's effectiveness with Weave's robust tooling, you'll learn to create a summarization pipeline that produces high-quality, entity-rich summaries of technical content while gaining insights into the summarization process.\n", + "\n", + "![Final Evaluation](./media/chain_of_density/eval_comparison.gif)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## What is Chain of Density Summarization?\n", + "\n", + "[![arXiv](https://img.shields.io/badge/arXiv-2309.04269-b31b1b.svg)](https://arxiv.org/abs/2309.04269)\n", + "\n", + "Chain of Density (CoD) is an iterative summarization technique that produces increasingly concise and information-dense summaries. It works by:\n", + "\n", + "1. Starting with an initial summary\n", + "2. Iteratively refining the summary, making it more concise while preserving key information\n", + "3. Increasing the density of entities and technical details with each iteration\n", + "\n", + "This approach is particularly useful for summarizing scientific papers or technical documents where preserving detailed information is crucial." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Why use Weave?\n", + "\n", + "In this tutorial, we'll use Weave to implement and evaluate a Chain of Density summarization pipeline for ArXiv papers. You'll learn how to:\n", + "\n", + "1. **Track your LLM pipeline**: Use Weave to automatically log inputs, outputs, and intermediate steps of your summarization process.\n", + "2. **Evaluate LLM outputs**: Create rigorous, apples-to-apples evaluations of your summaries using Weave's built-in tools.\n", + "3. **Build composable operations**: Combine and reuse Weave operations across different parts of your summarization pipeline.\n", + "4. **Integrate seamlessly**: Add Weave to your existing Python code with minimal overhead.\n", + "\n", + "By the end of this tutorial, you'll have created a CoD summarization pipeline that leverages Weave's capabilities for model serving, evaluation, and result tracking." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set up the environment\n", + "\n", + "First, let's set up our environment and import the necessary libraries:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU anthropic weave pydantic requests PyPDF2 set-env-colab-kaggle-dotenv" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ">To get an Anthropic API key:\n", + "> 1. Sign up for an account at https://www.anthropic.com\n", + "> 2. Navigate to the API section in your account settings\n", + "> 3. Generate a new API key\n", + "> 4. Store the API key securely in your .env file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import io\n", + "import os\n", + "from datetime import datetime, timezone\n", + "\n", + "import anthropic\n", + "import requests\n", + "from pydantic import BaseModel\n", + "from PyPDF2 import PdfReader\n", + "from set_env import set_env\n", + "\n", + "import weave\n", + "\n", + "set_env(\"WANDB_API_KEY\")\n", + "set_env(\"ANTHROPIC_API_KEY\")\n", + "\n", + "weave.init(\"summarization-chain-of-density-cookbook\")\n", + "anthropic_client = anthropic.Anthropic(api_key=os.getenv(\"ANTHROPIC_API_KEY\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We're using Weave to track our experiment and Anthropic's Claude model for text generation. The `weave.init()` call sets up a new Weave project for our summarization task." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the ArxivPaper model\n", + "\n", + "We'll create a simple `ArxivPaper` class to represent our data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define ArxivPaper model\n", + "class ArxivPaper(BaseModel):\n", + " entry_id: str\n", + " updated: datetime\n", + " published: datetime\n", + " title: str\n", + " authors: list[str]\n", + " summary: str\n", + " pdf_url: str\n", + "\n", + "\n", + "# Create sample ArxivPaper\n", + "arxiv_paper = ArxivPaper(\n", + " entry_id=\"http://arxiv.org/abs/2406.04744v1\",\n", + " updated=datetime(2024, 6, 7, 8, 43, 7, tzinfo=timezone.utc),\n", + " published=datetime(2024, 6, 7, 8, 43, 7, tzinfo=timezone.utc),\n", + " title=\"CRAG -- Comprehensive RAG Benchmark\",\n", + " authors=[\"Xiao Yang\", \"Kai Sun\", \"Hao Xin\"], # Truncated for brevity\n", + " summary=\"Retrieval-Augmented Generation (RAG) has recently emerged as a promising solution...\", # Truncated\n", + " pdf_url=\"https://arxiv.org/pdf/2406.04744\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This class encapsulates the metadata and content of an ArXiv paper, which will be the input to our summarization pipeline.\n", + "\n", + "![Arxiv Paper](./media/chain_of_density/arxiv_paper.gif)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load PDF content\n", + "\n", + "To work with the full paper content, we'll add a function to load and extract text from PDFs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@weave.op()\n", + "def load_pdf(pdf_url: str) -> str:\n", + " # Download the PDF\n", + " response = requests.get(pdf_url)\n", + " pdf_file = io.BytesIO(response.content)\n", + "\n", + " # Read the PDF\n", + " pdf_reader = PdfReader(pdf_file)\n", + "\n", + " # Extract text from all pages\n", + " text = \"\"\n", + " for page in pdf_reader.pages:\n", + " text += page.extract_text()\n", + "\n", + " return text" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Implement Chain of Density summarization\n", + "\n", + "Now, let's implement the core CoD summarization logic using Weave operations:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Chain of Density Summarization\n", + "@weave.op()\n", + "def summarize_current_summary(\n", + " document: str,\n", + " instruction: str,\n", + " current_summary: str = \"\",\n", + " iteration: int = 1,\n", + " model: str = \"claude-3-sonnet-20240229\",\n", + "):\n", + " prompt = f\"\"\"\n", + " Document: {document}\n", + " Current summary: {current_summary}\n", + " Instruction to focus on: {instruction}\n", + " Iteration: {iteration}\n", + "\n", + " Generate an increasingly concise, entity-dense, and highly technical summary from the provided document that specifically addresses the given instruction.\n", + " \"\"\"\n", + " response = anthropic_client.messages.create(\n", + " model=model, max_tokens=4096, messages=[{\"role\": \"user\", \"content\": prompt}]\n", + " )\n", + " return response.content[0].text\n", + "\n", + "\n", + "@weave.op()\n", + "def iterative_density_summarization(\n", + " document: str,\n", + " instruction: str,\n", + " current_summary: str,\n", + " density_iterations: int,\n", + " model: str = \"claude-3-sonnet-20240229\",\n", + "):\n", + " iteration_summaries = []\n", + " for iteration in range(1, density_iterations + 1):\n", + " current_summary = summarize_current_summary(\n", + " document, instruction, current_summary, iteration, model\n", + " )\n", + " iteration_summaries.append(current_summary)\n", + " return current_summary, iteration_summaries\n", + "\n", + "\n", + "@weave.op()\n", + "def final_summary(\n", + " instruction: str, current_summary: str, model: str = \"claude-3-sonnet-20240229\"\n", + "):\n", + " prompt = f\"\"\"\n", + " Given this summary: {current_summary}\n", + " And this instruction to focus on: {instruction}\n", + " Create an extremely dense, final summary that captures all key technical information in the most concise form possible, while specifically addressing the given instruction.\n", + " \"\"\"\n", + " return (\n", + " anthropic_client.messages.create(\n", + " model=model, max_tokens=4096, messages=[{\"role\": \"user\", \"content\": prompt}]\n", + " )\n", + " .content[0]\n", + " .text\n", + " )\n", + "\n", + "\n", + "@weave.op()\n", + "def chain_of_density_summarization(\n", + " document: str,\n", + " instruction: str,\n", + " current_summary: str = \"\",\n", + " model: str = \"claude-3-sonnet-20240229\",\n", + " density_iterations: int = 2,\n", + "):\n", + " current_summary, iteration_summaries = iterative_density_summarization(\n", + " document, instruction, current_summary, density_iterations, model\n", + " )\n", + " final_summary_text = final_summary(instruction, current_summary, model)\n", + " return {\n", + " \"final_summary\": final_summary_text,\n", + " \"accumulated_summary\": current_summary,\n", + " \"iteration_summaries\": iteration_summaries,\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here's what each function does:\n", + "\n", + "- `summarize_current_summary`: Generates a single summary iteration based on the current state.\n", + "- `iterative_density_summarization`: Applies the CoD technique by calling `summarize_current_summary` multiple times.\n", + "- `chain_of_density_summarization`: Orchestrates the entire summarization process and returns the results.\n", + "\n", + "By using `@weave.op()` decorators, we ensure that Weave tracks the inputs, outputs, and execution of these functions.\n", + "\n", + "![Chain of Density](./media/chain_of_density/chain_of_density.gif)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create a Weave Model\n", + "\n", + "Now, let's wrap our summarization pipeline in a Weave Model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Weave Model\n", + "class ArxivChainOfDensityPipeline(weave.Model):\n", + " model: str = \"claude-3-sonnet-20240229\"\n", + " density_iterations: int = 3\n", + "\n", + " @weave.op()\n", + " def predict(self, paper: ArxivPaper, instruction: str) -> dict:\n", + " text = load_pdf(paper[\"pdf_url\"])\n", + " result = chain_of_density_summarization(\n", + " text,\n", + " instruction,\n", + " model=self.model,\n", + " density_iterations=self.density_iterations,\n", + " )\n", + " return result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This `ArxivChainOfDensityPipeline` class encapsulates our summarization logic as a Weave Model, providing several key benefits:\n", + "\n", + "1. Automatic experiment tracking: Weave captures inputs, outputs, and parameters for each run of the model.\n", + "2. Versioning: Changes to the model's attributes or code are automatically versioned, creating a clear history of how your summarization pipeline evolves over time.\n", + "3. Reproducibility: The versioning and tracking make it easy to reproduce any previous result or configuration of your summarization pipeline.\n", + "4. Hyperparameter management: Model attributes (like `model` and `density_iterations`) are clearly defined and tracked across different runs, facilitating experimentation.\n", + "5. Integration with Weave ecosystem: Using `weave.Model` allows seamless integration with other Weave tools, such as evaluations and serving capabilities.\n", + "\n", + "![Arxiv Chain of Density Pipeline](./media/chain_of_density/model.gif)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Implement evaluation metrics\n", + "\n", + "To assess the quality of our summaries, we'll implement simple evaluation metrics:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "\n", + "@weave.op()\n", + "def evaluate_summary(\n", + " summary: str, instruction: str, model: str = \"claude-3-sonnet-20240229\"\n", + ") -> dict:\n", + " prompt = f\"\"\"\n", + " Summary: {summary}\n", + " Instruction: {instruction}\n", + "\n", + " Evaluate the summary based on the following criteria:\n", + " 1. Relevance (1-5): How well does the summary address the given instruction?\n", + " 2. Conciseness (1-5): How concise is the summary while retaining key information?\n", + " 3. Technical Accuracy (1-5): How accurately does the summary convey technical details?\n", + "\n", + " Your response MUST be in the following JSON format:\n", + " {{\n", + " \"relevance\": {{\n", + " \"score\": ,\n", + " \"explanation\": \"\"\n", + " }},\n", + " \"conciseness\": {{\n", + " \"score\": ,\n", + " \"explanation\": \"\"\n", + " }},\n", + " \"technical_accuracy\": {{\n", + " \"score\": ,\n", + " \"explanation\": \"\"\n", + " }}\n", + " }}\n", + "\n", + " Ensure that the scores are integers between 1 and 5, and that the explanations are concise.\n", + " \"\"\"\n", + " response = anthropic_client.messages.create(\n", + " model=model, max_tokens=1000, messages=[{\"role\": \"user\", \"content\": prompt}]\n", + " )\n", + " print(response.content[0].text)\n", + "\n", + " eval_dict = json.loads(response.content[0].text)\n", + "\n", + " return {\n", + " \"relevance\": eval_dict[\"relevance\"][\"score\"],\n", + " \"conciseness\": eval_dict[\"conciseness\"][\"score\"],\n", + " \"technical_accuracy\": eval_dict[\"technical_accuracy\"][\"score\"],\n", + " \"average_score\": sum(eval_dict[k][\"score\"] for k in eval_dict) / 3,\n", + " \"evaluation_text\": response.content[0].text,\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "These evaluation functions use the Claude model to assess the quality of the generated summaries based on relevance, conciseness, and technical accuracy.\n", + "\n", + "![Evaluation](./media/chain_of_density/evals_main_screen.gif)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create a Weave Dataset and run evaluation\n", + "\n", + "To evaluate our pipeline, we'll create a Weave Dataset and run an evaluation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a Weave Dataset\n", + "dataset = weave.Dataset(\n", + " name=\"arxiv_papers\",\n", + " rows=[\n", + " {\n", + " \"paper\": arxiv_paper,\n", + " \"instruction\": \"What was the approach to experimenting with different data mixtures?\",\n", + " },\n", + " ],\n", + ")\n", + "\n", + "weave.publish(dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Dataset](./media/chain_of_density/eval_dataset.gif)\n", + "\n", + "For our evaluation, we'll use an LLM-as-a-judge approach. This technique involves using a language model to assess the quality of outputs generated by another model or system. It leverages the LLM's understanding and reasoning capabilities to provide nuanced evaluations, especially for tasks where traditional metrics may fall short." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![arXiv](https://img.shields.io/badge/arXiv-2306.05685-b31b1b.svg)](https://arxiv.org/abs/2306.05685)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the scorer function\n", + "@weave.op()\n", + "def quality_scorer(instruction: str, model_output: dict) -> dict:\n", + " result = evaluate_summary(model_output[\"final_summary\"], instruction)\n", + " return result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run evaluation\n", + "evaluation = weave.Evaluation(dataset=dataset, scorers=[quality_scorer])\n", + "arxiv_chain_of_density_pipeline = ArxivChainOfDensityPipeline()\n", + "results = await evaluation.evaluate(arxiv_chain_of_density_pipeline)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Final Evaluation](./media/chain_of_density/eval_comparison.gif)\n", + "\n", + "This code creates a dataset with our sample ArXiv paper, defines a quality scorer, and runs an evaluation of our summarization pipeline." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "In this example, we've demonstrated how to implement a Chain of Density summarization pipeline for ArXiv papers using Weave. We've shown how to:\n", + "\n", + "1. Create Weave operations for each step of the summarization process\n", + "2. Wrap the pipeline in a Weave Model for easy tracking and evaluation\n", + "3. Implement custom evaluation metrics using Weave operations\n", + "4. Create a dataset and run an evaluation of the pipeline\n", + "\n", + "Weave's seamless integration allows us to track inputs, outputs, and intermediate steps throughout the summarization process, making it easier to debug, optimize, and evaluate our LLM application.\n", + "\n", + "For more information on Weave and its capabilities, check out the [Weave documentation](https://docs.wandb.ai/weave). You can extend this example to handle larger datasets, implement more sophisticated evaluation metrics, or integrate with other LLM workflows.\n", + "\n", + "\n", + " View Full Report on W&B\n", + "" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/notebooks/dspy_prompt_optimization.ipynb b/docs/notebooks/dspy_prompt_optimization.ipynb new file mode 100644 index 00000000000..94c6b485e15 --- /dev/null +++ b/docs/notebooks/dspy_prompt_optimization.ipynb @@ -0,0 +1,434 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "# Optimizing LLM Workflows Using DSPy and Weave\n", + "\n", + "The [BIG-bench (Beyond the Imitation Game Benchmark)](https://github.com/google/BIG-bench) is a collaborative benchmark intended to probe large language models and extrapolate their future capabilities consisting of more than 200 tasks. The [BIG-Bench Hard (BBH)](https://github.com/suzgunmirac/BIG-Bench-Hard) is a suite of 23 most challenging BIG-Bench tasks that can be quite difficult to be solved using the current generation of language models.\n", + "\n", + "This tutorial demonstrates how we can improve the performance of our LLM workflow implemented on the **causal judgement task** from the BIG-bench Hard benchmark and evaluate our prompting strategies. We will use [DSPy](https://dspy-docs.vercel.app/) for implementing our LLM workflow and optimizing our prompting strategy. We will also use [Weave](../docs/introduction.md) to track our LLM workflow and evaluate our prompting strategies." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installing the Dependencies\n", + "\n", + "We need the following libraries for this tutorial:\n", + "\n", + "- [DSPy](https://dspy-docs.vercel.app/) for building the LLM workflow and optimizing it.\n", + "- [Weave](../introduction.md) to track our LLM workflow and evaluate our prompting strategies.\n", + "- [datasets](https://huggingface.co/docs/datasets/index) to access the Big-Bench Hard dataset from HuggingFace Hub." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU dspy-ai weave datasets" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since we'll be using [OpenAI API](https://openai.com/index/openai-api/) as the LLM Vendor, we will also need an OpenAI API key. You can [sign up](https://platform.openai.com/signup) on the OpenAI platform to get your own API key." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from getpass import getpass\n", + "\n", + "api_key = getpass(\"Enter you OpenAI API key: \")\n", + "os.environ[\"OPENAI_API_KEY\"] = api_key" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Enable Tracking using Weave\n", + "\n", + "Weave is currently integrated with DSPy, and including [`weave.init`](../docs/reference/python-sdk/weave/index.md) at the start of our code lets us automatically trace our DSPy functions which can be explored in the Weave UI. Check out the [Weave integration docs for DSPy](../docs/guides/integrations/dspy.md) to learn more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import weave\n", + "\n", + "weave.init(project_name=\"dspy-bigbench-hard\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutorial, we use a metadata class inherited from [`weave.Model`](../docs/guides/core-types/models.md) to manage our metadata." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Metadata(weave.Model):\n", + " dataset_address: str = \"maveriq/bigbenchhard\"\n", + " big_bench_hard_task: str = \"causal_judgement\"\n", + " num_train_examples: int = 50\n", + " openai_model: str = \"gpt-3.5-turbo\"\n", + " openai_max_tokens: int = 2048\n", + " max_bootstrapped_demos: int = 8\n", + " max_labeled_demos: int = 8\n", + "\n", + "\n", + "metadata = Metadata()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "| ![](../static/img/dspy_prompt_optimiztion/metadata.gif) |\n", + "|---|\n", + "| The `Metadata` objects are automatically versioned and traced when functions consuming them are traced |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the BIG-Bench Hard Dataset\n", + "\n", + "We will load this dataset from HuggingFace Hub, split into training and validation sets, and [publish](../docs/guides/core-types/datasets.md) them on Weave, this will let us version the datasets, and also use [`weave.Evaluation`](../docs/guides/core-types/evaluations.md) to evaluate our prompting strategy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import dspy\n", + "from datasets import load_dataset\n", + "\n", + "\n", + "@weave.op()\n", + "def get_dataset(metadata: Metadata):\n", + " # load the BIG-Bench Hard dataset corresponding to the task from Huggingface Hug\n", + " dataset = load_dataset(metadata.dataset_address, metadata.big_bench_hard_task)[\n", + " \"train\"\n", + " ]\n", + "\n", + " # create the training and validation datasets\n", + " rows = [{\"question\": data[\"input\"], \"answer\": data[\"target\"]} for data in dataset]\n", + " train_rows = rows[0 : metadata.num_train_examples]\n", + " val_rows = rows[metadata.num_train_examples :]\n", + "\n", + " # create the training and validation examples consisting of `dspy.Example` objects\n", + " dspy_train_examples = [\n", + " dspy.Example(row).with_inputs(\"question\") for row in train_rows\n", + " ]\n", + " dspy_val_examples = [dspy.Example(row).with_inputs(\"question\") for row in val_rows]\n", + "\n", + " # publish the datasets to the Weave, this would let us version the data and use for evaluation\n", + " weave.publish(\n", + " weave.Dataset(\n", + " name=f\"bigbenchhard_{metadata.big_bench_hard_task}_train\", rows=train_rows\n", + " )\n", + " )\n", + " weave.publish(\n", + " weave.Dataset(\n", + " name=f\"bigbenchhard_{metadata.big_bench_hard_task}_val\", rows=val_rows\n", + " )\n", + " )\n", + "\n", + " return dspy_train_examples, dspy_val_examples\n", + "\n", + "\n", + "dspy_train_examples, dspy_val_examples = get_dataset(metadata)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "| ![](../static/img/dspy_prompt_optimiztion/datasets.gif) |\n", + "|---|\n", + "| The datasets, once published, can be explored in the Weave UI |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The DSPy Program\n", + "\n", + "[DSPy](https://dspy-docs.vercel.app) is a framework that pushes building new LM pipelines away from manipulating free-form strings and closer to programming (composing modular operators to build text transformation graphs) where a compiler automatically generates optimized LM invocation strategies and prompts from a program.\n", + "\n", + "We will use the [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) abstraction to make LLM calls to [GPT3.5 Turbo](https://platform.openai.com/docs/models/gpt-3-5-turbo)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "system_prompt = \"\"\"\n", + "You are an expert in the field of causal reasoning. You are to analyze the a given question carefully and answer in `Yes` or `No`.\n", + "You should also provide a detailed explanation justifying your answer.\n", + "\"\"\"\n", + "\n", + "llm = dspy.OpenAI(model=\"gpt-3.5-turbo\", system_prompt=system_prompt)\n", + "dspy.settings.configure(lm=llm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Writing the Causal Reasoning Signature\n", + "\n", + "A [signature](https://dspy-docs.vercel.app/docs/building-blocks/signatures) is a declarative specification of input/output behavior of a [DSPy module](https://dspy-docs.vercel.app/docs/building-blocks/modules) which are task-adaptive components—akin to neural network layers—that abstract any particular text transformation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pydantic import BaseModel, Field\n", + "\n", + "\n", + "class Input(BaseModel):\n", + " query: str = Field(description=\"The question to be answered\")\n", + "\n", + "\n", + "class Output(BaseModel):\n", + " answer: str = Field(description=\"The answer for the question\")\n", + " confidence: float = Field(\n", + " ge=0, le=1, description=\"The confidence score for the answer\"\n", + " )\n", + " explanation: str = Field(description=\"The explanation for the answer\")\n", + "\n", + "\n", + "class QuestionAnswerSignature(dspy.Signature):\n", + " input: Input = dspy.InputField()\n", + " output: Output = dspy.OutputField()\n", + "\n", + "\n", + "class CausalReasoningModule(dspy.Module):\n", + " def __init__(self):\n", + " self.prog = dspy.TypedPredictor(QuestionAnswerSignature)\n", + "\n", + " @weave.op()\n", + " def forward(self, question) -> dict:\n", + " return self.prog(input=Input(query=question)).output.dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's test our LLM workflow, i.e., the `CausalReasoningModule` on an example from the causal reasoning subset of Big-Bench Hard." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import rich\n", + "\n", + "baseline_module = CausalReasoningModule()\n", + "\n", + "prediction = baseline_module(dspy_train_examples[0][\"question\"])\n", + "rich.print(prediction)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "| ![](../static/img/dspy_prompt_optimiztion/dspy_module_trace.gif) |\n", + "|---|\n", + "| Here's how you can explore the traces of the `CausalReasoningModule` in the Weave UI |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluating our DSPy Program\n", + "\n", + "Now that we have a baseline prompting strategy, let's evaluate it on our validation set using [`weave.Evaluation`](../docs/guides/core-types/evaluations.md) on a simple metric that matches the predicted answer with the ground truth. Weave will take each example, pass it through your application and score the output on multiple custom scoring functions. By doing this, you'll have a view of the performance of your application, and a rich UI to drill into individual outputs and scores.\n", + "\n", + "First, we need to create a simple weave evaluation scoring function that tells whether the answer from the baseline module's output is the same as the ground truth answer or not. Scoring functions need to have a `model_output` keyword argument, but the other arguments are user defined and are taken from the dataset examples. It will only take the necessary keys by using a dictionary key based on the argument name." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@weave.op()\n", + "def weave_evaluation_scorer(answer: str, model_output: Output) -> dict:\n", + " return {\"match\": int(answer.lower() == model_output[\"answer\"].lower())}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we can simply define the evaluation and run it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "validation_dataset = weave.ref(\n", + " f\"bigbenchhard_{metadata.big_bench_hard_task}_val:v0\"\n", + ").get()\n", + "\n", + "evaluation = weave.Evaluation(\n", + " name=\"baseline_causal_reasoning_module\",\n", + " dataset=validation_dataset,\n", + " scorers=[weave_evaluation_scorer],\n", + ")\n", + "\n", + "await evaluation.evaluate(baseline_module.forward)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::note\n", + "If you're running from a python script, you can use the following code to run the evaluation:\n", + "\n", + "```python\n", + "import asyncio\n", + "asyncio.run(evaluation.evaluate(baseline_module.forward))\n", + "```\n", + ":::\n", + "\n", + ":::warning\n", + "Running the evaluation causal reasoning dataset will cost approximately $0.24 in OpenAI credits.\n", + ":::" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optimizing our DSPy Program\n", + "\n", + "Now, that we have a baseline DSPy program, let us try to improve its performance for causal reasoning using a [DSPy teleprompter](https://dspy-docs.vercel.app/docs/building-blocks/optimizers) that can tune the parameters of a DSPy program to maximize the specified metrics. In this tutorial, we use the [BootstrapFewShot](https://dspy-docs.vercel.app/api/category/optimizers) teleprompter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dspy.teleprompt import BootstrapFewShot\n", + "\n", + "\n", + "@weave.op()\n", + "def get_optimized_program(model: dspy.Module, metadata: Metadata) -> dspy.Module:\n", + " @weave.op()\n", + " def dspy_evaluation_metric(true, prediction, trace=None):\n", + " return prediction[\"answer\"].lower() == true.answer.lower()\n", + "\n", + " teleprompter = BootstrapFewShot(\n", + " metric=dspy_evaluation_metric,\n", + " max_bootstrapped_demos=metadata.max_bootstrapped_demos,\n", + " max_labeled_demos=metadata.max_labeled_demos,\n", + " )\n", + " return teleprompter.compile(model, trainset=dspy_train_examples)\n", + "\n", + "\n", + "optimized_module = get_optimized_program(baseline_module, metadata)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::warning\n", + "Running the evaluation causal reasoning dataset will cost approximately $0.04 in OpenAI credits.\n", + ":::\n", + "\n", + "| ![](../static/img/dspy_prompt_optimiztion/dspy_compile.png) |\n", + "|---|\n", + "| You can explore the traces of the optimization process in the Weave UI. |\n", + "\n", + "Now that we have our optimized program (the optimized prompting strategy), let's evaluate it once again on our validation set and compare it with our baseline DSPy program." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "evaluation = weave.Evaluation(\n", + " name=\"optimized_causal_reasoning_module\",\n", + " dataset=validation_dataset,\n", + " scorers=[weave_evaluation_scorer],\n", + ")\n", + "\n", + "await evaluation.evaluate(optimized_module.forward)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "| ![](../static/img/dspy_prompt_optimiztion/eval_comparison.gif) |\n", + "|---|\n", + "| Comparing the evalution of the baseline program with the optimized one shows that the optimized program answers the causal reasoning questions with siginificantly more accuracy. |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/static/img/dspy_prompt_optimiztion/datasets.gif b/docs/static/img/dspy_prompt_optimiztion/datasets.gif new file mode 100644 index 00000000000..239c7c74767 Binary files /dev/null and b/docs/static/img/dspy_prompt_optimiztion/datasets.gif differ diff --git a/docs/static/img/dspy_prompt_optimiztion/dspy_compile.png b/docs/static/img/dspy_prompt_optimiztion/dspy_compile.png new file mode 100644 index 00000000000..1afff4bf60d Binary files /dev/null and b/docs/static/img/dspy_prompt_optimiztion/dspy_compile.png differ diff --git a/docs/static/img/dspy_prompt_optimiztion/dspy_module_trace.gif b/docs/static/img/dspy_prompt_optimiztion/dspy_module_trace.gif new file mode 100644 index 00000000000..970da65a1dd Binary files /dev/null and b/docs/static/img/dspy_prompt_optimiztion/dspy_module_trace.gif differ diff --git a/docs/static/img/dspy_prompt_optimiztion/eval_comparison.gif b/docs/static/img/dspy_prompt_optimiztion/eval_comparison.gif new file mode 100644 index 00000000000..2faccf15427 Binary files /dev/null and b/docs/static/img/dspy_prompt_optimiztion/eval_comparison.gif differ diff --git a/docs/static/img/dspy_prompt_optimiztion/metadata.gif b/docs/static/img/dspy_prompt_optimiztion/metadata.gif new file mode 100644 index 00000000000..5a00a7375f2 Binary files /dev/null and b/docs/static/img/dspy_prompt_optimiztion/metadata.gif differ diff --git a/weave-js/src/common/components/elements/LegacyWBIcon.tsx b/weave-js/src/common/components/elements/LegacyWBIcon.tsx index b1fce5a4895..fa440a9ba03 100644 --- a/weave-js/src/common/components/elements/LegacyWBIcon.tsx +++ b/weave-js/src/common/components/elements/LegacyWBIcon.tsx @@ -26,6 +26,10 @@ export interface LegacyWBIconProps { style?: any; 'data-test'?: any; + + role?: string; + ariaHidden?: string; + ariaLabel?: string; } const LegacyWBIconComp = React.forwardRef( @@ -42,6 +46,10 @@ const LegacyWBIconComp = React.forwardRef( onMouseLeave, style, 'data-test': dataTest, + role, + title, + ariaHidden, + ariaLabel, }, ref ) => { @@ -59,6 +67,10 @@ const LegacyWBIconComp = React.forwardRef( onMouseLeave, style, 'data-test': dataTest, + role, + title, + 'aria-hidden': ariaHidden, + 'aria-label': ariaLabel, }; if (ref == null) { return ; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse2/CellValue.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse2/CellValue.tsx index e37595dbd0e..55c52832283 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse2/CellValue.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse2/CellValue.tsx @@ -6,6 +6,8 @@ import {parseRef} from '../../../../react'; import {ValueViewNumber} from '../Browse3/pages/CallPage/ValueViewNumber'; import {ValueViewPrimitive} from '../Browse3/pages/CallPage/ValueViewPrimitive'; import {isRef} from '../Browse3/pages/common/util'; +import {isCustomWeaveTypePayload} from '../Browse3/typeViews/customWeaveType.types'; +import {CustomWeaveTypeDispatcher} from '../Browse3/typeViews/CustomWeaveTypeDispatcher'; import {CellValueBoolean} from './CellValueBoolean'; import {CellValueImage} from './CellValueImage'; import {CellValueString} from './CellValueString'; @@ -64,5 +66,8 @@ export const CellValue = ({value, isExpanded = false}: CellValueProps) => { ); } + if (isCustomWeaveTypePayload(value)) { + return ; + } return ; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse2/browse2Util.ts b/weave-js/src/components/PagePanelComponents/Home/Browse2/browse2Util.ts index f8aff50ee95..9bfb3e3e1d4 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse2/browse2Util.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse2/browse2Util.ts @@ -1,9 +1,31 @@ -export const flattenObject = ( +/** + * Flatten an object, but preserve any object that has a `_type` field. + * This is critical for handling "Weave Types" - payloads that should be + * treated as holistic objects, rather than flattened. + */ +export const flattenObjectPreservingWeaveTypes = (obj: { + [key: string]: any; +}) => { + return flattenObject(obj, '', {}, (key, value) => { + return ( + typeof value !== 'object' || + value == null || + value._type !== 'CustomWeaveType' + ); + }); +}; + +const flattenObject = ( obj: {[key: string]: any}, parentKey: string = '', - result: {[key: string]: any} = {} + result: {[key: string]: any} = {}, + shouldFlatten: (key: string, value: any) => boolean = () => true ) => { - if (typeof obj !== 'object' || obj === null) { + if ( + typeof obj !== 'object' || + obj === null || + !shouldFlatten(parentKey, obj) + ) { return obj; } const keys = Object.keys(obj); @@ -14,31 +36,14 @@ export const flattenObject = ( const newKey = parentKey ? `${parentKey}.${key}` : key; if (Array.isArray(obj[key])) { result[newKey] = obj[key]; - } else if (typeof obj[key] === 'object') { - flattenObject(obj[key], newKey, result); + } else if ( + typeof obj[key] === 'object' && + shouldFlatten(newKey, obj[key]) + ) { + flattenObject(obj[key], newKey, result, shouldFlatten); } else { result[newKey] = obj[key]; } }); return result; }; -export const unflattenObject = (obj: {[key: string]: any}) => { - const result: {[key: string]: any} = {}; - for (const key in obj) { - if (!obj.hasOwnProperty(key)) { - continue; - } - const keys = key.split('.'); - let current = result; - for (let i = 0; i < keys.length; i++) { - const k = keys[i]; - if (i === keys.length - 1) { - current[k] = obj[key]; - } else { - current[k] = current[k] || {}; - } - current = current[k]; - } - } - return result; -}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/NotFoundPanel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/NotFoundPanel.tsx new file mode 100644 index 00000000000..54aeb4477aa --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/NotFoundPanel.tsx @@ -0,0 +1,20 @@ +import {ErrorPanel} from '@wandb/weave/components/ErrorPanel'; +import React, {FC, useContext} from 'react'; + +import {Button} from '../../../Button'; +import {useClosePeek, WeaveflowPeekContext} from './context'; + +export const NotFoundPanel: FC<{title: string}> = ({title}) => { + const close = useClosePeek(); + const {isPeeking} = useContext(WeaveflowPeekContext); + return ( +
+
+ {isPeeking &&
+
+ +
+
+ ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/CellFilterWrapper.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/CellFilterWrapper.tsx index 2547dd19f93..afbb66e8b39 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/CellFilterWrapper.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/CellFilterWrapper.tsx @@ -11,6 +11,7 @@ type CellFilterWrapperProps = { field: string; operation: string | null; value: any; + style?: React.CSSProperties; }; export const CellFilterWrapper = ({ @@ -19,6 +20,7 @@ export const CellFilterWrapper = ({ field, operation, value, + style, }: CellFilterWrapperProps) => { const onClickCapture = onAddFilter ? (e: React.MouseEvent) => { @@ -31,5 +33,9 @@ export const CellFilterWrapper = ({ } : undefined; - return
{children}
; + return ( +
+ {children} +
+ ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx index 323e6ad4c71..15e9ca8e64b 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx @@ -8,6 +8,7 @@ import styled from 'styled-components'; import {MOON_800} from '../../../../../../common/css/color.styles'; import {Button} from '../../../../../Button'; import {useWeaveflowRouteContext, WeaveflowPeekContext} from '../../context'; +import {CustomWeaveTypeProjectContext} from '../../typeViews/CustomWeaveTypeDispatcher'; import {CallsTable} from '../CallsPage/CallsTable'; import {KeyValueTable} from '../common/KeyValueTable'; import {CallLink, opNiceName} from '../common/Links'; @@ -117,7 +118,10 @@ export const CallDetails: FC<{ flex: '0 0 auto', p: 2, }}> - + + + ) : ( - + + + )} {multipleChildCallOpRefs.map(opVersionRef => { @@ -251,13 +258,15 @@ const getDisplayInputsAndOutput = (call: CallSchema) => { const span = call.rawSpan; const inputKeys = span.inputs._keys ?? - Object.keys(span.inputs).filter(k => !k.startsWith('_')); + Object.keys(span.inputs).filter(k => !k.startsWith('_') || k === '_type'); const inputs = _.fromPairs(inputKeys.map(k => [k, span.inputs[k]])); const callOutput = span.output ?? {}; const outputKeys = callOutput._keys ?? - Object.keys(callOutput).filter(k => k === '_result' || !k.startsWith('_')); + Object.keys(callOutput).filter( + k => k === '_result' || !k.startsWith('_') || k === '_type' + ); const output = _.fromPairs(outputKeys.map(k => [k, callOutput[k]])); return {inputs, output}; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index 517bb1b4e24..78af2e03159 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -1,5 +1,4 @@ import Box from '@mui/material/Box'; -import {ErrorPanel} from '@wandb/weave/components/ErrorPanel'; import {Loading} from '@wandb/weave/components/Loading'; import {useViewTraceEvent} from '@wandb/weave/integrations/analytics/useViewEvents'; import React, {FC, useCallback} from 'react'; @@ -9,12 +8,9 @@ import {makeRefCall} from '../../../../../../util/refs'; import {Button} from '../../../../../Button'; import {Tailwind} from '../../../../../Tailwind'; import {Browse2OpDefCode} from '../../../Browse2/Browse2OpDefCode'; -import { - TRACETREE_PARAM, - useClosePeek, - useWeaveflowCurrentRouteContext, -} from '../../context'; +import {TRACETREE_PARAM, useWeaveflowCurrentRouteContext} from '../../context'; import {FeedbackGrid} from '../../feedback/FeedbackGrid'; +import {NotFoundPanel} from '../../NotFoundPanel'; import {isEvaluateOp} from '../common/heuristics'; import {CenteredAnimatedLoader} from '../common/Loader'; import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout'; @@ -26,7 +22,6 @@ import {CallDetails} from './CallDetails'; import {CallOverview} from './CallOverview'; import {CallSummary} from './CallSummary'; import {CallTraceView, useCallFlattenedTraceTree} from './CallTraceView'; - export const CallPage: FC<{ entity: string; project: string; @@ -34,7 +29,6 @@ export const CallPage: FC<{ path?: string; }> = props => { const {useCall} = useWFHooks(); - const close = useClosePeek(); const call = useCall({ entity: props.entity, @@ -45,16 +39,7 @@ export const CallPage: FC<{ if (call.loading) { return ; } else if (call.result === null) { - return ( -
-
-
-
- -
-
- ); + return ; } return ; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallSummary.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallSummary.tsx index 793191fbd80..daa866cb4fe 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallSummary.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallSummary.tsx @@ -32,7 +32,7 @@ export const CallSummary: React.FC<{ ); return ( -
+
parseRef(props.tableRefUri) as WeaveObjectRef, + [props.tableRefUri] + ); + // Determines if the table itself is truncated const isTruncated = useMemo(() => { return (fetchQuery.result ?? []).length > MAX_ROWS; @@ -96,16 +106,19 @@ export const WeaveCHTable: FC<{ ); return ( - + + + ); }; @@ -133,7 +146,7 @@ export const DataTableView: FC<{ if (val == null) { return {}; } else if (typeof val === 'object' && !Array.isArray(val)) { - return flattenObject(val); + return flattenObjectPreservingWeaveTypes(val); } return {'': val}; }); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx index 8868f8a1d7d..24e10a3dd09 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx @@ -21,6 +21,7 @@ import {LoadingDots} from '../../../../../LoadingDots'; import {Browse2OpDefCode} from '../../../Browse2/Browse2OpDefCode'; import {parseRefMaybe} from '../../../Browse2/SmallRef'; import {StyledDataGrid} from '../../StyledDataGrid'; +import {isCustomWeaveTypePayload} from '../../typeViews/customWeaveType.types'; import {isRef} from '../common/util'; import { LIST_INDEX_EDGE_NAME, @@ -163,6 +164,37 @@ export const ObjectViewer = ({ } > = []; traverse(resolvedData, context => { + // Ops should be migrated to the generic CustomWeaveType pattern, but for + // now they are custom handled. + const isOpPayload = context.value?.weave_type?.type === 'Op'; + + if (isCustomWeaveTypePayload(context.value) && !isOpPayload) { + /** + * This block adds an "empty" key that is used to render the custom + * weave type. In the event that a custom type has both properties AND + * custom views, then we might need to extend / modify this part. + */ + const refBackingData = context.value?._ref; + let depth = context.depth; + let path = context.path; + if (refBackingData) { + contexts.push({ + ...context, + isExpandableRef: true, + }); + depth += 1; + path = context.path.plus(''); + } + contexts.push({ + depth, + isLeaf: true, + path, + value: context.value, + valueType: context.valueType, + }); + return 'skip'; + } + if (context.depth !== 0) { const contextTail = context.path.tail(); const isNullDescription = @@ -207,7 +239,8 @@ export const ObjectViewer = ({ if (USE_TABLE_FOR_ARRAYS && context.valueType === 'array') { return 'skip'; } - if (context.value?._ref && context.value?.weave_type?.type === 'Op') { + if (context.value?._ref && isOpPayload) { + // This should be moved to the CustomWeaveType pattern. contexts.push({ depth: context.depth + 1, isLeaf: true, @@ -377,11 +410,15 @@ export const ObjectViewer = ({ isRef(params.model.value) && (parseRefMaybe(params.model.value) as any).weaveKind === 'table'; const {isCode} = params.model; + const isCustomWeaveType = isCustomWeaveTypePayload( + params.model.value + ); if ( isNonRefString || (isArray && USE_TABLE_FOR_ARRAYS) || isTableRef || - isCode + isCode || + isCustomWeaveType ) { return 'auto'; } diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx index d5176339465..6a72ee6e6d1 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx @@ -14,6 +14,8 @@ import {isWeaveObjectRef, parseRef} from '../../../../../../react'; import {Alert} from '../../../../../Alert'; import {Button} from '../../../../../Button'; import {CodeEditor} from '../../../../../CodeEditor'; +import {isCustomWeaveTypePayload} from '../../typeViews/customWeaveType.types'; +import {CustomWeaveTypeDispatcher} from '../../typeViews/CustomWeaveTypeDispatcher'; import {isRef} from '../common/util'; import {OBJECT_ATTR_EDGE_NAME} from '../wfReactInterface/constants'; import {WeaveCHTable, WeaveCHTableSourceRefContext} from './DataTableView'; @@ -119,7 +121,7 @@ const ObjectViewerSectionNonEmpty = ({ ); } return null; - }, [apiRef, mode, data, expandedIds]); + }, [mode, apiRef, data, expandedIds]); const setTreeExpanded = useCallback( (setIsExpanded: boolean) => { @@ -215,9 +217,20 @@ export const ObjectViewerSection = ({ noHide, isExpanded, }: ObjectViewerSectionProps) => { - const numKeys = Object.keys(data).length; const currentRef = useContext(WeaveCHTableSourceRefContext); + if (isCustomWeaveTypePayload(data)) { + return ( + <> + + {title} + + + + ); + } + + const numKeys = Object.keys(data).length; if (numKeys === 0) { return ( <> diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx index 1536ff6e9a1..581aba7a729 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx @@ -1,7 +1,9 @@ import React, {useMemo} from 'react'; -import {parseRef} from '../../../../../../react'; +import {isWeaveObjectRef, parseRef} from '../../../../../../react'; import {parseRefMaybe, SmallRef} from '../../../Browse2/SmallRef'; +import {isCustomWeaveTypePayload} from '../../typeViews/customWeaveType.types'; +import {CustomWeaveTypeDispatcher} from '../../typeViews/CustomWeaveTypeDispatcher'; import {isRef} from '../common/util'; import { DataTableView, @@ -77,5 +79,36 @@ export const ValueView = ({data, isExpanded}: ValueViewProps) => { return
{JSON.stringify(data.value)}
; } + if (data.valueType === 'object') { + if (isCustomWeaveTypePayload(data.value)) { + // This is a little ugly, but essentially if the data is coming from an + // expanded ref, then we want to use that ref to get the entity and project. + // Else we just use the current entity and project. + let entityForWeaveType: string | undefined; + let projectForWeaveType: string | undefined; + + if (valueIsExpandedRef(data)) { + const parsedRef = parseRef((data.value as any)._ref); + if (isWeaveObjectRef(parsedRef)) { + entityForWeaveType = parsedRef.entityName; + projectForWeaveType = parsedRef.projectName; + } + } + + // If we have have a custom view for this weave type, use it. + return ( + + ); + } + } + return
{data.value.toString()}
; }; + +const valueIsExpandedRef = (data: ValueData) => { + return data.value != null && (data.value as any)._ref != null; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx index 7d8ba65cb86..c0007df46dd 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx @@ -284,6 +284,10 @@ function buildCallsTableColumns( const {cols: newCols, groupingModel} = buildDynamicColumns( filteredDynamicColumnNames, + row => { + const [rowEntity, rowProject] = row.project_id.split('/'); + return {entity: rowEntity, project: rowProject}; + }, (row, key) => (row as any)[key], key => expandedRefCols.has(key), key => columnsWithRefs.has(key), diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts index 2ab5664005e..59570f4c14e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts @@ -1,7 +1,7 @@ import _ from 'lodash'; import {useMemo} from 'react'; -import {flattenObject} from '../../../../../Browse2/browse2Util'; +import {flattenObjectPreservingWeaveTypes} from '../../../../../Browse2/browse2Util'; import { buildCompositeMetricsMap, CompositeScoreMetrics, @@ -138,8 +138,10 @@ export const useFilteredAggregateRows = (state: EvaluationComparisonState) => { evaluationCallId: predictAndScoreRes.evaluationCallId, inputDigest: datasetRow.digest, inputRef: predictAndScoreRes.exampleRef, - input: flattenObject({input: datasetRow.val}), - output: flattenObject({output}), + input: flattenObjectPreservingWeaveTypes({ + input: datasetRow.val, + }), + output: flattenObjectPreservingWeaveTypes({output}), scores: Object.fromEntries( [...Object.entries(state.data.scoreMetrics)].map( ([scoreKey, scoreVal]) => { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx index 92e51f0e57d..f8c85adeae7 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx @@ -4,6 +4,8 @@ import React, {useMemo} from 'react'; import {maybePluralizeWord} from '../../../../../core/util/string'; import {LoadingDots} from '../../../../LoadingDots'; +import {NotFoundPanel} from '../NotFoundPanel'; +import {CustomWeaveTypeProjectContext} from '../typeViews/CustomWeaveTypeDispatcher'; import {WeaveCHTableSourceRefContext} from './CallPage/DataTableView'; import {ObjectViewerSection} from './CallPage/ObjectViewerSection'; import {WFHighLevelCallFilter} from './CallsPage/callsTableFilter'; @@ -58,7 +60,7 @@ export const ObjectVersionPage: React.FC<{ if (objectVersion.loading) { return ; } else if (objectVersion.result == null) { - return
Object not found
; + return ; } return ( @@ -207,7 +209,7 @@ const ObjectVersionPageInner: React.FC<{ { label: 'Values', content: ( - + ) : ( - + + + )} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx index aa6a9161a25..93af4fcacc5 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx @@ -239,9 +239,16 @@ const ObjectVersionsTable: React.FC<{ }); }); - const {cols: newCols, groupingModel} = - buildDynamicColumns(dynamicFields, (row, key) => { - const obj: ObjectVersionSchema = (row as any).obj; + const {cols: newCols, groupingModel} = buildDynamicColumns<{ + obj: ObjectVersionSchema; + }>( + dynamicFields, + row => ({ + entity: row.obj.entity, + project: row.obj.project, + }), + (row, key) => { + const obj: ObjectVersionSchema = row.obj; const res = obj.val?.[key]; if (isTableRef(res)) { // This whole block is a hack to make the table ref clickable. This @@ -258,7 +265,8 @@ const ObjectVersionsTable: React.FC<{ return makeRefExpandedPayload(targetRefUri, res); } return res; - }); + } + ); cols.push(...newCols); groups = groupingModel; } diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx index 030b8980675..8b76964845a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx @@ -1,6 +1,7 @@ import React, {useMemo} from 'react'; import {LoadingDots} from '../../../../LoadingDots'; +import {NotFoundPanel} from '../NotFoundPanel'; import {OpCodeViewer} from '../OpCodeViewer'; import { CallsLink, @@ -35,7 +36,7 @@ export const OpVersionPage: React.FC<{ if (opVersion.loading) { return ; } else if (opVersion.result == null) { - return
Op version not found
; + return ; } return ; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx index 9ff6ae4a191..1e81bbd9643 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx @@ -9,13 +9,15 @@ import React from 'react'; import {isWeaveObjectRef, parseRef} from '../../../../../../../react'; import {ErrorBoundary} from '../../../../../../ErrorBoundary'; -import {flattenObject} from '../../../../Browse2/browse2Util'; +import {flattenObjectPreservingWeaveTypes} from '../../../../Browse2/browse2Util'; import {CellValue} from '../../../../Browse2/CellValue'; import {CollapseHeader} from '../../../../Browse2/CollapseGroupHeader'; import {ExpandHeader} from '../../../../Browse2/ExpandHeader'; import {NotApplicable} from '../../../../Browse2/NotApplicable'; import {SmallRef} from '../../../../Browse2/SmallRef'; import {CellFilterWrapper} from '../../../filters/CellFilterWrapper'; +import {isCustomWeaveTypePayload} from '../../../typeViews/customWeaveType.types'; +import {CustomWeaveTypeProjectContext} from '../../../typeViews/CustomWeaveTypeDispatcher'; import { OBJECT_ATTR_EDGE_NAME, WEAVE_PRIVATE_PREFIX, @@ -60,7 +62,15 @@ export function prepareFlattenedDataForTable( ): Array { return data.map(r => { // First, flatten the inner object - let flattened = flattenObject(r ?? {}); + let flattened = flattenObjectPreservingWeaveTypes(r ?? {}); + + // In the rare case that we have custom objects in the root (this only occurs if you directly) + // publish a custom object. Then we want to instead nest it under an empty key! + if (isCustomWeaveTypePayload(flattened)) { + flattened = { + ' ': flattened, + }; + } flattened = replaceTableRefsInFlattenedData(flattened); @@ -182,6 +192,7 @@ const isExpandedRefWithValueAsTableRef = ( export const buildDynamicColumns = ( filteredDynamicColumnNames: string[], + entityProjectFromRow: (row: T) => {entity: string; project: string}, valueForKey: (row: T, key: string) => any, columnIsExpanded?: (col: string) => boolean, columnCanBeExpanded?: (col: string) => boolean, @@ -269,6 +280,7 @@ export const buildDynamicColumns = ( return val; }, renderCell: cellParams => { + const {entity, project} = entityProjectFromRow(cellParams.row); const val = valueForKey(cellParams.row, key); if (val === undefined) { return ( @@ -287,7 +299,12 @@ export const buildDynamicColumns = ( onAddFilter={onAddFilter} field={key} operation={null} - value={val}> + value={val} + style={{ + width: '100%', + height: '100%', + alignContent: 'center', + }}> {/* In the future, we may want to move this isExpandedRefWithValueAsTableRef condition into `CellValue`. However, at the moment, `ExpandedRefWithValueAsTableRef` is a Table-specific data structure and we might not want to leak that into the @@ -295,7 +312,10 @@ export const buildDynamicColumns = ( {isExpandedRefWithValueAsTableRef(val) ? ( ) : ( - + + + )} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts index 93f239b2d23..b659a35b845 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts @@ -669,6 +669,13 @@ const useOpVersion = ( }; } + if (opVersionRes.obj == null) { + return { + loading: false, + result: null, + }; + } + const returnedResult = convertTraceServerObjectVersionToOpSchema( opVersionRes.obj ); @@ -812,6 +819,13 @@ const useObjectVersion = ( }; } + if (objectVersionRes.obj == null) { + return { + loading: false, + result: null, + }; + } + const returnedResult: ObjectVersionSchema = convertTraceServerObjectVersionToSchema(objectVersionRes.obj); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/CustomWeaveTypeDispatcher.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/CustomWeaveTypeDispatcher.tsx new file mode 100644 index 00000000000..cc8389fecb1 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/CustomWeaveTypeDispatcher.tsx @@ -0,0 +1,80 @@ +import React from 'react'; + +import {CustomWeaveTypePayload} from './customWeaveType.types'; +import {PILImageImage} from './PIL.Image.Image/PILImageImage'; + +type CustomWeaveTypeDispatcherProps = { + data: CustomWeaveTypePayload; + // Entity and Project can be optionally provided as props, but if they are not + // provided, they must be provided in context. Failure to provide them will + // result in a console warning and a fallback to a default component. + // + // This pattern is used because in many cases we are rendering data from + // hierarchical data structures, and we want to avoid passing entity and project + // down through the tree. + entity?: string; + project?: string; +}; + +const customWeaveTypeRegistry: { + [typeId: string]: { + component: React.FC<{ + entity: string; + project: string; + data: any; // I wish this could be typed more specifically + }>; + }; +} = { + 'PIL.Image.Image': { + component: PILImageImage, + }, +}; + +/** + * This context is used to provide the entity and project to the + * CustomWeaveTypeDispatcher. Importantly, what this does is allows the + * developer to inject an entity/project context around some component tree, and + * then any CustomWeaveTypeDispatchers within that tree will be assumed to be + * within that entity/project context. This is far cleaner than passing + * entity/project down through the tree. We just have to remember in the future + * case when we support multiple entities/projects in the same tree, we will + * need to update this context if you end up traversing into a different + * entity/project. This should already be accounted for in all the current + * use-cases. + */ +export const CustomWeaveTypeProjectContext = React.createContext<{ + entity: string; + project: string; +} | null>(null); + +/** + * This is the primary entry-point for dispatching custom weave types. Currently + * we just have 1, but as we add more, we might want to add a more robust + * "registry" + */ +export const CustomWeaveTypeDispatcher: React.FC< + CustomWeaveTypeDispatcherProps +> = ({data, entity, project}) => { + const projectContext = React.useContext(CustomWeaveTypeProjectContext); + const typeId = data.weave_type.type; + const comp = customWeaveTypeRegistry[typeId]?.component; + const defaultReturn = Custom Weave Type: {data.weave_type.type}; + + if (comp) { + const applicableEntity = entity || projectContext?.entity; + const applicableProject = project || projectContext?.project; + if (applicableEntity == null || applicableProject == null) { + console.warn( + 'CustomWeaveTypeDispatch: entity and project must be provided in context or as props' + ); + return defaultReturn; + } + return React.createElement(comp, { + entity: applicableEntity, + project: applicableProject, + data, + }); + } + + return defaultReturn; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/PIL.Image.Image/PILImageImage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/PIL.Image.Image/PILImageImage.tsx new file mode 100644 index 00000000000..fb07477497f --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/PIL.Image.Image/PILImageImage.tsx @@ -0,0 +1,53 @@ +import React from 'react'; + +import {LoadingDots} from '../../../../../LoadingDots'; +import {useWFHooks} from '../../pages/wfReactInterface/context'; +import {CustomWeaveTypePayload} from '../customWeaveType.types'; + +type PILImageImageTypePayload = CustomWeaveTypePayload< + 'PIL.Image.Image', + {'image.png': string} +>; + +export const isPILImageImageType = ( + data: CustomWeaveTypePayload +): data is PILImageImageTypePayload => { + return data.weave_type.type === 'PIL.Image.Image'; +}; + +export const PILImageImage: React.FC<{ + entity: string; + project: string; + data: PILImageImageTypePayload; +}> = props => { + const {useFileContent} = useWFHooks(); + const imageBinary = useFileContent( + props.entity, + props.project, + props.data.files['image.png'] + ); + + if (imageBinary.loading) { + return ; + } else if (imageBinary.result == null) { + return ; + } + + const arrayBuffer = imageBinary.result as any as ArrayBuffer; + const blob = new Blob([arrayBuffer], {type: 'image/png'}); + const url = URL.createObjectURL(blob); + + // TODO: It would be nice to have a more general image render - similar to the + // ValueViewImage that does things like light box, general scaling, + // downloading, etc.. + return ( + Custom + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/customWeaveType.types.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/customWeaveType.types.ts new file mode 100644 index 00000000000..4677e58c407 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/customWeaveType.types.ts @@ -0,0 +1,48 @@ +export type CustomWeaveTypePayload< + T extends string = string, + FP extends {[filename: string]: string} = {[filename: string]: string} +> = { + _type: 'CustomWeaveType'; + weave_type: { + type: T; + }; + files: FP; + load_op?: string | CustomWeaveTypePayload<'Op', {'obj.py': string}>; +} & {[extra: string]: any}; + +export const isCustomWeaveTypePayload = ( + data: any +): data is CustomWeaveTypePayload => { + if (typeof data !== 'object' || data == null) { + return false; + } + if (data._type !== 'CustomWeaveType') { + return false; + } + if ( + typeof data.weave_type !== 'object' || + data.weave_type == null || + typeof data.weave_type.type !== 'string' + ) { + return false; + } + if (typeof data.files !== 'object' || data.files == null) { + return false; + } + if (data.weave_type.type === 'Op') { + if (data.load_op != null) { + return false; + } + } else { + if (data.load_op == null) { + return false; + } + if ( + typeof data.load_op !== 'string' && + !isCustomWeaveTypePayload(data.load_op) + ) { + return false; + } + } + return true; +}; diff --git a/weave/flow/model.py b/weave/flow/model.py index dc211902ba9..0cd23eabd54 100644 --- a/weave/flow/model.py +++ b/weave/flow/model.py @@ -2,6 +2,11 @@ from weave.flow.obj import Object +INFER_METHOD_NAMES = {"predict", "infer", "forward", "invoke"} + + +class MissingInferenceMethodError(Exception): ... + class Model(Object): """ @@ -32,20 +37,18 @@ def predict(self, input_data: str) -> dict: # TODO: should be infer: Callable def get_infer_method(self) -> Callable: - for infer_method_names in ("predict", "infer", "forward"): - infer_method = getattr(self, infer_method_names, None) - if infer_method: + for name in INFER_METHOD_NAMES: + if infer_method := getattr(self, name, None): return infer_method - raise ValueError( - f"Model {self} does not have a predict, infer, or forward method." + raise MissingInferenceMethodError( + f"Missing a method with name in ({INFER_METHOD_NAMES})" ) def get_infer_method(model: Model) -> Callable: - for infer_method_names in ("predict", "infer", "forward"): - infer_method = getattr(model, infer_method_names, None) - if infer_method: + for name in INFER_METHOD_NAMES: + if (infer_method := getattr(model, name, None)) is not None: return infer_method - raise ValueError( - f"Model {model} does not have a predict, infer, or forward method." + raise MissingInferenceMethodError( + f"Missing a method with name in ({INFER_METHOD_NAMES})" ) diff --git a/weave/frontend/index.html b/weave/frontend/index.html index 74808c202a6..c96a45a9ae8 100644 --- a/weave/frontend/index.html +++ b/weave/frontend/index.html @@ -91,7 +91,7 @@ - + diff --git a/weave/frontend/sha1.txt b/weave/frontend/sha1.txt index 01a425480b3..7a517474529 100644 --- a/weave/frontend/sha1.txt +++ b/weave/frontend/sha1.txt @@ -1 +1 @@ -cd3b2e94bf9dc8702f53efb3844074379cd3b951 +18eebed493dc14f0fbaa3ab62505c6bdfd42ad6f diff --git a/weave/init_message.py b/weave/init_message.py index ffb6da21cf6..34f107139dd 100644 --- a/weave/init_message.py +++ b/weave/init_message.py @@ -44,10 +44,17 @@ def _print_version_check() -> None: if use_message: print(use_message) - orig_module = wandb._wandb_module - wandb._wandb_module = "weave" - weave_messages = wandb.sdk.internal.update.check_available(weave.__version__) - wandb._wandb_module = orig_module + weave_messages = None + if hasattr(weave, "_wandb_module"): + try: + orig_module = wandb._wandb_module # type: ignore + wandb._wandb_module = "weave" # type: ignore + weave_messages = wandb.sdk.internal.update.check_available( + weave.__version__ + ) + wandb._wandb_module = orig_module # type: ignore + except Exception: + weave_messages = None if weave_messages: use_message = ( diff --git a/weave/legacy/wandb_interface/project_creator.py b/weave/legacy/wandb_interface/project_creator.py index 617fe09a167..c12fe52e9a5 100644 --- a/weave/legacy/wandb_interface/project_creator.py +++ b/weave/legacy/wandb_interface/project_creator.py @@ -39,12 +39,19 @@ def wandb_logging_disabled() -> typing.Iterator[None]: wandb.termerror = original_termerror -def ensure_project_exists(entity_name: str, project_name: str) -> None: +def ensure_project_exists(entity_name: str, project_name: str) -> typing.Dict[str, str]: with wandb_logging_disabled(): return _ensure_project_exists(entity_name, project_name) -def _ensure_project_exists(entity_name: str, project_name: str) -> None: +def _ensure_project_exists( + entity_name: str, project_name: str +) -> typing.Dict[str, str]: + """ + Ensures that a W&B project exists by trying to access it, returns the project_name, + which is not guaranteed to be the same if the provided project_name contains invalid + characters. Adheres to trace_server_interface.EnsureProjectExistsRes + """ wandb_logging_disabled() api = InternalApi({"entity": entity_name, "project": project_name}) # Since `UpsertProject` will fail if the user does not have permission to create a project @@ -72,4 +79,4 @@ def _ensure_project_exists(entity_name: str, project_name: str) -> None: raise UnableToCreateProject( f"Failed to create project {entity_name}/{project_name}" ) - return + return {"project_name": project["name"]} diff --git a/weave/tests/test_client_trace.py b/weave/tests/test_client_trace.py index ad0ddba7f87..c319c66cc12 100644 --- a/weave/tests/test_client_trace.py +++ b/weave/tests/test_client_trace.py @@ -2144,6 +2144,44 @@ def calculate(a: int, b: int) -> int: assert i == len(calls.calls) +def test_call_query_stream_columns(client): + @weave.op + def calculate(a: int, b: int) -> int: + return {"result": {"a + b": a + b}, "not result": 123} + + for i in range(2): + calculate(i, i * i) + + calls = client.server.calls_query_stream( + tsi.CallsQueryReq( + project_id=client._project_id(), + columns=["id", "inputs"], + ) + ) + calls = list(calls) + assert len(calls) == 2 + assert len(calls[0].inputs) == 2 + + # NO output returned because not required and not requested + assert calls[0].output is None + assert calls[0].ended_at is None + assert calls[0].attributes == {} + assert calls[0].inputs == {"a": 0, "b": 0} + + # now explicitly get output + calls = client.server.calls_query_stream( + tsi.CallsQueryReq( + project_id=client._project_id(), + columns=["id", "inputs", "output.result"], + ) + ) + calls = list(calls) + assert len(calls) == 2 + assert calls[0].output["result"]["a + b"] == 0 + assert calls[0].attributes == {} + assert calls[0].inputs == {"a": 0, "b": 0} + + @pytest.mark.skip("Not implemented: filter / sort through refs") def test_sort_and_filter_through_refs(client): @weave.op() diff --git a/weave/tests/test_op.py b/weave/tests/test_op.py index a0c2d3240a2..bfa2271e8e5 100644 --- a/weave/tests/test_op.py +++ b/weave/tests/test_op.py @@ -249,3 +249,28 @@ def my_op(self, a: int) -> str: # type: ignore[empty-body] "a": types.Int(), } assert SomeWeaveObj.my_op.concrete_output_type == types.String() + + +def test_op_internal_tracing_enabled(client): + # This test verifies the behavior of `_tracing_enabled` which + # is not a user-facing API and is used internally to toggle + # tracing on and off. + @weave.op + def my_op(): + return "hello" + + my_op() # <-- this call will be traced + + assert len(list(my_op.calls())) == 1 + + my_op._tracing_enabled = False + + my_op() # <-- this call will not be traced + + assert len(list(my_op.calls())) == 1 + + my_op._tracing_enabled = True + + my_op() # <-- this call will be traced + + assert len(list(my_op.calls())) == 2 diff --git a/weave/tests/test_weave_client.py b/weave/tests/test_weave_client.py index 6fc102ce400..7a065f243d2 100644 --- a/weave/tests/test_weave_client.py +++ b/weave/tests/test_weave_client.py @@ -1324,3 +1324,13 @@ def test_summary_tokens_cost_sqlite(client): assert noCostCallSummary is None assert withCostCallSummary is None + + +def test_ref_in_dict(client): + ref = client._save_object({"a": 5}, "d1") + + # Put a ref directly in a dict. + ref2 = client._save_object({"b": ref}, "d2") + + obj = weave.ref(ref2.uri()).get() + assert obj["b"] == {"a": 5} diff --git a/weave/trace/custom_objs.py b/weave/trace/custom_objs.py index 7f8b7215c32..6a8c0022d3e 100644 --- a/weave/trace/custom_objs.py +++ b/weave/trace/custom_objs.py @@ -156,6 +156,9 @@ def decode_custom_obj( raise ValueError(f"No serializer found for {weave_type}") load_instance_op = serializer.load + # Disables tracing so that calls to loading data itself don't get traced + load_instance_op._tracing_enabled = False # type: ignore + art = MemTraceFilesArtifact( encoded_path_contents, metadata={}, diff --git a/weave/trace/op.py b/weave/trace/op.py index b784c2f0a92..d99ac6200b4 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -125,6 +125,15 @@ class Op(Protocol): __call__: Callable[..., Any] __self__: Any + # `_tracing_enabled` is a runtime-only flag that can be used to disable + # call tracing for an op. It is not persisted as a property of the op, but is + # respected by the current execution context. It is an underscore property + # because it is not intended to be used by users directly, but rather assists + # with internal Weave behavior. If we find a need to expose this to users, we + # should consider a more user-friendly API (perhaps a setter/getter) & whether + # it disables child ops as well. + _tracing_enabled: bool + def _set_on_output_handler(func: Op, on_output: OnOutputHandlerType) -> None: if func._on_output_handler is not None: @@ -337,6 +346,8 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: return await func(*args, **kwargs) if weave_client_context.get_weave_client() is None: return await func(*args, **kwargs) + if not wrapper._tracing_enabled: # type: ignore + return await func(*args, **kwargs) call = _create_call(wrapper, *args, **kwargs) # type: ignore res, _ = await _execute_call(wrapper, call, *args, **kwargs) # type: ignore return res @@ -348,6 +359,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return func(*args, **kwargs) if weave_client_context.get_weave_client() is None: return func(*args, **kwargs) + if not wrapper._tracing_enabled: # type: ignore + return func(*args, **kwargs) call = _create_call(wrapper, *args, **kwargs) # type: ignore res, _ = _execute_call(wrapper, call, *args, **kwargs) # type: ignore return res @@ -375,6 +388,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: wrapper._set_on_output_handler = partial(_set_on_output_handler, wrapper) # type: ignore wrapper._on_output_handler = None # type: ignore + wrapper._tracing_enabled = True # type: ignore + return cast(Op, wrapper) return create_wrapper(func) diff --git a/weave/trace/serializer.py b/weave/trace/serializer.py index 7d5bf90fc93..334811111d8 100644 --- a/weave/trace/serializer.py +++ b/weave/trace/serializer.py @@ -52,7 +52,7 @@ def id(self) -> str: # "Op" in the database. if ser_id.endswith(".Op"): return "Op" - return self.target_class.__name__ + return ser_id SERIALIZERS = [] diff --git a/weave/trace/vals.py b/weave/trace/vals.py index 5deb9c4bf4e..8933dcc201b 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -292,7 +292,13 @@ def _remote_iter(self) -> Generator[dict, None, None]: for item in response.rows: new_ref = self.ref.with_item(item.digest) if self.ref else None - yield make_trace_obj(item.val, new_ref, self.server, self.root) + res = from_json( + item.val, + self.table_ref.entity + "/" + self.table_ref.project, + self.server, + ) + res = make_trace_obj(res, new_ref, self.server, self.root) + yield res if len(response.rows) < page_size: break diff --git a/weave/trace_server/clickhouse_schema.py b/weave/trace_server/clickhouse_schema.py index 1cead54c4a8..f41ba84cca3 100644 --- a/weave/trace_server/clickhouse_schema.py +++ b/weave/trace_server/clickhouse_schema.py @@ -105,8 +105,11 @@ class SelectableCHCallSchema(BaseModel): ended_at: typing.Optional[datetime.datetime] = None exception: typing.Optional[str] = None - attributes_dump: str - inputs_dump: str + # attributes and inputs are required on call schema, but can be + # optionally selected when querying + attributes_dump: typing.Optional[str] = None + inputs_dump: typing.Optional[str] = None + output_dump: typing.Optional[str] = None summary_dump: typing.Optional[str] = None diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 9cab74f6895..57745ddeea6 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -106,10 +106,7 @@ class NotFoundError(Exception): all_call_select_columns = list(SelectableCHCallSchema.model_fields.keys()) all_call_json_columns = ("inputs", "output", "attributes", "summary") - - -# Let's just make everything required for now ... can optimize when we implement column selection -required_call_columns = list(set(all_call_select_columns) - set([])) +required_call_columns = ["id", "project_id", "trace_id", "op_name", "started_at"] # Columns in the calls_merged table with special aggregation functions: @@ -254,17 +251,20 @@ def calls_query_stream( cq = CallsQuery( project_id=req.project_id, include_costs=req.include_costs or False ) - - # TODO (Perf): By allowing a sub-selection of columns - # we will gain increased performance by not having to - # fetch all columns from the database. Currently all use - # cases call for every column to be fetched, so we have not - # implemented this yet. columns = all_call_select_columns + if req.columns: + # Set columns to user-requested columns, w/ required columns + # These are all formatted by the CallsQuery, which prevents injection + # and other attack vectors. + columns = list(set(required_call_columns + req.columns)) + # TODO: add support for json extract fields + # Split out any nested column requests + columns = [col.split(".")[0] for col in columns] + # We put summary_dump last so that when we compute the costs and summary its in the right place if req.include_costs: columns = [ - *[col for col in all_call_select_columns if col != "summary_dump"], + *[col for col in columns if col != "summary_dump"], "summary_dump", ] for col in columns: @@ -291,9 +291,10 @@ def calls_query_stream( pb.get_params(), ) + select_columns = [c.field for c in cq.select_fields] for row in raw_res: yield tsi.CallSchema.model_validate( - _ch_call_dict_to_call_schema_dict(dict(zip(columns, row))) + _ch_call_dict_to_call_schema_dict(dict(zip(select_columns, row))) ) def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: @@ -1383,8 +1384,8 @@ def _ch_call_to_call_schema(ch_call: SelectableCHCallSchema) -> tsi.CallSchema: op_name=ch_call.op_name, started_at=_ensure_datetimes_have_tz(ch_call.started_at), ended_at=_ensure_datetimes_have_tz(ch_call.ended_at), - attributes=_dict_dump_to_dict(ch_call.attributes_dump), - inputs=_dict_dump_to_dict(ch_call.inputs_dump), + attributes=_dict_dump_to_dict(ch_call.attributes_dump or "{}"), + inputs=_dict_dump_to_dict(ch_call.inputs_dump or "{}"), output=_nullable_any_dump_to_any(ch_call.output_dump), summary=_nullable_dict_dump_to_dict(ch_call.summary_dump), exception=ch_call.exception, @@ -1404,8 +1405,8 @@ def _ch_call_dict_to_call_schema_dict(ch_call_dict: typing.Dict) -> typing.Dict: op_name=ch_call_dict.get("op_name"), started_at=_ensure_datetimes_have_tz(ch_call_dict.get("started_at")), ended_at=_ensure_datetimes_have_tz(ch_call_dict.get("ended_at")), - attributes=_dict_dump_to_dict(ch_call_dict["attributes_dump"]), - inputs=_dict_dump_to_dict(ch_call_dict["inputs_dump"]), + attributes=_dict_dump_to_dict(ch_call_dict.get("attributes_dump", "{}")), + inputs=_dict_dump_to_dict(ch_call_dict.get("inputs_dump", "{}")), output=_nullable_any_dump_to_any(ch_call_dict.get("output_dump")), summary=_nullable_dict_dump_to_dict(ch_call_dict.get("summary_dump")), exception=ch_call_dict.get("exception"), diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index c75929f2f6f..372aa48d1b3 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -94,8 +94,10 @@ def cached_int_to_ext_project_id(project_id: str) -> typing.Optional[str]: yield universal_int_to_ext_ref_converter(item, cached_int_to_ext_project_id) # Standard API Below: - def ensure_project_exists(self, entity: str, project: str) -> None: - self._internal_trace_server.ensure_project_exists(entity, project) + def ensure_project_exists( + self, entity: str, project: str + ) -> tsi.EnsureProjectExistsRes: + return self._internal_trace_server.ensure_project_exists(entity, project) def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: req.start.project_id = self._idc.ext_to_int_project_id(req.start.project_id) diff --git a/weave/trace_server/remote_http_trace_server.py b/weave/trace_server/remote_http_trace_server.py index 45617f1a0c6..145d2f50906 100644 --- a/weave/trace_server/remote_http_trace_server.py +++ b/weave/trace_server/remote_http_trace_server.py @@ -108,10 +108,14 @@ def __init__( self._auth: t.Optional[t.Tuple[str, str]] = None self.remote_request_bytes_limit = remote_request_bytes_limit - def ensure_project_exists(self, entity: str, project: str) -> None: + def ensure_project_exists( + self, entity: str, project: str + ) -> tsi.EnsureProjectExistsRes: # TODO: This should happen in the wandb backend, not here, and it's slow # (hundreds of ms) - project_creator.ensure_project_exists(entity, project) + return tsi.EnsureProjectExistsRes.model_validate( + project_creator.ensure_project_exists(entity, project) + ) @classmethod def from_env(cls, should_batch: bool = False) -> "RemoteHTTPTraceServer": diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 2402067b062..82a512bcfc7 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -379,7 +379,17 @@ def process_operand(operand: tsi_query.Operand) -> str: conds.append(filter_cond) - query = f"SELECT * FROM calls WHERE deleted_at IS NULL AND project_id = '{req.project_id}'" + required_columns = ["id", "trace_id", "project_id", "op_name", "started_at"] + select_columns = list(tsi.CallSchema.model_fields.keys()) + if req.columns: + # TODO(gst): allow json fields to be selected + simple_columns = [x.split(".")[0] for x in req.columns] + select_columns = [x for x in simple_columns if x in select_columns] + # add required columns, preserving requested column order + select_columns += [ + rcol for rcol in required_columns if rcol not in select_columns + ] + query = f"SELECT {', '.join(select_columns)} FROM calls WHERE deleted_at IS NULL AND project_id = '{req.project_id}'" conditions_part = " AND ".join(conds) @@ -431,29 +441,29 @@ def process_operand(operand: tsi_query.Operand) -> str: cursor.execute(query) query_result = cursor.fetchall() - return tsi.CallsQueryRes( - calls=[ - tsi.CallSchema( - project_id=row[0], - id=row[1], - trace_id=row[2], - parent_id=row[3], - op_name=row[4], - started_at=row[5], - ended_at=row[6], - exception=row[7], - attributes=json.loads(row[8]), - inputs=json.loads(row[9]), - output=None if row[11] is None else json.loads(row[11]), - output_refs=None if row[12] is None else json.loads(row[12]), - summary=json.loads(row[13]) if row[13] else None, - wb_user_id=row[14], - wb_run_id=row[15], - display_name=row[17] if row[17] != "" else None, - ) - for row in query_result - ] - ) + calls = [] + for row in query_result: + call_dict = {k: v for k, v in zip(select_columns, row)} + # convert json dump fields into json + for json_field in ["attributes", "summary", "inputs", "output"]: + if call_dict.get(json_field): + call_dict[json_field] = json.loads(call_dict[json_field]) + # convert empty string display_names to None + if "display_name" in call_dict and call_dict["display_name"] == "": + call_dict["display_name"] = None + # fill in missing required fields with defaults + for col, mfield in tsi.CallSchema.model_fields.items(): + if mfield.is_required() and col not in call_dict: + if isinstance(mfield.annotation, str): + call_dict[col] = "" + elif isinstance( + mfield.annotation, (datetime.datetime, datetime.date) + ): + raise ValueError(f"Field '{col}' is required for selection") + else: + call_dict[col] = {} + calls.append(tsi.CallSchema(**call_dict)) + return tsi.CallsQueryRes(calls=calls) def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema]: return iter(self.calls_query(req).calls) diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 64df67db54d..a4ad1625c9b 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -1,6 +1,5 @@ -import abc import datetime -import typing +from typing import Any, Dict, Iterator, List, Literal, Optional, Protocol, Union from pydantic import BaseModel, ConfigDict, Field, field_serializer from typing_extensions import TypedDict @@ -21,42 +20,42 @@ class ExtraKeysTypedDict(TypedDict): class LLMUsageSchema(TypedDict, total=False): - prompt_tokens: typing.Optional[int] - input_tokens: typing.Optional[int] - completion_tokens: typing.Optional[int] - output_tokens: typing.Optional[int] - requests: typing.Optional[int] - total_tokens: typing.Optional[int] + prompt_tokens: Optional[int] + input_tokens: Optional[int] + completion_tokens: Optional[int] + output_tokens: Optional[int] + requests: Optional[int] + total_tokens: Optional[int] class LLMCostSchema(LLMUsageSchema): - prompt_tokens_cost: typing.Optional[float] - completion_tokens_cost: typing.Optional[float] - prompt_token_cost: typing.Optional[float] - completion_token_cost: typing.Optional[float] - prompt_token_cost_unit: typing.Optional[str] - completion_token_cost_unit: typing.Optional[str] - effective_date: typing.Optional[str] - provider_id: typing.Optional[str] - pricing_level: typing.Optional[str] - pricing_level_id: typing.Optional[str] - created_at: typing.Optional[str] - created_by: typing.Optional[str] + prompt_tokens_cost: Optional[float] + completion_tokens_cost: Optional[float] + prompt_token_cost: Optional[float] + completion_token_cost: Optional[float] + prompt_token_cost_unit: Optional[str] + completion_token_cost_unit: Optional[str] + effective_date: Optional[str] + provider_id: Optional[str] + pricing_level: Optional[str] + pricing_level_id: Optional[str] + created_at: Optional[str] + created_by: Optional[str] class WeaveSummarySchema(ExtraKeysTypedDict, total=False): - status: typing.Optional[typing.Literal["success", "error", "running"]] - nice_trace_name: typing.Optional[str] - latency: typing.Optional[int] - costs: typing.Optional[typing.Dict[str, LLMCostSchema]] + status: Optional[Literal["success", "error", "running"]] + nice_trace_name: Optional[str] + latency: Optional[int] + costs: Optional[Dict[str, LLMCostSchema]] class SummaryInsertMap(ExtraKeysTypedDict, total=False): - usage: typing.Dict[str, LLMUsageSchema] + usage: Dict[str, LLMUsageSchema] class SummaryMap(SummaryInsertMap, total=False): - weave: typing.Optional[WeaveSummarySchema] + weave: Optional[WeaveSummarySchema] class CallSchema(BaseModel): @@ -66,43 +65,41 @@ class CallSchema(BaseModel): # Name of the calling function (op) op_name: str # Optional display name of the call - display_name: typing.Optional[str] = None + display_name: Optional[str] = None - ## Trace ID + # Trace ID trace_id: str - ## Parent ID is optional because the call may be a root - parent_id: typing.Optional[str] = None + # Parent ID is optional because the call may be a root + parent_id: Optional[str] = None - ## Start time is required + # Start time is required started_at: datetime.datetime - ## Attributes: properties of the call - attributes: typing.Dict[str, typing.Any] + # Attributes: properties of the call + attributes: Dict[str, Any] - ## Inputs - inputs: typing.Dict[str, typing.Any] + # Inputs + inputs: Dict[str, Any] - ## End time is required if finished - ended_at: typing.Optional[datetime.datetime] = None + # End time is required if finished + ended_at: Optional[datetime.datetime] = None - ## Exception is present if the call failed - exception: typing.Optional[str] = None + # Exception is present if the call failed + exception: Optional[str] = None - ## Outputs - output: typing.Optional[typing.Any] = None + # Outputs + output: Optional[Any] = None - ## Summary: a summary of the call - summary: typing.Optional[SummaryMap] = None + # Summary: a summary of the call + summary: Optional[SummaryMap] = None # WB Metadata - wb_user_id: typing.Optional[str] = None - wb_run_id: typing.Optional[str] = None + wb_user_id: Optional[str] = None + wb_run_id: Optional[str] = None - deleted_at: typing.Optional[datetime.datetime] = None + deleted_at: Optional[datetime.datetime] = None @field_serializer("attributes", "summary", when_used="unless-none") - def serialize_typed_dicts( - self, v: typing.Dict[str, typing.Any] - ) -> typing.Dict[str, typing.Any]: + def serialize_typed_dicts(self, v: Dict[str, Any]) -> Dict[str, Any]: return dict(v) @@ -111,51 +108,49 @@ def serialize_typed_dicts( # - trace_id is not required (will be generated) class StartedCallSchemaForInsert(BaseModel): project_id: str - id: typing.Optional[str] = None # Will be generated if not provided + id: Optional[str] = None # Will be generated if not provided # Name of the calling function (op) op_name: str # Optional display name of the call - display_name: typing.Optional[str] = None + display_name: Optional[str] = None - ## Trace ID - trace_id: typing.Optional[str] = None # Will be generated if not provided - ## Parent ID is optional because the call may be a root - parent_id: typing.Optional[str] = None + # Trace ID + trace_id: Optional[str] = None # Will be generated if not provided + # Parent ID is optional because the call may be a root + parent_id: Optional[str] = None - ## Start time is required + # Start time is required started_at: datetime.datetime - ## Attributes: properties of the call - attributes: typing.Dict[str, typing.Any] + # Attributes: properties of the call + attributes: Dict[str, Any] - ## Inputs - inputs: typing.Dict[str, typing.Any] + # Inputs + inputs: Dict[str, Any] # WB Metadata - wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) - wb_run_id: typing.Optional[str] = None + wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) + wb_run_id: Optional[str] = None class EndedCallSchemaForInsert(BaseModel): project_id: str id: str - ## End time is required + # End time is required ended_at: datetime.datetime - ## Exception is present if the call failed - exception: typing.Optional[str] = None + # Exception is present if the call failed + exception: Optional[str] = None - ## Outputs - output: typing.Optional[typing.Any] = None + # Outputs + output: Optional[Any] = None - ## Summary: a summary of the call + # Summary: a summary of the call summary: SummaryInsertMap @field_serializer("summary") - def serialize_typed_dicts( - self, v: typing.Dict[str, typing.Any] - ) -> typing.Dict[str, typing.Any]: + def serialize_typed_dicts(self, v: Dict[str, Any]) -> Dict[str, Any]: return dict(v) @@ -163,24 +158,24 @@ class ObjSchema(BaseModel): project_id: str object_id: str created_at: datetime.datetime - deleted_at: typing.Optional[datetime.datetime] = None + deleted_at: Optional[datetime.datetime] = None digest: str version_index: int is_latest: int kind: str - base_object_class: typing.Optional[str] - val: typing.Any + base_object_class: Optional[str] + val: Any class ObjSchemaForInsert(BaseModel): project_id: str object_id: str - val: typing.Any + val: Any class TableSchemaForInsert(BaseModel): project_id: str - rows: list[dict[str, typing.Any]] + rows: list[dict[str, Any]] class CallStartReq(BaseModel): @@ -203,19 +198,19 @@ class CallEndRes(BaseModel): class CallReadReq(BaseModel): project_id: str id: str - include_costs: typing.Optional[bool] = False + include_costs: Optional[bool] = False class CallReadRes(BaseModel): - call: typing.Optional[CallSchema] + call: Optional[CallSchema] class CallsDeleteReq(BaseModel): project_id: str - call_ids: typing.List[str] + call_ids: List[str] # wb_user_id is automatically populated by the server - wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) + wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) class CallsDeleteRes(BaseModel): @@ -223,15 +218,15 @@ class CallsDeleteRes(BaseModel): class CallsFilter(BaseModel): - op_names: typing.Optional[typing.List[str]] = None - input_refs: typing.Optional[typing.List[str]] = None - output_refs: typing.Optional[typing.List[str]] = None - parent_ids: typing.Optional[typing.List[str]] = None - trace_ids: typing.Optional[typing.List[str]] = None - call_ids: typing.Optional[typing.List[str]] = None - trace_roots_only: typing.Optional[bool] = None - wb_user_ids: typing.Optional[typing.List[str]] = None - wb_run_ids: typing.Optional[typing.List[str]] = None + op_names: Optional[List[str]] = None + input_refs: Optional[List[str]] = None + output_refs: Optional[List[str]] = None + parent_ids: Optional[List[str]] = None + trace_ids: Optional[List[str]] = None + call_ids: Optional[List[str]] = None + trace_roots_only: Optional[bool] = None + wb_user_ids: Optional[List[str]] = None + wb_run_ids: Optional[List[str]] = None class SortBy(BaseModel): @@ -240,32 +235,32 @@ class SortBy(BaseModel): # dot-separated. field: str # Consider changing this to _FieldSelect # Direction should be either 'asc' or 'desc' - direction: typing.Literal["asc", "desc"] + direction: Literal["asc", "desc"] class CallsQueryReq(BaseModel): project_id: str - filter: typing.Optional[CallsFilter] = None - limit: typing.Optional[int] = None - offset: typing.Optional[int] = None + filter: Optional[CallsFilter] = None + limit: Optional[int] = None + offset: Optional[int] = None # Sort by multiple fields - sort_by: typing.Optional[typing.List[SortBy]] = None - query: typing.Optional[Query] = None - include_costs: typing.Optional[bool] = False + sort_by: Optional[List[SortBy]] = None + query: Optional[Query] = None + include_costs: Optional[bool] = False # TODO: type this with call schema columns, following the same rules as # SortBy and thus GetFieldOperator.get_field_ (without direction) - columns: typing.Optional[typing.List[str]] = None + columns: Optional[List[str]] = None class CallsQueryRes(BaseModel): - calls: typing.List[CallSchema] + calls: List[CallSchema] class CallsQueryStatsReq(BaseModel): project_id: str - filter: typing.Optional[CallsFilter] = None - query: typing.Optional[Query] = None + filter: Optional[CallsFilter] = None + query: Optional[Query] = None class CallsQueryStatsRes(BaseModel): @@ -278,10 +273,10 @@ class CallUpdateReq(BaseModel): call_id: str # optional update fields - display_name: typing.Optional[str] = None + display_name: Optional[str] = None # wb_user_id is automatically populated by the server - wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) + wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) class CallUpdateRes(BaseModel): @@ -307,17 +302,17 @@ class OpReadRes(BaseModel): class OpVersionFilter(BaseModel): - op_names: typing.Optional[typing.List[str]] = None - latest_only: typing.Optional[bool] = None + op_names: Optional[List[str]] = None + latest_only: Optional[bool] = None class OpQueryReq(BaseModel): project_id: str - filter: typing.Optional[OpVersionFilter] = None + filter: Optional[OpVersionFilter] = None class OpQueryRes(BaseModel): - op_objs: typing.List[ObjSchema] + op_objs: List[ObjSchema] class ObjCreateReq(BaseModel): @@ -339,19 +334,19 @@ class ObjReadRes(BaseModel): class ObjectVersionFilter(BaseModel): - base_object_classes: typing.Optional[typing.List[str]] = None - object_ids: typing.Optional[typing.List[str]] = None - is_op: typing.Optional[bool] = None - latest_only: typing.Optional[bool] = None + base_object_classes: Optional[List[str]] = None + object_ids: Optional[List[str]] = None + is_op: Optional[bool] = None + latest_only: Optional[bool] = None class ObjQueryReq(BaseModel): project_id: str - filter: typing.Optional[ObjectVersionFilter] = None + filter: Optional[ObjectVersionFilter] = None class ObjQueryRes(BaseModel): - objs: typing.List[ObjSchema] + objs: List[ObjSchema] class TableCreateReq(BaseModel): @@ -410,7 +405,7 @@ class Table[OPERATION]Spec(BaseModel): class TableAppendSpecPayload(BaseModel): - row: dict[str, typing.Any] + row: dict[str, Any] class TableAppendSpec(BaseModel): @@ -427,14 +422,14 @@ class TablePopSpec(BaseModel): class TableInsertSpecPayload(BaseModel): index: int - row: dict[str, typing.Any] + row: dict[str, Any] class TableInsertSpec(BaseModel): insert: TableInsertSpecPayload -TableUpdateSpec = typing.Union[TableAppendSpec, TablePopSpec, TableInsertSpec] +TableUpdateSpec = Union[TableAppendSpec, TablePopSpec, TableInsertSpec] class TableUpdateReq(BaseModel): @@ -449,7 +444,7 @@ class TableUpdateRes(BaseModel): class TableRowSchema(BaseModel): digest: str - val: typing.Any + val: Any class TableCreateRes(BaseModel): @@ -457,27 +452,27 @@ class TableCreateRes(BaseModel): class TableRowFilter(BaseModel): - row_digests: typing.Optional[typing.List[str]] = None + row_digests: Optional[List[str]] = None class TableQueryReq(BaseModel): project_id: str digest: str - filter: typing.Optional[TableRowFilter] = None - limit: typing.Optional[int] = None - offset: typing.Optional[int] = None + filter: Optional[TableRowFilter] = None + limit: Optional[int] = None + offset: Optional[int] = None class TableQueryRes(BaseModel): - rows: typing.List[TableRowSchema] + rows: List[TableRowSchema] class RefsReadBatchReq(BaseModel): - refs: typing.List[str] + refs: List[str] class RefsReadBatchRes(BaseModel): - vals: typing.List[typing.Any] + vals: List[Any] class FeedbackPayloadReactionReq(BaseModel): @@ -491,9 +486,9 @@ class FeedbackPayloadNoteReq(BaseModel): class FeedbackCreateReq(BaseModel): project_id: str = Field(examples=["entity/project"]) weave_ref: str = Field(examples=["weave:///entity/project/object/name:digest"]) - creator: typing.Optional[str] = Field(default=None, examples=["Jane Smith"]) + creator: Optional[str] = Field(default=None, examples=["Jane Smith"]) feedback_type: str = Field(examples=["custom"]) - payload: typing.Dict[str, typing.Any] = Field( + payload: Dict[str, Any] = Field( examples=[ { "key": "value", @@ -502,7 +497,7 @@ class FeedbackCreateReq(BaseModel): ) # wb_user_id is automatically populated by the server - wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) + wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) # The response provides the additional fields needed to convert a request @@ -511,7 +506,7 @@ class FeedbackCreateRes(BaseModel): id: str created_at: datetime.datetime wb_user_id: str - payload: typing.Dict[str, typing.Any] # If not empty, replace payload + payload: Dict[str, Any] # If not empty, replace payload class Feedback(FeedbackCreateReq): @@ -521,20 +516,20 @@ class Feedback(FeedbackCreateReq): class FeedbackQueryReq(BaseModel): project_id: str = Field(examples=["entity/project"]) - fields: typing.Optional[list[str]] = Field( + fields: Optional[list[str]] = Field( default=None, examples=[["id", "feedback_type", "payload.note"]] ) - query: typing.Optional[Query] = None + query: Optional[Query] = None # TODO: I think I would prefer to call this order_by to match SQL, but this is what calls API uses # TODO: Might be nice to have shortcut for single field and implied ASC direction - sort_by: typing.Optional[typing.List[SortBy]] = None - limit: typing.Optional[int] = Field(default=None, examples=[10]) - offset: typing.Optional[int] = Field(default=None, examples=[0]) + sort_by: Optional[List[SortBy]] = None + limit: Optional[int] = Field(default=None, examples=[10]) + offset: Optional[int] = Field(default=None, examples=[0]) class FeedbackQueryRes(BaseModel): # Note: this is not a list of Feedback because user can request any fields. - result: list[dict[str, typing.Any]] + result: list[dict[str, Any]] class FeedbackPurgeReq(BaseModel): @@ -565,114 +560,49 @@ class FileContentReadRes(BaseModel): content: bytes -class TraceServerInterface: - def ensure_project_exists(self, entity: str, project: str) -> None: - pass +class EnsureProjectExistsRes(BaseModel): + project_name: str - # Call API - @abc.abstractmethod - def call_start(self, req: CallStartReq) -> CallStartRes: - raise NotImplementedError() - - @abc.abstractmethod - def call_end(self, req: CallEndReq) -> CallEndRes: - raise NotImplementedError() - - @abc.abstractmethod - def call_read(self, req: CallReadReq) -> CallReadRes: - raise NotImplementedError() - - @abc.abstractmethod - def calls_query(self, req: CallsQueryReq) -> CallsQueryRes: - raise NotImplementedError() - - @abc.abstractmethod - def calls_query_stream(self, req: CallsQueryReq) -> typing.Iterator[CallSchema]: - raise NotImplementedError() - @abc.abstractmethod - def calls_delete(self, req: CallsDeleteReq) -> CallsDeleteRes: - raise NotImplementedError() +class TraceServerInterface(Protocol): + def ensure_project_exists( + self, entity: str, project: str + ) -> EnsureProjectExistsRes: + return EnsureProjectExistsRes(project_name=project) - @abc.abstractmethod - def calls_query_stats(self, req: CallsQueryStatsReq) -> CallsQueryStatsRes: - raise NotImplementedError() - - @abc.abstractmethod - def call_update(self, req: CallUpdateReq) -> CallUpdateRes: - raise NotImplementedError() + # Call API + def call_start(self, req: CallStartReq) -> CallStartRes: ... + def call_end(self, req: CallEndReq) -> CallEndRes: ... + def call_read(self, req: CallReadReq) -> CallReadRes: ... + def calls_query(self, req: CallsQueryReq) -> CallsQueryRes: ... + def calls_query_stream(self, req: CallsQueryReq) -> Iterator[CallSchema]: ... + def calls_delete(self, req: CallsDeleteReq) -> CallsDeleteRes: ... + def calls_query_stats(self, req: CallsQueryStatsReq) -> CallsQueryStatsRes: ... + def call_update(self, req: CallUpdateReq) -> CallUpdateRes: ... # Op API - @abc.abstractmethod - def op_create(self, req: OpCreateReq) -> OpCreateRes: - raise NotImplementedError() - - @abc.abstractmethod - def op_read(self, req: OpReadReq) -> OpReadRes: - raise NotImplementedError() - - @abc.abstractmethod - def ops_query(self, req: OpQueryReq) -> OpQueryRes: - raise NotImplementedError() + def op_create(self, req: OpCreateReq) -> OpCreateRes: ... + def op_read(self, req: OpReadReq) -> OpReadRes: ... + def ops_query(self, req: OpQueryReq) -> OpQueryRes: ... # Obj API - @abc.abstractmethod - def obj_create(self, req: ObjCreateReq) -> ObjCreateRes: - raise NotImplementedError() - - @abc.abstractmethod - def obj_read(self, req: ObjReadReq) -> ObjReadRes: - raise NotImplementedError() - - @abc.abstractmethod - def objs_query(self, req: ObjQueryReq) -> ObjQueryRes: - raise NotImplementedError() - - @abc.abstractmethod - def table_create(self, req: TableCreateReq) -> TableCreateRes: - raise NotImplementedError() - - @abc.abstractmethod - def table_update(self, req: TableUpdateReq) -> TableUpdateRes: - raise NotImplementedError() - - @abc.abstractmethod - def table_query(self, req: TableQueryReq) -> TableQueryRes: - raise NotImplementedError() - - @abc.abstractmethod - def refs_read_batch(self, req: RefsReadBatchReq) -> RefsReadBatchRes: - raise NotImplementedError() - - @abc.abstractmethod - def file_create(self, req: FileCreateReq) -> FileCreateRes: - raise NotImplementedError() - - @abc.abstractmethod - def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: - raise NotImplementedError() - - @abc.abstractmethod - def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: - raise NotImplementedError() - - @abc.abstractmethod - def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: - raise NotImplementedError() - - @abc.abstractmethod - def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: - raise NotImplementedError() + def obj_create(self, req: ObjCreateReq) -> ObjCreateRes: ... + def obj_read(self, req: ObjReadReq) -> ObjReadRes: ... + def objs_query(self, req: ObjQueryReq) -> ObjQueryRes: ... + def table_create(self, req: TableCreateReq) -> TableCreateRes: ... + def table_update(self, req: TableUpdateReq) -> TableUpdateRes: ... + def table_query(self, req: TableQueryReq) -> TableQueryRes: ... + def refs_read_batch(self, req: RefsReadBatchReq) -> RefsReadBatchRes: ... + def file_create(self, req: FileCreateReq) -> FileCreateRes: ... + def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ... + def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: ... + def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ... + def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ... # These symbols are used in the WB Trace Server and it is not safe # to remove them, else it will break the server. Once the server # is updated to use the new symbols, these can be removed. -# -# Remove once https://github.com/wandb/core/pull/22040 lands -CallsDeleteReqForInsert = CallsDeleteReq -CallUpdateReqForInsert = CallUpdateReq -FeedbackCreateReqForInsert = FeedbackCreateReq # Legacy Names (i think these might be used in a few growth examples, so keeping # around until we clean those up of them) diff --git a/weave/type_serializers/Image/__init__.py b/weave/type_serializers/Image/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/weave/type_serializers/Image/image.py b/weave/type_serializers/Image/image.py new file mode 100644 index 00000000000..c9b9402a4de --- /dev/null +++ b/weave/type_serializers/Image/image.py @@ -0,0 +1,44 @@ +"""Defines the custom Image weave type.""" + +from weave.trace import serializer +from weave.trace.custom_objs import MemTraceFilesArtifact + +dependencies_met = False + +try: + from PIL import Image + + dependencies_met = True +except ImportError: + pass + + +def save(obj: "Image.Image", artifact: MemTraceFilesArtifact, name: str) -> None: + # Note: I am purposely ignoring the `name` here and hard-coding the filename to "image.png". + # There is an extensive internal discussion here: + # https://weightsandbiases.slack.com/archives/C03BSTEBD7F/p1723670081582949 + # + # In summary, there is an outstanding design decision to be made about how to handle the + # `name` parameter. One school of thought is that using the `name` parameter allows multiple + # object to use the same artifact more cleanly. However, another school of thought is that + # the payload should not be dependent on an external name - resulting in different payloads + # for the same logical object. + # + # Using `image.png` is fine for now since we don't have any cases of multiple objects + # using the same artifact. Moreover, since we package the deserialization logic with the + # object payload, we can always change the serialization logic later without breaking + # existing payloads. + with artifact.new_file("image.png", binary=True) as f: + obj.save(f, format="png") # type: ignore + + +def load(artifact: MemTraceFilesArtifact, name: str) -> "Image.Image": + # Note: I am purposely ignoring the `name` here and hard-coding the filename. See comment + # on save. + path = artifact.path("image.png") + return Image.open(path) + + +def register() -> None: + if dependencies_met: + serializer.register_serializer(Image.Image, save, load) diff --git a/weave/type_serializers/Image/image_test.py b/weave/type_serializers/Image/image_test.py new file mode 100644 index 00000000000..cf42c07c4d6 --- /dev/null +++ b/weave/type_serializers/Image/image_test.py @@ -0,0 +1,102 @@ +from PIL import Image + +import weave +from weave.weave_client import WeaveClient, get_ref + +"""When testing types, it is important to test: +Objects: +1. Publishing Directly +2. Publishing as a property +3. Using as a cell in a table + +Calls: +4. Using as inputs, output, and output component (raw) +5. Using as inputs, output, and output component (refs) + +""" + + +def test_image_publish(client: WeaveClient) -> None: + img = Image.new("RGB", (512, 512), "purple") + weave.publish(img) + + ref = get_ref(img) + + assert ref is not None + gotten_img = weave.ref(ref.uri()).get() + assert img.tobytes() == gotten_img.tobytes() + + +class ImageWrapper(weave.Object): + img: Image.Image + + +def test_image_as_property(client: WeaveClient) -> None: + img = Image.new("RGB", (512, 512), "purple") + img_wrapper = ImageWrapper(img=img) + assert img_wrapper.img == img + + weave.publish(img_wrapper) + + ref = get_ref(img_wrapper) + assert ref is not None + + gotten_img_wrapper = weave.ref(ref.uri()).get() + assert gotten_img_wrapper.img.tobytes() == img.tobytes() + + +def test_image_as_dataset_cell(client: WeaveClient) -> None: + img = Image.new("RGB", (512, 512), "purple") + dataset = weave.Dataset(rows=[{"img": img}]) + assert dataset.rows[0]["img"] == img + + weave.publish(dataset) + + ref = get_ref(dataset) + assert ref is not None + + gotten_dataset = weave.ref(ref.uri()).get() + assert gotten_dataset.rows[0]["img"].tobytes() == img.tobytes() + + +@weave.op +def image_as_solo_output(publish_first: bool) -> Image.Image: + img = Image.new("RGB", (512, 512), "purple") + if publish_first: + weave.publish(img) + return img + + +@weave.op +def image_as_input_and_output_part(in_img: Image.Image) -> dict: + return {"out_img": in_img} + + +def test_image_as_call_io(client: WeaveClient) -> None: + non_published_img = image_as_solo_output(publish_first=False) + img_dict = image_as_input_and_output_part(non_published_img) + + exp_bytes = non_published_img.tobytes() + assert img_dict["out_img"].tobytes() == exp_bytes + + image_as_solo_output_call = image_as_solo_output.calls()[0] + image_as_input_and_output_part_call = image_as_input_and_output_part.calls()[0] + + assert image_as_solo_output_call.output.tobytes() == exp_bytes + assert image_as_input_and_output_part_call.inputs["in_img"].tobytes() == exp_bytes + assert image_as_input_and_output_part_call.output["out_img"].tobytes() == exp_bytes + + +def test_image_as_call_io_refs(client: WeaveClient) -> None: + non_published_img = image_as_solo_output(publish_first=True) + img_dict = image_as_input_and_output_part(non_published_img) + + exp_bytes = non_published_img.tobytes() + assert img_dict["out_img"].tobytes() == exp_bytes + + image_as_solo_output_call = image_as_solo_output.calls()[0] + image_as_input_and_output_part_call = image_as_input_and_output_part.calls()[0] + + assert image_as_solo_output_call.output.tobytes() == exp_bytes + assert image_as_input_and_output_part_call.inputs["in_img"].tobytes() == exp_bytes + assert image_as_input_and_output_part_call.output["out_img"].tobytes() == exp_bytes diff --git a/weave/type_serializers/__init__.py b/weave/type_serializers/__init__.py new file mode 100644 index 00000000000..396af8f791e --- /dev/null +++ b/weave/type_serializers/__init__.py @@ -0,0 +1,3 @@ +from .Image import image + +image.register() diff --git a/weave/weave_client.py b/weave/weave_client.py index 88e205f605a..d65275180d9 100644 --- a/weave/weave_client.py +++ b/weave/weave_client.py @@ -23,6 +23,7 @@ from weave.trace.op import op as op_deco from weave.trace.refs import CallRef, ObjectRef, OpRef, Ref, TableRef from weave.trace.serialize import from_json, isinstance_namedtuple, to_json +from weave.trace.serializer import get_serializer_for_obj from weave.trace.vals import WeaveObject, WeaveTable, make_trace_obj from weave.trace_server.ids import generate_id from weave.trace_server.trace_server_interface import ( @@ -101,6 +102,8 @@ def _get_direct_ref(obj: Any) -> Optional[Ref]: def map_to_refs(obj: Any) -> Any: + if isinstance(obj, Ref): + return obj if ref := _get_direct_ref(obj): return ref @@ -288,7 +291,7 @@ def make_client_call( parent_id=server_call.parent_id, id=server_call.id, inputs=from_json(server_call.inputs, server_call.project_id, server), - output=output, + output=from_json(output, server_call.project_id, server), summary=dict(server_call.summary) if server_call.summary is not None else None, display_name=server_call.display_name, attributes=server_call.attributes, @@ -380,7 +383,9 @@ def __init__( self.ensure_project_exists = ensure_project_exists if ensure_project_exists: - self.server.ensure_project_exists(entity, project) + resp = self.server.ensure_project_exists(entity, project) + # Set Client project name with updated project name + self.project = resp.project_name ################ High Level Convenience Methods ################ @@ -727,10 +732,18 @@ def _project_id(self) -> str: @trace_sentry.global_trace_sentry.watch() def _save_object(self, val: Any, name: str, branch: str = "latest") -> ObjectRef: self._save_nested_objects(val, name=name) + + # typically, this condition would belong inside of the + # `_save_nested_objects` switch. However, we don't want to recursively + # publish all custom objects. Instead we only want to do this at the + # top-most level if requested + if get_serializer_for_obj(val) is not None: + self._save_and_attach_ref(val) + return self._save_object_basic(val, name, branch) def _save_object_basic( - self, val: Any, name: str, branch: str = "latest" + self, val: Any, name: Optional[str] = None, branch: str = "latest" ) -> ObjectRef: # The WeaveTable case is special because object saving happens inside # _save_object_nested and it has a special table_ref -- skip it here. @@ -743,6 +756,14 @@ def _save_object_basic( return val json_val = to_json(val, self._project_id(), self.server) + if name is None: + if json_val.get("_type") == "CustomWeaveType": + custom_name = json_val.get("weave_type", {}).get("type") + name = custom_name + + if name is None: + raise ValueError("Name must be provided for object saving") + response = self.server.obj_create( ObjCreateReq( obj=ObjSchemaForInsert( @@ -778,7 +799,7 @@ def _save_nested_objects(self, obj: Any, name: Optional[str] = None) -> Any: self._save_nested_objects(v) ref = self._save_object_basic(obj_rec, name or get_obj_name(obj_rec)) obj.__dict__["ref"] = ref - elif dataclasses.is_dataclass(obj): + elif dataclasses.is_dataclass(obj) and not isinstance(obj, Ref): obj_rec = dataclass_object_record(obj) for v in obj_rec.__dict__.values(): self._save_nested_objects(v) @@ -808,11 +829,10 @@ def _save_nested_objects(self, obj: Any, name: Optional[str] = None) -> Any: @trace_sentry.global_trace_sentry.watch() def _save_table(self, table: Table) -> TableRef: + rows = to_json(table.rows, self._project_id(), self.server) response = self.server.table_create( TableCreateReq( - table=TableSchemaForInsert( - project_id=self._project_id(), rows=table.rows - ) + table=TableSchemaForInsert(project_id=self._project_id(), rows=rows) ) ) return TableRef( @@ -846,8 +866,16 @@ def _objects(self, filter: Optional[ObjectVersionFilter] = None) -> list[ObjSche def _save_op(self, op: Op, name: Optional[str] = None) -> Ref: if op.ref is not None: return op.ref + if name is None: name = op.name + + return self._save_and_attach_ref(op, name) + + def _save_and_attach_ref(self, op: Any, name: Optional[str] = None) -> Ref: + if (ref := getattr(op, "ref", None)) is not None: + return ref + op_def_ref = self._save_object_basic(op, name) # setattr(op, "ref", op_def_ref) fails here diff --git a/weave/weave_init.py b/weave/weave_init.py index b1744ff0ce1..d2af6bf5eb7 100644 --- a/weave/weave_init.py +++ b/weave/weave_init.py @@ -104,6 +104,8 @@ def init_weave( client = weave_client.WeaveClient( entity_name, project_name, remote_server, ensure_project_exists ) + # If the project name was formatted by init, update the project name + project_name = client.project _current_inited_client = InitializedClient(client) # entity_name, project_name = get_entity_project_from_project_name(project_name)