diff --git a/.github/workflows/check-which-tests-to-run.yaml b/.github/workflows/check-which-tests-to-run.yaml index ac81cd713a6..622ef7389f8 100644 --- a/.github/workflows/check-which-tests-to-run.yaml +++ b/.github/workflows/check-which-tests-to-run.yaml @@ -29,16 +29,33 @@ jobs: fetch-depth: 0 fetch-tags: true ref: ${{ github.head_ref }} - - name: Git setup + - name: Get changed files run: | - git fetch origin ${{ github.base_ref }} - if [ "${{ github.event_name }}" = "pull_request" ]; then - base_sha=$(git rev-parse origin/${{ github.base_ref }}) - head_sha=$(git rev-parse HEAD) - changed_files=$(git diff --name-only $base_sha $head_sha) + # Fetch all branches + git fetch --all + + # Determine the base branch and current commit + if [[ "${{ github.event_name }}" == "pull_request" ]]; then + # For pull requests + BASE_BRANCH="${{ github.base_ref }}" + CURRENT_COMMIT="${{ github.event.pull_request.head.sha }}" else - changed_files=$(git diff --name-only HEAD^) + # For pushes + BASE_BRANCH=$(git remote show origin | sed -n '/HEAD branch/s/.*: //p') + CURRENT_COMMIT="${{ github.sha }}" fi + echo "Base branch is $BASE_BRANCH" + + # Find the common ancestor + MERGE_BASE=$(git merge-base origin/$BASE_BRANCH $CURRENT_COMMIT) + + # Get changed files + changed_files=$(git diff --name-only $MERGE_BASE $CURRENT_COMMIT) + echo "Changed files:" + echo "$changed_files" + echo "changed_files<> $GITHUB_ENV + echo "$changed_files" >> $GITHUB_ENV + echo "EOF" >> $GITHUB_ENV - id: weave_query name: Weave Query Checks run: | @@ -62,7 +79,7 @@ jobs: - id: trace_server name: Weave Trace Server Checks run: | - for path in ${{ env.CORE_INTEGRATION_PATHS }}; do + for path in ${{ env.TRACE_SERVER_PATHS }}; do if echo "$changed_files" | grep -q "$path"; then echo "run_tests=true" >> $GITHUB_OUTPUT exit 0 diff --git a/.github/workflows/cla.yaml b/.github/workflows/cla.yaml index b7b343bf3bd..e1543843c96 100644 --- a/.github/workflows/cla.yaml +++ b/.github/workflows/cla.yaml @@ -27,4 +27,4 @@ jobs: # branch should not be protected branch: "cla" # cannot use teams due to: https://github.com/contributor-assistant/github-action/issues/100 - allowlist: actions-user, altay, bdytx5, dannygoldstein, davidwallacejackson, jamie-rasmussen, jlzhao27, jo-fang, jwlee64, laxels, morganmcg1, nickpenaranda, scottire, shawnlewis, staceysv, tssweeney, vanpelt, vwrj, wandbmachine + allowlist: actions-user, altay, andrewtruong, bdytx5, dannygoldstein, davidwallacejackson, jamie-rasmussen, jlzhao27, jo-fang, jwlee64, laxels, morganmcg1, nickpenaranda, scottire, shawnlewis, staceysv, tssweeney, vanpelt, vwrj, wandbmachine, weave@wandb.com diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2ee24c971a4..de8d7ac6be3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -20,7 +20,8 @@ jobs: name: Build Legacy (Query Service) test container timeout-minutes: 30 runs-on: [self-hosted, builder] - # runs-on: ubuntu-latest + outputs: + build_needed: ${{ steps.build_check.outputs.build_needed }} env: REGISTRY: us-east4-docker.pkg.dev/weave-support-367421/weave-images needs: check-which-tests-to-run @@ -70,7 +71,7 @@ jobs: matrix: job_num: [0, 1] # runs-on: ubuntu-latest - container: ${{ needs.build-container-query-service.outputs.build_needed == 'true' && 'us-east4-docker.pkg.dev/weave-support-367421/weave-images/weave-test-python-query-service:${{ github.sha }}' || 'ubuntu:latest' }} + container: ${{ needs.build-container-query-service.outputs.build_needed == 'true' && format('us-east4-docker.pkg.dev/weave-support-367421/weave-images/weave-test-python-query-service:{0}', github.sha) || null }} services: wandbservice: image: us-central1-docker.pkg.dev/wandb-production/images/local-testcontainer:master diff --git a/docs/docs/media/dspy_optimization/1.png b/docs/docs/media/dspy_optimization/1.png new file mode 100644 index 00000000000..e98f29936f7 Binary files /dev/null and b/docs/docs/media/dspy_optimization/1.png differ diff --git a/docs/docs/media/dspy_optimization/2.png b/docs/docs/media/dspy_optimization/2.png new file mode 100644 index 00000000000..4cbfba3d4ad Binary files /dev/null and b/docs/docs/media/dspy_optimization/2.png differ diff --git a/docs/docs/media/dspy_optimization/3.png b/docs/docs/media/dspy_optimization/3.png new file mode 100644 index 00000000000..51e67461a86 Binary files /dev/null and b/docs/docs/media/dspy_optimization/3.png differ diff --git a/docs/docs/media/dspy_optimization/4.png b/docs/docs/media/dspy_optimization/4.png new file mode 100644 index 00000000000..46884cf9ec1 Binary files /dev/null and b/docs/docs/media/dspy_optimization/4.png differ diff --git a/docs/docs/media/dspy_optimization/5.png b/docs/docs/media/dspy_optimization/5.png new file mode 100644 index 00000000000..be128063e58 Binary files /dev/null and b/docs/docs/media/dspy_optimization/5.png differ diff --git a/docs/docs/media/intro/1.png b/docs/docs/media/intro/1.png new file mode 100644 index 00000000000..fa91fa30735 Binary files /dev/null and b/docs/docs/media/intro/1.png differ diff --git a/docs/docs/media/intro/10.png b/docs/docs/media/intro/10.png new file mode 100644 index 00000000000..de246ec8a74 Binary files /dev/null and b/docs/docs/media/intro/10.png differ diff --git a/docs/docs/media/intro/2.png b/docs/docs/media/intro/2.png new file mode 100644 index 00000000000..03780ef1e69 Binary files /dev/null and b/docs/docs/media/intro/2.png differ diff --git a/docs/docs/media/intro/3.png b/docs/docs/media/intro/3.png new file mode 100644 index 00000000000..1f3342a2ad5 Binary files /dev/null and b/docs/docs/media/intro/3.png differ diff --git a/docs/docs/media/intro/4.png b/docs/docs/media/intro/4.png new file mode 100644 index 00000000000..7747a8943cd Binary files /dev/null and b/docs/docs/media/intro/4.png differ diff --git a/docs/docs/media/intro/5.png b/docs/docs/media/intro/5.png new file mode 100644 index 00000000000..beddf178d9d Binary files /dev/null and b/docs/docs/media/intro/5.png differ diff --git a/docs/docs/media/intro/6.png b/docs/docs/media/intro/6.png new file mode 100644 index 00000000000..6aafda16c41 Binary files /dev/null and b/docs/docs/media/intro/6.png differ diff --git a/docs/docs/media/intro/7.png b/docs/docs/media/intro/7.png new file mode 100644 index 00000000000..a71e9467605 Binary files /dev/null and b/docs/docs/media/intro/7.png differ diff --git a/docs/docs/media/intro/8.png b/docs/docs/media/intro/8.png new file mode 100644 index 00000000000..9c35e0124a8 Binary files /dev/null and b/docs/docs/media/intro/8.png differ diff --git a/docs/docs/media/intro/9.png b/docs/docs/media/intro/9.png new file mode 100644 index 00000000000..1a69ca4505d Binary files /dev/null and b/docs/docs/media/intro/9.png differ diff --git a/docs/docs/media/summarization/dataset.png b/docs/docs/media/summarization/dataset.png new file mode 100644 index 00000000000..dae068d9124 Binary files /dev/null and b/docs/docs/media/summarization/dataset.png differ diff --git a/docs/docs/media/summarization/eval_dash.png b/docs/docs/media/summarization/eval_dash.png new file mode 100644 index 00000000000..66d4c94aa11 Binary files /dev/null and b/docs/docs/media/summarization/eval_dash.png differ diff --git a/docs/docs/media/summarization/model.png b/docs/docs/media/summarization/model.png new file mode 100644 index 00000000000..f05d5cc85fe Binary files /dev/null and b/docs/docs/media/summarization/model.png differ diff --git a/docs/docs/media/summarization/summarization_trace.png b/docs/docs/media/summarization/summarization_trace.png new file mode 100644 index 00000000000..428339ea621 Binary files /dev/null and b/docs/docs/media/summarization/summarization_trace.png differ diff --git a/docs/docs/reference/gen_notebooks/01-intro_notebook.md b/docs/docs/reference/gen_notebooks/01-intro_notebook.md index 17e9332397e..20639f24fd7 100644 --- a/docs/docs/reference/gen_notebooks/01-intro_notebook.md +++ b/docs/docs/reference/gen_notebooks/01-intro_notebook.md @@ -72,6 +72,8 @@ weave.init('project-name') # initialize tracking for a specific W&B project Add the @weave.op decorator to the functions you want to track +![](../../media/intro/1.png) + ```python from openai import OpenAI @@ -102,6 +104,8 @@ You can find your interactive dashboard by clicking any of the 👆 wandb links Here, we're automatically tracking all calls to `openai`. We automatically track a lot of LLM libraries, but it's really easy to add support for whatever LLM you're using, as you'll see below. +![](../../media/intro/2.png) + ```python import weave @@ -128,6 +132,8 @@ Now that you've seen the basics, let's combine all of the above and track some d +![](../../media/intro/3.png) + ```python from openai import OpenAI @@ -169,6 +175,8 @@ print(result) Whenever your code crashes, weave will highlight what caused the issue. This is especially useful for finding things like JSON parsing issues that can occasionally happen when parsing data from LLM responses. +![](../../media/intro/4.png) + ```python import json @@ -221,6 +229,8 @@ Organizing experimentation is difficult when there are many moving pieces. You c Many times, it is useful to track & version data, just like you track and version code. For example, here we define a `SystemPrompt(weave.Object)` object that can be shared between teammates +![](../../media/intro/5.png) + ```python import weave @@ -242,6 +252,8 @@ weave.publish(system_prompt) Models are so common of an object type, that we have a special class to represent them: `weave.Model`. The only requirement is that we define a `predict` method. +![](../../media/intro/6.png) + ```python from openai import OpenAI @@ -283,6 +295,8 @@ print(result) Similar to models, a `weave.Dataset` object exists to help track, organize, and operate on datasets +![](../../media/intro/7.png) + ```python dataset = weave.Dataset( @@ -309,6 +323,8 @@ Notice that we saved a versioned `GrammarCorrector` object that captures the con You can publish objects and then retrieve them in your code. You can even call functions from your retrieved objects! +![](../../media/intro/8.png) + ```python import weave @@ -324,6 +340,8 @@ ref = weave.publish(corrector) print(ref.uri()) ``` +![](../../media/intro/9.png) + ```python import weave @@ -346,6 +364,8 @@ Evaluation-driven development helps you reliably iterate on an application. The See a preview of the API below: +![](../../media/intro/10.png) + ```python import weave diff --git a/docs/docs/reference/gen_notebooks/chain_of_density.md b/docs/docs/reference/gen_notebooks/chain_of_density.md index caa9e6da805..9e9205e1d72 100644 --- a/docs/docs/reference/gen_notebooks/chain_of_density.md +++ b/docs/docs/reference/gen_notebooks/chain_of_density.md @@ -20,6 +20,8 @@ title: Chain of Density Summarization Summarizing complex technical documents while preserving crucial details is a challenging task. The Chain of Density (CoD) summarization technique offers a solution by iteratively refining summaries to be more concise and information-dense. This guide demonstrates how to implement CoD using Weave for tracking and evaluating the application. +![](../../media/summarization/eval_dash.png) + ## What is Chain of Density Summarization? [![arXiv](https://img.shields.io/badge/arXiv-2309.04269-b31b1b.svg)](https://arxiv.org/abs/2309.04269) @@ -139,6 +141,8 @@ def load_pdf(pdf_url: str) -> str: Now, let's implement the core CoD summarization logic using Weave operations: +![](../../media/summarization/summarization_trace.png) + ```python # Chain of Density Summarization @@ -231,6 +235,8 @@ By using `@weave.op()` decorators, we ensure that Weave tracks the inputs, outpu Now, let's wrap our summarization pipeline in a Weave Model: +![](../../media/summarization/model.png) + ```python # Weave Model @@ -240,7 +246,7 @@ class ArxivChainOfDensityPipeline(weave.Model): @weave.op() def predict(self, paper: ArxivPaper, instruction: str) -> dict: - text = load_pdf(paper["pdf_url"]) + text = load_pdf(paper.pdf_url) result = chain_of_density_summarization( text, instruction, @@ -320,6 +326,8 @@ These evaluation functions use the Claude model to assess the quality of the gen To evaluate our pipeline, we'll create a Weave Dataset and run an evaluation: +![](../../media/summarization/dataset.png) + ```python # Create a Weave Dataset @@ -340,6 +348,8 @@ For our evaluation, we'll use an LLM-as-a-judge approach. This technique involve [![arXiv](https://img.shields.io/badge/arXiv-2306.05685-b31b1b.svg)](https://arxiv.org/abs/2306.05685) +![](../../media/summarization/eval_dash.png) + ```python # Define the scorer function diff --git a/docs/docs/reference/gen_notebooks/dspy_prompt_optimization.md b/docs/docs/reference/gen_notebooks/dspy_prompt_optimization.md index d39ccbaa14e..878d7c2dac8 100644 --- a/docs/docs/reference/gen_notebooks/dspy_prompt_optimization.md +++ b/docs/docs/reference/gen_notebooks/dspy_prompt_optimization.md @@ -125,6 +125,8 @@ def get_dataset(metadata: Metadata): dspy_train_examples, dspy_val_examples = get_dataset(metadata) ``` +![](../../media/dspy_optimization/1.png) + ## The DSPy Program [DSPy](https://dspy-docs.vercel.app) is a framework that pushes building new LM pipelines away from manipulating free-form strings and closer to programming (composing modular operators to build text transformation graphs) where a compiler automatically generates optimized LM invocation strategies and prompts from a program. @@ -189,6 +191,8 @@ prediction = baseline_module(dspy_train_examples[0]["question"]) rich.print(prediction) ``` +![](../../media/dspy_optimization/2.png) + ## Evaluating our DSPy Program Now that we have a baseline prompting strategy, let's evaluate it on our validation set using [`weave.Evaluation`](../../guides/core-types/evaluations.md) on a simple metric that matches the predicted answer with the ground truth. Weave will take each example, pass it through your application and score the output on multiple custom scoring functions. By doing this, you'll have a view of the performance of your application, and a rich UI to drill into individual outputs and scores. @@ -219,6 +223,8 @@ evaluation = weave.Evaluation( await evaluation.evaluate(baseline_module.forward) ``` +![](../../media/dspy_optimization/3.png) + :::note If you're running from a python script, you can use the following code to run the evaluation: @@ -258,6 +264,8 @@ def get_optimized_program(model: dspy.Module, metadata: Metadata) -> dspy.Module optimized_module = get_optimized_program(baseline_module, metadata) ``` +![](../../media/dspy_optimization/4.png) + :::warning Running the evaluation causal reasoning dataset will cost approximately $0.04 in OpenAI credits. ::: @@ -275,6 +283,8 @@ evaluation = weave.Evaluation( await evaluation.evaluate(optimized_module.forward) ``` +![](../../media/dspy_optimization/5.png) + When coomparing the evalution of the baseline program with the optimized one shows that the optimized program answers the causal reasoning questions with siginificantly more accuracy. ## Conclusion diff --git a/docs/docs/reference/gen_notebooks/online_monitoring.md b/docs/docs/reference/gen_notebooks/online_monitoring.md index 27d033a8ca4..bfe048de9bb 100644 --- a/docs/docs/reference/gen_notebooks/online_monitoring.md +++ b/docs/docs/reference/gen_notebooks/online_monitoring.md @@ -62,19 +62,25 @@ MODEL_NAMES = [ ("gpt-4o-mini", 0.03, 0.06), ("gpt-4-turbo", 0.03, 0.06), ("claude-3-haiku-20240307", 0.01, 0.03), - ("gpt-4o", 0.03, 0.06) + ("gpt-4o", 0.03, 0.06), ] + def init_weave_client(project_name): try: client = weave.init(project_name) for model, prompt_cost, completion_cost in MODEL_NAMES: - client.add_cost(llm_id=model, prompt_token_cost=prompt_cost, completion_token_cost=completion_cost) + client.add_cost( + llm_id=model, + prompt_token_cost=prompt_cost, + completion_token_cost=completion_cost, + ) return client except Exception as e: print(f"Failed to initialize Weave client for project '{project_name}': {e}") return None - + + client = init_weave_client(PROJECT_NAME) ``` @@ -96,9 +102,11 @@ The first option to access data from Weave is to retrieve a list of filtered cal ```python import itertools -import pandas as pd from datetime import datetime, timedelta +import pandas as pd + + def fetch_calls(client, project_id, start_time, trace_roots_only, limit): filter_params = { "project_id": project_id, @@ -110,13 +118,16 @@ def fetch_calls(client, project_id, start_time, trace_roots_only, limit): } try: calls_stream = client.server.calls_query_stream(filter_params) - calls = list(itertools.islice(calls_stream, limit)) # limit the number of calls to fetch if too many + calls = list( + itertools.islice(calls_stream, limit) + ) # limit the number of calls to fetch if too many print(f"Fetched {len(calls)} calls.") return calls except Exception as e: print(f"Error fetching calls: {e}") return [] - + + calls = fetch_calls(client, PROJECT_NAME, datetime.now() - timedelta(days=1), True, 100) ``` @@ -131,28 +142,40 @@ Processing the calls is very easy with the return from Weave - we'll extract the ```python import json -import pandas as pd from datetime import datetime +import pandas as pd + + def process_calls(calls): records = [] for call in calls: feedback = call.summary.get("weave", {}).get("feedback", []) - thumbs_up = sum(1 for item in feedback if isinstance(item, dict) and item.get("payload", {}).get("emoji") == "👍") - thumbs_down = sum(1 for item in feedback if isinstance(item, dict) and item.get("payload", {}).get("emoji") == "👎") + thumbs_up = sum( + 1 + for item in feedback + if isinstance(item, dict) and item.get("payload", {}).get("emoji") == "👍" + ) + thumbs_down = sum( + 1 + for item in feedback + if isinstance(item, dict) and item.get("payload", {}).get("emoji") == "👎" + ) latency = call.summary.get("weave", {}).get("latency_ms", 0) - - records.append({ - "Call ID": call.id, - "Trace ID": call.trace_id, # this is a unique ID for the trace that can be used to retrieve it - "Display Name": call.display_name, # this is an optional name you can set in the UI or programatically - "Latency (ms)": latency, - "Thumbs Up": thumbs_up, - "Thumbs Down": thumbs_down, - "Started At": pd.to_datetime(getattr(call, 'started_at', datetime.min)), - "Inputs": json.dumps(call.inputs, default=str), - "Outputs": json.dumps(call.output, default=str) - }) + + records.append( + { + "Call ID": call.id, + "Trace ID": call.trace_id, # this is a unique ID for the trace that can be used to retrieve it + "Display Name": call.display_name, # this is an optional name you can set in the UI or programatically + "Latency (ms)": latency, + "Thumbs Up": thumbs_up, + "Thumbs Down": thumbs_down, + "Started At": pd.to_datetime(getattr(call, "started_at", datetime.min)), + "Inputs": json.dumps(call.inputs, default=str), + "Outputs": json.dumps(call.output, default=str), + } + ) return pd.DataFrame(records) ``` @@ -171,7 +194,9 @@ For example, for the cost, we'll use the `query_costs` API to fetch the costs of # Use cost API to get costs costs = client.query_costs() df_costs = pd.DataFrame([cost.dict() for cost in costs]) -df_costs['total_cost'] = df_costs['prompt_token_cost'] + df_costs['completion_token_cost'] +df_costs["total_cost"] = ( + df_costs["prompt_token_cost"] + df_costs["completion_token_cost"] +) # only show the first row for every unqiue llm_id df_costs @@ -185,17 +210,35 @@ Next, we can generate the visualizations using plotly. This is the most basic da import plotly.express as px import plotly.graph_objects as go + def plot_feedback_pie_chart(thumbs_up, thumbs_down): - fig = go.Figure(data=[go.Pie(labels=['Thumbs Up', 'Thumbs Down'], values=[thumbs_up, thumbs_down], marker=dict(colors=['#66b3ff', '#ff9999']), hole=.3)]) - fig.update_traces(textinfo='percent+label', hoverinfo='label+percent') + fig = go.Figure( + data=[ + go.Pie( + labels=["Thumbs Up", "Thumbs Down"], + values=[thumbs_up, thumbs_down], + marker=dict(colors=["#66b3ff", "#ff9999"]), + hole=0.3, + ) + ] + ) + fig.update_traces(textinfo="percent+label", hoverinfo="label+percent") fig.update_layout(showlegend=False, title="Feedback Summary") return fig + def plot_model_cost_distribution(df): - fig = px.bar(df, x="llm_id", y="total_cost", color="llm_id", title="Cost Distribution by Model") + fig = px.bar( + df, + x="llm_id", + y="total_cost", + color="llm_id", + title="Cost Distribution by Model", + ) fig.update_layout(xaxis_title="Model", yaxis_title="Cost (USD)") return fig + # See the source code for all the plots ``` diff --git a/docs/docs/reference/gen_notebooks/pii.md b/docs/docs/reference/gen_notebooks/pii.md index f7688c35d20..e042d0ea738 100644 --- a/docs/docs/reference/gen_notebooks/pii.md +++ b/docs/docs/reference/gen_notebooks/pii.md @@ -18,14 +18,52 @@ title: Handling and Redacting PII # How to use Weave with PII data: -In this tutorial, we'll demonstrate how to utilize Weave while preventing your Personally Identifiable Information (PII) data from being incorporated into Weave or the LLMs you employ. +In this tutorial, we'll demonstrate how to utilize Weave while ensuring your Personally Identifiable Information (PII) data remains private. Weave supports removing PII from LLM calls and preventing PII from being displayed in the Weave UI. -To protect our PII data, we'll employ a couple techniques. First, we'll use regular expressions to identify PII data and redact it. Second, we'll use Microsoft's [Presidio](https://microsoft.github.io/presidio/), a python-based data protection SDK. This tool provides redaction and replacement functionalities, both of which we will implement in this tutorial. +To detect and protect our PII data, we'll identify and redact PII data and optionally anonymize it with the following methods: +1. __Regular expressions__ to identify PII data and redact it. +2. __Microsoft's [Presidio](https://microsoft.github.io/presidio/)__, a python-based data protection SDK. This tool provides redaction and replacement functionalities. +3. __[Faker](https://faker.readthedocs.io/en/master/)__, a Python library to generate fake data, combined with Presidio to anonymize PII data. -For this use-case. We will leverage Anthropic's Claude Sonnet to perform sentiment analysis. While we use Weave's [Traces](https://wandb.github.io/weave/quickstart) to track and analize the LLM's API calls. Sonnet will receive a block of text and output one of the following sentiment classifications: -1. positive -2. negative -3. neutral +Additionally, we'll make use of _Weave Ops input/output logging customization_ to seamlessly integrate PII redaction and anonymization into the workflow. See [here](https://weave-docs.wandb.ai/guides/tracking/ops/#customize-logged-inputs-and-outputs) for more information. + +For this use-case, we will leverage Anthropic's Claude Sonnet to perform sentiment analysis while tracing the LLM calls using Weave's [Traces](https://wandb.github.io/weave/quickstart). Sonnet will receive a block of text and output one of the following sentiment classifications: _positive_, _negative_, or _neutral_. + +## Overview of Weave Ops Input/Output Logging Customization + +Weave Ops support defining input and output postprocessing functions. These functions allow you to modify the data that is passed to your LLM call or logged to Weave, respectively. + +```python +from dataclasses import dataclass +from typing import Any + +import weave + +# Inputs Wrapper Class +@dataclass +class CustomObject: + x: int + secret_password: str + +# First we define functions for input and output postprocessing: +def postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: + return {k:v for k,v in inputs.items() if k != "hide_me"} + +def postprocess_output(output: CustomObject) -> CustomObject: + return CustomObject(x=output.x, secret_password="REDACTED") + +# Then, when we use the `@weave.op` decorator, we pass these processing functions as arguments to the decorator: +@weave.op( + postprocess_inputs=postprocess_inputs, + postprocess_output=postprocess_output, +) +def some_llm_call(a: int, hide_me: str) -> CustomObject: + return CustomObject(x=a, secret_password=hide_me) +``` + +# Setup + +Let's install the required packages and set up our API keys. Your Weights & Biases API key can be found [here](https://wandb.ai/authorize), and your Anthropic API keys are [here](https://console.anthropic.com/settings/keys). ```python @@ -33,19 +71,20 @@ For this use-case. We will leverage Anthropic's Claude Sonnet to perform sentime # @title required python packages: !pip install presidio_analyzer !pip install presidio_anonymizer -!python -m spacy download en_core_web_lg # Presidio uses spacy NLP engine -!pip install Faker # we'll use Faker to replace PII data with fake data -!pip install weave # To leverage Traces +!python -m spacy download en_core_web_lg # Presidio uses spacy NLP engine +!pip install Faker # we'll use Faker to replace PII data with fake data +!pip install weave # To leverage Traces !pip install set-env-colab-kaggle-dotenv -q # for env var !pip install anthropic # to use sonnet !pip install cryptography # to encrypt our data ``` -# Setup - ```python +%%capture # @title Make sure to set up set up your API keys correctly +# See: https://pypi.org/project/set-env-colab-kaggle-dotenv/ for usage instructions. + from set_env import set_env _ = set_env("ANTHROPIC_API_KEY") @@ -56,6 +95,7 @@ _ = set_env("WANDB_API_KEY") ```python import weave +# Start a new Weave project WEAVE_PROJECT = "pii_cookbook" weave.init(WEAVE_PROJECT) ``` @@ -69,25 +109,13 @@ import requests url = "https://raw.githubusercontent.com/wandb/weave/master/docs/notebooks/10_pii_data.json" response = requests.get(url) pii_data = response.json() -``` -# Using Weave Safely with PII Data - -## During Testing -- Log anonymized data to check PII detection -- Track PII handling processes with Weave traces -- Measure anonymization performance without exposing real PII - -## In Production -- Never log raw PII -- Encrypt sensitive fields before logging +print('PII data first sample: "' + pii_data[0]["text"] + '"') +``` -## Encryption Tips -- Use reversible encryption for data you need to decrypt later -- Apply one-way hashing for unique IDs you don't need to reverse -- Consider specialized encryption for data you need to analyze while encrypted +# Redaction Methods Implementation -# Method 1: +## Method 1: Regular Expression Filtering Our initial method is to use [regular expressions (regex)](https://docs.python.org/3/library/re.html) to identify PII data and redact it. It allows us to define patterns that can match various formats of sensitive information like phone numbers, email addresses, and social security numbers. By using regex, we can scan through large volumes of text and replace or redact information without the need for more complex NLP techniques. @@ -96,7 +124,8 @@ Our initial method is to use [regular expressions (regex)](https://docs.python.o import re -def clean_pii_with_regex(text): +# Define a function to clean PII data using regex +def redact_with_regex(text): # Phone number pattern # \b : Word boundary # \d{3} : Exactly 3 digits @@ -140,164 +169,79 @@ def clean_pii_with_regex(text): text = re.sub(r"\b[A-Z][a-z]+ [A-Z][a-z]+\b", "", text) return text - - -# Test the function -test_text = "My name is John Doe, my email is john.doe@example.com, my phone is 123-456-7890, and my SSN is 123-45-6789." -cleaned_text = clean_pii_with_regex(test_text) -print(cleaned_text) ``` -# Method 2: Microsoft Presidio - -In this example, we'll create a [Weave Model](https://wandb.github.io/weave/guides/core-types/models) which is a combination of data (which can include configuration, trained model weights, or other information) and code that defines how the model operates. -In this model, we will include our predict function where the Anthropic API will be called. - -Once you run this code you will receive a link to the Weave project page +Let's test the function with a sample text: ```python -import json - -from anthropic import AsyncAnthropic - -import weave - - -# Weave model / predict function -class sentiment_analysis_model(weave.Model): - model_name: str - system_prompt: str - temperature: int - - @weave.op() - async def predict(self, text_block: str) -> dict: - client = AsyncAnthropic() - - response = await client.messages.create( - max_tokens=1024, - model=self.model_name, - system=self.system_prompt, - messages=[ - {"role": "user", "content": [{"type": "text", "text": text_block}]} - ], - ) - result = response.content[0].text - if result is None: - raise ValueError("No response from model") - parsed = json.loads(result) - return parsed - - # create our LLM model with a system prompt - - -model = sentiment_analysis_model( - name="claude-3-sonnet", - model_name="claude-3-5-sonnet-20240620", - system_prompt='You are a Sentiment Analysis classifier. You will be classifying text based on their sentiment. Your input will be a block of text. You will answer with one the following rating option["positive", "negative", "neutral"]. Your answer should be one word in json format: {classification}. Ensure that it is valid JSON.', - temperature=0, -) +# Test the function +test_text = "My name is John Doe, my email is john.doe@example.com, my phone is 123-456-7890, and my SSN is 123-45-6789." +cleaned_text = redact_with_regex(test_text) +print(f"Raw text:\n\t{test_text}") +print(f"Redacted text:\n\t{cleaned_text}") ``` -# Method 2A: -Our next method involves complete removal of PII data using Presidio. This approach redacts PII and replaces it with a placeholder representing the PII type. For example: -``` - "My name is Alex" -``` +## Method 2: Microsoft Presidio Redaction +Our next method involves complete removal of PII data using Presidio. This approach redacts PII and replaces it with a placeholder representing the PII type. -Will be: +For example: +`"My name is Alex"` becomes `"My name is "`. -``` - "My name is " -``` -Presidio comes with a built-in [list of recognizable entities](https://microsoft.github.io/presidio/supported_entities/). We can select the ones that are important for our use case. In the below example, we are only looking at redicating names and phone numbers from our text: +Presidio comes with a built-in [list of recognizable entities](https://microsoft.github.io/presidio/supported_entities/). We can select the ones that are important for our use case. In the below example, we redact names, phone numbers, locations, email addresses, and US Social Security Numbers. -![](../../media/pii/redact.png) +We'll then encapsulate the Presidio process into a function. ```python from presidio_analyzer import AnalyzerEngine from presidio_anonymizer import AnonymizerEngine -text = "My phone number is 212-555-5555 and my name is alex" - -# Set up the engine, loads the NLP module (spaCy model by default) -# and other PII recognizers +# Set up the Analyzer, which loads an NLP module (spaCy model by default) and other PII recognizers. analyzer = AnalyzerEngine() -# Call analyzer to get results -results = analyzer.analyze( - text=text, entities=["PHONE_NUMBER", "PERSON"], language="en" -) - -# Analyzer results are passed to the AnonymizerEngine for anonymization - +# Set up the Anonymizer, which will use the analyzer results to anonymize the text. anonymizer = AnonymizerEngine() -anonymized_text = anonymizer.anonymize(text=text, analyzer_results=results) - -print(anonymized_text) -``` - -Let's encapsulate the previous step into a function and expand the entity recognition capabilities. We will expand our redaction scope to include addresses, email addresses, and US Social Security numbers. - -```python -from presidio_analyzer import AnalyzerEngine -from presidio_anonymizer import AnonymizerEngine - -analyzer = AnalyzerEngine() -anonymizer = AnonymizerEngine() -""" -The below function will take a block of text, process it using presidio -and return a block of text with the PII data redicated. -PII data to be redicated: -- Phone Numbers -- Names -- Addresses -- Email addresses -- US Social Security Numbers -""" - - -def anonymize_my_text(text): +# Encapsulate the Presidio redaction process into a function +def redact_with_presidio(text): + # Analyze the text to identify PII data results = analyzer.analyze( text=text, entities=["PHONE_NUMBER", "PERSON", "LOCATION", "EMAIL_ADDRESS", "US_SSN"], language="en", ) + # Anonymize the identified PII data anonymized_text = anonymizer.anonymize(text=text, analyzer_results=results) return anonymized_text.text ``` +Let's test the function with a sample text: + ```python -# for every block of text, anonymized first and then predict -for entry in pii_data: - anonymized_entry = anonymize_my_text(entry["text"]) - (await model.predict(anonymized_entry)) -``` +text = "My phone number is 212-555-5555 and my name is alex" -# Method 2B: Replace PII data with fake data +# Test the function +anonymized_text = redact_with_presidio(text) -Instead of redacting text, we can anonymize it by swapping PII (like names and phone numbers) with fake data generated using the [Faker](https://faker.readthedocs.io/en/master/) Python library. For example: +print(f"Raw text:\n\t{text}") +print(f"Redacted text:\n\t{anonymized_text}") +``` +## Method 3: Anonymization with Replacement using Fakr and Presidio -``` -"My name is Raphael and I like to fish. My phone number is 212-555-5555" -``` -Will be: +Instead of redacting text, we can anonymize it by swapping PII (like names and phone numbers) with fake data generated using the [Faker](https://faker.readthedocs.io/en/master/) Python library. For example: +`"My name is Raphael and I like to fish. My phone number is 212-555-5555"` -``` -"My name is Katherine Dixon and I like to fish. My phone number is 667.431.7379" +might become -``` +`"My name is Katherine Dixon and I like to fish. My phone number is 667.431.7379"` To effectively utilize Presidio, we must supply references to our custom operators. These operators will direct Presidio to the functions responsible for swapping PII with fake data. -![](../../media/pii/replace.png) - ```python from faker import Faker @@ -322,7 +266,6 @@ operators = { "PHONE_NUMBER": OperatorConfig("custom", {"lambda": fake_number}), } - text_to_anonymize = ( "My name is Raphael and I like to fish. My phone number is 212-555-5555" ) @@ -332,7 +275,6 @@ analyzer_results = analyzer.analyze( text=text_to_anonymize, entities=["PHONE_NUMBER", "PERSON"], language="en" ) - anonymizer = AnonymizerEngine() # do not forget to pass the operators from above to the anonymizer @@ -340,22 +282,20 @@ anonymized_results = anonymizer.anonymize( text=text_to_anonymize, analyzer_results=analyzer_results, operators=operators ) -print(anonymized_results.text) +print(f"Raw text:\n\t{text_to_anonymize}") +print(f"Anonymized text:\n\t{anonymized_results.text}") ``` Let's consolidate our code into a single class and expand the list of entities to include the additional ones we identified earlier. ```python -from anthropic import AsyncAnthropic from faker import Faker from presidio_anonymizer import AnonymizerEngine from presidio_anonymizer.entities import OperatorConfig -import weave - -# Let's build a custom class for generating fake data that will extend Faker +# A custom class for generating fake data that extends Faker class my_faker(Faker): # Create faker functions (note that it has to receive a value) def fake_address(x): @@ -382,7 +322,7 @@ class my_faker(Faker): "US_SSN": OperatorConfig("custom", {"lambda": fake_ssn}), } - def anonymize_my_text(self, text): + def redact_and_anonymize_with_faker(self, text): anonymizer = AnonymizerEngine() analyzer_results = analyzer.analyze( text=text, @@ -393,17 +333,243 @@ class my_faker(Faker): text=text, analyzer_results=analyzer_results, operators=self.operators ) return anonymized_results.text +``` + +Let's test the function with a sample text: + + +```python +faker = my_faker() +text_to_anonymize = ( + "My name is Raphael and I like to fish. My phone number is 212-555-5555" +) +anonymized_text = faker.redact_and_anonymize_with_faker(text_to_anonymize) + +print(f"Raw text:\n\t{text_to_anonymize}") +print(f"Anonymized text:\n\t{anonymized_text}") +``` + +# Applying the Methods to Weave Calls + +In these examples we will integrate our PII redaction and anonymization methods into Weave Models, and preview the results in Weave Traces. + +We'll create a [Weave Model](https://wandb.github.io/weave/guides/core-types/models) which is a combination of data (which can include configuration, trained model weights, or other information) and code that defines how the model operates. + +In this model, we will include our predict function where the Anthropic API will be called. Additionally, we will include our postprocessing functions to ensure that our PII data is redacted or anonymized before it is sent to the LLM. + +Once you run this code you will receive a links to the Weave project page as well as the specific trace (LLM calls)you ran. + +## Regex Method + +In the simplest case, we can use regex to identify and redact PII data in the original text. + + +```python +import json +from typing import Any + +import anthropic + +import weave + +# Define an input postprocessing function that applies our regex redaction for the model prediction Weave Op +def postprocess_inputs_regex(inputs: dict[str, Any]) -> dict: + inputs["text_block"] = redact_with_regex(inputs["text_block"]) + return inputs + +# Weave model / predict function +class sentiment_analysis_regex_pii_model(weave.Model): + model_name: str + system_prompt: str + temperature: int + + @weave.op( + postprocess_inputs=postprocess_inputs_regex, + ) + async def predict(self, text_block: str) -> dict: + client = anthropic.AsyncAnthropic() + response = await client.messages.create( + max_tokens=1024, + model=self.model_name, + system=self.system_prompt, + messages=[ + {"role": "user", "content": [{"type": "text", "text": text_block}]} + ], + ) + result = response.content[0].text + if result is None: + raise ValueError("No response from model") + parsed = json.loads(result) + return parsed +``` + + +```python +# create our LLM model with a system prompt +model = sentiment_analysis_regex_pii_model( + name="claude-3-sonnet", + model_name="claude-3-5-sonnet-20240620", + system_prompt='You are a Sentiment Analysis classifier. You will be classifying text based on their sentiment. Your input will be a block of text. You will answer with one the following rating option["positive", "negative", "neutral"]. Your answer should be one word in json format: {classification}. Ensure that it is valid JSON.', + temperature=0, +) + +print("Model: ", model) +# for every block of text, anonymized first and then predict +for entry in pii_data: + await model.predict(entry["text"]) +``` + +## Presidio Redaction Method + +Here we will use Presidio to identify and redact PII data in the original text. + +![](../../media/pii/redact.png) + + +```python +import json +from typing import Any + +import anthropic + +import weave + + +# Define an input postprocessing function that applies our Presidio redaction for the model prediction Weave Op +def postprocess_inputs_presidio(inputs: dict[str, Any]) -> dict: + inputs["text_block"] = redact_with_presidio(inputs["text_block"]) + return inputs + + +# Weave model / predict function +class sentiment_analysis_presidio_pii_model(weave.Model): + model_name: str + system_prompt: str + temperature: int + + @weave.op( + postprocess_inputs=postprocess_inputs_presidio, + ) + async def predict(self, text_block: str) -> dict: + client = anthropic.AsyncAnthropic() + response = await client.messages.create( + max_tokens=1024, + model=self.model_name, + system=self.system_prompt, + messages=[ + {"role": "user", "content": [{"type": "text", "text": text_block}]} + ], + ) + result = response.content[0].text + if result is None: + raise ValueError("No response from model") + parsed = json.loads(result) + return parsed +``` + + +```python +# create our LLM model with a system prompt +model = sentiment_analysis_presidio_pii_model( + name="claude-3-sonnet", + model_name="claude-3-5-sonnet-20240620", + system_prompt='You are a Sentiment Analysis classifier. You will be classifying text based on their sentiment. Your input will be a block of text. You will answer with one the following rating option["positive", "negative", "neutral"]. Your answer should be one word in json format: {classification}. Ensure that it is valid JSON.', + temperature=0, +) + +print("Model: ", model) +# for every block of text, anonymized first and then predict +for entry in pii_data: + await model.predict(entry["text"]) +``` + +## Faker + Presidio Replacement Method + +Here we will have Faker generate anonymized replacement PII data and use Presidio to identify and replace the PII data in the original text. + + +![](../../media/pii/replace.png) + + +```python +import json +from typing import Any + +import anthropic + +import weave + +# Define an input postprocessing function that applies our Faker anonymization and Presidio redaction for the model prediction Weave Op faker = my_faker() + + +def postprocess_inputs_faker(inputs: dict[str, Any]) -> dict: + inputs["text_block"] = faker.redact_and_anonymize_with_faker(inputs["text_block"]) + return inputs + + +# Weave model / predict function +class sentiment_analysis_faker_pii_model(weave.Model): + model_name: str + system_prompt: str + temperature: int + + @weave.op( + postprocess_inputs=postprocess_inputs_faker, + ) + async def predict(self, text_block: str) -> dict: + client = anthropic.AsyncAnthropic() + response = await client.messages.create( + max_tokens=1024, + model=self.model_name, + system=self.system_prompt, + messages=[ + {"role": "user", "content": [{"type": "text", "text": text_block}]} + ], + ) + result = response.content[0].text + if result is None: + raise ValueError("No response from model") + parsed = json.loads(result) + return parsed +``` + + +```python +# create our LLM model with a system prompt +model = sentiment_analysis_faker_pii_model( + name="claude-3-sonnet", + model_name="claude-3-5-sonnet-20240620", + system_prompt='You are a Sentiment Analysis classifier. You will be classifying text based on their sentiment. Your input will be a block of text. You will answer with one the following rating option["positive", "negative", "neutral"]. Your answer should be one word in json format: {classification}. Ensure that it is valid JSON.', + temperature=0, +) + +print("Model: ", model) +# for every block of text, anonymized first and then predict for entry in pii_data: - anonymized_entry = faker.anonymize_my_text(entry["text"]) - (await model.predict(anonymized_entry)) + await model.predict(entry["text"]) ``` +## Checklist for Safely Using Weave with PII Data + +### During Testing +- Log anonymized data to check PII detection +- Track PII handling processes with Weave Traces +- Measure anonymization performance without exposing real PII + +### In Production +- Never log raw PII +- Encrypt sensitive fields before logging + +### Encryption Tips +- Use reversible encryption for data you need to decrypt later +- Apply one-way hashing for unique IDs you don't need to reverse +- Consider specialized encryption for data you need to analyze while encrypted +
(Optional) Encrypting our data - ![](../../media/pii/encrypt.png) In addition to anonymizing PII, we can add an extra layer of security by encrypting our data using the cryptography library's [Fernet](https://cryptography.io/en/latest/fernet/) symmetric encryption. This approach ensures that even if the anonymized data is intercepted, it remains unreadable without the encryption key. diff --git a/docs/intro_notebook.ipynb b/docs/intro_notebook.ipynb index 35af77fa213..c491787657f 100644 --- a/docs/intro_notebook.ipynb +++ b/docs/intro_notebook.ipynb @@ -103,6 +103,14 @@ "Add the @weave.op decorator to the functions you want to track" ] }, + { + "cell_type": "markdown", + "id": "32c013bc", + "metadata": {}, + "source": [ + "![](../../media/intro/1.png)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -150,6 +158,14 @@ "Here, we're automatically tracking all calls to `openai`. We automatically track a lot of LLM libraries, but it's really easy to add support for whatever LLM you're using, as you'll see below. " ] }, + { + "cell_type": "markdown", + "id": "4dc2d909", + "metadata": {}, + "source": [ + "![](../../media/intro/2.png)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -192,6 +208,14 @@ "\n" ] }, + { + "cell_type": "markdown", + "id": "12213633", + "metadata": {}, + "source": [ + "![](../../media/intro/3.png)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -244,6 +268,14 @@ "Whenever your code crashes, weave will highlight what caused the issue. This is especially useful for finding things like JSON parsing issues that can occasionally happen when parsing data from LLM responses." ] }, + { + "cell_type": "markdown", + "id": "768cdd2f", + "metadata": {}, + "source": [ + "![](../../media/intro/4.png)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -313,6 +345,14 @@ "Many times, it is useful to track & version data, just like you track and version code. For example, here we define a `SystemPrompt(weave.Object)` object that can be shared between teammates" ] }, + { + "cell_type": "markdown", + "id": "95017cd1", + "metadata": {}, + "source": [ + "![](../../media/intro/5.png)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -345,6 +385,14 @@ "Models are so common of an object type, that we have a special class to represent them: `weave.Model`. The only requirement is that we define a `predict` method." ] }, + { + "cell_type": "markdown", + "id": "7feb0667", + "metadata": {}, + "source": [ + "![](../../media/intro/6.png)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -397,6 +445,14 @@ "Similar to models, a `weave.Dataset` object exists to help track, organize, and operate on datasets" ] }, + { + "cell_type": "markdown", + "id": "5384d6c3", + "metadata": {}, + "source": [ + "![](../../media/intro/7.png)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -440,6 +496,14 @@ "You can publish objects and then retrieve them in your code. You can even call functions from your retrieved objects!" ] }, + { + "cell_type": "markdown", + "id": "a9bf0233", + "metadata": {}, + "source": [ + "![](../../media/intro/8.png)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -460,6 +524,14 @@ "print(ref.uri())" ] }, + { + "cell_type": "markdown", + "id": "4e8ea290", + "metadata": {}, + "source": [ + "![](../../media/intro/9.png)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -493,6 +565,14 @@ "See a preview of the API below:" ] }, + { + "cell_type": "markdown", + "id": "72bdf072", + "metadata": {}, + "source": [ + "![](../../media/intro/10.png)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/docs/notebooks/chain_of_density.ipynb b/docs/notebooks/chain_of_density.ipynb index 2fb6ecc3f45..c7bc1136325 100644 --- a/docs/notebooks/chain_of_density.ipynb +++ b/docs/notebooks/chain_of_density.ipynb @@ -23,6 +23,13 @@ "Summarizing complex technical documents while preserving crucial details is a challenging task. The Chain of Density (CoD) summarization technique offers a solution by iteratively refining summaries to be more concise and information-dense. This guide demonstrates how to implement CoD using Weave for tracking and evaluating the application. " ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](../../media/summarization/eval_dash.png)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -203,6 +210,13 @@ "Now, let's implement the core CoD summarization logic using Weave operations:" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](../../media/summarization/summarization_trace.png)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -308,6 +322,13 @@ "Now, let's wrap our summarization pipeline in a Weave Model:" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](../../media/summarization/model.png)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -321,7 +342,7 @@ "\n", " @weave.op()\n", " def predict(self, paper: ArxivPaper, instruction: str) -> dict:\n", - " text = load_pdf(paper[\"pdf_url\"])\n", + " text = load_pdf(paper.pdf_url)\n", " result = chain_of_density_summarization(\n", " text,\n", " instruction,\n", @@ -425,6 +446,13 @@ "To evaluate our pipeline, we'll create a Weave Dataset and run an evaluation:" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](../../media/summarization/dataset.png)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -459,6 +487,13 @@ "[![arXiv](https://img.shields.io/badge/arXiv-2306.05685-b31b1b.svg)](https://arxiv.org/abs/2306.05685)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](../../media/summarization/eval_dash.png)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/docs/notebooks/dspy_prompt_optimization.ipynb b/docs/notebooks/dspy_prompt_optimization.ipynb index 573aaf7085a..e2da185e2ad 100644 --- a/docs/notebooks/dspy_prompt_optimization.ipynb +++ b/docs/notebooks/dspy_prompt_optimization.ipynb @@ -177,6 +177,13 @@ "dspy_train_examples, dspy_val_examples = get_dataset(metadata)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](../../media/dspy_optimization/1.png)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -268,6 +275,13 @@ "rich.print(prediction)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](../../media/dspy_optimization/2.png)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -316,6 +330,13 @@ "await evaluation.evaluate(baseline_module.forward)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](../../media/dspy_optimization/3.png)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -369,6 +390,13 @@ "optimized_module = get_optimized_program(baseline_module, metadata)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](../../media/dspy_optimization/4.png)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -395,6 +423,13 @@ "await evaluation.evaluate(optimized_module.forward)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](../../media/dspy_optimization/5.png)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/docs/notebooks/pii.ipynb b/docs/notebooks/pii.ipynb index e227c0b3bfa..7d89df49828 100644 --- a/docs/notebooks/pii.ipynb +++ b/docs/notebooks/pii.ipynb @@ -29,55 +29,97 @@ "id": "C70egOGRLCgm" }, "source": [ - "In this tutorial, we'll demonstrate how to utilize Weave while preventing your Personally Identifiable Information (PII) data from being incorporated into Weave or the LLMs you employ.\n", + "In this tutorial, we'll demonstrate how to utilize Weave while ensuring your Personally Identifiable Information (PII) data remains private. Weave supports removing PII from LLM calls and preventing PII from being displayed in the Weave UI. \n", "\n", - "To protect our PII data, we'll employ a couple techniques. First, we'll use regular expressions to identify PII data and redact it. Second, we'll use Microsoft's [Presidio](https://microsoft.github.io/presidio/), a python-based data protection SDK. This tool provides redaction and replacement functionalities, both of which we will implement in this tutorial.\n", + "To detect and protect our PII data, we'll identify and redact PII data and optionally anonymize it with the following methods:\n", + "1. __Regular expressions__ to identify PII data and redact it.\n", + "2. __Microsoft's [Presidio](https://microsoft.github.io/presidio/)__, a python-based data protection SDK. This tool provides redaction and replacement functionalities.\n", + "3. __[Faker](https://faker.readthedocs.io/en/master/)__, a Python library to generate fake data, combined with Presidio to anonymize PII data.\n", "\n", - "For this use-case. We will leverage Anthropic's Claude Sonnet to perform sentiment analysis. While we use Weave's [Traces](https://wandb.github.io/weave/quickstart) to track and analize the LLM's API calls. Sonnet will receive a block of text and output one of the following sentiment classifications:\n", - "1. positive\n", - "2. negative\n", - "3. neutral" + "Additionally, we'll make use of _Weave Ops input/output logging customization_ to seamlessly integrate PII redaction and anonymization into the workflow. See [here](https://weave-docs.wandb.ai/guides/tracking/ops/#customize-logged-inputs-and-outputs) for more information.\n", + "\n", + "For this use-case, we will leverage Anthropic's Claude Sonnet to perform sentiment analysis while tracing the LLM calls using Weave's [Traces](https://wandb.github.io/weave/quickstart). Sonnet will receive a block of text and output one of the following sentiment classifications: _positive_, _negative_, or _neutral_." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview of Weave Ops Input/Output Logging Customization\n", + "\n", + "Weave Ops support defining input and output postprocessing functions. These functions allow you to modify the data that is passed to your LLM call or logged to Weave, respectively." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```python\n", + "from dataclasses import dataclass\n", + "from typing import Any\n", + "\n", + "import weave\n", + "\n", + "# Inputs Wrapper Class\n", + "@dataclass\n", + "class CustomObject:\n", + " x: int\n", + " secret_password: str\n", + "\n", + "# First we define functions for input and output postprocessing:\n", + "def postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]:\n", + " return {k:v for k,v in inputs.items() if k != \"hide_me\"}\n", + "\n", + "def postprocess_output(output: CustomObject) -> CustomObject:\n", + " return CustomObject(x=output.x, secret_password=\"REDACTED\")\n", + "\n", + "# Then, when we use the `@weave.op` decorator, we pass these processing functions as arguments to the decorator:\n", + "@weave.op(\n", + " postprocess_inputs=postprocess_inputs,\n", + " postprocess_output=postprocess_output,\n", + ")\n", + "def some_llm_call(a: int, hide_me: str) -> CustomObject:\n", + " return CustomObject(x=a, secret_password=hide_me)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup\n", + "\n", + "Let's install the required packages and set up our API keys. Your Weights & Biases API key can be found [here](https://wandb.ai/authorize), and your Anthropic API keys are [here](https://console.anthropic.com/settings/keys)." ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "collapsed": true, - "id": "qi-VNJT35v2j", - "outputId": "8870f124-e141-4fee-8789-024d81944c56" - }, + "execution_count": 3, + "metadata": {}, "outputs": [], "source": [ "%%capture\n", "# @title required python packages:\n", "!pip install presidio_analyzer\n", "!pip install presidio_anonymizer\n", - "!python -m spacy download en_core_web_lg # Presidio uses spacy NLP engine\n", - "!pip install Faker # we'll use Faker to replace PII data with fake data\n", - "!pip install weave # To leverage Traces\n", + "!python -m spacy download en_core_web_lg # Presidio uses spacy NLP engine\n", + "!pip install Faker # we'll use Faker to replace PII data with fake data\n", + "!pip install weave # To leverage Traces\n", "!pip install set-env-colab-kaggle-dotenv -q # for env var\n", "!pip install anthropic # to use sonnet\n", "!pip install cryptography # to encrypt our data" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Setup" - ] - }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ + "%%capture\n", "# @title Make sure to set up set up your API keys correctly\n", + "# See: https://pypi.org/project/set-env-colab-kaggle-dotenv/ for usage instructions.\n", + "\n", "from set_env import set_env\n", "\n", "_ = set_env(\"ANTHROPIC_API_KEY\")\n", @@ -86,12 +128,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import weave\n", "\n", + "# Start a new Weave project\n", "WEAVE_PROJECT = \"pii_cookbook\"\n", "weave.init(WEAVE_PROJECT)" ] @@ -105,57 +159,54 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PII data first sample: \"I remember the day like it was yesterday. I was a waitress at a busy restaurant in New York City, Mohammed Gross is the name. The place was packed, and I was running around like a crazy person, trying to keep up with the demand. Suddenly, I noticed a customer sitting at one of my tables. He looked very upset. I went over to see what was wrong. \"My food is cold,\" he said. \"I've been waiting for it for over an hour.\" I apologized and asked him if he would like me to get him a new meal. He said yes, so I went back to the kitchen and told the chef. The chef was very apologetic and made the customer a new meal right away. I brought it out to the customer and he seemed satisfied. \"Thank you,\" he said. \"That's much better.\" I was glad that I was able to resolve the situation and make the customer happy. It's always rewarding to be able to help people, and it's one of the things I love about my job. I believe that customer service is very important, and I always try to go the extra mile for my customers. If you're ever in New York City, be sure to stop by the restaurant and ask for Mohammed Gross. I'll be happy to serve you! Feel free to email me at mohammedgross@msn.edu if you have any questions or concerns. You can also reach me by mail at 8202 Dudley Way, [City, State, ZIP].\"\n" + ] + } + ], "source": [ "import requests\n", "\n", "url = \"https://raw.githubusercontent.com/wandb/weave/master/docs/notebooks/10_pii_data.json\"\n", "response = requests.get(url)\n", - "pii_data = response.json()" + "pii_data = response.json()\n", + "\n", + "print('PII data first sample: \"' + pii_data[0][\"text\"] + '\"')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# Using Weave Safely with PII Data\n", - "\n", - "## During Testing\n", - "- Log anonymized data to check PII detection\n", - "- Track PII handling processes with Weave traces\n", - "- Measure anonymization performance without exposing real PII\n", - "\n", - "## In Production\n", - "- Never log raw PII\n", - "- Encrypt sensitive fields before logging\n", - "\n", - "## Encryption Tips\n", - "- Use reversible encryption for data you need to decrypt later\n", - "- Apply one-way hashing for unique IDs you don't need to reverse\n", - "- Consider specialized encryption for data you need to analyze while encrypted" + "# Redaction Methods Implementation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# Method 1: \n", + "## Method 1: Regular Expression Filtering\n", "\n", "Our initial method is to use [regular expressions (regex)](https://docs.python.org/3/library/re.html) to identify PII data and redact it. It allows us to define patterns that can match various formats of sensitive information like phone numbers, email addresses, and social security numbers. By using regex, we can scan through large volumes of text and replace or redact information without the need for more complex NLP techniques. " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import re\n", "\n", "\n", - "def clean_pii_with_regex(text):\n", + "# Define a function to clean PII data using regex\n", + "def redact_with_regex(text):\n", " # Phone number pattern\n", " # \\b : Word boundary\n", " # \\d{3} : Exactly 3 digits\n", @@ -198,110 +249,58 @@ " # \\b : Word boundary\n", " text = re.sub(r\"\\b[A-Z][a-z]+ [A-Z][a-z]+\\b\", \"\", text)\n", "\n", - " return text\n", - "\n", - "\n", - "# Test the function\n", - "test_text = \"My name is John Doe, my email is john.doe@example.com, my phone is 123-456-7890, and my SSN is 123-45-6789.\"\n", - "cleaned_text = clean_pii_with_regex(test_text)\n", - "print(cleaned_text)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Method 2: Microsoft Presidio" + " return text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "In this example, we'll create a [Weave Model](https://wandb.github.io/weave/guides/core-types/models) which is a combination of data (which can include configuration, trained model weights, or other information) and code that defines how the model operates.\n", - "In this model, we will include our predict function where the Anthropic API will be called.\n", - "\n", - "Once you run this code you will receive a link to the Weave project page" + "Let's test the function with a sample text:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Raw text:\n", + "\tMy name is John Doe, my email is john.doe@example.com, my phone is 123-456-7890, and my SSN is 123-45-6789.\n", + "Redacted text:\n", + "\tMy name is , my email is , my phone is , and my SSN is .\n" + ] + } + ], "source": [ - "import json\n", - "\n", - "from anthropic import AsyncAnthropic\n", - "\n", - "import weave\n", - "\n", - "\n", - "# Weave model / predict function\n", - "class sentiment_analysis_model(weave.Model):\n", - " model_name: str\n", - " system_prompt: str\n", - " temperature: int\n", - "\n", - " @weave.op()\n", - " async def predict(self, text_block: str) -> dict:\n", - " client = AsyncAnthropic()\n", - "\n", - " response = await client.messages.create(\n", - " max_tokens=1024,\n", - " model=self.model_name,\n", - " system=self.system_prompt,\n", - " messages=[\n", - " {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": text_block}]}\n", - " ],\n", - " )\n", - " result = response.content[0].text\n", - " if result is None:\n", - " raise ValueError(\"No response from model\")\n", - " parsed = json.loads(result)\n", - " return parsed\n", - "\n", - " # create our LLM model with a system prompt\n", - "\n", - "\n", - "model = sentiment_analysis_model(\n", - " name=\"claude-3-sonnet\",\n", - " model_name=\"claude-3-5-sonnet-20240620\",\n", - " system_prompt='You are a Sentiment Analysis classifier. You will be classifying text based on their sentiment. Your input will be a block of text. You will answer with one the following rating option[\"positive\", \"negative\", \"neutral\"]. Your answer should be one word in json format: {classification}. Ensure that it is valid JSON.',\n", - " temperature=0,\n", - ")" + "# Test the function\n", + "test_text = \"My name is John Doe, my email is john.doe@example.com, my phone is 123-456-7890, and my SSN is 123-45-6789.\"\n", + "cleaned_text = redact_with_regex(test_text)\n", + "print(f\"Raw text:\\n\\t{test_text}\")\n", + "print(f\"Redacted text:\\n\\t{cleaned_text}\")" ] }, { "cell_type": "markdown", - "metadata": { - "id": "sYEX_yPqMDOk" - }, + "metadata": {}, "source": [ - "# Method 2A:\n", - "Our next method involves complete removal of PII data using Presidio. This approach redacts PII and replaces it with a placeholder representing the PII type. For example:\n", - "```\n", - " \"My name is Alex\"\n", - "```\n", + "## Method 2: Microsoft Presidio Redaction\n", + "Our next method involves complete removal of PII data using Presidio. This approach redacts PII and replaces it with a placeholder representing the PII type. \n", "\n", - "Will be:\n", + "For example:\n", + "`\"My name is Alex\"` becomes `\"My name is \"`.\n", "\n", - "```\n", - " \"My name is \"\n", - "```\n", - "Presidio comes with a built-in [list of recognizable entities](https://microsoft.github.io/presidio/supported_entities/). We can select the ones that are important for our use case. In the below example, we are only looking at redicating names and phone numbers from our text:" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "![](../../media/pii/redact.png)" + "Presidio comes with a built-in [list of recognizable entities](https://microsoft.github.io/presidio/supported_entities/). We can select the ones that are important for our use case. In the below example, we redact names, phone numbers, locations, email addresses, and US Social Security Numbers.\n", + "\n", + "We'll then encapsulate the Presidio process into a function." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "id": "Bh6MgL3g6sc2" }, @@ -310,128 +309,94 @@ "from presidio_analyzer import AnalyzerEngine\n", "from presidio_anonymizer import AnonymizerEngine\n", "\n", - "text = \"My phone number is 212-555-5555 and my name is alex\"\n", - "\n", - "# Set up the engine, loads the NLP module (spaCy model by default)\n", - "# and other PII recognizers\n", + "# Set up the Analyzer, which loads an NLP module (spaCy model by default) and other PII recognizers.\n", "analyzer = AnalyzerEngine()\n", "\n", - "# Call analyzer to get results\n", - "results = analyzer.analyze(\n", - " text=text, entities=[\"PHONE_NUMBER\", \"PERSON\"], language=\"en\"\n", - ")\n", - "\n", - "# Analyzer results are passed to the AnonymizerEngine for anonymization\n", - "\n", + "# Set up the Anonymizer, which will use the analyzer results to anonymize the text.\n", "anonymizer = AnonymizerEngine()\n", "\n", - "anonymized_text = anonymizer.anonymize(text=text, analyzer_results=results)\n", - "\n", - "print(anonymized_text)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zq1abYAUl4rV" - }, - "source": [ - "Let's encapsulate the previous step into a function and expand the entity recognition capabilities. We will expand our redaction scope to include addresses, email addresses, and US Social Security numbers." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "collapsed": true, - "id": "rQyUmN5nmQrj", - "outputId": "4a538fda-c079-4cea-a65f-ccbeee58f17e" - }, - "outputs": [], - "source": [ - "from presidio_analyzer import AnalyzerEngine\n", - "from presidio_anonymizer import AnonymizerEngine\n", "\n", - "analyzer = AnalyzerEngine()\n", - "anonymizer = AnonymizerEngine()\n", - "\"\"\"\n", - "The below function will take a block of text, process it using presidio\n", - "and return a block of text with the PII data redicated.\n", - "PII data to be redicated:\n", - "- Phone Numbers\n", - "- Names\n", - "- Addresses\n", - "- Email addresses\n", - "- US Social Security Numbers\n", - "\"\"\"\n", - "\n", - "\n", - "def anonymize_my_text(text):\n", + "# Encapsulate the Presidio redaction process into a function\n", + "def redact_with_presidio(text):\n", + " # Analyze the text to identify PII data\n", " results = analyzer.analyze(\n", " text=text,\n", " entities=[\"PHONE_NUMBER\", \"PERSON\", \"LOCATION\", \"EMAIL_ADDRESS\", \"US_SSN\"],\n", " language=\"en\",\n", " )\n", + " # Anonymize the identified PII data\n", " anonymized_text = anonymizer.anonymize(text=text, analyzer_results=results)\n", " return anonymized_text.text" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's test the function with a sample text:" + ] + }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sTGbemoMtcO8" - }, - "outputs": [], + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Raw text:\n", + "\tMy phone number is 212-555-5555 and my name is alex\n", + "Redacted text:\n", + "\tMy phone number is and my name is \n" + ] + } + ], "source": [ - "# for every block of text, anonymized first and then predict\n", - "for entry in pii_data:\n", - " anonymized_entry = anonymize_my_text(entry[\"text\"])\n", - " (await model.predict(anonymized_entry))" + "text = \"My phone number is 212-555-5555 and my name is alex\"\n", + "\n", + "# Test the function\n", + "anonymized_text = redact_with_presidio(text)\n", + "\n", + "print(f\"Raw text:\\n\\t{text}\")\n", + "print(f\"Redacted text:\\n\\t{anonymized_text}\")" ] }, { "cell_type": "markdown", - "metadata": { - "id": "7omtvsdpMuel" - }, + "metadata": {}, "source": [ - "# Method 2B: Replace PII data with fake data\n", + "## Method 3: Anonymization with Replacement using Fakr and Presidio\n", "\n", "Instead of redacting text, we can anonymize it by swapping PII (like names and phone numbers) with fake data generated using the [Faker](https://faker.readthedocs.io/en/master/) Python library. For example:\n", "\n", + "`\"My name is Raphael and I like to fish. My phone number is 212-555-5555\"` \n", "\n", - "```\n", - "\"My name is Raphael and I like to fish. My phone number is 212-555-5555\"\n", - "```\n", - "Will be:\n", + "might become\n", "\n", - "\n", - "```\n", - "\"My name is Katherine Dixon and I like to fish. My phone number is 667.431.7379\"\n", - "\n", - "```\n", + "`\"My name is Katherine Dixon and I like to fish. My phone number is 667.431.7379\"`\n", "\n", "To effectively utilize Presidio, we must supply references to our custom operators. These operators will direct Presidio to the functions responsible for swapping PII with fake data." ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "![](../../media/pii/replace.png)" - ] - }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "id": "3XJa7u5T_WYd" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Raw text:\n", + "\tMy name is Raphael and I like to fish. My phone number is 212-555-5555\n", + "Anonymized text:\n", + "\tMy name is Jennifer Waters and I like to fish. My phone number is 8007771735\n" + ] + } + ], "source": [ "from faker import Faker\n", "from presidio_anonymizer import AnonymizerEngine\n", @@ -455,7 +420,6 @@ " \"PHONE_NUMBER\": OperatorConfig(\"custom\", {\"lambda\": fake_number}),\n", "}\n", "\n", - "\n", "text_to_anonymize = (\n", " \"My name is Raphael and I like to fish. My phone number is 212-555-5555\"\n", ")\n", @@ -465,7 +429,6 @@ " text=text_to_anonymize, entities=[\"PHONE_NUMBER\", \"PERSON\"], language=\"en\"\n", ")\n", "\n", - "\n", "anonymizer = AnonymizerEngine()\n", "\n", "# do not forget to pass the operators from above to the anonymizer\n", @@ -473,7 +436,8 @@ " text=text_to_anonymize, analyzer_results=analyzer_results, operators=operators\n", ")\n", "\n", - "print(anonymized_results.text)" + "print(f\"Raw text:\\n\\t{text_to_anonymize}\")\n", + "print(f\"Anonymized text:\\n\\t{anonymized_results.text}\")" ] }, { @@ -487,7 +451,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "colab": { "background_save": true @@ -496,15 +460,12 @@ }, "outputs": [], "source": [ - "from anthropic import AsyncAnthropic\n", "from faker import Faker\n", "from presidio_anonymizer import AnonymizerEngine\n", "from presidio_anonymizer.entities import OperatorConfig\n", "\n", - "import weave\n", "\n", - "\n", - "# Let's build a custom class for generating fake data that will extend Faker\n", + "# A custom class for generating fake data that extends Faker\n", "class my_faker(Faker):\n", " # Create faker functions (note that it has to receive a value)\n", " def fake_address(x):\n", @@ -531,7 +492,7 @@ " \"US_SSN\": OperatorConfig(\"custom\", {\"lambda\": fake_ssn}),\n", " }\n", "\n", - " def anonymize_my_text(self, text):\n", + " def redact_and_anonymize_with_faker(self, text):\n", " anonymizer = AnonymizerEngine()\n", " analyzer_results = analyzer.analyze(\n", " text=text,\n", @@ -541,13 +502,405 @@ " anonymized_results = anonymizer.anonymize(\n", " text=text, analyzer_results=analyzer_results, operators=self.operators\n", " )\n", - " return anonymized_results.text\n", + " return anonymized_results.text" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's test the function with a sample text:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Raw text:\n", + "\tMy name is Raphael and I like to fish. My phone number is 212-555-5555\n", + "Anonymized text:\n", + "\tMy name is Amy Adkins and I like to fish. My phone number is +1-481-464-7524x69850\n" + ] + } + ], + "source": [ + "faker = my_faker()\n", + "text_to_anonymize = (\n", + " \"My name is Raphael and I like to fish. My phone number is 212-555-5555\"\n", + ")\n", + "anonymized_text = faker.redact_and_anonymize_with_faker(text_to_anonymize)\n", + "\n", + "print(f\"Raw text:\\n\\t{text_to_anonymize}\")\n", + "print(f\"Anonymized text:\\n\\t{anonymized_text}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Applying the Methods to Weave Calls\n", + "\n", + "In these examples we will integrate our PII redaction and anonymization methods into Weave Models, and preview the results in Weave Traces.\n", + "\n", + "We'll create a [Weave Model](https://wandb.github.io/weave/guides/core-types/models) which is a combination of data (which can include configuration, trained model weights, or other information) and code that defines how the model operates. \n", + "\n", + "In this model, we will include our predict function where the Anthropic API will be called. Additionally, we will include our postprocessing functions to ensure that our PII data is redacted or anonymized before it is sent to the LLM.\n", + "\n", + "Once you run this code you will receive a links to the Weave project page as well as the specific trace (LLM calls)you ran." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Regex Method \n", + "\n", + "In the simplest case, we can use regex to identify and redact PII data in the original text." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from typing import Any\n", + "\n", + "import anthropic\n", + "\n", + "import weave\n", + "\n", + "\n", + "# Define an input postprocessing function that applies our regex redaction for the model prediction Weave Op\n", + "def postprocess_inputs_regex(inputs: dict[str, Any]) -> dict:\n", + " inputs[\"text_block\"] = redact_with_regex(inputs[\"text_block\"])\n", + " return inputs\n", + "\n", + "\n", + "# Weave model / predict function\n", + "class sentiment_analysis_regex_pii_model(weave.Model):\n", + " model_name: str\n", + " system_prompt: str\n", + " temperature: int\n", + "\n", + " @weave.op(\n", + " postprocess_inputs=postprocess_inputs_regex,\n", + " )\n", + " async def predict(self, text_block: str) -> dict:\n", + " client = anthropic.AsyncAnthropic()\n", + " response = await client.messages.create(\n", + " max_tokens=1024,\n", + " model=self.model_name,\n", + " system=self.system_prompt,\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": text_block}]}\n", + " ],\n", + " )\n", + " result = response.content[0].text\n", + " if result is None:\n", + " raise ValueError(\"No response from model\")\n", + " parsed = json.loads(result)\n", + " return parsed" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: sentiment_analysis_regex_pii_model(name='claude-3-sonnet', description=None, model_name='claude-3-5-sonnet-20240620', system_prompt='You are a Sentiment Analysis classifier. You will be classifying text based on their sentiment. Your input will be a block of text. You will answer with one the following rating option[\"positive\", \"negative\", \"neutral\"]. Your answer should be one word in json format: {classification}. Ensure that it is valid JSON.', temperature=0)\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-b36e-7f60-a856-6569856128c9\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-b9e3-7751-95e9-7949b34cd9c3\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-bd4c-72b3-bd33-3940cb2c3812\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-c0a8-7390-a5f4-3ed0767d2b87\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-c442-7fc2-9523-3e20f1e25dc0\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-c7d3-7a30-9975-29cab0ec0fd7\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-cb5a-76d2-a7ed-e8cd55a4f280\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-cef1-7220-852b-2d9a1e30e8ab\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-d284-7bc3-9181-d2284f291906\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-d637-7713-8fea-392f83ae8af1\n" + ] + } + ], + "source": [ + "# create our LLM model with a system prompt\n", + "model = sentiment_analysis_regex_pii_model(\n", + " name=\"claude-3-sonnet\",\n", + " model_name=\"claude-3-5-sonnet-20240620\",\n", + " system_prompt='You are a Sentiment Analysis classifier. You will be classifying text based on their sentiment. Your input will be a block of text. You will answer with one the following rating option[\"positive\", \"negative\", \"neutral\"]. Your answer should be one word in json format: {classification}. Ensure that it is valid JSON.',\n", + " temperature=0,\n", + ")\n", + "\n", + "print(\"Model: \", model)\n", + "# for every block of text, anonymized first and then predict\n", + "for entry in pii_data:\n", + " await model.predict(entry[\"text\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Presidio Redaction Method\n", + "\n", + "Here we will use Presidio to identify and redact PII data in the original text." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](../../media/pii/redact.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from typing import Any\n", + "\n", + "import anthropic\n", + "\n", + "import weave\n", + "\n", + "\n", + "# Define an input postprocessing function that applies our Presidio redaction for the model prediction Weave Op\n", + "def postprocess_inputs_presidio(inputs: dict[str, Any]) -> dict:\n", + " inputs[\"text_block\"] = redact_with_presidio(inputs[\"text_block\"])\n", + " return inputs\n", + "\n", + "\n", + "# Weave model / predict function\n", + "class sentiment_analysis_presidio_pii_model(weave.Model):\n", + " model_name: str\n", + " system_prompt: str\n", + " temperature: int\n", + "\n", + " @weave.op(\n", + " postprocess_inputs=postprocess_inputs_presidio,\n", + " )\n", + " async def predict(self, text_block: str) -> dict:\n", + " client = anthropic.AsyncAnthropic()\n", + " response = await client.messages.create(\n", + " max_tokens=1024,\n", + " model=self.model_name,\n", + " system=self.system_prompt,\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": text_block}]}\n", + " ],\n", + " )\n", + " result = response.content[0].text\n", + " if result is None:\n", + " raise ValueError(\"No response from model\")\n", + " parsed = json.loads(result)\n", + " return parsed" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: sentiment_analysis_presidio_pii_model(name='claude-3-sonnet', description=None, model_name='claude-3-5-sonnet-20240620', system_prompt='You are a Sentiment Analysis classifier. You will be classifying text based on their sentiment. Your input will be a block of text. You will answer with one the following rating option[\"positive\", \"negative\", \"neutral\"]. Your answer should be one word in json format: {classification}. Ensure that it is valid JSON.', temperature=0)\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-dd6f-7851-8c99-18529ddc63f4\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-e17e-7cc3-a28f-7be155d87eb1\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-e538-76d3-a5a3-0e89bf4ff6b9\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-e8cf-7fb1-b942-11e5442dcf22\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a61-ef00-7793-8799-10ff5a32e129\n" + ] + }, + { + "ename": "CancelledError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mCancelledError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[20], line 12\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m# for every block of text, anonymized first and then predict\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m entry \u001b[38;5;129;01min\u001b[39;00m pii_data:\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m model\u001b[38;5;241m.\u001b[39mpredict(entry[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/weave/trace/op.py:464\u001b[0m, in \u001b[0;36mop..op_deco..create_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 459\u001b[0m log_once(\n\u001b[1;32m 460\u001b[0m logger\u001b[38;5;241m.\u001b[39merror,\n\u001b[1;32m 461\u001b[0m ASYNC_CALL_CREATE_MSG\u001b[38;5;241m.\u001b[39mformat(traceback\u001b[38;5;241m.\u001b[39mformat_exc()),\n\u001b[1;32m 462\u001b[0m )\n\u001b[1;32m 463\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mawait\u001b[39;00m func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 464\u001b[0m res, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m _execute_call(wrapper, call, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m 465\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m res\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/weave/trace/op.py:274\u001b[0m, in \u001b[0;36m_execute_call.._call_async\u001b[0;34m()\u001b[0m\n\u001b[1;32m 272\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_call_async\u001b[39m() \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Coroutine[Any, Any, Any]:\n\u001b[1;32m 273\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 274\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 275\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 276\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_exception(e)\n", + "Cell \u001b[0;32mIn[19], line 23\u001b[0m, in \u001b[0;36msentiment_analysis_presidio_pii_model.predict\u001b[0;34m(self, text_block)\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;129m@weave\u001b[39m\u001b[38;5;241m.\u001b[39mop(\n\u001b[1;32m 19\u001b[0m postprocess_inputs\u001b[38;5;241m=\u001b[39mpostprocess_inputs_presidio,\n\u001b[1;32m 20\u001b[0m )\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mpredict\u001b[39m(\u001b[38;5;28mself\u001b[39m, text_block: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mdict\u001b[39m:\n\u001b[1;32m 22\u001b[0m client \u001b[38;5;241m=\u001b[39m anthropic\u001b[38;5;241m.\u001b[39mAsyncAnthropic()\n\u001b[0;32m---> 23\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m client\u001b[38;5;241m.\u001b[39mmessages\u001b[38;5;241m.\u001b[39mcreate(\n\u001b[1;32m 24\u001b[0m max_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1024\u001b[39m,\n\u001b[1;32m 25\u001b[0m model\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_name,\n\u001b[1;32m 26\u001b[0m system\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msystem_prompt,\n\u001b[1;32m 27\u001b[0m messages\u001b[38;5;241m=\u001b[39m[\n\u001b[1;32m 28\u001b[0m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muser\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m: [{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext\u001b[39m\u001b[38;5;124m\"\u001b[39m: text_block}]}\n\u001b[1;32m 29\u001b[0m ],\n\u001b[1;32m 30\u001b[0m )\n\u001b[1;32m 31\u001b[0m result \u001b[38;5;241m=\u001b[39m response\u001b[38;5;241m.\u001b[39mcontent[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mtext\n\u001b[1;32m 32\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/weave/trace/op.py:464\u001b[0m, in \u001b[0;36mop..op_deco..create_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 459\u001b[0m log_once(\n\u001b[1;32m 460\u001b[0m logger\u001b[38;5;241m.\u001b[39merror,\n\u001b[1;32m 461\u001b[0m ASYNC_CALL_CREATE_MSG\u001b[38;5;241m.\u001b[39mformat(traceback\u001b[38;5;241m.\u001b[39mformat_exc()),\n\u001b[1;32m 462\u001b[0m )\n\u001b[1;32m 463\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mawait\u001b[39;00m func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 464\u001b[0m res, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m _execute_call(wrapper, call, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m 465\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m res\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/weave/trace/op.py:274\u001b[0m, in \u001b[0;36m_execute_call.._call_async\u001b[0;34m()\u001b[0m\n\u001b[1;32m 272\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_call_async\u001b[39m() \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Coroutine[Any, Any, Any]:\n\u001b[1;32m 273\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 274\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 275\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 276\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_exception(e)\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/weave/integrations/anthropic/anthropic_sdk.py:100\u001b[0m, in \u001b[0;36mcreate_wrapper_async..wrapper.._fn_wrapper.._async_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(fn)\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_async_wrapper\u001b[39m(\n\u001b[1;32m 98\u001b[0m \u001b[38;5;241m*\u001b[39margs: typing\u001b[38;5;241m.\u001b[39mAny, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: typing\u001b[38;5;241m.\u001b[39mAny\n\u001b[1;32m 99\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m typing\u001b[38;5;241m.\u001b[39mAny:\n\u001b[0;32m--> 100\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mawait\u001b[39;00m fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/anthropic/resources/messages.py:1799\u001b[0m, in \u001b[0;36mAsyncMessages.create\u001b[0;34m(self, max_tokens, messages, model, metadata, stop_sequences, stream, system, temperature, tool_choice, tools, top_k, top_p, extra_headers, extra_query, extra_body, timeout)\u001b[0m\n\u001b[1;32m 1792\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m model \u001b[38;5;129;01min\u001b[39;00m DEPRECATED_MODELS:\n\u001b[1;32m 1793\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 1794\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe model \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m is deprecated and will reach end-of-life on \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mDEPRECATED_MODELS[model]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mPlease migrate to a newer model. Visit https://docs.anthropic.com/en/docs/resources/model-deprecations for more information.\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1795\u001b[0m \u001b[38;5;167;01mDeprecationWarning\u001b[39;00m,\n\u001b[1;32m 1796\u001b[0m stacklevel\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m,\n\u001b[1;32m 1797\u001b[0m )\n\u001b[0;32m-> 1799\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_post(\n\u001b[1;32m 1800\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/v1/messages\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1801\u001b[0m body\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mawait\u001b[39;00m async_maybe_transform(\n\u001b[1;32m 1802\u001b[0m {\n\u001b[1;32m 1803\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmax_tokens\u001b[39m\u001b[38;5;124m\"\u001b[39m: max_tokens,\n\u001b[1;32m 1804\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmessages\u001b[39m\u001b[38;5;124m\"\u001b[39m: messages,\n\u001b[1;32m 1805\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m: model,\n\u001b[1;32m 1806\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmetadata\u001b[39m\u001b[38;5;124m\"\u001b[39m: metadata,\n\u001b[1;32m 1807\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstop_sequences\u001b[39m\u001b[38;5;124m\"\u001b[39m: stop_sequences,\n\u001b[1;32m 1808\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstream\u001b[39m\u001b[38;5;124m\"\u001b[39m: stream,\n\u001b[1;32m 1809\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msystem\u001b[39m\u001b[38;5;124m\"\u001b[39m: system,\n\u001b[1;32m 1810\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtemperature\u001b[39m\u001b[38;5;124m\"\u001b[39m: temperature,\n\u001b[1;32m 1811\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtool_choice\u001b[39m\u001b[38;5;124m\"\u001b[39m: tool_choice,\n\u001b[1;32m 1812\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtools\u001b[39m\u001b[38;5;124m\"\u001b[39m: tools,\n\u001b[1;32m 1813\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtop_k\u001b[39m\u001b[38;5;124m\"\u001b[39m: top_k,\n\u001b[1;32m 1814\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtop_p\u001b[39m\u001b[38;5;124m\"\u001b[39m: top_p,\n\u001b[1;32m 1815\u001b[0m },\n\u001b[1;32m 1816\u001b[0m message_create_params\u001b[38;5;241m.\u001b[39mMessageCreateParams,\n\u001b[1;32m 1817\u001b[0m ),\n\u001b[1;32m 1818\u001b[0m options\u001b[38;5;241m=\u001b[39mmake_request_options(\n\u001b[1;32m 1819\u001b[0m extra_headers\u001b[38;5;241m=\u001b[39mextra_headers, extra_query\u001b[38;5;241m=\u001b[39mextra_query, extra_body\u001b[38;5;241m=\u001b[39mextra_body, timeout\u001b[38;5;241m=\u001b[39mtimeout\n\u001b[1;32m 1820\u001b[0m ),\n\u001b[1;32m 1821\u001b[0m cast_to\u001b[38;5;241m=\u001b[39mMessage,\n\u001b[1;32m 1822\u001b[0m stream\u001b[38;5;241m=\u001b[39mstream \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 1823\u001b[0m stream_cls\u001b[38;5;241m=\u001b[39mAsyncStream[RawMessageStreamEvent],\n\u001b[1;32m 1824\u001b[0m )\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/anthropic/_base_client.py:1816\u001b[0m, in \u001b[0;36mAsyncAPIClient.post\u001b[0;34m(self, path, cast_to, body, files, options, stream, stream_cls)\u001b[0m\n\u001b[1;32m 1802\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mpost\u001b[39m(\n\u001b[1;32m 1803\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1804\u001b[0m path: \u001b[38;5;28mstr\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1811\u001b[0m stream_cls: \u001b[38;5;28mtype\u001b[39m[_AsyncStreamT] \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1812\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ResponseT \u001b[38;5;241m|\u001b[39m _AsyncStreamT:\n\u001b[1;32m 1813\u001b[0m opts \u001b[38;5;241m=\u001b[39m FinalRequestOptions\u001b[38;5;241m.\u001b[39mconstruct(\n\u001b[1;32m 1814\u001b[0m method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpost\u001b[39m\u001b[38;5;124m\"\u001b[39m, url\u001b[38;5;241m=\u001b[39mpath, json_data\u001b[38;5;241m=\u001b[39mbody, files\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mawait\u001b[39;00m async_to_httpx_files(files), \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39moptions\n\u001b[1;32m 1815\u001b[0m )\n\u001b[0;32m-> 1816\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrequest(cast_to, opts, stream\u001b[38;5;241m=\u001b[39mstream, stream_cls\u001b[38;5;241m=\u001b[39mstream_cls)\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/anthropic/_base_client.py:1510\u001b[0m, in \u001b[0;36mAsyncAPIClient.request\u001b[0;34m(self, cast_to, options, stream, stream_cls, remaining_retries)\u001b[0m\n\u001b[1;32m 1501\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrequest\u001b[39m(\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1503\u001b[0m cast_to: Type[ResponseT],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1508\u001b[0m remaining_retries: Optional[\u001b[38;5;28mint\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1509\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ResponseT \u001b[38;5;241m|\u001b[39m _AsyncStreamT:\n\u001b[0;32m-> 1510\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_request(\n\u001b[1;32m 1511\u001b[0m cast_to\u001b[38;5;241m=\u001b[39mcast_to,\n\u001b[1;32m 1512\u001b[0m options\u001b[38;5;241m=\u001b[39moptions,\n\u001b[1;32m 1513\u001b[0m stream\u001b[38;5;241m=\u001b[39mstream,\n\u001b[1;32m 1514\u001b[0m stream_cls\u001b[38;5;241m=\u001b[39mstream_cls,\n\u001b[1;32m 1515\u001b[0m remaining_retries\u001b[38;5;241m=\u001b[39mremaining_retries,\n\u001b[1;32m 1516\u001b[0m )\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/anthropic/_base_client.py:1549\u001b[0m, in \u001b[0;36mAsyncAPIClient._request\u001b[0;34m(self, cast_to, options, stream, stream_cls, remaining_retries)\u001b[0m\n\u001b[1;32m 1546\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauth\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcustom_auth\n\u001b[1;32m 1548\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1549\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_client\u001b[38;5;241m.\u001b[39msend(\n\u001b[1;32m 1550\u001b[0m request,\n\u001b[1;32m 1551\u001b[0m stream\u001b[38;5;241m=\u001b[39mstream \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_stream_response_body(request\u001b[38;5;241m=\u001b[39mrequest),\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 1553\u001b[0m )\n\u001b[1;32m 1554\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m httpx\u001b[38;5;241m.\u001b[39mTimeoutException \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[1;32m 1555\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEncountered httpx.TimeoutException\u001b[39m\u001b[38;5;124m\"\u001b[39m, exc_info\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/httpx/_client.py:1674\u001b[0m, in \u001b[0;36mAsyncClient.send\u001b[0;34m(self, request, stream, auth, follow_redirects)\u001b[0m\n\u001b[1;32m 1670\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_set_timeout(request)\n\u001b[1;32m 1672\u001b[0m auth \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_build_request_auth(request, auth)\n\u001b[0;32m-> 1674\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_send_handling_auth(\n\u001b[1;32m 1675\u001b[0m request,\n\u001b[1;32m 1676\u001b[0m auth\u001b[38;5;241m=\u001b[39mauth,\n\u001b[1;32m 1677\u001b[0m follow_redirects\u001b[38;5;241m=\u001b[39mfollow_redirects,\n\u001b[1;32m 1678\u001b[0m history\u001b[38;5;241m=\u001b[39m[],\n\u001b[1;32m 1679\u001b[0m )\n\u001b[1;32m 1680\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1681\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m stream:\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/httpx/_client.py:1702\u001b[0m, in \u001b[0;36mAsyncClient._send_handling_auth\u001b[0;34m(self, request, auth, follow_redirects, history)\u001b[0m\n\u001b[1;32m 1699\u001b[0m request \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m auth_flow\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__anext__\u001b[39m()\n\u001b[1;32m 1701\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m-> 1702\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_send_handling_redirects(\n\u001b[1;32m 1703\u001b[0m request,\n\u001b[1;32m 1704\u001b[0m follow_redirects\u001b[38;5;241m=\u001b[39mfollow_redirects,\n\u001b[1;32m 1705\u001b[0m history\u001b[38;5;241m=\u001b[39mhistory,\n\u001b[1;32m 1706\u001b[0m )\n\u001b[1;32m 1707\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1708\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/httpx/_client.py:1739\u001b[0m, in \u001b[0;36mAsyncClient._send_handling_redirects\u001b[0;34m(self, request, follow_redirects, history)\u001b[0m\n\u001b[1;32m 1736\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m hook \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_event_hooks[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrequest\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[1;32m 1737\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m hook(request)\n\u001b[0;32m-> 1739\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_send_single_request(request)\n\u001b[1;32m 1740\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1741\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m hook \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_event_hooks[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mresponse\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/httpx/_client.py:1776\u001b[0m, in \u001b[0;36mAsyncClient._send_single_request\u001b[0;34m(self, request)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAttempted to send an sync request with an AsyncClient instance.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1773\u001b[0m )\n\u001b[1;32m 1775\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m request_context(request\u001b[38;5;241m=\u001b[39mrequest):\n\u001b[0;32m-> 1776\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m transport\u001b[38;5;241m.\u001b[39mhandle_async_request(request)\n\u001b[1;32m 1778\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(response\u001b[38;5;241m.\u001b[39mstream, AsyncByteStream)\n\u001b[1;32m 1779\u001b[0m response\u001b[38;5;241m.\u001b[39mrequest \u001b[38;5;241m=\u001b[39m request\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/httpx/_transports/default.py:377\u001b[0m, in \u001b[0;36mAsyncHTTPTransport.handle_async_request\u001b[0;34m(self, request)\u001b[0m\n\u001b[1;32m 364\u001b[0m req \u001b[38;5;241m=\u001b[39m httpcore\u001b[38;5;241m.\u001b[39mRequest(\n\u001b[1;32m 365\u001b[0m method\u001b[38;5;241m=\u001b[39mrequest\u001b[38;5;241m.\u001b[39mmethod,\n\u001b[1;32m 366\u001b[0m url\u001b[38;5;241m=\u001b[39mhttpcore\u001b[38;5;241m.\u001b[39mURL(\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 374\u001b[0m extensions\u001b[38;5;241m=\u001b[39mrequest\u001b[38;5;241m.\u001b[39mextensions,\n\u001b[1;32m 375\u001b[0m )\n\u001b[1;32m 376\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m map_httpcore_exceptions():\n\u001b[0;32m--> 377\u001b[0m resp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pool\u001b[38;5;241m.\u001b[39mhandle_async_request(req)\n\u001b[1;32m 379\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(resp\u001b[38;5;241m.\u001b[39mstream, typing\u001b[38;5;241m.\u001b[39mAsyncIterable)\n\u001b[1;32m 381\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m Response(\n\u001b[1;32m 382\u001b[0m status_code\u001b[38;5;241m=\u001b[39mresp\u001b[38;5;241m.\u001b[39mstatus,\n\u001b[1;32m 383\u001b[0m headers\u001b[38;5;241m=\u001b[39mresp\u001b[38;5;241m.\u001b[39mheaders,\n\u001b[1;32m 384\u001b[0m stream\u001b[38;5;241m=\u001b[39mAsyncResponseStream(resp\u001b[38;5;241m.\u001b[39mstream),\n\u001b[1;32m 385\u001b[0m extensions\u001b[38;5;241m=\u001b[39mresp\u001b[38;5;241m.\u001b[39mextensions,\n\u001b[1;32m 386\u001b[0m )\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/httpcore/_async/connection_pool.py:216\u001b[0m, in \u001b[0;36mAsyncConnectionPool.handle_async_request\u001b[0;34m(self, request)\u001b[0m\n\u001b[1;32m 213\u001b[0m closing \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_assign_requests_to_connections()\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_close_connections(closing)\n\u001b[0;32m--> 216\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exc \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 218\u001b[0m \u001b[38;5;66;03m# Return the response. Note that in this case we still have to manage\u001b[39;00m\n\u001b[1;32m 219\u001b[0m \u001b[38;5;66;03m# the point at which the response is closed.\u001b[39;00m\n\u001b[1;32m 220\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(response\u001b[38;5;241m.\u001b[39mstream, AsyncIterable)\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/httpcore/_async/connection_pool.py:196\u001b[0m, in \u001b[0;36mAsyncConnectionPool.handle_async_request\u001b[0;34m(self, request)\u001b[0m\n\u001b[1;32m 192\u001b[0m connection \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m pool_request\u001b[38;5;241m.\u001b[39mwait_for_connection(timeout\u001b[38;5;241m=\u001b[39mtimeout)\n\u001b[1;32m 194\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 195\u001b[0m \u001b[38;5;66;03m# Send the request on the assigned connection.\u001b[39;00m\n\u001b[0;32m--> 196\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m connection\u001b[38;5;241m.\u001b[39mhandle_async_request(\n\u001b[1;32m 197\u001b[0m pool_request\u001b[38;5;241m.\u001b[39mrequest\n\u001b[1;32m 198\u001b[0m )\n\u001b[1;32m 199\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ConnectionNotAvailable:\n\u001b[1;32m 200\u001b[0m \u001b[38;5;66;03m# In some cases a connection may initially be available to\u001b[39;00m\n\u001b[1;32m 201\u001b[0m \u001b[38;5;66;03m# handle a request, but then become unavailable.\u001b[39;00m\n\u001b[1;32m 202\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;66;03m# In this case we clear the connection and try again.\u001b[39;00m\n\u001b[1;32m 204\u001b[0m pool_request\u001b[38;5;241m.\u001b[39mclear_connection()\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/httpcore/_async/connection.py:101\u001b[0m, in \u001b[0;36mAsyncHTTPConnection.handle_async_request\u001b[0;34m(self, request)\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_connect_failed \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exc\n\u001b[0;32m--> 101\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_connection\u001b[38;5;241m.\u001b[39mhandle_async_request(request)\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/httpcore/_async/http11.py:143\u001b[0m, in \u001b[0;36mAsyncHTTP11Connection.handle_async_request\u001b[0;34m(self, request)\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mwith\u001b[39;00m Trace(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mresponse_closed\u001b[39m\u001b[38;5;124m\"\u001b[39m, logger, request) \u001b[38;5;28;01mas\u001b[39;00m trace:\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_response_closed()\n\u001b[0;32m--> 143\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exc\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/httpcore/_async/http11.py:113\u001b[0m, in \u001b[0;36mAsyncHTTP11Connection.handle_async_request\u001b[0;34m(self, request)\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m 104\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mwith\u001b[39;00m Trace(\n\u001b[1;32m 105\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mreceive_response_headers\u001b[39m\u001b[38;5;124m\"\u001b[39m, logger, request, kwargs\n\u001b[1;32m 106\u001b[0m ) \u001b[38;5;28;01mas\u001b[39;00m trace:\n\u001b[1;32m 107\u001b[0m (\n\u001b[1;32m 108\u001b[0m http_version,\n\u001b[1;32m 109\u001b[0m status,\n\u001b[1;32m 110\u001b[0m reason_phrase,\n\u001b[1;32m 111\u001b[0m headers,\n\u001b[1;32m 112\u001b[0m trailing_data,\n\u001b[0;32m--> 113\u001b[0m ) \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_receive_response_headers(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 114\u001b[0m trace\u001b[38;5;241m.\u001b[39mreturn_value \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 115\u001b[0m http_version,\n\u001b[1;32m 116\u001b[0m status,\n\u001b[1;32m 117\u001b[0m reason_phrase,\n\u001b[1;32m 118\u001b[0m headers,\n\u001b[1;32m 119\u001b[0m )\n\u001b[1;32m 121\u001b[0m network_stream \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_network_stream\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/httpcore/_async/http11.py:186\u001b[0m, in \u001b[0;36mAsyncHTTP11Connection._receive_response_headers\u001b[0;34m(self, request)\u001b[0m\n\u001b[1;32m 183\u001b[0m timeout \u001b[38;5;241m=\u001b[39m timeouts\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mread\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 185\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m--> 186\u001b[0m event \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_receive_event(timeout\u001b[38;5;241m=\u001b[39mtimeout)\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(event, h11\u001b[38;5;241m.\u001b[39mResponse):\n\u001b[1;32m 188\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/httpcore/_async/http11.py:224\u001b[0m, in \u001b[0;36mAsyncHTTP11Connection._receive_event\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 221\u001b[0m event \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_h11_state\u001b[38;5;241m.\u001b[39mnext_event()\n\u001b[1;32m 223\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m event \u001b[38;5;129;01mis\u001b[39;00m h11\u001b[38;5;241m.\u001b[39mNEED_DATA:\n\u001b[0;32m--> 224\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_network_stream\u001b[38;5;241m.\u001b[39mread(\n\u001b[1;32m 225\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mREAD_NUM_BYTES, timeout\u001b[38;5;241m=\u001b[39mtimeout\n\u001b[1;32m 226\u001b[0m )\n\u001b[1;32m 228\u001b[0m \u001b[38;5;66;03m# If we feed this case through h11 we'll raise an exception like:\u001b[39;00m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m 230\u001b[0m \u001b[38;5;66;03m# httpcore.RemoteProtocolError: can't handle event type\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 234\u001b[0m \u001b[38;5;66;03m# perspective. Instead we handle this case distinctly and treat\u001b[39;00m\n\u001b[1;32m 235\u001b[0m \u001b[38;5;66;03m# it as a ConnectError.\u001b[39;00m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m data \u001b[38;5;241m==\u001b[39m \u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_h11_state\u001b[38;5;241m.\u001b[39mtheir_state \u001b[38;5;241m==\u001b[39m h11\u001b[38;5;241m.\u001b[39mSEND_RESPONSE:\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/httpcore/_backends/anyio.py:35\u001b[0m, in \u001b[0;36mAnyIOStream.read\u001b[0;34m(self, max_bytes, timeout)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m anyio\u001b[38;5;241m.\u001b[39mfail_after(timeout):\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 35\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_stream\u001b[38;5;241m.\u001b[39mreceive(max_bytes\u001b[38;5;241m=\u001b[39mmax_bytes)\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m anyio\u001b[38;5;241m.\u001b[39mEndOfStream: \u001b[38;5;66;03m# pragma: nocover\u001b[39;00m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/anyio/streams/tls.py:205\u001b[0m, in \u001b[0;36mTLSStream.receive\u001b[0;34m(self, max_bytes)\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mreceive\u001b[39m(\u001b[38;5;28mself\u001b[39m, max_bytes: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m65536\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mbytes\u001b[39m:\n\u001b[0;32m--> 205\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_sslobject_method(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ssl_object\u001b[38;5;241m.\u001b[39mread, max_bytes)\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m data:\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m EndOfStream\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/anyio/streams/tls.py:147\u001b[0m, in \u001b[0;36mTLSStream._call_sslobject_method\u001b[0;34m(self, func, *args)\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_write_bio\u001b[38;5;241m.\u001b[39mpending:\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransport_stream\u001b[38;5;241m.\u001b[39msend(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_write_bio\u001b[38;5;241m.\u001b[39mread())\n\u001b[0;32m--> 147\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransport_stream\u001b[38;5;241m.\u001b[39mreceive()\n\u001b[1;32m 148\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m EndOfStream:\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_read_bio\u001b[38;5;241m.\u001b[39mwrite_eof()\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/site-packages/anyio/_backends/_asyncio.py:1198\u001b[0m, in \u001b[0;36mSocketStream.receive\u001b[0;34m(self, max_bytes)\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_protocol\u001b[38;5;241m.\u001b[39mread_event\u001b[38;5;241m.\u001b[39mis_set()\n\u001b[1;32m 1194\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_transport\u001b[38;5;241m.\u001b[39mis_closing()\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_protocol\u001b[38;5;241m.\u001b[39mis_at_eof\n\u001b[1;32m 1196\u001b[0m ):\n\u001b[1;32m 1197\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_transport\u001b[38;5;241m.\u001b[39mresume_reading()\n\u001b[0;32m-> 1198\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_protocol\u001b[38;5;241m.\u001b[39mread_event\u001b[38;5;241m.\u001b[39mwait()\n\u001b[1;32m 1199\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_transport\u001b[38;5;241m.\u001b[39mpause_reading()\n\u001b[1;32m 1200\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "File \u001b[0;32m~/Desktop/projects/weave/.conda/lib/python3.12/asyncio/locks.py:212\u001b[0m, in \u001b[0;36mEvent.wait\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_waiters\u001b[38;5;241m.\u001b[39mappend(fut)\n\u001b[1;32m 211\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 212\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m fut\n\u001b[1;32m 213\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n", + "\u001b[0;31mCancelledError\u001b[0m: " + ] + } + ], + "source": [ + "# create our LLM model with a system prompt\n", + "model = sentiment_analysis_presidio_pii_model(\n", + " name=\"claude-3-sonnet\",\n", + " model_name=\"claude-3-5-sonnet-20240620\",\n", + " system_prompt='You are a Sentiment Analysis classifier. You will be classifying text based on their sentiment. Your input will be a block of text. You will answer with one the following rating option[\"positive\", \"negative\", \"neutral\"]. Your answer should be one word in json format: {classification}. Ensure that it is valid JSON.',\n", + " temperature=0,\n", + ")\n", + "\n", + "print(\"Model: \", model)\n", + "# for every block of text, anonymized first and then predict\n", + "for entry in pii_data:\n", + " await model.predict(entry[\"text\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Faker + Presidio Replacement Method\n", + "\n", + "Here we will have Faker generate anonymized replacement PII data and use Presidio to identify and replace the PII data in the original text.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](../../media/pii/replace.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from typing import Any\n", + "\n", + "import anthropic\n", "\n", + "import weave\n", "\n", + "# Define an input postprocessing function that applies our Faker anonymization and Presidio redaction for the model prediction Weave Op\n", "faker = my_faker()\n", + "\n", + "\n", + "def postprocess_inputs_faker(inputs: dict[str, Any]) -> dict:\n", + " inputs[\"text_block\"] = faker.redact_and_anonymize_with_faker(inputs[\"text_block\"])\n", + " return inputs\n", + "\n", + "\n", + "# Weave model / predict function\n", + "class sentiment_analysis_faker_pii_model(weave.Model):\n", + " model_name: str\n", + " system_prompt: str\n", + " temperature: int\n", + "\n", + " @weave.op(\n", + " postprocess_inputs=postprocess_inputs_faker,\n", + " )\n", + " async def predict(self, text_block: str) -> dict:\n", + " client = anthropic.AsyncAnthropic()\n", + " response = await client.messages.create(\n", + " max_tokens=1024,\n", + " model=self.model_name,\n", + " system=self.system_prompt,\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": text_block}]}\n", + " ],\n", + " )\n", + " result = response.content[0].text\n", + " if result is None:\n", + " raise ValueError(\"No response from model\")\n", + " parsed = json.loads(result)\n", + " return parsed" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: sentiment_analysis_faker_pii_model(name='claude-3-sonnet', description=None, model_name='claude-3-5-sonnet-20240620', system_prompt='You are a Sentiment Analysis classifier. You will be classifying text based on their sentiment. Your input will be a block of text. You will answer with one the following rating option[\"positive\", \"negative\", \"neutral\"]. Your answer should be one word in json format: {classification}. Ensure that it is valid JSON.', temperature=0)\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a62-34d0-70a0-a442-d08e59d745de\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a62-386a-7632-8785-8102ac495ae8\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a62-3c15-70b1-8bd9-b143e0aa8b20\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a62-4105-7af1-8446-572485bab118\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a62-44c5-7760-8d93-65743f1f2152\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a62-48c1-7b71-ab41-0e02dddd795c\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a62-4d50-7320-9f1b-03e9a3efeb4b\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a62-5155-78e3-a7a5-605fc663a8c3\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a62-54fc-71d1-9b6c-775570e84b24\n", + "🍩 https://wandb.ai/wandb/pii_cookbook/r/call/01924a62-589a-7d22-a8d5-eaeed5e8df40\n" + ] + } + ], + "source": [ + "# create our LLM model with a system prompt\n", + "model = sentiment_analysis_faker_pii_model(\n", + " name=\"claude-3-sonnet\",\n", + " model_name=\"claude-3-5-sonnet-20240620\",\n", + " system_prompt='You are a Sentiment Analysis classifier. You will be classifying text based on their sentiment. Your input will be a block of text. You will answer with one the following rating option[\"positive\", \"negative\", \"neutral\"]. Your answer should be one word in json format: {classification}. Ensure that it is valid JSON.',\n", + " temperature=0,\n", + ")\n", + "\n", + "print(\"Model: \", model)\n", + "# for every block of text, anonymized first and then predict\n", "for entry in pii_data:\n", - " anonymized_entry = faker.anonymize_my_text(entry[\"text\"])\n", - " (await model.predict(anonymized_entry))" + " await model.predict(entry[\"text\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Checklist for Safely Using Weave with PII Data\n", + "\n", + "### During Testing\n", + "- Log anonymized data to check PII detection\n", + "- Track PII handling processes with Weave Traces\n", + "- Measure anonymization performance without exposing real PII\n", + "\n", + "### In Production\n", + "- Never log raw PII\n", + "- Encrypt sensitive fields before logging\n", + "\n", + "### Encryption Tips\n", + "- Use reversible encryption for data you need to decrypt later\n", + "- Apply one-way hashing for unique IDs you don't need to reverse\n", + "- Consider specialized encryption for data you need to analyze while encrypted" ] }, { @@ -556,7 +909,6 @@ "source": [ "
\n", " (Optional) Encrypting our data \n", - "\n", "![](../../media/pii/encrypt.png)\n", "\n", "In addition to anonymizing PII, we can add an extra layer of security by encrypting our data using the cryptography library's [Fernet](https://cryptography.io/en/latest/fernet/) symmetric encryption. This approach ensures that even if the anonymized data is intercepted, it remains unreadable without the encryption key.\n", @@ -681,7 +1033,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.4" + "version": "3.12.5" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 9c0cb00d680..f0392b7ace4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,20 +36,13 @@ dynamic = ["version"] dependencies = [ "pydantic>=2.0.0", "wandb>=0.17.1", - "analytics-python>=1.2.9", # Segment logging - "python-dateutil>=2.8.2", # For ISO date parsing "packaging>=21.0", # For version parsing in integrations "tenacity>=8.3.0,!=8.4.0", # Excluding 8.4.0 because it had a bug on import of AsyncRetrying "emoji>=2.12.1", # For emoji shortcode support in Feedback "uuid-utils>=0.9.0", # Used for ID generation - remove once python's built-in uuid supports UUIDv7 - "numpy>1.21.0", # Used in box.py (should be made optional) + "numpy>1.21.0", # Used in box.py and scorer.py (should be made optional) "rich", # Used for special formatting of tables (should be made optional) - - # dependencies for remaining legacy code. Remove when possible - "httpx", - "aiohttp", - "gql", - "requests_toolbelt", + "gql[aiohttp,requests]", # Used exclusively in wandb_api.py ] [project.optional-dependencies] @@ -85,14 +78,15 @@ test = [ "sqlparse==0.5.0", # Integration Tests - "pytest-recording==0.13.1", - "vcrpy==6.0.1", + "pytest-recording>=0.13.2", + "vcrpy>=6.0.1", # serving tests "flask", "uvicorn>=0.27.0", "pillow", "filelock", + "httpx", ] [project.scripts] @@ -148,6 +142,7 @@ select = [ "W291", # https://docs.astral.sh/ruff/rules/trailing-whitespace/ "W391", # https://docs.astral.sh/ruff/rules/too-many-newlines-at-end-of-file/ "F401", # https://docs.astral.sh/ruff/rules/unused-import/ + "TID252", # https://docs.astral.sh/ruff/rules/relative-imports/#relative-imports-tid252 ] ignore = [ # we use Google style @@ -160,10 +155,6 @@ exclude = ["weave_query"] [tool.ruff.lint.isort] known-third-party = ["wandb", "weave_query"] -[tool.ruff.lint.per-file-ignores] -"tests/*" = ["F401"] - - [tool.ruff] line-length = 88 show-fixes = true diff --git a/tests/conftest.py b/tests/conftest.py index 1341093f25d..bbba381f04b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,33 +1,41 @@ +import base64 +import contextlib import logging import os -from contextlib import _GeneratorContextManager -from typing import Callable, Iterator +import subprocess +import time +import typing +import urllib +from typing import Iterator import pytest +import requests from fastapi import FastAPI from fastapi.testclient import TestClient import weave from tests.trace.util import DummyTestException -from weave.trace import autopatch, weave_init -from weave.trace.client_context import context_state +from weave.trace import autopatch, weave_client, weave_init from weave.trace_server import ( clickhouse_trace_server_batched, + external_to_internal_trace_server_adapter, sqlite_trace_server, ) +from weave.trace_server import environment as ts_env from weave.trace_server import trace_server_interface as tsi from weave.trace_server_bindings import remote_http_trace_server -from .trace.trace_server_clickhouse_conftest import * -from .wandb_system_tests_conftest import * - # Force testing to never report wandb sentry events os.environ["WANDB_ERROR_REPORTING"] = "false" -def pytest_sessionfinish(session, exitstatus): - if exitstatus == pytest.ExitCode.NO_TESTS_COLLECTED: - session.exitstatus = 0 +def pytest_addoption(parser): + parser.addoption( + "--weave-server", + action="store", + default="sqlite", + help="Specify the client object to use: sqlite or clickhouse", + ) def pytest_collection_modifyitems(config, items): @@ -37,11 +45,289 @@ def pytest_collection_modifyitems(config, items): item.add_marker(pytest.mark.weave_client) -PYTEST_CURRENT_TEST_ENV_VAR = "PYTEST_CURRENT_TEST" +def pytest_sessionfinish(session, exitstatus): + if exitstatus == pytest.ExitCode.NO_TESTS_COLLECTED: + session.exitstatus = 0 + + +class ThrowingServer(tsi.TraceServerInterface): + # Call API + def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: + raise DummyTestException("FAILURE - call_start, req:", req) + + def call_end(self, req: tsi.CallEndReq) -> tsi.CallEndRes: + raise DummyTestException("FAILURE - call_end, req:", req) + + def call_read(self, req: tsi.CallReadReq) -> tsi.CallReadRes: + raise DummyTestException("FAILURE - call_read, req:", req) + + def calls_query(self, req: tsi.CallsQueryReq) -> tsi.CallsQueryRes: + raise DummyTestException("FAILURE - calls_query, req:", req) + + def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema]: + raise DummyTestException("FAILURE - calls_query_stream, req:", req) + + def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: + raise DummyTestException("FAILURE - calls_delete, req:", req) + + def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsRes: + raise DummyTestException("FAILURE - calls_query_stats, req:", req) + + def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: + raise DummyTestException("FAILURE - call_update, req:", req) + + # Op API + def op_create(self, req: tsi.OpCreateReq) -> tsi.OpCreateRes: + raise DummyTestException("FAILURE - op_create, req:", req) + + def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes: + raise DummyTestException("FAILURE - op_read, req:", req) + + def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes: + raise DummyTestException("FAILURE - ops_query, req:", req) + + # Cost API + def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: + raise DummyTestException("FAILURE - cost_create, req:", req) + + def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: + raise DummyTestException("FAILURE - cost_query, req:", req) + + def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: + raise DummyTestException("FAILURE - cost_purge, req:", req) + + # Obj API + def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: + raise DummyTestException("FAILURE - obj_create, req:", req) + + def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: + raise DummyTestException("FAILURE - obj_read, req:", req) + + def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: + raise DummyTestException("FAILURE - objs_query, req:", req) + + def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes: + raise DummyTestException("FAILURE - table_create, req:", req) + + def table_update(self, req: tsi.TableUpdateReq) -> tsi.TableUpdateRes: + raise DummyTestException("FAILURE - table_update, req:", req) + + def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes: + raise DummyTestException("FAILURE - table_query, req:", req) + + def refs_read_batch(self, req: tsi.RefsReadBatchReq) -> tsi.RefsReadBatchRes: + raise DummyTestException("FAILURE - refs_read_batch, req:", req) + + def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: + raise DummyTestException("FAILURE - file_create, req:", req) + + def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes: + raise DummyTestException("FAILURE - file_content_read, req:", req) + + def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: + raise DummyTestException("FAILURE - feedback_create, req:", req) + + def feedback_query(self, req: tsi.FeedbackQueryReq) -> tsi.FeedbackQueryRes: + raise DummyTestException("FAILURE - feedback_query, req:", req) + + def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: + raise DummyTestException("FAILURE - feedback_purge, req:", req) + + +@pytest.fixture() +def client_with_throwing_server(client): + curr_server = client.server + client.server = ThrowingServer() + try: + yield client + finally: + client.server = curr_server + + +@pytest.fixture(scope="session") +def clickhouse_server(): + server_up = _check_server_up( + ts_env.wf_clickhouse_host(), ts_env.wf_clickhouse_port() + ) + if not server_up: + pytest.fail("clickhouse server is not running") + + +@pytest.fixture(scope="session") +def clickhouse_trace_server(clickhouse_server): + clickhouse_trace_server = ( + clickhouse_trace_server_batched.ClickHouseTraceServer.from_env( + use_async_insert=False + ) + ) + clickhouse_trace_server._run_migrations() + yield clickhouse_trace_server + + +def _check_server_health( + base_url: str, endpoint: str, num_retries: int = 1, sleep_time: int = 1 +) -> bool: + for _ in range(num_retries): + try: + response = requests.get(urllib.parse.urljoin(base_url, endpoint)) + if response.status_code == 200: + return True + time.sleep(sleep_time) + except requests.exceptions.ConnectionError: + time.sleep(sleep_time) + + print( + f"Server not healthy @ {urllib.parse.urljoin(base_url, endpoint)}: no response" + ) + return False + + +def _check_server_up(host, port) -> bool: + base_url = f"http://{host}:{port}/" + endpoint = "ping" + + def server_healthy(num_retries=1): + return _check_server_health( + base_url=base_url, endpoint=endpoint, num_retries=num_retries + ) + + if server_healthy(): + return True + + if os.environ.get("CI") != "true": + print("CI is not true, not starting clickhouse server") + + subprocess.Popen( + [ + "docker", + "run", + "-d", + "--rm", + "-p", + f"{port}:8123", + "--name", + "weave-python-test-clickhouse-server", + "--ulimit", + "nofile=262144:262144", + "clickhouse/clickhouse-server", + ] + ) + + # wait for the server to start + return server_healthy(num_retries=30) + + +class TwoWayMapping: + def __init__(self): + self._ext_to_int_map = {} + self._int_to_ext_map = {} + + # Useful for testing to ensure caching is working + self.stats = { + "ext_to_int": { + "hits": 0, + "misses": 0, + }, + "int_to_ext": { + "hits": 0, + "misses": 0, + }, + } + + def ext_to_int(self, key, default=None): + if key not in self._ext_to_int_map: + if default is None: + raise ValueError(f"Key {key} not found") + if default in self._int_to_ext_map: + raise ValueError(f"Default {default} already in use") + self._ext_to_int_map[key] = default + self._int_to_ext_map[default] = key + self.stats["ext_to_int"]["misses"] += 1 + else: + self.stats["ext_to_int"]["hits"] += 1 + return self._ext_to_int_map[key] + + def int_to_ext(self, key, default): + if key not in self._int_to_ext_map: + if default is None: + raise ValueError(f"Key {key} not found") + if default in self._ext_to_int_map: + raise ValueError(f"Default {default} already in use") + self._int_to_ext_map[key] = default + self._ext_to_int_map[default] = key + self.stats["int_to_ext"]["misses"] += 1 + else: + self.stats["int_to_ext"]["hits"] += 1 + return self._int_to_ext_map[key] + + +def b64(s: str) -> str: + # Base64 encode the string + return base64.b64encode(s.encode("ascii")).decode("ascii") + + +class DummyIdConverter(external_to_internal_trace_server_adapter.IdConverter): + def __init__(self): + self._project_map = TwoWayMapping() + self._run_map = TwoWayMapping() + self._user_map = TwoWayMapping() + + def ext_to_int_project_id(self, project_id: str) -> str: + return self._project_map.ext_to_int(project_id, b64(project_id)) + + def int_to_ext_project_id(self, project_id: str) -> typing.Optional[str]: + return self._project_map.int_to_ext(project_id, b64(project_id)) + + def ext_to_int_run_id(self, run_id: str) -> str: + return self._run_map.ext_to_int(run_id, b64(run_id) + ":" + run_id) + + def int_to_ext_run_id(self, run_id: str) -> str: + exp = run_id.split(":")[1] + return self._run_map.int_to_ext(run_id, exp) + + def ext_to_int_user_id(self, user_id: str) -> str: + return self._user_map.ext_to_int(user_id, b64(user_id)) + + def int_to_ext_user_id(self, user_id: str) -> str: + return self._user_map.int_to_ext(user_id, b64(user_id)) + + +class TestOnlyUserInjectingExternalTraceServer( + external_to_internal_trace_server_adapter.ExternalTraceServer +): + def __init__( + self, + internal_trace_server: tsi.TraceServerInterface, + id_converter: external_to_internal_trace_server_adapter.IdConverter, + user_id: str, + ): + super().__init__(internal_trace_server, id_converter) + self._user_id = user_id + + def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: + req.start.wb_user_id = self._user_id + return super().call_start(req) + + def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: + req.wb_user_id = self._user_id + return super().calls_delete(req) + + def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: + req.wb_user_id = self._user_id + return super().call_update(req) + def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: + req.wb_user_id = self._user_id + return super().feedback_create(req) + def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: + req.wb_user_id = self._user_id + return super().cost_create(req) + + +# https://docs.pytest.org/en/7.1.x/example/simple.html#pytest-current-test-environment-variable def get_test_name(): - return os.environ.get(PYTEST_CURRENT_TEST_ENV_VAR).split(" ")[0] + return os.environ.get("PYTEST_CURRENT_TEST", " ").split(" ")[0] class InMemoryWeaveLogCollector(logging.Handler): @@ -102,12 +388,6 @@ def logging_error_check(request, log_collector): ) -@pytest.fixture() -def strict_op_saving(): - with context_state.strict_op_saving(True): - yield - - class TestOnlyFlushingWeaveClient(weave_client.WeaveClient): """ A WeaveClient that automatically flushes after every method call. @@ -233,7 +513,7 @@ def create_client(request) -> weave_init.InitializedClient: @pytest.fixture() -def client(request) -> Generator[weave_client.WeaveClient, None, None]: +def client(request): """This is the standard fixture used everywhere in tests to test end to end client functionality""" inited_client = create_client(request) @@ -244,11 +524,7 @@ def client(request) -> Generator[weave_client.WeaveClient, None, None]: @pytest.fixture() -def client_creator( - request, -) -> Generator[ - Callable[[], _GeneratorContextManager[weave_client.WeaveClient]], None, None -]: +def client_creator(request): """This fixture is useful for delaying the creation of the client (ex. when you want to set settings first)""" @contextlib.contextmanager @@ -313,97 +589,3 @@ def post(url, data=None, json=None, **kwargs): yield (client, remote_client, records) weave.trace_server.requests.post = orig_post - - -class ThrowingServer(tsi.TraceServerInterface): - # Call API - def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: - raise DummyTestException("FAILURE - call_start, req:", req) - - def call_end(self, req: tsi.CallEndReq) -> tsi.CallEndRes: - raise DummyTestException("FAILURE - call_end, req:", req) - - def call_read(self, req: tsi.CallReadReq) -> tsi.CallReadRes: - raise DummyTestException("FAILURE - call_read, req:", req) - - def calls_query(self, req: tsi.CallsQueryReq) -> tsi.CallsQueryRes: - raise DummyTestException("FAILURE - calls_query, req:", req) - - def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema]: - raise DummyTestException("FAILURE - calls_query_stream, req:", req) - - def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: - raise DummyTestException("FAILURE - calls_delete, req:", req) - - def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsRes: - raise DummyTestException("FAILURE - calls_query_stats, req:", req) - - def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: - raise DummyTestException("FAILURE - call_update, req:", req) - - # Op API - def op_create(self, req: tsi.OpCreateReq) -> tsi.OpCreateRes: - raise DummyTestException("FAILURE - op_create, req:", req) - - def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes: - raise DummyTestException("FAILURE - op_read, req:", req) - - def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes: - raise DummyTestException("FAILURE - ops_query, req:", req) - - # Cost API - def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: - raise DummyTestException("FAILURE - cost_create, req:", req) - - def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: - raise DummyTestException("FAILURE - cost_query, req:", req) - - def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: - raise DummyTestException("FAILURE - cost_purge, req:", req) - - # Obj API - def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: - raise DummyTestException("FAILURE - obj_create, req:", req) - - def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: - raise DummyTestException("FAILURE - obj_read, req:", req) - - def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: - raise DummyTestException("FAILURE - objs_query, req:", req) - - def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes: - raise DummyTestException("FAILURE - table_create, req:", req) - - def table_update(self, req: tsi.TableUpdateReq) -> tsi.TableUpdateRes: - raise DummyTestException("FAILURE - table_update, req:", req) - - def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes: - raise DummyTestException("FAILURE - table_query, req:", req) - - def refs_read_batch(self, req: tsi.RefsReadBatchReq) -> tsi.RefsReadBatchRes: - raise DummyTestException("FAILURE - refs_read_batch, req:", req) - - def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: - raise DummyTestException("FAILURE - file_create, req:", req) - - def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes: - raise DummyTestException("FAILURE - file_content_read, req:", req) - - def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: - raise DummyTestException("FAILURE - feedback_create, req:", req) - - def feedback_query(self, req: tsi.FeedbackQueryReq) -> tsi.FeedbackQueryRes: - raise DummyTestException("FAILURE - feedback_query, req:", req) - - def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: - raise DummyTestException("FAILURE - feedback_purge, req:", req) - - -@pytest.fixture() -def client_with_throwing_server(client: weave_client.WeaveClient): - curr_server = client.server - client.server = ThrowingServer() - try: - yield client - finally: - client.server = curr_server diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index 28b29e490ee..f387d20b62b 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -644,32 +644,6 @@ def test_trace_call_query_filter_trace_roots_only(client): assert len(inner_res.calls) == exp_count -@pytest.mark.skip("too slow") -def test_trace_call_query_filter_wb_run_ids(client, user_by_api_key_in_env): - call_spec = simple_line_call_bootstrap(init_wandb=True) - - res = get_all_calls_asserting_finished(client, call_spec) - - wb_run_ids = list(set([call.wb_run_id for call in res.calls]) - set([None])) - - for wb_run_ids, exp_count in [ - # Test the None case - (None, call_spec.total_calls), - # Test the empty list case - ([], call_spec.total_calls), - # Test List (of 1) - (wb_run_ids, call_spec.run_calls), - ]: - inner_res = get_client_trace_server(client).calls_query( - tsi.CallsQueryReq( - project_id=get_client_project_id(client), - filter=tsi.CallsFilter(wb_run_ids=wb_run_ids), - ) - ) - - assert len(inner_res.calls) == exp_count - - def test_trace_call_query_limit(client): call_spec = simple_line_call_bootstrap() @@ -1561,23 +1535,6 @@ class MySerializableClass(weave.Object): assert b2.obj == repr(b_obj) -# Note: this test only works with the `trace_init_client` fixture -@pytest.mark.skip(reason="TODO: Skipping since it seems to rely on the testcontainer") -def test_ref_get_no_client(trace_init_client): - trace_client = trace_init_client.client - data = weave.publish(42) - data_got = weave.ref(data.uri()).get() - assert data_got == 42 - - # clear the graph client effectively "de-initializing it" - with _no_graph_client(): - # This patching is required just to make the test path work - with _patched_default_initializer(trace_client): - # Now we will try to get the data again - data_got = weave.ref(data.uri()).get() - assert data_got == 42 - - @contextmanager def _no_graph_client(): client = weave.trace.client_context.weave_client.get_weave_client() diff --git a/tests/trace/test_op_versioning.py b/tests/trace/test_op_versioning.py index b0892063518..a8ed18ce780 100644 --- a/tests/trace/test_op_versioning.py +++ b/tests/trace/test_op_versioning.py @@ -34,7 +34,7 @@ def solo_versioned_op(a: int) -> float: """ -def test_solo_op_versioning(strict_op_saving, client): +def test_solo_op_versioning(client): from tests.trace import op_versioning_solo ref = weave.publish(op_versioning_solo.solo_versioned_op) @@ -57,7 +57,7 @@ def versioned_op(self, a: int) -> float: """ -def test_object_op_versioning(strict_op_saving, client): +def test_object_op_versioning(client): from tests.trace import op_versioning_obj obj = op_versioning_obj.MyTestObjWithOp(val=5) @@ -81,7 +81,7 @@ def versioned_op_importfrom(a: int) -> float: """ -def test_op_versioning_importfrom(strict_op_saving, client): +def test_op_versioning_importfrom(client): from tests.trace import op_versioning_importfrom ref = weave.publish(op_versioning_importfrom.versioned_op_importfrom) @@ -92,7 +92,7 @@ def test_op_versioning_importfrom(strict_op_saving, client): assert saved_code == EXPECTED_IMPORTFROM_OP_CODE -def test_op_versioning_lotsofstuff(strict_op_saving): +def test_op_versioning_lotsofstuff(): @weave.op() def versioned_op_lotsofstuff(a: int) -> float: j = [x + 1 for x in range(a)] @@ -100,11 +100,11 @@ def versioned_op_lotsofstuff(a: int) -> float: return np.array(k).mean() -def test_op_versioning_inline_import(strict_op_saving, client): +def test_op_versioning_inline_import(client): pass -def test_op_versioning_inline_func_decl(strict_op_saving): +def test_op_versioning_inline_func_decl(): @weave.op() def versioned_op_inline_func_decl(a: int) -> float: def inner_func(x): @@ -126,7 +126,7 @@ def versioned_op_closure_constant(a: int) -> float: """ -def test_op_versioning_closure_constant(strict_op_saving, client): +def test_op_versioning_closure_constant(client): x = 10 @weave.op() @@ -155,7 +155,7 @@ def versioned_op_closure_constant(a: int) -> float: """ -def test_op_versioning_closure_dict_simple(strict_op_saving, client): +def test_op_versioning_closure_dict_simple(client): x = {"a": 5, "b": 10} @weave.op() @@ -186,7 +186,7 @@ def versioned_op_closure_constant(a: int) -> float: @pytest.mark.skip("custom objs not working with new weave_client") -def test_op_versioning_closure_dict_np(strict_op_saving, client): +def test_op_versioning_closure_dict_np(client): x = {"a": 5, "b": np.array([1, 2, 3])} @weave.op() @@ -224,7 +224,7 @@ def pony(v: int): @pytest.mark.skip("failing in ci, due to some kind of /tmp file slowness?") -def test_op_versioning_closure_dict_ops(strict_op_saving, client): +def test_op_versioning_closure_dict_ops(client): @weave.op() def cat(v: int): print("hello from cat()") @@ -279,7 +279,7 @@ def pony(v: int): @pytest.mark.skip("custom objs not working with new weave_client") -def test_op_versioning_mixed(strict_op_saving, client): +def test_op_versioning_mixed(client): @weave.op() def cat(v: int): print("hello from cat()") @@ -313,7 +313,7 @@ def pony(v: int): assert op2(1) == 102.0 -def test_op_versioning_exception(strict_op_saving): +def test_op_versioning_exception(): # Just ensure this doesn't raise by running it. @weave.op() def versioned_op_exception(a: int) -> float: @@ -325,7 +325,7 @@ def versioned_op_exception(a: int) -> float: return x -def test_op_versioning_2ops(strict_op_saving, client): +def test_op_versioning_2ops(client): @weave.op() def dog(): print("hello from dog()") @@ -355,7 +355,9 @@ def some_d(v: int) -> SomeDict: """ -def test_op_return_typeddict_annotation(client, strict_op_saving): +def test_op_return_typeddict_annotation( + client, +): class SomeDict(typing.TypedDict): val: int @@ -392,7 +394,9 @@ def some_d(v: int): """ -def test_op_return_return_custom_class(client, strict_op_saving): +def test_op_return_return_custom_class( + client, +): class MyCoolClass: val: int @@ -426,7 +430,9 @@ def internal_fn(x): """ -def test_op_nested_function(client, strict_op_saving): +def test_op_nested_function( + client, +): @weave.op() def some_d(v: int): def internal_fn(x): diff --git a/tests/trace/test_weave_client.py b/tests/trace/test_weave_client.py index 5d081f8ce63..f1e07070ac9 100644 --- a/tests/trace/test_weave_client.py +++ b/tests/trace/test_weave_client.py @@ -632,7 +632,7 @@ def hello_world(): assert obj.project_id == "shawn/test-project2" -def test_saveload_customtype(client, strict_op_saving): +def test_saveload_customtype(client): class MyCustomObj: a: int b: str diff --git a/tests/trace/test_weaveflow.py b/tests/trace/test_weaveflow.py index 37490de691a..0fbc81180a4 100644 --- a/tests/trace/test_weaveflow.py +++ b/tests/trace/test_weaveflow.py @@ -1,5 +1,4 @@ import numpy as np -import pytest from pydantic import Field import weave @@ -119,46 +118,6 @@ def op_with_unknown_param() -> int: assert op_with_unknown_param() == 12 -@pytest.mark.skip("artifact file download doesn't work here?") -def test_saveloop_idempotent_with_refs(user_by_api_key_in_env): - with weave.wandb_client("weaveflow_example-idempotent_with_refs"): - - @weave.type() - class A: - b: int - - @weave.op() - def call(self, v): - return self.b + v - - @weave.type() - class C: - a: A - c: int - - @weave.op() - def call(self, v): - return self.a.call(v) * self.c - - a = A(5) - c = C(a, 10) - assert c.call(40) == 450 - - c2_0_ref = weave.ref("C:latest") - c2_0 = c2_0_ref.get() - assert c2_0.call(50) == 550 - - c2_1_ref = weave.ref("C:latest") - c2_1 = c2_1_ref.get() - assert c2_1.call(60) == 650 - assert c2_0_ref.version == c2_1_ref.version - - c2_2_ref = weave.ref("C:latest") - c2_2 = c2_2_ref.get() - assert c2_2.call(60) == 650 - assert c2_1_ref.version == c2_2_ref.version - - def test_subobj_ref_passing(client): dataset = client.save( weave.Dataset(rows=[{"x": 1, "y": 3}, {"x": 2, "y": 16}]), "my-dataset" diff --git a/tests/trace/trace_server_clickhouse_conftest.py b/tests/trace/trace_server_clickhouse_conftest.py deleted file mode 100644 index 5f80d1478de..00000000000 --- a/tests/trace/trace_server_clickhouse_conftest.py +++ /dev/null @@ -1,225 +0,0 @@ -import base64 -import os -import subprocess -import time -import typing -import urllib -import uuid - -import pytest -import requests - -from weave.trace import weave_client -from weave.trace.weave_init import InitializedClient -from weave.trace_server import ( - clickhouse_trace_server_batched, - external_to_internal_trace_server_adapter, -) -from weave.trace_server import environment as wf_env -from weave.trace_server import trace_server_interface as tsi - - -@pytest.fixture(scope="session") -def clickhouse_server(): - server_up = _check_server_up( - wf_env.wf_clickhouse_host(), wf_env.wf_clickhouse_port() - ) - if not server_up: - pytest.fail("clickhouse server is not running") - - -@pytest.fixture(scope="session") -def clickhouse_trace_server(clickhouse_server): - clickhouse_trace_server = ( - clickhouse_trace_server_batched.ClickHouseTraceServer.from_env( - use_async_insert=False - ) - ) - clickhouse_trace_server._run_migrations() - yield clickhouse_trace_server - - -@pytest.fixture() -def trace_init_client(clickhouse_trace_server, user_by_api_key_in_env): - # Generate a random project name to avoid conflicts between tests - # using the same shared backend server - random_project_name = str(uuid.uuid4()) - server = TestOnlyUserInjectingExternalTraceServer( - clickhouse_trace_server, DummyIdConverter(), user_by_api_key_in_env.username - ) - graph_client = weave_client.WeaveClient( - user_by_api_key_in_env.username, random_project_name, server - ) - - inited_client = InitializedClient(graph_client) - - try: - yield inited_client - finally: - inited_client.reset() - - -@pytest.fixture() -def trace_client(trace_init_client): - return trace_init_client.client - - -def _check_server_health( - base_url: str, endpoint: str, num_retries: int = 1, sleep_time: int = 1 -) -> bool: - for _ in range(num_retries): - try: - response = requests.get(urllib.parse.urljoin(base_url, endpoint)) - if response.status_code == 200: - return True - time.sleep(sleep_time) - except requests.exceptions.ConnectionError: - time.sleep(sleep_time) - - print( - f"Server not healthy @ {urllib.parse.urljoin(base_url, endpoint)}: no response" - ) - return False - - -def _check_server_up(host, port) -> bool: - base_url = f"http://{host}:{port}/" - endpoint = "ping" - - def server_healthy(num_retries=1): - return _check_server_health( - base_url=base_url, endpoint=endpoint, num_retries=num_retries - ) - - if server_healthy(): - return True - - if os.environ.get("CI") != "true": - print("CI is not true, not starting clickhouse server") - - subprocess.Popen( - [ - "docker", - "run", - "-d", - "--rm", - "-p", - f"{port}:8123", - "--name", - "weave-python-test-clickhouse-server", - "--ulimit", - "nofile=262144:262144", - "clickhouse/clickhouse-server", - ] - ) - - # wait for the server to start - return server_healthy(num_retries=30) - - -class TwoWayMapping: - def __init__(self): - self._ext_to_int_map = {} - self._int_to_ext_map = {} - - # Useful for testing to ensure caching is working - self.stats = { - "ext_to_int": { - "hits": 0, - "misses": 0, - }, - "int_to_ext": { - "hits": 0, - "misses": 0, - }, - } - - def ext_to_int(self, key, default=None): - if key not in self._ext_to_int_map: - if default is None: - raise ValueError(f"Key {key} not found") - if default in self._int_to_ext_map: - raise ValueError(f"Default {default} already in use") - self._ext_to_int_map[key] = default - self._int_to_ext_map[default] = key - self.stats["ext_to_int"]["misses"] += 1 - else: - self.stats["ext_to_int"]["hits"] += 1 - return self._ext_to_int_map[key] - - def int_to_ext(self, key, default): - if key not in self._int_to_ext_map: - if default is None: - raise ValueError(f"Key {key} not found") - if default in self._ext_to_int_map: - raise ValueError(f"Default {default} already in use") - self._int_to_ext_map[key] = default - self._ext_to_int_map[default] = key - self.stats["int_to_ext"]["misses"] += 1 - else: - self.stats["int_to_ext"]["hits"] += 1 - return self._int_to_ext_map[key] - - -def b64(s: str) -> str: - # Base64 encode the string - return base64.b64encode(s.encode("ascii")).decode("ascii") - - -class DummyIdConverter(external_to_internal_trace_server_adapter.IdConverter): - def __init__(self): - self._project_map = TwoWayMapping() - self._run_map = TwoWayMapping() - self._user_map = TwoWayMapping() - - def ext_to_int_project_id(self, project_id: str) -> str: - return self._project_map.ext_to_int(project_id, b64(project_id)) - - def int_to_ext_project_id(self, project_id: str) -> typing.Optional[str]: - return self._project_map.int_to_ext(project_id, b64(project_id)) - - def ext_to_int_run_id(self, run_id: str) -> str: - return self._run_map.ext_to_int(run_id, b64(run_id) + ":" + run_id) - - def int_to_ext_run_id(self, run_id: str) -> str: - exp = run_id.split(":")[1] - return self._run_map.int_to_ext(run_id, exp) - - def ext_to_int_user_id(self, user_id: str) -> str: - return self._user_map.ext_to_int(user_id, b64(user_id)) - - def int_to_ext_user_id(self, user_id: str) -> str: - return self._user_map.int_to_ext(user_id, b64(user_id)) - - -class TestOnlyUserInjectingExternalTraceServer( - external_to_internal_trace_server_adapter.ExternalTraceServer -): - def __init__( - self, - internal_trace_server: tsi.TraceServerInterface, - id_converter: external_to_internal_trace_server_adapter.IdConverter, - user_id: str, - ): - super().__init__(internal_trace_server, id_converter) - self._user_id = user_id - - def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: - req.start.wb_user_id = self._user_id - return super().call_start(req) - - def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: - req.wb_user_id = self._user_id - return super().calls_delete(req) - - def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: - req.wb_user_id = self._user_id - return super().call_update(req) - - def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: - req.wb_user_id = self._user_id - return super().feedback_create(req) - - def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: - req.wb_user_id = self._user_id - return super().cost_create(req) diff --git a/tests/wandb_system_tests_conftest.py b/tests/wandb_system_tests_conftest.py deleted file mode 100644 index e0b18358315..00000000000 --- a/tests/wandb_system_tests_conftest.py +++ /dev/null @@ -1,406 +0,0 @@ -import contextlib -import dataclasses -import os -import platform -import secrets -import string -import subprocess -import time -import unittest.mock -import urllib.parse -from typing import Any, Generator, Literal, Optional, Union - -import filelock -import pytest -import requests -import wandb - -from weave.wandb_interface.wandb_api import from_environment - - -# The following code snippet was copied from: -# https://github.com/pytest-dev/pytest-xdist/issues/84#issuecomment-617566804 -# -# The purpose of the `serial` fixture is to ensure that any test with `serial` -# must be run serially. This is critical for tests which use environment -# variables for auth since the low level W&B API state ends up being shared -# between tests. -@pytest.fixture(scope="session") -def lock(tmp_path_factory): - base_temp = tmp_path_factory.getbasetemp() - lock_file = base_temp.parent / "serial.lock" - yield filelock.FileLock(lock_file=str(lock_file)) - with contextlib.suppress(OSError): - os.remove(path=lock_file) - - -@pytest.fixture() -def serial(lock): - with lock.acquire(poll_interval=0.1): - yield - - -# End of copied code snippet - - -@dataclasses.dataclass -class LocalBackendFixturePayload: - username: str - password: str - api_key: str - base_url: str - cookie: str - - -def determine_scope(fixture_name, config): - return config.getoption("--user-scope") - - -@pytest.fixture(scope=determine_scope) -def bootstrap_user( - worker_id: str, fixture_fn, base_url, wandb_debug -) -> Generator[LocalBackendFixturePayload, None, None]: - username = f"user-{worker_id}-{random_string()}" - command = UserFixtureCommand(command="up", username=username) - fixture_fn(command) - command = UserFixtureCommand( - command="password", username=username, password=username - ) - fixture_fn(command) - - with unittest.mock.patch.dict( - os.environ, - { - "WANDB_BASE_URL": base_url, - }, - ): - yield LocalBackendFixturePayload( - username=username, - password=username, - api_key=username, - base_url=base_url, - cookie="NOT-IMPLEMENTED", - ) - - -@pytest.fixture(scope=determine_scope) -def user_by_api_key_in_env( - bootstrap_user: LocalBackendFixturePayload, serial -) -> Generator[LocalBackendFixturePayload, None, None]: - with unittest.mock.patch.dict( - os.environ, - { - "WANDB_API_KEY": bootstrap_user.api_key, - }, - ): - with from_environment(): - wandb.teardown() # type: ignore - yield bootstrap_user - wandb.teardown() # type: ignore - - -@pytest.fixture(scope=determine_scope) -def user_by_api_key_netrc( - bootstrap_user: LocalBackendFixturePayload, -) -> Generator[LocalBackendFixturePayload, None, None]: - netrc_path = os.path.expanduser("~/.netrc") - old_netrc = None - if os.path.exists(netrc_path): - with open(netrc_path, "r") as f: - old_netrc = f.read() - try: - with open(netrc_path, "w") as f: - url = urllib.parse.urlparse(bootstrap_user.base_url).netloc - f.write( - f"machine {url}\n login user\n password {bootstrap_user.api_key}\n" - ) - with from_environment(): - yield bootstrap_user - finally: - if old_netrc is None: - os.remove(netrc_path) - else: - with open(netrc_path, "w") as f: - f.write(old_netrc) - - -################################################################## -## The following code is a selection copied from the wandb sdk. ## -## wandb/tests/unit_tests/conftest.py ## -################################################################## - - -# # `local-testcontainer` ports -LOCAL_BASE_PORT = "8080" -SERVICES_API_PORT = "8083" -FIXTURE_SERVICE_PORT = "9015" -WB_SERVER_HOST = "http://localhost" - - -def pytest_addoption(parser): - # note: we default to "function" scope to ensure the environment is - # set up properly when running the tests in parallel with pytest-xdist. - if os.environ.get("WB_SERVER_HOST"): - wandb_server_host = os.environ["WB_SERVER_HOST"] - else: - wandb_server_host = WB_SERVER_HOST - - parser.addoption( - "--user-scope", - default="function", # or "function" or "session" or "module" - help='cli to set scope of fixture "user-scope"', - ) - - parser.addoption( - "--job-num", - default=None, - help='cli to set "job-num"', - ) - - parser.addoption( - "--base-url", - default=f"{wandb_server_host}:{LOCAL_BASE_PORT}", - help='cli to set "base-url"', - ) - parser.addoption( - "--wandb-server-tag", - default="master", - help="Image tag to use for the wandb server", - ) - parser.addoption( - "--wandb-server-pull", - action="store_true", - default=False, - help="Force pull the latest wandb server image", - ) - # debug option: creates an admin account that can be used to log in to the - # app and inspect the test runs. - parser.addoption( - "--wandb-debug", - action="store_true", - default=False, - help="Run tests in debug mode", - ) - - parser.addoption( - "--weave-server", - action="store", - default="sqlite", - help="Specify the client object to use: sqlite or clickhouse", - ) - - -def random_string(length: int = 12) -> str: - """Generate a random string of a given length. - - :param length: Length of the string to generate. - :return: Random string. - """ - return "".join( - secrets.choice(string.ascii_lowercase + string.digits) for _ in range(length) - ) - - -@pytest.fixture(scope="session") -def base_url(request): - return request.config.getoption("--base-url") - - -@pytest.fixture(scope="session") -def wandb_server_tag(request): - return request.config.getoption("--wandb-server-tag") - - -@pytest.fixture(scope="session") -def wandb_server_pull(request): - if request.config.getoption("--wandb-server-pull"): - return "always" - return "missing" - - -@pytest.fixture(scope="session") -def wandb_debug(request): - return request.config.getoption("--wandb-debug", default=False) - - -def check_server_health( - base_url: str, endpoint: str, num_retries: int = 1, sleep_time: int = 1 -) -> bool: - """Check if wandb server is healthy. - - :param base_url: - :param num_retries: - :param sleep_time: - :return: - """ - for _ in range(num_retries): - try: - response = requests.get(urllib.parse.urljoin(base_url, endpoint)) - if response.status_code == 200: - return True - time.sleep(sleep_time) - except requests.exceptions.ConnectionError: - time.sleep(sleep_time) - - print( - f"Server not healthy @ {urllib.parse.urljoin(base_url, endpoint)}: no response" - ) - return False - - -def check_server_up( - base_url: str, - wandb_server_tag: str = "master", - wandb_server_pull: Literal["missing", "always"] = "missing", -) -> bool: - """Check if wandb server is up and running. - - If not on the CI and the server is not running, then start it first. - - :param base_url: - :param wandb_server_tag: - :param wandb_server_pull: - :return: - """ - app_health_endpoint = "healthz" - fixture_url = base_url.replace(LOCAL_BASE_PORT, FIXTURE_SERVICE_PORT) - fixture_health_endpoint = "health" - - if os.environ.get("CI") == "true": - return check_server_health(base_url=base_url, endpoint=app_health_endpoint) - - if not check_server_health(base_url=base_url, endpoint=app_health_endpoint): - # start wandb server locally and expose necessary ports to the host - command = [ - "docker", - "run", - "--pull", - wandb_server_pull, - "--rm", - "-v", - "wandb:/vol", - "-p", - f"{LOCAL_BASE_PORT}:{LOCAL_BASE_PORT}", - "-p", - f"{SERVICES_API_PORT}:{SERVICES_API_PORT}", - "-p", - f"{FIXTURE_SERVICE_PORT}:{FIXTURE_SERVICE_PORT}", - "-e", - "WANDB_ENABLE_TEST_CONTAINER=true", - *( - ["-e", "PARQUET_ENABLED=true"] - if os.environ.get("PARQUET_ENABLED") - else [] - ), - "--name", - "wandb-local", - "--platform", - "linux/amd64", - ( - "us-central1-docker.pkg.dev/wandb-production/images/local-testcontainer:tim-franken_branch_parquet" - if os.environ.get("PARQUET_ENABLED") - else f"us-central1-docker.pkg.dev/wandb-production/images/local-testcontainer:{wandb_server_tag}" - ), - ] - subprocess.Popen(command) - # wait for the server to start - server_is_up = check_server_health( - base_url=base_url, endpoint=app_health_endpoint, num_retries=30 - ) - if not server_is_up: - return False - # check that the fixture service is accessible - return check_server_health( - base_url=fixture_url, endpoint=fixture_health_endpoint, num_retries=30 - ) - - return check_server_health( - base_url=fixture_url, endpoint=fixture_health_endpoint, num_retries=10 - ) - - -@dataclasses.dataclass -class UserFixtureCommand: - command: Literal["up", "down", "down_all", "logout", "login", "password"] - username: Optional[str] = None - password: Optional[str] = None - admin: bool = False - endpoint: str = "db/user" - port: str = FIXTURE_SERVICE_PORT - method: Literal["post"] = "post" - - -@dataclasses.dataclass -class AddAdminAndEnsureNoDefaultUser: - email: str - password: str - endpoint: str = "api/users-admin" - port: str = SERVICES_API_PORT - method: Literal["put"] = "put" - - -@pytest.fixture(scope="session") -def fixture_fn(base_url, wandb_server_tag, wandb_server_pull): - def fixture_util( - cmd: Union[UserFixtureCommand, AddAdminAndEnsureNoDefaultUser], - ) -> bool: - endpoint = urllib.parse.urljoin( - base_url.replace(LOCAL_BASE_PORT, cmd.port), - cmd.endpoint, - ) - data: Any - if isinstance(cmd, UserFixtureCommand): - data = {"command": cmd.command} - if cmd.username: - data["username"] = cmd.username - if cmd.password: - data["password"] = cmd.password - if cmd.admin is not None: - data["admin"] = cmd.admin - elif isinstance(cmd, AddAdminAndEnsureNoDefaultUser): - data = [ - {"email": f"{cmd.email}@wandb.com", "password": cmd.password}, - ] - else: - raise NotImplementedError(f"{cmd} is not implemented") - # trigger fixture - print(f"Triggering fixture on {endpoint}: {data}") - response = getattr(requests, cmd.method)(endpoint, json=data) - if response.status_code != 200: - print(response.json()) - return False - return True - - # todo: remove this once testcontainer is available on Win - if platform.system() == "Windows": - pytest.skip("testcontainer is not available on Win") - - if not check_server_up(base_url, wandb_server_tag, wandb_server_pull): - pytest.fail("wandb server is not running") - - yield fixture_util - - -@pytest.fixture(scope=determine_scope) -def dev_only_admin_env_override() -> Generator[None, None, None]: - new_env = {} - admin_path = "../config/.admin.env" - if not os.path.exists(admin_path): - print( - f"WARNING: Could not find admin env file at {admin_path}. Please follow instructions in README.md to create one." - ) - yield - return - with open(admin_path) as file: - for line in file: - # skip comments and blank lines - if line.startswith("#") or line.strip().__len__() == 0: - continue - # otherwise treat lines as environment variables in a KEY=VALUE combo - key, value = line.split("=", 1) - new_env[key.strip()] = value.strip() - with unittest.mock.patch.dict( - os.environ, - new_env, - ): - yield diff --git a/weave-js/src/components/Button/Button.tsx b/weave-js/src/components/Button/Button.tsx index 0ef7472636a..2e9d987e5e4 100644 --- a/weave-js/src/components/Button/Button.tsx +++ b/weave-js/src/components/Button/Button.tsx @@ -68,6 +68,7 @@ export const Button = React.forwardRef( const isGhost = variant === 'ghost'; const isQuiet = variant === 'quiet'; const isDestructive = variant === 'destructive'; + const isOutline = variant === 'outline'; const hasBothIcons = startIcon && endIcon; const hasOnlyOneIcon = hasIcon && !hasBothIcons; @@ -91,7 +92,7 @@ export const Button = React.forwardRef( className={twMerge( classNames( 'night-aware', - "inline-flex items-center justify-center whitespace-nowrap rounded border-none font-['Source_Sans_Pro'] font-semibold", + "inline-flex items-center justify-center whitespace-nowrap rounded font-['Source_Sans_Pro'] font-semibold", 'disabled:pointer-events-none disabled:opacity-35', 'focus-visible:outline focus-visible:outline-[2px] focus-visible:outline-teal-500', { @@ -131,6 +132,18 @@ export const Button = React.forwardRef( // destructive 'bg-red-500 text-white hover:bg-red-450': isDestructive, 'bg-red-450': isDestructive && active, + + // outline + 'box-border gap-4 border border-moon-200 bg-white text-moon-650 hover:border-transparent hover:bg-teal-300/[0.48] hover:text-teal-600': + isOutline, + 'dark:border-moon-750 dark:bg-transparent dark:text-moon-200 dark:hover:bg-teal-700/[0.48] dark:hover:text-teal-400': + isOutline, + // the border was adding 2px even with className="box-border" so we manually set height + 'h-24': isOutline && isSmall, + 'h-32': isOutline && isMedium, + 'h-40': isOutline && isLarge, + + 'border-none': !isOutline, }, className ) diff --git a/weave-js/src/components/Button/types.ts b/weave-js/src/components/Button/types.ts index 38b50dc6f7b..f9d9187fc3e 100644 --- a/weave-js/src/components/Button/types.ts +++ b/weave-js/src/components/Button/types.ts @@ -11,6 +11,7 @@ export const ButtonVariants = { Ghost: 'ghost', Quiet: 'quiet', Destructive: 'destructive', + Outline: 'outline', } as const; export type ButtonVariant = (typeof ButtonVariants)[keyof typeof ButtonVariants]; diff --git a/weave-js/src/components/Callout/Callout.tsx b/weave-js/src/components/Callout/Callout.tsx new file mode 100644 index 00000000000..51028420f46 --- /dev/null +++ b/weave-js/src/components/Callout/Callout.tsx @@ -0,0 +1,42 @@ +import React from 'react'; +import {twMerge} from 'tailwind-merge'; + +import {Icon, IconName} from '../Icon'; +import {getTagColorClass, type TagColorName} from '../Tag'; +import {Tailwind} from '../Tailwind'; +import {CalloutSize} from './types'; + +export type CalloutProps = { + className?: string; + color: TagColorName; + icon: IconName; + size: CalloutSize; +}; + +export const Callout = ({className, color, icon, size}: CalloutProps) => { + return ( + +
+ +
+
+ ); +}; diff --git a/weave-js/src/components/Callout/index.ts b/weave-js/src/components/Callout/index.ts new file mode 100644 index 00000000000..1034de14c13 --- /dev/null +++ b/weave-js/src/components/Callout/index.ts @@ -0,0 +1,2 @@ +export * from './Callout'; +export * from './types'; diff --git a/weave-js/src/components/Callout/types.ts b/weave-js/src/components/Callout/types.ts new file mode 100644 index 00000000000..c9d555fd1b0 --- /dev/null +++ b/weave-js/src/components/Callout/types.ts @@ -0,0 +1,7 @@ +export const CalloutSizes = { + XSmall: 'x-small', + Small: 'small', + Medium: 'medium', + Large: 'large', +} as const; +export type CalloutSize = (typeof CalloutSizes)[keyof typeof CalloutSizes]; diff --git a/weave-js/src/components/FancyPage/useProjectSidebar.ts b/weave-js/src/components/FancyPage/useProjectSidebar.ts index cdd5aa1e246..6fcc50703f9 100644 --- a/weave-js/src/components/FancyPage/useProjectSidebar.ts +++ b/weave-js/src/components/FancyPage/useProjectSidebar.ts @@ -9,15 +9,24 @@ export const useProjectSidebar = ( viewingRestricted: boolean, hasModelsData: boolean, hasWeaveData: boolean, - hasTraceBackend: boolean = true + hasTraceBackend: boolean = true, + hasModelsAccess: boolean = true ): FancyPageSidebarItem[] => { // Should show models sidebar items if we have models data or if we don't have a trace backend - const showModelsSidebarItems = hasModelsData || !hasTraceBackend; + let showModelsSidebarItems = hasModelsData || !hasTraceBackend; // Should show weave sidebar items if we have weave data and we have a trace backend - const showWeaveSidebarItems = hasWeaveData && hasTraceBackend; + let showWeaveSidebarItems = hasWeaveData && hasTraceBackend; - const isModelsOnly = showModelsSidebarItems && !showWeaveSidebarItems; - const isWeaveOnly = !showModelsSidebarItems && showWeaveSidebarItems; + let isModelsOnly = showModelsSidebarItems && !showWeaveSidebarItems; + let isWeaveOnly = !showModelsSidebarItems && showWeaveSidebarItems; + + if (!hasModelsAccess) { + showModelsSidebarItems = false; + isModelsOnly = false; + + showWeaveSidebarItems = true; + isWeaveOnly = true; + } const isNoSidebarItems = !showModelsSidebarItems && !showWeaveSidebarItems; const isBothSidebarItems = showModelsSidebarItems && showWeaveSidebarItems; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/CellFilterWrapper.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/CellFilterWrapper.tsx index afbb66e8b39..2ac8de0debc 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/CellFilterWrapper.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/CellFilterWrapper.tsx @@ -5,12 +5,19 @@ import React from 'react'; +export type OnAddFilter = ( + field: string, + operator: string | null, + value: any, + rowId: string +) => void; type CellFilterWrapperProps = { children: React.ReactNode; - onAddFilter?: (key: string, operation: string | null, value: any) => void; + onAddFilter?: OnAddFilter; field: string; operation: string | null; value: any; + rowId: string; style?: React.CSSProperties; }; @@ -20,6 +27,7 @@ export const CellFilterWrapper = ({ field, operation, value, + rowId, style, }: CellFilterWrapperProps) => { const onClickCapture = onAddFilter @@ -28,7 +36,7 @@ export const CellFilterWrapper = ({ if (e.altKey) { e.stopPropagation(); e.preventDefault(); - onAddFilter(field, operation, value); + onAddFilter(field, operation, value, rowId); } } : undefined; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx index b9e4201e410..8a9c011b5de 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallDetails.tsx @@ -71,6 +71,14 @@ export const CallSchemaLink = ({call}: {call: CallSchema}) => { ); }; +const ALLOWED_COLUMN_PATTERNS = [ + 'op_name', + 'status', + 'inputs.*', + 'output', + 'output.*', +]; + export const CallDetails: FC<{ call: CallSchema; }> = ({call}) => { @@ -176,6 +184,7 @@ export const CallDetails: FC<{ }} entity={call.entity} project={call.project} + allowedColumnPatterns={ALLOWED_COLUMN_PATTERNS} /> ); if (isPeeking) { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx index a3b999296a8..a21e47e3917 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx @@ -6,6 +6,7 @@ import { GridRowHeightParams, GridRowId, } from '@mui/x-data-grid-pro'; +import {Button} from '@wandb/weave/components/Button'; import _ from 'lodash'; import React, { Dispatch, @@ -38,7 +39,13 @@ import { getKnownImageDictContexts, isKnownImageDictFormat, } from './objectViewerUtilities'; -import {mapObject, ObjectPath, traverse, TraverseContext} from './traverse'; +import { + getValueType, + mapObject, + ObjectPath, + traverse, + TraverseContext, +} from './traverse'; import {ValueView} from './ValueView'; type Data = Record | any[]; @@ -64,8 +71,13 @@ const getRefs = (data: Data): string[] => { type RefValues = Record; // ref URI to value +type TruncatedStore = {[key: string]: {values: any; index: number}}; + const RESOVLED_REF_KEY = '_ref'; +export const ARRAY_TRUNCATION_LENGTH = 50; +const TRUNCATION_KEY = '__weave_array_truncated__'; + // This is a general purpose object viewer that can be used to view any object. export const ObjectViewer = ({ apiRef, @@ -76,8 +88,13 @@ export const ObjectViewer = ({ }: ObjectViewerProps) => { const {useRefsData} = useWFHooks(); + // `truncatedData` holds the data with all arrays truncated to ARRAY_TRUNCATION_LENGTH, unless we have specifically added more rows to the array + // `truncatedStore` is used to store the additional rows that we can add to the array when the user clicks "Show more" + const {truncatedData, truncatedStore, setTruncatedData, setTruncatedStore} = + useTruncatedData(data); + // `resolvedData` holds ref-resolved data. - const [resolvedData, setResolvedData] = useState(data); + const [resolvedData, setResolvedData] = useState(truncatedData); // `dataRefs` are the refs contained in the data, filtered to only include expandable refs. const dataRefs = useMemo(() => getRefs(data).filter(isExpandableRef), [data]); @@ -158,7 +175,7 @@ export const ObjectViewer = ({ } refValues[r] = val; } - let resolved = data; + let resolved = truncatedData; let dirty = true; const mapper = (context: TraverseContext) => { if ( @@ -177,7 +194,7 @@ export const ObjectViewer = ({ resolved = mapObject(resolved, mapper); } setResolvedData(resolved); - }, [data, refs, refsData.loading, refsData.result]); + }, [data, refs, refsData.loading, refsData.result, truncatedData]); // `rows` are the data-grid friendly rows that we will render. This method traverses // the data, hiding certain keys and adding loader rows for expandable refs. @@ -297,6 +314,19 @@ export const ObjectViewer = ({ display: 'flex', sortable: false, renderCell: ({row}) => { + const isTruncated = row?.value?.[TRUNCATION_KEY]; + const parentPath = row?.parent?.path?.toString() ?? ''; + if (isTruncated && truncatedStore[parentPath]) { + return ( + + ); + } if (row.isCode) { return ( { const refToExpand = params.row.value; + const isTruncated = params.row?.value?.[TRUNCATION_KEY]; + if (isTruncated) { + return null; + } + return ( { + const result = getValueType(data) === 'array' ? [] : {}; + + const store: TruncatedStore = {}; + traverse(data, (context: TraverseContext) => { + let value = context.value; + + // Truncates the value if it is an array and the length is greater than ARRAY_TRUNCATION_LENGTH + if (Array.isArray(value) && value.length > ARRAY_TRUNCATION_LENGTH) { + // Stores the truncated values in the store + store[context.path.toString()] = { + values: value.slice(ARRAY_TRUNCATION_LENGTH), + index: ARRAY_TRUNCATION_LENGTH, + }; + + // Truncates and sets the value to ARRAY_TRUNCATION_LENGTH + value = [ + ...value.slice(0, ARRAY_TRUNCATION_LENGTH), + { + [TRUNCATION_KEY]: true, + }, + ]; + context.value = value; + } + + // Passes the value to the result + if (context.depth === 0) { + // For the root object, we just want to assign the value to the result + if (Array.isArray(result)) { + result.push(...value); + } else { + Object.assign(result, value); + } + } else { + // For all other objects, we want to assign the value to the result + context.path.set(result, value); + } + }); + return {store, result}; +}; + +// This function updates the truncatedData from the truncatedStore, adding more data to the truncatedData array, based on the parentID and truncatedCount +const updateTruncatedDataFromStore = ( + key: string, + truncatedData: Data, + truncatedStore: TruncatedStore, + truncatedCount: number = ARRAY_TRUNCATION_LENGTH +) => { + const store = { + ...truncatedStore, + }; + + const newData = mapObject(truncatedData, (context: TraverseContext) => { + // If the path is the key, we need to show more data + if (context.path.toString() === key) { + const storeValue = truncatedStore[key].values; + // Depending on the length of the store value, we either add truncatedCount more, or the rest of the values + if (storeValue.length > truncatedCount) { + // Remove the truncated indicator + context.value.pop(); + // Add the new values and truncated indicator + context.value.push(...storeValue.slice(0, truncatedCount)); + context.value.push({ + [TRUNCATION_KEY]: true, + }); + // Update the store + store[key] = { + values: storeValue.slice(truncatedCount), + index: store[key].index + truncatedCount, + }; + } else { + // Remove the truncated indicator + context.value.pop(); + // Add the new values + context.value.push(...storeValue); + // Update the store + delete store[key]; + } + } + return context.value; + }); + return {newData, store}; +}; + +const ShowMoreButtons = ({ + parentPath, + truncatedData, + truncatedStore, + setTruncatedData, + setTruncatedStore, +}: { + parentPath: string; + truncatedData: Data; + truncatedStore: TruncatedStore; + setTruncatedData: (data: Data) => void; + setTruncatedStore: (store: TruncatedStore) => void; +}) => { + const truncatedCount = truncatedStore[parentPath]?.values.length ?? 0; + return ( + + {truncatedCount > ARRAY_TRUNCATION_LENGTH && ( + + )} + + + ); +}; + +const useTruncatedData = (data: Data) => { + const [truncatedData, setTruncatedData] = useState(data); + const [truncatedStore, setTruncatedStore] = useState({}); + + useEffect(() => { + const {store, result} = traverseAndTruncate(data); + setTruncatedData(result); + setTruncatedStore(store); + }, [data]); + + return {truncatedData, truncatedStore, setTruncatedData, setTruncatedStore}; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx index de632732b2b..74cdcd2f79f 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx @@ -19,7 +19,7 @@ import {isCustomWeaveTypePayload} from '../../typeViews/customWeaveType.types'; import {CustomWeaveTypeDispatcher} from '../../typeViews/CustomWeaveTypeDispatcher'; import {OBJECT_ATTR_EDGE_NAME} from '../wfReactInterface/constants'; import {WeaveCHTable, WeaveCHTableSourceRefContext} from './DataTableView'; -import {ObjectViewer} from './ObjectViewer'; +import {ARRAY_TRUNCATION_LENGTH, ObjectViewer} from './ObjectViewer'; import {getValueType, traverse} from './traverse'; import {ValueView} from './ValueView'; @@ -156,7 +156,13 @@ const ObjectViewerSectionNonEmpty = ({ setTreeExpanded(true); } setMode('expanded'); - setExpandedIds(getGroupIds()); + if (getGroupIds().length > ARRAY_TRUNCATION_LENGTH) { + setExpandedIds( + getGroupIds().slice(0, expandedIds.length + ARRAY_TRUNCATION_LENGTH) + ); + } else { + setExpandedIds(getGroupIds()); + } }; // On first render and when data changes, recompute expansion state @@ -187,7 +193,7 @@ const ObjectViewerSectionNonEmpty = ({ icon="expand-uncollapse" active={mode === 'expanded'} onClick={onClickExpanded} - tooltip="View expanded" + tooltip={`Expand next ${ARRAY_TRUNCATION_LENGTH} rows`} />