Skip to content

Commit

Permalink
Review comments 2
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanNovoselov committed Jan 2, 2025
1 parent 0b7ac60 commit 81f5a81
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 28 deletions.
4 changes: 0 additions & 4 deletions src/common/snippets/src/lowered/pass/init_live_ranges.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ bool InitLiveRanges::run(LinearIR& linear_ir) {
expr->set_live_regs(std::prev(expr_it)->get()->get_live_regs());
continue;
}

OPENVINO_ASSERT(expr->get_output_count() == op->get_output_size() ||
ov::is_type<op::LoopEnd>(op) ||
ov::is_type<ov::op::v0::Result>(op), "Incorrect count of output port descriptors!");
const double start = expr->get_exec_num();
// Remove all regs that expired before start
regs_to_expire.erase(regs_to_expire.begin(), regs_to_expire.lower_bound(start)); // remove all elements lower than start (not equal)
Expand Down
12 changes: 6 additions & 6 deletions src/common/snippets/src/lowered/pass/insert_reg_spills.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ bool InsertRegSpills::run(LinearIR& linear_ir) {
// Note: we need to insert immediately before LoopBegin => increment start_it
start_it++;
const auto& loop_begin_live = start_it->get()->get_live_regs();
std::set<Reg> brgemm_used;
const auto& brgemm_reg_info = expr->get_reg_info();
brgemm_used.insert(brgemm_reg_info.first.begin(), brgemm_reg_info.first.end());
brgemm_used.insert(brgemm_reg_info.second.begin(), brgemm_reg_info.second.end());
// Note: before the loop, we need to spill all live regs except for the ones used by brgemm
std::set<Reg> used;
const auto& reg_info = expr->get_reg_info();
used.insert(reg_info.first.begin(), reg_info.first.end());
used.insert(reg_info.second.begin(), reg_info.second.end());
// Note: before the loop, we need to spill all live regs except for the ones used by the target expression
std::set<Reg> regs_to_spill;
std::set_difference(loop_begin_live.begin(), loop_begin_live.end(),
brgemm_used.begin(), brgemm_used.end(),
used.begin(), used.end(),
std::inserter(regs_to_spill, regs_to_spill.begin()));
// we also need to keep kernel regs alive (actually only abi_param_1 is used in emitters, but save all for consistency)
for (const auto& r : m_reg_manager.get_kernel_call_regs( snippets::op::Kernel::make_kernel(linear_ir.is_dynamic())))
Expand Down
10 changes: 9 additions & 1 deletion src/common/snippets/src/lowered/pass/validate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,13 @@ void validate_buffer(const ExpressionPtr& expr, const LinearIR& linear_ir) {
void validate_loop_end(const ExpressionPtr& expr, const LinearIR& linear_ir) {
const auto loop_end = ov::as_type_ptr<op::LoopEnd>(expr->get_node());
OPENVINO_ASSERT(loop_end, "LoopEnd validation expects LoopEnd op");
OPENVINO_ASSERT(loop_end->get_loop_begin() != nullptr,
const auto& loop_begin = loop_end->get_loop_begin();
OPENVINO_ASSERT(loop_begin != nullptr,
"LoopEnd must be connected to the LoopBegin");
const auto num_inputs = expr->get_input_count();
OPENVINO_ASSERT(num_inputs >= 1, "LoopEnd expression must have at least 1 input");
OPENVINO_ASSERT(expr->get_input_port_connector(num_inputs - 1)->get_source().get_expr()->get_node() == loop_begin,
"LopEnd expression must have LoopBegin attached to the last connector");

const auto& loop_manager = linear_ir.get_loop_manager();
const auto& loop_info = loop_manager->get_loop_info<UnifiedLoopInfo>(loop_end->get_id());
Expand Down Expand Up @@ -148,6 +153,9 @@ bool Validate::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lo
if (found != m_validation_map.cend()) {
(found->second)(expr, linear_ir);
}
OPENVINO_ASSERT(expr->get_output_count() == node->get_output_size() ||
ov::is_type<op::LoopEnd>(node) ||
ov::is_type<ov::op::v0::Result>(node), "Incorrect count of output port descriptors!");
expr->validate();
// Loop expr doesn't have shapes and layouts
if (!ov::is_type<op::LoopBase>(node))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h,
m_memory_offsets.push_back(brgemm_repack->get_offset_compensations());
m_buffer_ids.push_back(utils::get_buffer_cluster_id(expr->get_output_port(1)));
}
m_live_regs = expr->get_live_regs();
}

void jit_brgemm_copy_b_emitter::validate_arguments(const std::vector<size_t>& in,
Expand All @@ -81,31 +82,38 @@ void jit_brgemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const s
if (out.size() > 1)
mem_ptrs_idxs.emplace_back(out[1]);

std::set<snippets::Reg> regs_to_spill = m_live_regs;
// Note: these 3 registers will be corrupted by the caller during the ABI call
regs_to_spill.emplace(snippets::RegType::gpr, abi_param1.getIdx());
regs_to_spill.emplace(snippets::RegType::gpr, abi_param2.getIdx());
regs_to_spill.emplace(snippets::RegType::gpr, h->rbp.getIdx());
// Note: abi_param_1 is a default invalid value to check later that the aux reg was allocated properly
Xbyak::Reg64 aux_reg = abi_param1;
utils::init_memory_access_aux_gpr(mem_ptrs_idxs, m_memory_offsets, aux_gpr_idxs, regs_to_spill, aux_reg);

EmitABIRegSpills spill(h);
spill.preamble();
spill.preamble(regs_to_spill);

h->mov(h->rbp, reinterpret_cast<uint64_t>(BrgemmCopyBKernelExecutor::execute));
auto reserved_stack_size = sizeof(BrgemmCopyBKernel::call_args);
// Reserve memory on the stack
h->sub(h->rsp, reserved_stack_size);

const bool is_dynamic_case =
std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value<size_t>);
Xbyak::Reg64 aux_reg = is_dynamic_case ? ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs) : Xbyak::Reg64();

const std::vector<size_t> args_offsets{GET_OFF_BRGEMM_COPY_B_ARGS(src),
GET_OFF_BRGEMM_COPY_B_ARGS(tr_src),
GET_OFF_BRGEMM_COPY_B_ARGS(compensation_ptr)};
const auto& mem_ptrs = ov::intel_cpu::utils::transform_idxs_to_regs(mem_ptrs_idxs);
for (size_t i = 0; i < mem_ptrs.size(); i++) {
if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i]))
if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i])) {
OV_CPU_JIT_EMITTER_ASSERT(aux_reg != abi_param1, "Aux reg is needed, but wasn't allocated");
utils::push_ptr_with_runtime_offset_on_stack(h,
args_offsets[i],
mem_ptrs[i],
aux_reg,
GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t));
else
} else {
utils::push_ptr_with_static_offset_on_stack(h, args_offsets[i], mem_ptrs[i], m_memory_offsets[i]);
}
}

