Skip to content

Commit

Permalink
weavejs_fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Aug 22, 2024
1 parent a045424 commit cb0c9d4
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 20 deletions.
2 changes: 1 addition & 1 deletion weave/legacy/op_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
tagged_value_type,
)
from weave.legacy.run import Run
from weave.weavejs_fixes import fixup_node
from weave.legacy.weavejs_fixes import fixup_node

if typing.TYPE_CHECKING:
from weave.trace import weave_client
Expand Down
22 changes: 11 additions & 11 deletions weave/weavejs_fixes.py → weave/legacy/weavejs_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from weave.legacy import graph

from . import weave_types
from .. import weave_types


def _convert_specific_opname_to_generic_opname(
Expand Down Expand Up @@ -85,7 +85,7 @@ def convert_specific_opname_to_generic_opname(
def convert_specific_ops_to_generic_ops_node(node: graph.Node) -> graph.Node:
"""Converts specific ops like typedDict-pick to generic ops like pick"""

def convert_specific_op_to_generic_op(node: graph.Node):
def convert_specific_op_to_generic_op(node: graph.Node): # type: ignore
if isinstance(node, graph.ConstNode) and isinstance(
node.type, weave_types.Function
):
Expand All @@ -102,14 +102,14 @@ def convert_specific_op_to_generic_op(node: graph.Node):
return graph.map_nodes_full([node], convert_specific_op_to_generic_op)[0]


def _obj_is_node_like(data: typing.Any):
def _obj_is_node_like(data: typing.Any): # type: ignore
if not isinstance(data, dict):
return False
return data.get("nodeType") in ["const", "output", "var", "void"]


# Non-perfect heuristic to determine if a serialized dict is likely an op
def _dict_is_op_like(data: dict):
def _dict_is_op_like(data: dict): # type: ignore
# Firstly, ops will only have "name" and "input" keys
if set(data.keys()) == set(["name", "inputs"]):
# Those keys will be str and list respectively.
Expand All @@ -121,7 +121,7 @@ def _dict_is_op_like(data: dict):
return False


def convert_specific_ops_to_generic_ops_data(data):
def convert_specific_ops_to_generic_ops_data(data): # type: ignore
"""Fix op call names for serialized objects containing graphs"""
if isinstance(data, list):
return [convert_specific_ops_to_generic_ops_data(d) for d in data]
Expand All @@ -138,7 +138,7 @@ def convert_specific_ops_to_generic_ops_data(data):
def remove_opcall_versions_node(node: graph.Node) -> graph.Node:
"""Fix op call names"""

def remove_op_version(node: graph.Node):
def remove_op_version(node: graph.Node): # type: ignore
if not isinstance(node, graph.OutputNode):
return node
return graph.OutputNode(
Expand All @@ -150,7 +150,7 @@ def remove_op_version(node: graph.Node):
return graph.map_nodes_full([node], remove_op_version)[0]


def remove_opcall_versions_data(data):
def remove_opcall_versions_data(data): # type: ignore
"""Fix op call names for serialized objects containing graphs"""
if isinstance(data, list):
return [remove_opcall_versions_data(d) for d in data]
Expand All @@ -168,7 +168,7 @@ def fixup_node(node: graph.Node) -> graph.Node:
return convert_specific_ops_to_generic_ops_node(node)


def recursively_unwrap_unions(obj):
def recursively_unwrap_unions(obj): # type: ignore
if isinstance(obj, list):
return [recursively_unwrap_unions(o) for o in obj]
if isinstance(obj, dict):
Expand All @@ -183,7 +183,7 @@ def recursively_unwrap_unions(obj):
return obj


def remove_nan_and_inf(obj):
def remove_nan_and_inf(obj): # type: ignore
if isinstance(obj, list):
return [remove_nan_and_inf(o) for o in obj]
if isinstance(obj, dict):
Expand All @@ -194,7 +194,7 @@ def remove_nan_and_inf(obj):
return obj


def remove_partialobject_from_types(data):
def remove_partialobject_from_types(data): # type: ignore
"""Convert weave-internal types like
{"type":"PartialObject","keys":{"name":"string"},"keyless_weave_type_class":"project"}
Expand All @@ -218,7 +218,7 @@ def remove_partialobject_from_types(data):
return data


def fixup_data(data):
def fixup_data(data): # type: ignore
data = recursively_unwrap_unions(data)
data = remove_opcall_versions_data(data)
# No good! We have to do this because remoteHttp doesn't handle NaN/inf in
Expand Down
2 changes: 1 addition & 1 deletion weave/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
storage,
usage_analytics,
util,
weavejs_fixes,
)
from . import weave_types as types
from .legacy import weavejs_fixes


def make_varname_for_type(t: types.Type):
Expand Down
3 changes: 1 addition & 2 deletions weave/tests/legacy/test_js_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
# weave Python code that I haven't documented here.

from weave import weave_types as types
from weave import weavejs_fixes
from weave.legacy import partial_object
from weave.legacy import partial_object, weavejs_fixes
from weave.legacy.ops_domain import wb_domain_types


Expand Down
2 changes: 1 addition & 1 deletion weave/tests/legacy/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from ... import api as weave
from ... import storage
from ... import weave_types as types
from ...legacy.weavejs_fixes import recursively_unwrap_unions
from ...weave_internal import make_const_node
from ...weavejs_fixes import recursively_unwrap_unions
from . import test_helpers


Expand Down
4 changes: 2 additions & 2 deletions weave/tests/legacy/test_weavejs_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import pytest

from weave.legacy import context_state, mappers_python, ops
from weave.legacy import context_state, mappers_python, ops, weavejs_fixes

from ... import api, weave_internal, weavejs_fixes
from ... import api, weave_internal
from ... import weave_types as types


Expand Down
3 changes: 1 addition & 2 deletions weave/weave_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@
server,
storage,
util,
weavejs_fixes,
)
from weave.legacy import context_state, graph, value_or_error, wandb_api
from weave.legacy import context_state, graph, value_or_error, wandb_api, weavejs_fixes
from weave.legacy.language_features.tagging import tag_store
from weave.server_error_handling import client_safe_http_exceptions_as_werkzeug

Expand Down

0 comments on commit cb0c9d4

Please sign in to comment.