Skip to content

Commit

Permalink
chore(weave): Add Object field deprecation decorator (#2688)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong authored Oct 14, 2024
1 parent d53cd9b commit a21e6dd
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
38 changes: 38 additions & 0 deletions tests/trace/test_weaveflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import logging

import numpy as np
from pydantic import Field

import weave
from weave.flow.obj import deprecated_field


def test_weaveflow_op_wandb(client):
Expand Down Expand Up @@ -193,3 +196,38 @@ def append_tool(self, f):
thing.append_tool(lambda: 2)
assert len(thing.tools) == 2
assert thing.tools is thing.tools


def test_deprecated_field_warning(caplog):
caplog.set_level(logging.WARNING)

class TestObj(weave.Object):
new_field: int = Field(..., alias="old_field")

@deprecated_field("new_field")
def old_field(self): ...

# Using new field is the same, but you can access the old field name
obj = TestObj(new_field=1)
assert obj.new_field == obj.old_field == 1

obj.new_field = 2
assert obj.new_field == obj.old_field == 2

# You can also instantiate with the old field name, but using it will show warnings
obj = TestObj(old_field=1)
with caplog.at_level(logging.WARNING):
v = obj.old_field

assert v == 1 == obj.new_field
assert "Use `new_field` instead of `old_field`" in caplog.text

caplog.clear()

with caplog.at_level(logging.WARNING):
obj.old_field = 2

assert obj.new_field == obj.old_field == 2
assert "Use `new_field` instead of `old_field`" in caplog.text

caplog.clear()
26 changes: 25 additions & 1 deletion weave/flow/obj.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Optional
import logging
from typing import Any, Callable, Optional, TypeVar

from pydantic import (
BaseModel,
Expand All @@ -12,6 +13,27 @@
from weave.trace.vals import WeaveObject, pydantic_getattribute
from weave.trace.weave_client import get_ref

logger = logging.getLogger(__name__)

T = TypeVar("T")


def deprecated_field(new_field_name: str) -> Callable[[Callable[[Any], T]], property]:
def decorator(func: Callable[[Any], T]) -> property:
warning_msg = f"Use `{new_field_name}` instead of `{func.__name__}`, which is deprecated and will be removed in a future version."

def getter(self: Any) -> T:
logger.warning(warning_msg)
return getattr(self, new_field_name)

def setter(self: Any, value: T) -> None:
logger.warning(warning_msg)
setattr(self, new_field_name, value)

return property(fget=getter, fset=setter)

return decorator


class Object(BaseModel):
name: Optional[str] = None
Expand All @@ -23,6 +45,8 @@ class Object(BaseModel):
arbitrary_types_allowed=True,
protected_namespaces=(),
extra="forbid",
# Intended to be used to allow "deprecated" aliases for fields until we fully remove them.
populate_by_name=True,
)

__str__ = BaseModel.__repr__
Expand Down

0 comments on commit a21e6dd

Please sign in to comment.