Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[STFT][CPU] Improve performance of STFT for CPU by reusage RDFT jit Executor #26967

Merged
merged 33 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
41185fd
Init STFT CPU by reference impl call
mitruska Oct 8, 2024
7c56896
Enable ref tests for cpu
mitruska Oct 8, 2024
843a6b7
STFT impl with RDFT executor
mitruska Oct 14, 2024
9465aea
CPU Layer tests
mitruska Oct 18, 2024
ccee7c3
Merge remote-tracking branch 'upstream/master' into mitruska/stft_cpu
mitruska Nov 29, 2024
e522095
Fix int cast warning
mitruska Nov 29, 2024
d2f8eac
Merge remote-tracking branch 'upstream/master' into mitruska/stft_cpu
mitruska Dec 3, 2024
b9ede8b
Add 1D support, fix failing tests
mitruska Dec 3, 2024
2cc727d
Revert template tests changes
mitruska Dec 3, 2024
085681f
Add JIT RDFT executor
mitruska Dec 6, 2024
6cee735
Merge remote-tracking branch 'upstream/master' into mitruska/stft_cpu
mitruska Dec 10, 2024
9abd14d
Move RDFT key to rdft.h
mitruska Dec 11, 2024
0fb0b9f
Add createPrimitive() for STFT
mitruska Dec 11, 2024
1ca31ea
Remove prepare params
mitruska Dec 11, 2024
728b8f9
Adjust test abs_threshold
mitruska Dec 11, 2024
30addd3
Merge remote-tracking branch 'upstream/master' into mitruska/stft_cpu
mitruska Dec 11, 2024
8b17ad4
Merge upstream/master
mitruska Dec 12, 2024
bd9d080
Merge branch 'master' into mitruska/stft_cpu
mlukasze Dec 13, 2024
a286147
Init executor with nullptr
mitruska Dec 13, 2024
5eb0895
Move functions to anonymous namespace
mitruska Dec 13, 2024
57e147c
Remove axes_order attr from transpose func
mitruska Dec 13, 2024
00cef70
Use uint8_t instead of char in transpose_out
mitruska Dec 13, 2024
b2cbe90
Make members private
mitruska Dec 13, 2024
72acaad
Rename transpose_out to transpose_out4d
mitruska Dec 13, 2024
dc1ac1f
Merge branch 'mitruska/stft_cpu' of github.com:mitruska/openvino into…
mitruska Dec 13, 2024
5947a39
Optimize transpose of stft out by last dim copy
mitruska Dec 13, 2024
0f1ea39
Merge remote-tracking branch 'upstream/master' into mitruska/stft_cpu
mitruska Dec 13, 2024
789a472
Revert exposing rdft executors, and create common RDFTExecutor builder
mitruska Dec 16, 2024
f64afbe
Ensure primDesc is not null before setImplType
mitruska Dec 16, 2024
5a8c341
Use getScratchPadMem for interim results if transpose
mitruska Dec 16, 2024
d1ec726
Merge remote-tracking branch 'upstream/master' into mitruska/stft_cpu
mitruska Dec 16, 2024
e546026
Merge remote-tracking branch 'upstream/master' into mitruska/stft_cpu
mitruska Dec 19, 2024
9a02a0e
Remove custom abs_threshold for tests
mitruska Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/core/shape_inference/include/stft_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,17 @@ std::vector<TRShape> shape_infer(const STFT* op,
if (signal_shape.rank().is_dynamic()) {
return {signal_shape};
} else if (!frame_size || !frame_step) {
return {TRShape{signal_shape[0], -1, -1, 2}};
return {TRShape{signal_shape[0], TDim(ov::util::dim::inf_bound), TDim(ov::util::dim::inf_bound), 2}};
}

const auto& frame_size_val = (*frame_size)[0];
const auto& frame_step_val = (*frame_step)[0];

NODE_SHAPE_INFER_CHECK(op,
input_shapes,
0 < frame_size_val && frame_size_val < signal_shape[1].get_interval().get_max_val(),
0 < frame_size_val && (signal_shape[1].is_static()
mitruska marked this conversation as resolved.
Show resolved Hide resolved
? frame_size_val < signal_shape[1].get_length()
: frame_size_val < signal_shape[1].get_interval().get_max_val()),
"Provided frame size is ",
frame_size_val,
" but must be in range [1, ",
Expand All @@ -77,7 +79,7 @@ std::vector<TRShape> shape_infer(const STFT* op,
"].");

const auto& batch_dim = signal_shape[0];
const TDim frame_size_dim = TDim{frame_size_val};
const TDim frame_size_dim = static_cast<TDim>(frame_size_val);
const TDim signal_frame_size_diff = signal_shape[1] - frame_size_dim;
TDim fft_samples_dim = (frame_size_val / 2) + 1;

Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/cpu_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
{"IDFT", Type::DFT},
{"RDFT", Type::RDFT},
{"IRDFT", Type::RDFT},
{"STFT", Type::STFT},
{"Abs", Type::Math},
{"Acos", Type::Math},
{"Acosh", Type::Math},
Expand Down Expand Up @@ -339,6 +340,7 @@ std::string NameFromType(const Type type) {
CASE(ShuffleChannels);
CASE(DFT);
CASE(RDFT);
CASE(STFT);
CASE(Math);
CASE(CTCLoss);
CASE(Bucketize);
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/cpu_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ enum class Type {
ShuffleChannels,
DFT,
RDFT,
STFT,
Math,
CTCLoss,
Bucketize,
Expand Down
223 changes: 110 additions & 113 deletions src/plugins/intel_cpu/src/nodes/rdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,126 +842,123 @@ struct RDFTJitExecutor : public RDFTExecutor {
};
#endif

struct RDFTRefExecutor : public RDFTExecutor {
RDFTRefExecutor(bool inverse) : RDFTExecutor(inverse) {}
RDFTRefExecutor::RDFTRefExecutor(bool inverse) : RDFTExecutor(inverse) {}

private:
std::vector<float> generateTwiddlesDFT(size_t inputSize, size_t outputSize, enum dft_type type) override {
std::vector<float> twiddles(inputSize * outputSize * 2);
parallel_for2d(outputSize, inputSize, [&] (size_t k, size_t n) {
double angle = 2 * PI * k * n / inputSize;
if (!isInverse)
angle = -angle;
twiddles[(k * inputSize + n) * 2] = std::cos(angle);
twiddles[(k * inputSize + n) * 2 + 1] = std::sin(angle);
});
return twiddles;
}

void dftRealToComplex(float* inputPtr, const float* twiddlesPtr, float* outputPtr,
size_t inputSize, size_t outputSize, bool parallelize) {
auto dftIteration = [&] (size_t k) {
float real = 0, imag = 0;
for (size_t n = 0; n < inputSize; n++) {
float cos = twiddlesPtr[2 * (k * inputSize + n)];
float sin = twiddlesPtr[2 * (k * inputSize + n) + 1];
real += inputPtr[n] * cos;
imag += inputPtr[n] * sin;
}
outputPtr[2 * k] = real;
outputPtr[2 * k + 1] = imag;
};
if (parallelize) {
parallel_for(outputSize, dftIteration);
} else {
for (size_t k = 0; k < outputSize; k++) {
dftIteration(k);
}
}
}
std::vector<float> RDFTRefExecutor::generateTwiddlesDFT(size_t inputSize, size_t outputSize, enum dft_type type) {
std::vector<float> twiddles(inputSize * outputSize * 2);
parallel_for2d(outputSize, inputSize, [&] (size_t k, size_t n) {
double angle = 2 * PI * k * n / inputSize;
if (!isInverse)
angle = -angle;
twiddles[(k * inputSize + n) * 2] = std::cos(angle);
twiddles[(k * inputSize + n) * 2 + 1] = std::sin(angle);
});
return twiddles;
}

void dftComplexToComplex(float* inputPtr, const float* twiddlesPtr, float* outputPtr,
size_t inputSize, size_t signalSize, size_t outputSize, bool parallelize) {
auto dftIteration = [&] (size_t k) {
float real = 0, imag = 0;
for (size_t n = 0; n < inputSize; n++) {
float cos = twiddlesPtr[2 * (k * outputSize + n)];
float sin = twiddlesPtr[2 * (k * outputSize + n) + 1];
float inputReal = inputPtr[2 * n];
float inputImag = inputPtr[2 * n + 1];
real += inputReal * cos - inputImag * sin;
imag += inputImag * cos + inputReal * sin;
}
if (isInverse) {
float* inp = inputPtr + 2 * (inputSize - 2 + outputSize % 2);
for (size_t n = inputSize; n < signalSize; n++, inp -= 2) {
float cos = twiddlesPtr[2 * (k * outputSize + n)];
float sin = twiddlesPtr[2 * (k * outputSize + n) + 1];
float inputReal = inp[0];
float inputImag = -inp[1];
real += inputReal * cos - inputImag * sin;
imag += inputImag * cos + inputReal * sin;
}
real /= outputSize;
imag /= outputSize;
}
outputPtr[2 * k] = real;
outputPtr[2 * k + 1] = imag;
};
if (parallelize) {
parallel_for(outputSize, dftIteration);
} else {
for (size_t k = 0; k < outputSize; k++) {
dftIteration(k);
}
}
void RDFTRefExecutor::dftRealToComplex(float* inputPtr, const float* twiddlesPtr, float* outputPtr,
size_t inputSize, size_t outputSize, bool parallelize) {
auto dftIteration = [&] (size_t k) {
float real = 0, imag = 0;
for (size_t n = 0; n < inputSize; n++) {
float cos = twiddlesPtr[2 * (k * inputSize + n)];
float sin = twiddlesPtr[2 * (k * inputSize + n) + 1];
real += inputPtr[n] * cos;
imag += inputPtr[n] * sin;
}

void dftComplexToReal(float* inputPtr, const float* twiddlesPtr, float* outputPtr,
size_t inputSize, size_t signalSize, size_t outputSize, bool parallelize) {
auto dftIteration = [&] (size_t k) {
float real = 0;
for (size_t n = 0; n < inputSize; n++) {
float cos = twiddlesPtr[2 * (k * outputSize + n)];
float sin = twiddlesPtr[2 * (k * outputSize + n) + 1];
float inputReal = inputPtr[2 * n];
float inputImag = inputPtr[2 * n + 1];
real += inputReal * cos - inputImag * sin;
}
if (isInverse) {
float* inp = inputPtr + 2 * (inputSize - 2 + outputSize % 2);
for (size_t n = inputSize; n < signalSize; n++, inp -= 2) {
float cos = twiddlesPtr[2 * (k * outputSize + n)];
float sin = twiddlesPtr[2 * (k * outputSize + n) + 1];
float inputReal = inp[0];
float inputImag = inp[1];
real += inputReal * cos + inputImag * sin;
}
real /= outputSize;
}
outputPtr[k] = real;
};
if (parallelize) {
parallel_for(outputSize, dftIteration);
} else {
for (size_t k = 0; k < outputSize; k++) {
dftIteration(k);
}
outputPtr[2 * k] = real;
outputPtr[2 * k + 1] = imag;
};
if (parallelize) {
parallel_for(outputSize, dftIteration);
} else {
for (size_t k = 0; k < outputSize; k++) {
dftIteration(k);
}
}
}

void RDFTRefExecutor::dftComplexToComplex(float* inputPtr, const float* twiddlesPtr, float* outputPtr,
size_t inputSize, size_t signalSize, size_t outputSize, bool parallelize) {
auto dftIteration = [&] (size_t k) {
float real = 0, imag = 0;
for (size_t n = 0; n < inputSize; n++) {
float cos = twiddlesPtr[2 * (k * outputSize + n)];
float sin = twiddlesPtr[2 * (k * outputSize + n) + 1];
float inputReal = inputPtr[2 * n];
float inputImag = inputPtr[2 * n + 1];
real += inputReal * cos - inputImag * sin;
imag += inputImag * cos + inputReal * sin;
}
if (isInverse) {
float* inp = inputPtr + 2 * (inputSize - 2 + outputSize % 2);
for (size_t n = inputSize; n < signalSize; n++, inp -= 2) {
float cos = twiddlesPtr[2 * (k * outputSize + n)];
float sin = twiddlesPtr[2 * (k * outputSize + n) + 1];
float inputReal = inp[0];
float inputImag = -inp[1];
real += inputReal * cos - inputImag * sin;
imag += inputImag * cos + inputReal * sin;
}
real /= outputSize;
imag /= outputSize;
}

void dft(float* inputPtr, const float* twiddlesPtr, float* outputPtr,
size_t inputSize, size_t signalSize, size_t outputSize,
enum dft_type type, bool parallelize) override {
if (type == real_to_complex) {
dftRealToComplex(inputPtr, twiddlesPtr, outputPtr, inputSize, outputSize, parallelize);
} else if (type == complex_to_complex) {
dftComplexToComplex(inputPtr, twiddlesPtr, outputPtr, inputSize, signalSize, outputSize, parallelize);
} else if (type == complex_to_real) {
dftComplexToReal(inputPtr, twiddlesPtr, outputPtr, inputSize, signalSize, outputSize, parallelize);
outputPtr[2 * k] = real;
outputPtr[2 * k + 1] = imag;
};
if (parallelize) {
parallel_for(outputSize, dftIteration);
} else {
for (size_t k = 0; k < outputSize; k++) {
dftIteration(k);
}
}
}

void RDFTRefExecutor::dftComplexToReal(float* inputPtr, const float* twiddlesPtr, float* outputPtr,
size_t inputSize, size_t signalSize, size_t outputSize, bool parallelize) {
auto dftIteration = [&] (size_t k) {
float real = 0;
for (size_t n = 0; n < inputSize; n++) {
float cos = twiddlesPtr[2 * (k * outputSize + n)];
float sin = twiddlesPtr[2 * (k * outputSize + n) + 1];
float inputReal = inputPtr[2 * n];
float inputImag = inputPtr[2 * n + 1];
real += inputReal * cos - inputImag * sin;
}
if (isInverse) {
float* inp = inputPtr + 2 * (inputSize - 2 + outputSize % 2);
for (size_t n = inputSize; n < signalSize; n++, inp -= 2) {
float cos = twiddlesPtr[2 * (k * outputSize + n)];
float sin = twiddlesPtr[2 * (k * outputSize + n) + 1];
float inputReal = inp[0];
float inputImag = inp[1];
real += inputReal * cos + inputImag * sin;
}
real /= outputSize;
}
};
outputPtr[k] = real;
};
if (parallelize) {
parallel_for(outputSize, dftIteration);
} else {
for (size_t k = 0; k < outputSize; k++) {
dftIteration(k);
}
}
}

void RDFTRefExecutor::dft(float* inputPtr, const float* twiddlesPtr, float* outputPtr,
size_t inputSize, size_t signalSize, size_t outputSize,
enum dft_type type, bool parallelize) {
if (type == real_to_complex) {
dftRealToComplex(inputPtr, twiddlesPtr, outputPtr, inputSize, outputSize, parallelize);
} else if (type == complex_to_complex) {
dftComplexToComplex(inputPtr, twiddlesPtr, outputPtr, inputSize, signalSize, outputSize, parallelize);
} else if (type == complex_to_real) {
dftComplexToReal(inputPtr, twiddlesPtr, outputPtr, inputSize, signalSize, outputSize, parallelize);
}
}

struct RDFTKey {
bool isInverse;
Expand Down
38 changes: 38 additions & 0 deletions src/plugins/intel_cpu/src/nodes/rdft.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,44 @@ struct RDFTExecutor {
enum dft_type type, bool useFFT);
};

struct RDFTRefExecutor : public RDFTExecutor {
a-sidorova marked this conversation as resolved.
Show resolved Hide resolved
RDFTRefExecutor(bool inverse);

private:
std::vector<float> generateTwiddlesDFT(size_t inputSize, size_t outputSize, enum dft_type type) override;
void dftRealToComplex(float* inputPtr,
const float* twiddlesPtr,
float* outputPtr,
size_t inputSize,
size_t outputSize,
bool parallelize);

void dftComplexToComplex(float* inputPtr,
const float* twiddlesPtr,
float* outputPtr,
size_t inputSize,
size_t signalSize,
size_t outputSize,
bool parallelize);

void dftComplexToReal(float* inputPtr,
const float* twiddlesPtr,
float* outputPtr,
size_t inputSize,
size_t signalSize,
size_t outputSize,
bool parallelize);

void dft(float* inputPtr,
const float* twiddlesPtr,
float* outputPtr,
size_t inputSize,
size_t signalSize,
size_t outputSize,
enum dft_type type,
bool parallelize) override;
};

class RDFT : public Node {
public:
RDFT(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context);
Expand Down
Loading
Loading