Skip to content

Commit

Permalink
use options
Browse files Browse the repository at this point in the history
  • Loading branch information
gutenye committed May 12, 2024
1 parent 2e897f6 commit 2062928
Show file tree
Hide file tree
Showing 17 changed files with 168 additions and 162 deletions.
1 change: 1 addition & 0 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ BasedOnStyle: Google
ColumnLimit: 120
AccessModifierOffset: -2
IncludeBlocks: Merge
SpaceBeforeCpp11BracedList: true
7 changes: 0 additions & 7 deletions assets/config.txt

This file was deleted.

19 changes: 10 additions & 9 deletions packages/react-native/cpp/db_post_process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,14 @@ float PolygonScoreAcc(std::vector<cv::Point> contour, cv::Mat pred) {
return score;
}

std::vector<std::vector<std::vector<int>>> BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
std::map<std::string, double> Config) {
std::vector<std::vector<std::vector<int>>> BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap, Options &options) {
const int min_size = 3;
const int max_candidates = 1000;
const float box_thresh = static_cast<float>(Config["det_db_box_thresh"]);
const float unclip_ratio = static_cast<float>(Config["det_db_unclip_ratio"]);
const int det_use_polygon_score = int(Config["det_use_polygon_score"]);
// const float box_thresh = static_cast<float>(options.detection_box_threshold);
// const float unclip_ratio = static_cast<float>(options.detection_unclip_ratiop);
const float box_thresh = options.detection_box_threshold;
const float unclip_ratio = options.detection_unclip_ratiop;
const int det_use_polygon_score = options.detection_use_polygon_score;

int width = bitmap.cols;
int height = bitmap.rows;
Expand Down Expand Up @@ -256,10 +257,10 @@ std::vector<std::vector<std::vector<int>>> BoxesFromBitmap(const cv::Mat pred, c
std::vector<std::vector<int>> intcliparray;

for (int num_pt = 0; num_pt < 4; num_pt++) {
std::vector<int> a{static_cast<int>(clamp(roundf(cliparray[num_pt][0] / float(width) * float(dest_width)),
float(0), float(dest_width))),
static_cast<int>(clamp(roundf(cliparray[num_pt][1] / float(height) * float(dest_height)),
float(0), float(dest_height)))};
std::vector<int> a {static_cast<int>(clamp(roundf(cliparray[num_pt][0] / float(width) * float(dest_width)),
float(0), float(dest_width))),
static_cast<int>(clamp(roundf(cliparray[num_pt][1] / float(height) * float(dest_height)),
float(0), float(dest_height)))};
intcliparray.push_back(a);
}
boxes.push_back(intcliparray);
Expand Down
4 changes: 2 additions & 2 deletions packages/react-native/cpp/db_post_process.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "shared.h"

template <class T>
T clamp(T x, T min, T max) {
Expand All @@ -48,8 +49,7 @@ std::vector<std::vector<float>> GetMiniBoxes(cv::RotatedRect box, float &ssid);

float BoxScoreFast(std::vector<std::vector<float>> box_array, cv::Mat pred);

std::vector<std::vector<std::vector<int>>> BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
std::map<std::string, double> Config);
std::vector<std::vector<std::vector<int>>> BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap, Options &options);

std::vector<std::vector<std::vector<int>>> FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
float ratio_h, float ratio_w, cv::Mat srcimg);
27 changes: 11 additions & 16 deletions packages/react-native/cpp/det_process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ cv::Mat DetResizeImg(const cv::Mat img, int max_size_len, std::vector<float> &ra
return resize_img;
}

DetPredictor::DetPredictor(const std::string &modelDir, const int cpuThreadNum, const std::string &cpuPowerMode)
: m_model_path{modelDir} {}
DetPredictor::DetPredictor(Options &options, const int cpuThreadNum, const std::string &cpuPowerMode)
: m_options {options} {}

ImageRaw DetPredictor::Preprocess(const cv::Mat &srcimg, const int max_side_len) {
cv::Mat img = DetResizeImg(srcimg, max_side_len, ratio_hw_);
Expand All @@ -75,56 +75,51 @@ ImageRaw DetPredictor::Preprocess(const cv::Mat &srcimg, const int max_side_len)
const float *dimg = reinterpret_cast<const float *>(img_fp.data);
NHWC3ToNC3HW(dimg, data0.data(), img_fp.rows * img_fp.cols, mean, scale);

ImageRaw image_raw{.data = data0, .width = img_fp.cols, .height = img_fp.rows, .channels = 3};
ImageRaw image_raw {.data = data0, .width = img_fp.cols, .height = img_fp.rows, .channels = 3};

return image_raw;
}

