From af81503aa330c0659a8c65b316b1071782a543e6 Mon Sep 17 00:00:00 2001 From: Jeff Raubitschek Date: Wed, 11 Dec 2024 12:34:22 -0800 Subject: [PATCH 01/16] chore(weave): update db to use replicated tables (#3148) --- .../test_clickhouse_trace_server_migrator.py | 230 ++++++++++++++++++ .../clickhouse_trace_server_migrator.py | 132 ++++++++-- 2 files changed, 336 insertions(+), 26 deletions(-) create mode 100644 tests/trace_server/test_clickhouse_trace_server_migrator.py diff --git a/tests/trace_server/test_clickhouse_trace_server_migrator.py b/tests/trace_server/test_clickhouse_trace_server_migrator.py new file mode 100644 index 00000000000..3a6a92f2479 --- /dev/null +++ b/tests/trace_server/test_clickhouse_trace_server_migrator.py @@ -0,0 +1,230 @@ +import types +from unittest.mock import Mock, call, patch + +import pytest + +from weave.trace_server import clickhouse_trace_server_migrator as trace_server_migrator +from weave.trace_server.clickhouse_trace_server_migrator import MigrationError + + +@pytest.fixture +def mock_costs(): + with patch( + "weave.trace_server.costs.insert_costs.should_insert_costs", return_value=False + ) as mock_should_insert: + with patch( + "weave.trace_server.costs.insert_costs.get_current_costs", return_value=[] + ) as mock_get_costs: + yield + + +@pytest.fixture +def migrator(): + ch_client = Mock() + migrator = trace_server_migrator.ClickHouseTraceServerMigrator(ch_client) + migrator._get_migration_status = Mock() + migrator._get_migrations = Mock() + migrator._determine_migrations_to_apply = Mock() + migrator._update_migration_status = Mock() + ch_client.command.reset_mock() + return migrator + + +def test_apply_migrations_with_target_version(mock_costs, migrator, tmp_path): + # Setup + migrator._get_migration_status.return_value = { + "curr_version": 1, + "partially_applied_version": None, + } + migrator._get_migrations.return_value = { + "1": {"up": "1.up.sql", "down": "1.down.sql"}, + "2": {"up": "2.up.sql", "down": "2.down.sql"}, + } + migrator._determine_migrations_to_apply.return_value = [(2, "2.up.sql")] + + # Create a temporary migration file + migration_dir = tmp_path / "migrations" + migration_dir.mkdir() + migration_file = migration_dir / "2.up.sql" + migration_file.write_text( + "CREATE TABLE test1 (id Int32);\nCREATE TABLE test2 (id Int32);" + ) + + # Mock the migration directory path + with patch("os.path.dirname") as mock_dirname: + mock_dirname.return_value = str(tmp_path) + + # Execute + migrator.apply_migrations("test_db", target_version=2) + + # Verify + migrator._get_migration_status.assert_called_once_with("test_db") + migrator._get_migrations.assert_called_once() + migrator._determine_migrations_to_apply.assert_called_once_with( + 1, migrator._get_migrations.return_value, 2 + ) + + # Verify migration execution + assert migrator._update_migration_status.call_count == 2 + migrator._update_migration_status.assert_has_calls( + [call("test_db", 2, is_start=True), call("test_db", 2, is_start=False)] + ) + + # Verify the actual SQL commands were executed + ch_client = migrator.ch_client + assert ch_client.command.call_count == 2 + ch_client.command.assert_has_calls( + [call("CREATE TABLE test1 (id Int32)"), call("CREATE TABLE test2 (id Int32)")] + ) + + +def test_execute_migration_command(mock_costs, migrator): + # Setup + ch_client = migrator.ch_client + ch_client.database = "original_db" + + # Execute + migrator._execute_migration_command("test_db", "CREATE TABLE test (id Int32)") + + # Verify + assert ch_client.database == "original_db" # Should restore original database + ch_client.command.assert_called_once_with("CREATE TABLE test (id Int32)") + + +def test_migration_replicated(mock_costs, migrator): + ch_client = migrator.ch_client + orig = "CREATE TABLE test (id String, project_id String) ENGINE = MergeTree ORDER BY (project_id, id);" + migrator._execute_migration_command("test_db", orig) + ch_client.command.assert_called_once_with(orig) + + +def test_update_migration_status(mock_costs, migrator): + # Don't mock _update_migration_status for this test + migrator._update_migration_status = types.MethodType( + trace_server_migrator.ClickHouseTraceServerMigrator._update_migration_status, + migrator, + ) + + # Test start of migration + migrator._update_migration_status("test_db", 2, is_start=True) + migrator.ch_client.command.assert_called_with( + "ALTER TABLE db_management.migrations UPDATE partially_applied_version = 2 WHERE db_name = 'test_db'" + ) + + # Test end of migration + migrator._update_migration_status("test_db", 2, is_start=False) + migrator.ch_client.command.assert_called_with( + "ALTER TABLE db_management.migrations UPDATE curr_version = 2, partially_applied_version = NULL WHERE db_name = 'test_db'" + ) + + +def test_is_safe_identifier(mock_costs, migrator): + # Valid identifiers + assert migrator._is_safe_identifier("test_db") + assert migrator._is_safe_identifier("my_db123") + assert migrator._is_safe_identifier("db.table") + + # Invalid identifiers + assert not migrator._is_safe_identifier("test-db") + assert not migrator._is_safe_identifier("db;") + assert not migrator._is_safe_identifier("db'name") + assert not migrator._is_safe_identifier("db/*") + + +def test_create_db_sql_validation(mock_costs, migrator): + # Test invalid database name + with pytest.raises(MigrationError, match="Invalid database name"): + migrator._create_db_sql("test;db") + + # Test replicated mode with invalid values + migrator.replicated = True + migrator.replicated_cluster = "test;cluster" + with pytest.raises(MigrationError, match="Invalid cluster name"): + migrator._create_db_sql("test_db") + + migrator.replicated_cluster = "test_cluster" + migrator.replicated_path = "/clickhouse/bad;path/{db}" + with pytest.raises(MigrationError, match="Invalid replicated path"): + migrator._create_db_sql("test_db") + + +def test_create_db_sql_non_replicated(mock_costs, migrator): + # Test non-replicated mode + migrator.replicated = False + sql = migrator._create_db_sql("test_db") + assert sql.strip() == "CREATE DATABASE IF NOT EXISTS test_db" + + +def test_create_db_sql_replicated(mock_costs, migrator): + # Test replicated mode + migrator.replicated = True + migrator.replicated_path = "/clickhouse/tables/{db}" + migrator.replicated_cluster = "test_cluster" + + sql = migrator._create_db_sql("test_db") + expected = """ + CREATE DATABASE IF NOT EXISTS test_db ON CLUSTER test_cluster ENGINE=Replicated('/clickhouse/tables/test_db', '{shard}', '{replica}') + """.strip() + assert sql.strip() == expected + + +def test_format_replicated_sql_non_replicated(mock_costs, migrator): + # Test that SQL is unchanged when replicated=False + migrator.replicated = False + test_cases = [ + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree", + "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree", + ] + + for sql in test_cases: + assert migrator._format_replicated_sql(sql) == sql + + +def test_format_replicated_sql_replicated(mock_costs, migrator): + # Test that MergeTree engines are converted to Replicated variants + migrator.replicated = True + + test_cases = [ + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree", + ), + ( + "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedSummingMergeTree", + ), + ( + "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedReplacingMergeTree", + ), + # Test with extra whitespace + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree", + ), + # Test with parameters + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree()", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree()", + ), + ] + + for input_sql, expected_sql in test_cases: + assert migrator._format_replicated_sql(input_sql) == expected_sql + + +def test_format_replicated_sql_non_mergetree(mock_costs, migrator): + # Test that non-MergeTree engines are left unchanged + migrator.replicated = True + + test_cases = [ + "CREATE TABLE test (id Int32) ENGINE = Memory", + "CREATE TABLE test (id Int32) ENGINE = Log", + "CREATE TABLE test (id Int32) ENGINE = TinyLog", + # This should not be changed as it's not a complete word match + "CREATE TABLE test (id Int32) ENGINE = MyMergeTreeCustom", + ] + + for sql in test_cases: + assert migrator._format_replicated_sql(sql) == sql diff --git a/weave/trace_server/clickhouse_trace_server_migrator.py b/weave/trace_server/clickhouse_trace_server_migrator.py index 30dffe89365..4336630bf50 100644 --- a/weave/trace_server/clickhouse_trace_server_migrator.py +++ b/weave/trace_server/clickhouse_trace_server_migrator.py @@ -1,6 +1,7 @@ # Clickhouse Trace Server Manager import logging import os +import re from typing import Optional from clickhouse_connect.driver.client import Client as CHClient @@ -9,6 +10,11 @@ logger = logging.getLogger(__name__) +# These settings are only used when `replicated` mode is enabled for +# self managed clickhouse instances. +DEFAULT_REPLICATED_PATH = "/clickhouse/tables/{db}" +DEFAULT_REPLICATED_CLUSTER = "weave_cluster" + class MigrationError(RuntimeError): """Raised when a migration error occurs.""" @@ -16,15 +22,77 @@ class MigrationError(RuntimeError): class ClickHouseTraceServerMigrator: ch_client: CHClient + replicated: bool + replicated_path: str + replicated_cluster: str def __init__( self, ch_client: CHClient, + replicated: Optional[bool] = None, + replicated_path: Optional[str] = None, + replicated_cluster: Optional[str] = None, ): super().__init__() self.ch_client = ch_client + self.replicated = False if replicated is None else replicated + self.replicated_path = ( + DEFAULT_REPLICATED_PATH if replicated_path is None else replicated_path + ) + self.replicated_cluster = ( + DEFAULT_REPLICATED_CLUSTER + if replicated_cluster is None + else replicated_cluster + ) self._initialize_migration_db() + def _is_safe_identifier(self, value: str) -> bool: + """Check if a string is safe to use as an identifier in SQL.""" + return bool(re.match(r"^[a-zA-Z0-9_\.]+$", value)) + + def _format_replicated_sql(self, sql_query: str) -> str: + """Format SQL query to use replicated engines if replicated mode is enabled.""" + if not self.replicated: + return sql_query + + # Match "ENGINE = MergeTree" followed by word boundary + pattern = r"ENGINE\s*=\s*(\w+)?MergeTree\b" + + def replace_engine(match: re.Match[str]) -> str: + engine_prefix = match.group(1) or "" + return f"ENGINE = Replicated{engine_prefix}MergeTree" + + return re.sub(pattern, replace_engine, sql_query, flags=re.IGNORECASE) + + def _create_db_sql(self, db_name: str) -> str: + """Geneate SQL database create string for normal and replicated databases.""" + if not self._is_safe_identifier(db_name): + raise MigrationError(f"Invalid database name: {db_name}") + + replicated_engine = "" + replicated_cluster = "" + if self.replicated: + if not self._is_safe_identifier(self.replicated_cluster): + raise MigrationError(f"Invalid cluster name: {self.replicated_cluster}") + + replicated_path = self.replicated_path.replace("{db}", db_name) + if not all( + self._is_safe_identifier(part) + for part in replicated_path.split("/") + if part + ): + raise MigrationError(f"Invalid replicated path: {replicated_path}") + + replicated_cluster = f" ON CLUSTER {self.replicated_cluster}" + replicated_engine = ( + f" ENGINE=Replicated('{replicated_path}', '{{shard}}', '{{replica}}')" + ) + + create_db_sql = f""" + CREATE DATABASE IF NOT EXISTS {db_name}{replicated_cluster}{replicated_engine} + """ + return create_db_sql + def apply_migrations( self, target_db: str, target_version: Optional[int] = None ) -> None: @@ -46,20 +114,15 @@ def apply_migrations( return logger.info(f"Migrations to apply: {migrations_to_apply}") if status["curr_version"] == 0: - self.ch_client.command(f"CREATE DATABASE IF NOT EXISTS {target_db}") + self.ch_client.command(self._create_db_sql(target_db)) for target_version, migration_file in migrations_to_apply: self._apply_migration(target_db, target_version, migration_file) if should_insert_costs(status["curr_version"], target_version): insert_costs(self.ch_client, target_db) def _initialize_migration_db(self) -> None: - self.ch_client.command( - """ - CREATE DATABASE IF NOT EXISTS db_management - """ - ) - self.ch_client.command( - """ + self.ch_client.command(self._create_db_sql("db_management")) + create_table_sql = """ CREATE TABLE IF NOT EXISTS db_management.migrations ( db_name String, @@ -69,7 +132,7 @@ def _initialize_migration_db(self) -> None: ENGINE = MergeTree() ORDER BY (db_name) """ - ) + self.ch_client.command(self._format_replicated_sql(create_table_sql)) def _get_migration_status(self, db_name: str) -> dict: column_names = ["db_name", "curr_version", "partially_applied_version"] @@ -184,31 +247,48 @@ def _determine_migrations_to_apply( return [] + def _execute_migration_command(self, target_db: str, command: str) -> None: + """Execute a single migration command in the context of the target database.""" + command = command.strip() + if len(command) == 0: + return + curr_db = self.ch_client.database + self.ch_client.database = target_db + self.ch_client.command(self._format_replicated_sql(command)) + self.ch_client.database = curr_db + + def _update_migration_status( + self, target_db: str, target_version: int, is_start: bool = True + ) -> None: + """Update the migration status in db_management.migrations table.""" + if is_start: + self.ch_client.command( + f"ALTER TABLE db_management.migrations UPDATE partially_applied_version = {target_version} WHERE db_name = '{target_db}'" + ) + else: + self.ch_client.command( + f"ALTER TABLE db_management.migrations UPDATE curr_version = {target_version}, partially_applied_version = NULL WHERE db_name = '{target_db}'" + ) + def _apply_migration( self, target_db: str, target_version: int, migration_file: str ) -> None: logger.info(f"Applying migration {migration_file} to `{target_db}`") migration_dir = os.path.join(os.path.dirname(__file__), "migrations") migration_file_path = os.path.join(migration_dir, migration_file) + with open(migration_file_path) as f: migration_sql = f.read() - self.ch_client.command( - f""" - ALTER TABLE db_management.migrations UPDATE partially_applied_version = {target_version} WHERE db_name = '{target_db}' - """ - ) + + # Mark migration as partially applied + self._update_migration_status(target_db, target_version, is_start=True) + + # Execute each command in the migration migration_sub_commands = migration_sql.split(";") for command in migration_sub_commands: - command = command.strip() - if len(command) == 0: - continue - curr_db = self.ch_client.database - self.ch_client.database = target_db - self.ch_client.command(command) - self.ch_client.database = curr_db - self.ch_client.command( - f""" - ALTER TABLE db_management.migrations UPDATE curr_version = {target_version}, partially_applied_version = NULL WHERE db_name = '{target_db}' - """ - ) + self._execute_migration_command(target_db, command) + + # Mark migration as fully applied + self._update_migration_status(target_db, target_version, is_start=False) + logger.info(f"Migration {migration_file} applied to `{target_db}`") From 4c6db183be9b356dfc6781670c5b8f5fb6fd7532 Mon Sep 17 00:00:00 2001 From: Josiah Lee Date: Wed, 11 Dec 2024 12:54:59 -0800 Subject: [PATCH 02/16] add choices drawer (#3203) --- .../Browse3/pages/ChatView/ChoiceView.tsx | 8 +- .../Browse3/pages/ChatView/ChoicesDrawer.tsx | 104 ++++++++++++++++++ .../Browse3/pages/ChatView/ChoicesView.tsx | 39 ++++--- .../pages/ChatView/ChoicesViewCarousel.tsx | 27 +++-- .../pages/ChatView/ChoicesViewLinear.tsx | 73 ------------ .../Home/Browse3/pages/ChatView/types.ts | 2 - 6 files changed, 148 insertions(+), 105 deletions(-) create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx delete mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewLinear.tsx diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx index c8143f9549e..d1a2c59d5d0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx @@ -6,15 +6,21 @@ import {Choice} from './types'; type ChoiceViewProps = { choice: Choice; isStructuredOutput?: boolean; + isNested?: boolean; }; -export const ChoiceView = ({choice, isStructuredOutput}: ChoiceViewProps) => { +export const ChoiceView = ({ + choice, + isStructuredOutput, + isNested, +}: ChoiceViewProps) => { const {message} = choice; return ( ); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx new file mode 100644 index 00000000000..16d6897c27b --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx @@ -0,0 +1,104 @@ +import {Box, Drawer} from '@mui/material'; +import {MOON_200} from '@wandb/weave/common/css/color.styles'; +import {Tag} from '@wandb/weave/components/Tag'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import React from 'react'; + +import {Button} from '../../../../../Button'; +import {ChoiceView} from './ChoiceView'; +import {Choice} from './types'; + +type ChoicesDrawerProps = { + choices: Choice[]; + isStructuredOutput?: boolean; + isDrawerOpen: boolean; + setIsDrawerOpen: React.Dispatch>; + selectedChoiceIndex: number; + setSelectedChoiceIndex: (choiceIndex: number) => void; +}; + +export const ChoicesDrawer = ({ + choices, + isStructuredOutput, + isDrawerOpen, + setIsDrawerOpen, + selectedChoiceIndex, + setSelectedChoiceIndex, +}: ChoicesDrawerProps) => { + return ( + setIsDrawerOpen(false)} + title="Choices" + anchor="right" + sx={{ + '& .MuiDrawer-paper': {mt: '60px', width: '400px'}, + }}> + + + Responses + + + ) : ( + + )} + + + + ))} + + + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx index c22df7c63d7..5ddc7f12202 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx @@ -1,9 +1,9 @@ import React, {useState} from 'react'; +import {ChoicesDrawer} from './ChoicesDrawer'; import {ChoicesViewCarousel} from './ChoicesViewCarousel'; -import {ChoicesViewLinear} from './ChoicesViewLinear'; import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; +import {Choice} from './types'; type ChoicesViewProps = { choices: Choice[]; @@ -14,7 +14,12 @@ export const ChoicesView = ({ choices, isStructuredOutput, }: ChoicesViewProps) => { - const [mode, setMode] = useState('linear'); + const [isDrawerOpen, setIsDrawerOpen] = useState(false); + const [localSelectedChoiceIndex, setLocalSelectedChoiceIndex] = useState(0); + + const handleSetSelectedChoiceIndex = (choiceIndex: number) => { + setLocalSelectedChoiceIndex(choiceIndex); + }; if (choices.length === 0) { return null; @@ -26,20 +31,20 @@ export const ChoicesView = ({ } return ( <> - {mode === 'linear' && ( - - )} - {mode === 'carousel' && ( - - )} + + ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx index f4a52fc6801..a34932dea17 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx @@ -1,34 +1,37 @@ -import React, {useState} from 'react'; +import React from 'react'; import {Button} from '../../../../../Button'; import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; +import {Choice} from './types'; type ChoicesViewCarouselProps = { choices: Choice[]; isStructuredOutput?: boolean; - setMode: React.Dispatch>; + setIsDrawerOpen: React.Dispatch>; + selectedChoiceIndex: number; + setSelectedChoiceIndex: (choiceIndex: number) => void; }; export const ChoicesViewCarousel = ({ choices, isStructuredOutput, - setMode, + setIsDrawerOpen, + selectedChoiceIndex, + setSelectedChoiceIndex, }: ChoicesViewCarouselProps) => { - const [step, setStep] = useState(0); - const onNext = () => { - setStep((step + 1) % choices.length); + setSelectedChoiceIndex((selectedChoiceIndex + 1) % choices.length); }; const onBack = () => { - const newStep = step === 0 ? choices.length - 1 : step - 1; - setStep(newStep); + const newStep = + selectedChoiceIndex === 0 ? choices.length - 1 : selectedChoiceIndex - 1; + setSelectedChoiceIndex(newStep); }; return ( <>
@@ -37,7 +40,7 @@ export const ChoicesViewCarousel = ({ size="small" variant="quiet" icon="expand-uncollapse" - onClick={() => setMode('linear')} + onClick={() => setIsDrawerOpen(true)} tooltip="Switch to linear view" />
@@ -48,7 +51,7 @@ export const ChoicesViewCarousel = ({ size="small" onClick={onBack} /> - {step + 1} of {choices.length} + {selectedChoiceIndex + 1} of {choices.length} -
-
- - + } + /> ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx index 1e778727522..f570b2f6295 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx @@ -18,6 +18,7 @@ type MessagePanelProps = { choiceIndex?: number; isNested?: boolean; pendingToolResponseId?: string; + messageHeader?: React.ReactNode; }; export const MessagePanel = ({ @@ -30,6 +31,7 @@ export const MessagePanel = ({ // If the tool call response is pending, the editor will be shown automatically // and on save the tool call response will be updated and sent to the LLM pendingToolResponseId, + messageHeader, }: MessagePanelProps) => { const [isShowingMore, setIsShowingMore] = useState(false); const [isOverflowing, setIsOverflowing] = useState(false); @@ -116,6 +118,7 @@ export const MessagePanel = ({ 'max-h-[400px]': !isShowingMore, 'max-h-full': isShowingMore, })}> + {messageHeader} {isPlayground && editorHeight ? ( { + console.log('playgroundStates', playgroundStates); const [chatText, setChatText] = useState(''); const [isLoading, setIsLoading] = useState(false); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx index 804670a1dc3..0ce3ad02b51 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx @@ -43,6 +43,8 @@ export const useChatFunctions = ( messageIndex: number, newMessage: Message ) => { + console.log('editMessage', callIndex, messageIndex, newMessage); + setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => { const newTraceCall = clearTraceCall( cloneDeep(prevTraceCall as OptionalTraceCallSchema) @@ -106,6 +108,7 @@ export const useChatFunctions = ( choiceIndex: number, newChoice: Message ) => { + console.log('editChoice', callIndex, choiceIndex, newChoice); setPlaygroundStateField(callIndex, 'traceCall', prevTraceCall => { const newTraceCall = clearTraceCall( cloneDeep(prevTraceCall as OptionalTraceCallSchema) From a0f12639451979efc01ef3b31767444542688ea4 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 11 Dec 2024 16:56:46 -0500 Subject: [PATCH 09/16] chore(weave): Add generic iterator for trace server API objects (#3177) --- weave/trace/weave_client.py | 207 ++++++++++++++++++++++-------------- 1 file changed, 129 insertions(+), 78 deletions(-) diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 0eca3fcbedb..1d5d54b9b23 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -10,7 +10,7 @@ from collections.abc import Iterator, Sequence from concurrent.futures import Future from functools import lru_cache -from typing import Any, Callable, cast +from typing import Any, Callable, Generic, Protocol, TypeVar, cast, overload import pydantic from requests import HTTPError @@ -90,6 +90,128 @@ logger = logging.getLogger(__name__) +T = TypeVar("T") +R = TypeVar("R", covariant=True) + + +class FetchFunc(Protocol[T]): + def __call__(self, offset: int, limit: int) -> list[T]: ... + + +TransformFunc = Callable[[T], R] + + +class PaginatedIterator(Generic[T, R]): + """An iterator that fetches pages of items from a server and optionally transforms them + into a more user-friendly type.""" + + def __init__( + self, + fetch_func: FetchFunc[T], + page_size: int = 1000, + transform_func: TransformFunc[T, R] | None = None, + ) -> None: + self.fetch_func = fetch_func + self.page_size = page_size + self.transform_func = transform_func + + if page_size <= 0: + raise ValueError("page_size must be greater than 0") + + @lru_cache + def _fetch_page(self, index: int) -> list[T]: + return self.fetch_func(index * self.page_size, self.page_size) + + @overload + def _get_one(self: PaginatedIterator[T, T], index: int) -> T: ... + @overload + def _get_one(self: PaginatedIterator[T, R], index: int) -> R: ... + def _get_one(self, index: int) -> T | R: + if index < 0: + raise IndexError("Negative indexing not supported") + + page_index = index // self.page_size + page_offset = index % self.page_size + + page = self._fetch_page(page_index) + if page_offset >= len(page): + raise IndexError(f"Index {index} out of range") + + res = page[page_offset] + if transform := self.transform_func: + return transform(res) + return res + + @overload + def _get_slice(self: PaginatedIterator[T, T], key: slice) -> Iterator[T]: ... + @overload + def _get_slice(self: PaginatedIterator[T, R], key: slice) -> Iterator[R]: ... + def _get_slice(self, key: slice) -> Iterator[T] | Iterator[R]: + if (start := key.start or 0) < 0: + raise ValueError("Negative start not supported") + if (stop := key.stop) is not None and stop < 0: + raise ValueError("Negative stop not supported") + if (step := key.step or 1) < 0: + raise ValueError("Negative step not supported") + + i = start + while stop is None or i < stop: + try: + yield self._get_one(i) + except IndexError: + break + i += step + + @overload + def __getitem__(self: PaginatedIterator[T, T], key: int) -> T: ... + @overload + def __getitem__(self: PaginatedIterator[T, R], key: int) -> R: ... + @overload + def __getitem__(self: PaginatedIterator[T, T], key: slice) -> list[T]: ... + @overload + def __getitem__(self: PaginatedIterator[T, R], key: slice) -> list[R]: ... + def __getitem__(self, key: slice | int) -> T | R | list[T] | list[R]: + if isinstance(key, slice): + return list(self._get_slice(key)) + return self._get_one(key) + + @overload + def __iter__(self: PaginatedIterator[T, T]) -> Iterator[T]: ... + @overload + def __iter__(self: PaginatedIterator[T, R]) -> Iterator[R]: ... + def __iter__(self) -> Iterator[T] | Iterator[R]: + return self._get_slice(slice(0, None, 1)) + + +# TODO: should be Call, not WeaveObject +CallsIter = PaginatedIterator[CallSchema, WeaveObject] + + +def _make_calls_iterator( + server: TraceServerInterface, + project_id: str, + filter: CallsFilter, + include_costs: bool = False, +) -> CallsIter: + def fetch_func(offset: int, limit: int) -> list[CallSchema]: + response = server.calls_query( + CallsQueryReq( + project_id=project_id, + filter=filter, + offset=offset, + limit=limit, + include_costs=include_costs, + ) + ) + return response.calls + + # TODO: Should be Call, not WeaveObject + def transform_func(call: CallSchema) -> WeaveObject: + entity, project = project_id.split("/") + return make_client_call(entity, project, call, server) + + return PaginatedIterator(fetch_func, transform_func=transform_func) + class OpNameError(ValueError): """Raised when an op name is invalid.""" @@ -284,7 +406,7 @@ def children(self) -> CallsIter: ) client = weave_client_context.require_weave_client() - return CallsIter( + return _make_calls_iterator( client.server, self.project_id, CallsFilter(parent_ids=[self.id]), @@ -362,80 +484,6 @@ def _apply_scorer(self, scorer_op: Op) -> None: ) -class CallsIter: - server: TraceServerInterface - filter: CallsFilter - include_costs: bool - - def __init__( - self, - server: TraceServerInterface, - project_id: str, - filter: CallsFilter, - include_costs: bool = False, - ) -> None: - self.server = server - self.project_id = project_id - self.filter = filter - self._page_size = 1000 - self.include_costs = include_costs - - # seems like this caching should be on the server, but it's here for now... - @lru_cache - def _fetch_page(self, index: int) -> list[CallSchema]: - # caching in here means that any other CallsIter objects would also - # benefit from the cache - response = self.server.calls_query( - CallsQueryReq( - project_id=self.project_id, - filter=self.filter, - offset=index * self._page_size, - limit=self._page_size, - include_costs=self.include_costs, - ) - ) - return response.calls - - def _get_one(self, index: int) -> WeaveObject: - if index < 0: - raise IndexError("Negative indexing not supported") - - page_index = index // self._page_size - page_offset = index % self._page_size - - calls = self._fetch_page(page_index) - if page_offset >= len(calls): - raise IndexError(f"Index {index} out of range") - - call = calls[page_offset] - entity, project = self.project_id.split("/") - return make_client_call(entity, project, call, self.server) - - def _get_slice(self, key: slice) -> Iterator[WeaveObject]: - if (start := key.start or 0) < 0: - raise ValueError("Negative start not supported") - if (stop := key.stop) is not None and stop < 0: - raise ValueError("Negative stop not supported") - if (step := key.step or 1) < 0: - raise ValueError("Negative step not supported") - - i = start - while stop is None or i < stop: - try: - yield self._get_one(i) - except IndexError: - break - i += step - - def __getitem__(self, key: slice | int) -> WeaveObject | list[WeaveObject]: - if isinstance(key, slice): - return list(self._get_slice(key)) - return self._get_one(key) - - def __iter__(self) -> Iterator[WeaveObject]: - return self._get_slice(slice(0, None, 1)) - - def make_client_call( entity: str, project: str, server_call: CallSchema, server: TraceServerInterface ) -> WeaveObject: @@ -642,8 +690,11 @@ def get_calls( if filter is None: filter = CallsFilter() - return CallsIter( - self.server, self._project_id(), filter, include_costs or False + return _make_calls_iterator( + self.server, + self._project_id(), + filter, + include_costs, ) @deprecated(new_name="get_calls") From 0525621c503ee6082a2d809f17bef64b04ea71ef Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 11 Dec 2024 17:05:51 -0500 Subject: [PATCH 10/16] feat(weave): Support op configuration for autopatched functions (starting with OpenAI) (#3197) --- tests/conftest.py | 12 +- .../test_configuration_with_dicts.yaml | 102 +++++++++++++ ...est_disabled_integration_doesnt_patch.yaml | 102 +++++++++++++ .../test_enabled_integration_patches.yaml | 102 +++++++++++++ .../test_passthrough_op_kwargs.yaml | 102 +++++++++++++ tests/integrations/openai/test_autopatch.py | 116 +++++++++++++++ weave/integrations/openai/openai_sdk.py | 136 +++++++++++------- weave/scorers/llm_utils.py | 4 - weave/trace/api.py | 9 +- weave/trace/autopatch.py | 59 +++++++- weave/trace/patcher.py | 8 ++ weave/trace/weave_init.py | 6 +- 12 files changed, 689 insertions(+), 69 deletions(-) create mode 100644 tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml create mode 100644 tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml create mode 100644 tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml create mode 100644 tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml create mode 100644 tests/integrations/openai/test_autopatch.py diff --git a/tests/conftest.py b/tests/conftest.py index b28187a3833..85e9b53c36b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -477,7 +477,9 @@ def __getattribute__(self, name): return ServerRecorder(server) -def create_client(request) -> weave_init.InitializedClient: +def create_client( + request, autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None +) -> weave_init.InitializedClient: inited_client = None weave_server_flag = request.config.getoption("--weave-server") server: tsi.TraceServerInterface @@ -513,7 +515,7 @@ def create_client(request) -> weave_init.InitializedClient: entity, project, make_server_recorder(server) ) inited_client = weave_init.InitializedClient(client) - autopatch.autopatch() + autopatch.autopatch(autopatch_settings) return inited_client @@ -527,6 +529,7 @@ def client(request): yield inited_client.client finally: inited_client.reset() + autopatch.reset_autopatch() @pytest.fixture() @@ -534,12 +537,13 @@ def client_creator(request): """This fixture is useful for delaying the creation of the client (ex. when you want to set settings first)""" @contextlib.contextmanager - def client(): - inited_client = create_client(request) + def client(autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None): + inited_client = create_client(request, autopatch_settings) try: yield inited_client.client finally: inited_client.reset() + autopatch.reset_autopatch() yield client diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml new file mode 100644 index 00000000000..7245829a0b3 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFJNa9wwEL37V0x1yWVd7P3KspcSSKE5thvooSlGK40tJbJGSOOSNOx/ + L/Z+2KEp9KLDe/Me743mNQMQVostCGUkqza4/Eav1f1O/v4aVvvbL/Pdt7tVHUp1s/vsnhsx6xW0 + f0TFZ9VHRW1wyJb8kVYRJWPvWl4vFpvNYl2UA9GSRtfLmsD5kvJ5MV/mxSYv1iehIaswiS38yAAA + Xoe3j+g1PostFLMz0mJKskGxvQwBiEiuR4RMySaWnsVsJBV5Rj+k/m5eQJO/YkhP6JDJJ6htYxhQ + KgPEBuOnB//g7w2eJ438hcAGoek4fZgaR6y7JPtevnPuhB8uSR01IdI+nfgLXltvk6kiykS+T5WY + ghjYQwbwc9hI96akCJHawBXTE/resCyPdmL8ggm5PJFMLN2Iz1ezd9wqjSytS5ONCiWVQT0qx/XL + TluaENmk899h3vM+9ra++R/7kVAKA6OuQkRt1dvC41jE/kD/NXbZ8RBYpJfE2Fa19Q3GEO3xRupQ + qWslC9xLJUV2yP4AAAD//wMA4O+DUSwDAAA= + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f01fe3aabd037cf-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 02:20:01 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=xqe_jHZdTV5LijJQYQ3GMY5MjtVrCyxbFO4glgLvgD0-1733883601-1.0.1.1-p.DDUca_cHppJu2hXzzA0CXU1mtalxHUNfBWVgPIQj.UkU603pbNscCvSIi4_Zjlz9Zuc3.hjlvoyZxcDBJTsw; + path=/; expires=Wed, 11-Dec-24 02:50:01 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=WEjxXqkGswaEDhllTROGX_go9tgaWNJcUJ3cCd50xDI-1733883601764-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '607' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_8592a74b531c806f65c63c7471101cb6 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml new file mode 100644 index 00000000000..1895cdcd5f2 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLJbtswEL3rK6a85GIV8lYvl6KX9NQFrYEckkKgyZHImOII5KiJEfjf + C8qLHDQFeuHhbXgzw5cMQFgt1iCUkaya1uWf9Hyuv5H29ebu89Pt80Z9//H0RS2nX/e3LEbJQdtH + VHx2vVfUtA7Zkj/SKqBkTKnjxXS6XCwWk3FPNKTRJVvdcj6jfFJMZnmxzIsPJ6MhqzCKNdxnAAAv + /Zsqeo3PYg3F6Iw0GKOsUawvIgARyCVEyBhtZOmPdU+kIs/o+9Y/u4AjMBjwJoIEZ2vDuUEZGDU8 + 0g6hogB76tYP/sHfmT1o8jcMcYcOmXyEKlkApTJAbDB8TMKNwbPSyN8IbBDqjuO76xoBqy7KtAXf + OXfCD5e5HNVtoG088Re8st5GUwaUkXyaITK1omcPGcCvfn/dq5WINlDTcsm0Q58Cx+NjnBgONpCT + 2YlkYukGfDofvZFWamRpXbzav1BSGdSDcziW7LSlKyK7mvnvMm9lH+e2vv6f+IFQCltGXbYBtVWv + Bx5kAdN3/pfssuO+sIj7yNiUlfU1hjbY44+q2nKl54XSq1WxFdkh+wMAAP//AwAWTTnuWgMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eadbff439d2-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:01 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=8FO1yMjc3pMQWRpWrkIe5mcs39GLeqQPmgHQq0YTT8s-1733877721-1.0.1.1-i4G06DBN08aH1F1H73U_TB9OLK3jLsV1jXydB1cQ4Hqx7I.r8xDn.7hFRZe2hy3D_nABTG1nDcdDoXL_wYiqug; + path=/; expires=Wed, 11-Dec-24 01:12:01 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=jxwySgtriPkUP8L2os1nb_gRq_SSUo3yWFUyJmHPmGY-1733877721989-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '652' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_1c86d4fda2ad715edfd41bcd2f4bdd89 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml new file mode 100644 index 00000000000..f0cdca54158 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLBjtMwEL3nKwZfuDQoTXc3VS8IBOIACCQOHHZR5NrTxDTxWJ4J2rDq + v6Ok2SYrFomLD+/Ne3pvxg8JgHJW7UCZWotpQ5O+sdfXaL9+/PDp+L6gG3ffyudv735v+y82eLUa + FLT/iUYeVa8MtaFBcTTRJqIWHFzXxWazLYoiz0eiJYvNIKuCpFeU5ll+lWbbNLuZhDU5g6x2cJsA + ADyM7xDRW7xXO8hWj0iLzLpCtbsMAahIzYAozexYtBe1mklDXtCPqb/XPVjyLwXYOPTiWBgkdiyg + hVp+fefv/Fs0umMEqbGHVh8RugD4C2MvtfPVi6V3xEPHeqjmu6aZ8NMlbENViLTnib/gB+cd12VE + zeSHYCwU1MieEoAf41K6Jz1ViNQGKYWO6AfD9fpsp+YrLMh8IoVENzOeb1bPuJUWRbuGF0tVRpsa + 7aycL6A762hBJIvOf4d5zvvc2/nqf+xnwhgMgrYMEa0zTwvPYxGHP/qvscuOx8CKexZsy4PzFcYQ + 3fmbHEJpCqMz3GujVXJK/gAAAP//AwAyhdwOLwMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eb36bb3a240-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:02 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=Q_ATX8JU4jFqXJPdwlneOua9wmNmAaASyAfcbPyPqng-1733877722-1.0.1.1-eTMEvBW7oqQa2i3l.Or2I3LF_cCESxfseq.S9DBr8dAJWsVoFfPxKtr5vMaO6yj4hRW8XOSOHcgIcwwqbHrLbg; + path=/; expires=Wed, 11-Dec-24 01:12:02 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=2ak.tRpn6uEHbM8GrWy_ALtrN34jVSNIJI1mFG2etvM-1733877722703-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '476' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_52e061e1cc55cdd8847a7ba9342f1a14 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml new file mode 100644 index 00000000000..646c57c6123 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLLbtswELzrK7a89GIVsmLXsS9Fr0UvBQIERVMINLkS2VBcglwVcQL/ + e0H5IQVNgV54mNkZzCz3pQAQVosdCGUkqz648rNer7G2z3fL8ITq+X6lvn7x/SHhN/d9LxZZQftf + qPii+qCoDw7Zkj/RKqJkzK7Lzc3N7WazqeuR6Emjy7IucLmisq7qVVndltXHs9CQVZjEDn4UAAAv + 45sjeo1PYgfV4oL0mJLsUOyuQwAiksuIkCnZxNKzWEykIs/ox9T35gCa/HuG9IgOmXyC1naGAaUy + QGwwfnrwD/7O4GXSyN8IbBC6gdO7uXHEdkgy9/KDc2f8eE3qqAuR9unMX/HWeptME1Em8jlVYgpi + ZI8FwM9xI8OrkiJE6gM3TI/os+FyebIT0xfMyNWZZGLpJrxeL95wazSytC7NNiqUVAb1pJzWLwdt + aUYUs85/h3nL+9Tb+u5/7CdCKQyMugkRtVWvC09jEfOB/mvsuuMxsEiHxNg3rfUdxhDt6Uba0Gz1 + ulJ6u632ojgWfwAAAP//AwCOwDMjLAMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eb76b71ac9a-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:03 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=r.xSSsYQNFPvMiizFSvjQiecNA6Q1wQa0VR1YElfXi4-1733877723-1.0.1.1-GVW0i7wrpHCQSY5eXu7sIQgxYWl6jfeSordQ7JFxV3lO6UfFhwxRT92bBP4DfnrSYpBpRw3k4aONAURyvKctiQ; + path=/; expires=Wed, 11-Dec-24 01:12:03 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=CQJVOdASzL9ency5_q6SDaInTsvpjA240cIxf.AUwXM-1733877723385-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '523' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_c9c57cfa6f37a99aaf0abac013237ed6 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/test_autopatch.py b/tests/integrations/openai/test_autopatch.py new file mode 100644 index 00000000000..2c2f5201d3f --- /dev/null +++ b/tests/integrations/openai/test_autopatch.py @@ -0,0 +1,116 @@ +# This is included here for convenience. Instead of creating a dummy API, we can test +# autopatching against the actual OpenAI API. + +from typing import Any + +import pytest +from openai import OpenAI + +from weave.integrations.openai import openai_sdk +from weave.trace.autopatch import AutopatchSettings, IntegrationSettings, OpSettings + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_disabled_integration_doesnt_patch(client_creator): + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings(enabled=False), + ) + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 0 + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_enabled_integration_patches(client_creator): + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings(enabled=True), + ) + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_passthrough_op_kwargs(client_creator): + def redact_inputs(inputs: dict[str, Any]) -> dict[str, Any]: + return dict.fromkeys(inputs, "REDACTED") + + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings( + op_settings=OpSettings( + postprocess_inputs=redact_inputs, + ) + ) + ) + + # Explicitly reset the patcher here to pretend like we're starting fresh. We need + # to do this because `_openai_patcher` is a global variable that is shared across + # tests. If we don't reset it, it will retain the state from the previous test, + # which can cause this test to fail. + openai_sdk._openai_patcher = None + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + call = calls[0] + assert all(v == "REDACTED" for v in call.inputs.values()) + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_configuration_with_dicts(client_creator): + def redact_inputs(inputs: dict[str, Any]) -> dict[str, Any]: + return dict.fromkeys(inputs, "REDACTED") + + autopatch_settings = { + "openai": { + "op_settings": {"postprocess_inputs": redact_inputs}, + } + } + + openai_sdk._openai_patcher = None + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + call = calls[0] + assert all(v == "REDACTED" for v in call.inputs.values()) diff --git a/weave/integrations/openai/openai_sdk.py b/weave/integrations/openai/openai_sdk.py index 7814700d4d3..a1e3a9b5831 100644 --- a/weave/integrations/openai/openai_sdk.py +++ b/weave/integrations/openai/openai_sdk.py @@ -1,15 +1,20 @@ +from __future__ import annotations + import importlib from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op import Op, ProcessedInputs from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from openai.types.chat import ChatCompletionChunk +_openai_patcher: MultiPatcher | None = None + def maybe_unwrap_api_response(value: Any) -> Any: """If the caller requests a raw response, we unwrap the APIResponse object. @@ -43,9 +48,7 @@ def maybe_unwrap_api_response(value: Any) -> Any: return value -def openai_on_finish_post_processor( - value: Optional["ChatCompletionChunk"], -) -> Optional[dict]: +def openai_on_finish_post_processor(value: ChatCompletionChunk | None) -> dict | None: from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ( ChoiceDeltaFunctionCall, @@ -60,8 +63,8 @@ def openai_on_finish_post_processor( value = maybe_unwrap_api_response(value) def _get_function_call( - function_call: Optional[ChoiceDeltaFunctionCall], - ) -> Optional[FunctionCall]: + function_call: ChoiceDeltaFunctionCall | None, + ) -> FunctionCall | None: if function_call is None: return function_call if isinstance(function_call, ChoiceDeltaFunctionCall): @@ -73,8 +76,8 @@ def _get_function_call( return None def _get_tool_calls( - tool_calls: Optional[list[ChoiceDeltaToolCall]], - ) -> Optional[list[ChatCompletionMessageToolCall]]: + tool_calls: list[ChoiceDeltaToolCall] | None, + ) -> list[ChatCompletionMessageToolCall] | None: if tool_calls is None: return tool_calls @@ -128,10 +131,10 @@ def _get_tool_calls( def openai_accumulator( - acc: Optional["ChatCompletionChunk"], - value: "ChatCompletionChunk", + acc: ChatCompletionChunk | None, + value: ChatCompletionChunk, skip_last: bool = False, -) -> "ChatCompletionChunk": +) -> ChatCompletionChunk: from openai.types.chat import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ( ChoiceDeltaFunctionCall, @@ -285,7 +288,7 @@ def should_use_accumulator(inputs: dict) -> bool: def openai_on_input_handler( func: Op, args: tuple, kwargs: dict -) -> Optional[ProcessedInputs]: +) -> ProcessedInputs | None: if len(args) == 2 and isinstance(args[1], weave.EasyPrompt): original_args = args original_kwargs = kwargs @@ -305,20 +308,16 @@ def openai_on_input_handler( return None -def create_wrapper_sync( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: "We need to do this so we can check if `stream` is used" def _add_stream_options(fn: Callable) -> Callable: @wraps(fn) def _wrapper(*args: Any, **kwargs: Any) -> Any: - if bool(kwargs.get("stream")) and kwargs.get("stream_options") is None: + if kwargs.get("stream") and kwargs.get("stream_options") is None: kwargs["stream_options"] = {"include_usage": True} - return fn( - *args, **kwargs - ) # This is where the final execution of fn is happening. + return fn(*args, **kwargs) return _wrapper @@ -327,8 +326,8 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return True return False - op = weave.op()(_add_stream_options(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_add_stream_options(fn), **op_kwargs) op._set_on_input_handler(openai_on_input_handler) return add_accumulator( op, # type: ignore @@ -345,16 +344,14 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: # Surprisingly, the async `client.chat.completions.create` does not pass # `inspect.iscoroutinefunction`, so we can't dispatch on it and must write # it manually here... -def create_wrapper_async( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: "We need to do this so we can check if `stream` is used" def _add_stream_options(fn: Callable) -> Callable: @wraps(fn) async def _wrapper(*args: Any, **kwargs: Any) -> Any: - if bool(kwargs.get("stream")) and kwargs.get("stream_options") is None: + if kwargs.get("stream") and kwargs.get("stream_options") is None: kwargs["stream_options"] = {"include_usage": True} return await fn(*args, **kwargs) @@ -365,8 +362,8 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return True return False - op = weave.op()(_add_stream_options(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_add_stream_options(fn), **op_kwargs) op._set_on_input_handler(openai_on_input_handler) return add_accumulator( op, # type: ignore @@ -380,28 +377,61 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return wrapper -symbol_patchers = [ - # Patch the Completions.create method - SymbolPatcher( - lambda: importlib.import_module("openai.resources.chat.completions"), - "Completions.create", - create_wrapper_sync(name="openai.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.chat.completions"), - "AsyncCompletions.create", - create_wrapper_async(name="openai.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.beta.chat.completions"), - "Completions.parse", - create_wrapper_sync(name="openai.beta.chat.completions.parse"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.beta.chat.completions"), - "AsyncCompletions.parse", - create_wrapper_async(name="openai.beta.chat.completions.parse"), - ), -] - -openai_patcher = MultiPatcher(symbol_patchers) # type: ignore +def get_openai_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _openai_patcher + if _openai_patcher is not None: + return _openai_patcher + + base = settings.op_settings + + completions_create_settings = base.model_copy( + update={"name": base.name or "openai.chat.completions.create"} + ) + async_completions_create_settings = base.model_copy( + update={"name": base.name or "openai.chat.completions.create"} + ) + completions_parse_settings = base.model_copy( + update={"name": base.name or "openai.beta.chat.completions.parse"} + ) + async_completions_parse_settings = base.model_copy( + update={"name": base.name or "openai.beta.chat.completions.parse"} + ) + + _openai_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("openai.resources.chat.completions"), + "Completions.create", + create_wrapper_sync(settings=completions_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("openai.resources.chat.completions"), + "AsyncCompletions.create", + create_wrapper_async(settings=async_completions_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module( + "openai.resources.beta.chat.completions" + ), + "Completions.parse", + create_wrapper_sync(settings=completions_parse_settings), + ), + SymbolPatcher( + lambda: importlib.import_module( + "openai.resources.beta.chat.completions" + ), + "AsyncCompletions.parse", + create_wrapper_async(settings=async_completions_parse_settings), + ), + ] + ) + + return _openai_patcher diff --git a/weave/scorers/llm_utils.py b/weave/scorers/llm_utils.py index 68ae2ccb366..eef6f018b0f 100644 --- a/weave/scorers/llm_utils.py +++ b/weave/scorers/llm_utils.py @@ -2,10 +2,6 @@ from typing import TYPE_CHECKING, Any, Union -from weave.trace.autopatch import autopatch - -autopatch() # ensure both weave patching and instructor patching are applied - OPENAI_DEFAULT_MODEL = "gpt-4o" OPENAI_DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small" OPENAI_DEFAULT_MODERATION_MODEL = "text-moderation-latest" diff --git a/weave/trace/api.py b/weave/trace/api.py index ee8131b0875..294308cbb67 100644 --- a/weave/trace/api.py +++ b/weave/trace/api.py @@ -13,6 +13,7 @@ # There is probably a better place for this, but including here for now to get the fix in. from weave import type_handlers # noqa: F401 from weave.trace import urls, util, weave_client, weave_init +from weave.trace.autopatch import AutopatchSettings from weave.trace.constants import TRACE_OBJECT_EMOJI from weave.trace.context import call_context from weave.trace.context import weave_client_context as weave_client_context @@ -32,6 +33,7 @@ def init( project_name: str, *, settings: UserSettings | dict[str, Any] | None = None, + autopatch_settings: AutopatchSettings | None = None, ) -> weave_client.WeaveClient: """Initialize weave tracking, logging to a wandb project. @@ -52,7 +54,12 @@ def init( if should_disable_weave(): return weave_init.init_weave_disabled().client - return weave_init.init_weave(project_name).client + initialized_client = weave_init.init_weave( + project_name, + autopatch_settings=autopatch_settings, + ) + + return initialized_client.client @contextlib.contextmanager diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py index 3a5dca14556..0619194a224 100644 --- a/weave/trace/autopatch.py +++ b/weave/trace/autopatch.py @@ -4,8 +4,54 @@ check if libraries are installed and imported and patch in the case that they are. """ +from typing import Any, Callable, Optional, Union -def autopatch() -> None: +from pydantic import BaseModel, Field, validate_call + +from weave.trace.weave_client import Call + + +class OpSettings(BaseModel): + """Op settings for a specific integration. + These currently subset the `op` decorator args to provide a consistent interface + when working with auto-patched functions. See the `op` decorator for more details.""" + + name: Optional[str] = None + call_display_name: Optional[Union[str, Callable[[Call], str]]] = None + postprocess_inputs: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None + postprocess_output: Optional[Callable[[Any], Any]] = None + + +class IntegrationSettings(BaseModel): + """Configuration for a specific integration.""" + + enabled: bool = True + op_settings: OpSettings = Field(default_factory=OpSettings) + + +class AutopatchSettings(BaseModel): + """Settings for auto-patching integrations.""" + + # These will be uncommented as we add support for more integrations. Note that + + # anthropic: IntegrationSettings = Field(default_factory=IntegrationSettings) + # cerebras: IntegrationSettings = Field(default_factory=IntegrationSettings) + # cohere: IntegrationSettings = Field(default_factory=IntegrationSettings) + # dspy: IntegrationSettings = Field(default_factory=IntegrationSettings) + # google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings) + # groq: IntegrationSettings = Field(default_factory=IntegrationSettings) + # instructor: IntegrationSettings = Field(default_factory=IntegrationSettings) + # langchain: IntegrationSettings = Field(default_factory=IntegrationSettings) + # litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) + # llamaindex: IntegrationSettings = Field(default_factory=IntegrationSettings) + # mistral: IntegrationSettings = Field(default_factory=IntegrationSettings) + # notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings) + openai: IntegrationSettings = Field(default_factory=IntegrationSettings) + # vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings) + + +@validate_call +def autopatch(settings: Optional[AutopatchSettings] = None) -> None: from weave.integrations.anthropic.anthropic_sdk import anthropic_patcher from weave.integrations.cerebras.cerebras_sdk import cerebras_patcher from weave.integrations.cohere.cohere_sdk import cohere_patcher @@ -20,10 +66,13 @@ def autopatch() -> None: from weave.integrations.llamaindex.llamaindex import llamaindex_patcher from weave.integrations.mistral import mistral_patcher from weave.integrations.notdiamond.tracing import notdiamond_patcher - from weave.integrations.openai.openai_sdk import openai_patcher + from weave.integrations.openai.openai_sdk import get_openai_patcher from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher - openai_patcher.attempt_patch() + if settings is None: + settings = AutopatchSettings() + + get_openai_patcher(settings.openai).attempt_patch() mistral_patcher.attempt_patch() litellm_patcher.attempt_patch() llamaindex_patcher.attempt_patch() @@ -54,10 +103,10 @@ def reset_autopatch() -> None: from weave.integrations.llamaindex.llamaindex import llamaindex_patcher from weave.integrations.mistral import mistral_patcher from weave.integrations.notdiamond.tracing import notdiamond_patcher - from weave.integrations.openai.openai_sdk import openai_patcher + from weave.integrations.openai.openai_sdk import get_openai_patcher from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher - openai_patcher.undo_patch() + get_openai_patcher().undo_patch() mistral_patcher.undo_patch() litellm_patcher.undo_patch() llamaindex_patcher.undo_patch() diff --git a/weave/trace/patcher.py b/weave/trace/patcher.py index 1567c4e2bb9..c1d0d653ffa 100644 --- a/weave/trace/patcher.py +++ b/weave/trace/patcher.py @@ -17,6 +17,14 @@ def undo_patch(self) -> bool: raise NotImplementedError() +class NoOpPatcher(Patcher): + def attempt_patch(self) -> bool: + return True + + def undo_patch(self) -> bool: + return True + + class MultiPatcher(Patcher): def __init__(self, patchers: Sequence[Patcher]) -> None: self.patchers = patchers diff --git a/weave/trace/weave_init.py b/weave/trace/weave_init.py index 563dcbdaed4..f51d42d5018 100644 --- a/weave/trace/weave_init.py +++ b/weave/trace/weave_init.py @@ -63,7 +63,9 @@ def get_entity_project_from_project_name(project_name: str) -> tuple[str, str]: def init_weave( - project_name: str, ensure_project_exists: bool = True + project_name: str, + ensure_project_exists: bool = True, + autopatch_settings: autopatch.AutopatchSettings | None = None, ) -> InitializedClient: global _current_inited_client if _current_inited_client is not None: @@ -120,7 +122,7 @@ def init_weave( # autopatching is only supported for the wandb client, because OpenAI calls are not # logged in local mode currently. When that's fixed, this autopatch call can be # moved to InitializedClient.__init__ - autopatch.autopatch() + autopatch.autopatch(autopatch_settings) username = get_username() try: From 16e47c3a8d804db1f7c8c80fe53b5b58082a6757 Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:19:24 -0600 Subject: [PATCH 11/16] chore(ui): update UUID dependency to v11 (latest) (#3208) --- weave-js/package.json | 3 +-- weave-js/yarn.lock | 15 +++++---------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/weave-js/package.json b/weave-js/package.json index cb57125143a..1f551021ed4 100644 --- a/weave-js/package.json +++ b/weave-js/package.json @@ -192,7 +192,6 @@ "@types/react-virtualized-auto-sizer": "^1.0.0", "@types/safe-json-stringify": "^1.1.2", "@types/styled-components": "^5.1.26", - "@types/uuid": "^9.0.1", "@types/wavesurfer.js": "^2.0.0", "@types/zen-observable": "^0.8.3", "@typescript-eslint/eslint-plugin": "5.35.1", @@ -237,7 +236,7 @@ "tslint-config-prettier": "^1.18.0", "tslint-plugin-prettier": "^2.3.0", "typescript": "4.7.4", - "uuid": "^9.0.0", + "uuid": "^11.0.3", "vite": "5.2.9", "vitest": "^1.6.0" }, diff --git a/weave-js/yarn.lock b/weave-js/yarn.lock index c7f9379e32a..6a5ec14e872 100644 --- a/weave-js/yarn.lock +++ b/weave-js/yarn.lock @@ -4776,11 +4776,6 @@ resolved "https://registry.yarnpkg.com/@types/unist/-/unist-2.0.7.tgz#5b06ad6894b236a1d2bd6b2f07850ca5c59cf4d6" integrity sha512-cputDpIbFgLUaGQn6Vqg3/YsJwxUwHLO13v3i5ouxT4lat0khip9AEWxtERujXV9wxIB1EyF97BSJFt6vpdI8g== -"@types/uuid@^9.0.1": - version "9.0.2" - resolved "https://registry.yarnpkg.com/@types/uuid/-/uuid-9.0.2.tgz#ede1d1b1e451548d44919dc226253e32a6952c4b" - integrity sha512-kNnC1GFBLuhImSnV7w4njQkUiJi0ZXUycu1rUaouPqiKlXkh77JKgdRnTAp1x5eBwcIwbtI+3otwzuIDEuDoxQ== - "@types/wavesurfer.js@^2.0.0": version "2.0.2" resolved "https://registry.yarnpkg.com/@types/wavesurfer.js/-/wavesurfer.js-2.0.2.tgz#b98a4d57ca24ee2028ae6dd5c2208b568bb73842" @@ -15032,6 +15027,11 @@ util-deprecate@^1.0.1, util-deprecate@^1.0.2, util-deprecate@~1.0.1: resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" integrity sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw== +uuid@^11.0.3: + version "11.0.3" + resolved "https://registry.yarnpkg.com/uuid/-/uuid-11.0.3.tgz#248451cac9d1a4a4128033e765d137e2b2c49a3d" + integrity sha512-d0z310fCWv5dJwnX1Y/MncBAqGMKEzlBb1AOf7z9K8ALnd0utBX/msg/fA0+sbyN1ihbMsLhrBlnl1ak7Wa0rg== + uuid@^2.0.2: version "2.0.3" resolved "https://registry.yarnpkg.com/uuid/-/uuid-2.0.3.tgz#67e2e863797215530dff318e5bf9dcebfd47b21a" @@ -15042,11 +15042,6 @@ uuid@^3.0.0, uuid@^3.4.0: resolved "https://registry.yarnpkg.com/uuid/-/uuid-3.4.0.tgz#b23e4358afa8a202fe7a100af1f5f883f02007ee" integrity sha512-HjSDRw6gZE5JMggctHBcjVak08+KEVhSIiDzFnT9S9aegmp85S/bReBVTb4QTFaRNptJ9kuYaNhnbNEOkbKb/A== -uuid@^9.0.0: - version "9.0.0" - resolved "https://registry.yarnpkg.com/uuid/-/uuid-9.0.0.tgz#592f550650024a38ceb0c562f2f6aa435761efb5" - integrity sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg== - uvu@^0.5.0: version "0.5.6" resolved "https://registry.yarnpkg.com/uvu/-/uvu-0.5.6.tgz#2754ca20bcb0bb59b64e9985e84d2e81058502df" From 444c04dcdeb965b531d7508f5e82af56cfcb00f2 Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen <112953339+jamie-rasmussen@users.noreply.github.com> Date: Wed, 11 Dec 2024 18:50:14 -0600 Subject: [PATCH 12/16] chore(ui): remove some unused code (#3157) --- .../PagePanelComponents/Home/Browse3.tsx | 138 +----------------- 1 file changed, 1 insertion(+), 137 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index 3517f4d3b9c..761fd536930 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -1,15 +1,5 @@ import {ApolloProvider} from '@apollo/client'; -import {Home} from '@mui/icons-material'; -import { - AppBar, - Box, - Breadcrumbs, - Drawer, - IconButton, - Link as MaterialLink, - Toolbar, - Typography, -} from '@mui/material'; +import {Box, Drawer} from '@mui/material'; import { GridColumnVisibilityModel, GridFilterModel, @@ -21,9 +11,7 @@ import {LicenseInfo} from '@mui/x-license'; import {makeGorillaApolloClient} from '@wandb/weave/apollo'; import {EVALUATE_OP_NAME_POST_PYDANTIC} from '@wandb/weave/components/PagePanelComponents/Home/Browse3/pages/common/heuristics'; import {opVersionKeyToRefUri} from '@wandb/weave/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/utilities'; -import _ from 'lodash'; import React, { - ComponentProps, FC, useCallback, useEffect, @@ -33,7 +21,6 @@ import React, { } from 'react'; import useMousetrap from 'react-hook-mousetrap'; import { - Link as RouterLink, Redirect, Route, Switch, @@ -199,7 +186,6 @@ export const Browse3: FC<{ `/${URL_BROWSE3}`, ]}> @@ -211,7 +197,6 @@ export const Browse3: FC<{ }; const Browse3Mounted: FC<{ - hideHeader?: boolean; headerOffset?: number; navigateAwayFromProject?: () => void; }> = props => { @@ -225,37 +210,6 @@ const Browse3Mounted: FC<{ overflow: 'auto', flexDirection: 'column', }}> - {!props.hideHeader && ( - theme.zIndex.drawer + 1, - height: '60px', - flex: '0 0 auto', - position: 'static', - }}> - - - theme.palette.getContrastText(theme.palette.primary.main), - '&:hover': { - color: theme => - theme.palette.getContrastText(theme.palette.primary.dark), - }, - marginRight: theme => theme.spacing(2), - }}> - - - - - - )} @@ -1050,20 +1004,6 @@ const ComparePageBinding = () => { return ; }; -const AppBarLink = (props: ComponentProps) => ( - theme.palette.getContrastText(theme.palette.primary.main), - '&:hover': { - color: theme => - theme.palette.getContrastText(theme.palette.primary.dark), - }, - }} - {...props} - component={RouterLink} - /> -); - const PlaygroundPageBinding = () => { const params = useParamsDecoded(); return ( @@ -1074,79 +1014,3 @@ const PlaygroundPageBinding = () => { /> ); }; - -const Browse3Breadcrumbs: FC = props => { - const params = useParamsDecoded(); - const query = useURLSearchParamsDict(); - const filePathParts = query.path?.split('/') ?? []; - const refFields = query.extra?.split('/') ?? []; - - return ( - - {params.entity && ( - - {params.entity} - - )} - {params.project && ( - - {params.project} - - )} - {params.tab && ( - - {params.tab} - - )} - {params.itemName && ( - - {params.itemName} - - )} - {params.version && ( - - {params.version} - - )} - {filePathParts.map((part, idx) => ( - - {part} - - ))} - {_.range(0, refFields.length, 2).map(idx => ( - - - theme.palette.getContrastText(theme.palette.primary.main), - }}> - {refFields[idx]} - - - {refFields[idx + 1]} - - - ))} - - ); -}; From 96d1d0d0f48cd571ae7b7737de2fd58d663588d9 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Wed, 11 Dec 2024 17:55:26 -0800 Subject: [PATCH 13/16] chore(ui): Create scorer drawer style + small annotation drawer style tweaks (#3186) --- .../StructuredFeedback/HumanAnnotation.tsx | 4 +- .../ScorersPage/AnnotationScorerForm.tsx | 9 +- .../pages/ScorersPage/FormComponents.tsx | 2 +- .../pages/ScorersPage/NewScorerDrawer.tsx | 25 +++- .../pages/ScorersPage/ZodSchemaForm.tsx | 133 +++++++++++------- 5 files changed, 111 insertions(+), 62 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx index 7facffe9556..2821c02affa 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx @@ -415,7 +415,7 @@ export const TextFeedbackColumn = ({ placeholder="" /> {maxLength && ( -
+
{`Maximum characters: ${maxLength}`}
)} @@ -603,7 +603,7 @@ export const NumericalTextField: React.FC = ({ errorState={error} /> {(min != null || max != null) && ( -
+
{isInteger ? 'Integer required. ' : ''} {min != null && `Min: ${min}`} {min != null && max != null && ', '} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx index 9acbdfe6c2f..a478437facb 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx @@ -1,5 +1,5 @@ import {Box} from '@material-ui/core'; -import React, {FC, useCallback, useState} from 'react'; +import React, {FC, useCallback, useEffect, useState} from 'react'; import {z} from 'zod'; import {createBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; @@ -28,7 +28,7 @@ const AnnotationScorerFormSchema = z.object({ }), z.object({ type: z.literal('String'), - 'Max length': z.number().optional(), + 'Maximum length': z.number().optional(), }), z.object({ type: z.literal('Select'), @@ -45,6 +45,9 @@ export const AnnotationScorerForm: FC< ScorerFormProps> > = ({data, onDataChange}) => { const [config, setConfig] = useState(data ?? DEFAULT_STATE); + useEffect(() => { + setConfig(data ?? DEFAULT_STATE); + }, [data]); const [isValid, setIsValid] = useState(false); const handleConfigChange = useCallback( @@ -113,7 +116,7 @@ function convertTypeExtrasToJsonSchema( const typeSchema = obj.Type; const typeExtras: Record = {}; if (typeSchema.type === 'String') { - typeExtras.maxLength = typeSchema['Max length']; + typeExtras.maxLength = typeSchema['Maximum length']; } else if (typeSchema.type === 'Integer' || typeSchema.type === 'Number') { typeExtras.minimum = typeSchema.Minimum; typeExtras.maximum = typeSchema.Maximum; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx index 2716bfbfa81..250c896cfea 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx @@ -3,7 +3,7 @@ import {Select} from '@wandb/weave/components/Form/Select'; import {TextField} from '@wandb/weave/components/Form/TextField'; import React from 'react'; -export const GAP_BETWEEN_ITEMS_PX = 10; +export const GAP_BETWEEN_ITEMS_PX = 16; export const GAP_BETWEEN_LABEL_AND_FIELD_PX = 10; type AutocompleteWithLabelType