// No scratchpad => need to write nullptr manually
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class jit_brgemm_copy_b_emitter : public jit_emitter {
std::vector<size_t> m_memory_offsets{};
std::vector<size_t> m_buffer_ids{};
std::shared_ptr<BrgemmCopyBKernelExecutor> m_kernel_executor{nullptr};
std::set<snippets::Reg> m_live_regs{};
bool m_with_comp{false};

#ifdef SNIPPETS_DEBUG_CAPS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,9 @@ void jit_brgemm_emitter::emit_call(const std::vector<size_t>& mem_ptrs_idxs) con
regs_to_spill.emplace(snippets::RegType::gpr, abi_param1.getIdx());
regs_to_spill.emplace(snippets::RegType::gpr, abi_param2.getIdx());
regs_to_spill.emplace(snippets::RegType::gpr, h->rbp.getIdx());
const bool is_dynamic_case =
std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value<size_t>);
// Note: abi_param_1 is a default invalid value to check later that the aux reg was allocated properly
Xbyak::Reg64 aux_reg = abi_param1;
if (std::is_same<T, BrgemmAMXKernelExecutor>() || is_dynamic_case) {
if (!aux_gpr_idxs.empty()) {
aux_reg = Xbyak::Reg64(static_cast<int>(aux_gpr_idxs[0]));
} else {
aux_reg = ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs);
regs_to_spill.emplace(snippets::RegType::gpr, aux_reg.getIdx());
}
}
utils::init_memory_access_aux_gpr(mem_ptrs_idxs, m_memory_offsets, aux_gpr_idxs, regs_to_spill, aux_reg);
EmitABIRegSpills spill(h);
spill.preamble(regs_to_spill);

