diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 43ab9649de6..78ea91c9759 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,8 @@ repos: rev: "v1.10.0" hooks: - id: mypy - additional_dependencies: [types-all, wandb>=0.15.5] + additional_dependencies: + [types-pkg-resources==0.1.3, types-all, wandb>=0.15.5] # You have to exclude in 3 places. 1) here. 2) mypi.ini exclude, 3) follow_imports = skip for each module in mypy.ini exclude: (.*pyi$)|(weave/legacy)|(weave/tests) # Turn pyright back off, duplicative of mypy diff --git a/docs/CONTRIBUTING_DOCS.md b/docs/CONTRIBUTING_DOCS.md new file mode 100644 index 00000000000..9fdc1e5e203 --- /dev/null +++ b/docs/CONTRIBUTING_DOCS.md @@ -0,0 +1,79 @@ +# Contributing to Weave Documentation + +## Guidelines + +- Ensure tone and style is consistent with existing documentation. +- Ensure that the `sidebar.ts` file is updated if adding new pages + +## Installation + +Satisfy the following dependencies to create, build, and locally serve Weave Docs on your local machine: + +- (Recommended) Install [`nvm`](https://github.com/nvm-sh/nvm) to manage your node.js versions. +- Install [Node.js](https://nodejs.org/en/download/) version 18.0.0. + ```node + nvm install 18.0.0 + ``` +- Install Yarn. It is recommended to install Yarn through the [npm package manager](http://npmjs.org/), which comes bundled with [Node.js](https://nodejs.org/) when you install it on your system. + ```yarn + npm install --global yarn + ``` +- Install an IDE (e.g. VS Code) or Text Editor (e.g. Sublime) + +  + +Build and run the docs locally to test that all edits, links etc are working. After you have forked and cloned wandb/weave: + +``` +cd docs + +yarn install +``` + +Then test that you can build and run the docs locally: + +``` +yarn start +``` + +This will return the port number where you can preview your changes to the docs. + +## How to edit the docs locally + +1. Navigate to your local GitHub repo of `weave` and pull the latest changes from master: + +```bash +cd docs +git pull origin main +``` + +2. Create a feature branch off of `main`. + +```bash +git checkout -b +``` + +3. In a new terminal, start a local preview of the docs with `yarn start`. + +```bash +yarn start +``` + +This will return the port number where you can preview your changes to the docs. + +4. Make your changes on the new branch. +5. Check your changes are rendered correctly. + +6. Commit the changes to the branch. + +```bash +git commit -m 'chore(docs): Useful commit message.' +``` + +7. Push the branch to GitHub. + +```bash +git push origin +``` + +8. Open a pull request from the new branch to the original repo. diff --git a/docs/docs/guides/tracking/ops.md b/docs/docs/guides/tracking/ops.md index 5b39e6b9668..e1a69064d62 100644 --- a/docs/docs/guides/tracking/ops.md +++ b/docs/docs/guides/tracking/ops.md @@ -20,7 +20,7 @@ weave.init('intro-example') track_me(15) ``` -Calling an op will created a new op version if the code has changed from the last call, and log the inputs and outputs of the function. +Calling an op will create a new op version if the code has changed from the last call, and log the inputs and outputs of the function. :::note Functions decorated with `@weave.op()` will behave normally (without code versioning and tracking), if you don't call `weave.init('your-project-name')` before calling them. diff --git a/docs/docs/tutorial-tracing_2.md b/docs/docs/tutorial-tracing_2.md index 3fa9e19bd7e..33af0044d80 100644 --- a/docs/docs/tutorial-tracing_2.md +++ b/docs/docs/tutorial-tracing_2.md @@ -110,4 +110,4 @@ To track system attributes, such as a System Prompt, we recommend using [weave M ## What's next? -- Follow the [Build an Evaluation pipeline tutorial](/tutorial-eval) to start iteratively improving your applications. +- Follow the [App Versioning tutorial](/tutorial-weave_models) to capture, version and organize ad-hoc prompt, model, and application changes. diff --git a/docs/docs/tutorial-weave_models.md b/docs/docs/tutorial-weave_models.md new file mode 100644 index 00000000000..43fe4495805 --- /dev/null +++ b/docs/docs/tutorial-weave_models.md @@ -0,0 +1,136 @@ +--- +sidebar_position: 1 +hide_table_of_contents: true +--- + +# App versioning + +Tracking the [inputs, outputs, metadata](/quickstart) as well as [data flowing through your app](/tutorial-tracing_2) is critical to understanding the performance of your system. However **versioning your app over time** is also critical to understand how modifications to your code or app attributes change your outputs. Weave's `Model` class is how these changes can be tracked in Weave. + + +In this tutorial you'll learn: + +- How to use Weave `Model` to track and version your app and its attributes. +- How to export, modify and re-use a Weave `Model` already logged. + +## Using `weave.Model` + +Using Weave `Model`s means that attributes such as model vendor ids, prompts, temperature, and more are stored and versioned when they change. + +To create a `Model` in Weave, you need the following: + +- a class that inherits from `weave.Model` +- type definitions on all class attributes +- a typed `invoke` function with the `@weave.op()` decorator + +When you change the class attributes or the code that defines your model, **these changes will be logged and the version will be updated**. This ensures that you can compare the generations across different versions of your app. + +In the example below, the **model name, temperature and system prompt will be tracked and versioned**: + +```python +import json +from openai import OpenAI + +import weave + +@weave.op() +def extract_dinos(wmodel: weave.Model, sentence: str) -> dict: + response = wmodel.client.chat.completions.create( + model=wmodel.model_name, + temperature=wmodel.temperature, + messages=[ + { + "role": "system", + "content": wmodel.system_prompt + }, + { + "role": "user", + "content": sentence + } + ], + response_format={ "type": "json_object" } + ) + return response.choices[0].message.content + +# Sub-class with a weave.Model +# highlight-next-line +class ExtractDinos(weave.Model): + client: OpenAI = None + model_name: str + temperature: float + system_prompt: str + + # Ensure your function is called `invoke` or `predict` + # highlight-next-line + @weave.op() + # highlight-next-line + def invoke(self, sentence: str) -> dict: + dino_data = extract_dinos(self, sentence) + return json.loads(dino_data) +``` + +Now you can instantiate and call the model with `invoke`: + +```python +weave.init('jurassic-park') +client = OpenAI() + +system_prompt = """Extract any dinosaur `name`, their `common_name`, \ +names and whether its `diet` is a herbivore or carnivore, in JSON format.""" + +# highlight-next-line +dinos = ExtractDinos( + client=client, + model_name='gpt-4o', + temperature=0.4, + system_prompt=system_prompt +) + +sentence = """I watched as a Tyrannosaurus rex (T. rex) chased after a Triceratops (Trike), \ +both carnivore and herbivore locked in an ancient dance. Meanwhile, a gentle giant \ +Brachiosaurus (Brachi) calmly munched on treetops, blissfully unaware of the chaos below.""" + +# highlight-next-line +result = dinos.invoke(sentence) +print(result) +``` + +Now after calling `.invoke` you can see the trace in Weave **now tracks the model attributes as well as the code** for the model functions that have been decorated with `weave.op()`. You can see the model is also versioned, "v21" in this case, and if you click on the model **you can see all of the calls** that have used that version of the model + +![Re-using a weave model](../static/img/tutorial-model_invoke3.png) + +**A note on using `weave.Model`:** +- You can use `predict` instead of `invoke` for the name of the function in your Weave `Model` if you prefer. +- If you want other class methods to be tracked by weave they need to be wrapped in `weave.op()` +- Attributes starting with an underscore are ignored by weave and won't be logged + +## Exporting and re-using a logged `weave.Model` +Because Weave stores and versions Models that have been invoked, it is possible to export and re-use these models. + +**Get the Model ref** +In the Weave UI you can get the Model ref for a particular version + + +**Using the Model** +Once you have the URI of the Model object, you can export and re-use it. Note that the exported model is already initialised and ready to use: + +```python +# the exported weave model is already initialised and ready to be called +# highlight-next-line +new_dinos = weave.ref("weave:///morgan/jurassic-park/object/ExtractDinos:ey4udBU2MU23heQFJenkVxLBX4bmDsFk7vsGcOWPjY4").get() + +# set the client to the openai client again +new_dinos.client = client + +new_sentence = """I also saw a Ankylosaurus grazing on giant ferns""" +new_result = new_dinos.invoke(new_sentence) +print(new_result) +``` + +Here you can now see the name Model version (v21) was used with the new input: + +![Re-using a weave model](../static/img/tutorial-model_re-use.png) + +## What's next? + +- Follow the [Build an Evaluation pipeline tutorial](/tutorial-eval) to start iteratively improving your applications. diff --git a/docs/sidebars.ts b/docs/sidebars.ts index eabe083c750..a9d6a627665 100644 --- a/docs/sidebars.ts +++ b/docs/sidebars.ts @@ -21,7 +21,14 @@ const sidebars: SidebarsConfig = { { type: "category", label: "Getting Started", - items: ["introduction", "quickstart", "tutorial-tracing_2", "tutorial-eval", "tutorial-rag"], + items: [ + "introduction", + "quickstart", + "tutorial-tracing_2", + "tutorial-weave_models", + "tutorial-eval", + "tutorial-rag", + ], }, { type: "category", diff --git a/docs/static/img/tutorial-model_invoke3.png b/docs/static/img/tutorial-model_invoke3.png new file mode 100644 index 00000000000..c504f59799a Binary files /dev/null and b/docs/static/img/tutorial-model_invoke3.png differ diff --git a/docs/static/img/tutorial-model_re-use.png b/docs/static/img/tutorial-model_re-use.png new file mode 100644 index 00000000000..5f2ae04c84d Binary files /dev/null and b/docs/static/img/tutorial-model_re-use.png differ diff --git a/weave-js/src/assets/icons/icon-crop-beginning.svg b/weave-js/src/assets/icons/icon-crop-beginning.svg new file mode 100644 index 00000000000..c6476a2aea9 --- /dev/null +++ b/weave-js/src/assets/icons/icon-crop-beginning.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/weave-js/src/assets/icons/icon-crop-end.svg b/weave-js/src/assets/icons/icon-crop-end.svg new file mode 100644 index 00000000000..70a93219cbf --- /dev/null +++ b/weave-js/src/assets/icons/icon-crop-end.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/weave-js/src/assets/icons/icon-crop-middle.svg b/weave-js/src/assets/icons/icon-crop-middle.svg new file mode 100644 index 00000000000..e8f0010b5b8 --- /dev/null +++ b/weave-js/src/assets/icons/icon-crop-middle.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/weave-js/src/components/Form/TextField.tsx b/weave-js/src/components/Form/TextField.tsx index ae31aff4cfa..80b134180d2 100644 --- a/weave-js/src/components/Form/TextField.tsx +++ b/weave-js/src/components/Form/TextField.tsx @@ -57,6 +57,7 @@ export const TextField = ({ dataTest, }: TextFieldProps) => { const textFieldSize = size ?? 'medium'; + const leftPaddingForIcon = textFieldSize === 'medium' ? 'pl-34' : 'pl-36'; const handleChange = onChange ? (e: React.ChangeEvent) => { @@ -78,34 +79,42 @@ export const TextField = ({
-
- {prefix &&
{prefix}
} +
+ {prefix && ( +
+ {prefix} +
+ )} )} diff --git a/weave-js/src/components/Icon/Icon.tsx b/weave-js/src/components/Icon/Icon.tsx index ac4ffaeeae5..f475839d4d0 100644 --- a/weave-js/src/components/Icon/Icon.tsx +++ b/weave-js/src/components/Icon/Icon.tsx @@ -41,6 +41,9 @@ import {ReactComponent as ImportContentWide} from '../../assets/icons/icon-conte import {ReactComponent as ImportContractLeft} from '../../assets/icons/icon-contract-left.svg'; import {ReactComponent as ImportCopy} from '../../assets/icons/icon-copy.svg'; import {ReactComponent as ImportCreditCardPayment} from '../../assets/icons/icon-credit-card-payment.svg'; +import {ReactComponent as ImportCropBeginning} from '../../assets/icons/icon-crop-beginning.svg'; +import {ReactComponent as ImportCropEnd} from '../../assets/icons/icon-crop-end.svg'; +import {ReactComponent as ImportCropMiddle} from '../../assets/icons/icon-crop-middle.svg'; import {ReactComponent as ImportCross} from '../../assets/icons/icon-cross.svg'; import {ReactComponent as ImportCrownPro} from '../../assets/icons/icon-crown-pro.svg'; import {ReactComponent as ImportCubeContainer} from '../../assets/icons/icon-cube-container.svg'; @@ -373,6 +376,15 @@ export const IconCopy = (props: SVGIconProps) => ( export const IconCreditCardPayment = (props: SVGIconProps) => ( ); +export const IconCropBeginning = (props: SVGIconProps) => ( + +); +export const IconCropEnd = (props: SVGIconProps) => ( + +); +export const IconCropMiddle = (props: SVGIconProps) => ( + +); export const IconCross = (props: SVGIconProps) => ( ); @@ -1001,6 +1013,9 @@ const ICON_NAME_TO_ICON: Record = { 'contract-left': IconContractLeft, copy: IconCopy, 'credit-card-payment': IconCreditCardPayment, + 'crop-beginning': IconCropBeginning, + 'crop-end': IconCropEnd, + 'crop-middle': IconCropMiddle, cross: IconCross, 'crown-pro': IconCrownPro, 'cube-container': IconCubeContainer, diff --git a/weave-js/src/components/Icon/index.ts b/weave-js/src/components/Icon/index.ts index 37475489f64..337e32c1658 100644 --- a/weave-js/src/components/Icon/index.ts +++ b/weave-js/src/components/Icon/index.ts @@ -41,6 +41,9 @@ export { IconContractLeft, IconCopy, IconCreditCardPayment, + IconCropBeginning, + IconCropEnd, + IconCropMiddle, IconCross, IconCrownPro, IconCubeContainer, diff --git a/weave-js/src/components/Icon/types.ts b/weave-js/src/components/Icon/types.ts index 387baa6b6b2..d7f0907c5cb 100644 --- a/weave-js/src/components/Icon/types.ts +++ b/weave-js/src/components/Icon/types.ts @@ -40,6 +40,9 @@ export const IconNames = { ContractLeft: 'contract-left', Copy: 'copy', CreditCardPayment: 'credit-card-payment', + CropBeginning: 'crop-beginning', + CropEnd: 'crop-end', + CropMiddle: 'crop-middle', Cross: 'cross', CrownPro: 'crown-pro', CubeContainer: 'cube-container', diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index 16ad224278e..a25808ba512 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -12,6 +12,7 @@ import { import { GridColumnVisibilityModel, GridPaginationModel, + GridPinnedColumns, GridSortModel, } from '@mui/x-data-grid-pro'; import {LicenseInfo} from '@mui/x-license-pro'; @@ -61,13 +62,16 @@ import { DEFAULT_PAGE_SIZE, getValidPaginationModel, } from './Browse3/grid/pagination'; +import {getValidPinModel, removeAlwaysLeft} from './Browse3/grid/pin'; import {getValidSortModel} from './Browse3/grid/sort'; import {BoardPage} from './Browse3/pages/BoardPage'; import {BoardsPage} from './Browse3/pages/BoardsPage'; import {CallPage} from './Browse3/pages/CallPage/CallPage'; import {CallsPage} from './Browse3/pages/CallsPage/CallsPage'; import { + ALWAYS_PIN_LEFT_CALLS, DEFAULT_COLUMN_VISIBILITY_CALLS, + DEFAULT_PIN_CALLS, DEFAULT_SORT_CALLS, } from './Browse3/pages/CallsPage/CallsTable'; import {Empty} from './Browse3/pages/common/Empty'; @@ -696,6 +700,19 @@ const CallsPageBinding = () => { history.push({search: newQuery.toString()}); }; + const pinModel = useMemo( + () => getValidPinModel(query.pin, DEFAULT_PIN_CALLS, ALWAYS_PIN_LEFT_CALLS), + [query.pin] + ); + const setPinModel = (newModel: GridPinnedColumns) => { + const newQuery = new URLSearchParams(location.search); + newQuery.set( + 'pin', + JSON.stringify(removeAlwaysLeft(newModel, ALWAYS_PIN_LEFT_CALLS)) + ); + history.push({search: newQuery.toString()}); + }; + const sortModel = useMemo( () => getValidSortModel(query.sort, DEFAULT_SORT_CALLS), [query.sort] @@ -739,6 +756,8 @@ const CallsPageBinding = () => { onFilterUpdate={onFilterUpdate} columnVisibilityModel={columnVisibilityModel} setColumnVisibilityModel={setColumnVisibilityModel} + pinModel={pinModel} + setPinModel={setPinModel} sortModel={sortModel} setSortModel={setSortModel} paginationModel={paginationModel} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx index 6db389e289b..aa86af6062d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGrid.tsx @@ -1,11 +1,13 @@ import {Box} from '@mui/material'; import _ from 'lodash'; -import React from 'react'; +import React, {useEffect} from 'react'; +import {useViewerInfo} from '../../../../../common/hooks/useViewerInfo'; import {Alert} from '../../../../Alert'; import {Loading} from '../../../../Loading'; import {Tailwind} from '../../../../Tailwind'; import {useWFHooks} from '../pages/wfReactInterface/context'; +import {useGetTraceServerClientContext} from '../pages/wfReactInterface/traceServerClientContext'; import {FeedbackGridInner} from './FeedbackGridInner'; type FeedbackGridProps = { @@ -21,6 +23,8 @@ export const FeedbackGrid = ({ weaveRef, objectType, }: FeedbackGridProps) => { + const {loading: loadingUserInfo, userInfo} = useViewerInfo(); + const {useFeedback} = useWFHooks(); const query = useFeedback({ entity, @@ -28,7 +32,13 @@ export const FeedbackGrid = ({ weaveRef, }); - if (query.loading) { + const getTsClient = useGetTraceServerClientContext(); + useEffect(() => { + return getTsClient().registerOnFeedbackListener(weaveRef, query.refetch); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + if (query.loading || loadingUserInfo) { return ( {paths.map(path => { return (
{path &&
On {path}
} - +
); })} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGridActions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGridActions.tsx new file mode 100644 index 00000000000..030ec71005c --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGridActions.tsx @@ -0,0 +1,162 @@ +import { + Dialog, + DialogActions as MaterialDialogActions, + DialogContent as MaterialDialogContent, + DialogTitle as MaterialDialogTitle, +} from '@material-ui/core'; +import React, {useEffect, useState} from 'react'; +import styled from 'styled-components'; + +import {Button} from '../../../../Button'; +import {useGetTraceServerClientContext} from '../pages/wfReactInterface/traceServerClientContext'; + +type FeedbackGridActionsProps = { + projectId: string; + feedbackId: string; +}; + +export const FeedbackGridActions = ({ + projectId, + feedbackId, +}: FeedbackGridActionsProps) => { + const [confirmDelete, setConfirmDelete] = useState(false); + + return ( + <> + + + + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGridInner.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGridInner.tsx index 76c13ae2f07..4bf6505bb23 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGridInner.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/FeedbackGridInner.tsx @@ -7,13 +7,18 @@ import {CellValueString} from '../../Browse2/CellValueString'; import {CopyableId} from '../pages/common/Id'; import {Feedback} from '../pages/wfReactInterface/traceServerClientTypes'; import {StyledDataGrid} from '../StyledDataGrid'; +import {FeedbackGridActions} from './FeedbackGridActions'; import {FeedbackTypeChip} from './FeedbackTypeChip'; type FeedbackGridInnerProps = { feedback: Feedback[]; + currentViewerId: string | null; }; -export const FeedbackGridInner = ({feedback}: FeedbackGridInnerProps) => { +export const FeedbackGridInner = ({ + feedback, + currentViewerId, +}: FeedbackGridInnerProps) => { const columns: GridColDef[] = [ { field: 'feedback_type', @@ -71,6 +76,26 @@ export const FeedbackGridInner = ({feedback}: FeedbackGridInnerProps) => { return ; }, }, + { + field: 'actions', + headerName: '', + width: 50, + filterable: false, + sortable: false, + resizable: false, + disableColumnMenu: true, + renderCell: params => { + const projectId = params.row.project_id; + const feedbackId = params.row.id; + const creatorId = params.row.wb_user_id; + if (!currentViewerId || creatorId !== currentViewerId) { + return null; + } + return ( + + ); + }, + }, ]; const rows = feedback; return ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/Reactions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/Reactions.tsx index 8435d88ddb7..f39c80a1b4e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/Reactions.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/Reactions.tsx @@ -2,8 +2,12 @@ import React, {useEffect, useState} from 'react'; import {useViewerInfo} from '../../../../../common/hooks/useViewerInfo'; import {parseRef} from '../../../../../react'; +import {useWFHooks} from '../pages/wfReactInterface/context'; import {useGetTraceServerClientContext} from '../pages/wfReactInterface/traceServerClientContext'; -import {Feedback} from '../pages/wfReactInterface/traceServerClientTypes'; +import { + Feedback, + SortBy, +} from '../pages/wfReactInterface/traceServerClientTypes'; import {ReactionsLoaded} from './ReactionsLoaded'; type ReactionsProps = { @@ -17,6 +21,8 @@ type ReactionsProps = { twWrapperStyles?: React.CSSProperties; }; +const SORT_BY: SortBy[] = [{field: 'created_at', direction: 'asc'}]; + export const Reactions = ({ weaveRef, readonly = false, @@ -33,32 +39,25 @@ export const Reactions = ({ const {entityName: entity, projectName: project} = parsedRef; const projectId = `${entity}/${project}`; + const {useFeedback} = useWFHooks(); + const query = useFeedback( + { + entity, + project, + weaveRef, + }, + SORT_BY + ); const getTsClient = useGetTraceServerClientContext(); - useEffect(() => { - let mounted = true; - getTsClient() - .feedbackQuery({ - project_id: projectId, - query: { - $expr: { - $eq: [{$getField: 'weave_ref'}, {$literal: weaveRef}], - }, - }, - sort_by: [{field: 'created_at', direction: 'asc'}], - }) - .then(res => { - if (!mounted) { - return; - } - if ('result' in res) { - setFeedback(res.result); - } - }); - return () => { - mounted = false; - }; - }, [getTsClient, projectId, weaveRef]); + return getTsClient().registerOnFeedbackListener(weaveRef, query.refetch); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + useEffect(() => { + if (query.result) { + setFeedback(query.result); + } + }, [query.result]); const onAddEmoji = (emoji: string) => { const req = { @@ -68,25 +67,7 @@ export const Reactions = ({ feedback_type: 'wandb.reaction.1', payload: {emoji}, }; - getTsClient() - .feedbackCreate(req) - .then(res => { - if (feedback === null) { - return; - } - if ('detail' in res) { - return; - } - const newReaction = { - ...req, - id: res.id, - created_at: res.created_at, - wb_user_id: res.wb_user_id, - payload: res.payload, - }; - const newFeedback = [...feedback, newReaction]; - setFeedback(newFeedback); - }); + getTsClient().feedbackCreate(req); }; const onAddNote = (note: string) => { const req = { @@ -96,43 +77,18 @@ export const Reactions = ({ feedback_type: 'wandb.note.1', payload: {note}, }; - getTsClient() - .feedbackCreate(req) - .then(res => { - if (feedback === null) { - return; - } - if ('detail' in res) { - return; - } - const newReaction = { - ...req, - id: res.id, - created_at: res.created_at, - wb_user_id: res.wb_user_id, - }; - const newFeedback = [...feedback, newReaction]; - setFeedback(newFeedback); - }); + getTsClient().feedbackCreate(req); }; const onRemoveFeedback = (id: string) => { - getTsClient() - .feedbackPurge({ - project_id: projectId, - query: { - $expr: { - $eq: [{$getField: 'id'}, {$literal: id}], - }, + getTsClient().feedbackPurge({ + project_id: projectId, + query: { + $expr: { + $eq: [{$getField: 'id'}, {$literal: id}], }, - }) - .then(res => { - if (!feedback) { - return; - } - const newFeedback = feedback.filter(f => f.id !== id); - setFeedback(newFeedback); - }); + }, + }); }; if (loadingUserInfo || feedback === null) { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/grid/pin.test.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/grid/pin.test.ts new file mode 100644 index 00000000000..fa4868a6b0f --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/grid/pin.test.ts @@ -0,0 +1,77 @@ +import {getValidPinModel, removeAlwaysLeft} from './pin'; + +describe('removeAlwaysLeft', () => { + it('removes an alwaysLeft item from left', () => { + const result = removeAlwaysLeft({left: ['checkbox', 'foo']}, ['checkbox']); + expect(result).toEqual({ + left: ['foo'], + }); + }); + it('removes multiple alwaysLeft items from left', () => { + const result = removeAlwaysLeft({left: ['checkbox', 'foo', 'bar']}, [ + 'bar', + 'checkbox', + ]); + expect(result).toEqual({ + left: ['foo'], + }); + }); + it('does not change when cols are not in left', () => { + const result = removeAlwaysLeft({left: ['checkbox', 'foo']}, ['bar']); + expect(result).toEqual({ + left: ['checkbox', 'foo'], + }); + }); + it('does not change when no left', () => { + const result = removeAlwaysLeft({right: ['checkbox', 'foo']}, ['checkbox']); + expect(result).toEqual({ + right: ['checkbox', 'foo'], + }); + }); +}); + +describe('getValidPinModel', () => { + it('parses a valid pin model', () => { + const parsed = getValidPinModel( + '{"left": ["CustomCheckbox", "op_name", "feedback"]}' + ); + expect(parsed).toEqual({ + left: ['CustomCheckbox', 'op_name', 'feedback'], + }); + }); + it('includes alwaysLeft items when left is specified', () => { + const parsed = getValidPinModel('{"left": ["foo"]}', null, ['checkbox']); + expect(parsed).toEqual({ + left: ['checkbox', 'foo'], + }); + }); + it('includes alwaysLeft items when left is not specified', () => { + const parsed = getValidPinModel('{}', null, ['checkbox']); + expect(parsed).toEqual({ + left: ['checkbox'], + }); + }); + it('moves alwaysLeft items to front', () => { + const parsed = getValidPinModel('{"left": ["foo", "checkbox"]}', null, [ + 'checkbox', + ]); + expect(parsed).toEqual({ + left: ['checkbox', 'foo'], + }); + }); + + it('returns {} on non-object with no explicit default', () => { + const parsed = getValidPinModel('[]'); + expect(parsed).toEqual({}); + }); + it('returns {} on invalid pin value with no explicit default', () => { + const parsed = getValidPinModel('{"lef": ["foo"]}'); + expect(parsed).toEqual({}); + }); + it('returns default on non-object', () => { + const parsed = getValidPinModel('[]', {left: ['checkbox']}); + expect(parsed).toEqual({ + left: ['checkbox'], + }); + }); +}); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/grid/pin.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/grid/pin.ts new file mode 100644 index 00000000000..8338214af2a --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/grid/pin.ts @@ -0,0 +1,63 @@ +import {GridPinnedColumns} from '@mui/x-data-grid-pro'; +import _ from 'lodash'; + +const isValidPinValue = (value: any): boolean => { + return _.isArray(value) && value.every(v => _.isString(v)); +}; + +// Columns that are always pinned left don't need to be present in serialized state. +export const removeAlwaysLeft = ( + pinModel: GridPinnedColumns, + alwaysLeft: string[] +): GridPinnedColumns => { + if (!pinModel.left) { + return pinModel; + } + const {left} = pinModel; + return { + ...pinModel, + left: left.filter(col => !alwaysLeft.includes(col)), + }; +}; + +// Ensure specified columns are always pinned left. +const ensureAlwaysLeft = ( + pinModel: GridPinnedColumns, + alwaysLeft: string[] +): GridPinnedColumns => { + let left = pinModel.left ?? []; + left = left.filter(col => !alwaysLeft.includes(col)); + left = alwaysLeft.concat(left); + return { + ...pinModel, + left, + }; +}; + +export const getValidPinModel = ( + jsonString: string, + def: GridPinnedColumns | null = null, + alwaysLeft?: string[] +): GridPinnedColumns => { + def = def ?? {}; + try { + const parsed = JSON.parse(jsonString); + if (_.isPlainObject(parsed)) { + const keys = Object.keys(parsed); + if ( + keys.every( + key => ['left', 'right'].includes(key) && isValidPinValue(parsed[key]) + ) + ) { + const pinModel = parsed as GridPinnedColumns; + if (alwaysLeft) { + return ensureAlwaysLeft(pinModel, alwaysLeft); + } + return pinModel; + } + } + } catch (e) { + // Ignore + } + return def; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx index e2c0a71e46b..d5176339465 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx @@ -1,7 +1,13 @@ import Box from '@mui/material/Box'; import {GridRowId, useGridApiRef} from '@mui/x-data-grid-pro'; import _ from 'lodash'; -import React, {useCallback, useContext, useMemo, useState} from 'react'; +import React, { + useCallback, + useContext, + useEffect, + useMemo, + useState, +} from 'react'; import styled from 'styled-components'; import {isWeaveObjectRef, parseRef} from '../../../../../../react'; @@ -86,9 +92,7 @@ const ObjectViewerSectionNonEmpty = ({ isExpanded, }: ObjectViewerSectionProps) => { const apiRef = useGridApiRef(); - const [mode, setMode] = useState( - isSimpleData(data) || isExpanded ? 'expanded' : 'collapsed' - ); + const [mode, setMode] = useState('collapsed'); const [expandedIds, setExpandedIds] = useState([]); const body = useMemo(() => { @@ -153,6 +157,18 @@ const ObjectViewerSectionNonEmpty = ({ setExpandedIds(getGroupIds()); }; + // On first render and when data changes, recompute expansion state + useEffect(() => { + const isSimple = isSimpleData(data); + const newMode = isSimple || isExpanded ? 'expanded' : 'collapsed'; + if (newMode === 'expanded') { + onClickExpanded(); + } else { + onClickCollapsed(); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [data, isExpanded]); + return ( <> diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsPage.tsx index 48533c5c72a..86e4cd1e805 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsPage.tsx @@ -1,6 +1,7 @@ import { GridColumnVisibilityModel, GridPaginationModel, + GridPinnedColumns, GridSortModel, } from '@mui/x-data-grid-pro'; import _ from 'lodash'; @@ -37,6 +38,9 @@ export const CallsPage: FC<{ columnVisibilityModel: GridColumnVisibilityModel; setColumnVisibilityModel: (newModel: GridColumnVisibilityModel) => void; + pinModel: GridPinnedColumns; + setPinModel: (newModel: GridPinnedColumns) => void; + sortModel: GridSortModel; setSortModel: (newModel: GridSortModel) => void; @@ -88,6 +92,8 @@ export const CallsPage: FC<{ onFilterUpdate={setFilter} columnVisibilityModel={props.columnVisibilityModel} setColumnVisibilityModel={props.setColumnVisibilityModel} + pinModel={props.pinModel} + setPinModel={props.setPinModel} sortModel={props.sortModel} setSortModel={props.setSortModel} paginationModel={props.paginationModel} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx index 46114a6a2c2..4a0a295475e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx @@ -47,7 +47,7 @@ import { } from '../../context'; import {DEFAULT_PAGE_SIZE} from '../../grid/pagination'; import {StyledPaper} from '../../StyledAutocomplete'; -import {SELECTED_FOR_DELETION, StyledDataGrid} from '../../StyledDataGrid'; +import {StyledDataGrid} from '../../StyledDataGrid'; import {StyledTextField} from '../../StyledTextField'; import {ConfirmDeleteModal} from '../CallPage/OverflowMenu'; import {Empty} from '../common/Empty'; @@ -81,7 +81,7 @@ import {ManageColumnsButton} from './ManageColumnsButton'; const OP_FILTER_GROUP_HEADER = 'Op'; const MAX_EVAL_COMPARISONS = 5; -const MAX_BULK_DELETE = 10; +const MAX_SELECT = 100; export const DEFAULT_COLUMN_VISIBILITY_CALLS = { 'attributes.weave.client_version': false, @@ -92,6 +92,12 @@ export const DEFAULT_COLUMN_VISIBILITY_CALLS = { 'attributes.weave.sys_version': false, }; +export const ALWAYS_PIN_LEFT_CALLS = ['CustomCheckbox']; + +export const DEFAULT_PIN_CALLS: GridPinnedColumns = { + left: ['CustomCheckbox', 'op_name', 'feedback'], +}; + export const DEFAULT_SORT_CALLS: GridSortModel = [ {field: 'started_at', sort: 'desc'}, ]; @@ -114,6 +120,9 @@ export const CallsTable: FC<{ columnVisibilityModel?: GridColumnVisibilityModel; setColumnVisibilityModel?: (newModel: GridColumnVisibilityModel) => void; + pinModel?: GridPinnedColumns; + setPinModel?: (newModel: GridPinnedColumns) => void; + sortModel?: GridSortModel; setSortModel?: (newModel: GridSortModel) => void; @@ -128,6 +137,8 @@ export const CallsTable: FC<{ hideControls, columnVisibilityModel, setColumnVisibilityModel, + pinModel, + setPinModel, sortModel, setSortModel, paginationModel, @@ -330,10 +341,7 @@ export const CallsTable: FC<{ ); // DataGrid Model Management - const [pinnedColumnsModel, setPinnedColumnsModel] = - useState({ - left: ['CustomCheckbox', 'op_name', 'feedback'], - }); + const pinModelResolved = pinModel ?? DEFAULT_PIN_CALLS; // END OF CPR FACTORED CODE @@ -380,18 +388,17 @@ export const CallsTable: FC<{ project ); - const [bulkDeleteMode, setBulkDeleteMode] = useState(false); - // Selection Management const [selectedCalls, setSelectedCalls] = useState([]); const muiColumns = useMemo(() => { - return [ + const cols = [ { minWidth: 30, width: 38, field: 'CustomCheckbox', sortable: false, disableColumnMenu: true, + resizable: false, renderHeader: (params: any) => { return ( { - // if bulk delete move, or not eval table, select all calls - if (bulkDeleteMode || !isEvaluateTable) { - if ( - selectedCalls.length === - Math.min(tableData.length, MAX_BULK_DELETE) - ) { - setSelectedCalls([]); - } else { - setSelectedCalls( - tableData.map(row => row.id).slice(0, MAX_BULK_DELETE) - ); - } + const maxForTable = isEvaluateTable + ? MAX_EVAL_COMPARISONS + : MAX_SELECT; + if ( + selectedCalls.length === + Math.min(tableData.length, maxForTable) + ) { + setSelectedCalls([]); } else { - // exclude non-successful calls from selection - const filtered = tableData.filter( - row => row.exception == null && row.ended_at != null + setSelectedCalls( + tableData.map(row => row.id).slice(0, maxForTable) ); - if ( - selectedCalls.length === - Math.min(filtered.length, MAX_EVAL_COMPARISONS) - ) { - setSelectedCalls([]); - } else { - setSelectedCalls( - filtered.map(row => row.id).slice(0, MAX_EVAL_COMPARISONS) - ); - } } }} /> @@ -438,30 +430,18 @@ export const CallsTable: FC<{ renderCell: (params: any) => { const rowId = params.id as string; const isSelected = selectedCalls.includes(rowId); - const disabledDueToMax = - selectedCalls.length >= MAX_EVAL_COMPARISONS && !isSelected; - const disabledDueToNonSuccess = - params.row.exception != null || params.row.ended_at == null; + const disabled = + !isSelected && + (isEvaluateTable + ? selectedCalls.length >= MAX_EVAL_COMPARISONS + : selectedCalls.length >= MAX_SELECT); let tooltipText = ''; - if (bulkDeleteMode || !isEvaluateTable) { - if (selectedCalls.length >= MAX_BULK_DELETE) { - tooltipText = `Deletion limited to ${MAX_BULK_DELETE} items`; - } else { - tooltipText = ''; - } - } else { - if (disabledDueToNonSuccess) { - tooltipText = 'Cannot compare non-successful evaluations'; - } else if (disabledDueToMax) { + if (isEvaluateTable) { + if (selectedCalls.length >= MAX_EVAL_COMPARISONS && !isSelected) { tooltipText = `Comparison limited to ${MAX_EVAL_COMPARISONS} evaluations`; } - } - - let disabled = false; - if ((bulkDeleteMode || !isEvaluateTable) && !isSelected) { - disabled = selectedCalls.length >= MAX_BULK_DELETE; - } else if (isEvaluateTable) { - disabled = disabledDueToNonSuccess || disabledDueToMax; + } else if (selectedCalls.length >= MAX_SELECT && !isSelected) { + tooltipText = `Selection limited to ${MAX_SELECT} items`; } return ( @@ -490,7 +470,8 @@ export const CallsTable: FC<{ }, ...columns.cols, ]; - }, [columns.cols, selectedCalls, tableData, bulkDeleteMode, isEvaluateTable]); + return cols; + }, [columns.cols, selectedCalls, tableData, isEvaluateTable]); // Register Compare Evaluations Button const history = useHistory(); @@ -507,10 +488,7 @@ export const CallsTable: FC<{ router.compareEvaluationsUri(entity, project, selectedCalls) ); }} - disabled={selectedCalls.length === 0 || bulkDeleteMode} - tooltipText={ - bulkDeleteMode ? 'Cannot compare while bulk deleting' : undefined - } + disabled={selectedCalls.length === 0} /> ), order: 1, @@ -529,7 +507,6 @@ export const CallsTable: FC<{ entity, project, history, - bulkDeleteMode, ]); // Register Delete Button @@ -541,29 +518,15 @@ export const CallsTable: FC<{ addExtra('deleteSelectedCalls', { node: ( setDeleteConfirmModalOpen(true)} + onClick={() => setDeleteConfirmModalOpen(true)} disabled={selectedCalls.length === 0} - bulkDeleteModeToggle={mode => { - setBulkDeleteMode(mode); - if (!mode) { - setSelectedCalls([]); - } - }} - selectedCalls={selectedCalls} /> ), order: 3, }); return () => removeExtra('deleteSelectedCalls'); - }, [ - addExtra, - removeExtra, - selectedCalls, - isEvaluateTable, - bulkDeleteMode, - isReadonly, - ]); + }, [addExtra, removeExtra, selectedCalls, isEvaluateTable, isReadonly]); useEffect(() => { if (isReadonly) { @@ -603,6 +566,16 @@ export const CallsTable: FC<{ } : undefined; + const onPinnedColumnsChange = useCallback( + (newModel: GridPinnedColumns) => { + if (!setPinModel || callsLoading) { + return; + } + setPinModel(newModel); + }, + [callsLoading, setPinModel] + ); + const onSortModelChange = useCallback( (newModel: GridSortModel) => { if (!setSortModel || callsLoading) { @@ -785,11 +758,6 @@ export const CallsTable: FC<{ // columnGroupingModel={groupingModel} columnGroupingModel={columns.colGroupingModel} hideFooterSelectedRowCount - getRowClassName={params => - bulkDeleteMode && selectedCalls.includes(params.row.id) - ? SELECTED_FOR_DELETION - : '' - } onColumnWidthChange={newCol => { setUserDefinedColumnWidths(curr => { return { @@ -798,8 +766,8 @@ export const CallsTable: FC<{ }; }); }} - pinnedColumns={pinnedColumnsModel} - onPinnedColumnsChange={newModel => setPinnedColumnsModel(newModel)} + pinnedColumns={pinModelResolved} + onPinnedColumnsChange={onPinnedColumnsChange} sx={{ borderRadius: 0, }} @@ -938,7 +906,7 @@ const ExportRunsTableButton = ({ alignItems: 'center', }}>
); const BulkDeleteButton: FC<{ disabled?: boolean; - selectedCalls: string[]; - onConfirm: () => void; - bulkDeleteModeToggle: (mode: boolean) => void; -}> = ({disabled, selectedCalls, onConfirm, bulkDeleteModeToggle}) => { - const [deleteClicked, setDeleteClicked] = useState(false); + onClick: () => void; +}> = ({disabled, onClick}) => { return ( - {deleteClicked ? ( - <> - - - - ) : selectedCalls.length > 0 ? ( - + }> + + No summary information found for{' '} + {maybePluralizeWord(invalidEvals.length, 'evaluation')}:{' '} + {invalidEvals.join(', ')}. + + + + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpUtil.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpUtil.ts index 19dd436d8a5..601116db34d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpUtil.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpUtil.ts @@ -18,6 +18,8 @@ import { SourceType, } from './ecpTypes'; +export const EVALUATION_NAME_DEFAULT = 'Evaluation'; + export const flattenedDimensionPath = (dim: MetricDefinition): string => { const paths = [...dim.metricSubPath]; if (dim.source === 'scorer') { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClient.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClient.ts index 193bd0fd30f..af997016a4a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClient.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClient.ts @@ -1,6 +1,10 @@ import _ from 'lodash'; import { + FeedbackCreateReq, + FeedbackCreateRes, + FeedbackPurgeReq, + FeedbackPurgeRes, TraceCallsDeleteReq, TraceCallUpdateReq, TraceRefsReadBatchReq, @@ -19,6 +23,7 @@ export class TraceServerClient extends DirectTraceServerClient { }> = []; private onDeleteListeners: Array<() => void>; private onRenameListeners: Array<() => void>; + private onFeedbackListeners: Record void>>; constructor(baseUrl: string) { super(baseUrl); @@ -26,6 +31,7 @@ export class TraceServerClient extends DirectTraceServerClient { this.scheduleReadBatch(); this.onDeleteListeners = []; this.onRenameListeners = []; + this.onFeedbackListeners = {}; } /** @@ -52,6 +58,25 @@ export class TraceServerClient extends DirectTraceServerClient { ); }; } + public registerOnFeedbackListener( + weaveRef: string, + callback: () => void + ): () => void { + if (!(weaveRef in this.onFeedbackListeners)) { + this.onFeedbackListeners[weaveRef] = []; + } + this.onFeedbackListeners[weaveRef].push(callback); + return () => { + const newListeners = this.onFeedbackListeners[weaveRef].filter( + listener => listener !== callback + ); + if (newListeners.length) { + this.onFeedbackListeners[weaveRef] = newListeners; + } else { + delete this.onFeedbackListeners[weaveRef]; + } + }; + } public callsDelete(req: TraceCallsDeleteReq): Promise { const res = super.callsDelete(req).then(() => { @@ -67,6 +92,27 @@ export class TraceServerClient extends DirectTraceServerClient { return res; } + public feedbackCreate(req: FeedbackCreateReq): Promise { + const res = super.feedbackCreate(req).then(createRes => { + const listeners = this.onFeedbackListeners[req.weave_ref] ?? []; + listeners.forEach(listener => listener()); + return createRes; + }); + return res; + } + public feedbackPurge(req: FeedbackPurgeReq): Promise { + const res = super.feedbackPurge(req).then(purgeRes => { + // TODO: Since purge takes a query, we need to change the result to include + // information about the refs that were modified. + // For now, just call all registered feedback listeners. + for (const listeners of Object.values(this.onFeedbackListeners)) { + listeners.forEach(listener => listener()); + } + return purgeRes; + }); + return res; + } + public readBatch(req: TraceRefsReadBatchReq): Promise { return this.requestReadBatch(req); } diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts index 917de4c83a1..f4c6a8d9873 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts @@ -42,6 +42,7 @@ import { OpVersionKey, OpVersionSchema, RawSpanFromStreamTableEra, + Refetchable, RefMutation, TableQuery, WFDataModelHooksInterface, @@ -512,8 +513,9 @@ const useCallUpdateFunc = () => { }; const useFeedback = ( - key: FeedbackKey | null -): LoadableWithError => { + key: FeedbackKey | null, + sortBy?: traceServerTypes.SortBy[] +): LoadableWithError & Refetchable => { const getTsClient = useGetTraceServerClientContext(); const [result, setResult] = useState< @@ -523,10 +525,18 @@ const useFeedback = ( result: null, error: null, }); + const [doReload, setDoReload] = useState(false); + const refetch = useCallback(() => { + setDoReload(true); + }, [setDoReload]); const deepKey = useDeepMemo(key); useEffect(() => { + let mounted = true; + if (doReload) { + setDoReload(false); + } if (!deepKey) { return; } @@ -542,20 +552,29 @@ const useFeedback = ( $eq: [{$getField: 'weave_ref'}, {$literal: deepKey.weaveRef}], }, }, - sort_by: [{field: 'created_at', direction: 'desc'}], + sort_by: sortBy ?? [{field: 'created_at', direction: 'desc'}], }) .then(res => { + if (!mounted) { + return; + } if ('result' in res) { setResult({loading: false, result: res.result, error: null}); } // TODO: handle error case }) .catch(err => { + if (!mounted) { + return; + } setResult({loading: false, result: null, error: err}); }); - }, [deepKey, getTsClient]); + return () => { + mounted = false; + }; + }, [deepKey, getTsClient, doReload, sortBy]); - return result; + return {...result, refetch}; }; const useOpVersion = ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts index 2f10edfd968..60f2025b47a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts @@ -81,7 +81,10 @@ import { EvaluationComparisonData, MetricDefinition, } from '../CompareEvaluationsPage/ecpTypes'; -import {metricDefinitionId} from '../CompareEvaluationsPage/ecpUtil'; +import { + EVALUATION_NAME_DEFAULT, + metricDefinitionId, +} from '../CompareEvaluationsPage/ecpUtil'; import {getScoreKeyNameFromScorerRef} from '../CompareEvaluationsPage/ecpUtil'; import {TraceServerClient} from '../wfReactInterface/traceServerClient'; import {useGetTraceServerClientContext} from '../wfReactInterface/traceServerClientContext'; @@ -183,8 +186,7 @@ const fetchEvaluationComparisonData = async ( call.id, { callId: call.id, - // TODO: Get user-defined name for the evaluation - name: 'Evaluation', + name: call.display_name ?? EVALUATION_NAME_DEFAULT, color: pickColor(ndx), evaluationRef: call.inputs.self, modelRef: call.inputs.model, @@ -226,11 +228,15 @@ const fetchEvaluationComparisonData = async ( if (evalObj == null) { return; } + const output = evaluationCallCache[evalCall.callId].output; + if (output == null) { + return; + } // Add the user-defined scores evalObj.scorerRefs.forEach(scorerRef => { const scorerKey = getScoreKeyNameFromScorerRef(scorerRef); - const score = evaluationCallCache[evalCall.callId].output[scorerKey]; + const score = output[scorerKey]; const recursiveAddScore = (scoreVal: any, currPath: string[]) => { if (isBinarySummaryScore(scoreVal)) { const metricDimension: MetricDefinition = { @@ -304,15 +310,13 @@ const fetchEvaluationComparisonData = async ( // Add the derived metrics // Model latency - const model_latency = - evaluationCallCache[evalCall.callId].output.model_latency; - if (model_latency != null) { + if (output.model_latency != null) { const metricId = metricDefinitionId(modelLatencyMetricDimension); result.summaryMetrics[metricId] = { ...modelLatencyMetricDimension, }; evalCall.summaryMetrics[metricId] = { - value: model_latency.mean, + value: output.model_latency.mean, sourceCallId: evalCallId, }; result.scoreMetrics[metricId] = { @@ -323,11 +327,10 @@ const fetchEvaluationComparisonData = async ( // Total Tokens // TODO: This "mean" is incorrect - really should average across all model // calls since this includes LLM scorers - const totalTokens = sum( - Object.values( - evaluationCallCache[evalCall.callId].summary.usage ?? {} - ).map(v => v.total_tokens) - ); + const summary = evaluationCallCache[evalCall.callId].summary; + const totalTokens = summary + ? sum(Object.values(summary.usage ?? {}).map(v => v.total_tokens)) + : null; if (totalTokens != null) { const metricId = metricDefinitionId(totalTokensMetricDimension); result.summaryMetrics[metricId] = { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts index c226e3af04e..f61c9e2c6db 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts @@ -156,6 +156,10 @@ export type FeedbackKey = { weaveRef: string; }; +export type Refetchable = { + refetch: () => void; +}; + export type WFDataModelHooksInterface = { useCall: (key: CallKey | null) => Loadable; useCalls: ( @@ -221,7 +225,10 @@ export type WFDataModelHooksInterface = { digest: string, opts?: {skip?: boolean} ) => Loadable; - useFeedback: (key: FeedbackKey | null) => LoadableWithError; + useFeedback: ( + key: FeedbackKey | null, + sortBy?: traceServerClientTypes.SortBy[] + ) => LoadableWithError & Refetchable; derived: { useChildCallsForCompare: ( entity: string, diff --git a/weave-js/src/components/UserLink.tsx b/weave-js/src/components/UserLink.tsx index 5e2d9b6436c..b2ea4414ac2 100644 --- a/weave-js/src/components/UserLink.tsx +++ b/weave-js/src/components/UserLink.tsx @@ -243,14 +243,24 @@ export const useUsers = (userIds: string[]) => { const [users, setUsers] = useState('load'); useEffect(() => { + let mounted = true; setUsers('loading'); fetchUsers(memoedUserIds) .then(userRes => { + if (!mounted) { + return; + } setUsers(userRes); }) .catch(err => { + if (!mounted) { + return; + } setUsers('error'); }); + return () => { + mounted = false; + }; }, [memoedUserIds]); return users; diff --git a/weave/flow/scorer.py b/weave/flow/scorer.py index 09d184d749c..92be745e212 100644 --- a/weave/flow/scorer.py +++ b/weave/flow/scorer.py @@ -44,6 +44,10 @@ def auto_summarize(data: list) -> Optional[dict[str, Any]]: if not data: return {} data = [x for x in data if x is not None] + + if not data: + return None + val = data[0] if isinstance(val, bool): @@ -55,8 +59,13 @@ def auto_summarize(data: list) -> Optional[dict[str, Any]]: return {"mean": np.mean(data).item()} elif isinstance(val, dict): result = {} - for k in val: - if (summary := auto_summarize([x[k] for x in data])) is not None: + all_keys = set().union(*[x.keys() for x in data if isinstance(x, dict)]) + for k in all_keys: + if ( + summary := auto_summarize( + [x.get(k) for x in data if isinstance(x, dict)] + ) + ) is not None: if k in summary: result.update(summary) else: diff --git a/weave/frontend/index.html b/weave/frontend/index.html index 93710ee3208..7d0662afea1 100644 --- a/weave/frontend/index.html +++ b/weave/frontend/index.html @@ -91,8 +91,8 @@ - - + + diff --git a/weave/frontend/sha1.txt b/weave/frontend/sha1.txt index 42ce2d0b3f6..8977b1f3826 100644 --- a/weave/frontend/sha1.txt +++ b/weave/frontend/sha1.txt @@ -1 +1 @@ -3092c45b6a1cf06977f4510864e61b268fae92b1 +9ffbbb13b8ce1e58bf851e179dff32aa74db3de9 diff --git a/weave/integrations/langchain/langchain.py b/weave/integrations/langchain/langchain.py index f6995850d84..003eb8170f4 100644 --- a/weave/integrations/langchain/langchain.py +++ b/weave/integrations/langchain/langchain.py @@ -17,7 +17,7 @@ 3. Respecting User Settings: The patcher respects any existing WEAVE_TRACE_LANGCHAIN environment variable set by the user: - - If not set, it's set to "true" and global patchin is enabled. + - If not set, it's set to "true" and global patching is enabled. - If already set, its value is preserved 4. Context Manager: @@ -60,7 +60,7 @@ def make_valid_run_name(name: str) -> str: name = name.replace("<", "_").replace(">", "") - valid_run_name = re.sub(r"[^a-zA-Z0-9 .-_]", "_", name) + valid_run_name = re.sub(r"[^a-zA-Z0-9 .\\-_]", "_", name) return valid_run_name def _run_to_dict(run: Run, as_input: bool = False) -> dict: diff --git a/weave/tests/test_evaluations.py b/weave/tests/test_evaluations.py index e98e6af380f..09909d23558 100644 --- a/weave/tests/test_evaluations.py +++ b/weave/tests/test_evaluations.py @@ -586,3 +586,40 @@ def function_model(sentence: str) -> dict: # # 2: Assert that the model was correctly oped # assert shouldBeModelRef.startswith("weave:///") + + +@pytest.mark.asyncio +async def test_eval_is_robust_to_missing_values(client): + # At least 1 None + # All dicts have "d": None + resp = [ + None, + {"a": 1, "b": {"c": 2}, "always_none": None}, + {"a": 2, "b": {"c": None}, "always_none": None}, + {"a": 3, "b": {}, "always_none": None}, + {"a": 4, "b": None, "always_none": None}, + {"a": 5, "always_none": None}, + {"a": None, "always_none": None}, + {"always_none": None}, + {}, + ] + + @weave.op + def model_func(model_res) -> dict: + return resp[model_res] + + def function_score(scorer_res, model_output) -> dict: + return resp[scorer_res] + + evaluation = weave.Evaluation( + name="fruit_eval", + dataset=[{"model_res": i, "scorer_res": i} for i in range(len(resp))], + scorers=[function_score], + ) + + res = await evaluation.evaluate(model_func) + assert res == { + "model_output": {"a": {"mean": 3.0}, "b": {"c": {"mean": 2.0}}}, + "function_score": {"a": {"mean": 3.0}, "b": {"c": {"mean": 2.0}}}, + "model_latency": {"mean": pytest.approx(0, abs=1)}, + } diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 1c0d9ecf1c0..afdf8a9b7a0 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -191,6 +191,10 @@ class CallsQueryReq(BaseModel): sort_by: typing.Optional[typing.List[_SortBy]] = None query: typing.Optional[Query] = None + # TODO: type this with call schema columns, following the same rules as + # _SortBy and thus GetFieldOperator.get_field_ (without direction) + columns: typing.Optional[typing.List[str]] = None + class CallsQueryRes(BaseModel): calls: typing.List[CallSchema] diff --git a/weave/version.py b/weave/version.py index 335954d9c69..c3dcdfb7e30 100644 --- a/weave/version.py +++ b/weave/version.py @@ -1 +1 @@ -VERSION = "0.50.12" +VERSION = "0.50.13"