Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fixes issues with special characters in object name #2156

Merged
merged 25 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ const buildBaseRef = (
const parts = path.toPath().slice(-startIndex);
parts.forEach(part => {
if (typeof part === 'string') {
baseRef += '/' + OBJECT_ATTR_EDGE_NAME + '/' + part;
baseRef += '/' + OBJECT_ATTR_EDGE_NAME + '/' + encodeURIComponent(part);
} else if (typeof part === 'number') {
baseRef += '/' + LIST_INDEX_EDGE_NAME + '/' + part.toString();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ function replaceTableRefsInFlattenedData(flattened: Record<string, any>) {
if (parentRef) {
const newVal: ExpandedRefWithValueAsTableRef =
makeRefExpandedPayload(
parentRef + '/' + OBJECT_ATTR_EDGE_NAME + '/' + attr,
parentRef +
'/' +
OBJECT_ATTR_EDGE_NAME +
'/' +
encodeURIComponent(attr),
val
);
return [key, newVal];
Expand Down
7 changes: 4 additions & 3 deletions weave-js/src/react.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ export const isWeaveObjectRef = (ref: ObjectRef): ref is WeaveObjectRef => {
// Unfortunately many teams have been created that violate this.
const PATTERN_ENTITY = '([^/]+)';
const PATTERN_PROJECT = '([^\\#?%:]{1,128})'; // Project name
const PATTERN_REF_EXTRA = '([a-zA-Z0-9_/%]*)'; // Optional ref extra
const RE_WEAVE_OBJECT_REF_PATHNAME = new RegExp(
[
'^', // Start of the string
Expand All @@ -518,7 +519,7 @@ const RE_WEAVE_OBJECT_REF_PATHNAME = new RegExp(
':',
'([*]|[a-zA-Z0-9]+)', // Artifact version, allowing '*' for any version
'/?', // Ref extra portion is optional
'([a-zA-Z0-9_/]*)', // Optional ref extra
PATTERN_REF_EXTRA, // Optional ref extra
'$', // End of the string
].join('')
);
Expand All @@ -531,7 +532,7 @@ const RE_WEAVE_TABLE_REF_PATHNAME = new RegExp(
'/table/',
'([a-f0-9]+)', // Digest
'/?', // Ref extra portion is optional
'([a-zA-Z0-9_/]*)', // Optional ref extra
PATTERN_REF_EXTRA, // Optional ref extra
'$', // End of the string
].join('')
);
Expand All @@ -544,7 +545,7 @@ const RE_WEAVE_CALL_REF_PATHNAME = new RegExp(
'/call/',
'([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})', // Call UUID
'/?', // Ref extra portion is optional
'([a-zA-Z0-9_/]*)', // Optional ref extra
PATTERN_REF_EXTRA, // Optional ref extra
'$', // End of the string
].join('')
);
Expand Down
26 changes: 9 additions & 17 deletions weave/trace/refs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import urllib
from typing import Any, Union

from ..trace_server import refs_internal
Expand Down Expand Up @@ -56,7 +57,7 @@ class ObjectRef(RefWithExtra):
def uri(self) -> str:
u = f"weave:///{self.entity}/{self.project}/object/{self.name}:{self.digest}"
if self.extra:
u += "/" + "/".join(self.extra)
u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra)
return u

def get(self) -> Any:
Expand Down Expand Up @@ -106,7 +107,7 @@ class OpRef(ObjectRef):
def uri(self) -> str:
u = f"weave:///{self.entity}/{self.project}/op/{self.name}:{self.digest}"
if self.extra:
u += "/" + "/".join(self.extra)
u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra)
return u


Expand All @@ -120,7 +121,7 @@ class CallRef(RefWithExtra):
def uri(self) -> str:
u = f"weave:///{self.entity}/{self.project}/call/{self.id}"
if self.extra:
u += "/" + "/".join(self.extra)
u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra)
return u


Expand All @@ -138,27 +139,18 @@ def parse_uri(uri: str) -> AnyRef:
remaining = tuple(parts[3:])
if kind == "table":
return TableRef(entity=entity, project=project, digest=remaining[0])
elif kind == "call":
return CallRef(
entity=entity, project=project, id=remaining[0], extra=remaining[1:]
)
extra = tuple(urllib.parse.unquote(r) for r in remaining[1:])
if kind == "call":
return CallRef(entity=entity, project=project, id=remaining[0], extra=extra)
elif kind == "object":
name, version = remaining[0].split(":")
return ObjectRef(
entity=entity,
project=project,
name=name,
digest=version,
extra=remaining[1:],
entity=entity, project=project, name=name, digest=version, extra=extra
)
elif kind == "op":
name, version = remaining[0].split(":")
return OpRef(
entity=entity,
project=project,
name=name,
digest=version,
extra=remaining[1:],
entity=entity, project=project, name=name, digest=version, extra=extra
)
else:
raise ValueError(f"Unknown ref kind: {kind}")
29 changes: 14 additions & 15 deletions weave/trace_server/refs_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import urllib
from typing import Union

from . import validation

WEAVE_INTERNAL_SCHEME = "weave-trace-internal"
WEAVE_SCHEME = "weave"
WEAVE_PRIVATE_SCHEME = "weave-private"
Expand All @@ -29,18 +31,15 @@ class InvalidInternalRef(ValueError):
pass


def quote_select(s: str, quote_chars: tuple[str, ...] = ("/", ":", "%")) -> str:
def extra_value_quoter(s: str, quote_chars: tuple[str, ...] = ("%", "/")) -> str:
tssweeney marked this conversation as resolved.
Show resolved Hide resolved
"""We don't need to quote every single character, rather
we just need to quote the characters that are used as
delimiters in the URI. Right now, we only use "/", ":", and "%". Moreover,
the only user-controlled fields are the object name and the extra fields.

All other fields are generated by the system and are safe for parsing.
delimiters in the URI.
"""
# We have to be careful here since quoting creates new "%" characters.
# We don't want to double-quote characters.
if "%" in quote_chars and quote_chars[-1] != "%":
raise ValueError("Quoting '%' must be the last character in the list.")
if "%" in quote_chars and quote_chars[0] != "%":
raise ValueError("Quoting '%' must be the first character in the list.")

for c in quote_chars:
s = s.replace(c, urllib.parse.quote(c))
Expand Down Expand Up @@ -118,20 +117,21 @@ def __post_init__(self) -> None:
validate_no_slashes(self.version, "version")
validate_no_colons(self.version, "version")
validate_extra(self.extra)
validation.object_id_validator(self.name)

def uri(self) -> str:
u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/object/{quote_select(self.name)}:{self.version}"
u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/object/{self.name}:{self.version}"
if self.extra:
u += "/" + "/".join(quote_select(e) for e in self.extra)
u += "/" + "/".join(extra_value_quoter(e) for e in self.extra)
return u


@dataclasses.dataclass(frozen=True)
class InternalOpRef(InternalObjectRef):
def uri(self) -> str:
u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/op/{quote_select(self.name)}:{self.version}"
u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/op/{self.name}:{self.version}"
if self.extra:
u += "/" + "/".join(quote_select(e) for e in self.extra)
u += "/" + "/".join(extra_value_quoter(e) for e in self.extra)
return u


Expand All @@ -151,7 +151,7 @@ def __post_init__(self) -> None:
def uri(self) -> str:
u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/call/{self.id}"
if self.extra:
u += "/" + "/".join(quote_select(e) for e in self.extra)
u += "/" + "/".join(extra_value_quoter(e) for e in self.extra)
return u


Expand Down Expand Up @@ -202,13 +202,12 @@ def _parse_remaining(remaining: list[str]) -> tuple[str, str, list[str]]:
It is expected to be pre-split by slashes into parts. The return
is a tuple of name, version, and extra parts, properly unquoted.
"""
name_encoded, version = remaining[0].split(":")
name = urllib.parse.unquote_plus(name_encoded)
name, version = remaining[0].split(":")
extra = remaining[1:]
if len(extra) == 1 and extra[0] == "":
extra = []
else:
extra = [urllib.parse.unquote_plus(r) for r in extra]
extra = [urllib.parse.unquote(r) for r in extra]

return name, version, extra

Expand Down
14 changes: 14 additions & 0 deletions weave/trace_server/validation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import typing

from weave.trace_server import refs_internal
Expand Down Expand Up @@ -57,7 +58,20 @@ def wb_run_id_validator(s: typing.Optional[str]) -> typing.Optional[str]:
return s


def _validate_object_name_charset(name: str) -> None:
# Object names must be alphanumeric with dashes
invalid_chars = re.findall(r"[^\w-]", name)
if invalid_chars:
raise ValueError(
f"Invalid object name: {name}. Contains invalid characters: {invalid_chars}"
)

if not name:
raise ValueError("Object name cannot be empty")


def object_id_validator(s: str) -> str:
_validate_object_name_charset(s)
return validation_util.require_max_str_len(s, 128)


Expand Down
17 changes: 17 additions & 0 deletions weave/weave_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
import datetime
import platform
import re
import sys
import typing
from functools import lru_cache
Expand Down Expand Up @@ -48,6 +49,7 @@
TableSchemaForInsert,
TraceServerInterface,
)
from weave.trace_server.validation import object_id_validator

if typing.TYPE_CHECKING:
from . import ref_base
Expand Down Expand Up @@ -762,6 +764,9 @@ def _save_object_basic(
if name is None:
raise ValueError("Name must be provided for object saving")

name = sanitize_object_name(name)
object_id_validator(name)

response = self.server.obj_create(
ObjCreateReq(
obj=ObjSchemaForInsert(
Expand Down Expand Up @@ -999,4 +1004,16 @@ def redact_sensitive_keys(obj: typing.Any) -> typing.Any:
return obj


def sanitize_object_name(name: str) -> str:
# Replaces any non-alphanumeric characters with a single dash and removes
# any leading or trailing dashes. This is more restrictive than the DB
# constraints and can be relaxed if needed.
res = re.sub(r"\W+", "-", name).strip("-_")
if not res:
raise ValueError(f"Invalid object name: {name}")
if len(res) > 128:
raise ValueError(f"Object name too long: {name}")
return res


__docspec__ = [WeaveClient, Call, CallsIter]
Loading