From ad121dd5c4f1951470a445051238c3e805010e70 Mon Sep 17 00:00:00 2001 From: Egor Duplensky Date: Tue, 6 Aug 2024 20:53:32 +0200 Subject: [PATCH] [CPU] Avoid storing extra copies of constant inputs --- src/plugins/intel_cpu/src/cpu_memory.cpp | 34 +- src/plugins/intel_cpu/src/cpu_memory.h | 6 + src/plugins/intel_cpu/src/graph.cpp | 96 +++++ src/plugins/intel_cpu/src/graph.h | 2 + src/plugins/intel_cpu/src/node.cpp | 15 +- src/plugins/intel_cpu/src/node.h | 22 +- .../src/nodes/common/has_subnormals.cpp | 270 ++++++++++++++ .../src/nodes/common/has_subnormals.h | 19 + .../src/nodes/common/subnormals_to_zero.cpp | 24 ++ .../src/nodes/common/subnormals_to_zero.h | 19 + src/plugins/intel_cpu/src/nodes/conv.cpp | 6 +- src/plugins/intel_cpu/src/nodes/conv.h | 10 + src/plugins/intel_cpu/src/nodes/deconv.cpp | 2 +- src/plugins/intel_cpu/src/nodes/deconv.h | 5 + .../executors/dnnl/dnnl_fullyconnected.hpp | 11 +- .../dnnl/dnnl_fullyconnected_primitive.cpp | 3 +- .../src/nodes/executors/dnnl/dnnl_utils.cpp | 22 +- .../src/nodes/executors/dnnl/dnnl_utils.hpp | 4 +- .../nodes/executors/fullyconnected_config.hpp | 2 + .../src/nodes/executors/mlas/mlas_gemm.cpp | 14 +- .../intel_cpu/src/nodes/fullyconnected.h | 11 + src/plugins/intel_cpu/src/nodes/input.cpp | 349 +----------------- src/plugins/intel_cpu/src/nodes/input.h | 6 +- src/plugins/intel_cpu/src/nodes/reorder.cpp | 29 +- src/plugins/intel_cpu/src/nodes/reorder.h | 2 +- .../src/utils/clone_original_blob.cpp | 121 ++++++ .../intel_cpu/src/utils/clone_original_blob.h | 25 ++ .../intel_cpu/tests/unit/graph/dummy_node.hpp | 130 +++++-- .../tests/unit/graph/inplace_resolve_io.cpp | 20 +- .../tests/unit/graph/memory_state.cpp | 6 +- .../graph/merge_transpose_reorder_test.cpp | 8 +- .../graph/resolve_edge_conflicts_test.cpp | 8 +- 32 files changed, 847 insertions(+), 454 deletions(-) create mode 100644 src/plugins/intel_cpu/src/nodes/common/has_subnormals.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/common/has_subnormals.h create mode 100644 src/plugins/intel_cpu/src/nodes/common/subnormals_to_zero.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/common/subnormals_to_zero.h create mode 100644 src/plugins/intel_cpu/src/utils/clone_original_blob.cpp create mode 100644 src/plugins/intel_cpu/src/utils/clone_original_blob.h diff --git a/src/plugins/intel_cpu/src/cpu_memory.cpp b/src/plugins/intel_cpu/src/cpu_memory.cpp index 8e5fe8d72fd1f2..4c008df80bf25e 100644 --- a/src/plugins/intel_cpu/src/cpu_memory.cpp +++ b/src/plugins/intel_cpu/src/cpu_memory.cpp @@ -27,41 +27,9 @@ BlockedMemoryDescPtr IMemory::getDescWithType() const { } namespace { - inline void setSubnormalsToZero(float *data, size_t size) { - uint32_t *u32data = reinterpret_cast(data); - for (size_t i = 0; i < size; ++i) { - if ((u32data[i] & (0xFF << 23)) == 0) { - u32data[i] = 0; - } - } - } - void transferData(const IMemory& src, const IMemory& dst, bool ftz) { - node::Reorder::reorderData(src, dst); - - if (!ftz) { - return; - } - if (src.getDesc().getPrecision() != ov::element::f32 || dst.getDesc().getPrecision() == ov::element::bf16) { - return; - } - size_t offset = 0; - if (dst.getDesc().getType() & MemoryDescType::Dnnl) { - // here we can safely cast to DnnlMemoryDesc - auto dnnl_desc = dst.getDescWithType(); - auto desc = dnnl_desc->getDnnlDesc(); - dnnl::impl::memory_desc_wrapper wrapper(desc.get()); - offset = wrapper.offset0(); - if (wrapper.is_wino_desc() || wrapper.is_rnn_packed_desc()) { - return; - } - } - // actual FTZ - auto* memData = static_cast(dst.getData()); - memData += offset; - setSubnormalsToZero(memData, dst.getSize() / sizeof(float)); + node::Reorder::reorderData(src, dst, nullptr, ftz); } - } // namespace Memory::Memory(const dnnl::engine& eng, MemoryDescPtr desc, const void* data, bool pads_zeroing) : diff --git a/src/plugins/intel_cpu/src/cpu_memory.h b/src/plugins/intel_cpu/src/cpu_memory.h index 70e6713e36b886..ce2dd37c0e4d40 100644 --- a/src/plugins/intel_cpu/src/cpu_memory.h +++ b/src/plugins/intel_cpu/src/cpu_memory.h @@ -177,6 +177,12 @@ class IMemory { return static_cast(getData()); } + template + const typename element_type_traits::value_type* getDataAs() const { + OPENVINO_ASSERT(ET == getPrecision(), "get_data_ptr() called for incorrect element type."); + return static_cast::value_type*>(getData()); + } + virtual size_t getSize() const = 0; // in bytes virtual const Shape& getShape() const = 0; virtual const VectorDims& getStaticDims() const = 0; diff --git a/src/plugins/intel_cpu/src/graph.cpp b/src/plugins/intel_cpu/src/graph.cpp index e0573e310ac86c..9429da196c4831 100644 --- a/src/plugins/intel_cpu/src/graph.cpp +++ b/src/plugins/intel_cpu/src/graph.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -16,7 +17,9 @@ #include #include +#include "cpu_types.h" #include "edge.h" +#include "graph_context.h" #include "graph_dumper.h" #include "graph_optimizer.h" #include "infer_request.h" @@ -34,12 +37,14 @@ #include "openvino/core/model.hpp" #include "openvino/core/node.hpp" #include "openvino/core/type/element_type.hpp" +#include "ov_optional.hpp" #include "utils/debug_capabilities.h" #include "utils/general_utils.h" #include "utils/ngraph_utils.hpp" #include "utils/node_dumper.h" #include "utils/verbose.h" #include "utils/precision_support.h" +#include "utils/clone_original_blob.h" #include #include "common/primitive_desc_iface.hpp" @@ -325,6 +330,8 @@ void Graph::InitGraph(bool optimize) { ResolveComplexInplaceConflicts(); + PreProcessConstantInputs(); + SortTopologically(); const bool hasDynNodes = ProcessDynNodes(); @@ -900,6 +907,95 @@ bool Graph::ProcessDynNodes() { return containsDynamicNodes; } +// @todo add ascii diagram +void Graph::PreProcessConstantInputs() { + std::vector visited(graphNodes.size()); + + std::function(NodePtr, bool, bool)> visitConstantPath; + visitConstantPath = [this, &visitConstantPath, &visited](NodePtr node, + int inPlaceOutPort, + bool oneShotCopyPossible) -> ov::optional { + if (visited[node->getExecIndex()]) + return {}; + + visited[node->getExecIndex()] = true; + + if (!node->getParentEdges().empty()) { + for (size_t i = 0; i < node->getParentEdges().size(); i++) { + const auto edge = node->getParentEdgeAt(i); + const auto parent = node->getParentEdgeAt(0)->getParent(); + // keep track of inplace up by inplace output ports + inPlaceOutPort = inPlaceOutPort == parent->inPlaceOutPort(i) ? edge->parent_port : -1; + + return visitConstantPath(parent, inPlaceOutPort, oneShotCopyPossible); + } + } + + // that means this is an input node + OPENVINO_ASSERT(node->getType() == Type::Input, "Only Input node is expected to have no parent edges"); + + auto input = std::dynamic_pointer_cast(node); + MemoryCPtr inputMemory = input->getMemoryPtr(); + + InputPrepType prepType = requiresPreProcessing(*inputMemory, context, getEngine()); + + if (prepType == InputPrepType::None) { + return {}; + } + + const bool isInPlace = inPlaceOutPort >= 0; + + if (isInPlace && oneShotCopyPossible && !std::getenv("DISABLE_CLONE_POSTPONE")) { + // clone will be done by a node + return ov::optional(prepType); + } + + if (!isInPlace && prepType == InputPrepType::PutToNumaLocalCache && !std::getenv("DISABLE_CLONE_POSTPONE")) { + // no need for numa local copy, since current constant path is not inplace, so it will produce a new blob anyway + return {}; + } + + auto blobKey = [](std::shared_ptr input) { + const auto memory = input->getMemoryPtr(); + return input->getName() + + "_" + std::to_string(memory->getSize() * memory->getPrecision().size()) + + "_" + std::to_string(reinterpret_cast(memory->getData())); + }; + + auto create = [&]() { + return cloneBlob(*inputMemory, getEngine(), prepType == InputPrepType::FTZ); + }; + + auto weightCache = context->getWeightsCache(); + auto clone = weightCache ? *weightCache->findOrCreate(blobKey(input), create) + : create(); + + input->setMemoryPtr(clone); + + return {}; + }; + + for (auto& node : graphNodes) { + if (node->isConstant()) + continue; // constant nodes will be visited in scope of 'visitConstantPath' + + for (size_t i = 0; i < node->getParentEdges().size(); i++) { + const auto parent = node->getParentEdgeAt(i)->getParent(); + + if (!parent->isConstant()) + continue; + + bool oneShotCopyPossible = node->canPrepInput(i); + if (auto postponePreProcessing = visitConstantPath(parent, true, oneShotCopyPossible)) { + const auto preprocessing = *postponePreProcessing; + node->prepInput(i, preprocessing); + } + } + } + + return; +} + void Graph::PushInputData(const std::size_t& index, const ov::SoPtr& input) { if (!IsReady()) OPENVINO_THROW("Wrong state. Topology not ready."); auto input_itr = inputNodesMap.find(index); diff --git a/src/plugins/intel_cpu/src/graph.h b/src/plugins/intel_cpu/src/graph.h index 3f9debefe7e06c..148f88e612dd43 100644 --- a/src/plugins/intel_cpu/src/graph.h +++ b/src/plugins/intel_cpu/src/graph.h @@ -222,6 +222,8 @@ class Graph { void ResolveEdgeConflicts(); void ResolveComplexInplaceConflicts(); bool ProcessDynNodes(); + void PreProcessConstantInputs(); + void GroupParallelNodes(); void Allocate(const std::vector& syncNodesInds); void AllocateWithReuse(const std::vector& syncNodesInds); void CreatePrimitivesAndExecConstants() const; diff --git a/src/plugins/intel_cpu/src/node.cpp b/src/plugins/intel_cpu/src/node.cpp index 31c4a0d2a5b54d..894e8eb9de4d30 100644 --- a/src/plugins/intel_cpu/src/node.cpp +++ b/src/plugins/intel_cpu/src/node.cpp @@ -35,6 +35,7 @@ #include "memory_desc/dnnl_blocked_memory_desc.h" #include #include +#include "utils/clone_original_blob.h" using namespace dnnl; using namespace openvino; @@ -917,7 +918,9 @@ void Node::prepareMemory(dnnl::primitive_desc_iterator& itpd) { Node::prepareMemory(intDescs); } -MemoryPtr Node::prepareWeightMemory(DnnlMemoryDescPtr dstWeightDesc, DnnlMemoryDescPtr srcWeightDesc) { +MemoryPtr Node::prepareWeightMemory(DnnlMemoryDescPtr dstWeightDesc, + DnnlMemoryDescPtr srcWeightDesc, + InputPrepType preprocessing) { if (!getParentEdgeAt(1)->getParent()->isConstant()) OPENVINO_THROW("Weight input is not const for node ", getName(), "."); auto edgeMem = getSrcMemoryAtPort(1); @@ -933,10 +936,14 @@ MemoryPtr Node::prepareWeightMemory(DnnlMemoryDescPtr dstWeightDesc, DnnlMemoryD auto create = [&] () { Memory srcMemory{ getEngine(), srcWeightDesc, edgeMem->getData() }; - MemoryPtr _ptr = std::make_shared(getEngine(), dstWeightDesc); - node::Reorder::reorderData(srcMemory, *_ptr, context->getParamsCache()); + MemoryPtr weightsMem = std::make_shared(getEngine(), dstWeightDesc); - return _ptr; + node::Reorder::reorderData(srcMemory, + *weightsMem, + context->getParamsCache(), + preprocessing == InputPrepType::FTZ); + + return weightsMem; }; MemoryPtr ptr; diff --git a/src/plugins/intel_cpu/src/node.h b/src/plugins/intel_cpu/src/node.h index ff8bf87d993a74..b5bde5d6e07b84 100644 --- a/src/plugins/intel_cpu/src/node.h +++ b/src/plugins/intel_cpu/src/node.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include "cpu_memory.h" #include "cpu_shape.h" @@ -24,6 +25,7 @@ #include "utils/debug_capabilities.h" #include "utils/bit_util.hpp" #include "utils/debug_capabilities.h" +#include "utils/clone_original_blob.h" #include "graph_context.h" #include "nodes/executors/executor.hpp" @@ -269,6 +271,22 @@ class Node { return !hasEmptyInputTensors(); } + /** + * Return true if a node can perform a preprocessing for an input \idx + */ + virtual bool canPrepInput(size_t idx) const { + (void) idx; + return false; + } + + /** + * Require a node to perform \type preprocessing for an input \idx + */ + virtual void prepInput(size_t idx, InputPrepType type) { + (void) idx; + (void) type; + } + enum class ConstantType { Const, // Node is placed in a constant subgraph NoConst, // Node is placed in a non-constant subgraph @@ -740,7 +758,9 @@ class Node { virtual void prepareMemory(const DnnlMemoryDescPtr& intDesc, size_t indx); void prepareMemory(dnnl::primitive_desc_iterator& itpd); - MemoryPtr prepareWeightMemory(DnnlMemoryDescPtr dstWeightDesc, DnnlMemoryDescPtr srcWeightDesc = nullptr); + MemoryPtr prepareWeightMemory(DnnlMemoryDescPtr dstWeightDesc, + DnnlMemoryDescPtr srcWeightDesc = nullptr, + InputPrepType preprocessing = InputPrepType::None); bool isDynamic = false; diff --git a/src/plugins/intel_cpu/src/nodes/common/has_subnormals.cpp b/src/plugins/intel_cpu/src/nodes/common/has_subnormals.cpp new file mode 100644 index 00000000000000..c868c5d662b655 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/common/has_subnormals.cpp @@ -0,0 +1,270 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/core/visibility.hpp" + +#include "has_subnormals.h" +#include "cpu_memory.h" +#include "openvino/core/parallel.hpp" +#include "cpu/x64/jit_generator.hpp" + +using namespace dnnl; +using namespace dnnl::impl::cpu::x64; +using namespace Xbyak; + +namespace ov { +namespace intel_cpu { + +#if defined(OPENVINO_ARCH_X86_64) + +struct jit_has_subnormals_base : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_has_subnormals_base) + + typedef struct { + const float* src; + const size_t count; + bool hasSubnormals; + } args_t; + + void (*ker_)(const args_t *); + void operator()(const args_t* args) { assert(ker_); ker_(args); } + + jit_has_subnormals_base() : jit_generator(jit_name()) { + jit_ker_ = nullptr; + } + + virtual void create_ker() = 0; + +protected: + void foreach(const Xbyak::Reg64& idx, + size_t step, + const Xbyak::Reg64& end, + std::function && fn) { + Label loop, exit; + + L(loop); + cmp(idx, end); + jge(exit); + + fn(idx); + + add(idx, step); + jmp(loop); + L(exit); + } + + void copy_floats(const Xbyak::Reg64& dst, + const Xbyak::Reg64& src, + const Xbyak::Reg64& size) { + push(rsi); + push(r15); + + xor_(rsi, rsi); + + foreach(rsi, 1, size, [&, this](const Xbyak::Reg64& idx) { + mov(r15d, dword[src + idx * sizeof(float)]); + mov(dword[dst + idx * sizeof(float)], r15d); + }); + + pop(r15); + pop(rsi); + } + + void check_subnormals(const Xbyak::Reg64& src, const Xbyak::Ymm &exponent_mask, const Xbyak::Ymm &mantissa_mask, const Xbyak::Ymm &zero) { + auto a = ymm1; + auto b = ymm2; + auto c = ymm3; + + vmovdqu(a, yword[src]); // load 8 floats + vpand(b, a, mantissa_mask); // b = a & 00000000011111111111111111111111 + vpcmpeqd(b, b, zero); // if (b == 0) b = 1 else b = 0 + vpand(c, a, exponent_mask); // c = a & 01111111100000000000000000000000 + vpcmpeqd(c, c, zero); // if (c == 0) c = 1 else c = 0 + vptest(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0 + } + + void check_subnormals(const Xbyak::Reg64& src, const Xbyak::Xmm &exponent_mask, const Xbyak::Xmm &mantissa_mask, const Xbyak::Xmm &zero) { + auto a = xmm1; + auto b = xmm2; + auto c = xmm3; + + uni_vmovdqu(a, xword[src]); // load 4 floats + uni_vmovdqu(b, a); // b = a + uni_vmovdqu(c, a); // c = a + uni_vpand(b, b, mantissa_mask); // b = a & 00000000011111111111111111111111 + uni_vpcmpeqd(b, b, zero); // if (b == 0) b = 1 else b = 0 + uni_vpand(c, c, exponent_mask); // c = a & 01111111100000000000000000000000 + uni_vpcmpeqd(c, c, zero); // if (c == 0) c = 1 else c = 0 + uni_vtestps(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0 + } + +protected: + Label exit, has_subnormals, no_subnormals; + + const Reg64 ®_src = rax; + const Reg64 ®_dst = rbx; + const Reg64 ®_sz = rdx; + const Reg64 ®_idx = rsi; + const Reg64 ®_mask_addr = r15; + + static const uint32_t exponent_mask_data[8]; + static const uint32_t mantissa_mask_data[8]; +}; + +const uint32_t jit_has_subnormals_base::exponent_mask_data[8] = { + 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, + 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000 +}; + +const uint32_t jit_has_subnormals_base::mantissa_mask_data[8] = { + 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, + 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff +}; + +template +struct jit_has_subnormals : public jit_has_subnormals_base { + using Vmm = typename dnnl::impl::utils::conditional::type; + + const Vmm rmm4 = Vmm(4); + const Vmm rmm5 = Vmm(5); + const Vmm rmm6 = Vmm(6); + const int length = isa == sse41 ? 4 : 8; + + void create_ker() override { + jit_generator::create_kernel(); + ker_ = (decltype(ker_))jit_ker(); + } + + void generate() override final { // NOLINT + size_t const vlen = length; + const int sh_bits = std::ilogb(vlen); + + auto zero = rmm4; + auto exponent_mask = rmm5; + auto mantissa_mask = rmm6; + + preamble(); + + // Get arguments addresses + mov(reg_src, ptr[param1 + offsetof(args_t, src)]); + lea(reg_dst, ptr[param1 + offsetof(args_t, hasSubnormals)]); + mov(reg_sz, ptr[param1 + offsetof(args_t, count)]); + + // Initialize necessary consts + uni_vpxor(zero, zero, zero); + mov(reg_mask_addr, (size_t)exponent_mask_data); + uni_vmovdqu(exponent_mask, ptr[reg_mask_addr]); + mov(reg_mask_addr, (size_t)mantissa_mask_data); + uni_vmovdqu(mantissa_mask, ptr[reg_mask_addr]); + + // Main loop + xor_(reg_idx, reg_idx); + mov(r8, reg_sz); + shr(r8, sh_bits); + + foreach(reg_idx, 1, r8, [&, this](const Xbyak::Reg64& idx) { + check_subnormals(reg_src, exponent_mask, mantissa_mask, zero); + jnc(has_subnormals); + add(reg_src, sizeof(float) * vlen); + }); + + // Tail + shl(reg_idx, sh_bits); + sub(reg_sz, reg_idx); + test(reg_sz, reg_sz); + jz(exit); + + // use space on stack for 4 or 8 floats + sub(rsp, vlen * sizeof(float)); + mov(r8, rsp); + + uni_vmovdqu(ptr[r8], zero); + + copy_floats(r8, reg_src, reg_sz); + check_subnormals(r8, exponent_mask, mantissa_mask, zero); + jc(no_subnormals); + add(rsp, vlen * sizeof(float)); + + L(has_subnormals); + + mov(rax, 1); + mov(byte[reg_dst], al); + jmp(exit); + + L(no_subnormals); + add(rsp, vlen * sizeof(float)); + + L(exit); + + postamble(); + } +}; + +static std::shared_ptr createKernel() { + std::shared_ptr kernel; + if (mayiuse(cpu_isa_t::avx2)) { + kernel = std::make_shared>(); + } else if (mayiuse(cpu_isa_t::sse41)) { + kernel = std::make_shared>(); + } + + if (kernel) { + kernel->create_ker(); + if (!kernel->jit_ker()) + kernel = nullptr; + } + + return kernel; +} +#endif + +bool HasSubnormals::execute(const IMemory& src) { + const auto prec = src.getPrecision(); + const auto size = src.getShape().getElementsCount(); + + if (size == 0) + return false; + + if (prec != ov::element::f32) + return false; + + const uint32_t* u32data = src.getDataAs(); + +#if defined(OPENVINO_ARCH_X86_64) + static std::shared_ptr kernel = createKernel(); + if (kernel) { + static const size_t batch_size = 2048; + const size_t iterations_num = size / batch_size + 1; + + volatile bool has_subnormals = false; + + parallel_for(iterations_num, [&](int n) { + auto ptr = u32data + n * batch_size; + const jit_has_subnormals_base::args_t args = {reinterpret_cast(ptr), + std::min(batch_size, (size_t)(u32data + size - ptr)), + false}; + (*kernel)(&args); + // result is written to the input 'hasSubnormals' parameter + if (args.hasSubnormals) + has_subnormals = true; + }); + + return has_subnormals; + } +#endif + + // @todo optimize for ARM and other architectures + uint32_t mantissaMask = 0x007fffff; + uint32_t exponentMask = 0x7f800000; + for (size_t i = 0; i < size; ++i) { + if ((u32data[i] & exponentMask) == 0 && (u32data[i] & mantissaMask) != 0) { + return true; + } + } + + return false; +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/has_subnormals.h b/src/plugins/intel_cpu/src/nodes/common/has_subnormals.h new file mode 100644 index 00000000000000..fa25332a794835 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/common/has_subnormals.h @@ -0,0 +1,19 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include "cpu_memory.h" + +namespace ov { +namespace intel_cpu { + +struct jit_has_subnormals_base; + +class HasSubnormals { +public: + bool execute(const IMemory& src); +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/subnormals_to_zero.cpp b/src/plugins/intel_cpu/src/nodes/common/subnormals_to_zero.cpp new file mode 100644 index 00000000000000..8e04622efeb3d9 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/common/subnormals_to_zero.cpp @@ -0,0 +1,24 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "subnormals_to_zero.h" + +#include +#include + +namespace ov { +namespace intel_cpu { + +// @todo add optimized implementation as Eltwise / emitter +void setSubnormalsToZero(float* data, size_t size) { + uint32_t *u32data = reinterpret_cast(data); + for (size_t i = 0; i < size; ++i) { + if ((u32data[i] & (0xFF << 23)) == 0) { + u32data[i] = 0; + } + } +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/subnormals_to_zero.h b/src/plugins/intel_cpu/src/nodes/common/subnormals_to_zero.h new file mode 100644 index 00000000000000..263689dc19cfef --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/common/subnormals_to_zero.h @@ -0,0 +1,19 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include "cpu/x64/jit_generator.hpp" +#include "cpu_memory.h" + +using namespace dnnl; +using namespace dnnl::impl::cpu::x64; +using namespace Xbyak; + +namespace ov { +namespace intel_cpu { + +void setSubnormalsToZero(float* data, size_t size); + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/conv.cpp b/src/plugins/intel_cpu/src/nodes/conv.cpp index cbdb35db271622..b24af8633a936d 100644 --- a/src/plugins/intel_cpu/src/nodes/conv.cpp +++ b/src/plugins/intel_cpu/src/nodes/conv.cpp @@ -1423,7 +1423,11 @@ void Convolution::prepareParams() { auto it = primArgs.find(DNNL_ARG_WEIGHTS); if (it == primArgs.end() || !prevExecPtr || !execPtr->getWeightDesc()->isCompatible(*(prevExecPtr->getWeightDesc()))) { - primArgs[DNNL_ARG_WEIGHTS] = prepareWeightMemory(execPtr->getWeightDesc())->getPrimitive(); + // if (it != primArgs.end() && prevExecPtr && + // !execPtr->getWeightDesc()->isCompatible(*(prevExecPtr->getWeightDesc()))) { + primArgs[DNNL_ARG_WEIGHTS] = prepareWeightMemory(execPtr->getWeightDesc(), + nullptr, + weightsPrepType)->getPrimitive(); } } else { // non-const weight will be reordered by executor on every exec diff --git a/src/plugins/intel_cpu/src/nodes/conv.h b/src/plugins/intel_cpu/src/nodes/conv.h index a7cac9bced1241..30db988b30b382 100644 --- a/src/plugins/intel_cpu/src/nodes/conv.h +++ b/src/plugins/intel_cpu/src/nodes/conv.h @@ -62,6 +62,15 @@ class Convolution : public Node { return isGrouped && 1 == groupOC && 1 == groupIC; } + bool canPrepInput(size_t idx) const override { + return idx == 1; + } + + void prepInput(size_t idx, InputPrepType type) override { + OPENVINO_ASSERT(idx == 1, "Only weights input (1) can be preprocessed"); + this->weightsPrepType = type; + } + protected: ov::element::Type fusedEltwisePrecision(const NodePtr& fusingNode) const; void redefineOutputMemory(const std::vector &newOutputShapes) override; @@ -177,6 +186,7 @@ class Convolution : public Node { #else const dnnl::algorithm baseConvAlgorithm = dnnl::algorithm::convolution_auto; #endif + InputPrepType weightsPrepType = InputPrepType::None; }; } // namespace node diff --git a/src/plugins/intel_cpu/src/nodes/deconv.cpp b/src/plugins/intel_cpu/src/nodes/deconv.cpp index 57046a0a06d55b..cf1cf484bc642b 100644 --- a/src/plugins/intel_cpu/src/nodes/deconv.cpp +++ b/src/plugins/intel_cpu/src/nodes/deconv.cpp @@ -979,7 +979,7 @@ void Deconvolution::prepareParams() { auto it = primArgs.find(DNNL_ARG_WEIGHTS); if (it == primArgs.end() || !prevExecPtr || !execPtr->getWeightDesc()->isCompatible(*(prevExecPtr->getWeightDesc()))) { - primArgs[DNNL_ARG_WEIGHTS] = prepareWeightMemory(execPtr->getWeightDesc(), wghDesc)->getPrimitive(); + primArgs[DNNL_ARG_WEIGHTS] = prepareWeightMemory(execPtr->getWeightDesc(), wghDesc, weightsPrepType)->getPrimitive(); } } else { // non-const weight will be reordered by executor on every exec diff --git a/src/plugins/intel_cpu/src/nodes/deconv.h b/src/plugins/intel_cpu/src/nodes/deconv.h index d94bcd8bcaca13..51e2456fe26b8f 100644 --- a/src/plugins/intel_cpu/src/nodes/deconv.h +++ b/src/plugins/intel_cpu/src/nodes/deconv.h @@ -49,6 +49,10 @@ class Deconvolution : public Node { bool canBeExecutedInInt8() const override; const std::vector& getDefaultImplPriority() override; + void prepInput(size_t idx, InputPrepType type) override { + OPENVINO_ASSERT(idx == 1, "Only weights input (1) can be preprocessed"); + this->weightsPrepType = type; + } protected: AttrPtr initPrimitiveAttr() override; @@ -108,6 +112,7 @@ class Deconvolution : public Node { bool asymmetricPaddingAnd1x1 = false; bool is1x1 = false; bool isConstOutShape = false; + InputPrepType weightsPrepType = InputPrepType::None; }; } // namespace node diff --git a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected.hpp b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected.hpp index 266e78b3d46c77..b26d18b9e37d3a 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected.hpp @@ -120,12 +120,19 @@ class DnnlFCExecutor : public Executor { const PrimitivePtr newPrimitive, const MemoryPtr memory) { const auto newPrimMemDesc = newPrimitive->weightsDesc(); - if (currentPrimitive && currentPrimitive->weightsDesc()->isCompatible(*newPrimMemDesc)) + if (currentPrimitive && currentPrimitive->weightsDesc()->isCompatible(*newPrimMemDesc)) { return; + } originalMemDesc = Primitive::makeTransposedWeightDescriptor(originalMemDesc, newPrimMemDesc, m_attrs.weightsNonTransposed); - const auto weiMemory = utils::prepareWeightsMemory(originalMemDesc, newPrimMemDesc, memory, m_context, true); + const auto weiMemory = utils::prepareWeightsMemory(originalMemDesc, + newPrimMemDesc, + memory, + m_context, + true, + m_attrs.weightsPrepType); + m_primArgs[DNNL_ARG_WEIGHTS] = weiMemory->getPrimitive(); } diff --git a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected_primitive.cpp b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected_primitive.cpp index fcb70d4753b2ce..6eeff127cd846e 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected_primitive.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected_primitive.cpp @@ -432,7 +432,8 @@ DnnlShapeAgnosticDataPtr DnnlFCPrimitive::createShapeAgnosticData(const FCAttrs& weightsDesc, memory.at(ARG_WEI), context, - useDynamicQuantization); + useDynamicQuantization, + attrs.weightsPrepType); return std::make_shared(postOpData); } diff --git a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.cpp b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.cpp index fa273ac3d6c3ff..2debf2efb685e5 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.cpp @@ -13,6 +13,7 @@ #include "nodes/executors/executor.hpp" #include "nodes/reorder.h" #include "utils/cpu_utils.hpp" +#include "utils/clone_original_blob.h" namespace ov { namespace intel_cpu { @@ -22,23 +23,28 @@ MemoryPtr prepareWeightsMemory(const DnnlMemoryDescPtr srcWeightDesc, const DnnlMemoryDescPtr dstWeightDesc, const MemoryCPtr weightsMem, const ExecutorContext::CPtr context, - const bool needShiftSignedToUnsigned) { + const bool needShiftSignedToUnsigned, + const InputPrepType preprocessing) { const auto& eng = context->getEngine(); const auto& format = dstWeightDesc->serializeFormat(); - const auto privateWeightCache = context->getPrivateWeighCache(); + OPENVINO_ASSERT(privateWeightCache, "privateWeightCache is nullptr"); - if (privateWeightCache) { - auto itr = privateWeightCache->find(format); - if (privateWeightCache->end() != itr) { - return itr->second; - } + + auto itr = privateWeightCache->find(format); + if (privateWeightCache->end() != itr) { + return itr->second; } auto create = [&]() { // https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html?highlight=128#inputs-of-the-same-type-s8 auto src_wdt = srcWeightDesc->getPrecision(); auto dst_wdt = dstWeightDesc->getPrecision(); + + // if (cloneWeights) { + // std::cout << "Cloning the weightsMem in scope of FC node" << "\n"; + // } + if (needShiftSignedToUnsigned && src_wdt.is_integral_number() && src_wdt.is_signed() && dst_wdt.is_integral_number() && !dst_wdt.is_signed()) { assert(src_wdt.bitwidth() == dst_wdt.bitwidth()); @@ -72,7 +78,7 @@ MemoryPtr prepareWeightsMemory(const DnnlMemoryDescPtr srcWeightDesc, Memory srcMemory{eng, srcWeightDesc, weightsMem->getData()}; MemoryPtr _ptr = std::make_shared(eng, dstWeightDesc); auto rtCache = context->getRuntimeCache(); - node::Reorder::reorderData(srcMemory, *_ptr, rtCache); + node::Reorder::reorderData(srcMemory, *_ptr, rtCache, preprocessing == InputPrepType::FTZ); return _ptr; }; diff --git a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.hpp b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.hpp index 1f35caef0f74d2..be085789b1d20a 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.hpp @@ -10,6 +10,7 @@ #include "cpu_memory.h" #include "memory_desc/dnnl_memory_desc.h" #include "nodes/executors/executor.hpp" +#include "utils/clone_original_blob.h" namespace ov { namespace intel_cpu { @@ -18,7 +19,8 @@ MemoryPtr prepareWeightsMemory(const DnnlMemoryDescPtr srcWeightDesc, const DnnlMemoryDescPtr dstWeightDesc, const MemoryCPtr weightsMem, const ExecutorContext::CPtr context, - const bool needShiftSignedToUnsigned = false); + const bool needShiftSignedToUnsigned = false, + const InputPrepType preprocessing = InputPrepType::None); } // namespace utils } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_config.hpp b/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_config.hpp index ad6479597c6971..87843ddb3b15b9 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_config.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_config.hpp @@ -19,6 +19,8 @@ struct FCAttrs { bool withBias = false; bool weightsNonTransposed = false; bool sparseWeights = false; + // @todo combine all the weights related attributes into some WeightsPreprocess structure + InputPrepType weightsPrepType = InputPrepType::None; // @todo only memory descriptors should be a part of attributes // actual memory should be passed into "execute" or "prepareMemory" calls std::vector dequantizationScales; diff --git a/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_gemm.cpp b/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_gemm.cpp index a03bfe2649413a..8d775c14c84e3a 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_gemm.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_gemm.cpp @@ -15,6 +15,8 @@ #include "nodes/executors/memory_arguments.hpp" #include "nodes/executors/mlas/mlas_gemm.hpp" #include "utils/debug_capabilities.h" +#include "utils/clone_original_blob.h" +#include "nodes/common/subnormals_to_zero.h" namespace ov { namespace intel_cpu { @@ -25,7 +27,7 @@ using namespace ov::element; static MemoryPtr prepareWeightMemory(const MemoryPtr weightsMemory, const ExecutorContext::CPtr context, - const bool weightsTransposed) { + const FCAttrs& attrs) { DEBUG_LOG("MlasGemmExecutor: prepack weights"); const auto& wgtDims = weightsMemory->getStaticDims(); // Weights are transposed by MatMulConstTransposesExtraction @@ -36,6 +38,8 @@ static MemoryPtr prepareWeightMemory(const MemoryPtr weightsMemory, auto packedBsize = mlas_sgemm_pack_get_size(N, K); + const bool weightsTransposed = !attrs.weightsNonTransposed; + auto create = [&]() { float* weightPtr = weightsMemory->getDataAs(); size_t ldb = weightsTransposed ? K : N; @@ -44,6 +48,12 @@ static MemoryPtr prepareWeightMemory(const MemoryPtr weightsMemory, float* prepackedDst = _ptr->getDataAs(); DEBUG_LOG("MlasGemmExecutor: cache miss, perform packing"); mlas_sgemm_pack(weightsTransposed ? "T" : "F", N, K, ldb, weightPtr, prepackedDst); + + // @todo can we ommit flushing subnormals in case of down convertion, i.e. FP32 -> FP16? + if (attrs.weightsPrepType == InputPrepType::FTZ) { + setSubnormalsToZero(prepackedDst, packedBsize / sizeof(float)); + } + return _ptr; }; @@ -108,7 +118,7 @@ MlasGemmExecutor::MlasGemmExecutor(const FCAttrs& attrs, const ExecutorContext::CPtr context) : m_attrs(attrs), m_memoryArgs(memory), - packedWeights(prepareWeightMemory(memory.at(ARG_WEI), context, !attrs.weightsNonTransposed)) {} + packedWeights(prepareWeightMemory(memory.at(ARG_WEI), context, attrs)) {} bool MlasGemmExecutor::update(const MemoryArgs& memory) { const auto& weiDesc = memory.at(ARG_WEI)->getDescPtr(); diff --git a/src/plugins/intel_cpu/src/nodes/fullyconnected.h b/src/plugins/intel_cpu/src/nodes/fullyconnected.h index be29342b851988..da54f4b9c6893b 100644 --- a/src/plugins/intel_cpu/src/nodes/fullyconnected.h +++ b/src/plugins/intel_cpu/src/nodes/fullyconnected.h @@ -15,6 +15,7 @@ #include "nodes/executors/executor_factory.hpp" #include "nodes/executors/memory_arguments.hpp" #include "nodes/executors/fullyconnected_config.hpp" +#include "openvino/core/except.hpp" #include "post_ops.hpp" #include "openvino/runtime/threading/cpu_message.hpp" @@ -70,6 +71,16 @@ class FullyConnected : public Node { void prepareParams() override; void executeDynamicImpl(dnnl::stream strm) override; bool canBeExecutedInInt8() const override; + + bool canPrepInput(size_t idx) const override { + return idx == 1; + } + + void prepInput(size_t idx, InputPrepType type) override { + OPENVINO_ASSERT(idx == 1, "Only weights input (1) can be preprocessed"); + attrs.weightsPrepType = type; + } + void keepWeightsNonTransposed(bool weightsNonTransposed) { this->attrs.weightsNonTransposed = weightsNonTransposed; } diff --git a/src/plugins/intel_cpu/src/nodes/input.cpp b/src/plugins/intel_cpu/src/nodes/input.cpp index ea659ec1e31b84..a501c474b7df5f 100644 --- a/src/plugins/intel_cpu/src/nodes/input.cpp +++ b/src/plugins/intel_cpu/src/nodes/input.cpp @@ -4,219 +4,27 @@ #include "input.h" -#include "cpu/x64/jit_generator.hpp" -#include "openvino/core/parallel.hpp" +#include "cpu_memory.h" +#include "openvino/core/type/element_type.hpp" +#include "openvino/op/constant.hpp" #include "shape_inference/shape_inference_pass_through.hpp" -using namespace dnnl; -using namespace dnnl::impl::cpu::x64; -using namespace Xbyak; - namespace ov { namespace intel_cpu { namespace node { -#if defined(OPENVINO_ARCH_X86_64) -namespace { -struct jit_has_subnormals_base : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_has_subnormals_base) - - typedef struct { - const float* src; - const size_t count; - bool hasSubnormals; - } args_t; - - typedef void (*fn_t)(const args_t*); - - jit_has_subnormals_base() : jit_generator(jit_name()) { - jit_ker_ = nullptr; - } - - fn_t get() { - return jit_ker() || create_kernel() == dnnl::impl::status::success - ? (fn_t)jit_ker() - : nullptr; - } - -protected: - void foreach(const Xbyak::Reg64& idx, - size_t step, - const Xbyak::Reg64& end, - std::function && fn) { - Label loop, exit; - - L(loop); - cmp(idx, end); - jge(exit); - - fn(idx); - - add(idx, step); - jmp(loop); - L(exit); - } - - void copy_floats(const Xbyak::Reg64& dst, - const Xbyak::Reg64& src, - const Xbyak::Reg64& size) { - push(rsi); - push(r15); - - xor_(rsi, rsi); - - foreach(rsi, 1, size, [&, this](const Xbyak::Reg64& idx) { - mov(r15d, dword[src + idx * sizeof(float)]); - mov(dword[dst + idx * sizeof(float)], r15d); - }); - - pop(r15); - pop(rsi); - } - - void check_subnormals(const Xbyak::Reg64& src, const Xbyak::Ymm &exponent_mask, const Xbyak::Ymm &mantissa_mask, const Xbyak::Ymm &zero) { - auto a = ymm1; - auto b = ymm2; - auto c = ymm3; - - vmovdqu(a, yword[src]); // load 8 floats - vpand(b, a, mantissa_mask); // b = a & 00000000011111111111111111111111 - vpcmpeqd(b, b, zero); // if (b == 0) b = 1 else b = 0 - vpand(c, a, exponent_mask); // c = a & 01111111100000000000000000000000 - vpcmpeqd(c, c, zero); // if (c == 0) c = 1 else c = 0 - vptest(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0 - } - - void check_subnormals(const Xbyak::Reg64& src, const Xbyak::Xmm &exponent_mask, const Xbyak::Xmm &mantissa_mask, const Xbyak::Xmm &zero) { - auto a = xmm1; - auto b = xmm2; - auto c = xmm3; - - uni_vmovdqu(a, xword[src]); // load 4 floats - uni_vmovdqu(b, a); // b = a - uni_vmovdqu(c, a); // c = a - uni_vpand(b, b, mantissa_mask); // b = a & 00000000011111111111111111111111 - uni_vpcmpeqd(b, b, zero); // if (b == 0) b = 1 else b = 0 - uni_vpand(c, c, exponent_mask); // c = a & 01111111100000000000000000000000 - uni_vpcmpeqd(c, c, zero); // if (c == 0) c = 1 else c = 0 - uni_vtestps(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0 - } - -protected: - Label exit, has_subnormals, no_subnormals; - - const Reg64 ®_src = rax; - const Reg64 ®_dst = rbx; - const Reg64 ®_sz = rdx; - const Reg64 ®_idx = rsi; - const Reg64 ®_mask_addr = r15; - - static const uint32_t exponent_mask_data[8]; - static const uint32_t mantissa_mask_data[8]; -}; - -const uint32_t jit_has_subnormals_base::exponent_mask_data[8] = { - 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, - 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000 -}; - -const uint32_t jit_has_subnormals_base::mantissa_mask_data[8] = { - 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, - 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff -}; - -template -struct jit_has_subnormals : public jit_has_subnormals_base { - using Vmm = typename dnnl::impl::utils::conditional::type; - - const Vmm rmm4 = Vmm(4); - const Vmm rmm5 = Vmm(5); - const Vmm rmm6 = Vmm(6); - const int length = isa == sse41 ? 4 : 8; - - void generate() override final { // NOLINT - size_t const vlen = length; - const int sh_bits = std::ilogb(vlen); - - auto zero = rmm4; - auto exponent_mask = rmm5; - auto mantissa_mask = rmm6; - - preamble(); - - // Get arguments addresses - mov(reg_src, ptr[param1 + offsetof(args_t, src)]); - lea(reg_dst, ptr[param1 + offsetof(args_t, hasSubnormals)]); - mov(reg_sz, ptr[param1 + offsetof(args_t, count)]); - - // Initialize necessary consts - uni_vpxor(zero, zero, zero); - mov(reg_mask_addr, (size_t)exponent_mask_data); - uni_vmovdqu(exponent_mask, ptr[reg_mask_addr]); - mov(reg_mask_addr, (size_t)mantissa_mask_data); - uni_vmovdqu(mantissa_mask, ptr[reg_mask_addr]); - - // Main loop - xor_(reg_idx, reg_idx); - mov(r8, reg_sz); - shr(r8, sh_bits); - - foreach(reg_idx, 1, r8, [&, this](const Xbyak::Reg64& idx) { - check_subnormals(reg_src, exponent_mask, mantissa_mask, zero); - jnc(has_subnormals); - add(reg_src, sizeof(float) * vlen); - }); - - // Tail - shl(reg_idx, sh_bits); - sub(reg_sz, reg_idx); - test(reg_sz, reg_sz); - jz(exit); - - // use space on stack for 4 or 8 floats - sub(rsp, vlen * sizeof(float)); - mov(r8, rsp); - - uni_vmovdqu(ptr[r8], zero); - - copy_floats(r8, reg_src, reg_sz); - check_subnormals(r8, exponent_mask, mantissa_mask, zero); - jc(no_subnormals); - add(rsp, vlen * sizeof(float)); - - L(has_subnormals); - - mov(rax, 1); - mov(byte[reg_dst], al); - jmp(exit); - - L(no_subnormals); - add(rsp, vlen * sizeof(float)); - - L(exit); - - postamble(); - } -}; +static MemoryPtr createMemoryForConstantOp(const std::shared_ptr& constOp, dnnl::engine engine) { + Shape shape(constOp->get_shape().empty() ? ov::Shape(1, 1) : constOp->get_shape()); + CpuBlockedMemoryDesc memDesc(constOp->get_element_type(), shape); -jit_has_subnormals_base::fn_t jit_has_subnormals_function() { - if (mayiuse(cpu_isa_t::avx2)) { - static jit_has_subnormals generator; - static auto fn = generator.get(); - return fn; - } else if (mayiuse(cpu_isa_t::sse41)) { - static jit_has_subnormals generator; - static auto fn = generator.get(); - return fn; - } - return nullptr; + if (memDesc.getPrecision() == element::string) + return std::make_shared(engine, memDesc, constOp->get_data_ptr()); + else + return std::make_shared(engine, memDesc, constOp->get_data_ptr()); } -} // namespace -#endif - Input::Input(const std::shared_ptr& op, const GraphContext::CPtr context) - : Node(op, context, PassThroughShapeInferFactory()) { + : Node(op, context, PassThroughShapeInferFactory()) { if (!one_of(op->get_type_info(), op::v0::Parameter::get_type_info_static(), op::v0::Constant::get_type_info_static(), @@ -227,146 +35,17 @@ Input::Input(const std::shared_ptr& op, const GraphContext::CPtr conte op->get_type_name(), " with name ", op->get_friendly_name()); + // @todo is it required to hold a pointer to the original Constant to preserve a memory? constOp = ov::as_type_ptr(op); + if (constOp) { constant = ConstantType::Const; - cloneBlobIfRequired(); + memoryPtr = createMemoryForConstantOp(constOp, getEngine()); } else { constant = ConstantType::StrictNoConst; } } -void Input::cloneBlobIfRequired() { - Shape shape(constOp->get_shape().empty() ? ov::Shape(1, 1) : constOp->get_shape()); - const auto prec = constOp->get_element_type(); - const size_t size = shape.getElementsCount(); - CpuBlockedMemoryDesc memDesc(prec, shape); - - bool needFlushDenormalsToZero = true; - if (context->getConfig().DAZOn) { - // DAZ has been set, processor automatically converts all denormal source operands - // to a zero with the sign of the original operand before performing any - // computations on them, thus no need to flush them to zero manually - needFlushDenormalsToZero = false; - } - - auto cloneBlob = [&, this] () { - MemoryPtr memory; - - // CVS-74980 - // oneDNN always allocate 1byte for element type with bitWidth < 8 (u4,u1...) - // but ngraph Constant uses actual bitWidth for data storage allocation - // in that case we make a copy to avoid overflow - if (constOp->get_byte_size() >= memDesc.getCurrentMemSize()) { - if (constOp->get_element_type() == element::string) { - memory = std::make_shared(getEngine(), memDesc, constOp->get_data_ptr()); - } else { - memory = std::make_shared(getEngine(), memDesc, constOp->get_data_ptr()); - } - } else { - if (constOp->get_element_type() == element::string) { - memory = std::make_shared(getEngine(), memDesc); - auto src = constOp->get_data_ptr(); - auto dst = memory->getDataAs(); - std::copy(src, src + size, dst); - } else { - memory = std::make_shared(getEngine(), memDesc); - memcpy(memory->getData(), constOp->get_data_ptr(), constOp->get_byte_size()); - } - } - - MemoryPtr ptr; - if (memDesc.getPrecision() == element::string) { - ptr = std::make_shared(getEngine(), memDesc); - } else { - ptr = std::make_shared(getEngine(), memDesc); - } - ptr->load(*memory.get(), needFlushDenormalsToZero); - - return ptr; - }; - - auto isBlobAligned = [&, this] () { - const void *ptr = constOp->get_data_ptr(); - bool blobAlignedOnSSE = true; -#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) - // Majority of arithmetic and data processing instructions in legacy SSE isa requires - // the memory address in the operands must be aligned on 16-byte boundary. To ensure - // safely reusing ngraph const blob memory, need to check address alignment. - blobAlignedOnSSE = mayiuse(cpu_isa_t::avx2) || ((reinterpret_cast(ptr) & 15) == 0); -#endif - const bool blobAlignedWithPrec = prec.size() > 1 ? (reinterpret_cast(ptr) % prec.size()) == 0 : true; - return blobAlignedWithPrec && blobAlignedOnSSE; - }; - - // The presence of subnormals is better to determined at IR read time. - auto hasSubnormals = [&, this] () { - if (prec == ov::element::f32) { - uint32_t const *u32data = constOp->get_data_ptr(); - - if (!size) - return false; - -#if defined(OPENVINO_ARCH_X86_64) - if (auto fn = jit_has_subnormals_function()) { - static const size_t batch_size = 2048; - const size_t iterations_num = size / batch_size + 1; - - volatile bool has_subnormals = false; - - parallel_for(iterations_num, [&](int n) { - auto ptr = u32data + n * batch_size; - const jit_has_subnormals_base::args_t args = { - reinterpret_cast(ptr), - std::min(batch_size, (size_t)(u32data + size - ptr)), - false - }; - - fn(&args); - - if (args.hasSubnormals) - has_subnormals = true; - }); - - return has_subnormals; - } -#endif - - uint32_t mantissaMask = 0x007fffff; - uint32_t exponentMask = 0x7f800000; - for (size_t i = 0; i < size; ++i) { - if ((u32data[i] & exponentMask) == 0 && (u32data[i] & mantissaMask) != 0) { - return true; - } - } - } - return false; - }; - - auto blobKey = [&, this] () { - char ptr[32]; - snprintf(ptr, sizeof ptr, "%p", constOp->get_data_ptr()); - return getName() - + "_" + std::to_string(size * prec.size()) - + "_" + ptr; - }; - - const auto weightCache = context->getWeightsCache(); - const bool clone_is_not_needed = - prec != element::string && - // IRs already have all subnormals flushed to zero, but in - // read_model scenario with directly loaded original model still can have subnormals - isBlobAligned() && (!needFlushDenormalsToZero || !hasSubnormals()) && - // Blob should be cloned in cache only if original weights are stored on other numa node. - // This is possible only in multistream case on multisocket machine. - // TODO: don't clone blob for multisocket + multistream case if current stream is run on the numa node where original weights are stored. - (!weightCache || context->getNumNumaNodes() == 1 || context->getCPUStreamExecutor()->get_streams_num() == 1); - - memoryPtr = clone_is_not_needed ? std::make_shared(getEngine(), memDesc, constOp->get_data_ptr()) - : std::const_pointer_cast( - weightCache ? *weightCache->findOrCreate(blobKey(), cloneBlob) : cloneBlob()); -} - static std::vector createInputShapes(const Shape& shape, const Type type) { if (type == Type::Output) diff --git a/src/plugins/intel_cpu/src/nodes/input.h b/src/plugins/intel_cpu/src/nodes/input.h index 9b304e5a75a891..40d55aeab61dd9 100644 --- a/src/plugins/intel_cpu/src/nodes/input.h +++ b/src/plugins/intel_cpu/src/nodes/input.h @@ -25,8 +25,12 @@ class Input : public Node { void initSupportedPrimitiveDescriptors() override; void createPrimitive() override; bool created() const override; - void withMeanImage(); + + void setMemoryPtr(MemoryCPtr memory) { + memoryPtr = memory; + } + MemoryCPtr getMemoryPtr() const; void execute(dnnl::stream strm) override {} diff --git a/src/plugins/intel_cpu/src/nodes/reorder.cpp b/src/plugins/intel_cpu/src/nodes/reorder.cpp index 9b521cdb3b57c7..e199e0b1c41797 100644 --- a/src/plugins/intel_cpu/src/nodes/reorder.cpp +++ b/src/plugins/intel_cpu/src/nodes/reorder.cpp @@ -27,6 +27,7 @@ #include "utils/precision_support.h" #include "nodes/executors/executor.hpp" #include "nodes/executors/transpose_list.hpp" +#include "nodes/common/subnormals_to_zero.h" namespace ov { namespace intel_cpu { @@ -438,7 +439,7 @@ std::string Reorder::getReorderArgs(const MemoryDesc &parentDesc, const MemoryDe return inArgs + "_" + outArgs; } -void Reorder::reorderData(const IMemory &input, const IMemory &output, MultiCachePtr cache) { +void Reorder::reorderData(const IMemory &input, const IMemory &output, MultiCachePtr cache, bool ftz) { if (!input.getDesc().isDefined() || !output.getDesc().isDefined()) OPENVINO_THROW("Can't reorder data with dynamic shapes"); @@ -515,6 +516,32 @@ void Reorder::reorderData(const IMemory &input, const IMemory &output, MultiCach OPENVINO_THROW("Could not make onednn reorder."); } } + + if (!ftz) { + return; + } + + if (input.getDesc().getPrecision() != ov::element::f32 || output.getDesc().getPrecision() == ov::element::bf16) { + return; + } + + size_t offset = 0; + if (output.getDesc().getType() & MemoryDescType::Dnnl) { + // here we can safely cast to DnnlMemoryDesc + auto dnnl_desc = output.getDescWithType(); + auto desc = dnnl_desc->getDnnlDesc(); + dnnl::impl::memory_desc_wrapper wrapper(desc.get()); + offset = wrapper.offset0(); + if (wrapper.is_wino_desc() || wrapper.is_rnn_packed_desc()) { + return; + } + } + // actual FTZ + auto* memData = static_cast(output.getData()); + memData += offset; + + // @todo can we ommit flushing subnormals in case of down convertion, i.e. FP32 -> FP16? + setSubnormalsToZero(memData, output.getSize() / sizeof(float)); } } // namespace node diff --git a/src/plugins/intel_cpu/src/nodes/reorder.h b/src/plugins/intel_cpu/src/nodes/reorder.h index ab94b60b6a4a18..a067fc287112bd 100644 --- a/src/plugins/intel_cpu/src/nodes/reorder.h +++ b/src/plugins/intel_cpu/src/nodes/reorder.h @@ -52,7 +52,7 @@ class Reorder : public Node { static std::string getReorderArgs(const MemoryDesc &parentDesc, const MemoryDesc &childDesc); - static void reorderData(const IMemory &input, const IMemory &output, MultiCachePtr cache = nullptr); + static void reorderData(const IMemory &input, const IMemory &output, MultiCachePtr cache = nullptr, bool ftz = false); private: dnnl::reorder::primitive prim; diff --git a/src/plugins/intel_cpu/src/utils/clone_original_blob.cpp b/src/plugins/intel_cpu/src/utils/clone_original_blob.cpp new file mode 100644 index 00000000000000..4b5862a705f369 --- /dev/null +++ b/src/plugins/intel_cpu/src/utils/clone_original_blob.cpp @@ -0,0 +1,121 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "clone_original_blob.h" +#include + +#include "cpu_memory.h" +#include "graph_context.h" +#include "cpu/x64/jit_generator.hpp" +#include "memory_desc/dnnl_blocked_memory_desc.h" +#include "nodes/common/has_subnormals.h" +#include "openvino/core/parallel.hpp" +// #include "dnnl_extension_utils.h" +#include "memory_desc/cpu_memory_desc_utils.h" +#include "openvino/core/type/element_type.hpp" +#include "ov_optional.hpp" +#include "dnnl_extension_utils.h" +#include "utils/debug_capabilities.h" + +using namespace dnnl; +using namespace dnnl::impl::cpu::x64; +using namespace Xbyak; + +namespace ov { +namespace intel_cpu { + +MemoryPtr cloneBlob(const IMemory& blob, const dnnl::engine& engine, bool needFlushDenormalsToZero) { + const auto& memDesc = blob.getDesc(); + const auto prec = blob.getPrecision(); + const size_t size = blob.getShape().getElementsCount(); + MemoryPtr memory; + + // CVS-74980 + // oneDNN always allocate 1byte for element type with bitWidth < 8 (u4,u1...) + // but ngraph Constant uses actual bitWidth for data storage allocation + // in that case we make a copy to avoid overflow + if (blob.getSize() >= memDesc.getCurrentMemSize()) { + if (prec == element::string) { + memory = std::make_shared(engine, memDesc, blob.getDataAs()); + } else { + memory = std::make_shared(engine, memDesc, blob.getData()); + } + } else { + if (prec == element::string) { + memory = std::make_shared(engine, memDesc); + auto src = blob.getDataAs(); + auto dst = memory->getDataAs(); + std::copy(src, src + size, dst); + } else { + memory = std::make_shared(engine, memDesc); + memcpy(memory->getData(), blob.getData(), blob.getSize()); + } + } + + MemoryPtr ptr; + if (memDesc.getPrecision() == element::string) { + ptr = std::make_shared(engine, memDesc); + } else { + ptr = std::make_shared(engine, memDesc); + } + + ptr->load(*memory.get(), needFlushDenormalsToZero); + + return ptr; +} + +InputPrepType requiresPreProcessing(const IMemory& blob, GraphContext::CPtr context, const dnnl::engine& engine) { + const auto shape = blob.getShape(); + const auto prec = blob.getPrecision(); + + // DAZ has been set, processor automatically converts all denormal source operands + // to a zero with the sign of the original operand before performing any + // computations on them, thus no need to flush them to zero manually + bool needFlushDenormalsToZero = context->getConfig().DAZOn ? false : true; + + auto isBlobAligned = [&] () { + const void *ptr = blob.getData(); + bool blobAlignedOnSSE = true; +#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) + // Majority of arithmetic and data processing instructions in legacy SSE isa requires + // the memory address in the operands must be aligned on 16-byte boundary. To ensure + // safely reusing ngraph const blob memory, need to check address alignment. + blobAlignedOnSSE = mayiuse(dnnl::impl::cpu::x64::avx2) || ((reinterpret_cast(ptr) & 15) == 0); +#endif + const bool blobAlignedWithPrec = prec.size() > 1 ? (reinterpret_cast(ptr) % prec.size()) == 0 : true; + return blobAlignedWithPrec && blobAlignedOnSSE; + }; + + // @WARNING The order of the checks below matters + // The checks are ordered from lightweight to heavy + if (prec == element::string) { + DEBUG_LOG("Clone is necessary for a string Constant"); + return InputPrepType::SimpleClone; + } + + const bool mustFlushDenormalsToZero = needFlushDenormalsToZero && std::make_shared()->execute(blob); + if (mustFlushDenormalsToZero) { + DEBUG_LOG("Clone is necessary for Constant containing subnormals"); + return InputPrepType::FTZ; + } + + if (!isBlobAligned()) { + DEBUG_LOG("Clone is necessary for not aligned blobs"); + return InputPrepType::SimpleClone; + } + + if (context->getWeightsCache() && + context->getNumNumaNodes() > 1 && + context->getCPUStreamExecutor()->get_streams_num() > 1) { + DEBUG_LOG("Clone is necessary for multistream multisocket configuration"); + return InputPrepType::PutToNumaLocalCache; + } + + DEBUG_LOG("Clone is not required"); + + return InputPrepType::None; +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/utils/clone_original_blob.h b/src/plugins/intel_cpu/src/utils/clone_original_blob.h new file mode 100644 index 00000000000000..eb259f1bf5be4e --- /dev/null +++ b/src/plugins/intel_cpu/src/utils/clone_original_blob.h @@ -0,0 +1,25 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include "cpu_memory.h" +#include "graph_context.h" + +namespace ov { +namespace intel_cpu { + +enum InputPrepType { + FTZ, + PutToNumaLocalCache, + SimpleClone, + None +}; + +MemoryPtr cloneBlob(const IMemory& blob, const dnnl::engine& engine, bool needFlushDenormalsToZero); +InputPrepType requiresPreProcessing(const IMemory& blob, + GraphContext::CPtr context, + const dnnl::engine& engine); + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/tests/unit/graph/dummy_node.hpp b/src/plugins/intel_cpu/tests/unit/graph/dummy_node.hpp index db11f274623726..d0feacef3f251b 100644 --- a/src/plugins/intel_cpu/tests/unit/graph/dummy_node.hpp +++ b/src/plugins/intel_cpu/tests/unit/graph/dummy_node.hpp @@ -4,38 +4,84 @@ #pragma once +#include #include "cpu_shape.h" +#include "memory_desc/blocked_memory_desc.h" #include "node.h" #include "graph_context.h" #include "edge.h" -#include "openvino/core/shape.hpp" +#include "node_config.h" +#include "openvino/core/partial_shape.hpp" namespace ov { namespace intel_cpu { namespace cpu_unit_test { -class DummyNode : public Node { -public: - // dummy node of the same shape and precision to both input and output. - DummyNode(const ov::PartialShape& shape, - const ov::element::Type_t& prc, - const std::string& name, - const std::string& type, - const GraphContext::CPtr context, - LayoutType layout = LayoutType::ncsp, - int in_place_direction = Edge::LOOK::LOOK_UP, - bool is_executable = false) : - Node(type, - {ov::intel_cpu::Shape(shape)}, - {ov::intel_cpu::Shape(shape)}, - {prc}, - {prc}, - name, - context), - m_layout(layout), - m_inplace(in_place_direction), - m_is_executable(is_executable) { +static std::shared_ptr createSharedDesc(const ov::PartialShape& shape, + const ov::element::Type& precision, + const LayoutType layout = LayoutType::ncsp) { + const auto& layoutCreator = BlockedDescCreator::getCommonCreators().at(layout); + return layoutCreator->createSharedDesc(precision, {ov::intel_cpu::Shape(shape)}); +} + +struct NoOpExecutor { + void operator()(){ + // do nothing } +}; + +class GenericNode : public Node { +public: + GenericNode(const ov::PartialShape& shape, + const ov::element::Type_t& prc, + const std::string& name, + const std::string& type, + const GraphContext::CPtr context, + std::vector inplace_input, + std::vector inplace_output, + bool is_executable) + : Node(type, {ov::intel_cpu::Shape(shape)}, {ov::intel_cpu::Shape(shape)}, {prc}, {prc}, name, context), + inputPortsConfig(inplace_input), + outputPortsConfig(inplace_output), + m_is_executable(is_executable) {} + + // single input single output node of the same shape and precision to both input and output. + GenericNode(const ov::PartialShape& shape, + const ov::element::Type_t& prc, + const std::string& name, + const std::string& type, + const GraphContext::CPtr context, + LayoutType layout = LayoutType::ncsp, + int in_place_direction = Edge::LOOK::LOOK_UP, + bool is_executable = false) + : GenericNode(shape, + prc, + name, + type, + context, + std::vector{ + { + createSharedDesc(shape, prc, layout), + BlockedMemoryDesc::FULL_MASK, + in_place_direction == static_cast(Edge::LOOK_DOWN) || + in_place_direction == static_cast(Edge::LOOK_BOTH) + ? 0 + : -1, + false + } + }, + std::vector{ + { + createSharedDesc(shape, prc, layout), + BlockedMemoryDesc::FULL_MASK, + in_place_direction == static_cast(Edge::LOOK_UP) || + in_place_direction == static_cast(Edge::LOOK_BOTH) + ? 0 + : -1, + false + } + }, + is_executable) {} void getSupportedDescriptors() override { if (getParentEdges().size() != 1) @@ -48,28 +94,30 @@ class DummyNode : public Node { if (!supportedPrimitiveDescriptors.empty()) return; - NodeConfig config; - config.inConfs.resize(1); - config.outConfs.resize(1); + NodeConfig nodeConfig; + nodeConfig.inConfs.reserve(inputPortsConfig.size()); + nodeConfig.outConfs.reserve(outputPortsConfig.size()); - config.inConfs[0].inPlace(m_inplace == static_cast(Edge::LOOK::LOOK_DOWN) || - m_inplace == static_cast(Edge::LOOK::LOOK_BOTH)? 0 : -1); - config.inConfs[0].constant(false); - config.outConfs[0].inPlace(m_inplace == static_cast(Edge::LOOK::LOOK_UP) || - m_inplace == static_cast(Edge::LOOK::LOOK_BOTH) ? 0 : -1); - config.outConfs[0].constant(false); + for (const auto& config : inputPortsConfig) { + nodeConfig.inConfs.push_back(config); + } - auto layoutCreator = BlockedDescCreator::getCommonCreators().at(m_layout); - auto& originInputPrecisions = getOriginalInputPrecisions(); - config.inConfs[0].setMemDesc(layoutCreator->createSharedDesc(originInputPrecisions[0], getInputShapeAtPort(0))); - config.outConfs[0].setMemDesc(layoutCreator->createSharedDesc(originInputPrecisions[0], getOutputShapeAtPort(0))); + for (const auto& config : outputPortsConfig) { + nodeConfig.outConfs.push_back(config); + } - supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::undef); + supportedPrimitiveDescriptors.emplace_back(nodeConfig, impl_desc_type::undef); }; - bool isExecutable() const override {return m_is_executable;} + bool isExecutable() const override { + return m_is_executable; + } + void execute(dnnl::stream strm) override {}; - bool created() const override {return true;} + + bool created() const override { + return true; + } bool needPrepareParams() const override { return false; @@ -83,10 +131,10 @@ class DummyNode : public Node { using Node::Node; private: - LayoutType m_layout = LayoutType::ncsp; - int m_inplace = Edge::LOOK::LOOK_UP; + std::vector inputPortsConfig; + std::vector outputPortsConfig; bool m_is_executable = false; }; -} // namespace cpu_unit_test +} // namespace cpu_unit_test } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/tests/unit/graph/inplace_resolve_io.cpp b/src/plugins/intel_cpu/tests/unit/graph/inplace_resolve_io.cpp index a41cb4c4300d42..c976ae80b26f19 100644 --- a/src/plugins/intel_cpu/tests/unit/graph/inplace_resolve_io.cpp +++ b/src/plugins/intel_cpu/tests/unit/graph/inplace_resolve_io.cpp @@ -113,7 +113,7 @@ Edge Concat -> Result0 can share memory of inference output; Reshape1 -> Result1 inputNodes.push_back(std::make_shared(params[i], context)); } - auto dummy_softmax = std::make_shared( + auto dummy_softmax = std::make_shared( params[0]->get_output_partial_shape(0), testPrec, "Softmax0" /*name*/, "DummyNode" /*type*/, context, LayoutType::ncsp, 0/*look*/); auto concat = std::make_shared(ov::OutputVector{params[0], params[0]}, 0); // default, the connection will be reset by addEdge @@ -136,7 +136,7 @@ Edge Concat -> Result0 can share memory of inference output; Reshape1 -> Result1 hidden_size, RecurrentSequenceDirection::FORWARD); auto rnnseqNode = std::make_shared(rnnseq, context); - auto dummy_reshape = std::make_shared( + auto dummy_reshape = std::make_shared( rnnseq->get_output_partial_shape(0), testPrec, "Reshape1" /*name*/, "DummyNode" /*type*/, context, LayoutType::ncsp, Edge::LOOK::LOOK_BOTH); auto outputNode0 = std::make_shared(results.front(), context); @@ -196,13 +196,13 @@ expect edge Reshape0->Result1 to be referenced by its upstreams, instead of refe outputNodes.push_back(std::make_shared(results[i], context)); } - auto dummy_softmax = std::make_shared( + auto dummy_softmax = std::make_shared( testShape, testPrec, "softmax" /*name*/, "DummyNode" /*type*/, context, LayoutType::ncsp, 0/*look*/); - auto dummy_add = std::make_shared( + auto dummy_add = std::make_shared( testShape, testPrec, "add" /*name*/, "DummyNode" /*type*/, context, LayoutType::ncsp, 0/*look*/); - auto dummy_reshape = std::make_shared( + auto dummy_reshape = std::make_shared( testShape, testPrec, "reshape" /*name*/, "DummyNode" /*type*/, context, LayoutType::ncsp, Edge::LOOK::LOOK_BOTH); addEdge(inputNodes.front(), dummy_softmax, 0, 0); @@ -253,16 +253,16 @@ could get a chance to be referenced by infer request. outputNodes.push_back(std::make_shared(results[i], context)); } - auto dummy_softmax = std::make_shared( + auto dummy_softmax = std::make_shared( testShape, testPrec, "softmax" /*name*/, "DummyNode" /*type*/, context, LayoutType::ncsp, 0/*look*/); - auto dummy_add = std::make_shared( + auto dummy_add = std::make_shared( testShape, testPrec, "add" /*name*/, "DummyNode" /*type*/, context, LayoutType::ncsp, 0/*look*/); - auto dummy_reshape0 = std::make_shared( + auto dummy_reshape0 = std::make_shared( testShape, testPrec, "reshape0" /*name*/, "DummyNode" /*type*/, context, LayoutType::ncsp, Edge::LOOK::LOOK_BOTH); - auto dummy_reshape1 = std::make_shared( + auto dummy_reshape1 = std::make_shared( testShape, testPrec, "reshape1" /*name*/, "DummyNode" /*type*/, context, LayoutType::ncsp, Edge::LOOK::LOOK_BOTH); addEdge(inputNodes.front(), dummy_softmax, 0, 0); @@ -308,7 +308,7 @@ Edge Reshape0 -> Result0 cannot be referenced by its upstreams as its upstream i outputNodes.push_back(std::make_shared(results[i], context)); } - auto dummy_reshape = std::make_shared( + auto dummy_reshape = std::make_shared( testShape, testPrec, "reshape0" /*name*/, "DummyNode" /*type*/, context, LayoutType::ncsp, Edge::LOOK::LOOK_BOTH); addEdge(inputNodes.front(), dummy_reshape, 0, 0); diff --git a/src/plugins/intel_cpu/tests/unit/graph/memory_state.cpp b/src/plugins/intel_cpu/tests/unit/graph/memory_state.cpp index 5b9468ffc35e6f..0de73de77bc3e9 100644 --- a/src/plugins/intel_cpu/tests/unit/graph/memory_state.cpp +++ b/src/plugins/intel_cpu/tests/unit/graph/memory_state.cpp @@ -86,7 +86,7 @@ TEST(MemStateGraphTest, smoke_Check_Memory_Modification_Guard) { auto input_node = std::make_shared(param, context); auto memory_input = std::make_shared(read, context); - auto first_dummy = std::make_shared( + auto first_dummy = std::make_shared( test_shape, test_prec, "first_dummy", "DummyNode", @@ -96,7 +96,7 @@ TEST(MemStateGraphTest, smoke_Check_Memory_Modification_Guard) { true); auto memory_output = std::make_shared(assign, context); - auto second_dummy = std::make_shared( + auto second_dummy = std::make_shared( test_shape, test_prec, "second_dummy", "DummyNode", context, LayoutType::ncsp, Edge::LOOK::LOOK_UP, true); auto softmax_node = std::make_shared(softmax, context); auto output_node = std::make_shared(res, context); @@ -285,7 +285,7 @@ TEST(MemStateGraphTest, smoke_ShapeOf_no_Inplace_Conflicts) { auto input_node = std::make_shared(param, context); auto memory_input = std::make_shared(read, context); - auto dummy = std::make_shared( + auto dummy = std::make_shared( test_shape, test_prec, "first_dummy", "DummyNode", diff --git a/src/plugins/intel_cpu/tests/unit/graph/merge_transpose_reorder_test.cpp b/src/plugins/intel_cpu/tests/unit/graph/merge_transpose_reorder_test.cpp index 003aca979398fb..bc4008132b34c4 100644 --- a/src/plugins/intel_cpu/tests/unit/graph/merge_transpose_reorder_test.cpp +++ b/src/plugins/intel_cpu/tests/unit/graph/merge_transpose_reorder_test.cpp @@ -118,7 +118,7 @@ class MergeTransposeReorderCPUTest : public testing::WithParamInterface(params[0], m_context); - auto dummyNode1 = std::make_shared( + auto dummyNode1 = std::make_shared( testShape, precision, "reshape", "DummyNode", m_context, firstNodeLayout, firstNodeInplaceDirection); auto orderNode = std::make_shared(constOrder, m_context); @@ -131,7 +131,7 @@ class MergeTransposeReorderCPUTest : public testing::WithParamInterfaceget_output_shape(0); for (size_t i = 0; i < num_consumers; i++) { - auto dummyConsumer = std::make_shared(transpose_shape, + auto dummyConsumer = std::make_shared(transpose_shape, precision, "multiply", "DummyNode", @@ -232,7 +232,7 @@ class MergeTransposeReorderWithReshapeCPUTest : public MergeTransposeReorderCPUT }; auto inputNode = std::make_shared(param, m_context); - auto dummyNode1 = std::make_shared( + auto dummyNode1 = std::make_shared( testShape, precision, "before_reshape", "DummyNode", m_context, LayoutType::nspc, LOOK::LOOK_UP); auto reshapeConstNode = std::make_shared(reshape_const, m_context); @@ -250,7 +250,7 @@ class MergeTransposeReorderWithReshapeCPUTest : public MergeTransposeReorderCPUT const auto& transpose_shape = transpose->get_output_shape(0); for (size_t i = 0; i < num_consumers; i++) { - auto dummyConsumer = std::make_shared(transpose_shape, + auto dummyConsumer = std::make_shared(transpose_shape, precision, "multiply", "DummyNode", diff --git a/src/plugins/intel_cpu/tests/unit/graph/resolve_edge_conflicts_test.cpp b/src/plugins/intel_cpu/tests/unit/graph/resolve_edge_conflicts_test.cpp index b44194a3d5806c..b2ceaf2c457ccf 100644 --- a/src/plugins/intel_cpu/tests/unit/graph/resolve_edge_conflicts_test.cpp +++ b/src/plugins/intel_cpu/tests/unit/graph/resolve_edge_conflicts_test.cpp @@ -57,13 +57,13 @@ TEST(ResolveEdgeConflictsCPUTest, smoke_Run_ResolveEdgeConflicts) { auto inputNode = std::make_shared(params[0], context); auto outputNode = std::make_shared(results[0], context); auto concatNode = std::make_shared(concat, context); - auto dummyNode1 = std::make_shared( + auto dummyNode1 = std::make_shared( testShape, testPrec, "Dummy1", "DummyNode", context); - auto dummyNode2 = std::make_shared( + auto dummyNode2 = std::make_shared( testShape, testPrec, "Dummy2", "DummyNode", context); - auto dummyNode3 = std::make_shared( + auto dummyNode3 = std::make_shared( testShape, testPrec, "Dummy3", "DummyNode", context, LayoutType::ncsp, Edge::LOOK::LOOK_UP, true); - auto dummyNode4 = std::make_shared( + auto dummyNode4 = std::make_shared( testShape, testPrec, "Dummy4", "DummyNode", context, LayoutType::ncsp, 0, true); std::vector graphNodes;