std::vector<std::vector<std::vector<int>>> DetPredictor::Postprocess(ModelOutput &model_output, const cv::Mat &srcimg,
std::map<std::string, double> Config,
int det_db_use_dilate) {
Options &options) {
auto height = model_output.shape[2];
auto width = model_output.shape[3];
cv::Mat pred_map = cv::Mat(height, width, CV_32F, model_output.data.data());
cv::Mat cbuf_map;
pred_map.convertTo(cbuf_map, CV_8UC1, 255.0f);

const double threshold = double(Config["det_db_thresh"]) * 255;
const double threshold = options.detection_threshold * 255;
const double max_value = 255;
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, max_value, cv::THRESH_BINARY);
if (det_db_use_dilate == 1) {
if (options.detection_use_dilate == 1) {
cv::Mat dilation_map;
cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::dilate(bit_map, dilation_map, dila_ele);
bit_map = dilation_map;
}
auto boxes = BoxesFromBitmap(pred_map, bit_map, Config);
auto boxes = BoxesFromBitmap(pred_map, bit_map, options);

std::vector<std::vector<std::vector<int>>> filter_boxes = FilterTagDetRes(boxes, ratio_hw_[0], ratio_hw_[1], srcimg);

return filter_boxes;
}

std::vector<std::vector<std::vector<int>>> DetPredictor::Predict(cv::Mat &img, std::map<std::string, double> Config) {
std::vector<std::vector<std::vector<int>>> DetPredictor::Predict(cv::Mat &img) {
cv::Mat srcimg;
img.copyTo(srcimg);

// Read img
int max_side_len = int(Config["max_side_len"]);
int det_db_use_dilate = int(Config["det_db_use_dilate"]);

Timer tic;
tic.start();
auto image = Preprocess(img, max_side_len);
auto image = Preprocess(img, m_options.image_max_size);
tic.end();
auto preprocessTime = tic.get_average_ms();
std::cout << "det predictor preprocess costs " << preprocessTime << std::endl;

// Run predictor
std::vector<int64_t> input_shape = {1, image.channels, image.height, image.width};

Onnx onnx{m_model_path};
Onnx onnx {m_options.detection_model_path};
tic.start();
auto model_output = onnx.run(image.data, input_shape);
tic.end();
Expand All @@ -133,7 +128,7 @@ std::vector<std::vector<std::vector<int>>> DetPredictor::Predict(cv::Mat &img, s

// Process Output
tic.start();
auto filter_boxes = Postprocess(model_output, srcimg, Config, det_db_use_dilate);
auto filter_boxes = Postprocess(model_output, srcimg, m_options);
tic.end();
auto postprocessTime = tic.get_average_ms();
std::cout << "det predictor postprocess costs " << postprocessTime << std::endl;
Expand Down
14 changes: 7 additions & 7 deletions packages/react-native/cpp/det_process.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@

class DetPredictor {
public:
explicit DetPredictor(const std::string &modelDir, const int cpuThreadNum, const std::string &cpuPowerMode);
explicit DetPredictor(Options &options, const int cpuThreadNum, const std::string &cpuPowerMode);

std::vector<std::vector<std::vector<int>>> Predict(cv::Mat &rgbImage, std::map<std::string, double> Config);
std::vector<std::vector<std::vector<int>>> Predict(cv::Mat &rgbImage);

private:
std::string m_model_path;
Options m_options {};
std::vector<float> ratio_hw_;

ImageRaw Preprocess(const cv::Mat &img, const int max_side_len);
std::vector<std::vector<std::vector<int>>> Postprocess(ModelOutput &model_output, const cv::Mat &srcimg,
std::map<std::string, double> Config, int det_db_use_dilate);

private:
std::vector<float> ratio_hw_;
std::vector<std::vector<std::vector<int>>> Postprocess(ModelOutput &model_output, const cv::Mat &srcimg,
Options &options);
};
32 changes: 13 additions & 19 deletions packages/react-native/cpp/example/ocr_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,26 @@
#include "../native-ocr.h"

int main() {
printf("Start\n");
std::cout << "Start Ocr Example" << std::endl;

std::string asset_dir = "../assets";
auto det_model_file = asset_dir + "/ch_PP-OCRv4_det_infer.onnx";
auto rec_model_file = asset_dir + "/ch_PP-OCRv4_rec_infer.onnx";
auto cls_model_file = asset_dir + "/ch_ppocr_mobile_v2";
RawOptions rawOptions {
{"detectionModelPath", asset_dir + "/ch_PP-OCRv4_det_infer.onnx"},
{"recognitionModelPath", asset_dir + "/ch_PP-OCRv4_rec_infer.onnx"},
{"classiferModelPath", asset_dir + "/ch_ppocr_mobile_v2"},
{"dictionaryPath", asset_dir + "/ppocr_keys_v1.txt"},
};
auto image_path = asset_dir + "/cn-01.jpg";
auto output_img_path = asset_dir + "/out.jpg";
auto dict_path = asset_dir + "/ppocr_keys_v1.txt";
auto config_path = asset_dir + "/config.txt";

NativeOcr *pipe_ =
new NativeOcr(det_model_file, cls_model_file, rec_model_file, "LITE_POWER_HIGH", 1, config_path, dict_path);
// std::cout << rawOptions.at("a") << std::endl;

NativeOcr *pipe_ = new NativeOcr(rawOptions);
std::vector<std::string> res_txt;
pipe_->Process(image_path, output_img_path, res_txt);
auto lines = pipe_->Process(image_path);

std::ostringstream result;
for (int i = 0; i < res_txt.size() / 2; i++) {
auto text = res_txt[2 * i];
auto score = res_txt[2 * i + 1];
// std::cout << score << "\t" << text << std::endl;
// result << i << "\t" << res_txt[2*i] << "\t" << res_txt[2*i + 1] <<
// "\t";
result << res_txt[2 * i] << "\n";
for (auto line : lines) {
std::cout << line << std::endl;
}
// std::cout << result.str().c_str() << std::endl;

return 0;
}
72 changes: 50 additions & 22 deletions packages/react-native/cpp/native-ocr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,19 @@
#include <stdexcept>
#include "timer.h"

std::map<std::string, double> LoadConfigTxt(std::string config_path);
std::vector<std::string> ReadDict(std::string path);
cv::Mat GetRotateCropImage(cv::Mat srcimage, std::vector<std::vector<int>> box);
Options convertRawOptions(RawOptions rawOptions);

NativeOcr::NativeOcr(const std::string &detModelDir, const std::string &clsModelDir, const std::string &recModelDir,
const std::string &config_path, const std::string &dict_path) {
NativeOcr::NativeOcr(RawOptions rawOptions) : m_options {convertRawOptions(rawOptions)} {
auto cPUThreadNum = 1;
auto cPUPowerMode = "LITE_POWER_HIGH";
try {
// clsPredictor_.reset(
// new ClsPredictor(clsModelDir, cPUThreadNum, cPUPowerMode));
detPredictor_.reset(new DetPredictor(detModelDir, cPUThreadNum, cPUPowerMode));
recPredictor_.reset(new RecPredictor(recModelDir, cPUThreadNum, cPUPowerMode));
Config_ = LoadConfigTxt(config_path);
charactor_dict_ = ReadDict(dict_path);
detPredictor_.reset(new DetPredictor(m_options, cPUThreadNum, cPUPowerMode));
recPredictor_.reset(new RecPredictor(m_options, cPUThreadNum, cPUPowerMode));
charactor_dict_ = ReadDict(m_options.dictionary_path);
charactor_dict_.insert(charactor_dict_.begin(), "#");
charactor_dict_.push_back(" ");
} catch (std::string &error) {
Expand All @@ -47,15 +45,14 @@ NativeOcr::NativeOcr(const std::string &detModelDir, const std::string &clsModel
std::vector<std::string> NativeOcr::Process(std::string &image_path) {
try {
auto img = cv::imread(image_path);
int use_direction_classify = int(Config_["use_direction_classify"]);
cv::Mat srcimg;
img.copyTo(srcimg);

printf("Run Detection\n");
// det predict
Timer tic;
tic.start();
auto boxes = detPredictor_->Predict(srcimg, Config_);
auto boxes = detPredictor_->Predict(srcimg);

std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
Expand All @@ -69,7 +66,7 @@ std::vector<std::string> NativeOcr::Process(std::string &image_path) {
std::vector<float> rec_text_score;
for (int i = boxes.size() - 1; i >= 0; i--) {
crop_img = GetRotateCropImage(img_copy, boxes[i]);
// if (use_direction_classify >= 1)
// if (m_options.detection_use_direction_classify)
// {
// crop_img =
// clsPredictor_->Predict(crop_img, nullptr, nullptr, nullptr, 0.9);
Expand All @@ -95,7 +92,7 @@ std::vector<std::string> NativeOcr::Process(std::string &image_path) {
return lines;
} catch (std::string &error) {
std::cerr << error << std::endl;
return std::vector<std::string>{};
return std::vector<std::string> {};
}
}

Expand Down Expand Up @@ -185,17 +182,6 @@ std::vector<std::string> split(const std::string &str, const std::string &delim)
return res;
}

std::map<std::string, double> LoadConfigTxt(std::string config_path) {
auto config = ReadDict(config_path);

std::map<std::string, double> dict;
for (int i = 0; i < config.size(); i++) {
std::vector<std::string> res = split(config[i], " ");
dict[res[0]] = stod(res[1]);
}
return dict;
}

cv::Mat Visualization(cv::Mat srcimg, std::vector<std::vector<std::vector<int>>> boxes, std::string output_image_path) {
cv::Point rook_points[boxes.size()][4];
for (int n = 0; n < boxes.size(); n++) {
Expand All @@ -215,3 +201,45 @@ cv::Mat Visualization(cv::Mat srcimg, std::vector<std::vector<std::vector<int>>>
std::cout << "The detection visualized image saved in " << output_image_path.c_str() << std::endl;
return img_vis;
}

Options convertRawOptions(RawOptions rawOptions) {
Options options {};
if (rawOptions.count("isDebug") > 0) {
options.is_debug = std::get<bool>(rawOptions["isDebug"]);
}
if (rawOptions.count("imageMaxSize") > 0) {
options.image_max_size = std::get<double>(rawOptions.at("imageMaxSize"));
}
if (rawOptions.count("detectionThreshold") > 0) {
options.detection_threshold = std::get<double>(rawOptions.at("detectionThreshold"));
}
if (rawOptions.count("detectionBoxThreshold") > 0) {
options.detection_box_threshold = std::get<double>(rawOptions.at("detectionBoxThreshold"));
}
if (rawOptions.count("detectionUnclipRatiop") > 0) {
options.detection_unclip_ratiop = std::get<double>(rawOptions.at("detectionUnclipRatiop"));
}
if (rawOptions.count("detectionUseDilate") > 0) {
options.detection_use_dilate = std::get<bool>(rawOptions.at("detectionUseDilate"));
}
if (rawOptions.count("detectionUsePolygonScore") > 0) {
options.detection_use_polygon_score = std::get<bool>(rawOptions.at("detectionUsePolygonScore"));
}
if (rawOptions.count("detectionuseDirectionClassify") > 0) {
options.detection_use_direction_classify = std::get<bool>(rawOptions.at("detectionuseDirectionClassify"));
}
if (rawOptions.count("detectionModelPath") > 0) {
options.detection_model_path = std::get<std::string>(rawOptions.at("detectionModelPath"));
}
if (rawOptions.count("recognitionModelPath") > 0) {
options.recognition_model_path = std::get<std::string>(rawOptions.at("recognitionModelPath"));
}
if (rawOptions.count("classififerModelPath") > 0) {
options.classifier_model_path = std::get<std::string>(rawOptions.at("classififerModelPath"));
}
if (rawOptions.count("dictionaryPath") > 0) {
options.dictionary_path = std::get<std::string>(rawOptions.at("dictionaryPath"));
}
std::cout << "isDebug " << options.is_debug << std::endl;
return options;
}
10 changes: 7 additions & 3 deletions packages/react-native/cpp/native-ocr.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,23 @@
#include <opencv2/imgcodecs.hpp>
#include <opencv2/imgproc.hpp>
#include <string>
#include <variant>
#include <vector>
#include "det_process.h"
#include "rec_process.h"
#include "shared.h"

using RawOptions = std::unordered_map<std::string, std::variant<bool, double, std::string>>;

class NativeOcr {
public:
NativeOcr(const std::string &detModelDir, const std::string &clsModelDir, const std::string &recModelDir,
const std::string &config_path, const std::string &dict_path);
NativeOcr(RawOptions rawOptions);

std::vector<std::string> Process(std::string &image_path);

private:
std::map<std::string, double> Config_;
Options m_options;
// TODO: charactor_dict_ -> m_dictionary
std::vector<std::string> charactor_dict_;
// std::shared_ptr<ClsPredictor> clsPredictor_;
std::shared_ptr<DetPredictor> detPredictor_;
Expand Down
Loading

0 comments on commit 2062928

Please sign in to comment.