From ab3546b9363de4ed59a4abf6313867d36a6bb873 Mon Sep 17 00:00:00 2001 From: Max Wittal Date: Mon, 20 Nov 2023 22:10:52 +0700 Subject: [PATCH] pos encoding --- include/mmx/pos/encoding.h | 29 +++++++++ src/pos/encoding.cpp | 122 +++++++++++++++++++++++++++++++++++++ test/CMakeLists.txt | 2 + test/test_encoding.cpp | 55 +++++++++++++++++ 4 files changed, 208 insertions(+) create mode 100644 include/mmx/pos/encoding.h create mode 100644 src/pos/encoding.cpp create mode 100644 test/test_encoding.cpp diff --git a/include/mmx/pos/encoding.h b/include/mmx/pos/encoding.h new file mode 100644 index 000000000..73831a3b0 --- /dev/null +++ b/include/mmx/pos/encoding.h @@ -0,0 +1,29 @@ +/* + * encoding.h + * + * Created on: Nov 20, 2023 + * Author: mad + */ + +#ifndef INCLUDE_MMX_POS_ENCODING_H_ +#define INCLUDE_MMX_POS_ENCODING_H_ + +#include +#include +#include +#include + + +namespace mmx { +namespace pos { + +std::vector encode(const std::vector& symbols, uint64_t& total_bits); + +std::vector decode(const std::vector& bit_stream, const uint64_t num_symbols); + + + +} // pos +} // mmx + +#endif /* INCLUDE_MMX_POS_ENCODING_H_ */ diff --git a/src/pos/encoding.cpp b/src/pos/encoding.cpp new file mode 100644 index 000000000..8cea759c9 --- /dev/null +++ b/src/pos/encoding.cpp @@ -0,0 +1,122 @@ +/* + * encoding.cpp + * + * Created on: Nov 20, 2023 + * Author: mad + */ + +#include + + +namespace mmx { +namespace pos { + +std::pair encode_symbol(const uint8_t sym) +{ + switch(sym) { + case 0: return std::make_pair(0, 2); + case 1: return std::make_pair(1, 2); + case 2: return std::make_pair(2, 2); + } + const uint32_t index = sym / 3; + const uint32_t mod = sym % 3; + + if(index > 15) { + throw std::logic_error("symbol out of range"); + } + uint32_t out = uint32_t(-1) >> (32 - 2 * index); + out |= mod << (2 * index); + return std::make_pair(out, 2 * index + 2); +} + +std::pair decode_symbol(const uint32_t bits) +{ + switch(bits & 3) { + case 0: return std::make_pair(0, 2); + case 1: return std::make_pair(1, 2); + case 2: return std::make_pair(2, 2); + } + uint32_t shift = bits; + + for(uint32_t index = 0; index < 16; ++index) + { + const auto mod = shift & 3; + if(mod == 3) { + shift >>= 2; + } else { + return std::make_pair(3 * index + mod, 2 * index + 2); + } + } + return std::make_pair(48, 32); +} + +std::vector encode(const std::vector& symbols, uint64_t& total_bits) +{ + std::vector out; + + total_bits = 0; + uint32_t offset = 0; + uint64_t buffer = 0; + + for(const auto sym : symbols) + { + const auto bits = encode_symbol(sym); + buffer |= uint64_t(bits.first) << offset; + + const auto end = offset + bits.second; + if(end >= 64) { + out.push_back(buffer); + buffer = 0; + } + if(end > 64) { + buffer = bits.first >> (64 - offset); + } + offset = end % 64; + + total_bits += bits.second; + } + if(offset) { + out.push_back(buffer); + } + return out; +} + +std::vector decode(const std::vector& bit_stream, const uint64_t num_symbols) +{ + std::vector out; + out.reserve(num_symbols); + + uint32_t bits = 0; + uint64_t offset = 0; + uint64_t buffer = 0; + + while(out.size() < num_symbols) + { + if(bits <= 32) { + const auto index = offset / 64; + if(index < bit_stream.size()) { + buffer |= uint64_t((bit_stream[index] >> (offset % 64)) & 0xFFFFFFFF) << bits; + offset += 32; + bits += 32; + } else if(bits == 0) { + throw std::logic_error("bit stream underflow"); + } + } + const auto sym = decode_symbol(buffer); + out.push_back(sym.first); + + if(sym.second > bits) { + throw std::logic_error("symbol decode error"); + } + buffer >>= sym.second; + bits -= sym.second; + } + return out; +} + + + + + +} // pos +} // mmx diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 25cc163b5..6b7949148 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -8,6 +8,7 @@ add_executable(test_swap_algo test_swap_algo.cpp) add_executable(test_database_reads test_database_reads.cpp) add_executable(test_mem_hash test_mem_hash.cpp) add_executable(test_pos_compute test_pos_compute.cpp) +add_executable(test_encoding test_encoding.cpp) add_executable(database_fill database_fill.cpp) add_executable(mmx_tests mmx_tests.cpp) @@ -23,6 +24,7 @@ target_link_libraries(test_database_reads mmx_db mmx_iface) target_link_libraries(database_fill mmx_db mmx_iface) target_link_libraries(test_mem_hash vnx_base mmx_pos) target_link_libraries(test_pos_compute mmx_iface mmx_pos) +target_link_libraries(test_encoding mmx_pos) target_link_libraries(mmx_tests mmx_iface) target_link_libraries(vm_engine_tests mmx_vm) diff --git a/test/test_encoding.cpp b/test/test_encoding.cpp new file mode 100644 index 000000000..1c4c4bd9e --- /dev/null +++ b/test/test_encoding.cpp @@ -0,0 +1,55 @@ +/* + * test_encoding.cpp + * + * Created on: Nov 20, 2023 + * Author: mad + */ + +#include + +#include + + +int main(int argc, char** argv) +{ + const int num_symbols = argc > 1 ? ::atoi(argv[1]) : 4096; + + std::vector symbols; + + for(int i = 0; i < num_symbols; ++i) + { + uint8_t sym = 0; + const auto ticket = ::rand() % 1000; + if(ticket < 900) { + sym = ticket % 3; + } else if(ticket < 990) { + sym = 3 + ticket % 3; + } else { + sym = 6 + ticket % 3; + } + symbols.push_back(sym); + } + + for(auto sym : symbols) { + std::cout << int(sym) << " "; + } + std::cout << std::endl; + + uint64_t total_bits = 0; + const auto bit_stream = mmx::pos::encode(symbols, total_bits); + + std::cout << "symbols = " << num_symbols << std::endl; + std::cout << "bit_stream = " << (total_bits + 7) / 8 << " bytes, " << double(total_bits) / num_symbols << " bits / symbol" << std::endl; + + const auto test = mmx::pos::decode(bit_stream, symbols.size()); + + if(test != symbols) { + for(auto sym : test) { + std::cout << int(sym) << " "; + } + throw std::logic_error("test != symbols"); + } + return 0; +} + +