Skip to content

Commit

Permalink
Merge pull request #208 from ankane/bad-any-cast2
Browse files Browse the repository at this point in the history
Fix `bad any cast` errors caused by hash collisions
  • Loading branch information
jasonroelofs authored Oct 7, 2024
2 parents 245ab01 + 6bea811 commit a82515b
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 29 deletions.
49 changes: 35 additions & 14 deletions include/rice/rice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,7 @@ namespace Rice::detail

#include <unordered_map>
#include <any>
#include <tuple>


namespace Rice::detail
Expand All @@ -1094,9 +1095,9 @@ namespace Rice::detail

private:
size_t key(VALUE klass, ID method_id);
std::unordered_map<size_t, std::any> natives_ = {};
std::unordered_multimap<size_t, std::tuple<VALUE, ID, std::any>> natives_ = {};
};
}
}

// --------- NativeRegistry.ipp ---------

Expand All @@ -1114,19 +1115,30 @@ namespace Rice::detail
// https://stackoverflow.com/a/2634715
inline size_t NativeRegistry::key(VALUE klass, ID id)
{
if (rb_type(klass) == T_ICLASS)
{
klass = detail::protect(rb_class_of, klass);
}

uint32_t prime = 53;
return (prime + klass) * prime + id;
}

inline void NativeRegistry::add(VALUE klass, ID method_id, std::any callable)
{
// Now store data about it
this->natives_[key(klass, method_id)] = callable;
if (rb_type(klass) == T_ICLASS)
{
klass = detail::protect(rb_class_of, klass);
}

auto range = this->natives_.equal_range(key(klass, method_id));
for (auto it = range.first; it != range.second; ++it)
{
const auto [k, m, d] = it->second;

if (k == klass && m == method_id)
{
std::get<2>(it->second) = callable;
return;
}
}

this->natives_.emplace(std::make_pair(key(klass, method_id), std::make_tuple(klass, method_id, callable)));
}

template <typename Return_T>
Expand All @@ -1145,14 +1157,23 @@ namespace Rice::detail
template <typename Return_T>
inline Return_T NativeRegistry::lookup(VALUE klass, ID method_id)
{
auto iter = this->natives_.find(key(klass, method_id));
if (iter == this->natives_.end())
if (rb_type(klass) == T_ICLASS)
{
rb_raise(rb_eRuntimeError, "Could not find data for klass and method id");
klass = detail::protect(rb_class_of, klass);
}

auto range = this->natives_.equal_range(key(klass, method_id));
for (auto it = range.first; it != range.second; ++it)
{
const auto [k, m, d] = it->second;

if (k == klass && m == method_id)
{
return std::any_cast<Return_T>(d);
}
}

std::any data = iter->second;
return std::any_cast<Return_T>(data);
rb_raise(rb_eRuntimeError, "Could not find data for klass and method id");
}
}

Expand Down
7 changes: 4 additions & 3 deletions rice/detail/NativeRegistry.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <unordered_map>
#include <any>
#include <tuple>

#include "ruby.hpp"

Expand All @@ -23,9 +24,9 @@ namespace Rice::detail

private:
size_t key(VALUE klass, ID method_id);
std::unordered_map<size_t, std::any> natives_ = {};
std::unordered_multimap<size_t, std::tuple<VALUE, ID, std::any>> natives_ = {};
};
}
}
#include "NativeRegistry.ipp"

#endif // Rice__detail__NativeRegistry__hpp
#endif // Rice__detail__NativeRegistry__hpp
44 changes: 32 additions & 12 deletions rice/detail/NativeRegistry.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,30 @@ namespace Rice::detail
// https://stackoverflow.com/a/2634715
inline size_t NativeRegistry::key(VALUE klass, ID id)
{
if (rb_type(klass) == T_ICLASS)
{
klass = detail::protect(rb_class_of, klass);
}

uint32_t prime = 53;
return (prime + klass) * prime + id;
}

inline void NativeRegistry::add(VALUE klass, ID method_id, std::any callable)
{
// Now store data about it
this->natives_[key(klass, method_id)] = callable;
if (rb_type(klass) == T_ICLASS)
{
klass = detail::protect(rb_class_of, klass);
}

auto range = this->natives_.equal_range(key(klass, method_id));
for (auto it = range.first; it != range.second; ++it)
{
const auto [k, m, d] = it->second;

if (k == klass && m == method_id)
{
std::get<2>(it->second) = callable;
return;
}
}

this->natives_.emplace(std::make_pair(key(klass, method_id), std::make_tuple(klass, method_id, callable)));
}

template <typename Return_T>
Expand All @@ -45,13 +56,22 @@ namespace Rice::detail
template <typename Return_T>
inline Return_T NativeRegistry::lookup(VALUE klass, ID method_id)
{
auto iter = this->natives_.find(key(klass, method_id));
if (iter == this->natives_.end())
if (rb_type(klass) == T_ICLASS)
{
rb_raise(rb_eRuntimeError, "Could not find data for klass and method id");
klass = detail::protect(rb_class_of, klass);
}

auto range = this->natives_.equal_range(key(klass, method_id));
for (auto it = range.first; it != range.second; ++it)
{
const auto [k, m, d] = it->second;

if (k == klass && m == method_id)
{
return std::any_cast<Return_T>(d);
}
}

std::any data = iter->second;
return std::any_cast<Return_T>(data);
rb_raise(rb_eRuntimeError, "Could not find data for klass and method id");
}
}
50 changes: 50 additions & 0 deletions test/test_Native_Registry.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "unittest.hpp"
#include "embed_ruby.hpp"

#include <rice/rice.hpp>
#include <rice/stl.hpp>

using namespace Rice;

TESTSUITE(NativeRegistry);

SETUP(NativeRegistry)
{
embed_ruby();
}

TESTCASE(collisions)
{
std::array<Class, 100> classes;
int scale = 1000;

for (int i = 0; i < std::size(classes); i++)
{
Class cls(anonymous_class());

for (int j = 0; j < scale; j++)
{
cls.define_function("int" + std::to_string(j), []() { return 1; });
cls.define_function("long" + std::to_string(j), []() { return 1L; });
cls.define_function("double" + std::to_string(j), []() { return 1.0; });
cls.define_function("float" + std::to_string(j), []() { return 1.0f; });
cls.define_function("bool" + std::to_string(j), []() { return true; });
}

classes[i] = cls;
}

for (auto& cls : classes)
{
auto obj = cls.call("new");

for (int j = 0; j < scale; j++)
{
obj.call("int" + std::to_string(j));
obj.call("long" + std::to_string(j));
obj.call("double" + std::to_string(j));
obj.call("float" + std::to_string(j));
obj.call("bool" + std::to_string(j));
}
}
}

0 comments on commit a82515b

Please sign in to comment.