Skip to content

Commit

Permalink
chore(weave): Legacy Refactor pt7 (#2225)
Browse files Browse the repository at this point in the history
* move to legacy: types_numpy

* move to legacy: registry_mem

* move to legacy: weave_types
  • Loading branch information
andrewtruong authored Aug 27, 2024
1 parent 33511e7 commit dfa4cf0
Show file tree
Hide file tree
Showing 36 changed files with 73 additions and 66 deletions.
2 changes: 1 addition & 1 deletion weave/legacy/artifact_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def type(self) -> types.Type:

ot = self._outer_type
if self.extra is not None:
from weave import types_numpy
from weave.legacy import types_numpy

if not types.is_list_like(ot) and isinstance(
ot, types_numpy.NumpyArrayType
Expand Down
4 changes: 2 additions & 2 deletions weave/legacy/codify.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import black

from weave import registry_mem, storage, weave_types
from weave.legacy import graph
from weave import storage, weave_types
from weave.legacy import graph, registry_mem

from . import codifiable_value_mixin

Expand Down
2 changes: 1 addition & 1 deletion weave/legacy/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from weave import (
engine_trace,
errors,
registry_mem,
weave_internal,
)
from weave import weave_types as types
Expand All @@ -26,6 +25,7 @@
op_args,
partial_object,
propagate_gql_keys,
registry_mem,
serialize,
stitch,
value_or_error,
Expand Down
4 changes: 2 additions & 2 deletions weave/legacy/compile_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import graphql

from weave import errors, registry_mem
from weave import errors
from weave import weave_types as types
from weave.legacy import gql_op_plugin, gql_to_weave, graph, op_args, stitch
from weave.legacy import gql_op_plugin, gql_to_weave, graph, op_args, stitch, registry_mem
from weave.legacy.input_provider import InputAndStitchProvider

if typing.TYPE_CHECKING:
Expand Down
4 changes: 2 additions & 2 deletions weave/legacy/decorator_class.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import inspect
import typing

from weave import errors, registry_mem
from weave import errors
from weave import weave_types as types
from weave.legacy import context_state, derive_op, op_def
from weave.legacy import context_state, derive_op, op_def, registry_mem

# Contrary to the way it is read, the weave.class() decorator runs AFTER the
# inner methods are defined. Therefore, this function runs after the ops are
Expand Down
3 changes: 1 addition & 2 deletions weave/legacy/decorator_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

from typing_extensions import ParamSpec

from weave import registry_mem
from weave import weave_types as types
from weave.legacy import context_state, derive_op, op_args, op_def, pyfunc_type_util
from weave.legacy import context_state, derive_op, op_args, op_def, pyfunc_type_util, registry_mem

if typing.TYPE_CHECKING:
from weave.legacy.gql_op_plugin import GqlOpPlugin
Expand Down
2 changes: 1 addition & 1 deletion weave/legacy/derive_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from weave import (
errors,
parallelism,
registry_mem,
storage,
weave_internal,
)
from weave import weave_types as types
from weave.legacy import (
registry_mem,
box,
context_state,
execute_fast,
Expand Down
4 changes: 2 additions & 2 deletions weave/legacy/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import typing
from dataclasses import dataclass

from weave import errors, registry_mem, util
from weave import errors, util
from weave import weave_types as types
from weave.legacy import graph, memo, op_args, op_def, pyfunc_type_util
from weave.legacy import graph, memo, op_args, op_def, pyfunc_type_util, registry_mem
from weave.legacy.language_features.tagging.is_tag_getter import is_tag_getter
from weave.legacy.language_features.tagging.tagged_value_type import TaggedValueType

Expand Down
4 changes: 2 additions & 2 deletions weave/legacy/ecosystem/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""
)

from weave import registry_mem
from weave.legacy import registry_mem

op_org_name = registry_mem.memory_registry.get_op("user-name")

Expand Down Expand Up @@ -63,7 +63,7 @@ def ops(self) -> list[op_def.OpDef]:
# objects.
@weave.op(name="op-ecosystem", render_info={"type": "function"})
def ecosystem() -> Ecosystem:
from weave import registry_mem
from weave.legacy import registry_mem

return Ecosystem(
_orgs=[],
Expand Down
2 changes: 1 addition & 1 deletion weave/legacy/ecosystem/wandb/wandb_objs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing

import weave
from weave import registry_mem
from weave.legacy import registry_mem
from weave.legacy.ops_domain import run_ops, wb_domain_types

# We can't chain ops called .name() because of a weird bug :( [its a field on VarNode].
Expand Down
2 changes: 1 addition & 1 deletion weave/legacy/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
errors,
parallelism,
ref_base,
registry_mem,
trace_local,
)
from weave import weave_types as types
Expand All @@ -27,6 +26,7 @@
# Trace / cache
# Language Features
from weave.legacy import (
registry_mem,
box,
compile,
context,
Expand Down
2 changes: 1 addition & 1 deletion weave/legacy/execute_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
engine_trace,
errors,
ref_base,
registry_mem,
weave_internal,
)
from weave import weave_types as types
from weave.legacy import (
registry_mem,
box,
compile,
forward_graph,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import typing

from weave import registry_mem
from weave.legacy import registry_mem
from weave import weave_types as types
from weave.legacy import graph
from weave.legacy.language_features.tagging.opdef_util import (
Expand Down
4 changes: 2 additions & 2 deletions weave/legacy/op_def_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from _ast import AsyncFunctionDef, ExceptHandler
from typing import Any

from weave import environment, errors, registry_mem, storage
from weave import environment, errors, storage
from weave import weave_types as types
from weave.legacy import artifact_fs, artifact_local, context_state, infer_types
from weave.legacy import artifact_fs, artifact_local, context_state, infer_types, registry_mem

if typing.TYPE_CHECKING:
from .op_def import OpDef
Expand Down
3 changes: 1 addition & 2 deletions weave/legacy/ops_arrow/vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

from weave import (
errors,
registry_mem,
weave_internal,
)
from weave.legacy import weavify
from weave.legacy import weavify, registry_mem
from weave import weave_types as types
from weave.query_api import op, use
from weave.legacy import dispatch, graph, graph_debug, op_args, op_def
Expand Down
2 changes: 1 addition & 1 deletion weave/legacy/ops_domain/run_history/history_op_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from weave import (
engine_trace,
errors,
registry_mem,
util,
)
from weave import weave_types as types
from weave.query_api import use
from weave.legacy import (
registry_mem,
_dict_utils,
artifact_base,
artifact_fs,
Expand Down
2 changes: 1 addition & 1 deletion weave/legacy/ops_primitives/weave_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from weave import (
errors,
ref_base,
registry_mem,
storage,
weave_internal,
)
from weave import weave_types as types
from weave.query_api import mutation, op, weave_class
from weave.legacy import (
registry_mem,
artifact_fs,
artifact_local,
artifact_wandb,
Expand Down
4 changes: 2 additions & 2 deletions weave/legacy/panels_py/generator_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import dataclasses
import typing

from weave import registry_mem, weave_types
from weave.legacy import decorator_op, graph
from weave import weave_types
from weave.legacy import decorator_op, graph, registry_mem


@dataclasses.dataclass
Expand Down
4 changes: 2 additions & 2 deletions weave/legacy/propagate_gql_keys.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import typing

from weave import registry_mem

from weave import weave_types as types
from weave.legacy import gql_op_plugin, graph, input_provider, op_def, partial_object
from weave.legacy import gql_op_plugin, graph, input_provider, op_def, partial_object, registry_mem


def _propagate_gql_keys_for_node(
Expand Down
18 changes: 9 additions & 9 deletions weave/registry_mem.py → weave/legacy/registry_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Registry:
# the registry over HTTP.
_updated_at: float

def __init__(self):
def __init__(self): # type: ignore
self._types = {}
self._ops = {}
self._ops_by_common_name = {}
Expand All @@ -41,8 +41,8 @@ def mark_updated(self) -> None:
def updated_at(self) -> float:
return self._updated_at

def register_op(self, op: "OpDef", location=None):
if context_state.get_no_op_register():
def register_op(self, op: "OpDef", location=None): # type: ignore
if context_state.get_no_op_register(): # type: ignore[no-untyped-call]
return op
self.mark_updated()
# do not save built-in ops today
Expand Down Expand Up @@ -97,7 +97,7 @@ def get_op(self, uri: str) -> "OpDef":
raise errors.WeaveMissingOpDefError("Op not registered: %s" % uri)
return res

def find_op_by_fn(self, lazy_local_fn):
def find_op_by_fn(self, lazy_local_fn): # type: ignore
for op_def in self._op_versions.values():
if op_def.call_fn == lazy_local_fn:
return op_def
Expand All @@ -111,17 +111,17 @@ def find_ops_by_common_name(self, common_name: str) -> typing.List["OpDef"]:
return ops

def find_chainable_ops(self, arg0_type: weave_types.Type) -> typing.List["OpDef"]:
def is_chainable(op):
def is_chainable(op): # type: ignore
if not isinstance(op.input_type, op_args.OpNamedArgs):
return False
args = list(op.input_type.arg_types.values())
if not args:
return False
return args[0].assign_type(arg0_type)

return [op for op in self._ops.values() if is_chainable(op)]
return [op for op in self._ops.values() if is_chainable(op)] # type: ignore[no-untyped-call]

def load_saved_ops(self):
def load_saved_ops(self): # type: ignore
from weave.legacy import op_def_type

for op_ref in storage.objects(op_def_type.OpDefType()):
Expand Down Expand Up @@ -149,7 +149,7 @@ def list_packages(self) -> typing.List["OpDef"]:
]
return packages

def rename_op(self, name, new_name):
def rename_op(self, name, new_name): # type: ignore
"""Internal use only, used during op bootstrapping at decorator time"""
self.mark_updated()
op = self._ops.pop(name)
Expand Down Expand Up @@ -186,4 +186,4 @@ def rename_op(self, name, new_name):


# Processes have a singleton MemoryRegistry
memory_registry = Registry()
memory_registry = Registry() # type: ignore[no-untyped-call]
4 changes: 2 additions & 2 deletions weave/legacy/stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import dataclasses
import typing

from weave.legacy import graph, op_def
from weave.legacy import graph, op_def, registry_mem
from weave.legacy.language_features.tagging import opdef_util

from .. import errors, registry_mem
from .. import errors
from .. import weave_types as types
from . import _dict_utils

Expand Down
18 changes: 9 additions & 9 deletions weave/types_numpy.py → weave/legacy/types_numpy.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import numpy as np

from . import weave_types as types
from .. import weave_types as types


# TODO: this doesn't match how extra works for list types...
class NumpyArrayType(types.Type):
instance_classes = np.ndarray
name = "WeaveNDArray"

def __init__(self, dtype="x", shape=(0,)):
def __init__(self, dtype="x", shape=(0,)): # type: ignore
self.dtype = dtype
self.shape = shape

@classmethod
def type_of_instance(cls, obj):
def type_of_instance(cls, obj): # type: ignore
return cls(obj.dtype, obj.shape)

def _to_dict(self):
def _to_dict(self): # type: ignore
return {"dtype": str(self.dtype), "shape": self.shape}

@classmethod
def from_dict(cls, d):
def from_dict(cls, d): # type: ignore
return cls(np.dtype(d.get("dtype", "object")), d["shape"])

def _assign_type_inner(self, next_type):
def _assign_type_inner(self, next_type): # type: ignore
if not isinstance(next_type, NumpyArrayType):
return False
if (
Expand All @@ -33,13 +33,13 @@ def _assign_type_inner(self, next_type):
return False
return True

def save_instance(self, obj, artifact, name):
def save_instance(self, obj, artifact, name): # type: ignore
with artifact.new_file(f"{name}.npz", binary=True) as f:
np.savez_compressed(f, arr=obj)

def load_instance(self, artifact, name, extra=None):
def load_instance(self, artifact, name, extra=None): # type: ignore
with artifact.open(f"{name}.npz", binary=True) as f:
return np.load(f)["arr"]

def __str__(self):
def __str__(self): # type: ignore
return "<NumpyArrayType %s %s>" % (self.dtype, self.shape)
4 changes: 2 additions & 2 deletions weave/weave_pydantic.py → weave/legacy/weave_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import BaseModel, create_model

from . import weave_types as types
from .legacy import infer_types
from .. import weave_types as types
from . import infer_types


def weave_type_to_pydantic(
Expand Down
2 changes: 1 addition & 1 deletion weave/query_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from . import weave_types as types

# needed to enable automatic numpy serialization
from . import types_numpy as _types_numpy
from .legacy import types_numpy as _types_numpy

from . import errors
from weave.legacy.decorators import weave_class, mutation, type
Expand Down
3 changes: 2 additions & 1 deletion weave/serve_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from weave.trace.op import Op
from weave.trace.refs import ObjectRef

from . import errors, weave_pydantic
from . import errors
from .legacy import weave_pydantic

key_cache: cache.LruTimeWindowCache[str, typing.Optional[bool]] = (
cache.LruTimeWindowCache(datetime.timedelta(minutes=5))
Expand Down
Loading

0 comments on commit dfa4cf0

Please sign in to comment.