Skip to content

Commit

Permalink
Fix std::smart_ptr handling. See #203.
Browse files Browse the repository at this point in the history
  • Loading branch information
cfis committed Feb 25, 2024
1 parent 0129b88 commit f6baa4b
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 5 deletions.
12 changes: 11 additions & 1 deletion rice/detail/Wrapper.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,26 @@ namespace Rice::detail

Wrapper* wrapper = nullptr;

if constexpr (!std::is_void_v<Wrapper_T>)
// Is this a pointer but cannot be copied? For example a std::unique_ptr
if constexpr (!std::is_void_v<Wrapper_T> && !std::is_copy_constructible_v<Wrapper_T>)
{
wrapper = new Wrapper_T(std::move(data));
result = TypedData_Wrap_Struct(klass, rb_type, wrapper);
}
// Is this a pointer or smart pointer like std::shared_ptr
else if constexpr (!std::is_void_v<Wrapper_T>)
{
wrapper = new Wrapper_T(data);
result = TypedData_Wrap_Struct(klass, rb_type, wrapper);
}
// Is this a pointer and it cannot copied? This is for std::unique_ptr
// If ruby is the owner than copy the object
else if (isOwner)
{
wrapper = new WrapperValue<T>(data);
result = TypedData_Wrap_Struct(klass, rb_type, wrapper);
}
// Ruby is not the owner so just wrap the reference
else
{
wrapper = new WrapperReference<T>(data);
Expand Down
2 changes: 1 addition & 1 deletion rice/stl/smart_ptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Rice::detail
class WrapperSmartPointer : public Wrapper
{
public:
WrapperSmartPointer(SmartPointer_T<Arg_Ts...>& data);
WrapperSmartPointer(SmartPointer_T<Arg_Ts...> data);
~WrapperSmartPointer();
void* get() override;
SmartPointer_T<Arg_Ts...>& data();
Expand Down
30 changes: 29 additions & 1 deletion rice/stl/smart_ptr.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Rice::detail
{
// ---- WrapperSmartPointer ------
template <template <typename, typename...> typename SmartPointer_T, typename...Arg_Ts>
inline WrapperSmartPointer<SmartPointer_T, Arg_Ts...>::WrapperSmartPointer(SmartPointer_T<Arg_Ts...>& data)
inline WrapperSmartPointer<SmartPointer_T, Arg_Ts...>::WrapperSmartPointer(SmartPointer_T<Arg_Ts...> data)
: data_(std::move(data))
{
}
Expand Down Expand Up @@ -46,6 +46,20 @@ namespace Rice::detail
}
};

template <typename T>
class To_Ruby<std::unique_ptr<T>&>
{
public:
VALUE convert(std::unique_ptr<T>& data)
{
std::pair<VALUE, rb_data_type_t*> rubyTypeInfo = detail::Registries::instance.types.figureType<T>(*data);

// Use custom wrapper type
using Wrapper_T = WrapperSmartPointer<std::unique_ptr, T>;
return detail::wrap<std::unique_ptr<T>, Wrapper_T>(rubyTypeInfo.first, rubyTypeInfo.second, data, true);
}
};

template <typename T>
class From_Ruby<std::unique_ptr<T>&>
{
Expand Down Expand Up @@ -121,6 +135,20 @@ namespace Rice::detail
Arg* arg_ = nullptr;
};

template <typename T>
class To_Ruby<std::shared_ptr<T>&>
{
public:
VALUE convert(std::shared_ptr<T>& data)
{
std::pair<VALUE, rb_data_type_t*> rubyTypeInfo = detail::Registries::instance.types.figureType<T>(*data);

// Use custom wrapper type
using Wrapper_T = WrapperSmartPointer<std::shared_ptr, T>;
return detail::wrap<std::shared_ptr<T>, Wrapper_T>(rubyTypeInfo.first, rubyTypeInfo.second, data, true);
}
};

template <typename T>
class From_Ruby<std::shared_ptr<T>&>
{
Expand Down
47 changes: 45 additions & 2 deletions test/test_Stl_SmartPointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ namespace
return instance_;
}

std::shared_ptr<MyClass>& share_ref()
{
if (!instance_)
{
instance_ = std::make_shared<MyClass>();
}
return instance_;
}

public:
static inline std::shared_ptr<MyClass> instance_;
};
Expand Down Expand Up @@ -104,7 +113,8 @@ SETUP(SmartPointer)
define_class<Factory>("Factory").
define_constructor(Constructor<Factory>()).
define_method("transfer", &Factory::transfer).
define_method("share", &Factory::share);
define_method("share", &Factory::share).
define_method("share_ref", &Factory::share_ref);

define_global_function("extract_flag_unique_ptr_ref", &extractFlagUniquePtrRef);
define_global_function("extract_flag_shared_ptr", &extractFlagSharedPtr);
Expand Down Expand Up @@ -152,8 +162,13 @@ TESTCASE(ShareOwnership)
my_class.set_flag(i)
end)";


ASSERT_EQUAL(0, Factory::instance_.use_count());
m.module_eval(code);

ASSERT_EQUAL(11, Factory::instance_.use_count());
rb_gc_start();
ASSERT_EQUAL(1, Factory::instance_.use_count());

ASSERT_EQUAL(1, MyClass::constructorCalls);
ASSERT_EQUAL(0, MyClass::copyConstructorCalls);
Expand All @@ -162,6 +177,34 @@ TESTCASE(ShareOwnership)
ASSERT_EQUAL(9, Factory::instance_->flag);
}


TESTCASE(ShareOwnership2)
{
MyClass::reset();

Module m = define_module("TestingModule");

// Create ruby objects that point to the same instance of MyClass
std::string code = R"(factory = Factory.new
10.times do |i|
my_class = factory.share
my_class.set_flag(i)
end)";

Factory factory;
std::shared_ptr<MyClass> myClass = factory.share();
ASSERT_EQUAL(2, Factory::instance_.use_count());

// Call some ruby code
Data_Object<Factory> wrapper(factory);
ASSERT_EQUAL(2, Factory::instance_.use_count());
wrapper.instance_eval("self.share_ref.set_flag(1)");

ASSERT_EQUAL(3, Factory::instance_.use_count());
rb_gc_start();
ASSERT_EQUAL(2, Factory::instance_.use_count());
}

TESTCASE(UniquePtrRefParameter)
{
MyClass::reset();
Expand Down Expand Up @@ -237,4 +280,4 @@ TESTCASE(SharedPtrRefDefaultParameter)
Object result = m.module_eval(code);
// The default value kicks in and ignores any previous pointer
ASSERT_EQUAL(0, detail::From_Ruby<int>().convert(result));
}
}

0 comments on commit f6baa4b

Please sign in to comment.