From 8a795284ac77db9ef65f770d7795509b97a39ae1 Mon Sep 17 00:00:00 2001
From: Marie Barr-Ramsey <126013019+mbarrramsey@users.noreply.github.com>
Date: Tue, 29 Oct 2024 16:37:47 -0700
Subject: [PATCH 01/16] feat(weave): adding charts to the traces page (#2745)
---
weave-js/package.json | 1 +
.../Browse3/pages/CallsPage/CallsCharts.tsx | 190 ++++++++++++
.../Browse3/pages/CallsPage/CallsTable.tsx | 22 +-
.../Home/Browse3/pages/CallsPage/Charts.tsx | 293 ++++++++++++++++++
.../pages/CallsPage/callsTableQuery.ts | 23 +-
.../ComparisonDefinitionSection.tsx | 2 +-
weave-js/yarn.lock | 5 +
7 files changed, 525 insertions(+), 11 deletions(-)
create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCharts.tsx
create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/Charts.tsx
diff --git a/weave-js/package.json b/weave-js/package.json
index fd4ae47bf23..d925f9d0a42 100644
--- a/weave-js/package.json
+++ b/weave-js/package.json
@@ -161,6 +161,7 @@
"@types/color": "^3.0.0",
"@types/cytoscape": "^3.2.0",
"@types/cytoscape-dagre": "^2.2.2",
+ "@types/d3-array": "^3.2.1",
"@types/diff": "^5.0.3",
"@types/downloadjs": "^1.4.2",
"@types/is-buffer": "^2.0.0",
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCharts.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCharts.tsx
new file mode 100644
index 00000000000..164122753d8
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsCharts.tsx
@@ -0,0 +1,190 @@
+import {GridFilterModel, GridSortModel} from '@mui/x-data-grid-pro';
+import React, {useMemo} from 'react';
+
+import {MOON_400} from '../../../../../../common/css/color.styles';
+import {IconInfo} from '../../../../../Icon';
+import {WaveLoader} from '../../../../../Loaders/WaveLoader';
+import {Tailwind} from '../../../../../Tailwind';
+import {WFHighLevelCallFilter} from './callsTableFilter';
+import {useCallsForQuery} from './callsTableQuery';
+import {
+ ErrorPlotlyChart,
+ LatencyPlotlyChart,
+ RequestsPlotlyChart,
+} from './Charts';
+
+type CallsChartsProps = {
+ entity: string;
+ project: string;
+ filterModelProp: GridFilterModel;
+ filter: WFHighLevelCallFilter;
+};
+
+const Chart = ({
+ isLoading,
+ chartData,
+ title,
+}: {
+ isLoading: boolean;
+ chartData: any;
+ title: string;
+}) => {
+ const CHART_CONTAINER_STYLES =
+ 'flex-1 rounded-lg border border-moon-250 bg-white p-10';
+ const CHART_TITLE_STYLES = 'ml-12 mt-8 text-base font-semibold text-moon-750';
+ const CHART_HEIGHT = 250;
+ const LOADING_CONTAINER_STYLES = `flex h-[${CHART_HEIGHT}px] items-center justify-center`;
+
+ let chart = null;
+ if (isLoading) {
+ chart = (
+
+
+
+ );
+ } else if (chartData.length > 0) {
+ switch (title) {
+ case 'Latency':
+ chart = (
+
+ );
+ break;
+ case 'Errors':
+ chart = (
+
+ );
+ break;
+ case 'Requests':
+ chart = (
+
+ );
+ break;
+ }
+ } else {
+ chart = (
+
+
+
+
+ No data available for the selected time frame
+
+
+
+ );
+ }
+ return (
+
+ );
+};
+
+export const CallsCharts = ({
+ entity,
+ project,
+ filter,
+ filterModelProp,
+}: CallsChartsProps) => {
+ const columns = useMemo(
+ () => ['started_at', 'ended_at', 'exception', 'id'],
+ []
+ );
+ const columnSet = useMemo(() => new Set(columns), [columns]);
+ const sortCalls: GridSortModel = useMemo(
+ () => [{field: 'started_at', sort: 'desc'}],
+ []
+ );
+ const page = useMemo(
+ () => ({
+ pageSize: 1000,
+ page: 0,
+ }),
+ []
+ );
+
+ const calls = useCallsForQuery(
+ entity,
+ project,
+ filter,
+ filterModelProp,
+ page,
+ sortCalls,
+ columnSet,
+ columns
+ );
+
+ const chartData = useMemo(() => {
+ if (calls.loading || !calls.result || calls.result.length === 0) {
+ return {latency: [], errors: [], requests: []};
+ }
+
+ const data: {
+ latency: Array<{started_at: string; latency: number}>;
+ errors: Array<{started_at: string; isError: boolean}>;
+ requests: Array<{started_at: string}>;
+ } = {
+ latency: [],
+ errors: [],
+ requests: [],
+ };
+
+ calls.result.forEach(call => {
+ const started_at = call.traceCall?.started_at;
+ if (!started_at) {
+ return;
+ }
+ const ended_at = call.traceCall?.ended_at;
+
+ const isError =
+ call.traceCall?.exception !== null &&
+ call.traceCall?.exception !== undefined &&
+ call.traceCall?.exception !== '';
+
+ data.requests.push({started_at});
+
+ if (isError) {
+ data.errors.push({started_at, isError});
+ } else {
+ data.errors.push({started_at, isError: false});
+ }
+
+ if (ended_at !== undefined) {
+ const startTime = new Date(started_at).getTime();
+ const endTime = new Date(ended_at).getTime();
+ const latency = endTime - startTime;
+ data.latency.push({started_at, latency});
+ }
+ });
+ return data;
+ }, [calls.result, calls.loading]);
+
+ const charts = (
+
+
+
+
+
+ );
+
+ return (
+
+ {/* setting the width to the width of the screen minus the sidebar width because of overflow: 'hidden' properties in SimplePageLayout causing issues */}
+
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx
index 224e4d9a12d..25d80005260 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx
@@ -26,6 +26,7 @@ import {
useGridApiRef,
} from '@mui/x-data-grid-pro';
import {MOON_200, TEAL_300} from '@wandb/weave/common/css/color.styles';
+import {Switch} from '@wandb/weave/components';
import {Checkbox} from '@wandb/weave/components/Checkbox/Checkbox';
import {Icon} from '@wandb/weave/components/Icon';
import React, {
@@ -69,6 +70,7 @@ import {traceCallToUICallSchema} from '../wfReactInterface/tsDataModelHooks';
import {EXPANDED_REF_REF_KEY} from '../wfReactInterface/tsDataModelHooksCallRefExpansion';
import {objectVersionNiceString} from '../wfReactInterface/utilities';
import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface';
+import {CallsCharts} from './CallsCharts';
import {CallsCustomColumnMenu} from './CallsCustomColumnMenu';
import {
BulkDeleteButton,
@@ -168,6 +170,7 @@ export const CallsTable: FC<{
allowedColumnPatterns,
}) => {
const {loading: loadingUserInfo, userInfo} = useViewerInfo();
+ const [isMetricsChecked, setMetricsChecked] = useState(false);
const isReadonly =
loadingUserInfo || !userInfo?.username || !userInfo?.teams.includes(entity);
@@ -245,8 +248,8 @@ export const CallsTable: FC<{
project,
effectiveFilter,
filterModelResolved,
- sortModelResolved,
paginationModelResolved,
+ sortModelResolved,
expandedRefCols
);
@@ -742,6 +745,15 @@ export const CallsTable: FC<{
clearSelectedCalls={clearSelectedCalls}
/>
)}
+
+
+
+
+ Metrics
+
{selectedInputObjectVersion && (
}>
+ {isMetricsChecked && (
+
+ )}
= {
+ type: 'date' as const,
+ automargin: true,
+ showgrid: false,
+ linecolor: MOON_300,
+ tickfont: {color: MOON_500},
+ showspikes: false,
+};
+
+const X_AXIS_STYLE_WITH_SPIKES: Partial = {
+ ...X_AXIS_STYLE,
+ showspikes: true,
+ spikemode: 'across',
+ spikethickness: 1,
+ spikecolor: MOON_300,
+};
+
+const Y_AXIS_STYLE: Partial = {
+ automargin: true,
+ griddash: 'dot',
+ showgrid: true,
+ gridcolor: MOON_300,
+ linecolor: MOON_300,
+ showspikes: false,
+ tickfont: {color: MOON_500},
+ zeroline: false,
+};
+
+export const calculateBinSize = (
+ data: ChartDataLatency[] | ChartDataErrors[] | ChartDataRequests[],
+ targetBinCount = 15
+) => {
+ if (data.length === 0) {
+ return 60;
+ } // default to 60 minutes if no data
+
+ const startTime = moment(_.minBy(data, 'started_at')?.started_at);
+ const endTime = moment(_.maxBy(data, 'started_at')?.started_at);
+
+ const minutesInRange = endTime.diff(startTime, 'minutes');
+
+ // Calculate bin size in minutes, rounded to a nice number
+ const rawBinSize = Math.max(1, Math.ceil(minutesInRange / targetBinCount));
+ const niceNumbers = [1, 2, 5, 10, 15, 30, 60, 120, 240, 360, 720, 1440];
+
+ // Find the closest nice number
+ return niceNumbers.reduce((prev, curr) => {
+ return Math.abs(curr - rawBinSize) < Math.abs(prev - rawBinSize)
+ ? curr
+ : prev;
+ }, niceNumbers[0]);
+};
+
+export const LatencyPlotlyChart: React.FC<{
+ height: number;
+ chartData: ChartDataLatency[];
+ targetBinCount?: number;
+}> = ({height, chartData, targetBinCount}) => {
+ const divRef = useRef(null);
+ const binSize = calculateBinSize(chartData, targetBinCount);
+
+ const plotlyData: Plotly.Data[] = useMemo(() => {
+ const groupedData = _(chartData)
+ .groupBy(d => {
+ const date = moment(d.started_at);
+ const roundedMinutes = Math.floor(date.minutes() / binSize) * binSize;
+ return date.startOf('hour').add(roundedMinutes, 'minutes').format();
+ })
+ .map((group, date) => {
+ const latenciesNonSorted = group.map(d => d.latency);
+ const p50 = quantile(latenciesNonSorted, 0.5) ?? 0;
+ const p95 = quantile(latenciesNonSorted, 0.95) ?? 0;
+ const p99 = quantile(latenciesNonSorted, 0.99) ?? 0;
+ return {timestamp: date, p50, p95, p99};
+ })
+ .value();
+
+ return [
+ {
+ type: 'scatter',
+ mode: 'lines+markers',
+ x: groupedData.map(d => d.timestamp),
+ y: groupedData.map(d => d.p50),
+ name: 'p50 Latency',
+ line: {color: BLUE_500},
+ marker: {color: BLUE_500},
+ hovertemplate: '%{data.name}: %{y:.2f} ms',
+ },
+ {
+ type: 'scatter',
+ mode: 'lines+markers',
+ x: groupedData.map(d => d.timestamp),
+ y: groupedData.map(d => d.p95),
+ name: 'p95 Latency',
+ line: {color: GREEN_500},
+ marker: {color: GREEN_500},
+ hovertemplate: '%{data.name}: %{y:.2f} ms',
+ },
+ {
+ type: 'scatter',
+ mode: 'lines+markers',
+ x: groupedData.map(d => d.timestamp),
+ y: groupedData.map(d => d.p99),
+ name: 'p99 Latency',
+ line: {color: MOON_500},
+ marker: {color: MOON_500},
+ hovertemplate: '%{data.name}: %{y:.2f} ms',
+ },
+ ];
+ }, [chartData, binSize]);
+
+ useEffect(() => {
+ const plotlyLayout: Partial = {
+ height,
+ margin: CHART_MARGIN_STYLE,
+ xaxis: X_AXIS_STYLE_WITH_SPIKES,
+ yaxis: Y_AXIS_STYLE,
+ hovermode: 'x unified',
+ showlegend: false,
+ hoverlabel: {
+ bordercolor: MOON_200,
+ },
+ };
+
+ const plotlyConfig: Partial = {
+ displayModeBar: false,
+ responsive: true,
+ };
+
+ if (divRef.current) {
+ Plotly.newPlot(divRef.current, plotlyData, plotlyLayout, plotlyConfig);
+ }
+ }, [plotlyData, height]);
+
+ return ;
+};
+
+export const ErrorPlotlyChart: React.FC<{
+ height: number;
+ chartData: ChartDataErrors[];
+ targetBinCount?: number;
+}> = ({height, chartData, targetBinCount}) => {
+ const divRef = useRef(null);
+ const binSize = calculateBinSize(chartData, targetBinCount);
+
+ const plotlyData: Plotly.Data[] = useMemo(() => {
+ const groupedData = _(chartData)
+ .groupBy(d => {
+ const date = moment(d.started_at);
+ const roundedMinutes = Math.floor(date.minutes() / binSize) * binSize;
+ return date.startOf('hour').add(roundedMinutes, 'minutes').format();
+ })
+ .map((group, date) => ({
+ timestamp: date,
+ count: group.filter(d => d.isError).length,
+ }))
+ .value();
+
+ return [
+ {
+ type: 'bar',
+ x: groupedData.map(d => d.timestamp),
+ y: groupedData.map(d => d.count),
+ name: 'Error Count',
+ marker: {color: RED_400},
+ hovertemplate: '%{y} errors',
+ },
+ ];
+ }, [chartData, binSize]);
+
+ useEffect(() => {
+ const plotlyLayout: Partial = {
+ height,
+ margin: CHART_MARGIN_STYLE,
+ bargap: 0.2,
+ xaxis: X_AXIS_STYLE,
+ yaxis: Y_AXIS_STYLE,
+ hovermode: 'x unified',
+ hoverlabel: {
+ bordercolor: MOON_200,
+ },
+ dragmode: 'zoom',
+ };
+
+ const plotlyConfig: Partial = {
+ displayModeBar: false,
+ responsive: true,
+ };
+
+ if (divRef.current) {
+ Plotly.newPlot(divRef.current, plotlyData, plotlyLayout, plotlyConfig);
+ }
+ }, [plotlyData, height]);
+
+ return ;
+};
+
+export const RequestsPlotlyChart: React.FC<{
+ height: number;
+ chartData: ChartDataRequests[];
+ targetBinCount?: number;
+}> = ({height, chartData, targetBinCount}) => {
+ const divRef = useRef(null);
+ const binSize = calculateBinSize(chartData, targetBinCount);
+
+ const plotlyData: Plotly.Data[] = useMemo(() => {
+ const groupedData = _(chartData)
+ .groupBy(d => {
+ const date = moment(d.started_at);
+ const roundedMinutes = Math.floor(date.minutes() / binSize) * binSize;
+ return date.startOf('hour').add(roundedMinutes, 'minutes').format();
+ })
+ .map((group, date) => ({
+ timestamp: date,
+ count: group.length,
+ }))
+ .value();
+
+ return [
+ {
+ type: 'bar',
+ x: groupedData.map(d => d.timestamp),
+ y: groupedData.map(d => d.count),
+ name: 'Requests',
+ marker: {color: TEAL_400},
+ hovertemplate: '%{y} requests',
+ },
+ ];
+ }, [chartData, binSize]);
+
+ useEffect(() => {
+ const plotlyLayout: Partial = {
+ height,
+ margin: CHART_MARGIN_STYLE,
+ xaxis: X_AXIS_STYLE,
+ yaxis: Y_AXIS_STYLE,
+ bargap: 0.2,
+ hovermode: 'x unified',
+ hoverlabel: {
+ bordercolor: MOON_200,
+ },
+ };
+
+ const plotlyConfig: Partial = {
+ displayModeBar: false,
+ responsive: true,
+ };
+
+ if (divRef.current) {
+ Plotly.newPlot(divRef.current, plotlyData, plotlyLayout, plotlyConfig);
+ }
+ }, [plotlyData, height]);
+
+ return ;
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableQuery.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableQuery.ts
index 2a0d1bad489..de221b652dc 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableQuery.ts
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableQuery.ts
@@ -32,9 +32,9 @@ export const useCallsForQuery = (
project: string,
filter: WFHighLevelCallFilter,
gridFilter: GridFilterModel,
- gridSort: GridSortModel,
gridPage: GridPaginationModel,
- expandedColumns: Set,
+ gridSort?: GridSortModel,
+ expandedColumns?: Set,
columns?: string[]
): {
costsLoading: boolean;
@@ -44,8 +44,8 @@ export const useCallsForQuery = (
refetch: () => void;
} => {
const {useCalls, useCallsStats} = useWFHooks();
- const offset = gridPage.page * gridPage.pageSize;
- const limit = gridPage.pageSize;
+ const effectiveOffset = gridPage?.page * gridPage?.pageSize;
+ const effectiveLimit = gridPage.pageSize;
const {sortBy, lowLevelFilter, filterBy} = useFilterSortby(
filter,
gridFilter,
@@ -56,8 +56,8 @@ export const useCallsForQuery = (
entity,
project,
lowLevelFilter,
- limit,
- offset,
+ effectiveLimit,
+ effectiveOffset,
sortBy,
filterBy,
columns,
@@ -77,11 +77,16 @@ export const useCallsForQuery = (
const total = useMemo(() => {
if (callsStats.loading || callsStats.result == null) {
- return offset + callResults.length;
+ return effectiveOffset + callResults.length;
} else {
return callsStats.result.count;
}
- }, [callResults.length, callsStats.loading, callsStats.result, offset]);
+ }, [
+ callResults.length,
+ callsStats.loading,
+ callsStats.result,
+ effectiveOffset,
+ ]);
const costFilter: CallFilter = useMemo(
() => ({
@@ -94,7 +99,7 @@ export const useCallsForQuery = (
entity,
project,
costFilter,
- limit,
+ effectiveLimit,
undefined,
sortBy,
undefined,
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx
index 3d461681a3c..b5c1a4bf96c 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx
@@ -111,8 +111,8 @@ const AddEvaluationButton: React.FC<{
props.state.data.project,
evaluationsFilter,
DEFAULT_FILTER_CALLS,
- DEFAULT_SORT_CALLS,
page,
+ DEFAULT_SORT_CALLS,
expandedRefCols,
columns
);
diff --git a/weave-js/yarn.lock b/weave-js/yarn.lock
index 3315ac14ade..2ee20553257 100644
--- a/weave-js/yarn.lock
+++ b/weave-js/yarn.lock
@@ -4215,6 +4215,11 @@
resolved "https://registry.yarnpkg.com/@types/cytoscape/-/cytoscape-3.19.10.tgz#f4540749d68cd3db6f89da5197f7ec2a2ca516ee"
integrity sha512-PLsKQcsUd05nz4PYyulIhjkLnlq9oD2WYpswrWOjoqtFZEuuBje0f9fi2zTG5/yfTf5+Gpllf/MPcFmfDzZ24w==
+"@types/d3-array@^3.2.1":
+ version "3.2.1"
+ resolved "https://registry.yarnpkg.com/@types/d3-array/-/d3-array-3.2.1.tgz#1f6658e3d2006c4fceac53fde464166859f8b8c5"
+ integrity sha512-Y2Jn2idRrLzUfAKV2LyRImR+y4oa2AntrgID95SHJxuMUrkNXmanDSed71sRNZysveJVt1hLLemQZIady0FpEg==
+
"@types/debug@^4.0.0":
version "4.1.8"
resolved "https://registry.yarnpkg.com/@types/debug/-/debug-4.1.8.tgz#cef723a5d0a90990313faec2d1e22aee5eecb317"
From 4ebc2b5f5ff1c1cf8cda09e606e333f2fec2430c Mon Sep 17 00:00:00 2001
From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com>
Date: Tue, 29 Oct 2024 18:50:24 -0500
Subject: [PATCH 02/16] feat: prompts as first class objects (#2811)
---
docs/docs/guides/core-types/prompts.md | 373 +++++++++++++++
docs/sidebars.ts | 1 +
tests/trace/test_prompt.py | 23 +
tests/trace/test_prompt_easy.py | 260 +++++++++++
.../components/FancyPage/useProjectSidebar.ts | 7 +
.../Home/Browse3/pages/CallPage/CallPage.tsx | 8 +-
.../Home/Browse3/pages/ObjectVersionPage.tsx | 68 ++-
.../Home/Browse3/pages/OpVersionPage.tsx | 9 +-
.../Home/Browse3/pages/TabPrompt.tsx | 25 +
.../Home/Browse3/pages/TabUseCall.tsx | 2 +-
.../Home/Browse3/pages/TabUseDataset.tsx | 2 +-
.../Home/Browse3/pages/TabUseModel.tsx | 2 +-
.../Home/Browse3/pages/TabUseObject.tsx | 2 +-
.../Home/Browse3/pages/TabUseOp.tsx | 2 +-
.../Home/Browse3/pages/TabUsePrompt.tsx | 99 ++++
weave/__init__.py | 7 +
weave/flow/prompt/common.py | 14 +
weave/flow/prompt/prompt.py | 440 ++++++++++++++++++
weave/integrations/openai/openai_sdk.py | 25 +
weave/trace/op.py | 71 ++-
weave/trace/refs.py | 17 +-
21 files changed, 1417 insertions(+), 40 deletions(-)
create mode 100644 docs/docs/guides/core-types/prompts.md
create mode 100644 tests/trace/test_prompt.py
create mode 100644 tests/trace/test_prompt_easy.py
create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabPrompt.tsx
create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUsePrompt.tsx
create mode 100644 weave/flow/prompt/common.py
create mode 100644 weave/flow/prompt/prompt.py
diff --git a/docs/docs/guides/core-types/prompts.md b/docs/docs/guides/core-types/prompts.md
new file mode 100644
index 00000000000..9a2d50ecf2b
--- /dev/null
+++ b/docs/docs/guides/core-types/prompts.md
@@ -0,0 +1,373 @@
+# Prompts
+
+Creating, evaluating, and refining prompts is a core activity for AI engineers.
+Small changes to a prompt can have big impacts on your application's behavior.
+Weave lets you create prompts, save and retrieve them, and evolve them over time.
+Some of the benefits of Weave's prompt management system are:
+
+- Unopinionated core, with a batteries-included option for rapid development
+- Versioning that shows you how a prompt has evolved over time
+- The ability to update a prompt in production without redeploying your application
+- The ability to evaluate a prompt against many inputs to evaluate performance
+
+## Getting started
+
+If you want complete control over how a Prompt is constructed, you can subclass the base class, `weave.Prompt`, `weave.StringPrompt`, or `weave.MessagesPrompt` and implement the corresponding `format` method. When you publish one of these objects with `weave.publish`, it will appear in your Weave project on the "Prompts" page.
+
+```
+class Prompt(Object):
+ def format(self, **kwargs: Any) -> Any:
+ ...
+
+class StringPrompt(Prompt):
+ def format(self, **kwargs: Any) -> str:
+ ...
+
+class MessagesPrompt(Prompt):
+ def format(self, **kwargs: Any) -> list:
+ ...
+```
+
+Weave also includes a "batteries-included" class called `EasyPrompt` that can be simpler to start with, especially if you are working with APIs that are similar to OpenAI. This document highlights the features you get with EasyPrompt.
+
+## Constructing prompts
+
+You can think of the EasyPrompt object as a list of messages with associated roles, optional
+placeholder variables, and an optional model configuration.
+But constructing a prompt can be as simple as providing a single string:
+
+```python
+import weave
+
+prompt = weave.EasyPrompt("What's 23 * 42?")
+assert prompt[0] == {"role": "user", "content": "What's 23 * 42?"}
+```
+
+For terseness, the weave library aliases the `EasyPrompt` class to `P`.
+
+```python
+from weave import P
+p = P("What's 23 * 42?")
+```
+
+It is common for a prompt to consist of multiple messages. Each message has an associated `role`.
+If the role is omitted, it defaults to `"user"`.
+
+**Some common roles**
+
+| Role | Description |
+| --------- | -------------------------------------------------------------------------------------------------------------------- |
+| system | System prompts provide high level instructions and can be used to set the behavior, knowledge, or persona of the AI. |
+| user | Represents input from a human user. (This is the default role.) |
+| assistant | Represents the AI's generated replies. Can be used for historical completions or to show examples. |
+
+For convenience, you can prefix a message string with one of these known roles:
+
+```python
+import weave
+
+prompt = weave.EasyPrompt("system: Talk like a pirate")
+assert prompt[0] == {"role": "system", "content": "Talk like a pirate"}
+
+# An explicit role parameter takes precedence
+prompt = weave.EasyPrompt("system: Talk like a pirate", role="user")
+assert prompt[0] == {"role": "user", "content": "system: Talk like a pirate"}
+
+```
+
+Messages can be appended to a prompt one-by-one:
+
+```python
+import weave
+
+prompt = weave.EasyPrompt()
+prompt.append("You are an expert travel consultant.", role="system")
+prompt.append("Give me five ideas for top kid-friendly attractions in New Zealand.")
+```
+
+Or you can append multiple messages at once, either with the `append` method or with the `Prompt`
+constructor, which is convenient for constructing a prompt from existing messages.
+
+```python
+import weave
+
+prompt = weave.EasyPrompt()
+prompt.append([
+ {"role": "system", "content": "You are an expert travel consultant."},
+ "Give me five ideas for top kid-friendly attractions in New Zealand."
+])
+
+# Same
+prompt = weave.EasyPrompt([
+ {"role": "system", "content": "You are an expert travel consultant."},
+ "Give me five ideas for top kid-friendly attractions in New Zealand."
+])
+```
+
+The Prompt class is designed to be easily inserted into existing code.
+For example, you can quickly wrap it around all of the arguments to the
+OpenAI chat completion `create` call including its messages and model
+configuration. If you don't wrap the inputs, Weave's integration would still
+track all of the call's inputs, but it would not extract them as a separate
+versioned object. Having a separate Prompt object allows you to version
+the prompt, easily filter calls by that version, etc.
+
+```python
+from weave import init, P
+from openai import OpenAI
+client = OpenAI()
+
+# Must specify a target project, otherwise the Weave code is a no-op
+# highlight-next-line
+init("intro-example")
+
+# highlight-next-line
+response = client.chat.completions.create(P(
+ model="gpt-4o-mini",
+ messages=[
+ {"role": "user", "content": "What's 23 * 42?"}
+ ],
+ temperature=0.7,
+ max_tokens=64,
+ top_p=1
+# highlight-next-line
+))
+```
+
+:::note
+Why this works: Weave's OpenAI integration wraps the OpenAI `create` method to make it a Weave Op.
+When the Op is executed, the Prompt object in the input will get saved and associated with the Call.
+However, it will be replaced with the structure the `create` method expects for the execution of the
+underlying function.
+:::
+
+## Parameterizing prompts
+
+When specifying a prompt, you can include placeholders for values you want to fill in later. These placeholders are called "Parameters".
+Parameters are indicated with curly braces. Here's a simple example:
+
+```python
+import weave
+
+prompt = weave.EasyPrompt("What's {A} + {B}?")
+```
+
+You will specify values for all of the parameters or "bind" them, when you [use the prompt](#using-prompts).
+
+The `require` method of Prompt allows you to associate parameters with restrictions that will be checked at bind time to detect programming errors.
+
+```python
+import weave
+
+prompt = weave.EasyPrompt("What's {A} + 42?")
+prompt.require("A", type="int", min=0, max=100)
+
+prompt = weave.EasyPrompt("system: You are a {profession}")
+prompt.require("profession", oneof=('pirate', 'cartoon mouse', 'hungry dragon'), default='pirate')
+```
+
+## Using prompts
+
+You use a Prompt by converting it into a list of messages where all template placeholders have been filled in. You can bind a prompt to parameter values with the `bind` method or by simply calling it as a function. Here's an example where the prompt has zero parameters.
+
+```python
+import weave
+prompt = weave.EasyPrompt("What's 23 * 42?")
+assert prompt() == prompt.bind() == [
+ {"role": "user", "content": "What's 23 * 42?"}
+]
+```
+
+If a prompt has parameters, you would specify values for them when you use the prompt.
+Parameter values can be passed in as a dictionary or as keyword arguments.
+
+```python
+import weave
+prompt = weave.EasyPrompt("What's {A} + {B}?")
+assert prompt(A=5, B="10") == prompt({"A": 5, "B": "10"})
+```
+
+If any parameters are missing, they will be left unsubstituted in the output.
+
+Here's a complete example of using a prompt with OpenAI. This example also uses [Weave's OpenAI integration](../integrations/openai.md) to automatically log the prompt and response.
+
+```python
+import weave
+from openai import OpenAI
+client = OpenAI()
+
+weave.init("intro-example")
+prompt = weave.EasyPrompt()
+prompt.append("You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative.", role="system")
+prompt.append("I love {this_thing}!")
+
+response = client.chat.completions.create(
+ model="gpt-4o-mini",
+ messages=prompt(this_thing="Weave"),
+ temperature=0.7,
+ max_tokens=64,
+ top_p=1
+)
+```
+
+## Publishing to server
+
+Prompt are a type of [Weave object](../tracking/objects.md), and use the same methods for publishing to the Weave server.
+You must specify a destination project name with `weave.init` before you can publish a prompt.
+
+```python
+import weave
+
+prompt = weave.EasyPrompt()
+prompt.append("What's 23 * 42?")
+
+weave.init("intro-example") # Use entity/project format if not targeting your default entity
+weave.publish(prompt, name="calculation-prompt")
+```
+
+Weave will automatically determine if the object has changed and only publish a new version if it has.
+You can also specify a name or description for the Prompt as part of its constructor.
+
+```python
+import weave
+
+prompt = weave.EasyPrompt(
+ "What's 23 * 42?",
+ name="calculation-prompt",
+ description="A prompt for calculating the product of two numbers.",
+)
+
+weave.init("intro-example")
+weave.publish(prompt)
+```
+
+## Retrieving from server
+
+Prompt are a type of [Weave object](../tracking/objects.md), and use the same methods for retrieval from the Weave server.
+You must specify a source project name with `weave.init` before you can retrieve a prompt.
+
+```python
+import weave
+
+weave.init("intro-example")
+prompt = weave.ref("calculation-prompt").get()
+```
+
+By default, the latest version of the prompt is returned. You can make this explicit or select a specific version by providing its version id.
+
+```python
+import weave
+
+weave.init("intro-example")
+prompt = weave.ref("calculation-prompt:latest").get()
+# ":", for example:
+prompt = weave.ref("calculation-prompt:QSLzr96CTzFwLWgFFi3EuawCI4oODz4Uax98SxIY79E").get()
+```
+
+It is also possible to retrieve a Prompt without calling `init` if you pass a fully qualified URI to `weave.ref`.
+
+## Loading and saving from files
+
+Prompts can be saved to files and loaded from files. This can be convenient if you want your Prompt to be versioned through
+a mechanism other than Weave such as git, or as a fallback if Weave is not available.
+
+To save a prompt to a file, you can use the `dump_file` method.
+
+```python
+import weave
+
+prompt = weave.EasyPrompt("What's 23 * 42?")
+prompt.dump_file("~/prompt.json")
+```
+
+and load it again later with `Prompt.load_file`.
+
+```python
+import weave
+
+prompt = weave.EasyPrompt.load_file("~/prompt.json")
+```
+
+You can also use the lower level `dump` and `Prompt.load` methods for custom (de)serialization.
+
+## Evaluating prompts
+
+The [Parameter feature of prompts](#parameterizing-prompts) can be used to execute or evaluate variations of a prompt.
+
+You can bind each row of a [Dataset](./datasets.md) to generate N variations of a prompt.
+
+```python
+import weave
+
+# Create a dataset
+dataset = weave.Dataset(name='countries', rows=[
+ {'id': '0', 'country': "Argentina"},
+ {'id': '1', 'country': "Belize"},
+ {'id': '2', 'country': "Canada"},
+ {'id': '3', 'country': "New Zealand"},
+])
+
+prompt = weave.EasyPrompt(name='travel_agent')
+prompt.append("You are an expert travel consultant.", role="system")
+prompt.append("Tell me the capital of {country} and about five kid-friendly attractions there.")
+
+
+prompts = prompt.bind_rows(dataset)
+assert prompts[2][1]["content"] == "Tell me the capital of Canada and about five kid-friendly attractions there."
+```
+
+You can extend this into an [Evaluation](./evaluations.md):
+
+```python
+import asyncio
+
+import openai
+import weave
+
+weave.init("intro-example")
+
+# Create a dataset
+dataset = weave.Dataset(name='countries', rows=[
+ {'id': '0', 'country': "Argentina", 'capital': "Buenos Aires"},
+ {'id': '1', 'country': "Belize", 'capital': "Belmopan"},
+ {'id': '2', 'country': "Canada", 'capital': "Ottawa"},
+ {'id': '3', 'country': "New Zealand", 'capital': "Wellington"},
+])
+
+# Create a prompt
+prompt = weave.EasyPrompt(name='travel_agent')
+prompt.append("You are an expert travel consultant.", role="system")
+prompt.append("Tell me the capital of {country} and about five kid-friendly attractions there.")
+
+# Create a model, combining a prompt with model configuration
+class TravelAgentModel(weave.Model):
+
+ model_name: str
+ prompt: weave.EasyPrompt
+
+ @weave.op
+ async def predict(self, country: str) -> dict:
+ client = openai.AsyncClient()
+
+ response = await client.chat.completions.create(
+ model=self.model_name,
+ messages=self.prompt(country=country),
+ )
+ result = response.choices[0].message.content
+ if result is None:
+ raise ValueError("No response from model")
+ return result
+
+# Define and run the evaluation
+@weave.op
+def mentions_capital_scorer(capital: str, model_output: str) -> dict:
+ return {'correct': capital in model_output}
+
+model = TravelAgentModel(model_name="gpt-4o-mini", prompt=prompt)
+evaluation = weave.Evaluation(
+ dataset=dataset,
+ scorers=[mentions_capital_scorer],
+)
+asyncio.run(evaluation.evaluate(model))
+
+```
diff --git a/docs/sidebars.ts b/docs/sidebars.ts
index c5da61462b5..d56f563fd3a 100644
--- a/docs/sidebars.ts
+++ b/docs/sidebars.ts
@@ -64,6 +64,7 @@ const sidebars: SidebarsConfig = {
"guides/evaluation/scorers",
],
},
+ "guides/core-types/prompts",
"guides/core-types/models",
"guides/core-types/datasets",
"guides/tracking/feedback",
diff --git a/tests/trace/test_prompt.py b/tests/trace/test_prompt.py
new file mode 100644
index 00000000000..98bb731d076
--- /dev/null
+++ b/tests/trace/test_prompt.py
@@ -0,0 +1,23 @@
+from weave.flow.prompt.prompt import MessagesPrompt, StringPrompt
+
+
+def test_stringprompt_format():
+ class MyPrompt(StringPrompt):
+ def format(self, **kwargs) -> str:
+ return "Imagine a lot of complicated logic build this string."
+
+ prompt = MyPrompt()
+ assert prompt.format() == "Imagine a lot of complicated logic build this string."
+
+
+def test_messagesprompt_format():
+ class MyPrompt(MessagesPrompt):
+ def format(self, **kwargs) -> list:
+ return [
+ {"role": "user", "content": "What's 23 * 42"},
+ ]
+
+ prompt = MyPrompt()
+ assert prompt.format() == [
+ {"role": "user", "content": "What's 23 * 42"},
+ ]
diff --git a/tests/trace/test_prompt_easy.py b/tests/trace/test_prompt_easy.py
new file mode 100644
index 00000000000..6d01db92a9f
--- /dev/null
+++ b/tests/trace/test_prompt_easy.py
@@ -0,0 +1,260 @@
+import itertools
+
+import pytest
+
+from weave import EasyPrompt
+
+
+def iter_equal(items1, items2):
+ """`True` if iterators `items1` and `items2` contain equal items."""
+ return (items1 is items2) or all(
+ a == b for a, b in itertools.zip_longest(items1, items2, fillvalue=object())
+ )
+
+
+def test_prompt_message_constructor_str():
+ prompt = EasyPrompt("What's 23 * 42")
+ assert prompt() == [{"role": "user", "content": "What's 23 * 42"}]
+
+
+def test_prompt_message_constructor_prefix_str():
+ prompt = EasyPrompt("system: you are a pirate")
+ assert prompt() == [{"role": "system", "content": "you are a pirate"}]
+
+
+def test_prompt_message_constructor_role_arg():
+ prompt = EasyPrompt("You're a calculator.", role="system")
+ assert prompt() == [{"role": "system", "content": "You're a calculator."}]
+
+
+def test_prompt_message_constructor_array():
+ prompt = EasyPrompt(
+ [
+ {"role": "system", "content": "You're a calculator."},
+ {"role": "user", "content": "What's 23 * 42"},
+ ]
+ )
+ assert prompt() == [
+ {"role": "system", "content": "You're a calculator."},
+ {"role": "user", "content": "What's 23 * 42"},
+ ]
+
+
+def test_prompt_message_constructor_obj():
+ prompt = EasyPrompt(
+ name="myprompt",
+ model="gpt-4o",
+ messages=[
+ {
+ "role": "system",
+ "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.",
+ },
+ {
+ "role": "user",
+ "content": "Artificial intelligence is a technology with great promise.",
+ },
+ ],
+ temperature=0.8,
+ max_tokens=64,
+ top_p=1,
+ )
+ assert prompt() == [
+ {
+ "role": "system",
+ "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.",
+ },
+ {
+ "role": "user",
+ "content": "Artificial intelligence is a technology with great promise.",
+ },
+ ]
+ assert prompt.config == {
+ "model": "gpt-4o",
+ "temperature": 0.8,
+ "max_tokens": 64,
+ "top_p": 1,
+ }
+
+
+def test_prompt_append() -> None:
+ prompt = EasyPrompt()
+ prompt.append("You are a helpful assistant.", role="system")
+ prompt.append("system: who knows a lot about geography")
+ prompt.append(
+ """
+ What's the capital of Brazil?
+ """,
+ dedent=True,
+ )
+ assert prompt() == [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "system", "content": "who knows a lot about geography"},
+ {"role": "user", "content": "What's the capital of Brazil?"},
+ ]
+
+
+def test_prompt_append_with_role() -> None:
+ prompt = EasyPrompt()
+ prompt.append("system: who knows a lot about geography", role="asdf")
+ assert prompt() == [
+ {"role": "asdf", "content": "system: who knows a lot about geography"},
+ ]
+
+
+def test_prompt_unbound_iteration() -> None:
+ """We don't error - is that the right behavior?"""
+ prompt = EasyPrompt("Tell me about {x}, {y}, and {z}. Especially {z}.")
+ prompt.bind(y="strawberry")
+ assert prompt.placeholders == ["x", "y", "z"]
+ assert not prompt.is_bound
+ assert prompt.unbound_placeholders == ["x", "z"]
+ assert list(prompt()) == [
+ {
+ "role": "user",
+ "content": "Tell me about {x}, strawberry, and {z}. Especially {z}.",
+ }
+ ]
+ prompt.bind(x="vanilla", z="chocolate")
+ assert prompt.is_bound
+ assert prompt.unbound_placeholders == []
+ assert list(prompt()) == [
+ {
+ "role": "user",
+ "content": "Tell me about vanilla, strawberry, and chocolate. Especially chocolate.",
+ }
+ ]
+
+
+def test_prompt_format_specifiers() -> None:
+ prompt = EasyPrompt("{x:.5}")
+ assert prompt.placeholders == ["x"]
+ assert prompt(x=3.14159)[0]["content"] == "3.1416"
+
+
+def test_prompt_parameter_default() -> None:
+ prompt = EasyPrompt("{A} * {B}")
+ prompt.require("A", default=23)
+ prompt.require("B", default=42)
+ assert list(prompt()) == [{"role": "user", "content": "23 * 42"}]
+
+
+def test_prompt_parameter_validation_int() -> None:
+ prompt = EasyPrompt("{A} + {B}")
+ prompt.require("A", min=10, max=100)
+ with pytest.raises(ValueError) as e:
+ prompt.bind(A=0)
+ assert str(e.value) == "A (0) is less than min (10)"
+
+
+def test_prompt_parameter_validation_oneof() -> None:
+ prompt = EasyPrompt("{flavor}")
+ prompt.require("flavor", oneof=("vanilla", "strawberry", "chocolate"))
+ with pytest.raises(ValueError) as e:
+ prompt.bind(flavor="mint chip")
+ assert (
+ str(e.value)
+ == "flavor (mint chip) must be one of vanilla, strawberry, chocolate"
+ )
+
+
+def test_prompt_bind_iteration() -> None:
+ """Iterating over a prompt should return messages with placeholders filled in."""
+ prompt = EasyPrompt(
+ model="gpt-4o",
+ messages=[
+ {
+ "role": "system",
+ "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.",
+ },
+ {"role": "user", "content": "{sentence}"},
+ ],
+ temperature=0.8,
+ max_tokens=64,
+ top_p=1,
+ ).bind(sentence="Artificial intelligence is a technology with great promise.")
+ desired = [
+ {
+ "role": "system",
+ "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.",
+ },
+ {
+ "role": "user",
+ "content": "Artificial intelligence is a technology with great promise.",
+ },
+ ]
+ assert iter_equal(prompt, iter(desired))
+
+
+def test_prompt_as_dict():
+ prompt = EasyPrompt(
+ model="gpt-4o",
+ messages=[
+ {
+ "role": "system",
+ "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.",
+ },
+ {
+ "role": "user",
+ "content": "Artificial intelligence is a technology with great promise.",
+ },
+ ],
+ temperature=0.8,
+ max_tokens=64,
+ top_p=1,
+ )
+ assert prompt.as_dict() == {
+ "model": "gpt-4o",
+ "temperature": 0.8,
+ "max_tokens": 64,
+ "top_p": 1,
+ "messages": [
+ {
+ "role": "system",
+ "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.",
+ },
+ {
+ "role": "user",
+ "content": "Artificial intelligence is a technology with great promise.",
+ },
+ ],
+ }
+
+
+def test_prompt_as_pydantic_dict():
+ prompt = EasyPrompt(
+ model="gpt-4o",
+ messages=[
+ {
+ "role": "system",
+ "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.",
+ },
+ {
+ "role": "user",
+ "content": "Artificial intelligence is a technology with great promise.",
+ },
+ ],
+ temperature=0.8,
+ max_tokens=64,
+ top_p=1,
+ )
+ assert prompt.as_pydantic_dict() == {
+ "name": None,
+ "description": None,
+ "config": {
+ "model": "gpt-4o",
+ "temperature": 0.8,
+ "max_tokens": 64,
+ "top_p": 1,
+ },
+ "data": [
+ {
+ "role": "system",
+ "content": "You will be provided with text, and your task is to translate it into emojis. Do not use any regular text. Do your best with emojis only.",
+ },
+ {
+ "role": "user",
+ "content": "Artificial intelligence is a technology with great promise.",
+ },
+ ],
+ "requirements": {},
+ }
diff --git a/weave-js/src/components/FancyPage/useProjectSidebar.ts b/weave-js/src/components/FancyPage/useProjectSidebar.ts
index 6fcc50703f9..ab3e11a77df 100644
--- a/weave-js/src/components/FancyPage/useProjectSidebar.ts
+++ b/weave-js/src/components/FancyPage/useProjectSidebar.ts
@@ -144,6 +144,13 @@ export const useProjectSidebar = (
isShown: showWeaveSidebarItems || isShowAll,
iconName: IconNames.BaselineAlt,
},
+ {
+ type: 'button' as const,
+ name: 'Prompts',
+ slug: 'weave/prompts',
+ isShown: showWeaveSidebarItems || isShowAll,
+ iconName: IconNames.ForumChatBubble,
+ },
{
type: 'button' as const,
name: 'Models',
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx
index 1bd8c13106b..79f091e6a31 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx
@@ -125,9 +125,11 @@ const useCallTabs = (call: CallSchema) => {
{
label: 'Use',
content: (
-
-
-
+
+
+
+
+
),
},
];
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx
index 7e1663c70dc..045ceb54900 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx
@@ -27,9 +27,11 @@ import {
SimplePageLayoutWithHeader,
} from './common/SimplePageLayout';
import {EvaluationLeaderboardTab} from './LeaderboardTab';
+import {TabPrompt} from './TabPrompt';
import {TabUseDataset} from './TabUseDataset';
import {TabUseModel} from './TabUseModel';
import {TabUseObject} from './TabUseObject';
+import {TabUsePrompt} from './TabUsePrompt';
import {KNOWN_BASE_OBJECT_CLASSES} from './wfReactInterface/constants';
import {useWFHooks} from './wfReactInterface/context';
import {
@@ -127,6 +129,8 @@ const ObjectVersionPageInner: React.FC<{
}, [objectVersion.baseObjectClass]);
const refUri = objectVersionKeyToRefUri(objectVersion);
+ const showPromptTab = objectVersion.val._class_name === 'EasyPrompt';
+
const minimalColumns = useMemo(() => {
return ['id', 'op_name', 'project_id'];
}, []);
@@ -287,6 +291,26 @@ const ObjectVersionPageInner: React.FC<{
// },
// ]}
tabs={[
+ ...(showPromptTab
+ ? [
+ {
+ label: 'Prompt',
+ content: (
+
+ {data.loading ? (
+
+ ) : (
+
+ )}
+
+ ),
+ },
+ ]
+ : []),
...(isEvaluation && evalHasCalls
? [
{
@@ -333,23 +357,33 @@ const ObjectVersionPageInner: React.FC<{
{
label: 'Use',
content: (
-
- {baseObjectClass === 'Dataset' ? (
-
- ) : baseObjectClass === 'Model' ? (
-
- ) : (
-
- )}
-
+
+
+ {baseObjectClass === 'Prompt' ? (
+
+ ) : baseObjectClass === 'Dataset' ? (
+
+ ) : baseObjectClass === 'Model' ? (
+
+ ) : (
+
+ )}
+
+
),
},
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx
index 5e06b4a0474..1a6e4afc577 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx
@@ -12,6 +12,7 @@ import {
} from './common/Links';
import {CenteredAnimatedLoader} from './common/Loader';
import {
+ ScrollableTabContent,
SimpleKeyValueTable,
SimplePageLayoutWithHeader,
} from './common/SimplePageLayout';
@@ -136,9 +137,11 @@ const OpVersionPageInner: React.FC<{
{
label: 'Use',
content: (
-
-
-
+
+
+
+
+
),
},
]
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabPrompt.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabPrompt.tsx
new file mode 100644
index 00000000000..2f2819c3b34
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabPrompt.tsx
@@ -0,0 +1,25 @@
+import classNames from 'classnames';
+import React from 'react';
+
+import {Tailwind} from '../../../../Tailwind';
+import {MessageList} from './ChatView/MessageList';
+
+type Data = Record;
+
+type TabPromptProps = {
+ entity: string;
+ project: string;
+ data: Data;
+};
+
+export const TabPrompt = ({entity, project, data}: TabPromptProps) => {
+ return (
+
+
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseCall.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseCall.tsx
index 817d647d970..3f33be98e7c 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseCall.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseCall.tsx
@@ -30,7 +30,7 @@ os.environ["WF_TRACE_SERVER_URL"] = "http://127.0.0.1:6345"
const codeFeedback = `call.feedback.add("correctness", {"value": 4})`;
return (
-
+
See{' '}
{' '}
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseDataset.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseDataset.tsx
index 8b56a17604d..861eb15f443 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseDataset.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseDataset.tsx
@@ -43,7 +43,7 @@ ${pythonName} = weave.ref('${ref.artifactName}:v${versionIndex}').get()`;
}
return (
-
+
See{' '}
{
const label = isParentObject ? 'model version' : 'object';
return (
-
+
See{' '}
{' '}
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseObject.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseObject.tsx
index 4ea8dc6af30..e8178521316 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseObject.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUseObject.tsx
@@ -15,7 +15,7 @@ type TabUseObjectProps = {
export const TabUseObject = ({name, uri}: TabUseObjectProps) => {
const pythonName = isValidVarName(name) ? name : 'obj';
return (
-
+
See{' '}
{
const pythonName = isValidVarName(name) ? name : 'op';
return (
-
+
See for
more information.
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUsePrompt.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUsePrompt.tsx
new file mode 100644
index 00000000000..6d00af48bc6
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/TabUsePrompt.tsx
@@ -0,0 +1,99 @@
+import {Box} from '@mui/material';
+import React from 'react';
+
+import {isValidVarName} from '../../../../../core/util/var';
+import {parseRef} from '../../../../../react';
+import {abbreviateRef} from '../../../../../util/refs';
+import {Alert} from '../../../../Alert';
+import {CopyableText} from '../../../../CopyableText';
+import {DocLink} from './common/Links';
+
+type Data = Record;
+
+type TabUsePromptProps = {
+ name: string;
+ uri: string;
+ entityName: string;
+ projectName: string;
+ data: Data;
+};
+
+export const TabUsePrompt = ({
+ name,
+ uri,
+ entityName,
+ projectName,
+ data,
+}: TabUsePromptProps) => {
+ const pythonName = isValidVarName(name) ? name : 'prompt';
+ const ref = parseRef(uri);
+ const isParentObject = !ref.artifactRefExtra;
+ const label = isParentObject ? 'prompt version' : 'prompt';
+
+ // TODO: Simplify if no params.
+ const longExample = `import weave
+from openai import OpenAI
+
+weave.init("${projectName}")
+
+${pythonName} = weave.ref("${uri}").get()
+
+class MyModel(weave.Model):
+ model_name: str
+ prompt: weave.Prompt
+
+ @weave.op
+ def predict(self, params: dict) -> dict:
+ client = OpenAI()
+ response = client.chat.completions.create(
+ model=self.model_name,
+ messages=self.prompt.bind(params),
+ )
+ result = response.choices[0].message.content
+ if result is None:
+ raise ValueError("No response from model")
+ return result
+
+mymodel = MyModel(model_name="gpt-3.5-turbo", prompt=${pythonName})
+
+# Replace with desired parameter values
+params = ${JSON.stringify({}, null, 2)}
+print(mymodel.predict(params))
+`;
+
+ return (
+
+
+ See{' '}
+ {' '}
+ and for more
+ information.
+
+
+
+ The ref for this {label} is:
+
+
+
+ Use the following code to retrieve this {label}:
+
+
+
+ A more complete example:
+
+
+
+
+ );
+};
diff --git a/weave/__init__.py b/weave/__init__.py
index 3b54ba97176..781d1e89d89 100644
--- a/weave/__init__.py
+++ b/weave/__init__.py
@@ -12,9 +12,15 @@
from weave.flow.eval import Evaluation, Scorer
from weave.flow.model import Model
from weave.flow.obj import Object
+from weave.flow.prompt.prompt import EasyPrompt, Prompt
+from weave.flow.prompt.prompt import MessagesPrompt as MessagesPrompt
+from weave.flow.prompt.prompt import StringPrompt as StringPrompt
from weave.trace.util import Thread as Thread
from weave.trace.util import ThreadPoolExecutor as ThreadPoolExecutor
+# Alias for succinct code
+P = EasyPrompt
+
# Special object informing doc generation tooling which symbols
# to document & to associate with this module.
__docspec__ = [
@@ -31,6 +37,7 @@
Object,
Dataset,
Model,
+ Prompt,
Evaluation,
Scorer,
]
diff --git a/weave/flow/prompt/common.py b/weave/flow/prompt/common.py
new file mode 100644
index 00000000000..80bc63ae60f
--- /dev/null
+++ b/weave/flow/prompt/common.py
@@ -0,0 +1,14 @@
+# TODO: Maybe use an enum or something to lock down types more
+
+ROLE_COLORS: dict[str, str] = {
+ "system": "bold blue",
+ "user": "bold green",
+ "assistant": "bold magenta",
+}
+
+
+def color_role(role: str) -> str:
+ color = ROLE_COLORS.get(role)
+ if color:
+ return f"[{color}]{role}[/]"
+ return role
diff --git a/weave/flow/prompt/prompt.py b/weave/flow/prompt/prompt.py
new file mode 100644
index 00000000000..016e9d3f996
--- /dev/null
+++ b/weave/flow/prompt/prompt.py
@@ -0,0 +1,440 @@
+import copy
+import json
+import os
+import re
+import textwrap
+from collections import UserList
+from pathlib import Path
+from typing import IO, Any, Optional, SupportsIndex, TypedDict, Union, overload
+
+from pydantic import Field
+from rich.table import Table
+
+from weave.flow.obj import Object
+from weave.flow.prompt.common import ROLE_COLORS, color_role
+from weave.trace.api import publish as weave_publish
+from weave.trace.op import op
+from weave.trace.refs import ObjectRef
+from weave.trace.rich import pydantic_util
+
+
+class Message(TypedDict):
+ role: str
+ content: str
+
+
+def maybe_dedent(content: str, dedent: bool) -> str:
+ if dedent:
+ return textwrap.dedent(content).strip()
+ return content
+
+
+def str_to_message(
+ content: str, role: Optional[str] = None, dedent: bool = False
+) -> Message:
+ if role is not None:
+ return {"role": role, "content": maybe_dedent(content, dedent)}
+ for role in ROLE_COLORS:
+ prefix = role + ":"
+ if content.startswith(prefix):
+ return {
+ "role": role,
+ "content": maybe_dedent(content[len(prefix) :].lstrip(), dedent),
+ }
+ return {"role": "user", "content": maybe_dedent(content, dedent)}
+
+
+# TODO: This supports Python format specifiers, but maybe we don't want to
+# because it will be harder to do in clients in other languages?
+RE_PLACEHOLDER = re.compile(r"\{(\w+)(:[^}]+)?\}")
+
+
+def extract_placeholders(text: str) -> list[str]:
+ placeholders = re.findall(RE_PLACEHOLDER, text)
+ unique = []
+ for name, _ in placeholders:
+ if name not in unique:
+ unique.append(name)
+ return unique
+
+
+def color_content(content: str, values: dict) -> str:
+ placeholders = extract_placeholders(content)
+ colored_values = {}
+ for placeholder in placeholders:
+ if placeholder not in values:
+ colored_values[placeholder] = "[red]{" + placeholder + "}[/]"
+ else:
+ colored_values[placeholder] = (
+ "[orange3]{" + placeholder + ":" + str(values[placeholder]) + "}[/]"
+ )
+ return content.format(**colored_values)
+
+
+class Prompt(Object):
+ def format(self, **kwargs: Any) -> Any:
+ raise NotImplemented
+
+
+class MessagesPrompt(Prompt):
+ def format(self, **kwargs: Any) -> list:
+ raise NotImplemented
+
+
+class StringPrompt(Prompt):
+ def format(self, **kwargs: Any) -> str:
+ raise NotImplemented
+
+
+class EasyPrompt(UserList, Prompt):
+ data: list = Field(default_factory=list)
+ config: dict = Field(default_factory=dict)
+ requirements: dict = Field(default_factory=dict)
+
+ _values: dict
+
+ def __init__(
+ self,
+ content: Optional[Union[str, dict, list]] = None,
+ *,
+ role: Optional[str] = None,
+ dedent: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super(UserList, self).__init__()
+ name = kwargs.pop("name", None)
+ description = kwargs.pop("description", None)
+ config = kwargs.pop("config", {})
+ requirements = kwargs.pop("requirements", {})
+ if "messages" in kwargs:
+ content = kwargs.pop("messages")
+ config.update(kwargs)
+ kwargs = {"config": config, "requirements": requirements}
+ super(Object, self).__init__(name=name, description=description, **kwargs)
+ self._values = {}
+ if content is not None:
+ if isinstance(content, (str, dict)):
+ content = [content]
+ for item in content:
+ self.append(item, role=role, dedent=dedent)
+
+ def __add__(self, other: Any) -> "Prompt":
+ new_prompt = self.copy()
+ new_prompt += other
+ return new_prompt
+
+ def append(
+ self,
+ item: Any,
+ role: Optional[str] = None,
+ dedent: bool = False,
+ ) -> None:
+ if isinstance(item, str):
+ # Seems like we don't want to do this, if the user wants
+ # all system we have helpers for that, and we want to make the
+ # case of constructing system + user easy
+ # role = self.data[-1].get("role", "user") if self.data else "user"
+ self.data.append(str_to_message(item, role=role, dedent=dedent))
+ elif isinstance(item, dict):
+ # TODO: Validate that item has message shape
+ # TODO: Override role and do dedent?
+ self.data.append(item)
+ elif isinstance(item, list):
+ for item in item:
+ self.append(item)
+ else:
+ raise ValueError(f"Cannot append {item} of type {type(item)} to Prompt")
+
+ def __iadd__(self, item: Any) -> "Prompt":
+ self.append(item)
+ return self
+
+ @property
+ def as_str(self) -> str:
+ """Join all messages into a single string."""
+ return " ".join(message.get("content", "") for message in self.data)
+
+ @property
+ def system_message(self) -> Message:
+ """Join all messages into a system prompt message."""
+ return {"role": "system", "content": self.as_str}
+
+ @property
+ def system_prompt(self) -> "Prompt":
+ """Join all messages into a system prompt object."""
+ return Prompt(self.as_str, role="system")
+
+ @property
+ def messages(self) -> list[Message]:
+ return self.data
+
+ @property
+ def placeholders(self) -> list[str]:
+ all_placeholders: list[str] = []
+ for message in self.data:
+ # TODO: Support placeholders in image messages?
+ placeholders = extract_placeholders(message["content"])
+ all_placeholders.extend(
+ p for p in placeholders if p not in all_placeholders
+ )
+ return all_placeholders
+
+ @property
+ def unbound_placeholders(self) -> list[str]:
+ unbound = []
+ for p in self.placeholders:
+ if p not in self._values:
+ unbound.append(p)
+ return unbound
+
+ @property
+ def is_bound(self) -> bool:
+ return not self.unbound_placeholders
+
+ def validate_requirement(self, key: str, value: Any) -> list:
+ problems = []
+ requirement = self.requirements.get(key)
+ if not requirement:
+ return []
+ # TODO: Type coercion
+ min = requirement.get("min")
+ if min is not None and value < min:
+ problems.append(f"{key} ({value}) is less than min ({min})")
+ max = requirement.get("max")
+ if max is not None and value > max:
+ problems.append(f"{key} ({value}) is greater than max ({max})")
+ oneof = requirement.get("oneof")
+ if oneof is not None and value not in oneof:
+ problems.append(f"{key} ({value}) must be one of {', '.join(oneof)}")
+ return problems
+
+ def validate_requirements(self, values: dict[str, Any]) -> list:
+ problems = []
+ for key, value in values.items():
+ problems += self.validate_requirement(key, value)
+ return problems
+
+ def bind(self, *args: Any, **kwargs: Any) -> "Prompt":
+ is_dict = len(args) == 1 and isinstance(args[0], dict)
+ problems = []
+ if is_dict:
+ problems += self.validate_requirements(args[0])
+ problems += self.validate_requirements(kwargs)
+ if problems:
+ raise ValueError("\n".join(problems))
+ if is_dict:
+ self._values.update(args[0])
+ self._values.update(kwargs)
+ return self
+
+ def __call__(self, *args: Any, **kwargs: Any) -> list[Message]:
+ if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict):
+ kwargs = args[0]
+ prompt = self.bind(kwargs)
+ return list(prompt)
+
+ # TODO: Any should be Dataset but there is a circular dependency issue
+ def bind_rows(self, dataset: Union[list[dict], Any]) -> list["Prompt"]:
+ rows = dataset if isinstance(dataset, list) else dataset.rows
+ bound: list["Prompt"] = []
+ for row in rows:
+ bound.append(self.copy().bind(row))
+ return bound
+
+ @overload
+ def __getitem__(self, index: SupportsIndex) -> Any: ...
+
+ @overload
+ def __getitem__(self, key: slice) -> "EasyPrompt": ...
+
+ def __getitem__(self, key: Union[SupportsIndex, slice]) -> Any:
+ """Override getitem to return a Message, Prompt object, or config value."""
+ if isinstance(key, SupportsIndex):
+ int_index = key.__index__()
+ message = self.data[int_index].copy()
+ placeholders = extract_placeholders(message["content"])
+ values = {}
+ for placeholder in placeholders:
+ if placeholder in self._values:
+ values[placeholder] = self._values[placeholder]
+ elif (
+ placeholder in self.requirements
+ and "default" in self.requirements[placeholder]
+ ):
+ values[placeholder] = self.requirements[placeholder]["default"]
+ else:
+ values[placeholder] = "{" + placeholder + "}"
+ message["content"] = message["content"].format(**values)
+ return message
+ elif isinstance(key, slice):
+ new_prompt = Prompt()
+ new_prompt.name = self.name
+ new_prompt.description = self.description
+ new_prompt.data = self.data[key]
+ new_prompt.config = self.config.copy()
+ new_prompt.requirements = self.requirements.copy()
+ new_prompt._values = self._values.copy()
+ return new_prompt
+ elif isinstance(key, str):
+ if key == "ref":
+ return self
+ if key == "messages":
+ return self.data
+ return self.config[key]
+ else:
+ raise TypeError(f"Invalid argument type: {type(key)}")
+
+ def __deepcopy__(self, memo: dict) -> "Prompt":
+ # I'm sure this isn't right, but hacking in to avoid
+ # TypeError: cannot pickle '_thread.lock' object.
+ # Basically, as part of logging our message objects are
+ # turning into WeaveDicts which have a sever reference which
+ # in turn can't be copied
+ c = copy.deepcopy(dict(self.config), memo)
+ r = copy.deepcopy(dict(self.requirements), memo)
+ p = Prompt(
+ name=self.name, description=self.description, config=c, requirements=r
+ )
+ p._values = dict(self._values)
+ for value in self.data:
+ p.data.append(dict(value))
+ return p
+
+ def require(self, param_name: str, **kwargs: Any) -> "Prompt":
+ self.requirements[param_name] = kwargs
+ return self
+
+ def configure(self, config: Optional[dict] = None, **kwargs: Any) -> "Prompt":
+ if config:
+ self.config = config
+ self.config.update(kwargs)
+ return self
+
+ def publish(self, name: Optional[str] = None) -> ObjectRef:
+ # TODO: This only works if we've called weave.init, but it seems like
+ # that shouldn't be necessary if we have loaded this from a ref.
+ return weave_publish(self, name=name)
+
+ def messages_table(self, title: Optional[str] = None) -> Table:
+ table = Table(title=title, title_justify="left", show_header=False)
+ table.add_column("Role", justify="right")
+ table.add_column("Content")
+ # TODO: Maybe we should inline the values here? Or highlight placeholders missing values in red?
+ for message in self.data:
+ table.add_row(
+ color_role(message.get("role", "user")),
+ color_content(message.get("content", ""), self._values),
+ )
+ return table
+
+ def values_table(self, title: Optional[str] = None) -> Table:
+ table = Table(title=title, title_justify="left", show_header=False)
+ table.add_column("Parameter", justify="right")
+ table.add_column("Value")
+ for key, value in self._values.items():
+ table.add_row(key, str(value))
+ return table
+
+ def config_table(self, title: Optional[str] = None) -> Table:
+ table = Table(title=title, title_justify="left", show_header=False)
+ table.add_column("Key", justify="right")
+ table.add_column("Value")
+ for key, value in self.config.items():
+ table.add_row(key, str(value))
+ return table
+
+ def print(self) -> str:
+ tables = []
+ if self.name or self.description:
+ table1 = Table(show_header=False)
+ table1.add_column("Key", justify="right", style="bold cyan")
+ table1.add_column("Value")
+ if self.name is not None:
+ table1.add_row("Name", self.name)
+ if self.description is not None:
+ table1.add_row("Description", self.description)
+ tables.append(table1)
+ if self.data:
+ tables.append(self.messages_table(title="Messages"))
+ if self._values:
+ tables.append(self.values_table(title="Parameters"))
+ if self.config:
+ tables.append(self.config_table(title="Config"))
+ tables = [pydantic_util.table_to_str(t) for t in tables]
+ return "\n".join(tables)
+
+ def __str__(self) -> str:
+ """Return a single prompt string when str() is called on the object."""
+ return self.as_str
+
+ def _repr_pretty_(self, p: Any, cycle: bool) -> None:
+ """Show a nicely formatted table in ipython."""
+ if cycle:
+ p.text("Prompt(...)")
+ else:
+ p.text(self.print())
+
+ def as_pydantic_dict(self) -> dict[str, Any]:
+ return self.model_dump()
+
+ def as_dict(self) -> dict[str, Any]:
+ # In chat completion kwargs format
+ return {
+ **self.config,
+ "messages": list(self),
+ }
+
+ @staticmethod
+ def from_obj(obj: Any) -> "EasyPrompt":
+ messages = obj.messages if hasattr(obj, "messages") else obj.data
+ messages = [dict(m) for m in messages]
+ config = dict(obj.config)
+ requirements = dict(obj.requirements)
+ return EasyPrompt(
+ name=obj.name,
+ description=obj.description,
+ messages=messages,
+ config=config,
+ requirements=requirements,
+ )
+
+ @staticmethod
+ def load(fp: IO) -> "EasyPrompt":
+ if isinstance(fp, str): # Common mistake
+ raise ValueError(
+ "Prompt.load() takes a file-like object, not a string. Did you mean Prompt.e()?"
+ )
+ data = json.load(fp)
+ prompt = EasyPrompt(**data)
+ return prompt
+
+ @staticmethod
+ def load_file(filepath: Union[str, Path]) -> "Prompt":
+ expanded_path = os.path.expanduser(str(filepath))
+ with open(expanded_path, "r") as f:
+ return EasyPrompt.load(f)
+
+ def dump(self, fp: IO) -> None:
+ json.dump(self.as_pydantic_dict(), fp, indent=2)
+
+ def dump_file(self, filepath: Union[str, Path]) -> None:
+ expanded_path = os.path.expanduser(str(filepath))
+ with open(expanded_path, "w") as f:
+ self.dump(f)
+
+ # TODO: We would like to be able to make this an Op.
+ # Unfortunately, litellm tries to make a deepcopy of the messages
+ # and that fails because the Message objects aren't picklable.
+ # TypeError: cannot pickle '_thread.RLock' object
+ # (Which I think is because they keep a reference to the server interface maybe?)
+ @op
+ def run(self) -> Any:
+ # TODO: Nicer result type
+ import litellm
+
+ result = litellm.completion(
+ messages=list(self),
+ model=self.config.get("model", "gpt-4o-mini"),
+ )
+ # TODO: Print in a nicer format
+ return result
diff --git a/weave/integrations/openai/openai_sdk.py b/weave/integrations/openai/openai_sdk.py
index d32d1a80a70..558373ab44a 100644
--- a/weave/integrations/openai/openai_sdk.py
+++ b/weave/integrations/openai/openai_sdk.py
@@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any, Callable, Optional
import weave
+from weave.trace.op import Op, ProcessedInputs
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher
@@ -277,6 +278,28 @@ def should_use_accumulator(inputs: dict) -> bool:
)
+def openai_on_input_handler(
+ func: Op, args: tuple, kwargs: dict
+) -> Optional[ProcessedInputs]:
+ if len(args) == 2 and isinstance(args[1], weave.EasyPrompt):
+ original_args = args
+ original_kwargs = kwargs
+ prompt = args[1]
+ args = args[:-1]
+ kwargs.update(prompt.as_dict())
+ inputs = {
+ "prompt": prompt,
+ }
+ return ProcessedInputs(
+ original_args=original_args,
+ original_kwargs=original_kwargs,
+ args=args,
+ kwargs=kwargs,
+ inputs=inputs,
+ )
+ return None
+
+
def create_wrapper_sync(
name: str,
) -> Callable[[Callable], Callable]:
@@ -301,6 +324,7 @@ def _openai_stream_options_is_set(inputs: dict) -> bool:
op = weave.op()(_add_stream_options(fn))
op.name = name # type: ignore
+ op._set_on_input_handler(openai_on_input_handler)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: lambda acc, value: openai_accumulator(
@@ -338,6 +362,7 @@ def _openai_stream_options_is_set(inputs: dict) -> bool:
op = weave.op()(_add_stream_options(fn))
op.name = name # type: ignore
+ op._set_on_input_handler(openai_on_input_handler)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: lambda acc, value: openai_accumulator(
diff --git a/weave/trace/op.py b/weave/trace/op.py
index 7614b1d8630..ae85d65e7b8 100644
--- a/weave/trace/op.py
+++ b/weave/trace/op.py
@@ -5,6 +5,7 @@
import sys
import traceback
import typing
+from dataclasses import dataclass
from functools import partial, wraps
from types import MethodType
from typing import (
@@ -84,6 +85,21 @@ def print_call_link(call: "Call") -> None:
print(f"{TRACE_CALL_EMOJI} {call.ui_url}")
+@dataclass
+class ProcessedInputs:
+ # What the user passed to the function
+ original_args: tuple
+ original_kwargs: dict[str, Any]
+
+ # What should get passed to the interior function
+ args: tuple
+ kwargs: dict[str, Any]
+
+ # What should get sent to the Weave server
+ inputs: dict[str, Any]
+
+
+OnInputHandlerType = Callable[["Op", tuple, dict], Optional[ProcessedInputs]]
FinishCallbackType = Callable[[Any, Optional[BaseException]], None]
OnOutputHandlerType = Callable[[Any, FinishCallbackType, Dict], Any]
# Call, original function output, exception if occurred
@@ -155,6 +171,9 @@ class Op(Protocol):
call: Callable[..., Any]
calls: Callable[..., "CallsIter"]
+ _set_on_input_handler: Callable[[OnInputHandlerType], None]
+ _on_input_handler: Optional[OnInputHandlerType]
+
# not sure if this is the best place for this, but kept for compat
_set_on_output_handler: Callable[[OnOutputHandlerType], None]
_on_output_handler: Optional[OnOutputHandlerType]
@@ -175,6 +194,12 @@ class Op(Protocol):
_tracing_enabled: bool
+def _set_on_input_handler(func: Op, on_input: OnInputHandlerType) -> None:
+ if func._on_input_handler is not None:
+ raise ValueError("Cannot set on_input_handler multiple times")
+ func._on_input_handler = on_input
+
+
def _set_on_output_handler(func: Op, on_output: OnOutputHandlerType) -> None:
if func._on_output_handler is not None:
raise ValueError("Cannot set on_output_handler multiple times")
@@ -203,16 +228,32 @@ def _is_unbound_method(func: Callable) -> bool:
return bool(is_method)
-def _create_call(
- func: Op, *args: Any, __weave: Optional[WeaveKwargs] = None, **kwargs: Any
-) -> "Call":
- client = weave_client_context.require_weave_client()
-
+def default_on_input_handler(func: Op, args: tuple, kwargs: dict) -> ProcessedInputs:
try:
inputs = func.signature.bind(*args, **kwargs).arguments
except TypeError as e:
raise OpCallError(f"Error calling {func.name}: {e}")
inputs_with_defaults = _apply_fn_defaults_to_inputs(func, inputs)
+ return ProcessedInputs(
+ original_args=args,
+ original_kwargs=kwargs,
+ args=args,
+ kwargs=kwargs,
+ inputs=inputs_with_defaults,
+ )
+
+
+def _create_call(
+ func: Op, *args: Any, __weave: Optional[WeaveKwargs] = None, **kwargs: Any
+) -> "Call":
+ client = weave_client_context.require_weave_client()
+
+ pargs = None
+ if func._on_input_handler is not None:
+ pargs = func._on_input_handler(func, args, kwargs)
+ if not pargs:
+ pargs = default_on_input_handler(func, args, kwargs)
+ inputs_with_defaults = pargs.inputs
# This should probably be configurable, but for now we redact the api_key
if "api_key" in inputs_with_defaults:
@@ -368,12 +409,19 @@ def _do_call(
) -> tuple[Any, "Call"]:
func = op.resolve_fn
call = _placeholder_call()
+
+ pargs = None
+ if op._on_input_handler is not None:
+ pargs = op._on_input_handler(op, args, kwargs)
+ if not pargs:
+ pargs = default_on_input_handler(op, args, kwargs)
+
if settings.should_disable_weave():
- res = func(*args, **kwargs)
+ res = func(*pargs.args, **pargs.kwargs)
elif weave_client_context.get_weave_client() is None:
- res = func(*args, **kwargs)
+ res = func(*pargs.args, **pargs.kwargs)
elif not op._tracing_enabled:
- res = func(*args, **kwargs)
+ res = func(*pargs.args, **pargs.kwargs)
else:
try:
# This try/except allows us to fail gracefully and
@@ -388,10 +436,10 @@ def _do_call(
logger.error,
CALL_CREATE_MSG.format(traceback.format_exc()),
)
- res = func(*args, **kwargs)
+ res = func(*pargs.args, **pargs.kwargs)
else:
execute_result = _execute_call(
- op, call, *args, __should_raise=__should_raise, **kwargs
+ op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs
)
if inspect.iscoroutine(execute_result):
raise Exception(
@@ -600,6 +648,9 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
wrapper.__call__ = wrapper # type: ignore
wrapper.__self__ = wrapper # type: ignore
+ wrapper._set_on_input_handler = partial(_set_on_input_handler, wrapper) # type: ignore
+ wrapper._on_input_handler = None # type: ignore
+
wrapper._set_on_output_handler = partial(_set_on_output_handler, wrapper) # type: ignore
wrapper._on_output_handler = None # type: ignore
diff --git a/weave/trace/refs.py b/weave/trace/refs.py
index f29c79091a1..ef002997ea3 100644
--- a/weave/trace/refs.py
+++ b/weave/trace/refs.py
@@ -144,6 +144,19 @@ def uri(self) -> str:
u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra)
return u
+ def objectify(self, obj: Any) -> Any:
+ """Convert back to higher level object."""
+ class_name = getattr(obj, "_class_name", None)
+ if "EasyPrompt" == class_name:
+ from weave.flow.prompt.prompt import EasyPrompt
+
+ prompt = EasyPrompt.from_obj(obj)
+ # We want to use the ref on the object (and not self) as it will have had
+ # version number or latest alias resolved to a specific digest.
+ prompt.__dict__["ref"] = obj.ref
+ return prompt
+ return obj
+
def get(self) -> Any:
# Move import here so that it only happens when the function is called.
# This import is invalid in the trace server and represents a dependency
@@ -153,7 +166,7 @@ def get(self) -> Any:
gc = get_weave_client()
if gc is not None:
- return gc.get(self)
+ return self.objectify(gc.get(self))
# Special case: If the user is attempting to fetch an object but has not
# yet initialized the client, we can initialize a client to
@@ -166,7 +179,7 @@ def get(self) -> Any:
res = init_client.client.get(self)
finally:
init_client.reset()
- return res
+ return self.objectify(res)
def is_descended_from(self, potential_ancestor: "ObjectRef") -> bool:
if self.entity != potential_ancestor.entity:
From e85bdc14866c9a0e46c87a3b09d1e08e09cbcc33 Mon Sep 17 00:00:00 2001
From: Connie Lee
Date: Tue, 29 Oct 2024 16:52:33 -0700
Subject: [PATCH 03/16] style(ui): Fix night colors for Callout (#2815)
---
weave-js/src/components/Callout/Callout.tsx | 1 +
1 file changed, 1 insertion(+)
diff --git a/weave-js/src/components/Callout/Callout.tsx b/weave-js/src/components/Callout/Callout.tsx
index 51028420f46..9fc6535d9cf 100644
--- a/weave-js/src/components/Callout/Callout.tsx
+++ b/weave-js/src/components/Callout/Callout.tsx
@@ -18,6 +18,7 @@ export const Callout = ({className, color, icon, size}: CalloutProps) => {