From d7fcd910ca5b1bdaeeddc0458a5131f107d00372 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sun, 24 Nov 2024 06:06:07 +0000 Subject: [PATCH] clear allgather valid flag --- csrc/compile/native_z3.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/csrc/compile/native_z3.cpp b/csrc/compile/native_z3.cpp index 4214ea80fab2..4c38dccdd22a 100644 --- a/csrc/compile/native_z3.cpp +++ b/csrc/compile/native_z3.cpp @@ -137,7 +137,11 @@ class DSParamRegistry { const DSParam& getParam(long ds_id) const { return params_.at(ds_id); } const size_t getNumParams() const { return params_.size(); } - const at::Tensor& getGatheredParam(long ds_id) const { return gathered_params_.at(ds_id); } + const at::Tensor& getGatheredParam(long ds_id) const + { + assert(hasKey(gathered_params_, ds_id)); + return gathered_params_.at(ds_id); + } bool hasGatheredParam(long ds_id) const { return hasKey(gathered_params_, ds_id); } void setPersistent(long ds_id, bool persistent) { params_.at(ds_id).setPersistent(persistent); } @@ -627,6 +631,11 @@ class CustomOpExecutor { bool hasReloadBuffer(long id) { return hasKey(reload_buffers_, id); } + void invalidateGatheredParam(long ds_id) + { + if (hasKey(valid_, ds_id)) { valid_[ds_id] = false; } + } + private: c10::intrusive_ptr process_group_; std::shared_ptr param_registry_; @@ -896,6 +905,8 @@ void invalidate_gathered_param(long ds_id) param_registry->unregisterGatheredParam(ds_id); param_registry->registerGatheredParam(ds_id, at::Tensor()); + + for (auto& it : executors) { it.second->invalidateGatheredParam(ds_id); } } void clear_all_gathered_params() @@ -906,6 +917,7 @@ void clear_all_gathered_params() if (param.isPersistent()) { continue; } if (param_registry->hasGatheredParam(ds_id)) { param_registry->unregisterGatheredParam(ds_id); + for (auto& it : executors) { it.second->invalidateGatheredParam(ds_id); } } } }