diff --git a/include/rice/rice.hpp b/include/rice/rice.hpp index ddf79045..13fb2c89 100644 --- a/include/rice/rice.hpp +++ b/include/rice/rice.hpp @@ -1075,6 +1075,7 @@ namespace Rice::detail #include #include +#include namespace Rice::detail @@ -1094,9 +1095,9 @@ namespace Rice::detail private: size_t key(VALUE klass, ID method_id); - std::unordered_map natives_ = {}; + std::unordered_multimap> natives_ = {}; }; -} +} // --------- NativeRegistry.ipp --------- @@ -1114,19 +1115,30 @@ namespace Rice::detail // https://stackoverflow.com/a/2634715 inline size_t NativeRegistry::key(VALUE klass, ID id) { - if (rb_type(klass) == T_ICLASS) - { - klass = detail::protect(rb_class_of, klass); - } - uint32_t prime = 53; return (prime + klass) * prime + id; } inline void NativeRegistry::add(VALUE klass, ID method_id, std::any callable) { - // Now store data about it - this->natives_[key(klass, method_id)] = callable; + if (rb_type(klass) == T_ICLASS) + { + klass = detail::protect(rb_class_of, klass); + } + + auto range = this->natives_.equal_range(key(klass, method_id)); + for (auto it = range.first; it != range.second; ++it) + { + const auto [k, m, d] = it->second; + + if (k == klass && m == method_id) + { + std::get<2>(it->second) = callable; + return; + } + } + + this->natives_.emplace(std::make_pair(key(klass, method_id), std::make_tuple(klass, method_id, callable))); } template @@ -1145,14 +1157,23 @@ namespace Rice::detail template inline Return_T NativeRegistry::lookup(VALUE klass, ID method_id) { - auto iter = this->natives_.find(key(klass, method_id)); - if (iter == this->natives_.end()) + if (rb_type(klass) == T_ICLASS) { - rb_raise(rb_eRuntimeError, "Could not find data for klass and method id"); + klass = detail::protect(rb_class_of, klass); + } + + auto range = this->natives_.equal_range(key(klass, method_id)); + for (auto it = range.first; it != range.second; ++it) + { + const auto [k, m, d] = it->second; + + if (k == klass && m == method_id) + { + return std::any_cast(d); + } } - std::any data = iter->second; - return std::any_cast(data); + rb_raise(rb_eRuntimeError, "Could not find data for klass and method id"); } } diff --git a/rice/detail/NativeRegistry.hpp b/rice/detail/NativeRegistry.hpp index 6b0f3161..3c7b9575 100644 --- a/rice/detail/NativeRegistry.hpp +++ b/rice/detail/NativeRegistry.hpp @@ -3,6 +3,7 @@ #include #include +#include #include "ruby.hpp" @@ -23,9 +24,9 @@ namespace Rice::detail private: size_t key(VALUE klass, ID method_id); - std::unordered_map natives_ = {}; + std::unordered_multimap> natives_ = {}; }; -} +} #include "NativeRegistry.ipp" -#endif // Rice__detail__NativeRegistry__hpp \ No newline at end of file +#endif // Rice__detail__NativeRegistry__hpp diff --git a/rice/detail/NativeRegistry.ipp b/rice/detail/NativeRegistry.ipp index ff7b8413..d325f81b 100644 --- a/rice/detail/NativeRegistry.ipp +++ b/rice/detail/NativeRegistry.ipp @@ -14,19 +14,30 @@ namespace Rice::detail // https://stackoverflow.com/a/2634715 inline size_t NativeRegistry::key(VALUE klass, ID id) { - if (rb_type(klass) == T_ICLASS) - { - klass = detail::protect(rb_class_of, klass); - } - uint32_t prime = 53; return (prime + klass) * prime + id; } inline void NativeRegistry::add(VALUE klass, ID method_id, std::any callable) { - // Now store data about it - this->natives_[key(klass, method_id)] = callable; + if (rb_type(klass) == T_ICLASS) + { + klass = detail::protect(rb_class_of, klass); + } + + auto range = this->natives_.equal_range(key(klass, method_id)); + for (auto it = range.first; it != range.second; ++it) + { + const auto [k, m, d] = it->second; + + if (k == klass && m == method_id) + { + std::get<2>(it->second) = callable; + return; + } + } + + this->natives_.emplace(std::make_pair(key(klass, method_id), std::make_tuple(klass, method_id, callable))); } template @@ -45,13 +56,22 @@ namespace Rice::detail template inline Return_T NativeRegistry::lookup(VALUE klass, ID method_id) { - auto iter = this->natives_.find(key(klass, method_id)); - if (iter == this->natives_.end()) + if (rb_type(klass) == T_ICLASS) { - rb_raise(rb_eRuntimeError, "Could not find data for klass and method id"); + klass = detail::protect(rb_class_of, klass); + } + + auto range = this->natives_.equal_range(key(klass, method_id)); + for (auto it = range.first; it != range.second; ++it) + { + const auto [k, m, d] = it->second; + + if (k == klass && m == method_id) + { + return std::any_cast(d); + } } - std::any data = iter->second; - return std::any_cast(data); + rb_raise(rb_eRuntimeError, "Could not find data for klass and method id"); } } diff --git a/test/test_Native_Registry.cpp b/test/test_Native_Registry.cpp new file mode 100644 index 00000000..a4c892a0 --- /dev/null +++ b/test/test_Native_Registry.cpp @@ -0,0 +1,50 @@ +#include "unittest.hpp" +#include "embed_ruby.hpp" + +#include +#include + +using namespace Rice; + +TESTSUITE(NativeRegistry); + +SETUP(NativeRegistry) +{ + embed_ruby(); +} + +TESTCASE(collisions) +{ + std::array classes; + int scale = 1000; + + for (int i = 0; i < std::size(classes); i++) + { + Class cls(anonymous_class()); + + for (int j = 0; j < scale; j++) + { + cls.define_function("int" + std::to_string(j), []() { return 1; }); + cls.define_function("long" + std::to_string(j), []() { return 1L; }); + cls.define_function("double" + std::to_string(j), []() { return 1.0; }); + cls.define_function("float" + std::to_string(j), []() { return 1.0f; }); + cls.define_function("bool" + std::to_string(j), []() { return true; }); + } + + classes[i] = cls; + } + + for (auto& cls : classes) + { + auto obj = cls.call("new"); + + for (int j = 0; j < scale; j++) + { + obj.call("int" + std::to_string(j)); + obj.call("long" + std::to_string(j)); + obj.call("double" + std::to_string(j)); + obj.call("float" + std::to_string(j)); + obj.call("bool" + std::to_string(j)); + } + } +}