diff --git a/xla/python/BUILD b/xla/python/BUILD index 5387b8986ad0d..1029606e0fa6e 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -1112,6 +1112,7 @@ cc_library( # placeholder for index annotation deps "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/hash", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@nanobind", diff --git a/xla/python/weakref_lru_cache.cc b/xla/python/weakref_lru_cache.cc index a6fd73bc645b0..ade03a916864f 100644 --- a/xla/python/weakref_lru_cache.cc +++ b/xla/python/weakref_lru_cache.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/cleanup/cleanup.h" +#include "absl/hash/hash.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" @@ -78,34 +79,58 @@ class HashablePyDictIter { nb::detail::dict_iterator& iter_; }; +struct HashableKey { + nb::object context; + nb::args args; + nb::kwargs kwargs; + + template + friend H AbslHashValue(H h, const HashableKey& key) { + // Note: Despite the fact this is an ABSL hash function, it's safe to call + // functions that may throw exceptions such as nb::hash(), because it is + // used by an LRUCache, which uses a std::unordered_map, which is + // exception-safe. + h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args)); + nb::detail::dict_iterator begin = key.kwargs.begin(); + nb::detail::dict_iterator end = key.kwargs.end(); + h = H::combine_unordered(std::move(h), HashablePyDictIter(begin), + HashablePyDictIter(end)); + h = H::combine(std::move(h), key.kwargs.size()); + return h; + } +}; + } // namespace class WeakrefLRUCache : public std::enable_shared_from_this { public: - struct Key { - nb::object context; - nb::args args; - nb::kwargs kwargs; + class Key { + public: + Key(nb::object context, nb::args args, nb::kwargs kwargs) + : context_(std::move(context)), + args_(std::move(args)), + kwargs_(std::move(kwargs)), + cached_hash_(absl::HashOf(HashableKey{context_, args_, kwargs_})) {} bool operator==(const Key& other) const { - return context.equal(other.context) && args.equal(other.args) && - kwargs.equal(other.kwargs); + return context_.equal(other.context_) && args_.equal(other.args_) && + kwargs_.equal(other.kwargs_); } template friend H AbslHashValue(H h, const Key& key) { - // Note: Despite the fact this is an ABSL hash function, it's safe to call - // functions that may throw exceptions such as nb::hash(), because it is - // used by an LRUCache, which uses a std::unordered_map, which is - // exception-safe. - h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args)); - nb::detail::dict_iterator begin = key.kwargs.begin(); - nb::detail::dict_iterator end = key.kwargs.end(); - h = H::combine_unordered(std::move(h), HashablePyDictIter(begin), - HashablePyDictIter(end)); - h = H::combine(std::move(h), key.kwargs.size()); - return h; + return H::combine(std::move(h), key.cached_hash_); } + + nb::object context() const { return context_; } + nb::args args() const { return args_; } + nb::kwargs kwargs() const { return kwargs_; } + + private: + nb::object context_; + nb::args args_; + nb::kwargs kwargs_; + size_t cached_hash_; }; struct CacheEntry { @@ -123,14 +148,13 @@ class WeakrefLRUCache : public std::enable_shared_from_this { }; struct WeakrefCacheKey { - nb::handle object; + nb::weakref ref; size_t cached_hash; }; using Cache = xla::LRUCache>; struct WeakrefCacheValue { - std::optional weakref; std::shared_ptr cache; }; @@ -141,7 +165,7 @@ class WeakrefLRUCache : public std::enable_shared_from_this { struct WeakrefKeyEq { bool operator()(const WeakrefCacheKey& lhs, const WeakrefCacheKey& rhs) const { - return lhs.object.equal(rhs.object); + return lhs.ref.equal(rhs.ref); } }; @@ -150,21 +174,39 @@ class WeakrefLRUCache : public std::enable_shared_from_this { : cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {} std::shared_ptr GetCache(WeakrefCacheKey key) { - auto [it, inserted] = entries_.emplace(key, WeakrefCacheValue()); - if (!inserted) { - return it->second.cache; + WeakrefCacheValue& value = entries_[key]; + if (!value.cache) { + value.cache = std::make_shared(&lru_list_); } + return value.cache; + } - auto& value = it->second; + nb::object Call(nb::object weakref_key, nb::args args, + nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS { + nb::object context = cache_context_fn_(); + + // We precompute all of the hash values needed by the various maps rather + // than computing them during the std::unordered_map insertions. At the very + // least, MSVC's std::unordered_map has undefined behavior if the hash + // function throws an exception + // (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace). + Key key(context, args, kwargs); + size_t wrcache_hash = static_cast(nb::hash(weakref_key)); + + // No hash computations after this point. - value.cache = std::make_shared(&lru_list_); auto weakref_gc_callback = nb::cpp_function( - [this_weak = weak_from_this(), key](nb::handle weakref) { + [this_weak = weak_from_this(), wrcache_hash](nb::handle weakref) { auto cache = this_weak.lock(); if (cache == nullptr) { return; } - auto it = cache->entries_.find(key); + // The object the reference referred to is now in the process of being + // destroyed, so we cannot refer to its contents. Python weakref + // objects compare based on identity if the object they refer to is + // gone, so the hash lookup will work fine. + auto it = cache->entries_.find( + WeakrefCacheKey{nb::borrow(weakref), wrcache_hash}); if (it == cache->entries_.end()) { return; } @@ -172,21 +214,9 @@ class WeakrefLRUCache : public std::enable_shared_from_this { auto tmp = std::move(it->second); cache->entries_.erase(it); }); - PyObject* ref = - PyWeakref_NewRef(key.object.ptr(), weakref_gc_callback.ptr()); - if (!ref) { - entries_.erase(it); - throw nb::python_error(); - } - value.weakref = nb::steal(ref); - return value.cache; - } - - nb::object Call(nb::object weakref_key, nb::args args, - nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS { - nb::object context = cache_context_fn_(); - std::shared_ptr cache_ptr = GetCache(WeakrefCacheKey{ - weakref_key, static_cast(nb::hash(weakref_key))}); + nb::weakref weakref = nb::weakref(weakref_key, weakref_gc_callback); + WeakrefCacheKey wrcache_key{weakref, wrcache_hash}; + std::shared_ptr cache_ptr = GetCache(wrcache_key); Cache& cache = *cache_ptr; ++total_queries_; @@ -206,7 +236,6 @@ class WeakrefLRUCache : public std::enable_shared_from_this { // released if that happens. absl::Cleanup unlock = [this]() ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); }; - Key key{context, args, kwargs}; entry = cache.GetOrCreateIfAbsent(key, [&inserted](const Key& key) { inserted = true; return std::make_shared(); @@ -245,8 +274,8 @@ class WeakrefLRUCache : public std::enable_shared_from_this { for (const auto& wr_entry : entries_) { for (const auto& rest : *wr_entry.second.cache) { nb::tuple result = - nb::make_tuple(*wr_entry.second.weakref, rest.first.context, - rest.first.args, rest.first.kwargs); + nb::make_tuple(*wr_entry.first.ref, rest.first.context(), + rest.first.args(), rest.first.kwargs()); results.push_back(std::move(result)); } } diff --git a/xla/python/weakref_lru_cache_test.py b/xla/python/weakref_lru_cache_test.py index 0376cf1d3690d..92aa783d6b52a 100644 --- a/xla/python/weakref_lru_cache_test.py +++ b/xla/python/weakref_lru_cache_test.py @@ -160,6 +160,29 @@ class WRKey: "WeakrefLRUCache(hits=5, misses=10, maxsize=2048, currsize=10)", ) + def testGCKeys(self): + class WRKey: + + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return self.x == other.x + + def __hash__(self): + return hash(self.x) + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048) + keys = [WRKey(i) for i in range(10)] + for i in range(10): + cache(keys[i], i) + + # Delete some keys, to exercise the weakref callback behavior. + del keys[::2] + + for key in keys: + cache(key, 7) + if __name__ == "__main__": absltest.main()