From 5f53f887d9d3d8f8fd4ae3cf7dd94b0c8404c4f8 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 16 Aug 2024 14:56:24 -0700 Subject: [PATCH 01/22] more work --- weave/trace_server/refs_internal.py | 24 ++++++++++++------------ weave/trace_server/validation.py | 14 ++++++++++++++ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/weave/trace_server/refs_internal.py b/weave/trace_server/refs_internal.py index eaeaf4450761..c6be1e419674 100644 --- a/weave/trace_server/refs_internal.py +++ b/weave/trace_server/refs_internal.py @@ -8,6 +8,8 @@ import urllib from typing import Union +from . import validation + WEAVE_INTERNAL_SCHEME = "weave-trace-internal" WEAVE_SCHEME = "weave" WEAVE_PRIVATE_SCHEME = "weave-private" @@ -29,18 +31,15 @@ class InvalidInternalRef(ValueError): pass -def quote_select(s: str, quote_chars: tuple[str, ...] = ("/", ":", "%")) -> str: +def extra_value_quoter(s: str, quote_chars: tuple[str, ...] = ("%", "/")) -> str: """We don't need to quote every single character, rather we just need to quote the characters that are used as - delimiters in the URI. Right now, we only use "/", ":", and "%". Moreover, - the only user-controlled fields are the object name and the extra fields. - - All other fields are generated by the system and are safe for parsing. + delimiters in the URI. """ # We have to be careful here since quoting creates new "%" characters. # We don't want to double-quote characters. - if "%" in quote_chars and quote_chars[-1] != "%": - raise ValueError("Quoting '%' must be the last character in the list.") + if "%" in quote_chars and quote_chars[0] != "%": + raise ValueError("Quoting '%' must be the first character in the list.") for c in quote_chars: s = s.replace(c, urllib.parse.quote(c)) @@ -118,20 +117,21 @@ def __post_init__(self) -> None: validate_no_slashes(self.version, "version") validate_no_colons(self.version, "version") validate_extra(self.extra) + validation.object_id_validator(self.name) def uri(self) -> str: - u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/object/{quote_select(self.name)}:{self.version}" + u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/object/{self.name}:{self.version}" if self.extra: - u += "/" + "/".join(quote_select(e) for e in self.extra) + u += "/" + "/".join(extra_value_quoter(e) for e in self.extra) return u @dataclasses.dataclass(frozen=True) class InternalOpRef(InternalObjectRef): def uri(self) -> str: - u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/op/{quote_select(self.name)}:{self.version}" + u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/op/{self.name}:{self.version}" if self.extra: - u += "/" + "/".join(quote_select(e) for e in self.extra) + u += "/" + "/".join(extra_value_quoter(e) for e in self.extra) return u @@ -151,7 +151,7 @@ def __post_init__(self) -> None: def uri(self) -> str: u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/call/{self.id}" if self.extra: - u += "/" + "/".join(quote_select(e) for e in self.extra) + u += "/" + "/".join(extra_value_quoter(e) for e in self.extra) return u diff --git a/weave/trace_server/validation.py b/weave/trace_server/validation.py index 226e946f691a..bfe99946b2de 100644 --- a/weave/trace_server/validation.py +++ b/weave/trace_server/validation.py @@ -1,3 +1,4 @@ +import re import typing from weave.trace_server import refs_internal @@ -57,7 +58,20 @@ def wb_run_id_validator(s: typing.Optional[str]) -> typing.Optional[str]: return s +def _validate_object_name_charset(name: str) -> None: + # Object names must be alphanumeric with dashes + invalid_chars = re.findall(r"[^\w-]", name) + if invalid_chars: + raise ValueError( + f"Invalid object name: {name}. Contains invalid characters: {invalid_chars}" + ) + + if not name: + raise ValueError("Object name cannot be empty") + + def object_id_validator(s: str) -> str: + _validate_object_name_charset(s) return validation_util.require_max_str_len(s, 128) From fb36617e42ed249015706f656322ffb684e94f16 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 16 Aug 2024 15:05:14 -0700 Subject: [PATCH 02/22] adding --- weave/trace/refs.py | 17 +++++++++++------ weave/trace_server/refs_internal.py | 5 ++--- weave/weave_client.py | 17 +++++++++++++++++ 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/weave/trace/refs.py b/weave/trace/refs.py index cfffe99436c1..7d62ea1068ae 100644 --- a/weave/trace/refs.py +++ b/weave/trace/refs.py @@ -1,6 +1,8 @@ import dataclasses from typing import Any, Union +import urllib3 + from ..trace_server import refs_internal DICT_KEY_EDGE_NAME = refs_internal.DICT_KEY_EDGE_NAME @@ -56,7 +58,7 @@ class ObjectRef(RefWithExtra): def uri(self) -> str: u = f"weave:///{self.entity}/{self.project}/object/{self.name}:{self.digest}" if self.extra: - u += "/" + "/".join(self.extra) + u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra) return u def get(self) -> Any: @@ -106,7 +108,7 @@ class OpRef(ObjectRef): def uri(self) -> str: u = f"weave:///{self.entity}/{self.project}/op/{self.name}:{self.digest}" if self.extra: - u += "/" + "/".join(self.extra) + u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra) return u @@ -120,7 +122,7 @@ class CallRef(RefWithExtra): def uri(self) -> str: u = f"weave:///{self.entity}/{self.project}/call/{self.id}" if self.extra: - u += "/" + "/".join(self.extra) + u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra) return u @@ -140,7 +142,10 @@ def parse_uri(uri: str) -> AnyRef: return TableRef(entity=entity, project=project, digest=remaining[0]) elif kind == "call": return CallRef( - entity=entity, project=project, id=remaining[0], extra=remaining[1:] + entity=entity, + project=project, + id=remaining[0], + extra=[urllib3.parse.unquote(r) for r in remaining[1:]], ) elif kind == "object": name, version = remaining[0].split(":") @@ -149,7 +154,7 @@ def parse_uri(uri: str) -> AnyRef: project=project, name=name, digest=version, - extra=remaining[1:], + extra=[urllib3.parse.unquote(r) for r in remaining[1:]], ) elif kind == "op": name, version = remaining[0].split(":") @@ -158,7 +163,7 @@ def parse_uri(uri: str) -> AnyRef: project=project, name=name, digest=version, - extra=remaining[1:], + extra=[urllib3.parse.unquote(r) for r in remaining[1:]], ) else: raise ValueError(f"Unknown ref kind: {kind}") diff --git a/weave/trace_server/refs_internal.py b/weave/trace_server/refs_internal.py index c6be1e419674..1c8c1700bdc0 100644 --- a/weave/trace_server/refs_internal.py +++ b/weave/trace_server/refs_internal.py @@ -202,13 +202,12 @@ def _parse_remaining(remaining: list[str]) -> tuple[str, str, list[str]]: It is expected to be pre-split by slashes into parts. The return is a tuple of name, version, and extra parts, properly unquoted. """ - name_encoded, version = remaining[0].split(":") - name = urllib.parse.unquote_plus(name_encoded) + name, version = remaining[0].split(":") extra = remaining[1:] if len(extra) == 1 and extra[0] == "": extra = [] else: - extra = [urllib.parse.unquote_plus(r) for r in extra] + extra = [urllib.parse.unquote(r) for r in extra] return name, version, extra diff --git a/weave/weave_client.py b/weave/weave_client.py index 47344959a850..e1ab873fd6ca 100644 --- a/weave/weave_client.py +++ b/weave/weave_client.py @@ -1,6 +1,7 @@ import dataclasses import datetime import platform +import re import sys import typing from functools import lru_cache @@ -48,6 +49,7 @@ TableSchemaForInsert, TraceServerInterface, ) +from weave.trace_server.validation import object_id_validator if typing.TYPE_CHECKING: from . import ref_base @@ -762,6 +764,9 @@ def _save_object_basic( if name is None: raise ValueError("Name must be provided for object saving") + name = sanitize_object_name(name) + object_id_validator(name) + response = self.server.obj_create( ObjCreateReq( obj=ObjSchemaForInsert( @@ -999,4 +1004,16 @@ def redact_sensitive_keys(obj: typing.Any) -> typing.Any: return obj +def sanitize_object_name(name: str) -> str: + # Replaces any non-alphanumeric characters with a single dash and removes + # any leading or trailing dashes. This is more restrictive than the DB + # constraints and can be relaxed if needed. + res = re.sub(r"\W+", "-", name).strip("-_") + if not res: + raise ValueError(f"Invalid object name: {name}") + if len(res) > 128: + raise ValueError(f"Object name too long: {name}") + return res + + __docspec__ = [WeaveClient, Call, CallsIter] From c5d66b1bc0d6dfe8b607f3112706c009bb1b20ff Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 16 Aug 2024 15:12:24 -0700 Subject: [PATCH 03/22] linting --- weave/trace/refs.py | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/weave/trace/refs.py b/weave/trace/refs.py index 7d62ea1068ae..684086d5865c 100644 --- a/weave/trace/refs.py +++ b/weave/trace/refs.py @@ -1,8 +1,7 @@ import dataclasses +import urllib from typing import Any, Union -import urllib3 - from ..trace_server import refs_internal DICT_KEY_EDGE_NAME = refs_internal.DICT_KEY_EDGE_NAME @@ -140,30 +139,18 @@ def parse_uri(uri: str) -> AnyRef: remaining = tuple(parts[3:]) if kind == "table": return TableRef(entity=entity, project=project, digest=remaining[0]) - elif kind == "call": - return CallRef( - entity=entity, - project=project, - id=remaining[0], - extra=[urllib3.parse.unquote(r) for r in remaining[1:]], - ) + extra = tuple(urllib.parse.unquote(r) for r in remaining[1:]) + if kind == "call": + return CallRef(entity=entity, project=project, id=remaining[0], extra=extra) elif kind == "object": name, version = remaining[0].split(":") return ObjectRef( - entity=entity, - project=project, - name=name, - digest=version, - extra=[urllib3.parse.unquote(r) for r in remaining[1:]], + entity=entity, project=project, name=name, digest=version, extra=extra ) elif kind == "op": name, version = remaining[0].split(":") return OpRef( - entity=entity, - project=project, - name=name, - digest=version, - extra=[urllib3.parse.unquote(r) for r in remaining[1:]], + entity=entity, project=project, name=name, digest=version, extra=extra ) else: raise ValueError(f"Unknown ref kind: {kind}") From 26b9ae4f631b87438431f95aa98c56bdbb116df9 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 16 Aug 2024 15:17:58 -0700 Subject: [PATCH 04/22] some of the UI fixes --- .../Home/Browse3/pages/CallPage/ObjectViewer.tsx | 2 +- .../pages/common/tabularListViews/columnBuilder.tsx | 6 +++++- weave-js/src/react.tsx | 7 ++++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx index 24e10a3dd09c..fc72aab7d71d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx @@ -460,7 +460,7 @@ const buildBaseRef = ( const parts = path.toPath().slice(-startIndex); parts.forEach(part => { if (typeof part === 'string') { - baseRef += '/' + OBJECT_ATTR_EDGE_NAME + '/' + part; + baseRef += '/' + OBJECT_ATTR_EDGE_NAME + '/' + encodeURIComponent(part); } else if (typeof part === 'number') { baseRef += '/' + LIST_INDEX_EDGE_NAME + '/' + part.toString(); } else { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx index 1e81bbd96436..73f26cf297b4 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx @@ -160,7 +160,11 @@ function replaceTableRefsInFlattenedData(flattened: Record) { if (parentRef) { const newVal: ExpandedRefWithValueAsTableRef = makeRefExpandedPayload( - parentRef + '/' + OBJECT_ATTR_EDGE_NAME + '/' + attr, + parentRef + + '/' + + OBJECT_ATTR_EDGE_NAME + + '/' + + encodeURIComponent(attr), val ); return [key, newVal]; diff --git a/weave-js/src/react.tsx b/weave-js/src/react.tsx index d585b54421f6..2ee5118060c7 100644 --- a/weave-js/src/react.tsx +++ b/weave-js/src/react.tsx @@ -505,6 +505,7 @@ export const isWeaveObjectRef = (ref: ObjectRef): ref is WeaveObjectRef => { // Unfortunately many teams have been created that violate this. const PATTERN_ENTITY = '([^/]+)'; const PATTERN_PROJECT = '([^\\#?%:]{1,128})'; // Project name +const PATTERN_REF_EXTRA = '([a-zA-Z0-9_/%]*)'; // Optional ref extra const RE_WEAVE_OBJECT_REF_PATHNAME = new RegExp( [ '^', // Start of the string @@ -518,7 +519,7 @@ const RE_WEAVE_OBJECT_REF_PATHNAME = new RegExp( ':', '([*]|[a-zA-Z0-9]+)', // Artifact version, allowing '*' for any version '/?', // Ref extra portion is optional - '([a-zA-Z0-9_/]*)', // Optional ref extra + PATTERN_REF_EXTRA, // Optional ref extra '$', // End of the string ].join('') ); @@ -531,7 +532,7 @@ const RE_WEAVE_TABLE_REF_PATHNAME = new RegExp( '/table/', '([a-f0-9]+)', // Digest '/?', // Ref extra portion is optional - '([a-zA-Z0-9_/]*)', // Optional ref extra + PATTERN_REF_EXTRA, // Optional ref extra '$', // End of the string ].join('') ); @@ -544,7 +545,7 @@ const RE_WEAVE_CALL_REF_PATHNAME = new RegExp( '/call/', '([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})', // Call UUID '/?', // Ref extra portion is optional - '([a-zA-Z0-9_/]*)', // Optional ref extra + PATTERN_REF_EXTRA, // Optional ref extra '$', // End of the string ].join('') ); From 988007111c5b0f42bc0e2e8800ed53ea79cd2ad3 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 16 Aug 2024 15:43:30 -0700 Subject: [PATCH 05/22] tests - rough --- weave/tests/test_client_trace.py | 75 ++++++++++++++++++++++++ weave/trace_server/tests/test_refs.py | 83 ++++++++++++++++++++++++++- 2 files changed, 155 insertions(+), 3 deletions(-) diff --git a/weave/tests/test_client_trace.py b/weave/tests/test_client_trace.py index c319c66cc126..2cf0a5de31ac 100644 --- a/weave/tests/test_client_trace.py +++ b/weave/tests/test_client_trace.py @@ -17,6 +17,7 @@ from weave.trace.vals import MissingSelfInstanceError from weave.trace_server.ids import generate_id from weave.trace_server.sqlite_trace_server import SqliteTraceServer +from weave.weave_client import sanitize_object_name from ..trace_server import trace_server_interface as tsi from ..trace_server.trace_server_interface_util import ( @@ -2421,3 +2422,77 @@ def test_model_save(client): assert isinstance(expected_predict_op, str) and expected_predict_op.startswith( "weave:///" ) + + +class Custom(weave.Object): + val: dict + + +def test_object_with_disallowed_keys(client): + name = "%" + obj = Custom(name=name, val={"1": 1}) + + weave.publish(obj) + + # we sanitize the name + assert obj.ref.name == "_" + + create_req = tsi.ObjCreateReq.model_validate( + dict( + obj=dict( + project_id=client._project_id(), + object_id=name, + val={"1": 1}, + ) + ) + ) + with pytest.raises(Exception): + client.server.obj_create(create_req) + + +chars = "+_(){}|\"'<>!@$^&*#:,.[]-=;~`" + + +def test_objects_and_keys_with_special_characters(client): + # make sure to include ":", "/" which are URI-related + + name_with_special_characters = "name: /" + chars + dict_payload = {name_with_special_characters: "hello world"} + + obj = Custom(name=name_with_special_characters, val=dict_payload) + + weave.publish(obj) + assert obj.ref is not None + + entity, project = client._project_id().split("/") + project_id = f"{entity}/{project_id}" + ref_base = f"weave:///{project}" + exp_name = sanitize_object_name(name_with_special_characters) + # exp_digest = "2bnzTXFjtlwrtXWNLhAyvYq0XbRFfr633kKL2IkBOlI" + + # exp_obj_ref = f"{ref_base}/object/{exp_name}:{exp_digest}" + # assert obj.ref.uri() == exp_obj_ref + + @weave.op + def test(obj: Custom): + return obj.val[name_with_special_characters] + + test.name = name_with_special_characters + + res = test(obj) + + # exp_res_ref = f"{exp_obj_ref}/attr/val/key/{exp_name}" + found_ref = res.ref.uri() + assert res == "hello world" + # assert found_ref == exp_res_ref + + gotten_res = weave.ref(found_ref).get() + assert gotten_res == "hello world" + + # exp_op_digest = "WZAc2HA5GyWPr7YzHiSBncbHKMywXN3hk8onqRy2KkA" + # exp_op_ref = f"{ref_base}/op/{exp_name}:{exp_op_digest}" + + found_ref = test.ref.uri() + # assert found_ref == exp_op_ref + gotten_fn = weave.ref(found_ref).get() + assert gotten_fn(obj) == "hello world" diff --git a/weave/trace_server/tests/test_refs.py b/weave/trace_server/tests/test_refs.py index 41ce450a1716..d2d99be976e1 100644 --- a/weave/trace_server/tests/test_refs.py +++ b/weave/trace_server/tests/test_refs.py @@ -1,8 +1,85 @@ -from weave.trace.refs import ObjectRef +import random + +import pytest + +from weave.trace import refs +from weave.trace_server import refs_internal +from weave.weave_client import sanitize_object_name + +quote = refs_internal.ref_part_quoter def test_isdescended_from(): - a = ObjectRef(entity="e", project="p", name="n", digest="v", extra=["x1"]) - b = ObjectRef(entity="e", project="p", name="n", digest="v", extra=["x1", "x2"]) + a = refs.ObjectRef(entity="e", project="p", name="n", digest="v", extra=["x1"]) + b = refs.ObjectRef( + entity="e", project="p", name="n", digest="v", extra=["x1", "x2"] + ) assert a.is_descended_from(b) == False assert b.is_descended_from(a) == True + + +def string_with_every_char(disallowed_chars=[]): + char_codes = list(range(256)) + random.shuffle(char_codes) + return "".join(chr(i) for i in char_codes if chr(i) not in disallowed_chars) + + +def test_ref_parsing_external_invalid(): + ref_start = refs.ObjectRef( + entity="entity", + project="project", + name=string_with_every_char(), + digest="1234567890", + extra=("key", string_with_every_char()), + ) + + ref_str = ref_start.uri() + with pytest.raises(): + refs.parse_uri(ref_str) + + +def test_ref_parsing_external_sanitized(): + ref_start = refs.ObjectRef( + entity="entity", + project="project", + name=sanitize_object_name(string_with_every_char()), + digest="1234567890", + extra=("key", string_with_every_char()), + ) + + ref_str = ref_start.uri() + exp_ref = f"{refs_internal.WEAVE_SCHEME}:///{ref_start.entity}/{ref_start.project}/object/{ref_start.name}:{ref_start.digest}/{ref_start.extra[0]}/{quote(ref_start.extra[1])}" + assert ref_str == exp_ref + + parsed = refs.parse_uri(ref_str) + assert parsed == ref_start + + +def test_ref_parsing_external_invalid(): + ref_start = refs_internal.InternalObjectRef( + project_id="project", + name=string_with_every_char(), + version="1234567890", + extra=("key", string_with_every_char()), + ) + + ref_str = ref_start.uri() + with pytest.raises(): + refs.parse_uri(ref_str) + + +def test_ref_parsing_external_sanitized(): + ref_start = refs_internal.InternalObjectRef( + entity="entity", + project_id="project", + name=sanitize_object_name(string_with_every_char()), + version="1234567890", + extra=("key", string_with_every_char()), + ) + + ref_str = ref_start.uri() + exp_ref = f"{refs_internal.WEAVE_INTERNAL_SCHEME}:///{ref_start.project_id}/object/{ref_start.name}:{ref_start.digest}/{ref_start.extra[0]}/{quote(ref_start.extra[1])}" + assert ref_str == exp_ref + + parsed = refs.parse_uri(ref_str) + assert parsed == ref_start From e7bdc68113f5a236b6cf39fb93139a44a42e6edb Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 16 Aug 2024 15:45:02 -0700 Subject: [PATCH 06/22] lint complete --- weave/trace_server/tests/test_refs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/weave/trace_server/tests/test_refs.py b/weave/trace_server/tests/test_refs.py index d2d99be976e1..9ffa3e38782b 100644 --- a/weave/trace_server/tests/test_refs.py +++ b/weave/trace_server/tests/test_refs.py @@ -6,7 +6,7 @@ from weave.trace_server import refs_internal from weave.weave_client import sanitize_object_name -quote = refs_internal.ref_part_quoter +quote = refs_internal.extra_value_quoter def test_isdescended_from(): @@ -55,7 +55,7 @@ def test_ref_parsing_external_sanitized(): assert parsed == ref_start -def test_ref_parsing_external_invalid(): +def test_ref_parsing_internal_invalid(): ref_start = refs_internal.InternalObjectRef( project_id="project", name=string_with_every_char(), @@ -68,7 +68,7 @@ def test_ref_parsing_external_invalid(): refs.parse_uri(ref_str) -def test_ref_parsing_external_sanitized(): +def test_ref_parsing_internal_sanitized(): ref_start = refs_internal.InternalObjectRef( entity="entity", project_id="project", From b22a461ef4b43d3e669f85feedeb311345d73867 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 16 Aug 2024 15:58:12 -0700 Subject: [PATCH 07/22] test improvement --- weave/trace_server/refs_internal.py | 2 +- weave/trace_server/tests/test_refs.py | 54 ++++++--------------------- weave/trace_server/validation_util.py | 2 +- weave/weave_client.py | 2 +- 4 files changed, 14 insertions(+), 46 deletions(-) diff --git a/weave/trace_server/refs_internal.py b/weave/trace_server/refs_internal.py index 1c8c1700bdc0..d3c071fc88fb 100644 --- a/weave/trace_server/refs_internal.py +++ b/weave/trace_server/refs_internal.py @@ -42,7 +42,7 @@ def extra_value_quoter(s: str, quote_chars: tuple[str, ...] = ("%", "/")) -> str raise ValueError("Quoting '%' must be the first character in the list.") for c in quote_chars: - s = s.replace(c, urllib.parse.quote(c)) + s = s.replace(c, urllib.parse.quote(c, safe="")) return s diff --git a/weave/trace_server/tests/test_refs.py b/weave/trace_server/tests/test_refs.py index 9ffa3e38782b..e46a555e1720 100644 --- a/weave/trace_server/tests/test_refs.py +++ b/weave/trace_server/tests/test_refs.py @@ -24,62 +24,30 @@ def string_with_every_char(disallowed_chars=[]): return "".join(chr(i) for i in char_codes if chr(i) not in disallowed_chars) -def test_ref_parsing_external_invalid(): - ref_start = refs.ObjectRef( - entity="entity", - project="project", - name=string_with_every_char(), - digest="1234567890", - extra=("key", string_with_every_char()), - ) - - ref_str = ref_start.uri() - with pytest.raises(): - refs.parse_uri(ref_str) - - -def test_ref_parsing_external_sanitized(): - ref_start = refs.ObjectRef( - entity="entity", - project="project", - name=sanitize_object_name(string_with_every_char()), - digest="1234567890", - extra=("key", string_with_every_char()), - ) - - ref_str = ref_start.uri() - exp_ref = f"{refs_internal.WEAVE_SCHEME}:///{ref_start.entity}/{ref_start.project}/object/{ref_start.name}:{ref_start.digest}/{ref_start.extra[0]}/{quote(ref_start.extra[1])}" - assert ref_str == exp_ref - - parsed = refs.parse_uri(ref_str) - assert parsed == ref_start +# TODO: Write the equivalent tests for the external refs! def test_ref_parsing_internal_invalid(): - ref_start = refs_internal.InternalObjectRef( - project_id="project", - name=string_with_every_char(), - version="1234567890", - extra=("key", string_with_every_char()), - ) - - ref_str = ref_start.uri() - with pytest.raises(): - refs.parse_uri(ref_str) + with pytest.raises(Exception): + ref_start = refs_internal.InternalObjectRef( + project_id="project", + name=string_with_every_char(), + version="1234567890", + extra=("key", string_with_every_char()), + ) def test_ref_parsing_internal_sanitized(): ref_start = refs_internal.InternalObjectRef( - entity="entity", project_id="project", name=sanitize_object_name(string_with_every_char()), version="1234567890", - extra=("key", string_with_every_char()), + extra=["key", string_with_every_char()], ) ref_str = ref_start.uri() - exp_ref = f"{refs_internal.WEAVE_INTERNAL_SCHEME}:///{ref_start.project_id}/object/{ref_start.name}:{ref_start.digest}/{ref_start.extra[0]}/{quote(ref_start.extra[1])}" + exp_ref = f"{refs_internal.WEAVE_INTERNAL_SCHEME}:///{ref_start.project_id}/object/{ref_start.name}:{ref_start.version}/{ref_start.extra[0]}/{quote(ref_start.extra[1])}" assert ref_str == exp_ref - parsed = refs.parse_uri(ref_str) + parsed = refs_internal.parse_internal_uri(ref_str) assert parsed == ref_start diff --git a/weave/trace_server/validation_util.py b/weave/trace_server/validation_util.py index 2879c5b2ddae..2496a597926e 100644 --- a/weave/trace_server/validation_util.py +++ b/weave/trace_server/validation_util.py @@ -58,6 +58,6 @@ def require_internal_ref_uri( def require_max_str_len(s: str, length: int) -> str: - if len(s) >= length: + if len(s) > length: raise CHValidationError(f"String too long: {s}. Max length is {length}") return s diff --git a/weave/weave_client.py b/weave/weave_client.py index e1ab873fd6ca..4d5e8942260c 100644 --- a/weave/weave_client.py +++ b/weave/weave_client.py @@ -1012,7 +1012,7 @@ def sanitize_object_name(name: str) -> str: if not res: raise ValueError(f"Invalid object name: {name}") if len(res) > 128: - raise ValueError(f"Object name too long: {name}") + res = res[:128] return res From dd841d1cf1f3b1c5279c1ab2ea584c1f2aa13070 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 16 Aug 2024 16:09:03 -0700 Subject: [PATCH 08/22] done with tests --- weave/tests/test_client_trace.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/weave/tests/test_client_trace.py b/weave/tests/test_client_trace.py index 2cf0a5de31ac..ffb2ddd9dc94 100644 --- a/weave/tests/test_client_trace.py +++ b/weave/tests/test_client_trace.py @@ -16,6 +16,7 @@ from weave import Thread, ThreadPoolExecutor, weave_client from weave.trace.vals import MissingSelfInstanceError from weave.trace_server.ids import generate_id +from weave.trace_server.refs_internal import extra_value_quoter from weave.trace_server.sqlite_trace_server import SqliteTraceServer from weave.weave_client import sanitize_object_name @@ -2465,13 +2466,17 @@ def test_objects_and_keys_with_special_characters(client): assert obj.ref is not None entity, project = client._project_id().split("/") - project_id = f"{entity}/{project_id}" - ref_base = f"weave:///{project}" + project_id = f"{entity}/{project}" + ref_base = f"weave:///{project_id}" exp_name = sanitize_object_name(name_with_special_characters) - # exp_digest = "2bnzTXFjtlwrtXWNLhAyvYq0XbRFfr633kKL2IkBOlI" + assert exp_name == "name" + exp_key = extra_value_quoter(name_with_special_characters) + # If we decide to quote everything, this would change + assert exp_key == "name: %2F+_(){}|\"'<>!@$^&*#:,.[]-=;~`" + exp_digest = "2bnzTXFjtlwrtXWNLhAyvYq0XbRFfr633kKL2IkBOlI" - # exp_obj_ref = f"{ref_base}/object/{exp_name}:{exp_digest}" - # assert obj.ref.uri() == exp_obj_ref + exp_obj_ref = f"{ref_base}/object/{exp_name}:{exp_digest}" + assert obj.ref.uri() == exp_obj_ref @weave.op def test(obj: Custom): @@ -2481,18 +2486,18 @@ def test(obj: Custom): res = test(obj) - # exp_res_ref = f"{exp_obj_ref}/attr/val/key/{exp_name}" + exp_res_ref = f"{exp_obj_ref}/attr/val/key/{exp_key}" found_ref = res.ref.uri() assert res == "hello world" - # assert found_ref == exp_res_ref + assert found_ref == exp_res_ref gotten_res = weave.ref(found_ref).get() assert gotten_res == "hello world" - # exp_op_digest = "WZAc2HA5GyWPr7YzHiSBncbHKMywXN3hk8onqRy2KkA" - # exp_op_ref = f"{ref_base}/op/{exp_name}:{exp_op_digest}" + exp_op_digest = "WZAc2HA5GyWPr7YzHiSBncbHKMywXN3hk8onqRy2KkA" + exp_op_ref = f"{ref_base}/op/{exp_name}:{exp_op_digest}" found_ref = test.ref.uri() - # assert found_ref == exp_op_ref + assert found_ref == exp_op_ref gotten_fn = weave.ref(found_ref).get() assert gotten_fn(obj) == "hello world" From 853a2fa29a507cdc99b953498247ed031e36334c Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 16 Aug 2024 16:28:22 -0700 Subject: [PATCH 09/22] ok, basically done --- weave-js/src/react.tsx | 2 +- weave/tests/test_client_trace.py | 6 ++++-- weave/trace_server/refs_internal.py | 16 ++-------------- 3 files changed, 7 insertions(+), 17 deletions(-) diff --git a/weave-js/src/react.tsx b/weave-js/src/react.tsx index 2ee5118060c7..cb5be7cf1957 100644 --- a/weave-js/src/react.tsx +++ b/weave-js/src/react.tsx @@ -505,7 +505,7 @@ export const isWeaveObjectRef = (ref: ObjectRef): ref is WeaveObjectRef => { // Unfortunately many teams have been created that violate this. const PATTERN_ENTITY = '([^/]+)'; const PATTERN_PROJECT = '([^\\#?%:]{1,128})'; // Project name -const PATTERN_REF_EXTRA = '([a-zA-Z0-9_/%]*)'; // Optional ref extra +const PATTERN_REF_EXTRA = '([a-zA-Z0-9_/%.-~]*)'; // Optional ref extra (valid chars are result of python urllib.parse.quote and javascript encodeURIComponent) const RE_WEAVE_OBJECT_REF_PATHNAME = new RegExp( [ '^', // Start of the string diff --git a/weave/tests/test_client_trace.py b/weave/tests/test_client_trace.py index ffb2ddd9dc94..b11995325511 100644 --- a/weave/tests/test_client_trace.py +++ b/weave/tests/test_client_trace.py @@ -2471,8 +2471,10 @@ def test_objects_and_keys_with_special_characters(client): exp_name = sanitize_object_name(name_with_special_characters) assert exp_name == "name" exp_key = extra_value_quoter(name_with_special_characters) - # If we decide to quote everything, this would change - assert exp_key == "name: %2F+_(){}|\"'<>!@$^&*#:,.[]-=;~`" + assert ( + exp_key + == "name%3A%20%2F%2B_%28%29%7B%7D%7C%22%27%3C%3E%21%40%24%5E%26%2A%23%3A%2C.%5B%5D-%3D%3B~%60" + ) exp_digest = "2bnzTXFjtlwrtXWNLhAyvYq0XbRFfr633kKL2IkBOlI" exp_obj_ref = f"{ref_base}/object/{exp_name}:{exp_digest}" diff --git a/weave/trace_server/refs_internal.py b/weave/trace_server/refs_internal.py index d3c071fc88fb..69a5a363f7a2 100644 --- a/weave/trace_server/refs_internal.py +++ b/weave/trace_server/refs_internal.py @@ -31,20 +31,8 @@ class InvalidInternalRef(ValueError): pass -def extra_value_quoter(s: str, quote_chars: tuple[str, ...] = ("%", "/")) -> str: - """We don't need to quote every single character, rather - we just need to quote the characters that are used as - delimiters in the URI. - """ - # We have to be careful here since quoting creates new "%" characters. - # We don't want to double-quote characters. - if "%" in quote_chars and quote_chars[0] != "%": - raise ValueError("Quoting '%' must be the first character in the list.") - - for c in quote_chars: - s = s.replace(c, urllib.parse.quote(c, safe="")) - - return s +def extra_value_quoter(s: str) -> str: + return urllib.parse.quote(s, safe="") def validate_extra(extra: list[str]) -> None: From e669fe0c880d2e9cf4ef4a1aaedc3fa64a56e072 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 16 Aug 2024 16:43:53 -0700 Subject: [PATCH 10/22] Final --- weave-js/src/react.tsx | 100 ++++++++++++++++++++--------------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/weave-js/src/react.tsx b/weave-js/src/react.tsx index cb5be7cf1957..d11c404d8cd5 100644 --- a/weave-js/src/react.tsx +++ b/weave-js/src/react.tsx @@ -47,6 +47,7 @@ import { useState, } from 'react'; +import {WEAVE_REF_PREFIX} from './components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants'; import {PanelCompContext} from './components/Panel2/PanelComp'; import {usePanelContext} from './components/Panel2/PanelContext'; import {toWeaveType} from './components/Panel2/toWeaveType'; @@ -64,7 +65,6 @@ import { getChainRootVar, isConstructor, } from './core/mutate'; -import {trimStartChar} from './core/util/string'; import {UseNodeValueServerExecutionError} from './errors'; import {useDeepMemo} from './hookUtils'; import {consoleLog} from './util'; @@ -505,7 +505,7 @@ export const isWeaveObjectRef = (ref: ObjectRef): ref is WeaveObjectRef => { // Unfortunately many teams have been created that violate this. const PATTERN_ENTITY = '([^/]+)'; const PATTERN_PROJECT = '([^\\#?%:]{1,128})'; // Project name -const PATTERN_REF_EXTRA = '([a-zA-Z0-9_/%.-~]*)'; // Optional ref extra (valid chars are result of python urllib.parse.quote and javascript encodeURIComponent) +const PATTERN_REF_EXTRA = '([a-zA-Z0-9_.~/%-]*)'; // Optional ref extra (valid chars are result of python urllib.parse.quote and javascript encodeURIComponent) const RE_WEAVE_OBJECT_REF_PATHNAME = new RegExp( [ '^', // Start of the string @@ -562,7 +562,7 @@ export const parseRef = (ref: string): ObjectRef => { } else if (isLocalArtifact) { splitLimit = 2; } else if (isWeaveRef) { - splitLimit = 4; + return parseWeaveRef(ref); } else { throw new Error(`Unknown protocol: ${url.protocol}`); } @@ -599,58 +599,58 @@ export const parseRef = (ref: string): ObjectRef => { artifactPath, }; } + throw new Error(`Unknown protocol: ${url.protocol}`); +}; - if (isWeaveRef) { - const trimmed = trimStartChar(decodedUri, '/'); - const tableMatch = trimmed.match(RE_WEAVE_TABLE_REF_PATHNAME); - if (tableMatch !== null) { - const [entity, project, digest] = tableMatch.slice(1); - return { - scheme: 'weave', - entityName: entity, - projectName: project, - weaveKind: 'table' as WeaveKind, - artifactName: '', - artifactVersion: digest, - artifactRefExtra: '', - }; - } - const callMatch = trimmed.match(RE_WEAVE_CALL_REF_PATHNAME); - if (callMatch !== null) { - const [entity, project, callId] = callMatch.slice(1); - return { - scheme: 'weave', - entityName: entity, - projectName: project, - weaveKind: 'call' as WeaveKind, - artifactName: callId, - artifactVersion: '', - artifactRefExtra: '', - }; - } - const match = trimmed.match(RE_WEAVE_OBJECT_REF_PATHNAME); - if (match === null) { - throw new Error('Invalid weave ref uri: ' + ref); - } - const [ - entityName, - projectName, - weaveKind, - artifactName, - artifactVersion, - artifactRefExtra, - ] = match.slice(1); +const parseWeaveRef = (ref: string): WeaveObjectRef => { + const trimmed = ref.slice(WEAVE_REF_PREFIX.length); + const tableMatch = trimmed.match(RE_WEAVE_TABLE_REF_PATHNAME); + if (tableMatch !== null) { + const [entity, project, digest] = tableMatch.slice(1); return { scheme: 'weave', - entityName, - projectName, - weaveKind: weaveKind as WeaveKind, - artifactName, - artifactVersion, - artifactRefExtra: artifactRefExtra ?? '', + entityName: entity, + projectName: project, + weaveKind: 'table' as WeaveKind, + artifactName: '', + artifactVersion: digest, + artifactRefExtra: '', }; } - throw new Error(`Unknown protocol: ${url.protocol}`); + const callMatch = trimmed.match(RE_WEAVE_CALL_REF_PATHNAME); + if (callMatch !== null) { + const [entity, project, callId] = callMatch.slice(1); + return { + scheme: 'weave', + entityName: entity, + projectName: project, + weaveKind: 'call' as WeaveKind, + artifactName: callId, + artifactVersion: '', + artifactRefExtra: '', + }; + } + const match = trimmed.match(RE_WEAVE_OBJECT_REF_PATHNAME); + if (match === null) { + throw new Error('Invalid weave ref uri: ' + ref); + } + const [ + entityName, + projectName, + weaveKind, + artifactName, + artifactVersion, + artifactRefExtra, + ] = match.slice(1); + return { + scheme: 'weave', + entityName, + projectName, + weaveKind: weaveKind as WeaveKind, + artifactName, + artifactVersion, + artifactRefExtra: artifactRefExtra ?? '', + }; }; export const objectRefWithExtra = ( From bd36e5a69dfc048c87ad6c44718d77e393272afa Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 16 Aug 2024 16:57:28 -0700 Subject: [PATCH 11/22] Final --- weave/trace_server/validation.py | 2 +- weave/weave_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/weave/trace_server/validation.py b/weave/trace_server/validation.py index bfe99946b2de..97f3bc4b1f13 100644 --- a/weave/trace_server/validation.py +++ b/weave/trace_server/validation.py @@ -60,7 +60,7 @@ def wb_run_id_validator(s: typing.Optional[str]) -> typing.Optional[str]: def _validate_object_name_charset(name: str) -> None: # Object names must be alphanumeric with dashes - invalid_chars = re.findall(r"[^\w-]", name) + invalid_chars = re.findall(r"[^\w._-]", name) if invalid_chars: raise ValueError( f"Invalid object name: {name}. Contains invalid characters: {invalid_chars}" diff --git a/weave/weave_client.py b/weave/weave_client.py index 4d5e8942260c..7e99eef91b9c 100644 --- a/weave/weave_client.py +++ b/weave/weave_client.py @@ -1008,7 +1008,7 @@ def sanitize_object_name(name: str) -> str: # Replaces any non-alphanumeric characters with a single dash and removes # any leading or trailing dashes. This is more restrictive than the DB # constraints and can be relaxed if needed. - res = re.sub(r"\W+", "-", name).strip("-_") + res = re.sub(r"[._-]+", "-", re.sub(r"[^\w._]+", "-", name)).strip("-_") if not res: raise ValueError(f"Invalid object name: {name}") if len(res) > 128: From dc633d793d03c692bed9ade19013007a90811a45 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 16 Aug 2024 17:10:22 -0700 Subject: [PATCH 12/22] Final --- weave/weave_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weave/weave_client.py b/weave/weave_client.py index 7e99eef91b9c..c1ce37dd1476 100644 --- a/weave/weave_client.py +++ b/weave/weave_client.py @@ -1008,7 +1008,7 @@ def sanitize_object_name(name: str) -> str: # Replaces any non-alphanumeric characters with a single dash and removes # any leading or trailing dashes. This is more restrictive than the DB # constraints and can be relaxed if needed. - res = re.sub(r"[._-]+", "-", re.sub(r"[^\w._]+", "-", name)).strip("-_") + res = re.sub(r"([._-]{2,})+", "-", re.sub(r"[^\w._]+", "-", name)).strip("-_") if not res: raise ValueError(f"Invalid object name: {name}") if len(res) > 128: From 8bb015d586619e996afcb0dde184b0682bf5d8f5 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 16 Aug 2024 17:25:54 -0700 Subject: [PATCH 13/22] Final --- weave/tests/test_client_trace.py | 14 +++++++------- weave/tests/test_trace_server.py | 23 ++++++----------------- 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/weave/tests/test_client_trace.py b/weave/tests/test_client_trace.py index b11995325511..a726fcce8792 100644 --- a/weave/tests/test_client_trace.py +++ b/weave/tests/test_client_trace.py @@ -2430,13 +2430,13 @@ class Custom(weave.Object): def test_object_with_disallowed_keys(client): - name = "%" + name = "thing % with / disallowed : keys" obj = Custom(name=name, val={"1": 1}) weave.publish(obj) # we sanitize the name - assert obj.ref.name == "_" + assert obj.ref.name == "thing-with-disallowed-keys" create_req = tsi.ObjCreateReq.model_validate( dict( @@ -2457,7 +2457,7 @@ def test_object_with_disallowed_keys(client): def test_objects_and_keys_with_special_characters(client): # make sure to include ":", "/" which are URI-related - name_with_special_characters = "name: /" + chars + name_with_special_characters = "n-a_m.e: /" + chars + "100" dict_payload = {name_with_special_characters: "hello world"} obj = Custom(name=name_with_special_characters, val=dict_payload) @@ -2469,13 +2469,13 @@ def test_objects_and_keys_with_special_characters(client): project_id = f"{entity}/{project}" ref_base = f"weave:///{project_id}" exp_name = sanitize_object_name(name_with_special_characters) - assert exp_name == "name" + assert exp_name == "n-a_m.e-100" exp_key = extra_value_quoter(name_with_special_characters) assert ( exp_key - == "name%3A%20%2F%2B_%28%29%7B%7D%7C%22%27%3C%3E%21%40%24%5E%26%2A%23%3A%2C.%5B%5D-%3D%3B~%60" + == "n-a_m.e%3A%20%2F%2B_%28%29%7B%7D%7C%22%27%3C%3E%21%40%24%5E%26%2A%23%3A%2C.%5B%5D-%3D%3B~%60100" ) - exp_digest = "2bnzTXFjtlwrtXWNLhAyvYq0XbRFfr633kKL2IkBOlI" + exp_digest = "O66Mk7g91rlAUtcGYOFR1Y2Wk94YyPXJy2UEAzDQcYM" exp_obj_ref = f"{ref_base}/object/{exp_name}:{exp_digest}" assert obj.ref.uri() == exp_obj_ref @@ -2496,7 +2496,7 @@ def test(obj: Custom): gotten_res = weave.ref(found_ref).get() assert gotten_res == "hello world" - exp_op_digest = "WZAc2HA5GyWPr7YzHiSBncbHKMywXN3hk8onqRy2KkA" + exp_op_digest = "xEPCVKKjDWxKzqaCxxU09jD82FGGf5WcNy2fC9VUF3M" exp_op_ref = f"{ref_base}/op/{exp_name}:{exp_op_digest}" found_ref = test.ref.uri() diff --git a/weave/tests/test_trace_server.py b/weave/tests/test_trace_server.py index cab5bb56efee..e02899370a56 100644 --- a/weave/tests/test_trace_server.py +++ b/weave/tests/test_trace_server.py @@ -25,10 +25,8 @@ def test_save_object(client): def test_robust_to_url_sensitive_chars(client): - entity = "entity" - project = "project" - project_id = f"{entity}/{project}" - object_id = "mali:cious/obj%ect" + project_id = client._project_id() + object_id = "mali_cious-obj.ect" bad_key = "mali:cious/ke%y" bad_val = {bad_key: "hello world"} @@ -53,22 +51,13 @@ def test_robust_to_url_sensitive_chars(client): assert read_res.obj.val == bad_val # Object ID that contains reserved characters should be rejected. - with pytest.raises(Exception): - read_res = client.server.refs_read_batch( - tsi.RefsReadBatchReq( - refs=[f"weave:///{project_id}/object/{object_id}:{create_res.digest}"] - ) - ) - encoded_object_id = urllib.parse.quote_plus(object_id) - assert encoded_object_id == "mali%3Acious%2Fobj%25ect" read_res = client.server.refs_read_batch( tsi.RefsReadBatchReq( - refs=[ - f"weave:///{project_id}/object/{encoded_object_id}:{create_res.digest}" - ] + refs=[f"weave:///{project_id}/object/{object_id}:{create_res.digest}"] ) ) + assert read_res.vals[0] == bad_val # Key that contains reserved characters should be rejected. @@ -76,7 +65,7 @@ def test_robust_to_url_sensitive_chars(client): read_res = client.server.refs_read_batch( tsi.RefsReadBatchReq( refs=[ - f"weave:///{project_id}/object/{encoded_object_id}:{create_res.digest}/key/{bad_key}" + f"weave:///{project_id}/object/{object_id}:{create_res.digest}/key/{bad_key}" ] ) ) @@ -86,7 +75,7 @@ def test_robust_to_url_sensitive_chars(client): read_res = client.server.refs_read_batch( tsi.RefsReadBatchReq( refs=[ - f"weave:///{project_id}/object/{encoded_object_id}:{create_res.digest}/key/{encoded_bad_key}" + f"weave:///{project_id}/object/{object_id}:{create_res.digest}/key/{encoded_bad_key}" ] ) ) From a4ede9c7d178c6effa88d53924ea74420a2e9792 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 16 Aug 2024 17:38:05 -0700 Subject: [PATCH 14/22] Final --- weave/integrations/langchain/langchain_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/weave/integrations/langchain/langchain_test.py b/weave/integrations/langchain/langchain_test.py index 91e698ef98c4..8d051e239531 100644 --- a/weave/integrations/langchain/langchain_test.py +++ b/weave/integrations/langchain/langchain_test.py @@ -75,12 +75,12 @@ def test_simple_chain_invoke( ) prompt = PromptTemplate.from_template("1 + {number} = ") long_str = ( - "really_massive_name_that_is_longer_than_max_characters_which_would_be_crazy_" + "really_massive_name_that_is_longer_than_max_characters_which_would_be_crazy" ) name = long_str + long_str prompt.name = name - exp_name = "really_massive_name_that_is_longer_than_max_characte_9ad6_t_is_longer_than_max_characters_which_would_be_crazy_" + exp_name = "really_massive_name_that_is_longer_than_max_characte_ff6e_at_is_longer_than_max_characters_which_would_be_crazy" llm_chain = prompt | llm _ = llm_chain.invoke({"number": 2}) From e19dd390b6a8c2d7f86d56adeae85fca7d00f11c Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 26 Aug 2024 15:39:14 -0700 Subject: [PATCH 15/22] i think complete --- .../PagePanelComponents/Home/Browse3.tsx | 47 ++++++++++++------- weave-js/src/react.test.ts | 16 +------ 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index be81c42d08b5..f6c0d2ad16c4 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -284,7 +284,7 @@ const Browse3Mounted: FC<{ const MainPeekingLayout: FC = () => { const {baseRouter} = useWeaveflowRouteContext(); - const params = useParams(); + const params = useParamsDecoded(); const baseRouterProjectRoot = baseRouter.projectUrl(':entity', ':project'); const generalProjectRoot = browse2Context.projectUrl(':entity', ':project'); const query = useURLSearchParamsDict(); @@ -406,7 +406,7 @@ const MainPeekingLayout: FC = () => { }; const ProjectRedirect: FC = () => { - const {entity, project} = useParams(); + const {entity, project} = useParamsDecoded(); const {baseRouter} = useWeaveflowRouteContext(); const url = baseRouter.tracesUIUrl(entity, project); return ; @@ -497,7 +497,7 @@ const Browse3ProjectRoot: FC<{ // TODO(tim/weaveflow_improved_nav): Generalize this const ObjectVersionRoutePageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); const query = useURLSearchParamsDict(); const history = useHistory(); @@ -538,7 +538,7 @@ const ObjectVersionRoutePageBinding = () => { // TODO(tim/weaveflow_improved_nav): Generalize this const OpVersionRoutePageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); const history = useHistory(); const routerContext = useWeaveflowCurrentRouteContext(); useEffect(() => { @@ -573,7 +573,7 @@ const useCallPeekRedirect = () => { // This is a "hack" since the client doesn't have all the info // needed to make a correct peek URL. This allows the client to request // such a view and we can redirect to the correct URL. - const params = useParams(); + const params = useParamsDecoded(); const {baseRouter} = useWeaveflowRouteContext(); const history = useHistory(); const {useCall} = useWFHooks(); @@ -618,10 +618,23 @@ const useCallPeekRedirect = () => { ]); }; +const useParamsDecoded = () => { + // Handle the case where entity/project (old) have spaces + const params = useParams(); + return useMemo(() => { + return Object.fromEntries( + Object.entries(params).map(([key, value]) => [ + key, + decodeURIComponent(value), + ]) + ); + }, [params]); +}; + // TODO(tim/weaveflow_improved_nav): Generalize this const CallPageBinding = () => { useCallPeekRedirect(); - const params = useParams(); + const params = useParamsDecoded(); const query = useURLSearchParamsDict(); return ( @@ -636,7 +649,7 @@ const CallPageBinding = () => { // TODO(tim/weaveflow_improved_nav): Generalize this const CallsPageBinding = () => { - const {entity, project, tab} = useParams(); + const {entity, project, tab} = useParamsDecoded(); const query = useURLSearchParamsDict(); const initialFilter = useMemo(() => { if (tab === 'evaluations') { @@ -769,7 +782,7 @@ const CallsPageBinding = () => { // TODO(tim/weaveflow_improved_nav): Generalize this const ObjectVersionsPageBinding = () => { - const {entity, project, tab} = useParams(); + const {entity, project, tab} = useParamsDecoded(); const query = useURLSearchParamsDict(); const filters: WFHighLevelObjectVersionFilter = useMemo(() => { let queryFilter: WFHighLevelObjectVersionFilter = {}; @@ -815,7 +828,7 @@ const ObjectVersionsPageBinding = () => { // TODO(tim/weaveflow_improved_nav): Generalize this const OpVersionsPageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); const query = useURLSearchParamsDict(); const filters = useMemo(() => { @@ -851,7 +864,7 @@ const OpVersionsPageBinding = () => { // TODO(tim/weaveflow_improved_nav): Generalize this const BoardPageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); return ( { // TODO(tim/weaveflow_improved_nav): Generalize this const ObjectPageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); return ( { }; const OpPageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); return ( { }; const CompareEvaluationsBinding = () => { - const {entity, project} = useParams(); + const {entity, project} = useParamsDecoded(); const query = useURLSearchParamsDict(); const evaluationCallIds = useMemo(() => { return JSON.parse(query.evaluationCallIds); @@ -902,19 +915,19 @@ const CompareEvaluationsBinding = () => { }; const OpsPageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); return ; }; const BoardsPageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); return ; }; const TablesPageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); return ; }; @@ -934,7 +947,7 @@ const AppBarLink = (props: ComponentProps) => ( ); const Browse3Breadcrumbs: FC = props => { - const params = useParams(); + const params = useParamsDecoded(); const query = useURLSearchParamsDict(); const filePathParts = query.path?.split('/') ?? []; const refFields = query.extra?.split('/') ?? []; diff --git a/weave-js/src/react.test.ts b/weave-js/src/react.test.ts index a20e61692fa0..02877f1ba574 100644 --- a/weave-js/src/react.test.ts +++ b/weave-js/src/react.test.ts @@ -34,7 +34,7 @@ describe('parseRef', () => { }); it('parses a ref with spaces in entity', () => { const parsed = parseRef( - 'weave:///Entity%20Name/project/object/artifact-name:artifactversion' + 'weave:///Entity Name/project/object/artifact-name:artifactversion' ); expect(parsed).toEqual({ artifactName: 'artifact-name', @@ -102,20 +102,6 @@ describe('parseRef', () => { weaveKind: 'object', }); }); - it('parses a ref with escaped spaces in name and projectName', () => { - const parsed = parseRef( - 'weave:///entity/project%20with%20spaces/object/artifact%20name%20with%20spaces:artifactversion' - ); - expect(parsed).toEqual({ - artifactName: 'artifact name with spaces', - artifactRefExtra: '', - artifactVersion: 'artifactversion', - entityName: 'entity', - projectName: 'project with spaces', - scheme: 'weave', - weaveKind: 'object', - }); - }); }); it('parses a weave table ref', () => { const parsed = parseRef( From 5a174ba6a35e040cca566c58ea458f5451adf1c4 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 26 Aug 2024 17:17:24 -0700 Subject: [PATCH 16/22] test fix --- weave/trace_server/sqlite_trace_server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 238617071829..a8611c134554 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -32,6 +32,7 @@ from weave.trace_server.trace_server_interface_util import ( WILDCARD_ARTIFACT_VERSION_AND_PATH, ) +from weave.trace_server.validation import object_id_validator from . import trace_server_interface as tsi from .ids import generate_id @@ -553,6 +554,9 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: json_val = json.dumps(req.obj.val) digest = str_digest(json_val) + # Validate + object_id_validator(req.obj.object_id) + req_obj = req.obj # TODO: version index isn't right here, what if we delete stuff? with self.lock: From d923f3c2971213dd51d542c9ce16e545c2df2ac7 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 27 Aug 2024 13:53:08 -0700 Subject: [PATCH 17/22] removed client check --- weave/weave_client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/weave/weave_client.py b/weave/weave_client.py index c0fb9b60d8ea..809e3a0a2f4b 100644 --- a/weave/weave_client.py +++ b/weave/weave_client.py @@ -49,7 +49,6 @@ TableSchemaForInsert, TraceServerInterface, ) -from weave.trace_server.validation import object_id_validator if typing.TYPE_CHECKING: from . import ref_base @@ -767,7 +766,6 @@ def _save_object_basic( raise ValueError("Name must be provided for object saving") name = sanitize_object_name(name) - object_id_validator(name) response = self.server.obj_create( ObjCreateReq( From 6c7584aa9b5462f2b6b7b3c4a2a4f4da3a5458d1 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 27 Aug 2024 14:15:18 -0700 Subject: [PATCH 18/22] added checks --- weave/trace/refs.py | 8 +++++++- weave/trace_server/tests/test_refs.py | 27 ++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/weave/trace/refs.py b/weave/trace/refs.py index 684086d5865c..16fb47d8f21b 100644 --- a/weave/trace/refs.py +++ b/weave/trace/refs.py @@ -2,7 +2,7 @@ import urllib from typing import Any, Union -from ..trace_server import refs_internal +from ..trace_server import refs_internal, validation DICT_KEY_EDGE_NAME = refs_internal.DICT_KEY_EDGE_NAME LIST_INDEX_EDGE_NAME = refs_internal.LIST_INDEX_EDGE_NAME @@ -54,6 +54,12 @@ class ObjectRef(RefWithExtra): digest: str extra: tuple[str, ...] = () + def __post_init__(self) -> None: + refs_internal.validate_no_slashes(self.digest, "digest") + refs_internal.validate_no_colons(self.digest, "digest") + refs_internal.validate_extra(list(self.extra)) + validation.object_id_validator(self.name) + def uri(self) -> str: u = f"weave:///{self.entity}/{self.project}/object/{self.name}:{self.digest}" if self.extra: diff --git a/weave/trace_server/tests/test_refs.py b/weave/trace_server/tests/test_refs.py index e46a555e1720..a518fa1ef797 100644 --- a/weave/trace_server/tests/test_refs.py +++ b/weave/trace_server/tests/test_refs.py @@ -24,7 +24,32 @@ def string_with_every_char(disallowed_chars=[]): return "".join(chr(i) for i in char_codes if chr(i) not in disallowed_chars) -# TODO: Write the equivalent tests for the external refs! +def test_ref_parsing_external_invalid(): + with pytest.raises(Exception): + ref_start = refs.ObjectRef( + entity="entity", + project="project", + name=string_with_every_char(), + digest="1234567890", + extra=("key", string_with_every_char()), + ) + + +def test_ref_parsing_external_sanitized(): + ref_start = refs.ObjectRef( + entity="entity", + project="project", + name=sanitize_object_name(string_with_every_char()), + digest="1234567890", + extra=("key", string_with_every_char()), + ) + + ref_str = ref_start.uri() + exp_ref = f"{refs_internal.WEAVE_SCHEME}:///{ref_start.entity}/{ref_start.project}/object/{ref_start.name}:{ref_start.digest}/{ref_start.extra[0]}/{quote(ref_start.extra[1])}" + assert ref_str == exp_ref + + parsed = refs.parse_uri(ref_str) + assert parsed == ref_start def test_ref_parsing_internal_invalid(): From ea298a5584585c99e0c76de2372349df2048e2a3 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 27 Aug 2024 14:40:42 -0700 Subject: [PATCH 19/22] another little fix --- weave/trace_server/tests/test_refs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/weave/trace_server/tests/test_refs.py b/weave/trace_server/tests/test_refs.py index a518fa1ef797..4a9b4b28bb3b 100644 --- a/weave/trace_server/tests/test_refs.py +++ b/weave/trace_server/tests/test_refs.py @@ -10,9 +10,9 @@ def test_isdescended_from(): - a = refs.ObjectRef(entity="e", project="p", name="n", digest="v", extra=["x1"]) + a = refs.ObjectRef(entity="e", project="p", name="n", digest="v", extra=["attr", 'x2']) b = refs.ObjectRef( - entity="e", project="p", name="n", digest="v", extra=["x1", "x2"] + entity="e", project="p", name="n", digest="v", extra=["attr", "x2", 'attr', 'x4'] ) assert a.is_descended_from(b) == False assert b.is_descended_from(a) == True From 4920ee7ec8420d894d74fd2eefa24d09b3b7f25b Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 27 Aug 2024 15:05:41 -0700 Subject: [PATCH 20/22] done --- weave/trace_server/tests/test_refs.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/weave/trace_server/tests/test_refs.py b/weave/trace_server/tests/test_refs.py index 4a9b4b28bb3b..4e0d6647fbbd 100644 --- a/weave/trace_server/tests/test_refs.py +++ b/weave/trace_server/tests/test_refs.py @@ -10,9 +10,15 @@ def test_isdescended_from(): - a = refs.ObjectRef(entity="e", project="p", name="n", digest="v", extra=["attr", 'x2']) + a = refs.ObjectRef( + entity="e", project="p", name="n", digest="v", extra=["attr", "x2"] + ) b = refs.ObjectRef( - entity="e", project="p", name="n", digest="v", extra=["attr", "x2", 'attr', 'x4'] + entity="e", + project="p", + name="n", + digest="v", + extra=["attr", "x2", "attr", "x4"], ) assert a.is_descended_from(b) == False assert b.is_descended_from(a) == True From d269cc3cf7a9898975c3efb399f2f87b4758f3e6 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 27 Aug 2024 15:30:23 -0700 Subject: [PATCH 21/22] lint --- weave/trace_server/tests/test_refs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weave/trace_server/tests/test_refs.py b/weave/trace_server/tests/test_refs.py index 4e0d6647fbbd..42b90235d349 100644 --- a/weave/trace_server/tests/test_refs.py +++ b/weave/trace_server/tests/test_refs.py @@ -3,8 +3,8 @@ import pytest from weave.trace import refs +from weave.trace.weave_client import sanitize_object_name from weave.trace_server import refs_internal -from weave.weave_client import sanitize_object_name quote = refs_internal.extra_value_quoter From 2a7b86d4f4e89879daf911787d15edb4e4e91d11 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Tue, 27 Aug 2024 17:00:48 -0700 Subject: [PATCH 22/22] comments --- weave/tests/trace/test_client_trace.py | 46 ++++++++++++++++++++++++++ weave/trace_server/refs_internal.py | 1 + 2 files changed, 47 insertions(+) diff --git a/weave/tests/trace/test_client_trace.py b/weave/tests/trace/test_client_trace.py index 90e84b01be80..c2846ed126c8 100644 --- a/weave/tests/trace/test_client_trace.py +++ b/weave/tests/trace/test_client_trace.py @@ -2596,6 +2596,52 @@ def test_object_with_disallowed_keys(client): client.server.obj_create(create_req) +CHAR_LIMIT = 128 + + +def test_object_with_char_limit(client): + name = "l" * CHAR_LIMIT + obj = Custom(name=name, val={"1": 1}) + + weave.publish(obj) + + # we sanitize the name + assert obj.ref.name == name + + create_req = tsi.ObjCreateReq.model_validate( + dict( + obj=dict( + project_id=client._project_id(), + object_id=name, + val={"1": 1}, + ) + ) + ) + client.server.obj_create(create_req) + + +def test_object_with_char_over_limit(client): + name = "l" * (CHAR_LIMIT + 1) + obj = Custom(name=name, val={"1": 1}) + + weave.publish(obj) + + # we sanitize the name + assert obj.ref.name == name[:-1] + + create_req = tsi.ObjCreateReq.model_validate( + dict( + obj=dict( + project_id=client._project_id(), + object_id=name, + val={"1": 1}, + ) + ) + ) + with pytest.raises(Exception): + client.server.obj_create(create_req) + + chars = "+_(){}|\"'<>!@$^&*#:,.[]-=;~`" diff --git a/weave/trace_server/refs_internal.py b/weave/trace_server/refs_internal.py index 4edfd3e9014a..ecf2a5fc75c5 100644 --- a/weave/trace_server/refs_internal.py +++ b/weave/trace_server/refs_internal.py @@ -32,6 +32,7 @@ class InvalidInternalRef(ValueError): def extra_value_quoter(s: str) -> str: + # Here, we encode all non alpha-numerics or `_.-~`. return urllib.parse.quote(s, safe="")