diff --git a/weave/legacy/tests/test_refs.py b/weave/legacy/tests/test_refs.py index 98bd7048bc5..0a22d6028dd 100644 --- a/weave/legacy/tests/test_refs.py +++ b/weave/legacy/tests/test_refs.py @@ -4,7 +4,7 @@ from weave.flow.obj import Object from weave.legacy.weave import artifact_local, storage from weave.legacy.weave import ops_arrow as arrow -from weave.trace import ref_util +from weave.legacy.weave import ref_util from weave.trace_server.refs_internal import ( DICT_KEY_EDGE_NAME, LIST_INDEX_EDGE_NAME, diff --git a/weave/legacy/weave/arrow/list_.py b/weave/legacy/weave/arrow/list_.py index 6e3d0e85161..9a6a28b80ad 100644 --- a/weave/legacy/weave/arrow/list_.py +++ b/weave/legacy/weave/arrow/list_.py @@ -37,7 +37,7 @@ tag_store, tagged_value_type, ) -from weave.trace import ref_util +from weave.legacy.weave import ref_util def reverse_dict(d: dict) -> dict: diff --git a/weave/legacy/weave/artifact_fs.py b/weave/legacy/weave/artifact_fs.py index 37cc0a3b5ae..206d1bf487d 100644 --- a/weave/legacy/weave/artifact_fs.py +++ b/weave/legacy/weave/artifact_fs.py @@ -10,7 +10,7 @@ from weave.legacy.weave import weave_types as types from weave.legacy.weave import artifact_base, file_base, object_context, ref_base, uris from weave.legacy.weave.language_features.tagging import tag_store -from weave.trace import ref_util +from weave.legacy.weave import ref_util if typing.TYPE_CHECKING: from weave.legacy.weave import graph diff --git a/weave/legacy/weave/box.py b/weave/legacy/weave/box.py index a16725af94a..893707e6b26 100644 --- a/weave/legacy/weave/box.py +++ b/weave/legacy/weave/box.py @@ -4,7 +4,7 @@ import numpy as np -from weave.trace import ref_util +from weave.legacy.weave import ref_util from weave.legacy.weave import context_state diff --git a/weave/legacy/weave/object_type_ref_util.py b/weave/legacy/weave/object_type_ref_util.py index d0bc6a1ee35..e841087b064 100644 --- a/weave/legacy/weave/object_type_ref_util.py +++ b/weave/legacy/weave/object_type_ref_util.py @@ -1,6 +1,6 @@ import typing -from weave.trace import ref_util +from weave.legacy.weave import ref_util from weave.legacy.weave import context_state diff --git a/weave/legacy/weave/ref_util.py b/weave/legacy/weave/ref_util.py new file mode 100644 index 00000000000..47aed3131dd --- /dev/null +++ b/weave/legacy/weave/ref_util.py @@ -0,0 +1,105 @@ +import dataclasses +import typing +from urllib import parse + +from weave.legacy.weave import box +from weave.trace_server import refs_internal + +DICT_KEY_EDGE_NAME = refs_internal.DICT_KEY_EDGE_NAME +LIST_INDEX_EDGE_NAME = refs_internal.LIST_INDEX_EDGE_NAME +OBJECT_ATTR_EDGE_NAME = refs_internal.OBJECT_ATTR_EDGE_NAME +AWL_ROW_EDGE_NAME = "row" +AWL_COL_EDGE_NAME = "col" + + +def parse_local_ref_str(s: str) -> typing.Tuple[str, typing.Optional[list[str]]]: + if "#" not in s: + return s, None + path, extra = s.split("#", 1) + return path, extra.split("/") + + +def val_with_relative_ref( + parent_object: typing.Any, child_object: typing.Any, ref_extra_parts: list[str] +) -> typing.Any: + from weave.legacy.weave import context_state, ref_base + + # If we already have a ref, resolve it + if isinstance(child_object, ref_base.Ref): + child_object = child_object.get() + + # Only do this if ref_tracking_enabled right now. I just want to + # avoid introducing new behavior into W&B prod for the moment. + if context_state.ref_tracking_enabled(): + from weave.legacy.weave import storage + + child_ref = storage.get_ref(child_object) + parent_ref = ref_base.get_ref(parent_object) + + # This first check is super important - if the child ref is pointing + # to a completely different artifact (ref), then we want to point to + # the child's inherent ref, not the relative ref from the parent. + if child_ref is not None: + if parent_ref is not None: + if hasattr(child_ref, "digest") and hasattr(parent_ref, "digest"): + if child_ref.digest != parent_ref.digest: + return child_object + + if parent_ref is not None: + child_object = box.box(child_object) + sub_ref = parent_ref.with_extra(None, child_object, ref_extra_parts) + ref_base._put_ref(child_object, sub_ref) + return child_object + + return child_object + + +@dataclasses.dataclass +class RefExtraTuple: + edge_type: str + part: str + + +@dataclasses.dataclass +class ParsedRef: + scheme: str + entity: typing.Optional[str] + project: typing.Optional[str] + artifact: str + alias: str + file_path_parts: list[str] + ref_extra_tuples: list[RefExtraTuple] + + +def parse_ref_str(s: str) -> ParsedRef: + scheme, _, path, _, _, ref_extra = parse.urlparse(s) + entity = None + project = None + assert path.startswith("/") + path = path[1:] + path_parts = path.split("/") + if scheme == "wandb-artifact": + entity = path_parts[0] + project = path_parts[1] + path_parts = path_parts[2:] + + artifact, alias = path_parts[0].split(":") + file_path_parts = path_parts[1:] + ref_extra_tuples = [] + if ref_extra: + ref_extra_parts = ref_extra.split("/") + assert len(ref_extra_parts) % 2 == 0 + for i in range(0, len(ref_extra_parts), 2): + edge_type = ref_extra_parts[i] + part = ref_extra_parts[i + 1] + ref_extra_tuples.append(RefExtraTuple(edge_type, part)) + + return ParsedRef( + scheme=scheme, + entity=entity, + project=project, + artifact=artifact, + alias=alias, + file_path_parts=file_path_parts, + ref_extra_tuples=ref_extra_tuples, + )