diff --git a/third-party/thrift/src/thrift/lib/cpp2/protocol/CursorBasedSerializer.h b/third-party/thrift/src/thrift/lib/cpp2/protocol/CursorBasedSerializer.h index ad2c8ef61497a..69777d1ff8938 100644 --- a/third-party/thrift/src/thrift/lib/cpp2/protocol/CursorBasedSerializer.h +++ b/third-party/thrift/src/thrift/lib/cpp2/protocol/CursorBasedSerializer.h @@ -17,11 +17,11 @@ #pragma once #include -#include #include #include #include #include +#include #include #include #include @@ -70,6 +70,7 @@ class CursorSerializationWrapper { static_assert( std::is_same_v, "ProtocolWriter must be BinaryProtocolReader"); + using Serializer = Serializer; public: CursorSerializationWrapper() = default; @@ -100,12 +101,13 @@ class CursorSerializationWrapper { * Object read path (traditional Thrift deserialization) * Deserializes into a (returned) Thrift object. */ - T deserialize() { + T deserialize() const { + if (std::holds_alternative(protocol_)) { + folly::throw_exception( + "Concurrent reads/writes not supported"); + } checkHasData(); - T ret; - ret.read(reader()); - done(); - return ret; + return Serializer::template deserialize(serializedData_.get()); } /** diff --git a/third-party/thrift/src/thrift/lib/cpp2/protocol/test/CursorBasedSerializerTest.cpp b/third-party/thrift/src/thrift/lib/cpp2/protocol/test/CursorBasedSerializerTest.cpp index ebaaf98ee62e2..487acbc78d3c4 100644 --- a/third-party/thrift/src/thrift/lib/cpp2/protocol/test/CursorBasedSerializerTest.cpp +++ b/third-party/thrift/src/thrift/lib/cpp2/protocol/test/CursorBasedSerializerTest.cpp @@ -16,7 +16,6 @@ #include -#include #include #include #include @@ -686,3 +685,18 @@ TEST(CursorBasedSerializer, CursorReadRemainingEndOne) { TEST(CursorBasedSerializer, CursorReadRemainingEndMany) { doCursorReadRemainEndTest(10); } + +TEST(CursorBasedSerializer, ConcurrentAccess) { + EmptyWrapper wrapper; + auto writer = wrapper.beginWrite(); + EXPECT_THROW(wrapper.beginRead(), std::runtime_error); + EXPECT_THROW(wrapper.deserialize(), std::runtime_error); + EXPECT_THROW(wrapper.beginWrite(), std::runtime_error); + wrapper.endWrite(std::move(writer)); + + auto reader = wrapper.beginRead(); + EXPECT_THROW(wrapper.beginRead(), std::runtime_error); + EXPECT_EQ(wrapper.deserialize(), Empty{}); + EXPECT_THROW(wrapper.beginWrite(), std::runtime_error); + wrapper.endRead(std::move(reader)); +}