diff --git a/weave-js/src/components/Icon/Icon.tsx b/weave-js/src/components/Icon/Icon.tsx index f8cec63146b..fce338e1391 100644 --- a/weave-js/src/components/Icon/Icon.tsx +++ b/weave-js/src/components/Icon/Icon.tsx @@ -637,6 +637,12 @@ export const IconLogOut = (props: SVGIconProps) => ( export const IconLogoColab = (props: SVGIconProps) => ( ); +export const IconMarker = (props: SVGIconProps) => ( + +); +export const IconReloadRefresh = (props: SVGIconProps) => ( + +); export const IconMagicWandStar = (props: SVGIconProps) => ( ); @@ -1083,6 +1089,7 @@ const ICON_NAME_TO_ICON: Record = { 'cube-container': IconCubeContainer, 'dashboard-blackboard': IconDashboardBlackboard, 'database-artifacts': IconDatabaseArtifacts, + 'reload-refresh': IconReloadRefresh, date: IconDate, delete: IconDelete, diamond: IconDiamond, @@ -1157,6 +1164,7 @@ const ICON_NAME_TO_ICON: Record = { 'logo-colab': IconLogoColab, 'magic-wand-star': IconMagicWandStar, 'magic-wand-stick': IconMagicWandStick, + marker: IconMarker, markdown: IconMarkdown, marker: IconMarker, menu: IconMenu, diff --git a/weave-js/src/components/Icon/types.ts b/weave-js/src/components/Icon/types.ts index e536e365157..4bb158ff7d4 100644 --- a/weave-js/src/components/Icon/types.ts +++ b/weave-js/src/components/Icon/types.ts @@ -176,6 +176,7 @@ export const IconNames = { RemoveAlt: 'remove-alt', Report: 'report', Retry: 'retry', + ReloadRefresh: 'reload-refresh', RobotServiceMember: 'robot-service-member', RocketLaunch: 'rocket-launch', RowHeightLarge: 'row-height-large', diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse2/SmallRef.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse2/SmallRef.tsx index e6f4c0955a8..288a8d2beea 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse2/SmallRef.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse2/SmallRef.tsx @@ -124,6 +124,7 @@ export const SmallRef: FC<{ }; } } + // if a wandb artifact const objectVersion = useObjectVersion(objVersionKey); const opVersion = useOpVersion(opVersionKey); const versionIndex = diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index bee4705042c..da1cadf78b8 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -162,6 +162,58 @@ const browse3Paths = (projectRoot: string) => [ `${projectRoot}`, ]; +// Create a context for managing call IDs +export const CallIdContext = React.createContext<{ + setCallIds?: (callIds: string[]) => void; + getNextCallId?: (currentId: string) => string | null; + getPreviousCallId?: (currentId: string) => string | null; + nextPageNeeded: boolean; +}>({ + setCallIds: () => {}, + getNextCallId: () => null, + getPreviousCallId: () => null, + nextPageNeeded: false, +}); + +const CallIdProvider: FC<{children: React.ReactNode}> = ({children}) => { + const [callIds, setCallIds] = useState([]); + const [nextPageNeeded, setNextPageNeeded] = useState(false); + + const getNextCallId = useCallback( + (currentId: string) => { + const currentIndex = callIds.indexOf(currentId); + if ( + currentIndex === callIds.length - 1 && + callIds.length === DEFAULT_PAGE_SIZE + ) { + setNextPageNeeded(true); + } else if (currentIndex !== -1) { + return callIds[currentIndex + 1]; + } else if (nextPageNeeded) { + setNextPageNeeded(false); + return callIds[0]; + } + return null; + }, + [callIds, nextPageNeeded] + ); + + const getPreviousCallId = useCallback( + (currentId: string) => { + const currentIndex = callIds.indexOf(currentId); + return callIds[currentIndex - 1]; + }, + [callIds] + ); + + return ( + + {children} + + ); +}; + export const Browse3: FC<{ hideHeader?: boolean; headerOffset?: number; @@ -338,90 +390,92 @@ const MainPeekingLayout: FC = () => { - + - - - - -
+ + + + - {peekLocation && ( - - - - +
+
+ {options.map((option, index) => ( +
+ {option} +
+ ))} +
+ + ); +}; + +const createStructuredFeedback = ( + type: string, + name: string, + min?: number, + max?: number, + options?: string[] +): StructuredFeedback => { + // validate min and max + switch (type) { + case 'NumericalFeedback': + // validate min and max dont conflict + if (min && max && min > max) { + throw new Error('Min is greater than max'); + } + return {type: 'NumericalFeedback', name, min: min!, max: max!}; + case 'TextFeedback': + return {type: 'TextFeedback', name}; + case 'CategoricalFeedback': + return { + type: 'CategoricalFeedback', + name, + options: options!, + multiSelect: false, + addNewOption: false, + }; + case 'BooleanFeedback': + return {type: 'BooleanFeedback', name}; + case 'EmojiFeedback': + return {type: 'EmojiFeedback', name}; + default: + throw new Error('Invalid feedback type'); + } +}; + +const FeedbackTypeSelector = ({ + selectedFeedbackType, + setSelectedFeedbackType, + feedbackTypeOptions, + readOnly, +}: { + selectedFeedbackType: string; + setSelectedFeedbackType: (value: string) => void; + feedbackTypeOptions: Array<{name: string; value: string}>; + readOnly?: boolean; +}) => { + return ( +
+ Metric type + option.name} + onChange={(e, newValue) => + setSelectedFeedbackType(newValue?.value ?? '') + } + value={feedbackTypeOptions.find( + option => option.value === selectedFeedbackType + )} + renderInput={params => ( + + )} + disableClearable + sx={{ + minWidth: '244px', + width: 'auto', + }} + fullWidth + ListboxProps={{ + style: { + maxHeight: '200px', + }, + }} + disabled={readOnly} + renderOption={(props, option) => ( +
  • + {option.name ||  } +
  • + )} + /> +
    + ); +}; + +const submitStructuredFeedback = ( + entity: string, + project: string, + newFeedback: StructuredFeedback, + existingFeedbackColumns: StructuredFeedback[], + editColumnName: string | null, + getTsClient: () => any, + onClose: () => void +) => { + const tsClient = getTsClient(); + let updatedTypes: StructuredFeedback[]; + + if (editColumnName) { + updatedTypes = existingFeedbackColumns.map(t => + t.name === editColumnName ? newFeedback : t + ); + } else { + updatedTypes = [...existingFeedbackColumns, newFeedback]; + } + + const value: StructuredFeedbackSpec = { + _bases: ['StructuredFeedback', 'Object', 'BaseModel'], + _class_name: 'StructuredFeedback', + types: updatedTypes, + }; + + const req = { + obj: { + project_id: `${entity}/${project}`, + object_id: 'StructuredFeedback-obj', + val: value, + }, + }; + + tsClient + .objCreate(req) + .then(() => { + onClose(); + }) + .catch((e: any) => { + console.error( + `Error ${editColumnName ? 'updating' : 'creating'} structured feedback`, + e + ); + }); +}; + +const CreateStructuredFeedbackModal = ({ + entity, + project, + existingFeedbackColumns, + onClose, +}: { + entity: string; + project: string; + existingFeedbackColumns: StructuredFeedback[]; + onClose: () => void; +}) => { + const [open, setOpen] = useState(true); + const [nameField, setNameField] = useState(''); + const [selectedFeedbackType, setSelectedFeedbackType] = + useState('Numerical feedback'); + const [minValue, setMinValue] = useState(undefined); + const [maxValue, setMaxValue] = useState(undefined); + const [categoricalOptions, setCategoricalOptions] = useState([]); + const getTsClient = useGetTraceServerClientContext(); + + const submit = () => { + const option = FEEDBACK_TYPE_OPTIONS.find( + o => o.value === selectedFeedbackType + ); + if (!option) { + console.error( + `Invalid feedback type: ${selectedFeedbackType}, options: ${FEEDBACK_TYPE_OPTIONS.map( + o => o.value + ).join(', ')}` + ); + return; + } + let newFeedback; + try { + newFeedback = createStructuredFeedback( + option.value, + nameField, + minValue, + maxValue, + categoricalOptions + ); + } catch (e) { + console.error(e); + return; + } + submitStructuredFeedback( + entity, + project, + newFeedback, + existingFeedbackColumns, + null, + getTsClient, + onClose + ); + }; + + return ( + { + setOpen(false); + onClose(); + }} + maxWidth="xs" + fullWidth> + + Add feedback column + +
    + Metric name + setNameField(value)} + placeholder="..." + /> +
    + + {selectedFeedbackType === 'NumericalFeedback' && ( + + )} + {selectedFeedbackType === 'CategoricalFeedback' && ( + + )} +
    + + + +
    +
    + ); +}; + +const EditStructuredFeedbackModal = ({ + entity, + project, + feedbackColumn, + structuredFeedbackData, + onClose, +}: { + entity: string; + project: string; + feedbackColumn: StructuredFeedback; + structuredFeedbackData: StructuredFeedbackSpec; + onClose: () => void; +}) => { + const [open, setOpen] = useState(true); + const [nameField, setNameField] = useState(feedbackColumn.name); + const [selectedFeedbackType, setSelectedFeedbackType] = useState( + feedbackColumn.type + ); + const [minValue, setMinValue] = useState( + 'min' in feedbackColumn ? feedbackColumn.min : undefined + ); + const [maxValue, setMaxValue] = useState( + 'max' in feedbackColumn ? feedbackColumn.max : undefined + ); + const [categoricalOptions, setCategoricalOptions] = useState( + 'options' in feedbackColumn ? feedbackColumn.options : [] + ); + + const getTsClient = useGetTraceServerClientContext(); + + const submit = () => { + let updatedFeedbackColumn; + try { + updatedFeedbackColumn = createStructuredFeedback( + selectedFeedbackType, + nameField, + minValue, + maxValue, + categoricalOptions + ); + } catch (e) { + console.error(e); + return; + } + submitStructuredFeedback( + entity, + project, + updatedFeedbackColumn, + structuredFeedbackData.types, + feedbackColumn.name, + getTsClient, + onClose + ); + }; + + return ( + { + setOpen(false); + onClose(); + }} + maxWidth="xs" + fullWidth> + + Edit structured feedback + +
    + Metric name + setNameField(value)} + placeholder="..." + /> +
    + + {selectedFeedbackType === 'NumericalFeedback' && ( + + )} + {selectedFeedbackType === 'CategoricalFeedback' && ( + + )} +
    + + + +
    +
    + ); +}; + +export const ConfigureStructuredFeedbackModal = ({ + entity, + project, + structuredFeedbackData, + editColumnName, + onClose, +}: { + entity: string; + project: string; + structuredFeedbackData?: StructuredFeedbackSpec; + editColumnName?: string; + onClose: () => void; +}) => { + if (editColumnName && structuredFeedbackData) { + const feedbackColumn = structuredFeedbackData?.types.find( + t => t.name === editColumnName + ); + if (!feedbackColumn) { + console.error(`Feedback column not found: ${editColumnName}`); + return null; + } + return ( + + ); + } else { + return ( + + ); + } +}; + +const DialogContent = styled(MaterialDialogContent)` + padding: 0 32px !important; +`; +DialogContent.displayName = 'S.DialogContent'; + +const DialogTitle = styled(MaterialDialogTitle)` + padding: 32px 32px 16px 32px !important; + + h2 { + font-weight: 600; + font-size: 24px; + line-height: 30px; + } +`; +DialogTitle.displayName = 'S.DialogTitle'; + +const DialogActions = styled(MaterialDialogActions)<{$align: string}>` + justify-content: ${({$align}) => + $align === 'left' ? 'flex-start' : 'flex-end'} !important; + padding: 32px 32px 32px 32px !important; +`; +DialogActions.displayName = 'S.DialogActions'; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/StructuredFeedback.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/StructuredFeedback.tsx new file mode 100644 index 00000000000..cdb9403772f --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/StructuredFeedback.tsx @@ -0,0 +1,506 @@ +import {Checkbox} from '@mui/material'; +import {Autocomplete, TextField as MuiTextField} from '@mui/material'; +import {MOON_300} from '@wandb/weave/common/css/color.styles'; +import {TextField} from '@wandb/weave/components/Form/TextField'; +import {LoadingDots} from '@wandb/weave/components/LoadingDots'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import debounce from 'lodash/debounce'; +import React, { + SyntheticEvent, + useCallback, + useEffect, + useMemo, + useState, +} from 'react'; + +import {useWFHooks} from '../../pages/wfReactInterface/context'; +import {useGetTraceServerClientContext} from '../../pages/wfReactInterface/traceServerClientContext'; +import { + FeedbackCreateError, + FeedbackCreateReq, + FeedbackCreateRes, + FeedbackCreateSuccess, + FeedbackReplaceReq, + FeedbackReplaceRes, +} from '../../pages/wfReactInterface/traceServerClientTypes'; + +// Constants +const STRUCTURED_FEEDBACK_TYPE = 'wandb.structuredFeedback.1'; +const FEEDBACK_TYPES = { + NUMERICAL: 'NumericalFeedback', + TEXT: 'TextFeedback', + CATEGORICAL: 'CategoricalFeedback', + BOOLEAN: 'BooleanFeedback', +}; +const DEBOUNCE_VAL = 150; + +// Interfaces +interface StructuredFeedbackProps { + sfData: any; + callRef: string; + entity: string; + project: string; + readOnly?: boolean; + focused?: boolean; +} + +// Utility function for creating feedback request +const createFeedbackRequest = ( + props: StructuredFeedbackProps, + value: any, + currentFeedbackId: string | null +) => { + const baseRequest = { + project_id: `${props.entity}/${props.project}`, + weave_ref: props.callRef, + creator: null, + feedback_type: STRUCTURED_FEEDBACK_TYPE, + payload: { + value, + ref: props.sfData.ref, + name: props.sfData.name, + }, + sort_by: [{created_at: 'desc'}], + }; + + if (currentFeedbackId) { + return {...baseRequest, feedback_id: currentFeedbackId}; + } + + return baseRequest; +}; + +const renderFeedbackComponent = ( + props: StructuredFeedbackProps, + onAddFeedback: (value: any) => Promise, + foundValue: string | number | null, + currentFeedbackId: string | null +) => { + switch (props.sfData.type) { + case FEEDBACK_TYPES.NUMERICAL: + return ( + + ); + case FEEDBACK_TYPES.TEXT: + return ( + + ); + case FEEDBACK_TYPES.CATEGORICAL: + return ( + + ); + case FEEDBACK_TYPES.BOOLEAN: + return ( + + ); + default: + return
    Unknown feedback type
    ; + } +}; + +export const StructuredFeedbackCell: React.FC< + StructuredFeedbackProps +> = props => { + const {useFeedback} = useWFHooks(); + const query = useFeedback({ + entity: props.entity, + project: props.project, + weaveRef: props.callRef, + }); + + const [currentFeedbackId, setCurrentFeedbackId] = useState( + null + ); + const [foundValue, setFoundValue] = useState(null); + const getTsClient = useGetTraceServerClientContext(); + + useEffect(() => { + if (!props.readOnly) { + // We don't need to listen for feedback changes if the cell is editable + return; + } + return getTsClient().registerOnFeedbackListener( + props.callRef, + query.refetch + ); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + useEffect(() => { + if (props.callRef !== query?.result?.[0]?.weave_ref) { + // The call was changed without the component unmounted, we need to reset + setFoundValue(null); + setCurrentFeedbackId(null); + } + }, [props.callRef, query?.result]); + + const onAddFeedback = async (value: any): Promise => { + const tsClient = getTsClient(); + + if (!tsClient) { + console.error('Failed to get trace server client'); + return false; + } + + try { + let res: FeedbackCreateRes | FeedbackReplaceRes; + + if (currentFeedbackId) { + const replaceRequest = createFeedbackRequest( + props, + value, + currentFeedbackId + ) as FeedbackReplaceReq; + res = await tsClient.feedbackReplace(replaceRequest); + } else { + const createRequest = createFeedbackRequest( + props, + value, + null + ) as FeedbackCreateReq; + res = await tsClient.feedbackCreate(createRequest); + } + + if ('detail' in res) { + const errorRes = res as FeedbackCreateError; + console.error( + `Feedback ${currentFeedbackId ? 'replace' : 'create'} failed:`, + errorRes.detail + ); + return false; + } + const successRes = res as FeedbackCreateSuccess; + + if (successRes.id) { + setCurrentFeedbackId(successRes.id); + return true; + } + + return false; + } catch (error) { + console.error(`Error in onAddFeedback:`, error); + return false; + } + }; + + useEffect(() => { + if (query?.loading) { + return; + } + + // 3 conditions must be true for the feedback to be valid for this component: + // 1. Feedback is for this feedback spec + // 2. Feedback is for this structured feedback type + // 3. Feedback is for this structured feedback name + + const feedbackTypeMatches = (feedback: any) => + feedback.feedback_type === STRUCTURED_FEEDBACK_TYPE; + const feedbackNameMatches = (feedback: any) => + feedback.payload.name === props.sfData.name; + const feedbackSpecMatches = (feedback: any) => + feedback.payload.ref === props.sfData.ref; + + const currFeedback = query.result?.find( + (feedback: any) => + feedbackTypeMatches(feedback) && + feedbackNameMatches(feedback) && + feedbackSpecMatches(feedback) + ); + if (!currFeedback) { + return; + } + + setCurrentFeedbackId(currFeedback.id); + setFoundValue(currFeedback?.payload?.value ?? null); + }, [query?.result, query?.loading, props.sfData]); + + if (query?.loading) { + return ; + } + + if (props.readOnly) { + return
    {foundValue}
    ; + } + + return ( +
    + {renderFeedbackComponent( + props, + onAddFeedback, + foundValue, + currentFeedbackId + )} +
    + ); +}; + +export const NumericalFeedbackColumn = ({ + min, + max, + onAddFeedback, + defaultValue, + currentFeedbackId, + focused, +}: { + min: number; + max: number; + onAddFeedback?: (value: number, currentFeedbackId: string | null) => Promise; + defaultValue: number | null; + currentFeedbackId?: string | null; + focused?: boolean; +}) => { + const [value, setValue] = useState( + defaultValue ?? undefined + ); + const [error, setError] = useState(false); + + useEffect(() => { + setValue(defaultValue ?? undefined); + }, [defaultValue]); + + const debouncedOnAddFeedback = useCallback( + debounce((val: number) => { + onAddFeedback?.(val, currentFeedbackId ?? null); + }, DEBOUNCE_VAL), + [onAddFeedback, currentFeedbackId] + ); + + const onValueChange = (v: string) => { + const val = parseInt(v); + setValue(val); + if (val < min || val > max) { + setError(true); + return; + } else { + setError(false); + } + debouncedOnAddFeedback(val); + }; + + return ( +
    +
    + min: {min}, max: {max} +
    + +
    + ); +}; + +export const TextFeedbackColumn = ({ + onAddFeedback, + defaultValue, + currentFeedbackId, + focused, +}: { + onAddFeedback?: ( + value: string, + currentFeedbackId: string | null + ) => Promise; + defaultValue: string | null; + currentFeedbackId?: string | null; + focused?: boolean; +}) => { + const [value, setValue] = useState(defaultValue ?? ''); + + useEffect(() => { + setValue(defaultValue ?? ''); + }, [defaultValue]); + + const debouncedOnAddFeedback = useCallback( + debounce((val: string) => { + onAddFeedback?.(val, currentFeedbackId ?? null); + }, DEBOUNCE_VAL), + [onAddFeedback, currentFeedbackId] + ); + + const onValueChange = (newValue: string) => { + setValue(newValue); + debouncedOnAddFeedback(newValue); + }; + + return ( +
    + +
    + ); +}; + +type Option = { + label: string; + value: string; +}; + +export const CategoricalFeedbackColumn = ({ + options, + onAddFeedback, + defaultValue, + currentFeedbackId, + focused, +}: { + options: string[]; + onAddFeedback?: ( + value: string, + currentFeedbackId: string | null + ) => Promise; + defaultValue: string | null; + currentFeedbackId?: string | null; + focused?: boolean; +}) => { + const dropdownOptions = useMemo(() => { + const opts = options.map((option: string) => ({ + label: option, + value: option, + })); + opts.splice(0, 0, {label: '', value: ''}); + return opts; + }, [options]); + const [value, setValue] = useState
    ); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/ManageColumnsButton.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/ManageColumnsButton.tsx index d9946c05d36..d518ca2c762 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/ManageColumnsButton.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/ManageColumnsButton.tsx @@ -20,12 +20,14 @@ type ManageColumnsButtonProps = { columnInfo: ColumnInfo; columnVisibilityModel: GridColumnVisibilityModel; setColumnVisibilityModel: (model: GridColumnVisibilityModel) => void; + onEditColumns: (existingColumn?: string) => void; }; export const ManageColumnsButton = ({ columnInfo, columnVisibilityModel, setColumnVisibilityModel, + onEditColumns, }: ManageColumnsButtonProps) => { const [search, setSearch] = useState(''); const lowerSearch = search.toLowerCase(); @@ -83,6 +85,11 @@ export const ManageColumnsButton = ({ setColumnVisibilityModel(newModel); }; + const handleEditColumn = (columnName: string) => { + onEditColumns(columnName); + setAnchorEl(null); + }; + return ( <> @@ -121,9 +128,17 @@ export const ManageColumnsButton = ({
    Manage columns
    -
    - {maybePluralize(numHidden, 'hidden column', 's')} -
    +
    +
    @@ -141,6 +156,7 @@ export const ManageColumnsButton = ({ const checked = columnVisibilityModel[col.field] ?? true; const label = col.headerName ?? value; const disabled = !(col.hideable ?? true); + const feedbackCol = col.field.startsWith('feedback.'); if ( search && !label.toLowerCase().includes(search.toLowerCase()) @@ -172,6 +188,23 @@ export const ManageColumnsButton = ({ )}> {label} + {feedbackCol && ( + <> +
    +
    ); @@ -186,7 +219,6 @@ export const ManageColumnsButton = ({ onClick={onHideAll}> {`Hide ${buttonSuffix}`} -
    +
    +
    + {maybePluralize(numHidden, 'hidden column', 's')} +
    diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx index 2ef528c2cf0..ac2b1a7530b 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx @@ -19,6 +19,7 @@ import {isWeaveObjectRef, parseRef} from '../../../../../../react'; import {makeRefCall} from '../../../../../../util/refs'; import {Timestamp} from '../../../../../Timestamp'; import {Reactions} from '../../feedback/Reactions'; +import {StructuredFeedbackCell} from '../../feedback/StructuredFeedback/StructuredFeedback'; import {CellFilterWrapper, OnAddFilter} from '../../filters/CellFilterWrapper'; import {isWeaveRef} from '../../filters/common'; import { @@ -57,7 +58,8 @@ export const useCallsTableColumns = ( columnIsRefExpanded: (col: string) => boolean, allowedColumnPatterns?: string[], onAddFilter?: OnAddFilter, - costsLoading: boolean = false + costsLoading: boolean = false, + structuredFeedbackOptions: any | null = null ) => { const [userDefinedColumnWidths, setUserDefinedColumnWidths] = useState< Record @@ -134,7 +136,8 @@ export const useCallsTableColumns = ( userDefinedColumnWidths, allowedColumnPatterns, onAddFilter, - costsLoading + costsLoading, + structuredFeedbackOptions ), [ entity, @@ -152,6 +155,7 @@ export const useCallsTableColumns = ( allowedColumnPatterns, onAddFilter, costsLoading, + structuredFeedbackOptions, ] ); @@ -177,7 +181,8 @@ function buildCallsTableColumns( userDefinedColumnWidths: Record, allowedColumnPatterns?: string[], onAddFilter?: OnAddFilter, - costsLoading: boolean = false + costsLoading: boolean = false, + structuredFeedbackOptions: any | null = null ): { cols: Array>; colGroupingModel: GridColumnGroupingModel; @@ -200,6 +205,34 @@ function buildCallsTableColumns( return a.localeCompare(b); }); + const simpleFeedback = + !structuredFeedbackOptions || + structuredFeedbackOptions?.types?.length === 0; + + const structuredFeedbackColumns = ( + structuredFeedbackOptions?.types ?? [] + ).map((feedbackType: any) => ({ + field: feedbackColName(feedbackType), + headerName: feedbackType.name ?? feedbackType.type, + width: 150, + sortable: false, + filterable: false, + renderCell: (rowParams: GridRenderCellParams) => { + const callId = rowParams.row.id; + const weaveRef = makeRefCall(entity, project, callId); + + return ( + + ); + }, + })); + const cols: Array> = [ { field: 'op_name', @@ -233,30 +266,32 @@ function buildCallsTableColumns( ); }, }, - { - field: 'feedback', - headerName: 'Feedback', - width: 150, - sortable: false, - filterable: false, - renderCell: (rowParams: GridRenderCellParams) => { - const rowIndex = rowParams.api.getRowIndexRelativeToVisibleRows( - rowParams.id - ); - const callId = rowParams.row.id; - const weaveRef = makeRefCall(entity, project, callId); - return ( - - ); + { + field: 'feedback.emojis', + headerName: 'Reactions', + width: 150, + sortable: false, + filterable: false, + renderCell: (rowParams: GridRenderCellParams) => { + const rowIndex = rowParams.api.getRowIndexRelativeToVisibleRows( + rowParams.id + ); + const callId = rowParams.row.id; + const weaveRef = makeRefCall(entity, project, callId); + + return ( + + ); + }, }, - }, + ...structuredFeedbackColumns, ...(isSingleOp && !isSingleOpVersion ? [ { @@ -327,6 +362,17 @@ function buildCallsTableColumns( ); cols.push(...newCols); + const structuredFeedbackFields = structuredFeedbackOptions?.types.map((feedbackType: any) => ({ + field: feedbackColName(feedbackType), + })) ?? []; + const feedbackChildren = [...structuredFeedbackFields, {field: 'feedback.emojis'}]; + + groupingModel.push({ + headerName: 'Feedback', + groupId: 'feedback', + children: feedbackChildren, + }); + cols.push({ field: 'wb_user_id', headerName: 'User', @@ -543,3 +589,7 @@ const refIsExpandable = (ref: string): boolean => { } return false; }; + +export const feedbackColName = (feedbackType: any) => { + return 'feedback.' + (feedbackType.name ?? feedbackType.type); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx index 045ceb54900..7a0e4226ba4 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx @@ -1,18 +1,18 @@ import Box from '@mui/material/Box'; -import {useObjectViewEvent} from '@wandb/weave/integrations/analytics/useViewEvents'; +import { useObjectViewEvent } from '@wandb/weave/integrations/analytics/useViewEvents'; import numeral from 'numeral'; -import React, {useMemo} from 'react'; +import React, { useMemo } from 'react'; -import {maybePluralizeWord} from '../../../../../core/util/string'; -import {Icon, IconName} from '../../../../Icon'; -import {LoadingDots} from '../../../../LoadingDots'; -import {Tailwind} from '../../../../Tailwind'; -import {Tooltip} from '../../../../Tooltip'; -import {NotFoundPanel} from '../NotFoundPanel'; -import {CustomWeaveTypeProjectContext} from '../typeViews/CustomWeaveTypeDispatcher'; -import {WeaveCHTableSourceRefContext} from './CallPage/DataTableView'; -import {ObjectViewerSection} from './CallPage/ObjectViewerSection'; -import {WFHighLevelCallFilter} from './CallsPage/callsTableFilter'; +import { maybePluralizeWord } from '../../../../../core/util/string'; +import { Icon, IconName } from '../../../../Icon'; +import { LoadingDots } from '../../../../LoadingDots'; +import { Tailwind } from '../../../../Tailwind'; +import { Tooltip } from '../../../../Tooltip'; +import { NotFoundPanel } from '../NotFoundPanel'; +import { CustomWeaveTypeProjectContext } from '../typeViews/CustomWeaveTypeDispatcher'; +import { WeaveCHTableSourceRefContext } from './CallPage/DataTableView'; +import { ObjectViewerSection } from './CallPage/ObjectViewerSection'; +import { WFHighLevelCallFilter } from './CallsPage/callsTableFilter'; import { CallLink, CallsLink, @@ -20,20 +20,20 @@ import { objectVersionText, OpVersionLink, } from './common/Links'; -import {CenteredAnimatedLoader} from './common/Loader'; +import { CenteredAnimatedLoader } from './common/Loader'; import { ScrollableTabContent, SimpleKeyValueTable, SimplePageLayoutWithHeader, } from './common/SimplePageLayout'; -import {EvaluationLeaderboardTab} from './LeaderboardTab'; -import {TabPrompt} from './TabPrompt'; -import {TabUseDataset} from './TabUseDataset'; -import {TabUseModel} from './TabUseModel'; -import {TabUseObject} from './TabUseObject'; -import {TabUsePrompt} from './TabUsePrompt'; -import {KNOWN_BASE_OBJECT_CLASSES} from './wfReactInterface/constants'; -import {useWFHooks} from './wfReactInterface/context'; +import { EvaluationLeaderboardTab } from './LeaderboardTab'; +import { TabPrompt } from './TabPrompt'; +import { TabUseDataset } from './TabUseDataset'; +import { TabUseModel } from './TabUseModel'; +import { TabUseObject } from './TabUseObject'; +import { TabUsePrompt } from './TabUsePrompt'; +import { KNOWN_BASE_OBJECT_CLASSES } from './wfReactInterface/constants'; +import { useWFHooks } from './wfReactInterface/context'; import { objectVersionKeyToRefUri, refUriToOpVersionKey, @@ -52,6 +52,7 @@ const OBJECT_ICONS: Record = { Model: 'model', Dataset: 'table', Evaluation: 'benchmark-square', + StructuredFeedback: 'forum-chat-bubble', }; const ObjectIcon = ({baseObjectClass}: ObjectIconProps) => { if (baseObjectClass in OBJECT_ICONS) { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx index dfa213fde4c..bb6c709946e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx @@ -169,6 +169,8 @@ export const SimplePageLayoutWithHeader: FC<{ leftSidebar?: ReactNode; hideTabsIfSingle?: boolean; isSidebarOpen?: boolean; + isFeedbackSidebarOpen?: boolean; + feedbackSidebarContent?: ReactNode; }> = props => { const {tabs} = props; const simplePageLayoutContextValue = useContext(SimplePageLayoutContext); @@ -242,7 +244,7 @@ export const SimplePageLayoutWithHeader: FC<{ {props.headerExtra} {simplePageLayoutContextValue.headerSuffix} -
    +
    } /> + {props.isFeedbackSidebarOpen && ( + + {props.feedbackSidebarContent} + + )}
    ); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx index 7dd51250fc0..79469c840ad 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx @@ -8,6 +8,7 @@ const colorMap: Record = { Model: 'blue', Dataset: 'green', Evaluation: 'cactus', + StructuredFeedback: 'moon', }; export const TypeVersionCategoryChip: React.FC<{ diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts index 13f305a96d9..2918f2c810b 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts @@ -24,4 +24,5 @@ export const KNOWN_BASE_OBJECT_CLASSES = [ 'Model', 'Dataset', 'Evaluation', + 'StructuredFeedback', ] as const; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClient.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClient.ts index af997016a4a..8b394ee51f2 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClient.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClient.ts @@ -5,6 +5,10 @@ import { FeedbackCreateRes, FeedbackPurgeReq, FeedbackPurgeRes, + FeedbackReplaceReq, + FeedbackReplaceRes, + ObjCreateReq, + ObjCreateRes, TraceCallsDeleteReq, TraceCallUpdateReq, TraceRefsReadBatchReq, @@ -112,6 +116,18 @@ export class TraceServerClient extends DirectTraceServerClient { }); return res; } + public feedbackReplace(req: FeedbackReplaceReq): Promise { + const res = super.feedbackReplace(req).then(replaceRes => { + const listeners = this.onFeedbackListeners[req.weave_ref] ?? []; + listeners.forEach(listener => listener()); + return replaceRes; + }); + return res; + } + + public objCreate(req: ObjCreateReq): Promise { + return super.objCreate(req); + } public readBatch(req: TraceRefsReadBatchReq): Promise { return this.requestReadBatch(req); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts index 88113a37a74..6a64340726e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts @@ -190,6 +190,16 @@ export type FeedbackPurgeError = { detail: string; }; export type FeedbackPurgeRes = FeedbackPurgeSuccess | FeedbackPurgeError; + +export type FeedbackReplaceReq = FeedbackCreateReq & { + feedback_id: string; +}; +export type FeedbackReplaceSuccess = {}; +export type FeedbackReplaceError = { + detail: string; +}; +export type FeedbackReplaceRes = FeedbackCreateRes; + interface TraceObjectsFilter { base_object_classes?: string[]; object_ids?: string[]; @@ -229,6 +239,18 @@ export type TraceObjReadRes = { obj: TraceObjSchema; }; +export type ObjCreateReq = { + obj: { + project_id: string; + object_id: string; + val: any; + }; +}; + +export type ObjCreateRes = { + digest: string; +}; + export type TraceRefsReadBatchReq = { refs: string[]; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts index caaf63b7f56..1333783dd02 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts @@ -23,6 +23,10 @@ import { FeedbackPurgeRes, FeedbackQueryReq, FeedbackQueryRes, + FeedbackReplaceReq, + FeedbackReplaceRes, + ObjCreateReq, + ObjCreateRes, TraceCallReadReq, TraceCallReadRes, TraceCallSchema, @@ -268,6 +272,17 @@ export class DirectTraceServerClient { ); } + public feedbackReplace(req: FeedbackReplaceReq): Promise { + return this.makeRequest( + '/feedback/replace', + req + ); + } + + public objCreate(req: ObjCreateReq): Promise { + return this.makeRequest('/obj/create', req); + } + public fileContent( req: TraceFileContentReadReq ): Promise { diff --git a/weave/flow/structured_feedback.py b/weave/flow/structured_feedback.py new file mode 100644 index 00000000000..644c0f46f91 --- /dev/null +++ b/weave/flow/structured_feedback.py @@ -0,0 +1,60 @@ +from typing import Optional + +from pydantic import BaseModel + +from weave.flow.obj import Object + + +class FeedbackType(Object): + name: str = None + + +class StructuredFeedback(Object): + types: list[FeedbackType] + + +class BinaryFeedback(FeedbackType): + type: str = "BinaryFeedback" + + +class NumericalFeedback(FeedbackType): + type: str = "NumericalFeedback" + + min: float + max: float + + +class TextFeedback(FeedbackType): + type: str = "TextFeedback" + + +class CategoricalFeedback(FeedbackType): + type: str = "CategoricalFeedback" + + options: list[str] + multi_select: bool = False + add_new_option: bool = False + + +if __name__ == "__main__": + import weave + + api = weave.init("griffin_wb/trace-values") + + feedback = StructuredFeedback( + types=[ + NumericalFeedback(min=0, max=5, name="Score"), + CategoricalFeedback( + name="Qualitative", options=["plain", "complex", "spicy"] + ), + BinaryFeedback(name="Viewed"), + CategoricalFeedback( + options=[], + multi_select=True, + add_new_option=True, + name="Tags", + ), + ] + ) + + weave.publish(feedback, name="StructuredFeedback obj") diff --git a/weave/trace/op_type.py b/weave/trace/op_type.py index d7b7c270674..36392a5f33b 100644 --- a/weave/trace/op_type.py +++ b/weave/trace/op_type.py @@ -292,7 +292,9 @@ def func(*args, **kwargs): def {func_name}{sig_str}: ... # Code-capture unavailable for this op """ - )[1:] # skip first newline char + )[ + 1: + ] # skip first newline char return missing_code_template diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index eb4ce265f70..15b010f259b 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -1415,8 +1415,34 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: query = query.project_id(req.project_id) query = query.where(req.query) prepared = query.prepare(database_type="clickhouse") - self.ch_client.query(prepared.sql, prepared.parameters) - return tsi.FeedbackPurgeRes() + query_result = self.ch_client.query(prepared.sql, prepared.parameters) + return tsi.FeedbackPurgeRes(num_deleted=query_result.row_count) + + def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes: + # To replace, first purge, then if successful, create. + purge_request = tsi.FeedbackPurgeReq( + project_id=req.project_id, + query={ + "$expr": { + "$eq": [ + {"$getField": "id"}, + {"$literal": req.feedback_id}, + ], + } + }, + ) + purge_result = self.feedback_purge(purge_request) + if purge_result.num_deleted == 0: + raise InvalidRequest(f"Failed to purge feedback with id {req.feedback_id}") + if purge_result.num_deleted > 1: + raise InvalidRequest( + f"Purged more than one feedback with id {req.feedback_id}" + ) + + create_req = tsi.FeedbackCreateReq(**req.model_dump(exclude={"feedback_id"})) + create_result = self.feedback_create(create_req) + + return tsi.FeedbackReplaceRes(**create_result.model_dump()) def completions_create( self, req: tsi.CompletionsCreateReq diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index 7e085b8f75e..4e31cc9e98e 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -326,6 +326,18 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: req.project_id = self._idc.ext_to_int_project_id(req.project_id) return self._ref_apply(self._internal_trace_server.feedback_purge, req) + def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes: + req.project_id = self._idc.ext_to_int_project_id(req.project_id) + original_user_id = req.wb_user_id + if original_user_id is None: + raise ValueError("wb_user_id cannot be None") + req.wb_user_id = self._idc.ext_to_int_user_id(original_user_id) + res = self._ref_apply(self._internal_trace_server.feedback_replace, req) + if res.wb_user_id != req.wb_user_id: + raise ValueError("Internal Error - User Mismatch") + res.wb_user_id = original_user_id + return res + def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: req.project_id = self._idc.ext_to_int_project_id(req.project_id) return self._ref_apply(self._internal_trace_server.cost_create, req) diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 93a4f510090..f21ed013ebd 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1041,7 +1041,32 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: with self.lock: cursor.execute(prepared.sql, prepared.parameters) conn.commit() - return tsi.FeedbackPurgeRes() + return tsi.FeedbackPurgeRes(num_deleted=cursor.rowcount) + + def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes: + purge_request = tsi.FeedbackPurgeReq( + project_id=req.project_id, + query={ + "$expr": { + "$eq": [ + {"$getField": "id"}, + {"$literal": req.feedback_id}, + ], + } + }, + ) + purge_result = self.feedback_purge(purge_request) + if purge_result.num_deleted == 0: + raise InvalidRequest(f"Failed to purge feedback with id {req.feedback_id}") + if purge_result.num_deleted > 1: + raise InvalidRequest( + f"Purged more than one feedback with id {req.feedback_id}" + ) + + create_req = tsi.FeedbackCreateReq(**req.model_dump(exclude={"feedback_id"})) + create_result = self.feedback_create(create_req) + + return tsi.FeedbackReplaceRes(**create_result.model_dump()) def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: conn, cursor = get_conn_cursor(self.db_path) diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index abdfeae38ac..d9283809b49 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -728,6 +728,14 @@ class FeedbackPurgeReq(BaseModel): class FeedbackPurgeRes(BaseModel): + num_deleted: int + + +class FeedbackReplaceReq(FeedbackCreateReq): + feedback_id: str + + +class FeedbackReplaceRes(FeedbackCreateRes): pass @@ -877,5 +885,6 @@ def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ... def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: ... def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ... def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ... + def feedback_replace(self, req: FeedbackReplaceReq) -> FeedbackReplaceRes: ... # Execute LLM API def completions_create(self, req: CompletionsCreateReq) -> CompletionsCreateRes: ... diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index 34b906a560c..0d2750d26b3 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -527,6 +527,13 @@ def feedback_purge( "/feedback/purge", req, tsi.FeedbackPurgeReq, tsi.FeedbackPurgeRes ) + def feedback_replace( + self, req: Union[tsi.FeedbackReplaceReq, dict[str, Any]] + ) -> tsi.FeedbackReplaceRes: + return self._generic_request( + "/feedback/replace", req, tsi.FeedbackReplaceReq, tsi.FeedbackReplaceRes + ) + # Cost API def cost_query( self, req: Union[tsi.CostQueryReq, dict[str, Any]]