Skip to content

Commit

Permalink
use type int for comparisons not enum
Browse files Browse the repository at this point in the history
Summary:
It's inefficient to access `self.type.value`, specifically the second `.value` access, which is pure python `enum.Enum`. We can make a statically typed cython access on the internal data `int` instead.

This fixes a bug that occurs when unions with a field called `type` are compared. This fix is agnostic about what direction we want to take this API.

Reviewed By: yoney

Differential Revision: D67160589

fbshipit-source-id: 6d6471623ed11985e5bf0d334363cec89a2f44a9
  • Loading branch information
ahilger authored and facebook-github-bot committed Dec 13, 2024
1 parent d4eb8f4 commit e476ebe
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 35 deletions.
42 changes: 28 additions & 14 deletions third-party/thrift/src/thrift/lib/python/types.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1504,22 +1504,22 @@ cdef class Union(StructOrUnion):
Returns the value of the field with the given `field_id` if it is indeed the
field that is (currently) set for this union. Otherwise, raises AttributeError.
"""
if self.type.value != field_id:
if _fbthrift_get_Union_type_int(self) != field_id:
# TODO in python 3.10 update this to use name and obj fields
raise AttributeError(
f'Union contains a value of type {self.type.name}, not '
f'{type(self).Type(field_id).name}')
return self.value

def get_type(self):
def get_type(Union self not None):
return self.type

@property
def fbthrift_current_field(self):
def fbthrift_current_field(Union self not None):
return self.type

@property
def fbthrift_current_value(self):
def fbthrift_current_value(Union self not None):
return self.value

@classmethod
Expand Down Expand Up @@ -1555,36 +1555,50 @@ cdef class Union(StructOrUnion):
def __deepcopy__(Union self, _memo):
return self

def __eq__(Union self, other):
def __eq__(Union self not None, other):
if type(other) != type(self):
return False
return self.type == other.type and self.value == other.value
cdef Union other_u = other
cdef int self_type_int = _fbthrift_get_Union_type_int(self)
cdef int other_type_int = _fbthrift_get_Union_type_int(other_u)
return self_type_int == other_type_int and self.value == other_u.value

def __lt__(self, other):
def __lt__(Union self not None, other):
if type(self) != type(other):
return NotImplemented
return (self.type.value, self.value) < (other.type.value, other.value)
cdef Union other_u = other
cdef int self_type_int = _fbthrift_get_Union_type_int(self)
cdef int other_type_int = _fbthrift_get_Union_type_int(other_u)
return (self_type_int, self.value) < (other_type_int, other_u.value)

def __le__(self, other):
def __le__(Union self not None, other):
if type(self) != type(other):
return NotImplemented
return (self.type.value, self.value) <= (other.type.value, other.value)
cdef Union other_u = other
cdef int self_type_int = _fbthrift_get_Union_type_int(self)
cdef int other_type_int = _fbthrift_get_Union_type_int(other_u)
return (self_type_int, self.value) <= (other_type_int, other_u.value)

def __hash__(self):
return hash((self.type, self.value))
def __hash__(Union self not None):
cdef int self_type_int = _fbthrift_get_Union_type_int(self)
return hash((self_type_int, self.value))

def __repr__(self):
return f"{type(self).__name__}({self.type.name}={self.value!r})"

def __bool__(self):
return self.type.value != 0
def __bool__(self not None):
return _fbthrift_get_Union_type_int(self) != 0

def __dir__(self):
return dir(type(self))

def __reduce__(self):
return (_unpickle_union, (type(self), b''.join(self._serialize(Protocol.COMPACT))))


cdef inline _fbthrift_get_Union_type_int(Union u):
return u._fbthrift_data[0]

cdef _make_fget_struct(i):
"""
Returns a function that takes a `Struct` instance and returns the value of
Expand Down
40 changes: 19 additions & 21 deletions third-party/thrift/src/thrift/test/thrift-python/union_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)


