Skip to content

Commit

Permalink
clear allgather valid flag
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Nov 24, 2024
1 parent f4cf606 commit d7fcd91
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion csrc/compile/native_z3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ class DSParamRegistry {

const DSParam& getParam(long ds_id) const { return params_.at(ds_id); }
const size_t getNumParams() const { return params_.size(); }
const at::Tensor& getGatheredParam(long ds_id) const { return gathered_params_.at(ds_id); }
const at::Tensor& getGatheredParam(long ds_id) const
{
assert(hasKey(gathered_params_, ds_id));
return gathered_params_.at(ds_id);
}
bool hasGatheredParam(long ds_id) const { return hasKey(gathered_params_, ds_id); }
void setPersistent(long ds_id, bool persistent) { params_.at(ds_id).setPersistent(persistent); }

Expand Down Expand Up @@ -627,6 +631,11 @@ class CustomOpExecutor {

bool hasReloadBuffer(long id) { return hasKey(reload_buffers_, id); }

void invalidateGatheredParam(long ds_id)
{
if (hasKey(valid_, ds_id)) { valid_[ds_id] = false; }
}

private:
c10::intrusive_ptr<c10d::ProcessGroup> process_group_;
std::shared_ptr<DSParamRegistry> param_registry_;
Expand Down Expand Up @@ -896,6 +905,8 @@ void invalidate_gathered_param(long ds_id)

param_registry->unregisterGatheredParam(ds_id);
param_registry->registerGatheredParam(ds_id, at::Tensor());

for (auto& it : executors) { it.second->invalidateGatheredParam(ds_id); }
}

void clear_all_gathered_params()
Expand All @@ -906,6 +917,7 @@ void clear_all_gathered_params()
if (param.isPersistent()) { continue; }
if (param_registry->hasGatheredParam(ds_id)) {
param_registry->unregisterGatheredParam(ds_id);
for (auto& it : executors) { it.second->invalidateGatheredParam(ds_id); }
}
}
}
Expand Down

0 comments on commit d7fcd91

Please sign in to comment.