Skip to content

Commit

Permalink
refactor: Update model loading logic and file path handling
Browse files Browse the repository at this point in the history
  • Loading branch information
royshil committed Oct 20, 2024
1 parent f6241fa commit fb12416
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 71 deletions.
9 changes: 5 additions & 4 deletions src/cleanstream-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,10 +490,11 @@ obs_properties_t *cleanstream_properties(void *data)
obs_properties_add_list(ppts, "whisper_model_path", MT_("whisper_model"),
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_STRING);
// Add models from models_info map
for (const auto &model_info : models_info()) {
if (model_info.second.type == MODEL_TYPE_TRANSCRIPTION) {
obs_property_list_add_string(whisper_models_list, model_info.first.c_str(),
model_info.first.c_str());
for (const auto &model_info : get_sorted_models_info()) {
if (model_info.type == MODEL_TYPE_TRANSCRIPTION) {
obs_property_list_add_string(whisper_models_list,
model_info.friendly_name.c_str(),
model_info.friendly_name.c_str());
}
}

Expand Down
27 changes: 22 additions & 5 deletions src/model-utils/model-downloader-ui.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ void ModelDownloader::closeEvent(QCloseEvent *e)
{
if (!this->mPrepareToClose)
e->ignore();
else
else {
QDialog::closeEvent(e);
deleteLater();
}
}

void ModelDownloader::close()
Expand Down Expand Up @@ -125,11 +127,24 @@ std::string get_filename_from_url(const std::string &url)

void ModelDownloadWorker::download_model()
{
char *config_folder = obs_module_get_config_path(obs_current_module(), "models");
const std::filesystem::path module_config_models_folder =
std::filesystem::absolute(config_folder);
char *config_folder = obs_module_config_path("models");
#ifdef _WIN32
// convert mbstring to wstring
int count =
MultiByteToWideChar(CP_UTF8, 0, config_folder, (int)strlen(config_folder), NULL, 0);
std::wstring config_folder_str(count, 0);
MultiByteToWideChar(CP_UTF8, 0, config_folder, (int)strlen(config_folder),
&config_folder_str[0], count);
obs_log(LOG_INFO, "Download: Config models folder: %S", config_folder_str.c_str());
#else
std::string config_folder_str = config_folder;
obs_log(LOG_INFO, "Download: Config models folder: %s", config_folder_str.c_str());
#endif
bfree(config_folder);

const std::filesystem::path module_config_models_folder =
std::filesystem::absolute(config_folder_str);

// Check if the config folder exists
if (!std::filesystem::exists(module_config_models_folder)) {
obs_log(LOG_WARNING, "Config folder does not exist: %s",
Expand Down Expand Up @@ -231,7 +246,9 @@ ModelDownloader::~ModelDownloader()
}
delete this->download_thread;
}
delete this->download_worker;
if (this->download_worker != nullptr) {
delete this->download_worker;
}
}

ModelDownloadWorker::~ModelDownloadWorker()
Expand Down
4 changes: 2 additions & 2 deletions src/model-utils/model-downloader-ui.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ public slots:
private:
QVBoxLayout *layout;
QProgressBar *progress_bar;
QThread *download_thread;
ModelDownloadWorker *download_worker;
QPointer<QThread> download_thread;
QPointer<ModelDownloadWorker> download_worker;
// Callback for when the download is finished
download_finished_callback_t download_finished_callback;
bool mPrepareToClose;
Expand Down
85 changes: 31 additions & 54 deletions src/model-utils/model-downloader.cpp
Original file line number Diff line number Diff line change
@@ -1,65 +1,26 @@
#include "model-downloader.h"
#include "plugin-support.h"
#include "model-downloader-ui.h"
#include "model-find-utils.h"

#include <obs-module.h>
#include <obs-frontend-api.h>

#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>

#include <curl/curl.h>

bool check_if_model_exists(const std::string &model_name)
std::string find_model_folder(const ModelInfo &model_info)
{
obs_log(LOG_INFO, "Checking if model %s exists...", model_name.c_str());
char *model_file_path = obs_module_file(model_name.c_str());
obs_log(LOG_INFO, "Model file path: %s", model_file_path);
if (model_file_path == nullptr) {
obs_log(LOG_INFO, "Model %s does not exist.", model_name.c_str());
return false;
}

if (!std::filesystem::exists(model_file_path)) {
obs_log(LOG_INFO, "Model %s does not exist.", model_file_path);
bfree(model_file_path);
return false;
if (model_info.friendly_name.empty()) {
obs_log(LOG_ERROR, "Model info is invalid. Friendly name is empty.");
return "";
}
bfree(model_file_path);
return true;
}

std::string find_file_in_folder_by_name(const std::string &folder_path,
const std::string &file_name)
{
for (const auto &entry : std::filesystem::directory_iterator(folder_path)) {
if (entry.path().filename() == file_name) {
return entry.path().string();
}
if (model_info.local_folder_name.empty()) {
obs_log(LOG_ERROR, "Model info is invalid. Local folder name is empty.");
return "";
}
return "";
}

std::string find_bin_file_in_folder(const std::string &model_local_folder_path)
{
// find .bin file in folder
for (const auto &entry : std::filesystem::directory_iterator(model_local_folder_path)) {
if (entry.path().extension() == ".bin") {
const std::string bin_file_path = entry.path().string();
obs_log(LOG_INFO, "Model bin file found in folder: %s",
bin_file_path.c_str());
return bin_file_path;
}
if (model_info.files.empty()) {
obs_log(LOG_ERROR, "Model info is invalid. Files list is empty.");
return "";
}
obs_log(LOG_ERROR, "Model bin file not found in folder: %s",
model_local_folder_path.c_str());
return "";
}

std::string find_model_folder(const ModelInfo &model_info)
{
char *data_folder_models = obs_module_file("models");
const std::filesystem::path module_data_models_folder =
std::filesystem::absolute(data_folder_models);
Expand All @@ -79,9 +40,25 @@ std::string find_model_folder(const ModelInfo &model_info)
}

// Check if model exists in the config folder
char *config_folder = obs_module_get_config_path(obs_current_module(), "models");
char *config_folder = obs_module_config_path("models");
if (!config_folder) {
obs_log(LOG_INFO, "Config folder not set.");
return "";
}
#ifdef _WIN32
// convert mbstring to wstring
int count = MultiByteToWideChar(CP_UTF8, 0, config_folder, strlen(config_folder), NULL, 0);
std::wstring config_folder_str(count, 0);
MultiByteToWideChar(CP_UTF8, 0, config_folder, strlen(config_folder), &config_folder_str[0],
count);
obs_log(LOG_INFO, "Config models folder: %S", config_folder_str.c_str());
#else
std::string config_folder_str = config_folder;
obs_log(LOG_INFO, "Config models folder: %s", config_folder_str.c_str());
#endif

const std::filesystem::path module_config_models_folder =
std::filesystem::absolute(config_folder);
std::filesystem::absolute(config_folder_str);
bfree(config_folder);

obs_log(LOG_INFO, "Checking if model '%s' exists in config...",
Expand All @@ -90,9 +67,9 @@ std::string find_model_folder(const ModelInfo &model_info)
const std::string model_local_config_path =
(module_config_models_folder / model_info.local_folder_name).string();

obs_log(LOG_INFO, "Model path in config: %s", model_local_config_path.c_str());
obs_log(LOG_INFO, "Lookig for model in config: %s", model_local_config_path.c_str());
if (std::filesystem::exists(model_local_config_path)) {
obs_log(LOG_INFO, "Model exists in config folder: %s",
obs_log(LOG_INFO, "Model folder exists in config folder: %s",
model_local_config_path.c_str());
return model_local_config_path;
}
Expand Down
6 changes: 0 additions & 6 deletions src/model-utils/model-downloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,9 @@
#define MODEL_DOWNLOADER_H

#include <string>
#include <functional>

#include "model-downloader-types.h"

bool check_if_model_exists(const std::string &model_name);

std::string find_file_in_folder_by_name(const std::string &folder_path,
const std::string &file_name);
std::string find_bin_file_in_folder(const std::string &path);
std::string find_model_folder(const ModelInfo &model_info);
std::string find_model_bin_file(const ModelInfo &model_info);

Expand Down
50 changes: 50 additions & 0 deletions src/model-utils/model-find-utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>
#include <regex>

#include <obs-module.h>

#include "model-find-utils.h"
#include "plugin-support.h"

std::string find_file_in_folder_by_name(const std::string &folder_path,
const std::string &file_name)
{
for (const auto &entry : std::filesystem::directory_iterator(folder_path)) {
if (entry.path().filename() == file_name) {
return entry.path().string();
}
}
return "";
}

// Find a file in a folder by expression
std::string find_file_in_folder_by_regex_expression(const std::string &folder_path,
const std::string &file_name_regex)
{
for (const auto &entry : std::filesystem::directory_iterator(folder_path)) {
if (std::regex_match(entry.path().filename().string(),
std::regex(file_name_regex))) {
return entry.path().string();
}
}
return "";
}

std::string find_bin_file_in_folder(const std::string &model_local_folder_path)
{
// find .bin file in folder
for (const auto &entry : std::filesystem::directory_iterator(model_local_folder_path)) {
if (entry.path().extension() == ".bin") {
const std::string bin_file_path = entry.path().string();
obs_log(LOG_INFO, "Model bin file found in folder: %s",
bin_file_path.c_str());
return bin_file_path;
}
}
obs_log(LOG_ERROR, "Model bin file not found in folder: %s",
model_local_folder_path.c_str());
return "";
}
14 changes: 14 additions & 0 deletions src/model-utils/model-find-utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef MODEL_FIND_UTILS_H
#define MODEL_FIND_UTILS_H

#include <string>

#include "model-downloader-types.h"

std::string find_file_in_folder_by_name(const std::string &folder_path,
const std::string &file_name);
std::string find_bin_file_in_folder(const std::string &path);
std::string find_file_in_folder_by_regex_expression(const std::string &folder_path,
const std::string &file_name_regex);

#endif // MODEL_FIND_UTILS_H

0 comments on commit fb12416

Please sign in to comment.