diff --git a/.github/workflows/check-which-tests-to-run.yaml b/.github/workflows/check-which-tests-to-run.yaml new file mode 100644 index 00000000000..622ef7389f8 --- /dev/null +++ b/.github/workflows/check-which-tests-to-run.yaml @@ -0,0 +1,88 @@ +name: Check which tests to run + +on: + workflow_call: + outputs: + weave_query_tests: + value: ${{ jobs.check.outputs.weave_query_tests }} + weave_js_tests: + value: ${{ jobs.check.outputs.weave_js_tests }} + trace_server_tests: + value: ${{ jobs.check.outputs.trace_server_tests }} + +env: + WEAVE_QUERY_PATHS: 'weave_query/' + WEAVE_JS_PATHS: 'weave-js/' + TRACE_SERVER_PATHS: 'weave/trace_server/' + # Everything else is implicitly trace SDK + +jobs: + check: + runs-on: ubuntu-latest + outputs: + weave_query_tests: ${{ steps.weave_query.outputs.run_tests }} + weave_js_tests: ${{ steps.weave_js.outputs.run_tests }} + trace_server_tests: ${{ steps.trace_server.outputs.run_tests }} + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + fetch-tags: true + ref: ${{ github.head_ref }} + - name: Get changed files + run: | + # Fetch all branches + git fetch --all + + # Determine the base branch and current commit + if [[ "${{ github.event_name }}" == "pull_request" ]]; then + # For pull requests + BASE_BRANCH="${{ github.base_ref }}" + CURRENT_COMMIT="${{ github.event.pull_request.head.sha }}" + else + # For pushes + BASE_BRANCH=$(git remote show origin | sed -n '/HEAD branch/s/.*: //p') + CURRENT_COMMIT="${{ github.sha }}" + fi + echo "Base branch is $BASE_BRANCH" + + # Find the common ancestor + MERGE_BASE=$(git merge-base origin/$BASE_BRANCH $CURRENT_COMMIT) + + # Get changed files + changed_files=$(git diff --name-only $MERGE_BASE $CURRENT_COMMIT) + echo "Changed files:" + echo "$changed_files" + echo "changed_files<> $GITHUB_ENV + echo "$changed_files" >> $GITHUB_ENV + echo "EOF" >> $GITHUB_ENV + - id: weave_query + name: Weave Query Checks + run: | + for path in ${{ env.WEAVE_QUERY_PATHS }}; do + if echo "$changed_files" | grep -q "$path"; then + echo "run_tests=true" >> $GITHUB_OUTPUT + exit 0 + fi + done + echo "run_tests=false" >> $GITHUB_OUTPUT + - id: weave_js + name: Weave JS Checks + run: | + for path in ${{ env.WEAVE_JS_PATHS }}; do + if echo "$changed_files" | grep -q "$path"; then + echo "run_tests=true" >> $GITHUB_OUTPUT + exit 0 + fi + done + echo "run_tests=false" >> $GITHUB_OUTPUT + - id: trace_server + name: Weave Trace Server Checks + run: | + for path in ${{ env.TRACE_SERVER_PATHS }}; do + if echo "$changed_files" | grep -q "$path"; then + echo "run_tests=true" >> $GITHUB_OUTPUT + exit 0 + fi + done + echo "run_tests=false" >> $GITHUB_OUTPUT diff --git a/.github/workflows/notify-wandb-core.yaml b/.github/workflows/notify-wandb-core.yaml index 12b01234a21..3a1fa818d34 100644 --- a/.github/workflows/notify-wandb-core.yaml +++ b/.github/workflows/notify-wandb-core.yaml @@ -6,11 +6,14 @@ name: Notify wandb/core on: push: branches: - - "**" + - '**' workflow_dispatch: jobs: + check-which-tests-to-run: + uses: ./.github/workflows/check-which-tests-to-run.yaml notify-wandb-core: + needs: check-which-tests-to-run runs-on: ubuntu-latest steps: - name: Repository dispatch @@ -19,4 +22,4 @@ jobs: token: ${{ secrets.WANDB_CORE_ACCESS_TOKEN }} repository: wandb/core event-type: weave-package-updated - client-payload: '{"ref_name": "${{ github.ref_name }}", "sha": "${{ github.sha }}"}' + client-payload: '{"ref_name": "${{ github.ref_name }}", "sha": "${{ github.sha }}", "run_weave_js_tests": ${{ needs.check-which-tests-to-run.outputs.weave_js_tests }}, "run_weave_query_tests": ${{ needs.check-which-tests-to-run.outputs.weave_query_tests }}, "run_trace_server_tests": ${{ needs.check-which-tests-to-run.outputs.trace_server_tests }}}' diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 59693b6c877..de8d7ac6be3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -12,45 +12,66 @@ on: push: jobs: + check-which-tests-to-run: + uses: ./.github/workflows/check-which-tests-to-run.yaml + # ==== Query Service Jobs ==== build-container-query-service: name: Build Legacy (Query Service) test container timeout-minutes: 30 runs-on: [self-hosted, builder] - # runs-on: ubuntu-latest + outputs: + build_needed: ${{ steps.build_check.outputs.build_needed }} env: REGISTRY: us-east4-docker.pkg.dev/weave-support-367421/weave-images + needs: check-which-tests-to-run + + # if: github.ref == 'refs/heads/master' || needs.check-which-tests-to-run.outputs.weave_query_tests steps: - uses: actions/checkout@v3 with: fetch-depth: 2 + - name: Check if build is needed + id: build_check + run: | + if [[ "${{ github.ref }}" == "refs/heads/master" || "${{ needs.check-which-tests-to-run.outputs.weave_query_tests }}" == "true" ]]; then + echo "Build is needed" + echo "build_needed=true" >> $GITHUB_OUTPUT + else + echo "Build is not needed" + echo "build_needed=false" >> $GITHUB_OUTPUT + fi + - name: Login to Docker Hub + if: steps.build_check.outputs.build_needed == 'true' uses: docker/login-action@v2 with: registry: us-east4-docker.pkg.dev username: _json_key password: ${{ secrets.gcp_sa_key }} - # this script is hardcoded to build for linux/amd64 - name: Prune docker cache + if: steps.build_check.outputs.build_needed == 'true' run: docker system prune -f - - name: Build legacy (query sevice) unit test image + + - name: Build legacy (query service) unit test image + if: steps.build_check.outputs.build_needed == 'true' run: python3 weave/docker/docker_build.py build_deps weave-test-python-query-service builder . weave_query/Dockerfile.ci.test test-query-service: name: Legacy (Query Service) Python unit tests timeout-minutes: 15 # do not raise! running longer than this indicates an issue with the tests. fix there. needs: + - check-which-tests-to-run - build-container-query-service - # runs-on: [self-hosted, gke-runner] runs-on: ubuntu-latest strategy: fail-fast: false matrix: job_num: [0, 1] # runs-on: ubuntu-latest - container: us-east4-docker.pkg.dev/weave-support-367421/weave-images/weave-test-python-query-service:${{ github.sha }} + container: ${{ needs.build-container-query-service.outputs.build_needed == 'true' && format('us-east4-docker.pkg.dev/weave-support-367421/weave-images/weave-test-python-query-service:{0}', github.sha) || null }} services: wandbservice: image: us-central1-docker.pkg.dev/wandb-production/images/local-testcontainer:master @@ -65,14 +86,25 @@ jobs: - '8083:8083' - '9015:9015' options: --health-cmd "curl --fail http://localhost:8080/healthz || exit 1" --health-interval=5s --health-timeout=3s + outputs: + tests_should_run: ${{ steps.test_check.outputs.tests_should_run }} steps: - # - uses: datadog/agent-github-action@v1.3 - # with: - # api_key: ${{ secrets.DD_API_KEY }} - uses: actions/checkout@v2 + - name: Check if tests should run + id: test_check + run: | + if [[ "${{ github.ref }}" == "refs/heads/master" || "${{ needs.check-which-tests-to-run.outputs.weave_query_tests }}" == "true" ]]; then + echo "Tests should run" + echo "tests_should_run=true" >> $GITHUB_OUTPUT + else + echo "Tests should not run" + echo "tests_should_run=false" >> $GITHUB_OUTPUT + fi - name: Verify wandb server is running + if: steps.test_check.outputs.tests_should_run == 'true' run: curl -s http://wandbservice:8080/healthz - name: Run Legacy (Query Service) Python Unit Tests + if: steps.test_check.outputs.tests_should_run == 'true' env: DD_SERVICE: weave-python DD_ENV: ci @@ -108,16 +140,30 @@ jobs: weavejs-lint-compile: name: WeaveJS Lint and Compile runs-on: ubuntu-latest - # runs-on: [self-hosted, gke-runner] + needs: + - check-which-tests-to-run steps: - uses: actions/checkout@v2 with: fetch-depth: 0 token: ${{ secrets.GITHUB_TOKEN }} + - name: Check if lint and compile should run + id: check_run + run: | + if [[ "${{ github.ref }}" == "refs/heads/master" || "${{ needs.check-which-tests-to-run.outputs.weave_js_tests }}" == "true" ]]; then + echo "Lint and compile should run" + echo "should_lint_and_compile=true" >> $GITHUB_OUTPUT + else + echo "Lint and compile should not run" + echo "should_lint_and_compile=false" >> $GITHUB_OUTPUT + fi - uses: actions/setup-node@v1 + if: steps.check_run.outputs.should_lint_and_compile == 'true' with: node-version: '18.x' - - run: | + - name: Run WeaveJS Lint and Compile + if: steps.check_run.outputs.should_lint_and_compile == 'true' + run: | set -e cd weave-js yarn install --frozen-lockfile diff --git a/.github/workflows/wandb-core-updated.yaml b/.github/workflows/wandb-core-updated.yaml index 99e1a635ec4..2224103ae7f 100644 --- a/.github/workflows/wandb-core-updated.yaml +++ b/.github/workflows/wandb-core-updated.yaml @@ -10,26 +10,25 @@ jobs: update-frontend: runs-on: ubuntu-latest steps: - - name: Checkout weave - uses: actions/checkout@v3 - with: - ref: master - fetch-depth: 1 - - name: Checkout wandb/core - uses: actions/checkout@v3 - with: - path: .wandb_core - submodules: false - fetch-depth: 1 - repository: wandb/core - ref: refs/heads/master - token: ${{ secrets.WANDB_CORE_PAT }} - - name: Placeholder - run: git status + - name: Checkout weave + uses: actions/checkout@v3 + with: + ref: master + fetch-depth: 1 + - name: Checkout wandb/core + uses: actions/checkout@v3 + with: + path: .wandb_core + submodules: false + fetch-depth: 1 + repository: wandb/core + ref: refs/heads/master + token: ${{ secrets.WANDB_CORE_PAT }} + - name: Placeholder + run: git status # - name: Create working branch # run: git checkout -b bot/update-frontend # - name: Rebuild frontend - # run: build_frontend.sh + # run: weave_query/scripts/build_frontend.sh # env: # WANDB_CORE: ${{ env.GITHUB_WORKSPACE }}/.wandb_core - \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 265bbce2136..0e66f1c7ede 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,10 +17,9 @@ repos: - id: mypy additional_dependencies: [types-pkg-resources==0.1.3, types-all, wandb>=0.15.5] - # You have to exclude in 3 places. 1) here. 2) mypi.ini exclude, 3) follow_imports = skip for each module in mypy.ini - # exclude: (.*pyi$)|(weave/legacy)|(weave/tests) + # Note: You have to update pyproject.toml[tool.mypy] too! args: ['--config-file=pyproject.toml'] - exclude: (.*pyi$)|(weave_query)|(weave/tests)|(weave/trace_server/tests)|(weave/conftest.py)|(weave/trace/tests)|(integration_test)|(weave/tests/trace)|(weave/conftest.py)|(weave/trace_server/tests)|(examples)|(weave/integrations)|(weave/docker) + exclude: (.*pyi$)|(weave_query)|(tests)|(examples)|(weave/docker) # This is legacy Weave when we were building a notebook product - should be removed - repo: local hooks: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000000..f33e9840e2f --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,162 @@ +# Contributing to `weave` + +- [Contributing to `weave`](#contributing-to-weave) + - [Issues and PRs](#issues-and-prs) + - [Conventional Commits](#conventional-commits) + - [Types](#types) + - [Setting up your environment](#setting-up-your-environment) + - [Working with the `weave` package](#working-with-the-weave-package) + - [Linting](#linting) + - [Building the `weave` package](#building-the-weave-package) + - [Testing](#testing) + - [Deprecating features](#deprecating-features) + +## Issues and PRs + +1. Check the [issues](https://github.com/wandb/weave/issues) and [PRs](https://github.com/wandb/weave/pulls) to see if the feature/bug has already been requested/fixed. If not, [open an issue](https://github.com/wandb/weave/issues/new/choose). This helps us keep track of feature requests and bugs! +2. If you are a first-time contributor, welcome! To get started, make a fork and point to the main `weave` repo: + ```sh + git clone https://github.com//weave.git + cd weave + git remote add upstream https://github.com/wandb/weave.git + ``` +3. Build! + 1. Keep your fork up to date with the main `weave` repo: + ```sh + git checkout master + git pull upstream master + ``` + 2. Create a branch for your changes: + ```sh + git checkout -b / + ``` + 3. Commit changes to your branch and push: + ```sh + git add your_file.py + git commit -m "feat(integrations): Add new integration for " + git push origin / + ``` + 4. Open a PR! + +### Conventional Commits + +All PR titles should conform to the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) spec. Conventional Commits is a lightweight convention on top of commit messages. + +**Structure** + +The commit message should be structured as follows: + +```jsx +(): +``` + + + +#### Types + +Only certain types are permitted. + + + +| Type | Name | Description | User-facing? | +| -------- | ---------------- | ----------------------------------------------------------------------------------- | ------------ | +| feat | ✨ Feature | Changes that add new functionality that directly impacts users | Yes | +| fix | 🐛 Fix | Changes that fix existing issues | Yes | +| refactor | 💎 Code Refactor | A code change that neither fixes a bug nor adds a new feature | No | +| docs | 📜 Documentation | Documentation changes only | Maybe | +| style | 💅 Style | Changes that do not affect the meaning of the code (e.g. linting) | Maybe | +| chore | ⚙️ Chores | Changes that do not modify source code (e.g. CI configuration files, build scripts) | No | +| revert | ♻️ Reverts | Reverts a previous commit | Maybe | +| security | 🔒 Security | Security fix/feature | Maybe | + +## Setting up your environment + +We use: + +1. [`uv`](<(https://astral.sh/blog/uv)>) for package and env management -- follow the [uv guide to bootstrap an environment](https://docs.astral.sh/uv/getting-started/installation/) + ```sh + curl -LsSf https://astral.sh/uv/install.sh | sh + ``` +2. [`nox`](https://nox.thea.codes/en/stable/tutorial.html#installation) for running tests + ```sh + uv tool install nox + ``` + +### Working with the `weave` package + +We recommend installing the package in editable mode: + +```sh +uv pip install -e . +``` + +### Linting + +We use pre-commit. You can install with: + +```sh +uv tool install pre-commit +``` + +Then run: + +```sh +pre-commit run --hook-stage=pre-push --all-files +``` + +You can also use the `lint` nox target to run linting. + +```sh +nox -e lint +``` + +### Building the `weave` package + +Use `uv`: + +```sh +uv build +``` + +### Testing + +We use pytest and nox to run tests. To see a list of test environments: + +```sh +nox -l +``` + +Then to run a specific environment: + +```sh +nox -e "$TARGET_ENV" # e.g. nox -e "tests-3.12(shard='trace')" +``` + +Tests are split up into shards, which include: + +1. `trace` -- all of the trace SDK tests +2. `trace_server` -- tests for trace server backend +3. various integrations, like `openai`, `instructor`, etc -- these envs are isolated to simplify testing + +### Deprecating features + +`weave` is moving quickly, and sometimes features need to be deprecated. + +To deprecate a feature, use the `deprecated` decorator from `weave.trace.util`. This is currently used primarily for renames. + +```python +from weave.trace.util import deprecated + +@deprecated("new_func_name") +def old_func(): + pass +``` diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index fc1d7feabc9..00000000000 --- a/MANIFEST.in +++ /dev/null @@ -1,15 +0,0 @@ -include requirements.* -include wb_schema.gql -graft weave/client_context -graft weave/deploy -graft weave/flow -graft weave_query/weave_query/frontend -graft weave/integrations -graft weave/trace -graft weave/trace_server -graft weave/type_serializers -graft weave/trace_server_bindings -graft weave/wandb_interface -global-exclude */__pycache__/* -global-exclude *.pyc -global-exclude */cassettes/* diff --git a/docs/docs/guides/tracking/ops.md b/docs/docs/guides/tracking/ops.md index b83ff69a72a..48ac1e5ff2a 100644 --- a/docs/docs/guides/tracking/ops.md +++ b/docs/docs/guides/tracking/ops.md @@ -42,22 +42,28 @@ If you want to change the data that is logged to weave without modifying the ori `postprocess_output` takes in any value which would normally be returned by the function and returns the transformed output. ```py -def postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: - return {k:v for k,v in inputs.items() if k != "hide_me"} - -def postprocess_output(output: CustomObject) -> CustomObject: - return CustomObject(x=output.x, secret_password="REDACTED") - +from dataclasses import dataclass +from typing import Any +import weave @dataclass class CustomObject: x: int secret_password: str +def postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: + return {k:v for k,v in inputs.items() if k != "hide_me"} + +def postprocess_output(output: CustomObject) -> CustomObject: + return CustomObject(x=output.x, secret_password="REDACTED") + @weave.op( postprocess_inputs=postprocess_inputs, postprocess_output=postprocess_output, ) def func(a: int, hide_me: str) -> CustomObject: return CustomObject(x=a, secret_password=hide_me) + +weave.init('hide-data-example') # 🐝 +func(a=1, hide_me="password123") ``` diff --git a/examples/agents/explorer.py b/examples/agents/explorer.py index 9a6a3a44015..a92d34695e6 100644 --- a/examples/agents/explorer.py +++ b/examples/agents/explorer.py @@ -1,7 +1,5 @@ import json import subprocess -from rich import print - import weave from weave.flow.agent import Agent, AgentState diff --git a/examples/agents/programmer.py b/examples/agents/programmer.py index 18a42648afa..18bd201968d 100644 --- a/examples/agents/programmer.py +++ b/examples/agents/programmer.py @@ -1,8 +1,8 @@ import subprocess -from rich import print -from rich.console import Console import sys +from rich import print +from rich.console import Console import weave from weave.flow.agent import Agent, AgentState diff --git a/examples/agents/teacher_explorer.py b/examples/agents/teacher_explorer.py index 553fd2609d7..265ec4d57e0 100644 --- a/examples/agents/teacher_explorer.py +++ b/examples/agents/teacher_explorer.py @@ -1,12 +1,14 @@ -from rich.console import Console -import openai import subprocess import textwrap + +import openai +from rich.console import Console + import weave -from weave.flow.obj import Object from weave.flow.agent import Agent, AgentState from weave.flow.chat_util import OpenAIStream from weave.flow.console import LogEvents +from weave.flow.obj import Object TEACHER_MESSAGE = """Assistant is a teacher. The teacher uses a direct technique, without motivational fluff, to drive the student to discover itself. diff --git a/examples/cookbooks/weave_litellm_integration_docs.ipynb b/examples/cookbooks/weave_litellm_integration_docs.ipynb index d968f2997ce..a3352a2b262 100644 --- a/examples/cookbooks/weave_litellm_integration_docs.ipynb +++ b/examples/cookbooks/weave_litellm_integration_docs.ipynb @@ -5,7 +5,8 @@ "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/weave/blob/master/examples/cookbooks/weave_litellm_integration_docs.ipynb)\n", - "" + "\n", + "\n" ] }, { @@ -27,13 +28,16 @@ "outputs": [], "source": [ "try:\n", - " from google.colab import userdata\n", " import os\n", + "\n", + " from google.colab import userdata\n", + "\n", " os.environ[\"WANDB_API_KEY\"] = userdata.get(\"WANDB_API_KEY\")\n", " os.environ[\"OPENAI_API_KEY\"] = userdata.get(\"OPENAI_API_KEY\")\n", " os.environ[\"ANTHROPIC_API_KEY\"] = userdata.get(\"ANTHROPIC_API_KEY\")\n", "except:\n", " from dotenv import load_dotenv\n", + "\n", " load_dotenv()" ] }, @@ -68,6 +72,7 @@ "outputs": [], "source": [ "import weave\n", + "\n", "weave.init(project)" ] }, @@ -82,7 +87,7 @@ "response = litellm.completion(\n", " model=\"gpt-3.5-turbo\",\n", " messages=[{\"role\": \"user\", \"content\": \"Translate 'Hello, how are you?' to French\"}],\n", - " max_tokens=1024\n", + " max_tokens=1024,\n", ")\n", "print(response.choices[0].message.content)" ] @@ -98,7 +103,7 @@ "response = litellm.completion(\n", " model=\"claude-3-5-sonnet-20240620\",\n", " messages=[{\"role\": \"user\", \"content\": \"Translate 'Hello, how are you?' to French\"}],\n", - " max_tokens=1024\n", + " max_tokens=1024,\n", ")\n", "print(response.choices[0].message.content)" ] @@ -115,11 +120,14 @@ "def translate(text: str, target_language: str, model: str) -> str:\n", " response = litellm.completion(\n", " model=model,\n", - " messages=[{\"role\": \"user\", \"content\": f\"Translate '{text}' to {target_language}\"}],\n", - " max_tokens=1024\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": f\"Translate '{text}' to {target_language}\"}\n", + " ],\n", + " max_tokens=1024,\n", " )\n", " return response.choices[0].message.content\n", "\n", + "\n", "print(translate(\"Hello, how are you?\", \"French\", \"gpt-3.5-turbo\"))\n", "print(translate(\"Hello, how are you?\", \"Spanish\", \"claude-3-5-sonnet-20240620\"))" ] @@ -141,14 +149,18 @@ " response = litellm.completion(\n", " model=self.model,\n", " messages=[\n", - " {\"role\": \"system\", \"content\": f\"You are a translator. Translate the given text to {target_language}.\"},\n", - " {\"role\": \"user\", \"content\": text}\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": f\"You are a translator. Translate the given text to {target_language}.\",\n", + " },\n", + " {\"role\": \"user\", \"content\": text},\n", " ],\n", " max_tokens=1024,\n", - " temperature=self.temperature\n", + " temperature=self.temperature,\n", " )\n", " return response.choices[0].message.content\n", "\n", + "\n", "# Create instances with different models\n", "gpt_translator = TranslatorModel(model=\"gpt-3.5-turbo\", temperature=0.3)\n", "claude_translator = TranslatorModel(model=\"claude-3-5-sonnet-20240620\", temperature=0.1)\n", @@ -200,7 +212,7 @@ " \"target_language\": {\n", " \"type\": \"string\",\n", " \"description\": \"The language to translate to\",\n", - " }\n", + " },\n", " },\n", " \"required\": [\"text\", \"target_language\"],\n", " },\n", @@ -210,15 +222,6 @@ "\n", "print(response.choices[0].message.function_call)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Q0ehhh_0mB_2" - }, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/examples/text-extract/evaluate.py b/examples/text-extract/evaluate.py index 5f846d32f37..abb292b198e 100644 --- a/examples/text-extract/evaluate.py +++ b/examples/text-extract/evaluate.py @@ -1,12 +1,11 @@ -from typing import Any - import asyncio import json import os -import re +from typing import Any + import openai -import weave +import weave from weave.flow.scorer import MultiTaskBinaryClassificationF1 diff --git a/examples/text-extract/example_data/Generate AOI.ipynb b/examples/text-extract/example_data/Generate AOI.ipynb index f42c3db7a3d..540ca1abdf1 100644 --- a/examples/text-extract/example_data/Generate AOI.ipynb +++ b/examples/text-extract/example_data/Generate AOI.ipynb @@ -7,9 +7,9 @@ "metadata": {}, "outputs": [], "source": [ - "import weave\n", - "import openai\n", - "from datetime import datetime" + "from datetime import datetime\n", + "\n", + "import openai" ] }, { @@ -21,19 +21,25 @@ "source": [ "def gen_aoi():\n", " messages = [\n", - " {'role': 'system', 'content': 'you are a tool used for generating ultra-realistic datasets. you do not use placeholders in your output'},\n", - " {'role': 'user', 'content': 'please generate an articles of incorporation document for a realistic (but not real) company'}\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"you are a tool used for generating ultra-realistic datasets. you do not use placeholders in your output\",\n", + " },\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"please generate an articles of incorporation document for a realistic (but not real) company\",\n", + " },\n", " ]\n", " response = openai.ChatCompletion.create(\n", - " model='gpt-4',\n", - " messages=messages,\n", - " temperature=1.1,\n", - " n=1,\n", - " )\n", - " result = response['choices'][0]['message']['content']\n", + " model=\"gpt-4\",\n", + " messages=messages,\n", + " temperature=1.1,\n", + " n=1,\n", + " )\n", + " result = response[\"choices\"][0][\"message\"][\"content\"]\n", " now = datetime.now()\n", " timestamp_str = now.strftime(\"%Y%m%d_%H%M%S\")\n", - " open(f'example_data/aoi_{timestamp_str}.txt', 'w').write(result)" + " open(f\"example_data/aoi_{timestamp_str}.txt\", \"w\").write(result)" ] }, { diff --git a/examples/weaveflow/get_started.ipynb b/examples/weaveflow/get_started.ipynb index 4b333b9d0a6..c91c12d65b0 100644 --- a/examples/weaveflow/get_started.ipynb +++ b/examples/weaveflow/get_started.ipynb @@ -23,6 +23,7 @@ "# (note: we can eliminate this step import in the future)\n", "\n", "import wandb\n", + "\n", "wandb.login()" ] }, @@ -36,7 +37,7 @@ "# (note: we can eliminate the `weaveflow` import in future)\n", "\n", "import weave\n", - "from weave import weaveflow\n", + "\n", "client = weave.init(\"wf_eval\")" ] }, @@ -47,12 +48,16 @@ "outputs": [], "source": [ "# Authenticate with OpenAI\n", - "from getpass import getpass\n", "import os\n", + "from getpass import getpass\n", "\n", "if os.getenv(\"OPENAI_API_KEY\") is None:\n", - " os.environ[\"OPENAI_API_KEY\"] = getpass(\"Paste your OpenAI key from: https://platform.openai.com/account/api-keys\\n\")\n", - "assert os.getenv(\"OPENAI_API_KEY\", \"\").startswith(\"sk-\"), \"This doesn't look like a valid OpenAI API key\"\n", + " os.environ[\"OPENAI_API_KEY\"] = getpass(\n", + " \"Paste your OpenAI key from: https://platform.openai.com/account/api-keys\\n\"\n", + " )\n", + "assert os.getenv(\"OPENAI_API_KEY\", \"\").startswith(\n", + " \"sk-\"\n", + "), \"This doesn't look like a valid OpenAI API key\"\n", "print(\"OpenAI API key configured\")" ] }, @@ -63,86 +68,98 @@ "outputs": [], "source": [ "import json\n", + "\n", "from openai import OpenAI\n", "\n", + "\n", "@weave.op()\n", "def simple_openai_complete(message: str, model: str, system_prompt: str) -> str:\n", - " completion = OpenAI().chat.completions.create(\n", - " model=model,\n", - " messages=[\n", - " {\"role\": \"system\", \"content\": system_prompt},\n", - " {\"role\": \"user\", \"content\": message}\n", - " ]\n", - " )\n", + " completion = OpenAI().chat.completions.create(\n", + " model=model,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": message},\n", + " ],\n", + " )\n", + "\n", + " return completion.choices[0].message.content\n", "\n", - " return completion.choices[0].message.content\n", "\n", "@weave.op()\n", "def lookup_docs(user_message: str) -> list[str]:\n", - " docs = simple_openai_complete(\n", - " f\"Query: {user_message}\",\n", - " \"gpt-3.5-turbo\",\n", - " \"Please act like a vector database, returning up to 3 documents that relate to a query. Feel free to be creative, make up names & details, etc... - as if you have access to relevant documents containing critical information. Do not include docuemnt titles, just content. You must format your response as a JSON.load-able array of strings.\"\n", - " )\n", - " try:\n", - " json_res = json.loads(docs)\n", - " res = []\n", - " for item in json_res:\n", - " res.append(\"\" + item)\n", - " except Exception as e:\n", - " print(e)\n", - " return []\n", - "\n", - " return res\n", + " docs = simple_openai_complete(\n", + " f\"Query: {user_message}\",\n", + " \"gpt-3.5-turbo\",\n", + " \"Please act like a vector database, returning up to 3 documents that relate to a query. Feel free to be creative, make up names & details, etc... - as if you have access to relevant documents containing critical information. Do not include docuemnt titles, just content. You must format your response as a JSON.load-able array of strings.\",\n", + " )\n", + " try:\n", + " json_res = json.loads(docs)\n", + " res = []\n", + " for item in json_res:\n", + " res.append(\"\" + item)\n", + " except Exception as e:\n", + " print(e)\n", + " return []\n", + "\n", + " return res\n", + "\n", "\n", "@weave.type()\n", "class GptRagModel:\n", - " base_model: str\n", - " system_prompt: str\n", - "\n", - " @weave.op()\n", - " def predict(self, input: str) -> str:\n", - " docs = lookup_docs(input)\n", - " prompt = f\"Given the following documents, please formulate a short, consise answer.\\n Documents: {','.join(docs)}.\\n\\n Query: {input}.\"\n", - " res = simple_openai_complete(prompt, self.base_model, self.system_prompt)\n", - " return res\n", + " base_model: str\n", + " system_prompt: str\n", + "\n", + " @weave.op()\n", + " def predict(self, input: str) -> str:\n", + " docs = lookup_docs(input)\n", + " prompt = f\"Given the following documents, please formulate a short, consise answer.\\n Documents: {','.join(docs)}.\\n\\n Query: {input}.\"\n", + " res = simple_openai_complete(prompt, self.base_model, self.system_prompt)\n", + " return res\n", + "\n", "\n", "@weave.type()\n", "class SimpleDataset:\n", - " examples: list[str]\n", + " examples: list[str]\n", + "\n", "\n", "@weave.op()\n", - "def brevity_score(example:str, prediction:str) -> float:\n", + "def brevity_score(example: str, prediction: str) -> float:\n", " # returns a dict of scores\n", " return 1 / (1 + len(prediction))\n", "\n", + "\n", "@weave.op()\n", - "def relevance_score(example:str, prediction:str) -> int:\n", + "def relevance_score(example: str, prediction: str) -> int:\n", " # returns a dict of scores\n", - " return int(simple_openai_complete(\n", - " f\"Prompt:{example}.\\nAnswer:{prediction}.\\nScore:\",\n", - " \"gpt-3.5-turbo\",\n", - " \"Score relevance of the output to the input. Emit ONLY a number between 0 and 9 inclusive. Nothing else\"\n", - " ))\n", + " return int(\n", + " simple_openai_complete(\n", + " f\"Prompt:{example}.\\nAnswer:{prediction}.\\nScore:\",\n", + " \"gpt-3.5-turbo\",\n", + " \"Score relevance of the output to the input. Emit ONLY a number between 0 and 9 inclusive. Nothing else\",\n", + " )\n", + " )\n", + "\n", "\n", "@weave.op()\n", - "def score(example:str, prediction:str) -> dict:\n", + "def score(example: str, prediction: str) -> dict:\n", " # returns a dict of scores\n", " return {\n", - " 'brevity': brevity_score(example, prediction),\n", - " 'relevance': relevance_score(example, prediction),\n", + " \"brevity\": brevity_score(example, prediction),\n", + " \"relevance\": relevance_score(example, prediction),\n", " }\n", "\n", + "\n", "@weave.op()\n", "def eval_iter(model: GptRagModel, example: str) -> dict:\n", " return score(example, model.predict(example))\n", "\n", + "\n", "@weave.op()\n", "def aggregate_scores(score_dicts) -> float:\n", - " return sum([\n", - " score_dict['brevity'] * score_dict['relevance']\n", - " for score_dict in score_dicts\n", - " ]) / len(score_dicts)\n", + " return sum(\n", + " [score_dict[\"brevity\"] * score_dict[\"relevance\"] for score_dict in score_dicts]\n", + " ) / len(score_dicts)\n", + "\n", "\n", "@weave.op()\n", "def evaluate(model: GptRagModel, dataset: SimpleDataset):\n", @@ -163,25 +180,20 @@ "models = [\n", " GptRagModel(\"gpt-3.5-turbo\", \"You are a helpful assistant.\"),\n", " GptRagModel(\"gpt-3.5-turbo\", \"You are a very bored, sarcastic assistant.\"),\n", - " GptRagModel(\"gpt-3.5-turbo\", \"You are a evil, unhelpful assistant.\")\n", + " GptRagModel(\"gpt-3.5-turbo\", \"You are a evil, unhelpful assistant.\"),\n", "]\n", "\n", - "dataset = SimpleDataset([\n", - " \"What is the square root of pi?\",\n", - " \"Describe inception, not the movie, the concept.\",\n", - " \"How much is Apple worth today?\"\n", - "])\n", + "dataset = SimpleDataset(\n", + " [\n", + " \"What is the square root of pi?\",\n", + " \"Describe inception, not the movie, the concept.\",\n", + " \"How much is Apple worth today?\",\n", + " ]\n", + ")\n", "\n", "for model in models:\n", - " evaluate(model, dataset)" + " evaluate(model, dataset)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/noxfile.py b/noxfile.py index a4e34dd72eb..913afed6e3c 100644 --- a/noxfile.py +++ b/noxfile.py @@ -33,7 +33,7 @@ def lint(session): ) def tests(session, shard): session.install("-e", f".[{shard},test]") - session.chdir("weave") + session.chdir("tests") env = { k: session.env.get(k) @@ -48,7 +48,7 @@ def tests(session, shard): default_test_dirs = [f"integrations/{shard}/"] test_dirs_dict = { - "trace": ["tests/trace/", "trace/"], + "trace": ["trace/"], "trace_server": ["trace_server/"], "mistral0": ["integrations/mistral/v0/"], "mistral1": ["integrations/mistral/v1/"], diff --git a/pyproject.toml b/pyproject.toml index e6b7617c8d2..9c0cb00d680 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,13 +2,17 @@ name = "weave" description = "A toolkit for building composable interactive data driven applications." readme = "README.md" -license = { text = "Apache-2.0" } +license = { file = "LICENSE" } maintainers = [{ name = "W&B", email = "support@wandb.com" }] authors = [ { name = "Shawn Lewis", email = "shawn@wandb.com" }, - { name = "Danny Goldstein", email = "danny@wandb.com" }, { name = "Tim Sweeney", email = "tim@wandb.com" }, { name = "Nick Peneranda", email = "nick.penaranda@wandb.com" }, + { name = "Jeff Raubitschek", email = "jeff@wandb.com" }, + { name = "Jamie Rasmussen", email = "jamie.rasmussen@wandb.com" }, + { name = "Griffin Tarpenning", email = "griffin.tarpenning@wandb.com" }, + { name = "Josiah Lee", email = "josiah.lee@wandb.com" }, + { name = "Andrew Truong", email = "andrew@wandb.com" }, ] classifiers = [ "Development Status :: 4 - Beta", @@ -38,8 +42,8 @@ dependencies = [ "tenacity>=8.3.0,!=8.4.0", # Excluding 8.4.0 because it had a bug on import of AsyncRetrying "emoji>=2.12.1", # For emoji shortcode support in Feedback "uuid-utils>=0.9.0", # Used for ID generation - remove once python's built-in uuid supports UUIDv7 - "numpy>1.21.0", - "rich", + "numpy>1.21.0", # Used in box.py (should be made optional) + "rich", # Used for special formatting of tables (should be made optional) # dependencies for remaining legacy code. Remove when possible "httpx", @@ -83,7 +87,6 @@ test = [ # Integration Tests "pytest-recording==0.13.1", "vcrpy==6.0.1", - "semver==2.13.0", # Used for switching logic based on package versions (should be replaced with packaging!) # serving tests "flask", @@ -152,47 +155,24 @@ ignore = [ "D213", # https://docs.astral.sh/ruff/rules/multi-line-summary-second-line/ "D215", # https://docs.astral.sh/ruff/rules/section-underline-not-over-indented/ ] -exclude = [ - "weave/api.py", - "weave/__init__.py", - "weave/legacy/**/*.py", - "examples", - "weave_query", -] +exclude = ["weave_query"] [tool.ruff.lint.isort] known-third-party = ["wandb", "weave_query"] [tool.ruff.lint.per-file-ignores] -"weave/tests/*" = ["F401"] -"weave/weave_server.py" = ["F401"] +"tests/*" = ["F401"] [tool.ruff] line-length = 88 show-fixes = true -exclude = [ - "weave/query_api.py", - "weave/__init__.py", - "weave/legacy/**/*.py", - "examples", - "weave_query", -] +exclude = ["weave_query"] [tool.mypy] warn_unused_configs = true -exclude = [ - ".*pyi$", - "weave_query", - "weave/tests", - "weave/trace_server/tests", - "weave/conftest.py", - "weave/trace/tests", - "integration_test", - "weave/trace_server/tests", - "examples", - "weave/docker", -] +# Note: You have to update .pre-commit-config.yaml too! +exclude = [".*pyi$", "weave_query", "tests", "examples", "weave/docker"] ignore_missing_imports = true [[tool.mypy.overrides]] diff --git a/scripts/build_frontend.sh b/scripts/build_frontend.sh deleted file mode 100755 index e82be6470da..00000000000 --- a/scripts/build_frontend.sh +++ /dev/null @@ -1,2 +0,0 @@ -cd weave_query/weave_query/frontend -sh build.sh diff --git a/weave/tests/__init__.py b/tests/__init__.py similarity index 100% rename from weave/tests/__init__.py rename to tests/__init__.py diff --git a/weave/conftest.py b/tests/conftest.py similarity index 97% rename from weave/conftest.py rename to tests/conftest.py index bbd0e0cd01f..1341093f25d 100644 --- a/weave/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,8 @@ from fastapi.testclient import TestClient import weave -from weave.trace import weave_init +from tests.trace.util import DummyTestException +from weave.trace import autopatch, weave_init from weave.trace.client_context import context_state from weave.trace_server import ( clickhouse_trace_server_batched, @@ -17,9 +18,8 @@ from weave.trace_server import trace_server_interface as tsi from weave.trace_server_bindings import remote_http_trace_server -from .tests.trace.trace_server_clickhouse_conftest import * -from .tests.wandb_system_tests_conftest import * -from .trace import autopatch +from .trace.trace_server_clickhouse_conftest import * +from .wandb_system_tests_conftest import * # Force testing to never report wandb sentry events os.environ["WANDB_ERROR_REPORTING"] = "false" @@ -72,7 +72,7 @@ def get_error_logs(self): # is causing it and need to ship this PR, so I am just going to filter it out # for now. and not record.msg.startswith( - "Job failed with exception: 400 Client Error: Bad Request for url: https://trace.wandb.ai/" + "Task failed: HTTPError: 400 Client Error: Bad Request for url: https://trace.wandb.ai/" ) # Exclude legacy and not record.name.startswith("weave.weave_server") @@ -315,10 +315,6 @@ def post(url, data=None, json=None, **kwargs): weave.trace_server.requests.post = orig_post -class DummyTestException(Exception): - pass - - class ThrowingServer(tsi.TraceServerInterface): # Call API def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: diff --git a/weave/integrations/anthropic/anthropic_test.py b/tests/integrations/anthropic/anthropic_test.py similarity index 100% rename from weave/integrations/anthropic/anthropic_test.py rename to tests/integrations/anthropic/anthropic_test.py diff --git a/weave/integrations/anthropic/cassettes/anthropic_test/test_anthropic.yaml b/tests/integrations/anthropic/cassettes/anthropic_test/test_anthropic.yaml similarity index 100% rename from weave/integrations/anthropic/cassettes/anthropic_test/test_anthropic.yaml rename to tests/integrations/anthropic/cassettes/anthropic_test/test_anthropic.yaml diff --git a/weave/integrations/anthropic/cassettes/anthropic_test/test_anthropic_messages_stream_ctx_manager.yaml b/tests/integrations/anthropic/cassettes/anthropic_test/test_anthropic_messages_stream_ctx_manager.yaml similarity index 100% rename from weave/integrations/anthropic/cassettes/anthropic_test/test_anthropic_messages_stream_ctx_manager.yaml rename to tests/integrations/anthropic/cassettes/anthropic_test/test_anthropic_messages_stream_ctx_manager.yaml diff --git a/weave/integrations/anthropic/cassettes/anthropic_test/test_anthropic_messages_stream_ctx_manager_text.yaml b/tests/integrations/anthropic/cassettes/anthropic_test/test_anthropic_messages_stream_ctx_manager_text.yaml similarity index 100% rename from weave/integrations/anthropic/cassettes/anthropic_test/test_anthropic_messages_stream_ctx_manager_text.yaml rename to tests/integrations/anthropic/cassettes/anthropic_test/test_anthropic_messages_stream_ctx_manager_text.yaml diff --git a/weave/integrations/anthropic/cassettes/anthropic_test/test_anthropic_stream.yaml b/tests/integrations/anthropic/cassettes/anthropic_test/test_anthropic_stream.yaml similarity index 100% rename from weave/integrations/anthropic/cassettes/anthropic_test/test_anthropic_stream.yaml rename to tests/integrations/anthropic/cassettes/anthropic_test/test_anthropic_stream.yaml diff --git a/weave/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic.yaml b/tests/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic.yaml similarity index 100% rename from weave/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic.yaml rename to tests/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic.yaml diff --git a/weave/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic_messages_stream_ctx_manager.yaml b/tests/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic_messages_stream_ctx_manager.yaml similarity index 100% rename from weave/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic_messages_stream_ctx_manager.yaml rename to tests/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic_messages_stream_ctx_manager.yaml diff --git a/weave/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic_messages_stream_ctx_manager_text.yaml b/tests/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic_messages_stream_ctx_manager_text.yaml similarity index 100% rename from weave/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic_messages_stream_ctx_manager_text.yaml rename to tests/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic_messages_stream_ctx_manager_text.yaml diff --git a/weave/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic_stream.yaml b/tests/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic_stream.yaml similarity index 100% rename from weave/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic_stream.yaml rename to tests/integrations/anthropic/cassettes/anthropic_test/test_async_anthropic_stream.yaml diff --git a/weave/integrations/anthropic/cassettes/anthropic_test/test_tools_calling.yaml b/tests/integrations/anthropic/cassettes/anthropic_test/test_tools_calling.yaml similarity index 100% rename from weave/integrations/anthropic/cassettes/anthropic_test/test_tools_calling.yaml rename to tests/integrations/anthropic/cassettes/anthropic_test/test_tools_calling.yaml diff --git a/weave/integrations/cerebras/cassettes/cerebras_test/test_cerebras_async.yaml b/tests/integrations/cerebras/cassettes/cerebras_test/test_cerebras_async.yaml similarity index 100% rename from weave/integrations/cerebras/cassettes/cerebras_test/test_cerebras_async.yaml rename to tests/integrations/cerebras/cassettes/cerebras_test/test_cerebras_async.yaml diff --git a/weave/integrations/cerebras/cassettes/cerebras_test/test_cerebras_sync.yaml b/tests/integrations/cerebras/cassettes/cerebras_test/test_cerebras_sync.yaml similarity index 100% rename from weave/integrations/cerebras/cassettes/cerebras_test/test_cerebras_sync.yaml rename to tests/integrations/cerebras/cassettes/cerebras_test/test_cerebras_sync.yaml diff --git a/weave/integrations/cerebras/cerebras_test.py b/tests/integrations/cerebras/cerebras_test.py similarity index 100% rename from weave/integrations/cerebras/cerebras_test.py rename to tests/integrations/cerebras/cerebras_test.py diff --git a/weave/integrations/cohere/cassettes/cohere_test/test_cohere.yaml b/tests/integrations/cohere/cassettes/cohere_test/test_cohere.yaml similarity index 100% rename from weave/integrations/cohere/cassettes/cohere_test/test_cohere.yaml rename to tests/integrations/cohere/cassettes/cohere_test/test_cohere.yaml diff --git a/weave/integrations/cohere/cassettes/cohere_test/test_cohere_async.yaml b/tests/integrations/cohere/cassettes/cohere_test/test_cohere_async.yaml similarity index 100% rename from weave/integrations/cohere/cassettes/cohere_test/test_cohere_async.yaml rename to tests/integrations/cohere/cassettes/cohere_test/test_cohere_async.yaml diff --git a/weave/integrations/cohere/cassettes/cohere_test/test_cohere_async_stream.yaml b/tests/integrations/cohere/cassettes/cohere_test/test_cohere_async_stream.yaml similarity index 100% rename from weave/integrations/cohere/cassettes/cohere_test/test_cohere_async_stream.yaml rename to tests/integrations/cohere/cassettes/cohere_test/test_cohere_async_stream.yaml diff --git a/weave/integrations/cohere/cassettes/cohere_test/test_cohere_async_stream_v2.yaml b/tests/integrations/cohere/cassettes/cohere_test/test_cohere_async_stream_v2.yaml similarity index 100% rename from weave/integrations/cohere/cassettes/cohere_test/test_cohere_async_stream_v2.yaml rename to tests/integrations/cohere/cassettes/cohere_test/test_cohere_async_stream_v2.yaml diff --git a/weave/integrations/cohere/cassettes/cohere_test/test_cohere_async_v2.yaml b/tests/integrations/cohere/cassettes/cohere_test/test_cohere_async_v2.yaml similarity index 100% rename from weave/integrations/cohere/cassettes/cohere_test/test_cohere_async_v2.yaml rename to tests/integrations/cohere/cassettes/cohere_test/test_cohere_async_v2.yaml diff --git a/weave/integrations/cohere/cassettes/cohere_test/test_cohere_stream.yaml b/tests/integrations/cohere/cassettes/cohere_test/test_cohere_stream.yaml similarity index 100% rename from weave/integrations/cohere/cassettes/cohere_test/test_cohere_stream.yaml rename to tests/integrations/cohere/cassettes/cohere_test/test_cohere_stream.yaml diff --git a/weave/integrations/cohere/cassettes/cohere_test/test_cohere_stream_v2.yaml b/tests/integrations/cohere/cassettes/cohere_test/test_cohere_stream_v2.yaml similarity index 100% rename from weave/integrations/cohere/cassettes/cohere_test/test_cohere_stream_v2.yaml rename to tests/integrations/cohere/cassettes/cohere_test/test_cohere_stream_v2.yaml diff --git a/weave/integrations/cohere/cassettes/cohere_test/test_cohere_v2.yaml b/tests/integrations/cohere/cassettes/cohere_test/test_cohere_v2.yaml similarity index 100% rename from weave/integrations/cohere/cassettes/cohere_test/test_cohere_v2.yaml rename to tests/integrations/cohere/cassettes/cohere_test/test_cohere_v2.yaml diff --git a/weave/integrations/cohere/cohere_test.py b/tests/integrations/cohere/cohere_test.py similarity index 98% rename from weave/integrations/cohere/cohere_test.py rename to tests/integrations/cohere/cohere_test.py index ba8baf9f416..db7e3b1d5f8 100644 --- a/weave/integrations/cohere/cohere_test.py +++ b/tests/integrations/cohere/cohere_test.py @@ -4,7 +4,7 @@ import pytest import weave -from weave.integrations.integration_utilities import _get_call_output, op_name_from_ref +from weave.integrations.integration_utilities import op_name_from_ref cohere_model = "command" # You can change this to a specific model if needed @@ -256,7 +256,7 @@ def test_cohere_v2( assert call.exception is None and call.ended_at is not None assert call.started_at < call.ended_at assert op_name_from_ref(call.op_name) == "cohere.ClientV2.chat" - output = _get_call_output(call) + output = call.output assert output.message.content[0].text == response.message.content[0].text assert output.id == response.id @@ -303,7 +303,7 @@ async def test_cohere_async_v2( assert call.exception is None and call.ended_at is not None assert call.started_at < call.ended_at assert op_name_from_ref(call.op_name) == "cohere.AsyncClientV2.chat" - output = _get_call_output(call) + output = call.output assert output.message.content[0].text == response.message.content[0].text assert output.id == response.id @@ -361,7 +361,7 @@ def test_cohere_stream_v2( assert call.exception is None and call.ended_at is not None assert call.started_at < call.ended_at assert op_name_from_ref(call.op_name) == "cohere.ClientV2.chat_stream" - output = _get_call_output(call) + output = call.output assert output.message.content[0] == all_content assert output.id == id @@ -411,7 +411,7 @@ async def test_cohere_async_stream_v2( assert call.exception is None and call.ended_at is not None assert call.started_at < call.ended_at assert op_name_from_ref(call.op_name) == "cohere.AsyncClientV2.chat_stream" - output = _get_call_output(call) + output = call.output assert output.message.content[0] == all_content assert output.id == id diff --git a/weave/integrations/dspy/cassettes/dspy_test/test_dspy_inline_signatures.yaml b/tests/integrations/dspy/cassettes/dspy_test/test_dspy_inline_signatures.yaml similarity index 100% rename from weave/integrations/dspy/cassettes/dspy_test/test_dspy_inline_signatures.yaml rename to tests/integrations/dspy/cassettes/dspy_test/test_dspy_inline_signatures.yaml diff --git a/weave/integrations/dspy/cassettes/dspy_test/test_dspy_language_models.yaml b/tests/integrations/dspy/cassettes/dspy_test/test_dspy_language_models.yaml similarity index 100% rename from weave/integrations/dspy/cassettes/dspy_test/test_dspy_language_models.yaml rename to tests/integrations/dspy/cassettes/dspy_test/test_dspy_language_models.yaml diff --git a/weave/integrations/dspy/dspy_test.py b/tests/integrations/dspy/dspy_test.py similarity index 100% rename from weave/integrations/dspy/dspy_test.py rename to tests/integrations/dspy/dspy_test.py diff --git a/weave/integrations/groq/cassettes/groq_test/test_groq_async_chat_completion.yaml b/tests/integrations/groq/cassettes/groq_test/test_groq_async_chat_completion.yaml similarity index 100% rename from weave/integrations/groq/cassettes/groq_test/test_groq_async_chat_completion.yaml rename to tests/integrations/groq/cassettes/groq_test/test_groq_async_chat_completion.yaml diff --git a/weave/integrations/groq/cassettes/groq_test/test_groq_async_streaming_chat_completion.yaml b/tests/integrations/groq/cassettes/groq_test/test_groq_async_streaming_chat_completion.yaml similarity index 100% rename from weave/integrations/groq/cassettes/groq_test/test_groq_async_streaming_chat_completion.yaml rename to tests/integrations/groq/cassettes/groq_test/test_groq_async_streaming_chat_completion.yaml diff --git a/weave/integrations/groq/cassettes/groq_test/test_groq_quickstart.yaml b/tests/integrations/groq/cassettes/groq_test/test_groq_quickstart.yaml similarity index 100% rename from weave/integrations/groq/cassettes/groq_test/test_groq_quickstart.yaml rename to tests/integrations/groq/cassettes/groq_test/test_groq_quickstart.yaml diff --git a/weave/integrations/groq/cassettes/groq_test/test_groq_streaming_chat_completion.yaml b/tests/integrations/groq/cassettes/groq_test/test_groq_streaming_chat_completion.yaml similarity index 100% rename from weave/integrations/groq/cassettes/groq_test/test_groq_streaming_chat_completion.yaml rename to tests/integrations/groq/cassettes/groq_test/test_groq_streaming_chat_completion.yaml diff --git a/weave/integrations/groq/cassettes/groq_test/test_groq_tool_call.yaml b/tests/integrations/groq/cassettes/groq_test/test_groq_tool_call.yaml similarity index 100% rename from weave/integrations/groq/cassettes/groq_test/test_groq_tool_call.yaml rename to tests/integrations/groq/cassettes/groq_test/test_groq_tool_call.yaml diff --git a/weave/integrations/groq/groq_test.py b/tests/integrations/groq/groq_test.py similarity index 100% rename from weave/integrations/groq/groq_test.py rename to tests/integrations/groq/groq_test.py diff --git a/weave/integrations/instructor/cassettes/instructor_test/test_instructor_iterable.yaml b/tests/integrations/instructor/cassettes/instructor_test/test_instructor_iterable.yaml similarity index 100% rename from weave/integrations/instructor/cassettes/instructor_test/test_instructor_iterable.yaml rename to tests/integrations/instructor/cassettes/instructor_test/test_instructor_iterable.yaml diff --git a/weave/integrations/instructor/cassettes/instructor_test/test_instructor_iterable_async_stream.yaml b/tests/integrations/instructor/cassettes/instructor_test/test_instructor_iterable_async_stream.yaml similarity index 100% rename from weave/integrations/instructor/cassettes/instructor_test/test_instructor_iterable_async_stream.yaml rename to tests/integrations/instructor/cassettes/instructor_test/test_instructor_iterable_async_stream.yaml diff --git a/weave/integrations/instructor/cassettes/instructor_test/test_instructor_iterable_sync_stream.yaml b/tests/integrations/instructor/cassettes/instructor_test/test_instructor_iterable_sync_stream.yaml similarity index 100% rename from weave/integrations/instructor/cassettes/instructor_test/test_instructor_iterable_sync_stream.yaml rename to tests/integrations/instructor/cassettes/instructor_test/test_instructor_iterable_sync_stream.yaml diff --git a/weave/integrations/instructor/cassettes/instructor_test/test_instructor_openai.yaml b/tests/integrations/instructor/cassettes/instructor_test/test_instructor_openai.yaml similarity index 100% rename from weave/integrations/instructor/cassettes/instructor_test/test_instructor_openai.yaml rename to tests/integrations/instructor/cassettes/instructor_test/test_instructor_openai.yaml diff --git a/weave/integrations/instructor/cassettes/instructor_test/test_instructor_openai_async.yaml b/tests/integrations/instructor/cassettes/instructor_test/test_instructor_openai_async.yaml similarity index 100% rename from weave/integrations/instructor/cassettes/instructor_test/test_instructor_openai_async.yaml rename to tests/integrations/instructor/cassettes/instructor_test/test_instructor_openai_async.yaml diff --git a/weave/integrations/instructor/cassettes/instructor_test/test_instructor_partial_stream.yaml b/tests/integrations/instructor/cassettes/instructor_test/test_instructor_partial_stream.yaml similarity index 100% rename from weave/integrations/instructor/cassettes/instructor_test/test_instructor_partial_stream.yaml rename to tests/integrations/instructor/cassettes/instructor_test/test_instructor_partial_stream.yaml diff --git a/weave/integrations/instructor/cassettes/instructor_test/test_instructor_partial_stream_async.yaml b/tests/integrations/instructor/cassettes/instructor_test/test_instructor_partial_stream_async.yaml similarity index 100% rename from weave/integrations/instructor/cassettes/instructor_test/test_instructor_partial_stream_async.yaml rename to tests/integrations/instructor/cassettes/instructor_test/test_instructor_partial_stream_async.yaml diff --git a/weave/integrations/instructor/instructor_test.py b/tests/integrations/instructor/instructor_test.py similarity index 100% rename from weave/integrations/instructor/instructor_test.py rename to tests/integrations/instructor/instructor_test.py diff --git a/weave/integrations/integration_utilities_test.py b/tests/integrations/integration_utilities_test.py similarity index 86% rename from weave/integrations/integration_utilities_test.py rename to tests/integrations/integration_utilities_test.py index c5feea58ac8..2171b25d50e 100644 --- a/weave/integrations/integration_utilities_test.py +++ b/tests/integrations/integration_utilities_test.py @@ -1,15 +1,15 @@ import pytest -from weave.integrations.integration_utilities import truncate_op_name +from weave.integrations.integration_utilities import ( + _make_string_of_length, + _truncated_str, + truncate_op_name, +) MAX_RUN_NAME_LENGTH = 128 NON_HASH_LIMIT = 5 -def _make_string_of_length(n: int) -> str: - return "a" * n - - def test_truncate_op_name_less_than_limit() -> None: name = _make_string_of_length(MAX_RUN_NAME_LENGTH - 1) trunc = truncate_op_name(name) @@ -22,15 +22,6 @@ def test_truncate_op_name_at_limit() -> None: assert trunc == name -def _truncated_str(tail_len: int, total_len: int) -> tuple: - name = ( - _make_string_of_length(total_len - tail_len - 1) - + "." - + _make_string_of_length(tail_len) - ) - return name, truncate_op_name(name) - - def test_truncate_op_name_too_short_for_hash() -> None: # Remove 1 character for a range of tail lengths: chars_to_remove = 1 diff --git a/weave/integrations/langchain/cassettes/langchain_test/test_agent_run_with_function_call.yaml b/tests/integrations/langchain/cassettes/langchain_test/test_agent_run_with_function_call.yaml similarity index 100% rename from weave/integrations/langchain/cassettes/langchain_test/test_agent_run_with_function_call.yaml rename to tests/integrations/langchain/cassettes/langchain_test/test_agent_run_with_function_call.yaml diff --git a/weave/integrations/langchain/cassettes/langchain_test/test_agent_run_with_tools.yaml b/tests/integrations/langchain/cassettes/langchain_test/test_agent_run_with_tools.yaml similarity index 100% rename from weave/integrations/langchain/cassettes/langchain_test/test_agent_run_with_tools.yaml rename to tests/integrations/langchain/cassettes/langchain_test/test_agent_run_with_tools.yaml diff --git a/weave/integrations/langchain/cassettes/langchain_test/test_simple_chain_abatch.yaml b/tests/integrations/langchain/cassettes/langchain_test/test_simple_chain_abatch.yaml similarity index 100% rename from weave/integrations/langchain/cassettes/langchain_test/test_simple_chain_abatch.yaml rename to tests/integrations/langchain/cassettes/langchain_test/test_simple_chain_abatch.yaml diff --git a/weave/integrations/langchain/cassettes/langchain_test/test_simple_chain_ainvoke.yaml b/tests/integrations/langchain/cassettes/langchain_test/test_simple_chain_ainvoke.yaml similarity index 100% rename from weave/integrations/langchain/cassettes/langchain_test/test_simple_chain_ainvoke.yaml rename to tests/integrations/langchain/cassettes/langchain_test/test_simple_chain_ainvoke.yaml diff --git a/weave/integrations/langchain/cassettes/langchain_test/test_simple_chain_astream.yaml b/tests/integrations/langchain/cassettes/langchain_test/test_simple_chain_astream.yaml similarity index 100% rename from weave/integrations/langchain/cassettes/langchain_test/test_simple_chain_astream.yaml rename to tests/integrations/langchain/cassettes/langchain_test/test_simple_chain_astream.yaml diff --git a/weave/integrations/langchain/cassettes/langchain_test/test_simple_chain_batch.yaml b/tests/integrations/langchain/cassettes/langchain_test/test_simple_chain_batch.yaml similarity index 100% rename from weave/integrations/langchain/cassettes/langchain_test/test_simple_chain_batch.yaml rename to tests/integrations/langchain/cassettes/langchain_test/test_simple_chain_batch.yaml diff --git a/weave/integrations/langchain/cassettes/langchain_test/test_simple_chain_batch_inside_op.yaml b/tests/integrations/langchain/cassettes/langchain_test/test_simple_chain_batch_inside_op.yaml similarity index 100% rename from weave/integrations/langchain/cassettes/langchain_test/test_simple_chain_batch_inside_op.yaml rename to tests/integrations/langchain/cassettes/langchain_test/test_simple_chain_batch_inside_op.yaml diff --git a/weave/integrations/langchain/cassettes/langchain_test/test_simple_chain_invoke.yaml b/tests/integrations/langchain/cassettes/langchain_test/test_simple_chain_invoke.yaml similarity index 100% rename from weave/integrations/langchain/cassettes/langchain_test/test_simple_chain_invoke.yaml rename to tests/integrations/langchain/cassettes/langchain_test/test_simple_chain_invoke.yaml diff --git a/weave/integrations/langchain/cassettes/langchain_test/test_simple_chain_stream.yaml b/tests/integrations/langchain/cassettes/langchain_test/test_simple_chain_stream.yaml similarity index 100% rename from weave/integrations/langchain/cassettes/langchain_test/test_simple_chain_stream.yaml rename to tests/integrations/langchain/cassettes/langchain_test/test_simple_chain_stream.yaml diff --git a/weave/integrations/langchain/cassettes/langchain_test/test_simple_rag_chain.yaml b/tests/integrations/langchain/cassettes/langchain_test/test_simple_rag_chain.yaml similarity index 100% rename from weave/integrations/langchain/cassettes/langchain_test/test_simple_rag_chain.yaml rename to tests/integrations/langchain/cassettes/langchain_test/test_simple_rag_chain.yaml diff --git a/weave/integrations/langchain/langchain_test.py b/tests/integrations/langchain/langchain_test.py similarity index 95% rename from weave/integrations/langchain/langchain_test.py rename to tests/integrations/langchain/langchain_test.py index e063a0c4999..0121a640b48 100644 --- a/weave/integrations/langchain/langchain_test.py +++ b/tests/integrations/langchain/langchain_test.py @@ -304,13 +304,29 @@ def assert_correct_calls_for_rag_chain(calls: list[Call]) -> None: ("langchain.Chain.RunnableSequence", 2), ("langchain.Retriever.VectorStoreRetriever", 3), ("langchain.Chain.format_docs", 3), - ("langchain.Chain.RunnablePassthrough", 2), + ("langchain.Chain.RunnablePassthrough", 2), # Potential position ("langchain.Prompt.ChatPromptTemplate", 1), ("langchain.Llm.ChatOpenAI", 1), ("openai.chat.completions.create", 2), ("langchain.Parser.StrOutputParser", 1), ] - assert got == exp + # exp_2 is an alternative to exp with a different order of operations. + # The `RunnableParallel_context_question` executes its children in parallel, + # allowing for variation in the order of execution. As a result, + # `RunnablePassthrough` may appear in one of two possible positions. + exp_2 = [ + ("langchain.Chain.RunnableSequence", 0), + ("langchain.Chain.RunnableParallel_context_question", 1), + ("langchain.Chain.RunnablePassthrough", 2), # Potential position + ("langchain.Chain.RunnableSequence", 2), + ("langchain.Retriever.VectorStoreRetriever", 3), + ("langchain.Chain.format_docs", 3), + ("langchain.Prompt.ChatPromptTemplate", 1), + ("langchain.Llm.ChatOpenAI", 1), + ("openai.chat.completions.create", 2), + ("langchain.Parser.StrOutputParser", 1), + ] + assert (got == exp) or (got == exp_2) @pytest.fixture diff --git a/weave/integrations/langchain/test_data/paul_graham_essay.txt b/tests/integrations/langchain/test_data/paul_graham_essay.txt similarity index 100% rename from weave/integrations/langchain/test_data/paul_graham_essay.txt rename to tests/integrations/langchain/test_data/paul_graham_essay.txt diff --git a/weave/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart.yaml b/tests/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart.yaml similarity index 100% rename from weave/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart.yaml rename to tests/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart.yaml diff --git a/weave/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart_async.yaml b/tests/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart_async.yaml similarity index 100% rename from weave/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart_async.yaml rename to tests/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart_async.yaml diff --git a/weave/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart_stream.yaml b/tests/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart_stream.yaml similarity index 100% rename from weave/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart_stream.yaml rename to tests/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart_stream.yaml diff --git a/weave/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart_stream_async.yaml b/tests/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart_stream_async.yaml similarity index 100% rename from weave/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart_stream_async.yaml rename to tests/integrations/litellm/cassettes/litellm_test/test_litellm_quickstart_stream_async.yaml diff --git a/weave/integrations/litellm/cassettes/litellm_test/test_model_predict.yaml b/tests/integrations/litellm/cassettes/litellm_test/test_model_predict.yaml similarity index 100% rename from weave/integrations/litellm/cassettes/litellm_test/test_model_predict.yaml rename to tests/integrations/litellm/cassettes/litellm_test/test_model_predict.yaml diff --git a/weave/integrations/litellm/litellm_test.py b/tests/integrations/litellm/litellm_test.py similarity index 98% rename from weave/integrations/litellm/litellm_test.py rename to tests/integrations/litellm/litellm_test.py index 24fa4503ea9..965a8ae71df 100644 --- a/weave/integrations/litellm/litellm_test.py +++ b/tests/integrations/litellm/litellm_test.py @@ -4,11 +4,10 @@ import litellm import pytest -import semver +from packaging.version import parse as version_parse import weave - -from .litellm import litellm_patcher +from weave.integrations.litellm.litellm import litellm_patcher # This PR: # https://github.com/BerriAI/litellm/commit/fe2aa706e8ff4edbcd109897e5da6b83ef6ad693 @@ -16,7 +15,7 @@ # We can handle this in non-streaming mode, but in streaming mode, we # have no way of correctly capturing the output and not messing up the # users' code (that i can see). In these cases, model cost is not captured. -USES_RAW_OPENAI_RESPONSE = semver.compare(version("litellm"), "1.42.11") > 0 +USES_RAW_OPENAI_RESPONSE = version_parse(version("litellm")) > version_parse("1.42.11") class Nearly: diff --git a/weave/integrations/llamaindex/cassettes/llamaindex_test/test_llamaindex_quickstart.yaml b/tests/integrations/llamaindex/cassettes/llamaindex_test/test_llamaindex_quickstart.yaml similarity index 100% rename from weave/integrations/llamaindex/cassettes/llamaindex_test/test_llamaindex_quickstart.yaml rename to tests/integrations/llamaindex/cassettes/llamaindex_test/test_llamaindex_quickstart.yaml diff --git a/weave/integrations/llamaindex/cassettes/llamaindex_test/test_llamaindex_quickstart_async.yaml b/tests/integrations/llamaindex/cassettes/llamaindex_test/test_llamaindex_quickstart_async.yaml similarity index 100% rename from weave/integrations/llamaindex/cassettes/llamaindex_test/test_llamaindex_quickstart_async.yaml rename to tests/integrations/llamaindex/cassettes/llamaindex_test/test_llamaindex_quickstart_async.yaml diff --git a/weave/integrations/llamaindex/llamaindex_test.py b/tests/integrations/llamaindex/llamaindex_test.py similarity index 100% rename from weave/integrations/llamaindex/llamaindex_test.py rename to tests/integrations/llamaindex/llamaindex_test.py diff --git a/weave/integrations/llamaindex/test_data/paul_graham_essay.txt b/tests/integrations/llamaindex/test_data/paul_graham_essay.txt similarity index 100% rename from weave/integrations/llamaindex/test_data/paul_graham_essay.txt rename to tests/integrations/llamaindex/test_data/paul_graham_essay.txt diff --git a/weave/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart.yaml b/tests/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart.yaml similarity index 100% rename from weave/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart.yaml rename to tests/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart.yaml diff --git a/weave/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart_async.yaml b/tests/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart_async.yaml similarity index 100% rename from weave/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart_async.yaml rename to tests/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart_async.yaml diff --git a/weave/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart_with_stream.yaml b/tests/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart_with_stream.yaml similarity index 100% rename from weave/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart_with_stream.yaml rename to tests/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart_with_stream.yaml diff --git a/weave/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart_with_stream_async.yaml b/tests/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart_with_stream_async.yaml similarity index 100% rename from weave/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart_with_stream_async.yaml rename to tests/integrations/mistral/v0/cassettes/mistral_test/test_mistral_quickstart_with_stream_async.yaml diff --git a/weave/integrations/mistral/v0/mistral_test.py b/tests/integrations/mistral/v0/mistral_test.py similarity index 100% rename from weave/integrations/mistral/v0/mistral_test.py rename to tests/integrations/mistral/v0/mistral_test.py diff --git a/weave/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart.yaml b/tests/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart.yaml similarity index 100% rename from weave/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart.yaml rename to tests/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart.yaml diff --git a/weave/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart_async.yaml b/tests/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart_async.yaml similarity index 100% rename from weave/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart_async.yaml rename to tests/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart_async.yaml diff --git a/weave/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart_with_stream.yaml b/tests/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart_with_stream.yaml similarity index 100% rename from weave/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart_with_stream.yaml rename to tests/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart_with_stream.yaml diff --git a/weave/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart_with_stream_async.yaml b/tests/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart_with_stream_async.yaml similarity index 100% rename from weave/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart_with_stream_async.yaml rename to tests/integrations/mistral/v1/cassettes/mistral_test/test_mistral_quickstart_with_stream_async.yaml diff --git a/weave/integrations/mistral/v1/mistral_test.py b/tests/integrations/mistral/v1/mistral_test.py similarity index 100% rename from weave/integrations/mistral/v1/mistral_test.py rename to tests/integrations/mistral/v1/mistral_test.py diff --git a/weave/integrations/openai/cassettes/openai_test/test_openai_as_context_manager.yaml b/tests/integrations/openai/cassettes/openai_test/test_openai_as_context_manager.yaml similarity index 100% rename from weave/integrations/openai/cassettes/openai_test/test_openai_as_context_manager.yaml rename to tests/integrations/openai/cassettes/openai_test/test_openai_as_context_manager.yaml diff --git a/weave/integrations/openai/cassettes/openai_test/test_openai_as_context_manager_async.yaml b/tests/integrations/openai/cassettes/openai_test/test_openai_as_context_manager_async.yaml similarity index 100% rename from weave/integrations/openai/cassettes/openai_test/test_openai_as_context_manager_async.yaml rename to tests/integrations/openai/cassettes/openai_test/test_openai_as_context_manager_async.yaml diff --git a/weave/integrations/openai/cassettes/openai_test/test_openai_async_quickstart.yaml b/tests/integrations/openai/cassettes/openai_test/test_openai_async_quickstart.yaml similarity index 100% rename from weave/integrations/openai/cassettes/openai_test/test_openai_async_quickstart.yaml rename to tests/integrations/openai/cassettes/openai_test/test_openai_async_quickstart.yaml diff --git a/weave/integrations/openai/cassettes/openai_test/test_openai_async_stream_quickstart.yaml b/tests/integrations/openai/cassettes/openai_test/test_openai_async_stream_quickstart.yaml similarity index 100% rename from weave/integrations/openai/cassettes/openai_test/test_openai_async_stream_quickstart.yaml rename to tests/integrations/openai/cassettes/openai_test/test_openai_async_stream_quickstart.yaml diff --git a/weave/integrations/openai/cassettes/openai_test/test_openai_function_call.yaml b/tests/integrations/openai/cassettes/openai_test/test_openai_function_call.yaml similarity index 100% rename from weave/integrations/openai/cassettes/openai_test/test_openai_function_call.yaml rename to tests/integrations/openai/cassettes/openai_test/test_openai_function_call.yaml diff --git a/weave/integrations/openai/cassettes/openai_test/test_openai_function_call_async.yaml b/tests/integrations/openai/cassettes/openai_test/test_openai_function_call_async.yaml similarity index 100% rename from weave/integrations/openai/cassettes/openai_test/test_openai_function_call_async.yaml rename to tests/integrations/openai/cassettes/openai_test/test_openai_function_call_async.yaml diff --git a/weave/integrations/openai/cassettes/openai_test/test_openai_function_call_async_stream.yaml b/tests/integrations/openai/cassettes/openai_test/test_openai_function_call_async_stream.yaml similarity index 100% rename from weave/integrations/openai/cassettes/openai_test/test_openai_function_call_async_stream.yaml rename to tests/integrations/openai/cassettes/openai_test/test_openai_function_call_async_stream.yaml diff --git a/weave/integrations/openai/cassettes/openai_test/test_openai_quickstart.yaml b/tests/integrations/openai/cassettes/openai_test/test_openai_quickstart.yaml similarity index 100% rename from weave/integrations/openai/cassettes/openai_test/test_openai_quickstart.yaml rename to tests/integrations/openai/cassettes/openai_test/test_openai_quickstart.yaml diff --git a/weave/integrations/openai/cassettes/openai_test/test_openai_stream_quickstart.yaml b/tests/integrations/openai/cassettes/openai_test/test_openai_stream_quickstart.yaml similarity index 100% rename from weave/integrations/openai/cassettes/openai_test/test_openai_stream_quickstart.yaml rename to tests/integrations/openai/cassettes/openai_test/test_openai_stream_quickstart.yaml diff --git a/weave/integrations/openai/cassettes/openai_test/test_openai_stream_usage_quickstart.yaml b/tests/integrations/openai/cassettes/openai_test/test_openai_stream_usage_quickstart.yaml similarity index 100% rename from weave/integrations/openai/cassettes/openai_test/test_openai_stream_usage_quickstart.yaml rename to tests/integrations/openai/cassettes/openai_test/test_openai_stream_usage_quickstart.yaml diff --git a/weave/integrations/openai/cassettes/openai_test/test_openai_tool_call.yaml b/tests/integrations/openai/cassettes/openai_test/test_openai_tool_call.yaml similarity index 100% rename from weave/integrations/openai/cassettes/openai_test/test_openai_tool_call.yaml rename to tests/integrations/openai/cassettes/openai_test/test_openai_tool_call.yaml diff --git a/weave/integrations/openai/cassettes/openai_test/test_openai_tool_call_async.yaml b/tests/integrations/openai/cassettes/openai_test/test_openai_tool_call_async.yaml similarity index 100% rename from weave/integrations/openai/cassettes/openai_test/test_openai_tool_call_async.yaml rename to tests/integrations/openai/cassettes/openai_test/test_openai_tool_call_async.yaml diff --git a/weave/integrations/openai/cassettes/openai_test/test_openai_tool_call_async_stream.yaml b/tests/integrations/openai/cassettes/openai_test/test_openai_tool_call_async_stream.yaml similarity index 100% rename from weave/integrations/openai/cassettes/openai_test/test_openai_tool_call_async_stream.yaml rename to tests/integrations/openai/cassettes/openai_test/test_openai_tool_call_async_stream.yaml diff --git a/weave/integrations/openai/openai_test.py b/tests/integrations/openai/openai_test.py similarity index 100% rename from weave/integrations/openai/openai_test.py rename to tests/integrations/openai/openai_test.py diff --git a/weave/tests/test_patcher_module/__init__.py b/tests/test_patcher_module/__init__.py similarity index 100% rename from weave/tests/test_patcher_module/__init__.py rename to tests/test_patcher_module/__init__.py diff --git a/weave/tests/test_patcher_module/example_class.py b/tests/test_patcher_module/example_class.py similarity index 100% rename from weave/tests/test_patcher_module/example_class.py rename to tests/test_patcher_module/example_class.py diff --git a/weave/trace/concurrent/test_futures.py b/tests/trace/concurrent/test_futures.py similarity index 89% rename from weave/trace/concurrent/test_futures.py rename to tests/trace/concurrent/test_futures.py index 92c085c3636..e8cbb7bdc61 100644 --- a/weave/trace/concurrent/test_futures.py +++ b/tests/trace/concurrent/test_futures.py @@ -17,7 +17,8 @@ def simple_task() -> int: assert future.result() == 42 -def test_defer_with_exception() -> None: +@pytest.mark.disable_logging_error_check +def test_defer_with_exception(log_collector) -> None: executor: FutureExecutor = FutureExecutor() def failing_task() -> None: @@ -27,6 +28,10 @@ def failing_task() -> None: with pytest.raises(ValueError, match="Test exception"): future.result() + logs = log_collector.get_error_logs() + assert len(logs) == 1 + assert "ValueError: Test exception" in logs[0].msg + def test_then_single_future() -> None: executor: FutureExecutor = FutureExecutor() @@ -62,7 +67,8 @@ def process_multiple_data(data_list: List[List[int]]) -> int: assert future_result.result() == 15 -def test_then_with_exception_in_future() -> None: +@pytest.mark.disable_logging_error_check +def test_then_with_exception_in_future(log_collector) -> None: executor: FutureExecutor = FutureExecutor() def failing_task() -> None: @@ -77,8 +83,13 @@ def process_data(data_list: List[Any]) -> Any: with pytest.raises(ValueError, match="Future exception"): future_result.result() + logs = log_collector.get_error_logs() + assert len(logs) == 1 + assert "ValueError: Future exception" in logs[0].msg + -def test_then_with_exception_in_callback() -> None: +@pytest.mark.disable_logging_error_check +def test_then_with_exception_in_callback(log_collector) -> None: executor: FutureExecutor = FutureExecutor() def fetch_data() -> List[int]: @@ -93,6 +104,10 @@ def failing_process(data_list: List[List[int]]) -> None: with pytest.raises(ValueError, match="Callback exception"): future_result.result() + logs = log_collector.get_error_logs() + assert len(logs) == 1 + assert "ValueError: Callback exception" in logs[0].msg + def test_concurrent_execution() -> None: executor: FutureExecutor = FutureExecutor() diff --git a/weave/tests/trace/op_versioning_importfrom.py b/tests/trace/op_versioning_importfrom.py similarity index 100% rename from weave/tests/trace/op_versioning_importfrom.py rename to tests/trace/op_versioning_importfrom.py diff --git a/weave/tests/trace/op_versioning_inlineimport.py b/tests/trace/op_versioning_inlineimport.py similarity index 100% rename from weave/tests/trace/op_versioning_inlineimport.py rename to tests/trace/op_versioning_inlineimport.py diff --git a/weave/tests/trace/op_versioning_obj.py b/tests/trace/op_versioning_obj.py similarity index 100% rename from weave/tests/trace/op_versioning_obj.py rename to tests/trace/op_versioning_obj.py diff --git a/weave/tests/trace/op_versioning_solo.py b/tests/trace/op_versioning_solo.py similarity index 100% rename from weave/tests/trace/op_versioning_solo.py rename to tests/trace/op_versioning_solo.py diff --git a/weave/tests/trace/test_anonymous_ops.py b/tests/trace/test_anonymous_ops.py similarity index 100% rename from weave/tests/trace/test_anonymous_ops.py rename to tests/trace/test_anonymous_ops.py diff --git a/weave/tests/trace/test_call_behaviours.py b/tests/trace/test_call_behaviours.py similarity index 100% rename from weave/tests/trace/test_call_behaviours.py rename to tests/trace/test_call_behaviours.py diff --git a/weave/tests/trace/test_cli.py b/tests/trace/test_cli.py similarity index 100% rename from weave/tests/trace/test_cli.py rename to tests/trace/test_cli.py diff --git a/weave/tests/trace/test_client_cost.py b/tests/trace/test_client_cost.py similarity index 100% rename from weave/tests/trace/test_client_cost.py rename to tests/trace/test_client_cost.py diff --git a/weave/tests/trace/test_client_feedback.py b/tests/trace/test_client_feedback.py similarity index 100% rename from weave/tests/trace/test_client_feedback.py rename to tests/trace/test_client_feedback.py diff --git a/weave/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py similarity index 98% rename from weave/tests/trace/test_client_trace.py rename to tests/trace/test_client_trace.py index 6673717910e..28b29e490ee 100644 --- a/weave/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -14,16 +14,15 @@ from pydantic import BaseModel, ValidationError import weave -from weave import Thread, ThreadPoolExecutor -from weave.tests.trace.util import ( +from tests.trace.util import ( AnyIntMatcher, DatetimeMatcher, FuzzyDateTimeMatcher, MaybeStringMatcher, client_is_sqlite, ) +from weave import Thread, ThreadPoolExecutor from weave.trace import weave_client -from weave.trace.object_record import ObjectRecord from weave.trace.vals import MissingSelfInstanceError from weave.trace.weave_client import sanitize_object_name from weave.trace_server import trace_server_interface as tsi @@ -2664,6 +2663,31 @@ def return_nested_object(nested_obj: NestedObject): assert call_result.output == nested_ref.uri() +# Batch size is dynamically increased from 10 to MAX_CALLS_STREAM_BATCH_SIZE (500) +# in clickhouse_trace_server_batched.py, this test verifies that the dynamic +# increase works as expected +@pytest.mark.parametrize("batch_size", [1, 10, 100, 110]) +def test_calls_stream_column_expansion_dynamic_batch_size(client, batch_size): + @weave.op + def test_op(x): + return x + + for i in range(batch_size): + test_op(i) + + res = client.server.calls_query_stream( + tsi.CallsQueryReq( + project_id=client._project_id(), + columns=["output"], + expand_columns=["output"], + ) + ) + calls = list(res) + assert len(calls) == batch_size + for i in range(batch_size): + assert calls[i].output == i + + class Custom(weave.Object): val: dict @@ -2793,16 +2817,18 @@ def test(obj: Custom): def test_calls_stream_feedback(client): + BATCH_SIZE = 10 + num_calls = BATCH_SIZE + 1 + @weave.op def test_call(x): return "ello chap" - test_call(1) - test_call(2) - test_call(3) + for i in range(num_calls): + test_call(i) calls = list(test_call.calls()) - assert len(calls) == 3 + assert len(calls) == num_calls # add feedback to the first call calls[0].feedback.add("note", {"note": "this is a note on call1"}) @@ -2821,7 +2847,7 @@ def test_call(x): ) calls = list(res) - assert len(calls) == 3 + assert len(calls) == num_calls assert len(calls[0].summary["weave"]["feedback"]) == 4 assert len(calls[1].summary["weave"]["feedback"]) == 1 assert not calls[2].summary.get("weave", {}).get("feedback") diff --git a/weave/tests/trace/test_dirty_model_op_retrieval.py b/tests/trace/test_dirty_model_op_retrieval.py similarity index 100% rename from weave/tests/trace/test_dirty_model_op_retrieval.py rename to tests/trace/test_dirty_model_op_retrieval.py diff --git a/weave/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py similarity index 100% rename from weave/tests/trace/test_evaluate.py rename to tests/trace/test_evaluate.py diff --git a/weave/tests/trace/test_evaluation_performance.py b/tests/trace/test_evaluation_performance.py similarity index 93% rename from weave/tests/trace/test_evaluation_performance.py rename to tests/trace/test_evaluation_performance.py index 5563dbb15c2..909e192c408 100644 --- a/weave/tests/trace/test_evaluation_performance.py +++ b/tests/trace/test_evaluation_performance.py @@ -1,4 +1,3 @@ -import sys from collections import Counter from contextlib import contextmanager from threading import Lock @@ -8,7 +7,7 @@ import pytest import weave -from weave.conftest import DummyTestException +from tests.trace.util import DummyTestException from weave.trace.context import raise_on_captured_errors from weave.trace.weave_client import WeaveClient from weave.trace_server import trace_server_interface as tsi @@ -161,7 +160,10 @@ async def test_evaluation_resilience( client_with_throwing_server._flush() logs = log_collector.get_error_logs() - assert len(logs) == 0 + ag_res = Counter([k.split(", req:")[0] for k in set([l.msg for l in logs])]) + assert len(ag_res) == 2 + assert ag_res["Task failed: DummyTestException: ('FAILURE - obj_create"] <= 2 + assert ag_res["Task failed: DummyTestException: ('FAILURE - file_create"] <= 2 # We should gracefully handle the error and return a value with raise_on_captured_errors(False): @@ -177,7 +179,8 @@ async def test_evaluation_resilience( # For some reason with high parallelism, some logs are not captured, # so instead of exact counts, we just check that the number of unique # logs is <= the expected number of logs. - assert len(ag_res) == 3 + assert len(ag_res) == 4 assert ag_res["Job failed during flush: ('FAILURE - call_end"] <= 14 assert ag_res["Job failed during flush: ('FAILURE - obj_create"] <= 6 + assert ag_res["Job failed during flush: ('FAILURE - file_create"] <= 6 assert ag_res["Job failed during flush: ('FAILURE - table_create"] <= 1 diff --git a/weave/tests/trace/test_evaluations.py b/tests/trace/test_evaluations.py similarity index 99% rename from weave/tests/trace/test_evaluations.py rename to tests/trace/test_evaluations.py index 205509c17e9..dbba17f5d88 100644 --- a/weave/tests/trace/test_evaluations.py +++ b/tests/trace/test_evaluations.py @@ -7,8 +7,8 @@ from PIL import Image import weave +from tests.trace.util import AnyIntMatcher from weave import Evaluation, Model -from weave.tests.trace.util import AnyIntMatcher from weave.trace_server import trace_server_interface as tsi diff --git a/weave/tests/trace/test_exec.py b/tests/trace/test_exec.py similarity index 98% rename from weave/tests/trace/test_exec.py rename to tests/trace/test_exec.py index 86f2d276393..efafcb1cef8 100644 --- a/weave/tests/trace/test_exec.py +++ b/tests/trace/test_exec.py @@ -1,9 +1,6 @@ import sys import textwrap -import typing -from typing import Union -import numpy as np import pytest diff --git a/weave/tests/trace/test_feedback.py b/tests/trace/test_feedback.py similarity index 97% rename from weave/tests/trace/test_feedback.py rename to tests/trace/test_feedback.py index 80ae74b8af8..cc05dccd624 100644 --- a/weave/tests/trace/test_feedback.py +++ b/tests/trace/test_feedback.py @@ -1,7 +1,5 @@ import pytest -import weave.trace_server.trace_server_interface as tsi - def test_client_feedback(client) -> None: feedbacks = client.get_feedback() diff --git a/weave/tests/trace/test_objs_query.py b/tests/trace/test_objs_query.py similarity index 100% rename from weave/tests/trace/test_objs_query.py rename to tests/trace/test_objs_query.py diff --git a/weave/tests/trace/test_op_argument_forms.py b/tests/trace/test_op_argument_forms.py similarity index 100% rename from weave/tests/trace/test_op_argument_forms.py rename to tests/trace/test_op_argument_forms.py diff --git a/weave/tests/trace/test_op_call_method.py b/tests/trace/test_op_call_method.py similarity index 100% rename from weave/tests/trace/test_op_call_method.py rename to tests/trace/test_op_call_method.py diff --git a/weave/tests/trace/test_op_decorator_behaviour.py b/tests/trace/test_op_decorator_behaviour.py similarity index 99% rename from weave/tests/trace/test_op_decorator_behaviour.py rename to tests/trace/test_op_decorator_behaviour.py index 93b7741900b..6e788fac150 100644 --- a/weave/tests/trace/test_op_decorator_behaviour.py +++ b/tests/trace/test_op_decorator_behaviour.py @@ -1,12 +1,11 @@ import inspect -import sys from typing import Any import pytest import weave from weave.trace import errors -from weave.trace.op import Op, is_op, op +from weave.trace.op import is_op, op from weave.trace.refs import ObjectRef, parse_uri from weave.trace.vals import MissingSelfInstanceError from weave.trace.weave_client import Call diff --git a/weave/tests/trace/test_op_return_forms.py b/tests/trace/test_op_return_forms.py similarity index 100% rename from weave/tests/trace/test_op_return_forms.py rename to tests/trace/test_op_return_forms.py diff --git a/weave/tests/trace/test_op_versioning.py b/tests/trace/test_op_versioning.py similarity index 97% rename from weave/tests/trace/test_op_versioning.py rename to tests/trace/test_op_versioning.py index 13369761b15..b0892063518 100644 --- a/weave/tests/trace/test_op_versioning.py +++ b/tests/trace/test_op_versioning.py @@ -35,7 +35,7 @@ def solo_versioned_op(a: int) -> float: def test_solo_op_versioning(strict_op_saving, client): - from weave.tests.trace import op_versioning_solo + from tests.trace import op_versioning_solo ref = weave.publish(op_versioning_solo.solo_versioned_op) @@ -58,7 +58,7 @@ def versioned_op(self, a: int) -> float: def test_object_op_versioning(strict_op_saving, client): - from weave.tests.trace import op_versioning_obj + from tests.trace import op_versioning_obj obj = op_versioning_obj.MyTestObjWithOp(val=5) # Call it to publish @@ -82,7 +82,7 @@ def versioned_op_importfrom(a: int) -> float: def test_op_versioning_importfrom(strict_op_saving, client): - from weave.tests.trace import op_versioning_importfrom + from tests.trace import op_versioning_importfrom ref = weave.publish(op_versioning_importfrom.versioned_op_importfrom) saved_code = get_saved_code(client, ref) @@ -101,7 +101,7 @@ def versioned_op_lotsofstuff(a: int) -> float: def test_op_versioning_inline_import(strict_op_saving, client): - from weave.tests.trace import op_versioning_inlineimport + pass def test_op_versioning_inline_func_decl(strict_op_saving): diff --git a/weave/tests/trace/test_patcher.py b/tests/trace/test_patcher.py similarity index 76% rename from weave/tests/trace/test_patcher.py rename to tests/trace/test_patcher.py index 94797aec94d..dec04d0d971 100644 --- a/weave/tests/trace/test_patcher.py +++ b/tests/trace/test_patcher.py @@ -6,12 +6,10 @@ def test_symbol_patcher(): - from weave.tests.test_patcher_module.example_class import ExampleClass + from tests.test_patcher_module.example_class import ExampleClass patcher = SymbolPatcher( - lambda: importlib.import_module( - "weave.tests.test_patcher_module.example_class" - ), + lambda: importlib.import_module("tests.test_patcher_module.example_class"), "ExampleClass.example_fn", lambda original_fn: lambda self: 43, ) @@ -39,9 +37,7 @@ def test_symbol_patcher_invalid_module(): def test_symbol_patcher_invalid_attr(): patcher = SymbolPatcher( - lambda: importlib.import_module( - "weave.tests.test_patcher_module.example_class" - ), + lambda: importlib.import_module("tests.test_patcher_module.example_class"), "NotARealExampleClass.example_fn", lambda original_fn: lambda self: 43, ) @@ -52,12 +48,10 @@ def test_symbol_patcher_invalid_attr(): @pytest.mark.disable_logging_error_check def test_symbol_patcher_invalid_patching(log_collector): - from weave.tests.test_patcher_module.example_class import ExampleClass + from tests.test_patcher_module.example_class import ExampleClass patcher = SymbolPatcher( - lambda: importlib.import_module( - "weave.tests.test_patcher_module.example_class" - ), + lambda: importlib.import_module("tests.test_patcher_module.example_class"), "ExampleClass.example_fn", lambda original_fn: [] + 42, ) diff --git a/weave/tests/trace/test_ref_trace.py b/tests/trace/test_ref_trace.py similarity index 100% rename from weave/tests/trace/test_ref_trace.py rename to tests/trace/test_ref_trace.py diff --git a/weave/tests/trace/test_table_query.py b/tests/trace/test_table_query.py similarity index 100% rename from weave/tests/trace/test_table_query.py rename to tests/trace/test_table_query.py diff --git a/weave/tests/trace/test_trace_server.py b/tests/trace/test_trace_server.py similarity index 100% rename from weave/tests/trace/test_trace_server.py rename to tests/trace/test_trace_server.py diff --git a/weave/tests/trace/test_trace_server_common.py b/tests/trace/test_trace_server_common.py similarity index 99% rename from weave/tests/trace/test_trace_server_common.py rename to tests/trace/test_trace_server_common.py index 52ce1cb2976..9bc7495481f 100644 --- a/weave/tests/trace/test_trace_server_common.py +++ b/tests/trace/test_trace_server_common.py @@ -1,5 +1,3 @@ -import pytest - from weave.trace_server.trace_server_common import ( LRUCache, get_nested_key, diff --git a/weave/tests/trace/test_trace_settings.py b/tests/trace/test_trace_settings.py similarity index 99% rename from weave/tests/trace/test_trace_settings.py rename to tests/trace/test_trace_settings.py index 76e3083e05d..bcbd27aeea8 100644 --- a/weave/tests/trace/test_trace_settings.py +++ b/tests/trace/test_trace_settings.py @@ -1,8 +1,6 @@ -import inspect import io import os import sys -import textwrap import time import timeit diff --git a/weave/tests/trace/test_tracing_resilience.py b/tests/trace/test_tracing_resilience.py similarity index 98% rename from weave/tests/trace/test_tracing_resilience.py rename to tests/trace/test_tracing_resilience.py index 8f3c953bfb2..8a982d48418 100644 --- a/weave/tests/trace/test_tracing_resilience.py +++ b/tests/trace/test_tracing_resilience.py @@ -12,7 +12,7 @@ import pytest import weave -from weave.conftest import DummyTestException +from tests.trace.util import DummyTestException from weave.trace import call_context from weave.trace.context import raise_on_captured_errors from weave.trace.op_extensions.accumulator import add_accumulator @@ -68,8 +68,9 @@ def simple_op(): # Tim: This is very specific and intentiaion, please don't change # this unless you are sure that is the expected behavior assert ag_res == { - "Job failed during flush: ('FAILURE - call_end": 1, - "Job failed during flush: ('FAILURE - obj_create": 1, + "Task failed: DummyTestException: ('FAILURE - call_end": 1, + "Task failed: DummyTestException: ('FAILURE - file_create": 1, + "Task failed: DummyTestException: ('FAILURE - obj_create": 1, } diff --git a/weave/tests/trace/test_vals.py b/tests/trace/test_vals.py similarity index 100% rename from weave/tests/trace/test_vals.py rename to tests/trace/test_vals.py diff --git a/weave/tests/trace/test_weave_client.py b/tests/trace/test_weave_client.py similarity index 95% rename from weave/tests/trace/test_weave_client.py rename to tests/trace/test_weave_client.py index 4d8f281e0f8..5d081f8ce63 100644 --- a/weave/tests/trace/test_weave_client.py +++ b/tests/trace/test_weave_client.py @@ -10,17 +10,16 @@ import weave import weave.trace_server.trace_server_interface as tsi -from weave import Evaluation -from weave.tests.trace.testutil import ObjectRefStrMatcher -from weave.tests.trace.util import ( +from tests.trace.testutil import ObjectRefStrMatcher +from tests.trace.util import ( AnyIntMatcher, DatetimeMatcher, RegexStringMatcher, - client_is_sqlite, ) +from weave import Evaluation from weave.trace import refs, weave_client from weave.trace.isinstance import weave_isinstance -from weave.trace.op import Op, is_op +from weave.trace.op import is_op from weave.trace.refs import ( DICT_KEY_EDGE_NAME, LIST_INDEX_EDGE_NAME, @@ -28,7 +27,13 @@ TABLE_ROW_ID_EDGE_NAME, ) from weave.trace.serializer import get_serializer_for_obj, register_serializer -from weave.trace_server.sqlite_trace_server import SqliteTraceServer +from weave.trace_server.clickhouse_trace_server_batched import NotFoundError +from weave.trace_server.sqlite_trace_server import ( + NotFoundError as sqliteNotFoundError, +) +from weave.trace_server.sqlite_trace_server import ( + SqliteTraceServer, +) from weave.trace_server.trace_server_interface import ( FileContentReadReq, FileCreateReq, @@ -1437,7 +1442,31 @@ def test_object_version_read(client): assert obj_res.obj.val == {"a": 9} assert obj_res.obj.version_index == 9 - # now grab version 5 + # now grab each by their digests + for i, digest in enumerate([obj.digest for obj in objs]): + obj_res = client.server.obj_read( + tsi.ObjReadReq( + project_id=client._project_id(), + object_id=refs[0].name, + digest=digest, + ) + ) + assert obj_res.obj.val == {"a": i} + assert obj_res.obj.version_index == i + + # publish another, check that latest is updated + client._save_object({"a": 10}, refs[0].name) + obj_res = client.server.obj_read( + tsi.ObjReadReq( + project_id=client._project_id(), + object_id=refs[0].name, + digest="latest", + ) + ) + assert obj_res.obj.val == {"a": 10} + assert obj_res.obj.version_index == 10 + + # check that v5 is still correct obj_res = client.server.obj_read( tsi.ObjReadReq( project_id=client._project_id(), @@ -1447,3 +1476,26 @@ def test_object_version_read(client): ) assert obj_res.obj.val == {"a": 5} assert obj_res.obj.version_index == 5 + + # check badly formatted digests + digests = ["v1111", "1", ""] + for digest in digests: + with pytest.raises((NotFoundError, sqliteNotFoundError)): + # grab non-existant version + obj_res = client.server.obj_read( + tsi.ObjReadReq( + project_id=client._project_id(), + object_id=refs[0].name, + digest=digest, + ) + ) + + # check non-existant object_id + with pytest.raises((NotFoundError, sqliteNotFoundError)): + obj_res = client.server.obj_read( + tsi.ObjReadReq( + project_id=client._project_id(), + object_id="refs[0].name", + digest="v1", + ) + ) diff --git a/weave/tests/trace/test_weave_client_mutations.py b/tests/trace/test_weave_client_mutations.py similarity index 100% rename from weave/tests/trace/test_weave_client_mutations.py rename to tests/trace/test_weave_client_mutations.py diff --git a/weave/tests/trace/test_weave_client_threaded.py b/tests/trace/test_weave_client_threaded.py similarity index 100% rename from weave/tests/trace/test_weave_client_threaded.py rename to tests/trace/test_weave_client_threaded.py diff --git a/weave/tests/trace/test_weave_object_vals.py b/tests/trace/test_weave_object_vals.py similarity index 100% rename from weave/tests/trace/test_weave_object_vals.py rename to tests/trace/test_weave_object_vals.py diff --git a/weave/tests/trace/test_weaveflow.py b/tests/trace/test_weaveflow.py similarity index 99% rename from weave/tests/trace/test_weaveflow.py rename to tests/trace/test_weaveflow.py index 71e63c44a70..37490de691a 100644 --- a/weave/tests/trace/test_weaveflow.py +++ b/tests/trace/test_weaveflow.py @@ -1,5 +1,3 @@ -import typing - import numpy as np import pytest from pydantic import Field diff --git a/weave/tests/trace/testutil.py b/tests/trace/testutil.py similarity index 100% rename from weave/tests/trace/testutil.py rename to tests/trace/testutil.py diff --git a/weave/tests/trace/trace_server_clickhouse_conftest.py b/tests/trace/trace_server_clickhouse_conftest.py similarity index 100% rename from weave/tests/trace/trace_server_clickhouse_conftest.py rename to tests/trace/trace_server_clickhouse_conftest.py diff --git a/weave/tests/trace/trace_weaveflow.py b/tests/trace/trace_weaveflow.py similarity index 96% rename from weave/tests/trace/trace_weaveflow.py rename to tests/trace/trace_weaveflow.py index 4ba87560697..0addc4f8ff9 100644 --- a/weave/tests/trace/trace_weaveflow.py +++ b/tests/trace/trace_weaveflow.py @@ -1,5 +1,4 @@ import weave -from weave import weaveflow # This is just an experiment script, it is not a real test for now because we dont # have a unit test harness set up that supports what this needs to run successfully. diff --git a/weave/tests/trace/util.py b/tests/trace/util.py similarity index 96% rename from weave/tests/trace/util.py rename to tests/trace/util.py index 69a3de1b6e6..eb4c6002beb 100644 --- a/weave/tests/trace/util.py +++ b/tests/trace/util.py @@ -53,3 +53,7 @@ def __eq__(self, other): class DatetimeMatcher: def __eq__(self, other): return isinstance(other, datetime.datetime) + + +class DummyTestException(Exception): + pass diff --git a/weave/trace_server/tests/test_call_lifecycle.py b/tests/trace_server/test_call_lifecycle.py similarity index 100% rename from weave/trace_server/tests/test_call_lifecycle.py rename to tests/trace_server/test_call_lifecycle.py diff --git a/weave/trace_server/tests/test_calls_query_builder.py b/tests/trace_server/test_calls_query_builder.py similarity index 100% rename from weave/trace_server/tests/test_calls_query_builder.py rename to tests/trace_server/test_calls_query_builder.py diff --git a/weave/trace_server/tests/test_emoji_util.py b/tests/trace_server/test_emoji_util.py similarity index 100% rename from weave/trace_server/tests/test_emoji_util.py rename to tests/trace_server/test_emoji_util.py diff --git a/weave/trace_server/tests/test_orm.py b/tests/trace_server/test_orm.py similarity index 100% rename from weave/trace_server/tests/test_orm.py rename to tests/trace_server/test_orm.py diff --git a/weave/trace_server/tests/test_refs.py b/tests/trace_server/test_refs.py similarity index 100% rename from weave/trace_server/tests/test_refs.py rename to tests/trace_server/test_refs.py diff --git a/weave/trace_server/tests/test_remote_http_trace_server.py b/tests/trace_server/test_remote_http_trace_server.py similarity index 100% rename from weave/trace_server/tests/test_remote_http_trace_server.py rename to tests/trace_server/test_remote_http_trace_server.py diff --git a/weave/trace_server/tests/test_validation.py b/tests/trace_server/test_validation.py similarity index 100% rename from weave/trace_server/tests/test_validation.py rename to tests/trace_server/test_validation.py diff --git a/weave/type_serializers/Image/image_test.py b/tests/type_serializers/Image/image_test.py similarity index 100% rename from weave/type_serializers/Image/image_test.py rename to tests/type_serializers/Image/image_test.py diff --git a/weave/tests/wandb_system_tests_conftest.py b/tests/wandb_system_tests_conftest.py similarity index 100% rename from weave/tests/wandb_system_tests_conftest.py rename to tests/wandb_system_tests_conftest.py diff --git a/weave-js/src/components/FancyPage/useProjectSidebar.ts b/weave-js/src/components/FancyPage/useProjectSidebar.ts index cdd5aa1e246..6fcc50703f9 100644 --- a/weave-js/src/components/FancyPage/useProjectSidebar.ts +++ b/weave-js/src/components/FancyPage/useProjectSidebar.ts @@ -9,15 +9,24 @@ export const useProjectSidebar = ( viewingRestricted: boolean, hasModelsData: boolean, hasWeaveData: boolean, - hasTraceBackend: boolean = true + hasTraceBackend: boolean = true, + hasModelsAccess: boolean = true ): FancyPageSidebarItem[] => { // Should show models sidebar items if we have models data or if we don't have a trace backend - const showModelsSidebarItems = hasModelsData || !hasTraceBackend; + let showModelsSidebarItems = hasModelsData || !hasTraceBackend; // Should show weave sidebar items if we have weave data and we have a trace backend - const showWeaveSidebarItems = hasWeaveData && hasTraceBackend; + let showWeaveSidebarItems = hasWeaveData && hasTraceBackend; - const isModelsOnly = showModelsSidebarItems && !showWeaveSidebarItems; - const isWeaveOnly = !showModelsSidebarItems && showWeaveSidebarItems; + let isModelsOnly = showModelsSidebarItems && !showWeaveSidebarItems; + let isWeaveOnly = !showModelsSidebarItems && showWeaveSidebarItems; + + if (!hasModelsAccess) { + showModelsSidebarItems = false; + isModelsOnly = false; + + showWeaveSidebarItems = true; + isWeaveOnly = true; + } const isNoSidebarItems = !showModelsSidebarItems && !showWeaveSidebarItems; const isBothSidebarItems = showModelsSidebarItems && showWeaveSidebarItems; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/CellFilterWrapper.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/CellFilterWrapper.tsx index afbb66e8b39..2ac8de0debc 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/CellFilterWrapper.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/filters/CellFilterWrapper.tsx @@ -5,12 +5,19 @@ import React from 'react'; +export type OnAddFilter = ( + field: string, + operator: string | null, + value: any, + rowId: string +) => void; type CellFilterWrapperProps = { children: React.ReactNode; - onAddFilter?: (key: string, operation: string | null, value: any) => void; + onAddFilter?: OnAddFilter; field: string; operation: string | null; value: any; + rowId: string; style?: React.CSSProperties; }; @@ -20,6 +27,7 @@ export const CellFilterWrapper = ({ field, operation, value, + rowId, style, }: CellFilterWrapperProps) => { const onClickCapture = onAddFilter @@ -28,7 +36,7 @@ export const CellFilterWrapper = ({ if (e.altKey) { e.stopPropagation(); e.preventDefault(); - onAddFilter(field, operation, value); + onAddFilter(field, operation, value, rowId); } } : undefined; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/DataTableView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/DataTableView.tsx index e05df4f1b2b..4f6d6340dac 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/DataTableView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/DataTableView.tsx @@ -1,6 +1,12 @@ import LinkIcon from '@mui/icons-material/Link'; -import {Alert, Box} from '@mui/material'; -import {GridColDef, useGridApiRef} from '@mui/x-data-grid-pro'; +import {Box} from '@mui/material'; +import { + GridColDef, + GridEventListener, + GridPaginationModel, + GridSortModel, + useGridApiRef, +} from '@mui/x-data-grid-pro'; import { isAssignableTo, list, @@ -9,29 +15,33 @@ import { typedDict, typedDictPropertyTypes, } from '@wandb/weave/core'; +import {useDeepMemo} from '@wandb/weave/hookUtils'; import _ from 'lodash'; -import React, {FC, useCallback, useContext, useEffect, useMemo} from 'react'; +import React, { + FC, + useCallback, + useContext, + useEffect, + useMemo, + useRef, + useState, +} from 'react'; import {useHistory} from 'react-router-dom'; -import { - isWeaveObjectRef, - parseRef, - WeaveObjectRef, -} from '../../../../../../react'; +import {isWeaveObjectRef, parseRef} from '../../../../../../react'; import {flattenObjectPreservingWeaveTypes} from '../../../Browse2/browse2Util'; import {CellValue} from '../../../Browse2/CellValue'; +import {parseRefMaybe} from '../../../Browse2/SmallRef'; import { useWeaveflowCurrentRouteContext, WeaveflowPeekContext, } from '../../context'; +import {DEFAULT_PAGE_SIZE} from '../../grid/pagination'; import {StyledDataGrid} from '../../StyledDataGrid'; import {CustomWeaveTypeProjectContext} from '../../typeViews/CustomWeaveTypeDispatcher'; import {TABLE_ID_EDGE_NAME} from '../wfReactInterface/constants'; import {useWFHooks} from '../wfReactInterface/context'; -import {TableQuery} from '../wfReactInterface/wfDataModelHooksInterface'; - -// Controls the maximum number of rows to display in the table -const MAX_ROWS = 10_000; +import {SortBy} from '../wfReactInterface/traceServerClientTypes'; // Controls whether to use a table for arrays or not. export const USE_TABLE_FOR_ARRAYS = false; @@ -53,26 +63,99 @@ export const WeaveCHTable: FC<{ // Gets the source of this Table (set by a few levels up) const sourceRef = useContext(WeaveCHTableSourceRefContext); - // Retrieves the data for the table, with a limit of MAX_ROWS + 1 - const fetchQuery = useValueOfRefUri(props.tableRefUri, { - limit: MAX_ROWS + 1, - }); + const {useTableQueryStats, useTableRowsQuery} = useWFHooks(); const parsedRef = useMemo( - () => parseRef(props.tableRefUri) as WeaveObjectRef, + () => parseRefMaybe(props.tableRefUri), [props.tableRefUri] ); - // Determines if the table itself is truncated - const isTruncated = useMemo(() => { - return (fetchQuery.result ?? []).length > MAX_ROWS; - }, [fetchQuery.result]); + const lookupKey = useMemo(() => { + if ( + parsedRef == null || + !isWeaveObjectRef(parsedRef) || + parsedRef.weaveKind !== 'table' + ) { + return null; + } + return { + entity: parsedRef.entityName, + project: parsedRef.projectName, + digest: parsedRef.artifactVersion, + }; + }, [parsedRef]); + + const numRowsQuery = useTableQueryStats( + lookupKey?.entity ?? '', + lookupKey?.project ?? '', + lookupKey?.digest ?? '', + {skip: lookupKey == null} + ); + + const [limit, setLimit] = useState(DEFAULT_PAGE_SIZE); + const [offset, setOffset] = useState(0); + const [sortBy, setSortBy] = useState([]); + const [sortModel, setSortModel] = useState([]); + + const onSortModelChange = useCallback( + (model: GridSortModel) => { + setSortModel(model); + }, + [setSortModel] + ); + + const [paginationModel, setPaginationModel] = useState({ + page: 0, + pageSize: DEFAULT_PAGE_SIZE, + }); + + const onPaginationModelChange = useCallback( + (model: GridPaginationModel) => { + setPaginationModel(model); + }, + [setPaginationModel] + ); + + useEffect(() => { + setOffset(paginationModel.page * paginationModel.pageSize); + setLimit(paginationModel.pageSize); + }, [paginationModel]); + + useEffect(() => { + setSortBy( + sortModel.map(sort => ({ + field: sort.field, + direction: sort.sort === 'asc' ? 'asc' : 'desc', + })) + ); + }, [sortModel]); + + const fetchQuery = useTableRowsQuery( + lookupKey?.entity ?? '', + lookupKey?.project ?? '', + lookupKey?.digest ?? '', + undefined, + limit, + offset, + sortBy, + {skip: lookupKey == null} + ); + + const [loadedRows, setLoadedRows] = useState>([]); + + useEffect(() => { + if (!fetchQuery.loading && fetchQuery.result) { + setLoadedRows(fetchQuery.result.rows); + } + }, [fetchQuery.loading, fetchQuery.result]); + + const pagedRows = useMemo(() => { + return loadedRows ?? []; + }, [loadedRows]); - // `sourceRows` are the effective rows to display. If the table is truncated, - // we only display the first MAX_ROWS rows. - const sourceRows = useMemo(() => { - return (fetchQuery.result ?? []).slice(0, MAX_ROWS); - }, [fetchQuery.result]); + const totalRows = useMemo(() => { + return numRowsQuery.result?.count ?? pagedRows.length; + }, [numRowsQuery.result, pagedRows]); // In this block, we setup a click handler. The underlying datatable is more general // and not aware of the nuances of our links and ref model. Therefore, we handle @@ -105,31 +188,60 @@ export const WeaveCHTable: FC<{ [history, sourceRef, router] ); + const pageControl: DataTableServerSidePaginationControls = useMemo( + () => ({ + paginationModel, + onPaginationModelChange, + totalRows, + pageSizeOptions: [DEFAULT_PAGE_SIZE], + sortModel, + onSortModelChange, + }), + [ + paginationModel, + onPaginationModelChange, + totalRows, + sortModel, + onSortModelChange, + ] + ); + return ( + value={{ + entity: lookupKey?.entity ?? '', + project: lookupKey?.project ?? '', + }}> ); }; +type DataTableServerSidePaginationControls = { + paginationModel: GridPaginationModel; + onPaginationModelChange: (model: GridPaginationModel) => void; + totalRows: number; + pageSizeOptions: number[]; + sortModel: GridSortModel; + onSortModelChange: (model: GridSortModel) => void; +}; + // This is a general purpose table view that can be used to render any data. export const DataTableView: FC<{ data: Array<{[key: string]: any}>; fullHeight?: boolean; loading?: boolean; displayKey?: string; - isTruncated?: boolean; onLinkClick?: (row: any) => void; + pageControl?: DataTableServerSidePaginationControls; + autoPageSize?: boolean; }> = props => { const apiRef = useGridApiRef(); const {isPeeking} = useContext(WeaveflowPeekContext); @@ -157,31 +269,11 @@ export const DataTableView: FC<{ () => (dataAsListOfDict ?? []).map((row, i) => ({ id: i, - ...row, + data: row, })), [dataAsListOfDict] ); - // This effect will resize the columns after the table is rendered. We use a - // timeout to ensure that the table has been rendered before we resize the - // columns. - useEffect(() => { - let mounted = true; - const timeoutId = setTimeout(() => { - if (!mounted) { - return; - } - apiRef.current.autosizeColumns({ - includeHeaders: true, - includeOutliers: true, - }); - }, 0); - return () => { - mounted = false; - clearInterval(timeoutId); - }; - }, [gridRows, apiRef]); - // Next, we determine the type of the data. Previously, we used the WeaveJS // `Type` system to determine the type of the data. However, this is way to // slow for big tables and too detailed for our purposes. We just need to know @@ -242,9 +334,16 @@ export const DataTableView: FC<{ return typedDict(propertyTypes); }, [dataAsListOfDict]); + const propsDataRef = useRef(props.data); + useEffect(() => { + propsDataRef.current = props.data; + }, [props.data]); + + const objectTypeDeepMemo = useDeepMemo(objectType); + // Here we define the column spec for the table. It is based on // the type of the data and if we have a link or not. - const columnSpec: GridColDef[] = useMemo(() => { + const dataInitializedColumnSpec: GridColDef[] = useMemo(() => { const res: GridColDef[] = []; if (props.onLinkClick) { res.push({ @@ -256,21 +355,26 @@ export const DataTableView: FC<{ style={{ cursor: 'pointer', }} - onClick={() => props.onLinkClick!(props.data[params.id as number])} + onClick={() => + props.onLinkClick!(propsDataRef.current[params.id as number]) + } /> ), }); } - return [...res, ...typeToDataGridColumnSpec(objectType, isPeeking, true)]; - }, [props.onLinkClick, props.data, objectType, isPeeking]); + return [ + ...res, + ...typeToDataGridColumnSpec(objectTypeDeepMemo, isPeeking, true), + ]; + }, [props.onLinkClick, objectTypeDeepMemo, isPeeking]); // Finally, we do some math to determine the height of the table. const isSingleColumn = USE_TABLE_FOR_ARRAYS && - columnSpec.length === 1 && - columnSpec[0].field === ''; + dataInitializedColumnSpec.length === 1 && + dataInitializedColumnSpec[0].field === ''; if (isSingleColumn) { - columnSpec[0].flex = 1; + dataInitializedColumnSpec[0].flex = 1; } const hideHeader = isSingleColumn; const displayRows = 10; @@ -284,6 +388,74 @@ export const DataTableView: FC<{ (hideHeader ? 0 : headerHeight) + (hideFooter ? 0 : footerHeight) + (props.loading ? loadingHeight : contentHeight); + + const [columnSpec, setColumnSpec] = useState([]); + + // This effect will resize the columns after the table is rendered. We use a + // timeout to ensure that the table has been rendered before we resize the + // columns. + const hasLinkClick = props.onLinkClick != null; + useEffect(() => { + let mounted = true; + + // Update the column set if the column spec changes (ignore empty columns + // which can occur during loading) + setColumnSpec(curr => { + const dataFieldSet = new Set( + dataInitializedColumnSpec.map(col => col.field) + ); + const currFieldSet = new Set(curr.map(col => col.field)); + if (dataFieldSet.size > (hasLinkClick ? 1 : 0)) { + // Update if they are different + if (!_.isEqual(dataFieldSet, currFieldSet)) { + return dataInitializedColumnSpec; + } + } + return curr; + }); + + const timeoutId = setTimeout(() => { + if (!mounted) { + return; + } + apiRef.current.autosizeColumns({ + includeHeaders: true, + includeOutliers: true, + }); + // apiRef.current.forceUpdate() + }, 0); + return () => { + mounted = false; + clearInterval(timeoutId); + }; + }, [dataInitializedColumnSpec, apiRef, hasLinkClick]); + + const onColumnOrderChange: GridEventListener<'columnOrderChange'> = + useCallback(params => { + const oldIndex = params.oldIndex; + const newIndex = params.targetIndex; + setColumnSpec(currSpec => { + const col = currSpec[oldIndex]; + currSpec.splice(oldIndex, 1); + currSpec.splice(newIndex, 0, col); + return currSpec; + }); + }, []); + + const onColumnWidthChange: GridEventListener<'columnWidthChange'> = + useCallback(params => { + const field = params.colDef.field; + const newWidth = params.width; + setColumnSpec(currSpec => { + for (const col of currSpec) { + if (col.field === field) { + col.width = newWidth; + } + } + return currSpec; + }); + }, []); + return (
- {props.isTruncated && ( - - Showing {dataAsListOfDict.length.toLocaleString()} rows only. - - )}
@@ -356,6 +532,7 @@ export const typeToDataGridColumnSpec = ( ): GridColDef[] => { if (isAssignableTo(type, {type: 'typedDict', propertyTypes: {}})) { const maxWidth = window.innerWidth * (isPeeking ? 0.5 : 0.75); + const minWidth = 100; const propertyTypes = typedDictPropertyTypes(type); return Object.entries(propertyTypes).flatMap(([key, valueType]) => { const innerKey = parentKey ? `${parentKey}.${key}` : key; @@ -382,12 +559,14 @@ export const typeToDataGridColumnSpec = ( return [ { maxWidth, + minWidth, + flex: 1, type: 'string' as const, editable: false, field: innerKey, headerName: innerKey, renderCell: params => { - const listValue = params.row[innerKey]; + const listValue = params.row.data[innerKey]; if (listValue == null) { return '-'; } @@ -400,12 +579,14 @@ export const typeToDataGridColumnSpec = ( return [ { maxWidth, + minWidth, + flex: 1, type: colType, editable: editable && !disableEdits, field: innerKey, headerName: innerKey, renderCell: params => { - const data = params.row[innerKey]; + const data = params.row.data[innerKey]; return ; }, }, @@ -419,26 +600,3 @@ export const typeToDataGridColumnSpec = ( } return []; }; - -const useValueOfRefUri = (refUriStr: string, tableQuery?: TableQuery) => { - const {useRefsData} = useWFHooks(); - const data = useRefsData([refUriStr], tableQuery); - return useMemo(() => { - if (data.loading) { - return { - loading: true, - result: undefined, - }; - } - if (data.result == null || data.result.length === 0) { - return { - loading: true, - result: undefined, - }; - } - return { - loading: false, - result: data.result[0], - }; - }, [data]); -}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx index 5f4b02ef63c..7db8e1e9c1b 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx @@ -42,7 +42,7 @@ export const ValueView = ({data, isExpanded}: ValueViewProps) => { } } if (USE_TABLE_FOR_ARRAYS && data.valueType === 'array') { - return ; + return ; } if (data.valueType === 'array' && data.value.length === 0) { return Empty List; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx index bbe9d5c230c..39b54db8a14 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx @@ -39,7 +39,9 @@ import {useHistory} from 'react-router-dom'; import {useViewerInfo} from '../../../../../../common/hooks/useViewerInfo'; import {A, TargetBlank} from '../../../../../../common/util/links'; import {Tailwind} from '../../../../../Tailwind'; +import {flattenObjectPreservingWeaveTypes} from '../../../Browse2/browse2Util'; import {useWeaveflowCurrentRouteContext} from '../../context'; +import {OnAddFilter} from '../../filters/CellFilterWrapper'; import {getDefaultOperatorForValue} from '../../filters/common'; import {FilterPanel} from '../../filters/FilterPanel'; import {DEFAULT_PAGE_SIZE} from '../../grid/pagination'; @@ -62,6 +64,7 @@ import { import {useWFHooks} from '../wfReactInterface/context'; import {TraceCallSchema} from '../wfReactInterface/traceServerClientTypes'; import {traceCallToUICallSchema} from '../wfReactInterface/tsDataModelHooks'; +import {EXPANDED_REF_REF_KEY} from '../wfReactInterface/tsDataModelHooksCallRefExpansion'; import {objectVersionNiceString} from '../wfReactInterface/utilities'; import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface'; import {CallsCustomColumnMenu} from './CallsCustomColumnMenu'; @@ -268,14 +271,75 @@ export const CallsTable: FC<{ }, [calls, effectiveFilter]); // Construct Flattened Table Data - const tableData: TraceCallSchema[] = useMemo( + const tableData: FlattenedCallData[] = useMemo( () => prepareFlattenedCallDataForTable(callsResult), [callsResult] ); - const onAddFilter = + // This is a specific helper that is used when the user attempts to option-click + // a cell that is a child cell of an expanded ref. In this case, we want to + // add a filter on the parent ref itself, not the child cell. Once we can properly + // filter by reffed values on the backend, this can be removed. + const getFieldAndValueForRefExpandedFilter = useCallback( + (field: string, rowId: string) => { + if (columnIsRefExpanded(field)) { + // In this case, we actually just want to filter by the parent ref itself. + // This means we need to: + // 1. Determine the column of the highest level ancestor column with a ref + // 2. Get the value of that corresponding cell (ref column @ row) + // 3. Add a filter for that ref on that column. + // The acknowledge drawback of this approach is that we are not filtering by that + // cell's value, but rather the entire object itself. This still might confuse users, + // but is better than returning nothing. + const fieldParts = field.split('.'); + let ancestorField: string | null = null; + let targetRef: string | null = null; + for (let i = 1; i <= fieldParts.length; i++) { + const ancestorFieldCandidate = fieldParts.slice(0, i).join('.'); + if (expandedRefCols.has(ancestorFieldCandidate)) { + const candidateRow = callsResult.find( + row => row.traceCall?.id === rowId + )?.traceCall; + if (candidateRow != null) { + const flattenedCandidateRow = + flattenObjectPreservingWeaveTypes(candidateRow); + const targetRefCandidate = + flattenedCandidateRow[ + ancestorFieldCandidate + '.' + EXPANDED_REF_REF_KEY + ]; + if (targetRefCandidate != null) { + ancestorField = ancestorFieldCandidate; + targetRef = targetRefCandidate; + break; + } + } + } + } + if (ancestorField == null) { + console.warn('Could not find ancestor ref column for', field); + return null; + } + + return {value: targetRef, field: ancestorField}; + } + return null; + }, + [callsResult, columnIsRefExpanded, expandedRefCols] + ); + + const onAddFilter: OnAddFilter | undefined = filterModel && setFilterModel - ? (field: string, operator: string | null, value: any) => { + ? (field: string, operator: string | null, value: any, rowId: string) => { + // This condition is used to filter by the parent ref itself, not the child cell. + // Should be removed once we can filter by reffed values on the backend. + const expandedRef = getFieldAndValueForRefExpandedFilter( + field, + rowId + ); + if (expandedRef != null) { + value = expandedRef.value; + field = expandedRef.field; + } const op = operator ? operator : getDefaultOperatorForValue(value); const newModel = { ...filterModel, @@ -306,6 +370,18 @@ export const CallsTable: FC<{ onAddFilter ); + // This contains columns which are suitable for selection and raw data + // entry. Notably, not children of expanded refs. + const filterFriendlyColumnInfo = useMemo(() => { + const filteredCols = columns.cols.filter( + col => !columnIsRefExpanded(col.field) + ); + return { + cols: filteredCols, + colGroupingModel: columns.colGroupingModel, + }; + }, [columnIsRefExpanded, columns.colGroupingModel, columns.cols]); + // Now, there are 4 primary controls: // 1. Op Version // 2. Input Object Version @@ -634,7 +710,7 @@ export const CallsTable: FC<{ {filterModel && setFilterModel && ( { return pathname.split('/').pop() ?? null; }; +export type FlattenedCallData = TraceCallSchema & {[key: string]: string}; + function prepareFlattenedCallDataForTable( callsResult: CallSchema[] -): Array { +): FlattenedCallData[] { return prepareFlattenedDataForTable(callsResult.map(c => c.traceCall)); } diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx index 8c2c84ecb48..bc1ffe08388 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx @@ -18,7 +18,7 @@ import {isWeaveObjectRef, parseRef} from '../../../../../../react'; import {makeRefCall} from '../../../../../../util/refs'; import {Timestamp} from '../../../../../Timestamp'; import {Reactions} from '../../feedback/Reactions'; -import {CellFilterWrapper} from '../../filters/CellFilterWrapper'; +import {CellFilterWrapper, OnAddFilter} from '../../filters/CellFilterWrapper'; import {isWeaveRef} from '../../filters/common'; import { getTokensAndCostFromUsage, @@ -54,7 +54,7 @@ export const useCallsTableColumns = ( onCollapse: (col: string) => void, onExpand: (col: string) => void, columnIsRefExpanded: (col: string) => boolean, - onAddFilter?: (field: string, operator: string | null, value: any) => void + onAddFilter?: OnAddFilter ) => { const [userDefinedColumnWidths, setUserDefinedColumnWidths] = useState< Record @@ -168,7 +168,7 @@ function buildCallsTableColumns( onExpand: (col: string) => void, columnIsRefExpanded: (col: string) => boolean, userDefinedColumnWidths: Record, - onAddFilter?: (field: string, operator: string | null, value: any) => void + onAddFilter?: OnAddFilter ): { cols: Array>; colGroupingModel: GridColumnGroupingModel; @@ -299,8 +299,8 @@ function buildCallsTableColumns( onExpand, // TODO (Tim) - (BackendExpansion): This can be removed once we support backend expansion! key => !columnIsRefExpanded(key) && !columnsWithRefs.has(key), - (key, operator, value) => { - onAddFilter?.(key, operator, value); + (key, operator, value, rowId) => { + onAddFilter?.(key, operator, value, rowId); } ); cols.push(...newCols); @@ -323,6 +323,7 @@ function buildCallsTableColumns( @@ -350,6 +351,7 @@ function buildCallsTableColumns( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ToolCalls.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ToolCalls.tsx index 26317b6cb52..7a624fa77ab 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ToolCalls.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ToolCalls.tsx @@ -1,4 +1,5 @@ -import React from 'react'; +import Prism from 'prismjs'; +import React, {useEffect, useRef} from 'react'; import {Alert} from '../../../../../Alert'; import {ToolCall} from './types'; @@ -8,6 +9,13 @@ type OneToolCallProps = { }; const OneToolCall = ({toolCall}: OneToolCallProps) => { + const ref = useRef(null); + useEffect(() => { + if (ref.current) { + Prism.highlightElement(ref.current!); + } + }); + const {function: toolCallFunction} = toolCall; const {name, arguments: args} = toolCallFunction; let parsedArgs: any = null; @@ -21,9 +29,14 @@ const OneToolCall = ({toolCall}: OneToolCallProps) => { // The model does not always generate valid JSON return Invalid JSON: {args}; } + return ( - - {name}({parsedArgs}) + + {name}( + + {parsedArgs} + + ) ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx index 0b581e464a0..28002ee9bde 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx @@ -180,6 +180,8 @@ const ObjectVersionPageInner: React.FC<{ return viewerData; }, [viewerData]); + const isDataset = baseObjectClass === 'Dataset' && refExtra == null; + return ( + ( onCollapse?: (col: string) => void, onExpand?: (col: string) => void, columnIsSortable?: (col: string) => boolean, - onAddFilter?: (field: string, operator: string | null, value: any) => void + onAddFilter?: OnAddFilter ) => { const cols: Array> = []; @@ -279,12 +282,14 @@ export const buildDynamicColumns = ( }, renderCell: cellParams => { const {entity, project} = entityProjectFromRow(cellParams.row); + const val = valueForKey(cellParams.row, key); if (val === undefined) { return ( @@ -296,6 +301,7 @@ export const buildDynamicColumns = ( { + return this.makeRequest( + '/table/query_stats', + req + ); + } + public feedbackCreate(req: FeedbackCreateReq): Promise { return this.makeRequest( '/feedback/create', diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts index 12fa9872465..515969919dd 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts @@ -963,6 +963,8 @@ const useTableQuery = makeTraceServerEndpointHook< string, traceServerTypes.TraceTableQueryReq['filter'], traceServerTypes.TraceTableQueryReq['limit'], + traceServerTypes.TraceTableQueryReq['offset'], + traceServerTypes.TraceTableQueryReq['sort_by'], {skip?: boolean}? ], any[] @@ -973,6 +975,8 @@ const useTableQuery = makeTraceServerEndpointHook< digest: traceServerTypes.TraceTableQueryReq['digest'], filter: traceServerTypes.TraceTableQueryReq['filter'], limit: traceServerTypes.TraceTableQueryReq['limit'], + offset: traceServerTypes.TraceTableQueryReq['offset'], + sortBy: traceServerTypes.TraceTableQueryReq['sort_by'], opts?: {skip?: boolean} ) => ({ params: { @@ -980,6 +984,8 @@ const useTableQuery = makeTraceServerEndpointHook< digest, filter, limit, + offset, + sort_by: sortBy, }, skip: opts?.skip, }), @@ -1136,6 +1142,8 @@ const useRefsData = ( tableUriDigest, tableQueryFilter, tableQuery?.limit, + undefined, + undefined, {skip: tableRefUris.length === 0 || cachedTableResult != null} ); @@ -1192,6 +1200,129 @@ const useRefsData = ( ]); }; +const useTableRowsQuery = ( + entity: string, + project: string, + digest: string, + filter?: traceServerTypes.TraceTableQueryReq['filter'], + limit?: traceServerTypes.TraceTableQueryReq['limit'], + offset?: traceServerTypes.TraceTableQueryReq['offset'], + sortBy?: traceServerTypes.TraceTableQueryReq['sort_by'], + opts?: {skip?: boolean} +): Loadable => { + const getTsClient = useGetTraceServerClientContext(); + const [queryRes, setQueryRes] = + useState(null); + const loadingRef = useRef(false); + + const projectId = projectIdFromParts({entity, project}); + + const doFetch = useCallback(() => { + if (opts?.skip) { + return; + } + setQueryRes(null); + loadingRef.current = true; + + const req: traceServerTypes.TraceTableQueryReq = { + project_id: projectId, + digest, + filter, + limit, + offset, + sort_by: sortBy, + }; + + getTsClient() + .tableQuery(req) + .then(res => { + loadingRef.current = false; + setQueryRes(res); + }) + .catch(err => { + loadingRef.current = false; + console.error('Error fetching table rows:', err); + setQueryRes(null); + }); + }, [ + getTsClient, + projectId, + digest, + filter, + limit, + offset, + sortBy, + opts?.skip, + ]); + + useEffect(() => { + doFetch(); + }, [doFetch]); + + return useMemo(() => { + if (opts?.skip) { + return {loading: false, result: null}; + } + if (queryRes == null || loadingRef.current) { + return {loading: true, result: null}; + } + return {loading: false, result: queryRes}; + }, [queryRes, opts?.skip]); +}; + +const useTableQueryStats = ( + entity: string, + project: string, + digest: string, + opts?: {skip?: boolean} +): Loadable => { + const getTsClient = useGetTraceServerClientContext(); + const [statsRes, setStatsRes] = + useState(null); + const loadingRef = useRef(false); + + const projectId = projectIdFromParts({entity, project}); + + const doFetch = useCallback(() => { + if (opts?.skip) { + return; + } + setStatsRes(null); + loadingRef.current = true; + + const req: traceServerTypes.TraceTableQueryStatsReq = { + project_id: projectId, + digest, + }; + + getTsClient() + .tableQueryStats(req) + .then(res => { + loadingRef.current = false; + setStatsRes(res); + }) + .catch(err => { + loadingRef.current = false; + console.error('Error fetching table query stats:', err); + setStatsRes(null); + }); + }, [getTsClient, projectId, digest, opts?.skip]); + + useEffect(() => { + doFetch(); + }, [doFetch]); + + return useMemo(() => { + if (opts?.skip) { + return {loading: false, result: null}; + } + if (statsRes == null || loadingRef.current) { + return {loading: true, result: null}; + } + return {loading: false, result: statsRes}; + }, [statsRes, opts?.skip]); +}; + const useApplyMutationsToRef = (): (( refUri: string, edits: RefMutation[] @@ -1534,6 +1665,8 @@ export const tsWFDataModelHooks: WFDataModelHooksInterface = { useApplyMutationsToRef, useFeedback, useFileContent, + useTableRowsQuery, + useTableQueryStats, derived: { useChildCallsForCompare, useGetRefsType, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts index a760cbfb207..1dba29c06db 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts @@ -218,6 +218,22 @@ export type WFDataModelHooksInterface = { useObjectVersion: ( key: ObjectVersionKey | null ) => Loadable; + useTableRowsQuery: ( + entity: string, + project: string, + digest: string, + filter?: traceServerClientTypes.TraceTableQueryReq['filter'], + limit?: traceServerClientTypes.TraceTableQueryReq['limit'], + offset?: traceServerClientTypes.TraceTableQueryReq['offset'], + sortBy?: traceServerClientTypes.TraceTableQueryReq['sort_by'], + opts?: {skip?: boolean} + ) => Loadable; + useTableQueryStats: ( + entity: string, + project: string, + digest: string, + opts?: {skip?: boolean} + ) => Loadable; useRootObjectVersions: ( entity: string, project: string, diff --git a/weave-js/src/components/Panel2/PanelStringHistogram.tsx b/weave-js/src/components/Panel2/PanelStringHistogram.tsx index 9bf8fccb54b..9226bf2c6de 100644 --- a/weave-js/src/components/Panel2/PanelStringHistogram.tsx +++ b/weave-js/src/components/Panel2/PanelStringHistogram.tsx @@ -169,7 +169,12 @@ const PanelStringHistogramInner: React.FC< const isColorable = colorNode.nodeType !== 'void'; const nodeValueQuery = CGReact.useNodeValue(props.input); const data = useMemo(() => { - if (nodeValueQuery.loading || (isColorable && colorNodeValue.loading)) { + const shouldEject = [ + nodeValueQuery.loading, + isColorable && colorNodeValue.loading, + nodeValueQuery.result == null, + ].some(Boolean); + if (shouldEject) { return []; } if (!isColorable) { diff --git a/weave-js/src/core/README.md b/weave-js/src/core/README.md index b6ab48c56ba..eb3524ab0db 100644 --- a/weave-js/src/core/README.md +++ b/weave-js/src/core/README.md @@ -424,5 +424,5 @@ See [generateDocs.ts](./generateDocs.ts) for implementation. # Testing Against Weave Python -1. Start a Weave Server using `./scripts/weave_server.sh` from the `weave/weave` directory +1. Start a Weave Server using `./scripts/weave_server.sh` from the `weave_query` directory 2. Run `yarn test:python-backend` from this directory diff --git a/weave-js/src/react.tsx b/weave-js/src/react.tsx index d11c404d8cd..ad7a4cb3aa6 100644 --- a/weave-js/src/react.tsx +++ b/weave-js/src/react.tsx @@ -1,3 +1,4 @@ +import * as Sentry from '@sentry/react'; import { callOpVeryUnsafe, Client, @@ -354,8 +355,9 @@ export const useNodeValue = ( const message = 'Node execution failed (useNodeValue): ' + errorToText(error); // console.error(message); - - throw new UseNodeValueServerExecutionError(message); + const err = new UseNodeValueServerExecutionError(message); + Sentry.captureException(err, {fingerprint: ['useNodeValue']}); + throw err; } if (isConstNode(node)) { if (isFunction(node.type)) { diff --git a/weave/__init__.py b/weave/__init__.py index 1cd10716131..c865f4c5ad7 100644 --- a/weave/__init__.py +++ b/weave/__init__.py @@ -1,17 +1,18 @@ """The top-level functions and classes for working with Weave.""" from weave import version - from weave.trace.api import * __version__ = version.VERSION -from weave.flow.obj import Object +from weave.flow.agent import Agent as Agent +from weave.flow.agent import AgentState as AgentState from weave.flow.dataset import Dataset -from weave.flow.model import Model from weave.flow.eval import Evaluation, Scorer -from weave.flow.agent import Agent, AgentState -from weave.trace.util import ThreadPoolExecutor, Thread +from weave.flow.model import Model +from weave.flow.obj import Object +from weave.trace.util import Thread as Thread +from weave.trace.util import ThreadPoolExecutor as ThreadPoolExecutor # Special object informing doc generation tooling which symbols # to document & to associate with this module. diff --git a/weave/deploy/gcp/__init__.py b/weave/deploy/gcp/__init__.py index d0bab77e2e5..157c885499e 100644 --- a/weave/deploy/gcp/__init__.py +++ b/weave/deploy/gcp/__init__.py @@ -183,7 +183,7 @@ def ensure_secret( [ "secrets", "list", - f"--filter=name~^.*\/{name}$", + rf"--filter=name~^.*\/{name}$", f"--project={project}", "--format=json", ] diff --git a/weave/flow/eval.py b/weave/flow/eval.py index 54715b61e89..281bb98bec1 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -46,7 +46,7 @@ def async_call(func: Union[Callable, Op], *args: Any, **kwargs: Any) -> Coroutin return asyncio.to_thread(func, *args, **kwargs) -class EvaluationResults(weave.Object): +class EvaluationResults(Object): rows: weave.Table diff --git a/weave/integrations/integration_utilities.py b/weave/integrations/integration_utilities.py index 0ae65d10b6d..017bcc9017f 100644 --- a/weave/integrations/integration_utilities.py +++ b/weave/integrations/integration_utilities.py @@ -2,10 +2,8 @@ import re from typing import Any, Iterable, Union -import weave from weave.trace.refs import OpRef, parse_uri from weave.trace.weave_client import Call, CallsIter -from weave.trace_server import trace_server_interface as tsi MAX_RUN_NAME_LENGTH = 128 @@ -73,25 +71,6 @@ def _hash_str(s: str, hash_len: int) -> str: return hashlib.md5(s.encode()).hexdigest()[:hash_len] -def _get_call_output(call: tsi.CallSchema) -> Any: - """This is a hack and should not be needed. We should be able to auto-resolve this for the user. - Keeping this here for now, but it should be removed in the future once we have a better solution. - """ - call_output = call.output - if isinstance(call_output, str) and call_output.startswith("weave://"): - return weave.ref(call_output).get() - return call_output - - -def _get_op_name(s: str) -> str: - """This is a hack and should not be needed. We should be able to auto-resolve this for the user. - Keeping this here for now, but it should be removed in the future once we have a better solution. - """ - _, s = s.split("weave:///shawn/test-project/op/", 1) - s, _ = s.split(":", 1) - return s - - def flatten_calls(calls: Union[Iterable[Call], CallsIter], *, depth: int = 0) -> list: lst = [] for call in calls: @@ -116,3 +95,16 @@ def op_name_from_ref(ref: str) -> str: def filter_body(r: Any) -> Any: r.body = "" return r + + +def _make_string_of_length(n: int) -> str: + return "a" * n + + +def _truncated_str(tail_len: int, total_len: int) -> tuple: + name = ( + _make_string_of_length(total_len - tail_len - 1) + + "." + + _make_string_of_length(tail_len) + ) + return name, truncate_op_name(name) diff --git a/weave/trace/concurrent/futures.py b/weave/trace/concurrent/futures.py index 7ddeb6ef6f1..76b07ad23c6 100644 --- a/weave/trace/concurrent/futures.py +++ b/weave/trace/concurrent/futures.py @@ -91,11 +91,7 @@ def defer(self, f: Callable[..., T], *args: Any, **kwargs: Any) -> Future[T]: Returns: Future[T]: A Future object representing the eventual result of the function. """ - future = self._safe_submit(f, *args, **kwargs) - with self._active_futures_lock: - self._active_futures.append(future) - future.add_done_callback(self._future_done_callback) - return future + return self._safe_submit(f, *args, **kwargs) def then(self, futures: List[Future[T]], g: Callable[[List[T]], U]) -> Future[U]: """ @@ -173,7 +169,6 @@ def flush(self, timeout: Optional[float] = None) -> bool: except Exception as e: if get_raise_on_captured_errors(): raise - logger.error(f"Job failed during flush: {e}") return True def _future_done_callback(self, future: Future) -> None: @@ -181,6 +176,9 @@ def _future_done_callback(self, future: Future) -> None: with self._active_futures_lock: if future in self._active_futures: self._active_futures.remove(future) + exception = future.exception() + if exception: + logger.error(f"Task failed: {_format_exception(exception)}") def _shutdown(self) -> None: """Shutdown the thread pool executor. Should only be called when the program is exiting.""" @@ -223,12 +221,18 @@ def _safe_submit(self, f: Callable[..., T], *args: Any, **kwargs: Any) -> Future return self._execute_directly(wrapped, *args, **kwargs) try: - return self._executor.submit(wrapped, *args, **kwargs) + future = self._executor.submit(wrapped, *args, **kwargs) except Exception as e: if get_raise_on_captured_errors(): raise return self._execute_directly(wrapped, *args, **kwargs) + with self._active_futures_lock: + self._active_futures.append(future) + future.add_done_callback(self._future_done_callback) + + return future + def _execute_directly( self, f: Callable[..., T], *args: Any, **kwargs: Any ) -> Future[T]: @@ -238,8 +242,22 @@ def _execute_directly( res = f(*args, **kwargs) fut.set_result(res) except Exception as e: + logger.error(f"Task failed: {_format_exception(e)}") fut.set_exception(e) return fut +def _format_exception(e: BaseException) -> str: + exception_str = f"{type(e).__name__}: {e}" + return exception_str + # try: + # if hasattr(e, "__traceback__"): + # traceback_str = "".join(traceback.format_tb(e.__traceback__)) + # if traceback_str: + # exception_str += f"\nTraceback:\n{traceback_str}" + # return exception_str + # except: + # return exception_str + + __all__ = ["FutureExecutor"] diff --git a/weave/trace/vals.py b/weave/trace/vals.py index 5fac1dfe905..017f3b51e37 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -364,7 +364,7 @@ def _local_iter_with_remote_fallback(self) -> Generator[dict, None, None]: def _remote_iter(self) -> Generator[dict, None, None]: page_index = 0 - page_size = 1000 + page_size = 100 while True: if self.table_ref is None: break diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 08a244604f7..75fe1f7b8e7 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -28,7 +28,7 @@ import json import logging import threading -from collections import Counter, defaultdict +from collections import defaultdict from contextlib import contextmanager from typing import ( Any, @@ -114,6 +114,7 @@ FILE_CHUNK_SIZE = 100000 MAX_DELETE_CALLS_COUNT = 100 +MAX_CALLS_STREAM_BATCH_SIZE = 500 class NotFoundError(Exception): @@ -356,15 +357,8 @@ def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema] for call in hydrated_batch: yield tsi.CallSchema.model_validate(call) - # *** Dynamic Batch Size *** - # count the number of columns at each depth - depths = Counter(col.count(".") for col in expand_columns) - # take the max number of columns at any depth - max_count_at_ref_depth = max(depths.values()) - # divide max refs that we can resolve 1000 refs at any depth - max_size = 1000 // max_count_at_ref_depth - # double batch size up to what refs_read_batch can handle - batch_size = min(max_size, batch_size * 2) + # *** Dynamic increase from 10 to 500 *** + batch_size = min(MAX_CALLS_STREAM_BATCH_SIZE, batch_size * 10) batch = [] hydrated_batch = self._hydrate_calls( @@ -1440,9 +1434,8 @@ def _select_objs_query( parameters to be passed to the query. Must include all parameters for both conditions and object_id_conditions. metadata_only: - if metadata_only is True, then we exclude the val_dump field in the select query. - generally, "queries" should not include the val_dump, but "reads" should, as - the val_dump is the most expensive part of the query. + if metadata_only is True, then we return early and dont grab the value. + Otherwise, make a second query to grab the val_dump from the db """ if not conditions: conditions = ["1 = 1"] @@ -1475,11 +1468,7 @@ def _select_objs_query( if parameters is None: parameters = {} - # When metadata_only is false, dont actually read from the field - val_dump_field = "'{}' AS val_dump" if metadata_only else "val_dump" - - # The subquery is for deduplication of object versions by digest - select_query = f""" + select_without_val_dump_query = f""" SELECT project_id, object_id, @@ -1487,7 +1476,6 @@ def _select_objs_query( kind, base_object_class, refs, - val_dump, digest, is_op, version_index, @@ -1500,7 +1488,6 @@ def _select_objs_query( kind, base_object_class, refs, - val_dump, digest, is_op, row_number() OVER ( @@ -1518,7 +1505,6 @@ def _select_objs_query( kind, base_object_class, refs, - {val_dump_field}, digest, if (kind = 'op', 1, 0) AS is_op, row_number() OVER ( @@ -1540,7 +1526,7 @@ def _select_objs_query( {offset_part} """ query_result = self._query_stream( - select_query, + select_without_val_dump_query, {"project_id": project_id, **parameters}, ) result: list[SelectableCHObjSchema] = [] @@ -1556,19 +1542,50 @@ def _select_objs_query( "kind", "base_object_class", "refs", - "val_dump", "digest", "is_op", "version_index", "version_count", "is_latest", + "val_dump", ], - row, + # Add an empty val_dump to the end of the row + list(row) + ["{}"], ) ) ) ) + # -- Don't make second query for object values if metadata_only -- + if metadata_only: + return result + + # now get the val_dump for each object + object_ids = list(set([row.object_id for row in result])) + digests = list(set([row.digest for row in result])) + query = """ + SELECT object_id, digest, any(val_dump) + FROM object_versions + WHERE project_id = {project_id: String} AND + object_id IN {object_ids: Array(String)} AND + digest IN {digests: Array(String)} + GROUP BY object_id, digest + """ + parameters = { + "project_id": project_id, + "object_ids": object_ids, + "digests": digests, + } + query_result = self._query_stream(query, parameters) + # Map (object_id, digest) to val_dump + object_values: Dict[tuple[str, str], Any] = {} + for row in query_result: + (object_id, digest, val_dump) = row + object_values[(object_id, digest)] = val_dump + + # update the val_dump for each object + for obj in result: + obj.val_dump = object_values.get((obj.object_id, obj.digest), "{}") return result def _run_migrations(self) -> None: @@ -1581,7 +1598,7 @@ def _query_stream( query: str, parameters: Dict[str, Any], column_formats: Optional[Dict[str, Any]] = None, - ) -> Iterator[QueryResult]: + ) -> Iterator[tuple]: """Streams the results of a query from the database.""" summary = None parameters = _process_parameters(parameters) diff --git a/weave/trace_server/table_query_builder.py b/weave/trace_server/table_query_builder.py index 0a38a8245ec..c5c204c388e 100644 --- a/weave/trace_server/table_query_builder.py +++ b/weave/trace_server/table_query_builder.py @@ -23,30 +23,33 @@ def make_natural_sort_table_query( """ project_id_name = pb.add_param(project_id) digest_name = pb.add_param(digest) - sql_safe_dir = "ASC" if natural_direction == "ASC" else "DESC" - sql_safe_limit = ( - f"LIMIT {{{pb.add_param(limit)}: Int64}}" if limit is not None else "" - ) - sql_safe_offset = ( - f"OFFSET {{{pb.add_param(offset)}: Int64}}" if offset is not None else "" - ) + row_digests_selection = "row_digests" + if natural_direction.lower() == "desc": + row_digests_selection = f"reverse({row_digests_selection})" + if limit is not None and offset is None: + offset = 0 + if offset is not None: + if limit is None: + row_digests_selection = f"arraySlice({row_digests_selection}, 1 + {{{pb.add_param(offset)}: Int64}})" + else: + row_digests_selection = f"arraySlice({row_digests_selection}, 1 + {{{pb.add_param(offset)}: Int64}}, {{{pb.add_param(limit)}: Int64}})" query = f""" SELECT DISTINCT tr.digest, tr.val_dump, t.row_order FROM table_rows tr - RIGHT JOIN ( + INNER JOIN ( SELECT row_digest, row_number() OVER () AS row_order - FROM tables + FROM ( + SELECT {row_digests_selection} as row_digests + FROM tables + WHERE project_id = {{{project_id_name}: String}} + AND digest = {{{digest_name}: String}} + ) ARRAY JOIN row_digests AS row_digest - WHERE project_id = {{{project_id_name}: String}} - AND digest = {{{digest_name}: String}} - ORDER BY row_order {sql_safe_dir} - {sql_safe_limit} - {sql_safe_offset} ) AS t ON tr.digest = t.row_digest WHERE tr.project_id = {{{project_id_name}: String}} - ORDER BY row_order {sql_safe_dir} + ORDER BY row_order ASC """ return query @@ -88,12 +91,15 @@ def make_standard_table_query( ( SELECT DISTINCT tr.digest, tr.val_dump, t.row_order FROM table_rows tr - RIGHT JOIN ( + INNER JOIN ( SELECT row_digest, row_number() OVER () AS row_order - FROM tables + FROM ( + SELECT row_digests + FROM tables + WHERE project_id = {{{project_id_name}: String}} + AND digest = {{{digest_name}: String}} + ) ARRAY JOIN row_digests AS row_digest - WHERE project_id = {{{project_id_name}: String}} - AND digest = {{{digest_name}: String}} ) AS t ON tr.digest = t.row_digest WHERE tr.project_id = {{{project_id_name}: String}} {sql_safe_filter_clause} diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index 78f163749cf..fd8fd5d3fb6 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -57,7 +57,7 @@ def _is_retryable_exception(e: Exception) -> bool: # Unknown server error # TODO(np): We need to fix the server to return proper status codes # for downstream 401, 403, 404, etc... Those should propagate back to - # the clien + # the client. if e.response.status_code == 500: return False @@ -263,13 +263,13 @@ def call_start( req_as_obj = tsi.CallStartReq.model_validate(req) else: req_as_obj = req - if req_as_obj.starid == None or req_as_obj.startrace_id == None: + if req_as_obj.start.id == None or req_as_obj.start.trace_id == None: raise ValueError( "CallStartReq must have id and trace_id when batching." ) self.call_processor.enqueue([StartBatchItem(req=req_as_obj)]) return tsi.CallStartRes( - id=req_as_obj.starid, trace_id=req_as_obj.startrace_id + id=req_as_obj.start.id, trace_id=req_as_obj.start.trace_id ) return self._generic_request( "/call/start", req, tsi.CallStartReq, tsi.CallStartRes @@ -362,7 +362,7 @@ def table_create( """Similar to `calls/batch_upsert`, we can dynamically adjust the payload size due to the property that table creation can be decomposed into a series of updates. This is useful when the table creation size is too big to be sent in - a single reques We can create an empty table first, then update the table + a single request. We can create an empty table first, then update the table with the rows. """ if isinstance(req, dict): diff --git a/weave/weave_server.sh b/weave/weave_server.sh deleted file mode 100644 index 216605c8e88..00000000000 --- a/weave/weave_server.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/sh - -FLASK_APP=weave_server flask run --port 9994 diff --git a/weave_query/MANIFEST.in b/weave_query/MANIFEST.in new file mode 100644 index 00000000000..4727a536e70 --- /dev/null +++ b/weave_query/MANIFEST.in @@ -0,0 +1,6 @@ +include requirements.* +include wb_schema.gql +graft weave_query +global-exclude */__pycache__/* +global-exclude *.pyc +global-exclude */cassettes/* diff --git a/build_dist.py b/weave_query/build_dist.py old mode 100755 new mode 100644 similarity index 87% rename from build_dist.py rename to weave_query/build_dist.py index feac15dfcd2..a51d048b401 --- a/build_dist.py +++ b/weave_query/build_dist.py @@ -26,11 +26,11 @@ with open("MANIFEST.in") as f: for line in f.readlines(): if line.startswith("graft"): - known_dirs.add(line.split(" ")[1].strip().replace("weave/", "")) + known_dirs.add(line.split(" ")[1].strip().replace("weave_query/", "")) bad_dirs = set() -for name in os.listdir("./weave"): - if os.path.isfile(os.path.join("./weave", name)): +for name in os.listdir("./weave_query"): + if os.path.isfile(os.path.join("./weave_query", name)): continue if name in ignored_dirs: continue @@ -43,7 +43,7 @@ ) if os.getenv("WEAVE_SKIP_BUILD") == None: - subprocess.run(["bash", "weave_query/weave_query/frontend/build.sh"], check=True) + subprocess.run(["bash", "weave_query/frontend/build.sh"], check=True) else: print("!!! Skipping frontend build !!!") diff --git a/dev_docs/DEVELOPMENT.md b/weave_query/dev_docs/DEVELOPMENT.md similarity index 100% rename from dev_docs/DEVELOPMENT.md rename to weave_query/dev_docs/DEVELOPMENT.md diff --git a/scripts/auth_modes.sh b/weave_query/scripts/auth_modes.sh similarity index 100% rename from scripts/auth_modes.sh rename to weave_query/scripts/auth_modes.sh diff --git a/scripts/bisect_script.sh b/weave_query/scripts/bisect_script.sh similarity index 100% rename from scripts/bisect_script.sh rename to weave_query/scripts/bisect_script.sh diff --git a/weave_query/scripts/build_frontend.sh b/weave_query/scripts/build_frontend.sh new file mode 100755 index 00000000000..c6ac77e94b7 --- /dev/null +++ b/weave_query/scripts/build_frontend.sh @@ -0,0 +1,2 @@ +cd weave_query/frontend +sh build.sh diff --git a/scripts/build_test_container.sh b/weave_query/scripts/build_test_container.sh similarity index 100% rename from scripts/build_test_container.sh rename to weave_query/scripts/build_test_container.sh diff --git a/scripts/dd_weave_server.sh b/weave_query/scripts/dd_weave_server.sh similarity index 100% rename from scripts/dd_weave_server.sh rename to weave_query/scripts/dd_weave_server.sh diff --git a/scripts/dd_weave_server_replay.sh b/weave_query/scripts/dd_weave_server_replay.sh similarity index 83% rename from scripts/dd_weave_server_replay.sh rename to weave_query/scripts/dd_weave_server_replay.sh index 590a0b38bd8..04cfcf20db4 100755 --- a/scripts/dd_weave_server_replay.sh +++ b/weave_query/scripts/dd_weave_server_replay.sh @@ -5,6 +5,6 @@ DD_SERVICE="weave-python" DD_ENV="dev-$(whoami)" DD_LOGS_INJECTION=true WEAVE_SERVER_ENABLE_LOGGING=true -FLASK_APP=weave.weave_server +FLASK_APP=weave_query.weave_server ddtrace-run flask run --port 9994 diff --git a/weave_query/scripts/wandb_artifact_perf.py b/weave_query/scripts/wandb_artifact_perf.py index 00bf4bd402e..d457a5e1ce3 100644 --- a/weave_query/scripts/wandb_artifact_perf.py +++ b/weave_query/scripts/wandb_artifact_perf.py @@ -7,7 +7,7 @@ # # Try setting http.TRACE=True when you run this. # -# Run from repo root with: `python -m weave.test_scripts.wandb_artifact_perf` +# Run from repo root with: `python -m weave_query.scripts.wandb_artifact_perf` import asyncio import cProfile diff --git a/scripts/weave_server.sh b/weave_query/scripts/weave_server.sh similarity index 100% rename from scripts/weave_server.sh rename to weave_query/scripts/weave_server.sh diff --git a/scripts/weave_server_replay.sh b/weave_query/scripts/weave_server_replay.sh similarity index 100% rename from scripts/weave_server_replay.sh rename to weave_query/scripts/weave_server_replay.sh diff --git a/scripts/weave_server_test.sh b/weave_query/scripts/weave_server_test.sh similarity index 100% rename from scripts/weave_server_test.sh rename to weave_query/scripts/weave_server_test.sh diff --git a/setup.py b/weave_query/setup.py similarity index 88% rename from setup.py rename to weave_query/setup.py index 7028f55aa59..dab321b6d30 100644 --- a/setup.py +++ b/weave_query/setup.py @@ -14,9 +14,7 @@ ROOT = Path(__file__).resolve().parent SKIP_BUILD = os.environ.get("WEAVE_SKIP_BUILD", False) -IS_BUILT = ( - ROOT / "weave_query" / "weave_query" / "frontend" / "assets" -).is_dir() or SKIP_BUILD +IS_BUILT = (ROOT / "weave_query" / "frontend" / "assets").is_dir() or SKIP_BUILD FORCE_BUILD = os.environ.get("WEAVE_FORCE_BUILD", False) @@ -42,7 +40,7 @@ def check_build_deps() -> bool: def build_frontend() -> None: check_build_deps() try: - build_script = str(Path("weave_query", "weave_query", "frontend", "build.sh")) + build_script = str(Path("weave_query", "frontend", "build.sh")) subprocess.run(["bash", build_script], cwd=ROOT) except OSError: raise RuntimeError("Failed to build frontend.") @@ -73,16 +71,10 @@ def download_and_extract_tarball( def download_frontend() -> None: - sha = ( - open(ROOT / "weave_query" / "weave_query" / "frontend" / "sha1.txt") - .read() - .strip() - ) + sha = open(ROOT / "weave_query" / "frontend" / "sha1.txt").read().strip() url = f"https://storage.googleapis.com/wandb-cdn-prod/weave/{sha}.tar.gz" try: - download_and_extract_tarball( - url, extract_path=ROOT / "weave_query" / "weave_query" - ) + download_and_extract_tarball(url, extract_path=ROOT / "weave_query") except HTTPError: print(f"Warning: Failed to download frontend for sha {sha}") diff --git a/weave_query/weave_query/frontend/build.sh b/weave_query/weave_query/frontend/build.sh index 64c7ccb5a40..a8ceee55c1e 100755 --- a/weave_query/weave_query/frontend/build.sh +++ b/weave_query/weave_query/frontend/build.sh @@ -1,12 +1,12 @@ set -e SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -JS_DIR=$SCRIPT_DIR/../../../weave-js +JS_DIR=$SCRIPT_DIR/../../weave-js SHA1=$(find $JS_DIR -not -path "*/.vite-cache/*" -not -path "*/node_modules/*" -not -path "*/build/*" -type f -print0 | sort -z | xargs -0 sha1sum | cut -d " " -f1 | sha1sum | cut -d " " -f1) yarn --cwd=$JS_DIR install --frozen-lockfile -yarn --cwd=$SCRIPT_DIR/../../../weave-js build +yarn --cwd=$SCRIPT_DIR/../../weave-js build cd $SCRIPT_DIR rm -rf assets index.html -cp -R ../../../weave-js/build/* . +cp -R ../../weave-js/build/* . echo $SHA1 > sha1.txt diff --git a/weave_query/weave_query/frontend/index.html b/weave_query/weave_query/frontend/index.html index 85df1c14d5d..e18ff570f1a 100644 --- a/weave_query/weave_query/frontend/index.html +++ b/weave_query/weave_query/frontend/index.html @@ -91,7 +91,7 @@ - + diff --git a/weave_query/weave_query/frontend/sha1.txt b/weave_query/weave_query/frontend/sha1.txt index 031526369d5..8b52feec8d9 100644 --- a/weave_query/weave_query/frontend/sha1.txt +++ b/weave_query/weave_query/frontend/sha1.txt @@ -1 +1 @@ -31abf52bfbaf506ae62e1644dd7edbea2d011eea +1205c2b2e2e9a0d5bd7f536297c3178ea755302a