def _thrift_serialization_round_trip(
def _assert_serialization_round_trip(
test: unittest.TestCase,
serializer_module: types.ModuleType,
thrift_object: typing.Union[MutableStructOrUnion, ImmutableStructOrUnion],
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_creation(self) -> None:
AttributeError, "Union contains a value of type EMPTY, not string_field"
):
u.string_field
_thrift_serialization_round_trip(self, immutable_serializer, u)
_assert_serialization_round_trip(self, immutable_serializer, u)

# Specifying exactly one keyword argument whose name corresponds to that of a
# field for this Union, and a non-None value whose type is valid for that field,
Expand All @@ -98,7 +98,7 @@ def test_creation(self) -> None:
AttributeError, "Union contains a value of type string_field, not int_field"
):
u2.int_field
_thrift_serialization_round_trip(self, immutable_serializer, u2)
_assert_serialization_round_trip(self, immutable_serializer, u2)

# Attempts to initialize an instance with a keyword argument whose name does
# not match that of a field should raise an error.
Expand Down Expand Up @@ -250,7 +250,7 @@ def test_from_value_ambiguous_int_bool(self) -> None:
)
self.assertEqual(union_int_bool_1.value, 1)
self.assertEqual(union_int_bool_1.int_field, 1)
_thrift_serialization_round_trip(self, immutable_serializer, union_int_bool_1)
_assert_serialization_round_trip(self, immutable_serializer, union_int_bool_1)

# BAD: fromValue(bool) populates an int field if it comes before bool.
union_int_bool_2 = TestUnionAmbiguousFromValueIntBoolImmutable.fromValue(True)
Expand All @@ -264,7 +264,7 @@ def test_from_value_ambiguous_int_bool(self) -> None:
)
self.assertEqual(union_int_bool_2.value, 1)
self.assertEqual(union_int_bool_2.int_field, 1)
_thrift_serialization_round_trip(self, immutable_serializer, union_int_bool_2)
_assert_serialization_round_trip(self, immutable_serializer, union_int_bool_2)

def test_from_value_ambiguous_bool_int(self) -> None:
# BAD: Unlike the previous test case, fromValue(int) does not populate
Expand All @@ -281,7 +281,7 @@ def test_from_value_ambiguous_bool_int(self) -> None:
self.assertEqual(union_bool_int_1.value, 1)
self.assertEqual(union_bool_int_1.int_field, 1)
self.assertEqual(union_bool_int_1.int_field, True)
_thrift_serialization_round_trip(self, immutable_serializer, union_bool_int_1)
_assert_serialization_round_trip(self, immutable_serializer, union_bool_int_1)

union_bool_int_2 = TestUnionAmbiguousFromValueBoolIntImmutable.fromValue(True)
self.assertIs(
Expand All @@ -295,7 +295,7 @@ def test_from_value_ambiguous_bool_int(self) -> None:
self.assertEqual(union_bool_int_2.value, True)
self.assertEqual(union_bool_int_2.value, 1)
self.assertEqual(union_bool_int_2.bool_field, 1)
_thrift_serialization_round_trip(self, immutable_serializer, union_bool_int_2)
_assert_serialization_round_trip(self, immutable_serializer, union_bool_int_2)

def test_from_value_ambiguous_float_int(self) -> None:
# BAD: fromValue(int) populated a float field if it comes before int.
Expand All @@ -310,7 +310,7 @@ def test_from_value_ambiguous_float_int(self) -> None:
)
self.assertEqual(union_float_int_1.value, 1.0)
self.assertEqual(union_float_int_1.float_field, 1)
_thrift_serialization_round_trip(self, immutable_serializer, union_float_int_1)
_assert_serialization_round_trip(self, immutable_serializer, union_float_int_1)