Expand Down
19 changes: 19 additions & 0 deletions src/plugins/intel_cpu/src/emitters/snippets/x64/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,25 @@ Xbyak::Reg64 get_aux_gpr(const std::vector<size_t>& used_gpr_idxs) {
OV_CPU_JIT_EMITTER_THROW("Failed to allocate aux GPR");
}

void init_memory_access_aux_gpr(const std::vector<size_t>& mem_ptr_reg_idxs,
const std::vector<size_t>& memory_offsets,
const std::vector<size_t>& aux_gpr_idxs,
std::set<snippets::Reg>& regs_to_spill,
Xbyak::Reg64& aux_reg) {
const bool is_dynamic_case =
std::any_of(memory_offsets.cbegin(), memory_offsets.cend(),
ov::snippets::utils::is_dynamic_value<size_t>);
// Note: abi_param_1 is a default invalid value to check later that the aux reg was allocated properly
if (is_dynamic_case) {
if (!aux_gpr_idxs.empty()) {
aux_reg = Xbyak::Reg64(static_cast<int>(aux_gpr_idxs[0]));
} else {
aux_reg = ov::intel_cpu::utils::get_aux_gpr(mem_ptr_reg_idxs);
regs_to_spill.emplace(snippets::RegType::gpr, aux_reg.getIdx());
}
}
};

void push_ptr_with_runtime_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h,
size_t stack_offset,
Xbyak::Reg64 ptr_reg,
Expand Down
18 changes: 18 additions & 0 deletions src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ size_t get_buffer_cluster_id(const ov::snippets::lowered::ExpressionPort& port);
*/
Xbyak::Reg64 get_aux_gpr(const std::vector<size_t>& used_gpr_idxs);

/**
* @brief Initializes aux gpr register for dynamic memory access emitters. If any of the `memory_offsets` is dynamic,
* then try to assign `aux_reg` a register from `aux_gpr_idxs`. If `aux_gpr_idxs` is empty, then choose a register that
* is not in `mem_ptr_reg_idxs` and add it to `regs_to_spill`.
* @param mem_ptr_reg_idxs register indexes reserved to store memory pointers in this emitter
* @param memory_offsets memory offsets, could be dynamic
* @param aux_gpr_idxs pool of available gp register indexes
* @param regs_to_spill set of live registers to be spilled before ABI call
* @param aux_reg auxiliary register that should be initialized
*/
void init_memory_access_aux_gpr(const std::vector<size_t>& mem_ptr_reg_idxs,
const std::vector<size_t>& memory_offsets,
const std::vector<size_t>& aux_gpr_idxs,
std::set<snippets::Reg>& regs_to_spill,
Xbyak::Reg64& aux_reg);

/**
* @brief Push data pointer on stack adding offset. The offset is taken from runtime params `abi_param1`
* @param h generator
Expand All @@ -60,6 +76,8 @@ void push_ptr_with_static_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h
Xbyak::Reg64 ptr_reg,
size_t ptr_offset);



} // namespace utils
} // namespace intel_cpu
} // namespace ov

0 comments on commit 81f5a81

Please sign in to comment.