union_float_int_2 = TestUnionAmbiguousFromValueFloatIntImmutable.fromValue(1.0)
self.assertIs(
Expand All @@ -323,7 +323,7 @@ def test_from_value_ambiguous_float_int(self) -> None:
)
self.assertEqual(union_float_int_2.value, 1.0)
self.assertEqual(union_float_int_2.float_field, 1)
_thrift_serialization_round_trip(self, immutable_serializer, union_float_int_2)
_assert_serialization_round_trip(self, immutable_serializer, union_float_int_2)

def test_field_name_conflict(self) -> None:
# By setting class type `Type` attr after field attrs, we get the desired behavior
Expand Down Expand Up @@ -355,20 +355,18 @@ def test_field_name_conflict(self) -> None:
):
# pyre-ignore[41]: Intentional for test
type_union.Type = 1
_thrift_serialization_round_trip(self, immutable_serializer, type_union)
_assert_serialization_round_trip(self, immutable_serializer, type_union)

u = TestUnionAmbiguousValueFieldNameImmutable(value=42)
self.assertEqual(u.value, 42)
with self.assertRaises(AttributeError):
u.type
with self.assertRaises(AttributeError):
_thrift_serialization_round_trip(self, immutable_serializer, u)
_assert_serialization_round_trip(self, immutable_serializer, u)

u2 = TestUnionAmbiguousValueFieldNameImmutable(type=123)
with self.assertRaises(AttributeError):
u2.value
with self.assertRaises(AssertionError):
_thrift_serialization_round_trip(self, immutable_serializer, u2)
_assert_serialization_round_trip(self, immutable_serializer, u2)

def test_hash(self) -> None:
hash(TestUnionImmutable())
Expand All @@ -378,13 +376,13 @@ def test_equality(self) -> None:
u2 = TestUnionImmutable(string_field="hello")
self.assertIsNot(u1, u2)
self.assertEqual(u1, u2)
_thrift_serialization_round_trip(self, immutable_serializer, u1)
_thrift_serialization_round_trip(self, immutable_serializer, u2)
_assert_serialization_round_trip(self, immutable_serializer, u1)
_assert_serialization_round_trip(self, immutable_serializer, u2)

u3 = TestUnionImmutable(string_field="world")
self.assertIsNot(u1, u3)
self.assertNotEqual(u1, u3)
_thrift_serialization_round_trip(self, immutable_serializer, u3)
_assert_serialization_round_trip(self, immutable_serializer, u3)

def test_ordering(self) -> None:
self.assertLess(
Expand All @@ -408,7 +406,7 @@ def test_adapted_types(self) -> None:
)
self.assertIs(u1.type, TestUnionAdaptedTypesImmutable.Type.EMPTY)
self.assertIsNone(u1.value)
_thrift_serialization_round_trip(self, immutable_serializer, u1)
_assert_serialization_round_trip(self, immutable_serializer, u1)

with self.assertRaisesRegex(
AttributeError,
Expand Down Expand Up @@ -465,8 +463,8 @@ def test_adapted_types(self) -> None:
with self.assertRaisesRegex(
AttributeError, "'str' object has no attribute 'timestamp'"
):
(TestUnionAdaptedTypesImmutable.fromValue("1718728839"),)
_thrift_serialization_round_trip(self, immutable_serializer, u2)
TestUnionAdaptedTypesImmutable.fromValue("1718728839")
_assert_serialization_round_trip(self, immutable_serializer, u2)

u3 = TestUnionAdaptedTypesImmutable(non_adapted_i32=1718728839)
self.assertIs(
Expand All @@ -476,7 +474,7 @@ def test_adapted_types(self) -> None:
self.assertIs(u3.type, TestUnionAdaptedTypesImmutable.Type.non_adapted_i32)
self.assertIs(u3.value, u3.non_adapted_i32)
self.assertEqual(u3.non_adapted_i32, 1718728839)
_thrift_serialization_round_trip(self, immutable_serializer, u3)
_assert_serialization_round_trip(self, immutable_serializer, u3)

def test_to_immutable_python(self) -> None:
union_immutable = TestUnionImmutable(string_field="hello")
Expand Down

0 comments on commit e476ebe

Please sign in to comment.