From 8742d631fbdf96b787ae87119ac0f9760767f211 Mon Sep 17 00:00:00 2001 From: Anas Dorbani Date: Thu, 12 Dec 2024 21:49:03 +0100 Subject: [PATCH 1/4] add the global and local prompts/models support --- src/core/config/config.cpp | 66 ++++- src/core/config/model.cpp | 13 +- src/core/config/prompt.cpp | 13 +- src/custom_parser/query/model_parser.cpp | 276 +++++++++++++----- src/custom_parser/query/prompt_parser.cpp | 202 +++++++++---- src/custom_parser/query_parser.cpp | 18 +- src/include/flockmtl/core/config.hpp | 21 +- .../custom_parser/query/model_parser.hpp | 8 + .../custom_parser/query/prompt_parser.hpp | 8 + .../flockmtl/custom_parser/query_parser.hpp | 1 + .../custom_parser/query_statements.hpp | 2 + src/model_manager/model.cpp | 22 +- src/prompt_manager/prompt_manager.cpp | 29 +- 13 files changed, 493 insertions(+), 186 deletions(-) diff --git a/src/core/config/config.cpp b/src/core/config/config.cpp index 8b318b21..49d6aa97 100644 --- a/src/core/config/config.cpp +++ b/src/core/config/config.cpp @@ -1,10 +1,26 @@ #include "flockmtl/core/config.hpp" #include "flockmtl/secret_manager/secret_manager.hpp" +#include +#include namespace flockmtl { duckdb::DatabaseInstance* Config::db; +std::string Config::get_schema_name() { return "flockmtl_config"; } + +std::filesystem::path Config::get_global_storage_path() { +#ifdef _WIN32 + const char* homeDir = getenv("USERPROFILE"); +#else + const char* homeDir = getenv("HOME"); +#endif + if (homeDir == nullptr) { + throw std::runtime_error("Could not find home directory"); + } + return std::filesystem::path(homeDir) / ".duckdb" / "flockmtl_storage" / "flockmtl.db"; +} + duckdb::Connection Config::GetConnection(duckdb::DatabaseInstance* db) { if (db) { Config::db = db; @@ -13,11 +29,25 @@ duckdb::Connection Config::GetConnection(duckdb::DatabaseInstance* db) { return con; } -std::string Config::get_schema_name() { return "flockmtl_config"; } +duckdb::Connection Config::GetGlobalConnection() { + const duckdb::DuckDB db(Config::get_global_storage_path().string()); + duckdb::Connection con(*db.instance); + return con; +} -void Config::ConfigSchema(duckdb::Connection& con, std::string& schema_name) { +void Config::SetupGlobalStorageLocation() { + const auto flockmtl_global_path = get_global_storage_path(); + const auto flockmtlDir = flockmtl_global_path.parent_path(); + if (!std::filesystem::exists(flockmtlDir)) { + try { + std::filesystem::create_directories(flockmtlDir); + } catch (const std::filesystem::filesystem_error& e) { + std::cerr << "Error creating directories: " << e.what() << std::endl; + } + } +} - // Check if schema exists using fmt +void Config::ConfigSchema(duckdb::Connection& con, std::string& schema_name) { auto result = con.Query(duckdb_fmt::format(" SELECT * " " FROM information_schema.schemata " " WHERE schema_name = '{}'; ", @@ -27,19 +57,35 @@ void Config::ConfigSchema(duckdb::Connection& con, std::string& schema_name) { } } -void Config::Configure(duckdb::DatabaseInstance& db) { +void Config::ConfigureGlobal() { + auto con = Config::GetGlobalConnection(); + ConfigureTables(con, ConfigType::GLOBAL); +} + +void Config::ConfigureLocal(duckdb::DatabaseInstance& db) { auto con = Config::GetConnection(&db); - Registry::Register(db); - SecretManager::Register(db); + ConfigureTables(con, ConfigType::LOCAL); +} +void Config::ConfigureTables(duckdb::Connection& con, const ConfigType type) { con.BeginTransaction(); - std::string schema = Config::get_schema_name(); ConfigSchema(con, schema); - ConfigModelTable(con, schema); - ConfigPromptTable(con, schema); - + ConfigModelTable(con, schema, type); + ConfigPromptTable(con, schema, type); + con.Query( + duckdb_fmt::format("ATTACH DATABASE '{}' AS flockmtl_storage;", Config::get_global_storage_path().string())); con.Commit(); } +void Config::Configure(duckdb::DatabaseInstance& db) { + Registry::Register(db); + SecretManager::Register(db); + if (const auto db_path = db.config.options.database_path; db_path != get_global_storage_path().string()) { + ConfigureLocal(db); + SetupGlobalStorageLocation(); + ConfigureGlobal(); + } +} + } // namespace flockmtl diff --git a/src/core/config/model.cpp b/src/core/config/model.cpp index 9162fc6f..8de499e8 100644 --- a/src/core/config/model.cpp +++ b/src/core/config/model.cpp @@ -1,4 +1,5 @@ #include "flockmtl/core/config.hpp" +#include namespace flockmtl { @@ -6,7 +7,7 @@ std::string Config::get_default_models_table_name() { return "FLOCKMTL_MODEL_DEF std::string Config::get_user_defined_models_table_name() { return "FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE"; } -void Config::setup_default_models_config(duckdb::Connection& con, std::string& schema_name) { +void Config::SetupDefaultModelsConfig(duckdb::Connection& con, std::string& schema_name) { const std::string table_name = Config::get_default_models_table_name(); // Ensure schema exists auto result = con.Query(duckdb_fmt::format(" SELECT table_name " @@ -39,7 +40,7 @@ void Config::setup_default_models_config(duckdb::Connection& con, std::string& s } } -void Config::setup_user_defined_models_config(duckdb::Connection& con, std::string& schema_name) { +void Config::SetupUserDefinedModelsConfig(duckdb::Connection& con, std::string& schema_name) { const std::string table_name = Config::get_user_defined_models_table_name(); // Ensure schema exists auto result = con.Query(duckdb_fmt::format(" SELECT table_name " @@ -59,9 +60,11 @@ void Config::setup_user_defined_models_config(duckdb::Connection& con, std::stri } } -void Config::ConfigModelTable(duckdb::Connection& con, std::string& schema_name) { - setup_default_models_config(con, schema_name); - setup_user_defined_models_config(con, schema_name); +void Config::ConfigModelTable(duckdb::Connection& con, std::string& schema_name, const ConfigType type) { + if (type == ConfigType::GLOBAL) { + SetupDefaultModelsConfig(con, schema_name); + } + SetupUserDefinedModelsConfig(con, schema_name); } } // namespace flockmtl diff --git a/src/core/config/prompt.cpp b/src/core/config/prompt.cpp index 4ee24d7a..ad23db1a 100644 --- a/src/core/config/prompt.cpp +++ b/src/core/config/prompt.cpp @@ -1,10 +1,11 @@ #include "flockmtl/core/config.hpp" +#include namespace flockmtl { std::string Config::get_prompts_table_name() { return "FLOCKMTL_PROMPT_INTERNAL_TABLE"; } -void Config::ConfigPromptTable(duckdb::Connection& con, std::string& schema_name) { +void Config::ConfigPromptTable(duckdb::Connection& con, std::string& schema_name, const ConfigType type) { const std::string table_name = "FLOCKMTL_PROMPT_INTERNAL_TABLE"; auto result = con.Query(duckdb_fmt::format(" SELECT table_name " @@ -16,14 +17,16 @@ void Config::ConfigPromptTable(duckdb::Connection& con, std::string& schema_name con.Query(duckdb_fmt::format(" CREATE TABLE {}.{} ( " " prompt_name VARCHAR NOT NULL, " " prompt VARCHAR NOT NULL, " - " update_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, " + " updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, " " version INT DEFAULT 1, " " PRIMARY KEY (prompt_name, version) " " ); ", schema_name, table_name)); - con.Query(duckdb_fmt::format(" INSERT INTO {}.{} (prompt_name, prompt) " - " VALUES ('hello-world', 'Tell me hello world'); ", - schema_name, table_name)); + if (type == ConfigType::GLOBAL) { + con.Query(duckdb_fmt::format(" INSERT INTO {}.{} (prompt_name, prompt) " + " VALUES ('hello-world', 'Tell me hello world'); ", + schema_name, table_name)); + } } } diff --git a/src/custom_parser/query/model_parser.cpp b/src/custom_parser/query/model_parser.cpp index 5bf6f2e9..0037aef9 100644 --- a/src/custom_parser/query/model_parser.cpp +++ b/src/custom_parser/query/model_parser.cpp @@ -1,7 +1,7 @@ #include "flockmtl/custom_parser/query/model_parser.hpp" #include "flockmtl/core/common.hpp" - +#include "flockmtl/core/config.hpp" #include #include @@ -10,7 +10,7 @@ namespace flockmtl { void ModelParser::Parse(const std::string& query, std::unique_ptr& statement) { Tokenizer tokenizer(query); Token token = tokenizer.NextToken(); - std::string value = duckdb::StringUtil::Upper(token.value); + const std::string value = duckdb::StringUtil::Upper(token.value); if (token.type == TokenType::KEYWORD) { if (value == "CREATE") { @@ -32,6 +32,16 @@ void ModelParser::Parse(const std::string& query, std::unique_ptr& statement) { Token token = tokenizer.NextToken(); std::string value = duckdb::StringUtil::Upper(token.value); + + std::string catalog; + if (token.type == TokenType::KEYWORD && (value == "GLOBAL" || value == "LOCAL")) { + if (value == "GLOBAL") { + catalog = "flockmtl_storage."; + } + token = tokenizer.NextToken(); + value = duckdb::StringUtil::Upper(token.value); + } + if (token.type != TokenType::KEYWORD || value != "MODEL") { throw std::runtime_error("Expected 'MODEL' after 'CREATE'."); } @@ -96,6 +106,7 @@ void ModelParser::ParseCreateModel(Tokenizer& tokenizer, std::unique_ptr(); + create_statement->catalog = catalog; create_statement->model_name = model_name; create_statement->model = model; create_statement->provider_name = provider_name; @@ -130,79 +141,107 @@ void ModelParser::ParseDeleteModel(Tokenizer& tokenizer, std::unique_ptr& statement) { - Token token = tokenizer.NextToken(); - std::string value = duckdb::StringUtil::Upper(token.value); + auto token = tokenizer.NextToken(); + auto value = duckdb::StringUtil::Upper(token.value); if (token.type != TokenType::KEYWORD || value != "MODEL") { throw std::runtime_error("Expected 'MODEL' after 'UPDATE'."); } token = tokenizer.NextToken(); - if (token.type != TokenType::PARENTHESIS || token.value != "(") { - throw std::runtime_error("Expected opening parenthesis '(' after 'MODEL'."); - } + if (token.type == TokenType::STRING_LITERAL) { + auto model_name = token.value; + token = tokenizer.NextToken(); + if (token.type != TokenType::KEYWORD || duckdb::StringUtil::Upper(token.value) != "TO") { + throw std::runtime_error("Expected 'TO' after model name."); + } - token = tokenizer.NextToken(); - if (token.type != TokenType::STRING_LITERAL || token.value.empty()) { - throw std::runtime_error("Expected non-empty string literal for model name."); - } - std::string model_name = token.value; + token = tokenizer.NextToken(); + value = duckdb::StringUtil::Upper(token.value); + if (token.type != TokenType::KEYWORD || (value != "GLOBAL" && value != "LOCAL")) { + throw std::runtime_error("Expected 'GLOBAL' or 'LOCAL' after 'TO'."); + } + auto catalog = value == "GLOBAL" ? "flockmtl_storage." : ""; - token = tokenizer.NextToken(); - if (token.type != TokenType::SYMBOL || token.value != ",") { - throw std::runtime_error("Expected comma ',' after model name."); - } + token = tokenizer.NextToken(); + if (token.type == TokenType::SYMBOL || token.value == ";") { + auto update_statement = std::make_unique(); + update_statement->model_name = model_name; + update_statement->catalog = catalog; + statement = std::move(update_statement); + } else { + throw std::runtime_error( + "Unexpected characters after the closing parenthesis. Only a semicolon is allowed."); + } - token = tokenizer.NextToken(); - if (token.type != TokenType::STRING_LITERAL || token.value.empty()) { - throw std::runtime_error("Expected non-empty string literal for model."); - } - std::string new_model = token.value; + } else { + if (token.type != TokenType::PARENTHESIS || token.value != "(") { + throw std::runtime_error("Expected opening parenthesis '(' after 'MODEL'."); + } - token = tokenizer.NextToken(); - if (token.type != TokenType::SYMBOL || token.value != ",") { - throw std::runtime_error("Expected comma ',' after model."); - } + token = tokenizer.NextToken(); + if (token.type != TokenType::STRING_LITERAL || token.value.empty()) { + throw std::runtime_error("Expected non-empty string literal for model name."); + } + auto model_name = token.value; - token = tokenizer.NextToken(); - if (token.type != TokenType::STRING_LITERAL || token.value.empty()) { - throw std::runtime_error("Expected non-empty string literal for provider_name."); - } - std::string provider_name = token.value; + token = tokenizer.NextToken(); + if (token.type != TokenType::SYMBOL || token.value != ",") { + throw std::runtime_error("Expected comma ',' after model name."); + } - token = tokenizer.NextToken(); - if (token.type != TokenType::SYMBOL || token.value != ",") { - throw std::runtime_error("Expected comma ',' after provider_name."); - } + token = tokenizer.NextToken(); + if (token.type != TokenType::STRING_LITERAL || token.value.empty()) { + throw std::runtime_error("Expected non-empty string literal for model."); + } + std::string new_model = token.value; - token = tokenizer.NextToken(); - if (token.type != TokenType::JSON || token.value.empty()) { - throw std::runtime_error("Expected json value for the model_args."); - } - auto new_model_args = nlohmann::json::parse(token.value); - const std::set expected_keys = {"context_window", "max_output_tokens"}; - std::set json_keys; - for (auto it = new_model_args.begin(); it != new_model_args.end(); ++it) { - json_keys.insert(it.key()); - } - if (json_keys != expected_keys) { - throw std::runtime_error("Expected keys: context_window, max_output_tokens in model_args."); - } + token = tokenizer.NextToken(); + if (token.type != TokenType::SYMBOL || token.value != ",") { + throw std::runtime_error("Expected comma ',' after model."); + } - token = tokenizer.NextToken(); - if (token.type != TokenType::PARENTHESIS || token.value != ")") { - throw std::runtime_error("Expected closing parenthesis ')' after new max_output_tokens."); - } + token = tokenizer.NextToken(); + if (token.type != TokenType::STRING_LITERAL || token.value.empty()) { + throw std::runtime_error("Expected non-empty string literal for provider_name."); + } + std::string provider_name = token.value; - token = tokenizer.NextToken(); - if (token.type == TokenType::END_OF_FILE) { - auto update_statement = std::make_unique(); - update_statement->new_model = new_model; - update_statement->model_name = model_name; - update_statement->provider_name = provider_name; - update_statement->new_model_args = new_model_args; - statement = std::move(update_statement); - } else { - throw std::runtime_error("Unexpected characters after the closing parenthesis. Only a semicolon is allowed."); + token = tokenizer.NextToken(); + if (token.type != TokenType::SYMBOL || token.value != ",") { + throw std::runtime_error("Expected comma ',' after provider_name."); + } + + token = tokenizer.NextToken(); + if (token.type != TokenType::JSON || token.value.empty()) { + throw std::runtime_error("Expected json value for the model_args."); + } + auto new_model_args = nlohmann::json::parse(token.value); + const std::set expected_keys = {"context_window", "max_output_tokens"}; + std::set json_keys; + for (auto it = new_model_args.begin(); it != new_model_args.end(); ++it) { + json_keys.insert(it.key()); + } + if (json_keys != expected_keys) { + throw std::runtime_error("Expected keys: context_window, max_output_tokens in model_args."); + } + + token = tokenizer.NextToken(); + if (token.type != TokenType::PARENTHESIS || token.value != ")") { + throw std::runtime_error("Expected closing parenthesis ')' after new max_output_tokens."); + } + + token = tokenizer.NextToken(); + if (token.type == TokenType::END_OF_FILE) { + auto update_statement = std::make_unique(); + update_statement->new_model = new_model; + update_statement->model_name = model_name; + update_statement->provider_name = provider_name; + update_statement->new_model_args = new_model_args; + statement = std::move(update_statement); + } else { + throw std::runtime_error( + "Unexpected characters after the closing parenthesis. Only a semicolon is allowed."); + } } } @@ -240,50 +279,125 @@ std::string ModelParser::ToSQL(const QueryStatement& statement) const { switch (statement.type) { case StatementType::CREATE_MODEL: { const auto& create_stmt = static_cast(statement); + auto con = Config::GetConnection(); + auto result = con.Query(duckdb_fmt::format( + " SELECT model_name" + " FROM flockmtl_storage.flockmtl_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE" + " WHERE model_name = '{}'" + " UNION ALL " + " SELECT model_name " + " FROM {}flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" + " WHERE model_name = '{}';", + create_stmt.model_name, create_stmt.catalog.empty() ? "flockmtl_storage." : "", create_stmt.model_name)); + if (result->RowCount() != 0) { + throw std::runtime_error(duckdb_fmt::format("Model '{}' already exist.", create_stmt.model_name)); + } + query = duckdb_fmt::format(" INSERT INTO " - " flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + " {}flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " " (model_name, model, provider_name, model_args) " " VALUES ('{}', '{}', '{}', '{}');", - create_stmt.model_name, create_stmt.model, create_stmt.provider_name, - create_stmt.model_args.dump()); + create_stmt.catalog, create_stmt.model_name, create_stmt.model, + create_stmt.provider_name, create_stmt.model_args.dump()); break; } case StatementType::DELETE_MODEL: { const auto& delete_stmt = static_cast(statement); + auto con = Config::GetConnection(); + + con.Query(duckdb_fmt::format(" DELETE FROM flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + " WHERE model_name = '{}';", + delete_stmt.model_name)); + query = duckdb_fmt::format(" DELETE FROM " - " flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + " flockmtl_storage.flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " " WHERE model_name = '{}';", - delete_stmt.model_name); + delete_stmt.model_name, delete_stmt.model_name); break; } case StatementType::UPDATE_MODEL: { const auto& update_stmt = static_cast(statement); - query = duckdb_fmt::format(" UPDATE flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + auto con = Config::GetConnection(); + // get the location of the model_name if local or global + auto result = con.Query( + duckdb_fmt::format(" SELECT model_name, 'global' AS scope " + " FROM flockmtl_storage.flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" + " WHERE model_name = '{}'" + " UNION ALL " + " SELECT model_name, 'local' AS scope " + " FROM flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" + " WHERE model_name = '{}';", + update_stmt.model_name, update_stmt.model_name, update_stmt.model_name)); + + if (result->RowCount() == 0) { + throw std::runtime_error(duckdb_fmt::format("Model '{}' doesn't exist.", update_stmt.model_name)); + } + + auto catalog = result->GetValue(1, 0).ToString() == "global" ? "flockmtl_storage." : ""; + + query = duckdb_fmt::format(" UPDATE {}flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " " SET model = '{}', provider_name = '{}', " " model_args = '{}' WHERE model_name = '{}'; ", - update_stmt.new_model, update_stmt.provider_name, update_stmt.new_model_args.dump(), + catalog, update_stmt.new_model, update_stmt.provider_name, + update_stmt.new_model_args.dump(), update_stmt.model_name); + break; + } + case StatementType::UPDATE_MODEL_SCOPE: { + const auto& update_stmt = static_cast(statement); + auto con = Config::GetConnection(); + auto result = + con.Query(duckdb_fmt::format(" SELECT model_name " + " FROM {}flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" + " WHERE model_name = '{}';", + update_stmt.catalog, update_stmt.model_name)); + if (result->RowCount() != 0) { + throw std::runtime_error( + duckdb_fmt::format("Model '{}' already exist in {} storage.", update_stmt.model_name, + update_stmt.catalog == "flockmtl_storage." ? "global" : "local")); + } + + con.Query(duckdb_fmt::format("INSERT INTO {}flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + "(model_name, model, provider_name, model_args) " + "SELECT model_name, model, provider_name, model_args " + "FROM {}flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + "WHERE model_name = '{}'; ", + update_stmt.catalog, + update_stmt.catalog == "flockmtl_storage." ? "" : "flockmtl_storage.", + update_stmt.model_name)); + + query = duckdb_fmt::format("DELETE FROM {}flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + "WHERE model_name = '{}'; ", + update_stmt.catalog == "flockmtl_storage." ? "" : "flockmtl_storage.", update_stmt.model_name); break; } case StatementType::GET_MODEL: { const auto& get_stmt = static_cast(statement); - query = duckdb_fmt::format(" SELECT * " - " FROM flockmtl_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE " - " WHERE model_name = '{}' " - " UNION ALL " - " SELECT * " - " FROM flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " - " WHERE model_name = '{}';", - get_stmt.model_name, get_stmt.model_name); + query = duckdb_fmt::format("SELECT 'global' AS scope, * " + "FROM flockmtl_storage.flockmtl_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE " + "WHERE model_name = '{}' " + "UNION ALL " + "SELECT 'global' AS scope, * " + "FROM flockmtl_storage.flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + "WHERE model_name = '{}'" + "UNION ALL " + "SELECT 'local' AS scope, * " + "FROM flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + "WHERE model_name = '{}';", + get_stmt.model_name, get_stmt.model_name, get_stmt.model_name, get_stmt.model_name); break; } case StatementType::GET_ALL_MODEL: { - query = duckdb_fmt::format(" SELECT * " - " FROM flockmtl_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE " - " UNION ALL " - " SELECT * " - " FROM flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE; "); + query = duckdb_fmt::format(" SELECT 'global' AS scope, * " + " FROM flockmtl_storage.flockmtl_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE" + " UNION ALL " + " SELECT 'global' AS scope, * " + " FROM flockmtl_storage.flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" + " UNION ALL " + " SELECT 'local' AS scope, * " + " FROM flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE;", + Config::get_global_storage_path().string()); break; } default: diff --git a/src/custom_parser/query/prompt_parser.cpp b/src/custom_parser/query/prompt_parser.cpp index e5964e8b..13ed57a1 100644 --- a/src/custom_parser/query/prompt_parser.cpp +++ b/src/custom_parser/query/prompt_parser.cpp @@ -33,6 +33,16 @@ void PromptParser::Parse(const std::string& query, std::unique_ptr& statement) { Token token = tokenizer.NextToken(); std::string value = duckdb::StringUtil::Upper(token.value); + + std::string catalog; + if (token.type == TokenType::KEYWORD && (value == "GLOBAL" || value == "LOCAL")) { + if (value == "GLOBAL") { + catalog = "flockmtl_storage."; + } + token = tokenizer.NextToken(); + value = duckdb::StringUtil::Upper(token.value); + } + if (token.type != TokenType::KEYWORD || value != "PROMPT") { throw std::runtime_error("Unknown keyword: " + token.value); } @@ -46,7 +56,7 @@ void PromptParser::ParseCreatePrompt(Tokenizer& tokenizer, std::unique_ptr(); + create_statement->catalog = catalog; create_statement->prompt_name = prompt_name; create_statement->prompt = prompt; statement = std::move(create_statement); @@ -106,40 +117,68 @@ void PromptParser::ParseUpdatePrompt(Tokenizer& tokenizer, std::unique_ptr(); + update_statement->prompt_name = prompt_name; + update_statement->catalog = catalog; + statement = std::move(update_statement); + } else { + throw std::runtime_error( + "Unexpected characters after the closing parenthesis. Only a semicolon is allowed."); + } - token = tokenizer.NextToken(); - if (token.type != TokenType::STRING_LITERAL || token.value.empty()) { - throw std::runtime_error("Expected non-empty string literal for new prompt text."); - } - std::string new_prompt = token.value; + } else { + if (token.type != TokenType::PARENTHESIS || token.value != "(") { + throw std::runtime_error("Expected opening parenthesis '(' after 'PROMPT'."); + } - token = tokenizer.NextToken(); - if (token.type != TokenType::PARENTHESIS || token.value != ")") { - throw std::runtime_error("Expected closing parenthesis ')' after new prompt text."); - } + token = tokenizer.NextToken(); + if (token.type != TokenType::STRING_LITERAL || token.value.empty()) { + throw std::runtime_error("Expected non-empty string literal for prompt name."); + } + std::string prompt_name = token.value; - token = tokenizer.NextToken(); - if (token.type == TokenType::END_OF_FILE) { - auto update_statement = std::make_unique(); - update_statement->prompt_name = prompt_name; - update_statement->new_prompt = new_prompt; - statement = std::move(update_statement); - } else { - throw std::runtime_error("Unexpected characters after the closing parenthesis. Only a semicolon is allowed."); + token = tokenizer.NextToken(); + if (token.type != TokenType::SYMBOL || token.value != ",") { + throw std::runtime_error("Expected comma ',' after prompt name."); + } + + token = tokenizer.NextToken(); + if (token.type != TokenType::STRING_LITERAL || token.value.empty()) { + throw std::runtime_error("Expected non-empty string literal for new prompt text."); + } + std::string new_prompt = token.value; + + token = tokenizer.NextToken(); + if (token.type != TokenType::PARENTHESIS || token.value != ")") { + throw std::runtime_error("Expected closing parenthesis ')' after new prompt text."); + } + + token = tokenizer.NextToken(); + if (token.type == TokenType::END_OF_FILE) { + auto update_statement = std::make_unique(); + update_statement->prompt_name = prompt_name; + update_statement->new_prompt = new_prompt; + statement = std::move(update_statement); + } else { + throw std::runtime_error( + "Unexpected characters after the closing parenthesis. Only a semicolon is allowed."); + } } } @@ -177,59 +216,110 @@ std::string PromptParser::ToSQL(const QueryStatement& statement) const { switch (statement.type) { case StatementType::CREATE_PROMPT: { const auto& create_stmt = static_cast(statement); - // check if prompt_name already exists auto con = Config::GetConnection(); - auto result = con.Query(duckdb_fmt::format(" SELECT * " - " FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " - " WHERE prompt_name = '{}'; ", + auto result = con.Query(duckdb_fmt::format(" SELECT prompt_name " + " FROM {}flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE" + " WHERE prompt_name = '{}';", + create_stmt.catalog.empty() ? "flockmtl_storage." : "", create_stmt.prompt_name)); - if (result->RowCount() > 0) { - throw std::runtime_error(duckdb_fmt::format("Prompt '{}' already exists.", create_stmt.prompt_name)); + if (result->RowCount() != 0) { + throw std::runtime_error(duckdb_fmt::format("Prompt '{}' already exist.", create_stmt.prompt_name)); } - query = duckdb_fmt::format(" INSERT INTO flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + query = duckdb_fmt::format(" INSERT INTO {}flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " " (prompt_name, prompt) " " VALUES ('{}', '{}'); ", - create_stmt.prompt_name, create_stmt.prompt); + create_stmt.catalog, create_stmt.prompt_name, create_stmt.prompt); break; } case StatementType::DELETE_PROMPT: { const auto& delete_stmt = static_cast(statement); - query = duckdb_fmt::format(" DELETE FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + auto con = Config::GetConnection(); + auto result = con.Query(duckdb_fmt::format(" DELETE FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + " WHERE prompt_name = '{}'; ", + delete_stmt.prompt_name)); + + query = duckdb_fmt::format(" DELETE FROM flockmtl_storage.flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " " WHERE prompt_name = '{}'; ", delete_stmt.prompt_name); break; } case StatementType::UPDATE_PROMPT: { const auto& update_stmt = static_cast(statement); - // get the existing prompt version auto con = Config::GetConnection(); - auto result = con.Query(duckdb_fmt::format(" SELECT version " - " FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " - " WHERE prompt_name = '{}' " - " ORDER BY version DESC " - " LIMIT 1; ", - update_stmt.prompt_name)); + auto result = + con.Query(duckdb_fmt::format(" SELECT version, 'local' AS scope " + " FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE" + " WHERE prompt_name = '{}'" + " UNION ALL " + " SELECT version, 'global' AS scope " + " FROM flockmtl_storage.flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE" + " WHERE prompt_name = '{}' " + " ORDER BY version DESC;", + update_stmt.prompt_name, update_stmt.prompt_name)); if (result->RowCount() == 0) { - throw std::runtime_error(duckdb_fmt::format("Prompt '{}' does not exist.", update_stmt.prompt_name)); + throw std::runtime_error(duckdb_fmt::format("Prompt '{}' doesn't exist.", update_stmt.prompt_name)); } + int version = result->GetValue(0, 0) + 1; - query = duckdb_fmt::format(" INSERT INTO flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + auto catalog = result->GetValue(1, 0).ToString() == "global" ? "flockmtl_storage." : ""; + query = duckdb_fmt::format(" INSERT INTO {}flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " " (prompt_name, prompt, version) " " VALUES ('{}', '{}', {}); ", - update_stmt.prompt_name, update_stmt.new_prompt, version); + catalog, update_stmt.prompt_name, update_stmt.new_prompt, version); + break; + } + case StatementType::UPDATE_PROMPT_SCOPE: { + const auto& update_stmt = static_cast(statement); + auto con = Config::GetConnection(); + auto result = con.Query(duckdb_fmt::format(" SELECT prompt_name " + " FROM {}flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE" + " WHERE prompt_name = '{}';", + update_stmt.catalog, update_stmt.prompt_name)); + if (result->RowCount() != 0) { + throw std::runtime_error( + duckdb_fmt::format("Model '{}' already exist in {} storage.", update_stmt.prompt_name, + update_stmt.catalog == "flockmtl_storage." ? "global" : "local")); + } + + con.Query(duckdb_fmt::format("INSERT INTO {}flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + "(prompt_name, prompt, updated_at, version) " + "SELECT prompt_name, prompt, updated_at, version " + "FROM {}flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + "WHERE prompt_name = '{}';", + update_stmt.catalog, + update_stmt.catalog == "flockmtl_storage." ? "" : "flockmtl_storage.", + update_stmt.prompt_name)); + + query = duckdb_fmt::format("DELETE FROM {}flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + "WHERE prompt_name = '{}'; ", + update_stmt.catalog == "flockmtl_storage." ? "" : "flockmtl_storage.", + update_stmt.prompt_name); break; } case StatementType::GET_PROMPT: { const auto& get_stmt = static_cast(statement); - query = duckdb_fmt::format(" SELECT * " - " FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " - " WHERE prompt_name = '{}' " - " ORDER BY version DESC; ", - get_stmt.prompt_name); + query = duckdb_fmt::format("SELECT 'global' AS scope, * " + "FROM flockmtl_storage.flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + "WHERE prompt_name = '{}' " + "UNION ALL " + "SELECT 'local' AS scope, * " + "FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + "WHERE prompt_name = '{}' " + "ORDER BY version DESC;", + get_stmt.prompt_name, get_stmt.prompt_name); + break; } case StatementType::GET_ALL_PROMPT: { - query = " SELECT t1.* " + query = " SELECT 'global' as scope, t1.* " + " FROM flockmtl_storage.flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE AS t1 " + " JOIN (SELECT prompt_name, MAX(version) AS max_version " + " FROM flockmtl_storage.flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + " GROUP BY prompt_name) AS t2 " + " ON t1.prompt_name = t2.prompt_name " + " AND t1.version = t2.max_version" + " UNION ALL " + " SELECT 'local' as scope, t1.* " " FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE AS t1 " " JOIN (SELECT prompt_name, MAX(version) AS max_version " " FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " diff --git a/src/custom_parser/query_parser.cpp b/src/custom_parser/query_parser.cpp index ecf884c9..5f57e7a1 100644 --- a/src/custom_parser/query_parser.cpp +++ b/src/custom_parser/query_parser.cpp @@ -10,23 +10,29 @@ namespace flockmtl { std::string QueryParser::ParseQuery(const std::string& query) { Tokenizer tokenizer(query); - Token token = tokenizer.NextToken(); - std::string value = duckdb::StringUtil::Upper(token.value); + auto token = tokenizer.NextToken(); + auto value = duckdb::StringUtil::Upper(token.value); if (token.type != TokenType::KEYWORD || (value != "CREATE" && value != "DELETE" && value != "UPDATE" && value != "GET")) { throw std::runtime_error(duckdb_fmt::format("Unknown keyword: {}", token.value)); } - token = tokenizer.NextToken(); - value = duckdb::StringUtil::Upper(token.value); - if (token.type == TokenType::KEYWORD && value == "MODEL" || value == "MODELS") { + return ParsePromptOrModel(tokenizer, query); +} + +inline std::string QueryParser::ParsePromptOrModel(Tokenizer tokenizer, const std::string& query) { + Token token = tokenizer.NextToken(); + auto value = duckdb::StringUtil::Upper(token.value); + if (token.type == TokenType::KEYWORD && (value == "MODEL" || value == "MODELS")) { ModelParser model_parser; model_parser.Parse(query, statement); return model_parser.ToSQL(*statement); - } else if (token.type == TokenType::KEYWORD && ((value == "PROMPT" || value == "PROMPTS"))) { + } else if (token.type == TokenType::KEYWORD && (value == "PROMPT" || value == "PROMPTS")) { PromptParser prompt_parser; prompt_parser.Parse(query, statement); return prompt_parser.ToSQL(*statement); + } else if (token.type == TokenType::KEYWORD && (value == "GLOBAL" || value == "LOCAL")) { + return ParsePromptOrModel(tokenizer, query); } else { throw std::runtime_error(duckdb_fmt::format("Unknown keyword: {}", token.value)); } diff --git a/src/include/flockmtl/core/config.hpp b/src/include/flockmtl/core/config.hpp index 248390cb..62cc42ec 100644 --- a/src/include/flockmtl/core/config.hpp +++ b/src/include/flockmtl/core/config.hpp @@ -1,19 +1,27 @@ #pragma once #include - +#include #include "flockmtl/core/common.hpp" #include "flockmtl/registry/registry.hpp" namespace flockmtl { +enum ConfigType { LOCAL, GLOBAL }; + class Config { public: static duckdb::DatabaseInstance* db; + static duckdb::DatabaseInstance* global_db; static duckdb::Connection GetConnection(duckdb::DatabaseInstance* db = nullptr); - + static duckdb::Connection GetGlobalConnection(); static void Configure(duckdb::DatabaseInstance& db); + static void ConfigureGlobal(); + static void ConfigureTables(duckdb::Connection& con, ConfigType type); + static void ConfigureLocal(duckdb::DatabaseInstance& db); + static std::string get_schema_name(); + static std::filesystem::path get_global_storage_path(); static std::string get_default_models_table_name(); static std::string get_user_defined_models_table_name(); static std::string get_prompts_table_name(); @@ -21,11 +29,12 @@ class Config { constexpr static int32_t default_max_output_tokens = 4096; private: + static void SetupGlobalStorageLocation(); static void ConfigSchema(duckdb::Connection& con, std::string& schema_name); - static void ConfigModelTable(duckdb::Connection& con, std::string& schema_name); - static void ConfigPromptTable(duckdb::Connection& con, std::string& schema_name); - static void setup_default_models_config(duckdb::Connection& con, std::string& schema_name); - static void setup_user_defined_models_config(duckdb::Connection& con, std::string& schema_name); + static void ConfigPromptTable(duckdb::Connection& con, std::string& schema_name, ConfigType type); + static void ConfigModelTable(duckdb::Connection& con, std::string& schema_name, ConfigType type); + static void SetupDefaultModelsConfig(duckdb::Connection& con, std::string& schema_name); + static void SetupUserDefinedModelsConfig(duckdb::Connection& con, std::string& schema_name); }; } // namespace flockmtl diff --git a/src/include/flockmtl/custom_parser/query/model_parser.hpp b/src/include/flockmtl/custom_parser/query/model_parser.hpp index c47d12e8..9e94d5a2 100644 --- a/src/include/flockmtl/custom_parser/query/model_parser.hpp +++ b/src/include/flockmtl/custom_parser/query/model_parser.hpp @@ -16,6 +16,7 @@ namespace flockmtl { class CreateModelStatement : public QueryStatement { public: CreateModelStatement() { type = StatementType::CREATE_MODEL; } + std::string catalog; std::string model_name; std::string model; std::string provider_name; @@ -29,6 +30,13 @@ class DeleteModelStatement : public QueryStatement { std::string provider_name; }; +class UpdateModelScopeStatement : public QueryStatement { +public: + UpdateModelScopeStatement() { type = StatementType::UPDATE_MODEL_SCOPE; } + std::string model_name; + std::string catalog; +}; + class UpdateModelStatement : public QueryStatement { public: UpdateModelStatement() { type = StatementType::UPDATE_MODEL; } diff --git a/src/include/flockmtl/custom_parser/query/prompt_parser.hpp b/src/include/flockmtl/custom_parser/query/prompt_parser.hpp index 7574c8e2..42365c91 100644 --- a/src/include/flockmtl/custom_parser/query/prompt_parser.hpp +++ b/src/include/flockmtl/custom_parser/query/prompt_parser.hpp @@ -15,6 +15,7 @@ namespace flockmtl { class CreatePromptStatement : public QueryStatement { public: CreatePromptStatement() { type = StatementType::CREATE_PROMPT; } + std::string catalog; std::string prompt_name; std::string prompt; }; @@ -25,6 +26,13 @@ class DeletePromptStatement : public QueryStatement { std::string prompt_name; }; +class UpdatePromptScopeStatement : public QueryStatement { +public: + UpdatePromptScopeStatement() { type = StatementType::UPDATE_PROMPT_SCOPE; } + std::string catalog; + std::string prompt_name; +}; + class UpdatePromptStatement : public QueryStatement { public: UpdatePromptStatement() { type = StatementType::UPDATE_PROMPT; } diff --git a/src/include/flockmtl/custom_parser/query_parser.hpp b/src/include/flockmtl/custom_parser/query_parser.hpp index ee04c779..1586526a 100644 --- a/src/include/flockmtl/custom_parser/query_parser.hpp +++ b/src/include/flockmtl/custom_parser/query_parser.hpp @@ -16,6 +16,7 @@ namespace flockmtl { class QueryParser { public: std::string ParseQuery(const std::string& query); + std::string ParsePromptOrModel(Tokenizer tokenizer, const std::string& query); private: std::unique_ptr statement; diff --git a/src/include/flockmtl/custom_parser/query_statements.hpp b/src/include/flockmtl/custom_parser/query_statements.hpp index 6fd2b196..c220ab1e 100644 --- a/src/include/flockmtl/custom_parser/query_statements.hpp +++ b/src/include/flockmtl/custom_parser/query_statements.hpp @@ -13,11 +13,13 @@ enum class StatementType { CREATE_MODEL, DELETE_MODEL, UPDATE_MODEL, + UPDATE_MODEL_SCOPE, GET_MODEL, GET_ALL_MODEL, CREATE_PROMPT, DELETE_PROMPT, UPDATE_PROMPT, + UPDATE_PROMPT_SCOPE, GET_PROMPT, GET_ALL_PROMPT, }; diff --git a/src/model_manager/model.cpp b/src/model_manager/model.cpp index ca4343e4..a73b7398 100644 --- a/src/model_manager/model.cpp +++ b/src/model_manager/model.cpp @@ -37,19 +37,25 @@ void Model::LoadModelDetails(const nlohmann::json& model_json) { } std::tuple Model::GetQueriedModel(const std::string& model_name) { - std::string query = duckdb_fmt::format(" SELECT model, provider_name, model_args " - " FROM flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " - " WHERE model_name = '{}' ", - model_name); + const std::string query = + duckdb_fmt::format(" SELECT model, provider_name, model_args " + " FROM flockmtl_storage.flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" + " WHERE model_name = '{}'" + " UNION ALL " + " SELECT model, provider_name, model_args " + " FROM flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" + " WHERE model_name = '{}';", + model_name, model_name); auto con = Config::GetConnection(); auto query_result = con.Query(query); if (query_result->RowCount() == 0) { - query_result = con.Query(duckdb_fmt::format(" SELECT model, provider_name, model_args " - " FROM flockmtl_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE " - " WHERE model_name = '{}' ", - model_name)); + query_result = con.Query( + duckdb_fmt::format(" SELECT model, provider_name, model_args " + " FROM flockmtl_storage.flockmtl_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE " + " WHERE model_name = '{}' ", + model_name)); if (query_result->RowCount() == 0) { throw std::runtime_error("Model not found"); diff --git a/src/prompt_manager/prompt_manager.cpp b/src/prompt_manager/prompt_manager.cpp index a3004b68..53a992d8 100644 --- a/src/prompt_manager/prompt_manager.cpp +++ b/src/prompt_manager/prompt_manager.cpp @@ -81,21 +81,32 @@ PromptDetails PromptManager::CreatePromptDetails(const nlohmann::json& prompt_de "prompt_name with prompt version"); } prompt_details.prompt_name = prompt_details_json["prompt_name"]; - std::string prompt_details_query = duckdb_fmt::format(" SELECT prompt " - " FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " - " WHERE prompt_name = '{}'", - prompt_details.prompt_name); - std::string error_message = duckdb_fmt::format("The provided `{}` prompt ", prompt_details.prompt_name); + std::string error_message; + std::string version_where_clause; + std::string order_by_clause; if (prompt_details_json.contains("version")) { prompt_details.version = std::stoi(prompt_details_json["version"].get()); - prompt_details_query += duckdb_fmt::format(" AND version = {}", prompt_details.version); - error_message += duckdb_fmt::format("with version {} not found", prompt_details.version); + version_where_clause = duckdb_fmt::format(" AND version = {}", prompt_details.version); + error_message = duckdb_fmt::format("with version {} not found", prompt_details.version); } else { - prompt_details_query += " ORDER BY version DESC LIMIT 1;"; + order_by_clause = " ORDER BY version DESC LIMIT 1 "; error_message += "not found"; } + const auto prompt_details_query = + duckdb_fmt::format(" SELECT prompt, version " + " FROM flockmtl_storage.flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + " WHERE prompt_name = '{}'" + " {} " + " UNION ALL " + " SELECT prompt, version " + " FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + " WHERE prompt_name = '{}'" + " {} {}", + prompt_details.prompt_name, version_where_clause, prompt_details.prompt_name, + version_where_clause, order_by_clause); + error_message = duckdb_fmt::format("The provided `{}` prompt " + error_message, prompt_details.prompt_name); auto con = Config::GetConnection(); - auto query_result = con.Query(prompt_details_query); + const auto query_result = con.Query(prompt_details_query); if (query_result->RowCount() == 0) { throw std::runtime_error(error_message); } From 0dd08f12535966c6d3faa8a720baec3b9583f731 Mon Sep 17 00:00:00 2001 From: Anas Dorbani Date: Fri, 13 Dec 2024 15:45:41 +0100 Subject: [PATCH 2/4] update documentation --- .../resource-management/model-management.md | 44 +++++++++++++++---- .../resource-management/prompt-management.md | 39 +++++++++++++--- 2 files changed, 68 insertions(+), 15 deletions(-) diff --git a/docs/docs/resource-management/model-management.md b/docs/docs/resource-management/model-management.md index 5fa6f281..2c9c990c 100644 --- a/docs/docs/resource-management/model-management.md +++ b/docs/docs/resource-management/model-management.md @@ -18,39 +18,65 @@ Models are stored in a table with the following structure: | **Provider** | Source of the model (e.g., `openai`, `azure`, `ollama`) | | **Model Arguments** | JSON configuration parameters such as `context_window` size and `max_output_tokens` | -## 2. Management Commands +## 2. Introduction to Global and Local Models -- Retrieve all available models +FlockMTL supports two types of models: + +* **Global Models**: Created using `CREATE GLOBAL MODEL`. These are shared across different databases. +* **Local Models**: Created using `CREATE LOCAL MODEL` or `CREATE MODEL` (default if no type is specified). These are limited to a single database. + +## 3. Management Commands + +### Create Models + +- Create a global model: + +```sql +CREATE GLOBAL MODEL('model_name', 'model_type', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}) +``` + +- Create a local model (default if no type is specified): + +```sql +CREATE LOCAL MODEL('model_name', 'model_type', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}) +CREATE MODEL('model_name', 'model_type', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}) +``` + +### Retrieve Models + +- Retrieve all available models: ```sql GET MODELS; ``` -- Retrieve details of a specific model +- Retrieve details of a specific model: ```sql GET MODEL 'model_name'; ``` -- Create a new user-defined model +### Update Models + +- Update an existing model: ```sql -CREATE MODEL('model_name', 'model', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}) +UPDATE MODEL('model_name', 'model_type', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}); ``` -- Modify an existing user-defined model +- Toggle a model's state between global and local: ```sql -UPDATE MODEL('model_name', 'model', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}); +UPDATE MODEL 'model_name' TO GLOBAL; +UPDATE MODEL 'model_name' TO LOCAL; ``` - - Remove a user-defined model ```sql DELETE MODEL 'model_name'; ``` -## 3. SQL Query Examples +## 4. SQL Query Examples ### Semantic Text Completion diff --git a/docs/docs/resource-management/prompt-management.md b/docs/docs/resource-management/prompt-management.md index 6355cc1b..fb3366d2 100644 --- a/docs/docs/resource-management/prompt-management.md +++ b/docs/docs/resource-management/prompt-management.md @@ -16,30 +16,57 @@ The **Prompt Management** section provides guidance on how to manage and configu | **updated_at** | Timestamp of the last update | | **version** | Version number of the prompt | +## 1. Introduction to Global and Local Prompts + +FlockMTL supports two types of prompts: + +* **Global Prompts**: Created using `CREATE GLOBAL PROMPT`. These are shared across different databases. +* **Local Prompts**: Created using `CREATE LOCAL PROMPT` or `CREATE PROMPT` (default if no type is specified). These are limited to a single database. + ## 2. Management Commands -- Retrieve all available prompts +### Create Prompts + +* Create a global prompt: + +```sql +CREATE GLOBAL PROMPT('prompt_name', 'prompt'); +``` + +- Create a local prompt (default if no type is specified): + +```sql +CREATE LOCAL PROMPT('prompt_name', 'prompt'); +CREATE PROMPT('prompt_name', 'prompt'); +``` + +### Retrieve Prompts + +- Retrieve all available prompts: ```sql GET PROMPTS; ``` -- Retrieve details of a specific prompt +- Retrieve details of a specific prompt: ```sql GET PROMPT 'prompt_name'; ``` -- Create a new prompt +### Update Prompts + +- Update an existing prompt: ```sql -CREATE PROMPT('prompt_name', 'prompt'); +UPDATE PROMPT('prompt_name', 'prompt'); ``` -- Modify an existing prompt +- Toggle a prompt's state between global and local: ```sql -UPDATE PROMPT('prompt_name', 'prompt'); +UPDATE PROMPT 'prompt_name' TO GLOBAL; +UPDATE PROMPT 'prompt_name' TO LOCAL; ``` - Remove a prompt From 2bdd8f96201c350a28dc17fac4d5fab0ed25efba Mon Sep 17 00:00:00 2001 From: Anas Dorbani Date: Fri, 13 Dec 2024 16:02:50 +0100 Subject: [PATCH 3/4] replace with any type with auto --- src/custom_parser/query/model_parser.cpp | 28 +++++++++++------------ src/custom_parser/query/prompt_parser.cpp | 28 +++++++++++------------ src/custom_parser/query_parser.cpp | 2 +- src/custom_parser/tokenizer.cpp | 20 ++++++++-------- 4 files changed, 39 insertions(+), 39 deletions(-) diff --git a/src/custom_parser/query/model_parser.cpp b/src/custom_parser/query/model_parser.cpp index 0037aef9..6ccca96b 100644 --- a/src/custom_parser/query/model_parser.cpp +++ b/src/custom_parser/query/model_parser.cpp @@ -9,8 +9,8 @@ namespace flockmtl { void ModelParser::Parse(const std::string& query, std::unique_ptr& statement) { Tokenizer tokenizer(query); - Token token = tokenizer.NextToken(); - const std::string value = duckdb::StringUtil::Upper(token.value); + auto token = tokenizer.NextToken(); + const auto value = duckdb::StringUtil::Upper(token.value); if (token.type == TokenType::KEYWORD) { if (value == "CREATE") { @@ -30,8 +30,8 @@ void ModelParser::Parse(const std::string& query, std::unique_ptr& statement) { - Token token = tokenizer.NextToken(); - std::string value = duckdb::StringUtil::Upper(token.value); + auto token = tokenizer.NextToken(); + auto value = duckdb::StringUtil::Upper(token.value); std::string catalog; if (token.type == TokenType::KEYWORD && (value == "GLOBAL" || value == "LOCAL")) { @@ -55,7 +55,7 @@ void ModelParser::ParseCreateModel(Tokenizer& tokenizer, std::unique_ptr& statement) { - Token token = tokenizer.NextToken(); - std::string value = duckdb::StringUtil::Upper(token.value); + auto token = tokenizer.NextToken(); + auto value = duckdb::StringUtil::Upper(token.value); if (token.type != TokenType::KEYWORD || value != "MODEL") { throw std::runtime_error("Unknown keyword: " + token.value); } @@ -128,7 +128,7 @@ void ModelParser::ParseDeleteModel(Tokenizer& tokenizer, std::unique_ptr& statement) { - Token token = tokenizer.NextToken(); - std::string value = duckdb::StringUtil::Upper(token.value); + auto token = tokenizer.NextToken(); + auto value = duckdb::StringUtil::Upper(token.value); if (token.type != TokenType::KEYWORD || (value != "MODEL" && value != "MODELS")) { throw std::runtime_error("Expected 'MODEL' after 'GET'."); } @@ -260,7 +260,7 @@ void ModelParser::ParseGetModel(Tokenizer& tokenizer, std::unique_ptr& statement) { Tokenizer tokenizer(query); - Token token = tokenizer.NextToken(); - std::string value = duckdb::StringUtil::Upper(token.value); + auto token = tokenizer.NextToken(); + auto value = duckdb::StringUtil::Upper(token.value); if (token.type == TokenType::KEYWORD) { if (value == "CREATE") { @@ -31,8 +31,8 @@ void PromptParser::Parse(const std::string& query, std::unique_ptr& statement) { - Token token = tokenizer.NextToken(); - std::string value = duckdb::StringUtil::Upper(token.value); + auto token = tokenizer.NextToken(); + auto value = duckdb::StringUtil::Upper(token.value); std::string catalog; if (token.type == TokenType::KEYWORD && (value == "GLOBAL" || value == "LOCAL")) { @@ -87,8 +87,8 @@ void PromptParser::ParseCreatePrompt(Tokenizer& tokenizer, std::unique_ptr& statement) { - Token token = tokenizer.NextToken(); - std::string value = duckdb::StringUtil::Upper(token.value); + auto token = tokenizer.NextToken(); + auto value = duckdb::StringUtil::Upper(token.value); if (token.type != TokenType::KEYWORD || value != "PROMPT") { throw std::runtime_error("Unknown keyword: " + token.value); } @@ -97,7 +97,7 @@ void PromptParser::ParseDeletePrompt(Tokenizer& tokenizer, std::unique_ptr& statement) { - Token token = tokenizer.NextToken(); - std::string value = duckdb::StringUtil::Upper(token.value); + auto token = tokenizer.NextToken(); + auto value = duckdb::StringUtil::Upper(token.value); if (token.type != TokenType::KEYWORD || value != "PROMPT") { throw std::runtime_error("Unknown keyword: " + token.value); } @@ -151,7 +151,7 @@ void PromptParser::ParseUpdatePrompt(Tokenizer& tokenizer, std::unique_ptr& statement) { - Token token = tokenizer.NextToken(); - std::string value = duckdb::StringUtil::Upper(token.value); + auto token = tokenizer.NextToken(); + auto value = duckdb::StringUtil::Upper(token.value); if (token.type != TokenType::KEYWORD || (value != "PROMPT" && value != "PROMPTS")) { throw std::runtime_error("Unknown keyword: " + token.value); } @@ -197,7 +197,7 @@ void PromptParser::ParseGetPrompt(Tokenizer& tokenizer, std::unique_ptr 0) { if (query_[position_] == '{') { ++brace_count; @@ -53,40 +53,40 @@ Token Tokenizer::ParseJson() { if (brace_count > 0) { throw std::runtime_error("Unterminated JSON."); } - std::string value = query_.substr(start, position_ - start); + auto value = query_.substr(start, position_ - start); return {TokenType::JSON, value}; } // Parse a keyword (word made of letters) Token Tokenizer::ParseKeyword() { - int start = position_; + auto start = position_; while (position_ < query_.size() && (std::isalpha(query_[position_]) || query_[position_] == '_')) { ++position_; } - std::string value = query_.substr(start, position_ - start); + auto value = query_.substr(start, position_ - start); return {TokenType::KEYWORD, value}; } // Parse a symbol (single character) Token Tokenizer::ParseSymbol() { - char ch = query_[position_]; + auto ch = query_[position_]; ++position_; return {TokenType::SYMBOL, std::string(1, ch)}; } // Parse a number (sequence of digits) Token Tokenizer::ParseNumber() { - int start = position_; + auto start = position_; while (position_ < query_.size() && std::isdigit(query_[position_])) { ++position_; } - std::string value = query_.substr(start, position_ - start); + auto value = query_.substr(start, position_ - start); return {TokenType::NUMBER, value}; } // Parse a parenthesis Token Tokenizer::ParseParenthesis() { - char ch = query_[position_]; + auto ch = query_[position_]; ++position_; return {TokenType::PARENTHESIS, std::string(1, ch)}; } @@ -98,7 +98,7 @@ Token Tokenizer::GetNextToken() { return {TokenType::END_OF_FILE, ""}; } - char ch = query_[position_]; + auto ch = query_[position_]; if (ch == '\'') { return ParseStringLiteral(); } else if (ch == '{') { From e6d733c95f3f2ea84054be537f9a9efce3c21cd4 Mon Sep 17 00:00:00 2001 From: Anas Dorbani Date: Fri, 13 Dec 2024 16:19:24 +0100 Subject: [PATCH 4/4] update documentation --- .../resource-management/model-management.md | 74 ++++++++++--------- .../resource-management/prompt-management.md | 71 +++++++++--------- 2 files changed, 76 insertions(+), 69 deletions(-) diff --git a/docs/docs/resource-management/model-management.md b/docs/docs/resource-management/model-management.md index 2c9c990c..34e0aa8d 100644 --- a/docs/docs/resource-management/model-management.md +++ b/docs/docs/resource-management/model-management.md @@ -18,65 +18,39 @@ Models are stored in a table with the following structure: | **Provider** | Source of the model (e.g., `openai`, `azure`, `ollama`) | | **Model Arguments** | JSON configuration parameters such as `context_window` size and `max_output_tokens` | -## 2. Introduction to Global and Local Models +## 2. Management Commands -FlockMTL supports two types of models: - -* **Global Models**: Created using `CREATE GLOBAL MODEL`. These are shared across different databases. -* **Local Models**: Created using `CREATE LOCAL MODEL` or `CREATE MODEL` (default if no type is specified). These are limited to a single database. - -## 3. Management Commands - -### Create Models - -- Create a global model: - -```sql -CREATE GLOBAL MODEL('model_name', 'model_type', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}) -``` - -- Create a local model (default if no type is specified): - -```sql -CREATE LOCAL MODEL('model_name', 'model_type', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}) -CREATE MODEL('model_name', 'model_type', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}) -``` - -### Retrieve Models - -- Retrieve all available models: +- Retrieve all available models ```sql GET MODELS; ``` -- Retrieve details of a specific model: +- Retrieve details of a specific model ```sql GET MODEL 'model_name'; ``` -### Update Models - -- Update an existing model: +- Create a new user-defined model ```sql -UPDATE MODEL('model_name', 'model_type', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}); +CREATE MODEL('model_name', 'model', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}) ``` -- Toggle a model's state between global and local: +- Modify an existing user-defined model ```sql -UPDATE MODEL 'model_name' TO GLOBAL; -UPDATE MODEL 'model_name' TO LOCAL; +UPDATE MODEL('model_name', 'model', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}); ``` + - Remove a user-defined model ```sql DELETE MODEL 'model_name'; ``` -## 4. SQL Query Examples +## 3. SQL Query Examples ### Semantic Text Completion @@ -99,3 +73,33 @@ SELECT llm_complete( ) AS search_results FROM search_data; ``` + +## 4. Global and Local Models + +Model creation is database specific if you want it to be available irrespective of the database then make it a GLOBAL mode. Note that previously, the creation was specific to the running database, which is LOCAL by default and the keyword LOCAL is optional. + +### Create Models + +- Create a global model: + +```sql +CREATE GLOBAL MODEL('model_name', 'model_type', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}) +``` + +- Create a local model (default if no type is specified): + +```sql +CREATE LOCAL MODEL('model_name', 'model_type', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}) +CREATE MODEL('model_name', 'model_type', 'provider', {'context_window': 128000, 'max_output_tokens': 8000}) +``` + +### Toggle Model State + +- Toggle a model's state between global and local: + +```sql +UPDATE MODEL 'model_name' TO GLOBAL; +UPDATE MODEL 'model_name' TO LOCAL; +``` + +All the other queries remain the same for both global and local prompts. diff --git a/docs/docs/resource-management/prompt-management.md b/docs/docs/resource-management/prompt-management.md index fb3366d2..bf3f0791 100644 --- a/docs/docs/resource-management/prompt-management.md +++ b/docs/docs/resource-management/prompt-management.md @@ -7,7 +7,7 @@ sidebar_position: 2 The **Prompt Management** section provides guidance on how to manage and configure prompts for **analytics and semantic analysis tasks** within FlockMTL. Prompts guide models in generating specific outputs for tasks like content generation, summarization, and ranking. Each database is configured with its own prompt management table during the initial load. -### Prompt Table Structure +### 1. Prompt Table Structure | **Column Name** | **Description** | | --------------- | --------------------------------- | @@ -16,57 +16,30 @@ The **Prompt Management** section provides guidance on how to manage and configu | **updated_at** | Timestamp of the last update | | **version** | Version number of the prompt | -## 1. Introduction to Global and Local Prompts - -FlockMTL supports two types of prompts: - -* **Global Prompts**: Created using `CREATE GLOBAL PROMPT`. These are shared across different databases. -* **Local Prompts**: Created using `CREATE LOCAL PROMPT` or `CREATE PROMPT` (default if no type is specified). These are limited to a single database. - ## 2. Management Commands -### Create Prompts - -* Create a global prompt: - -```sql -CREATE GLOBAL PROMPT('prompt_name', 'prompt'); -``` - -- Create a local prompt (default if no type is specified): - -```sql -CREATE LOCAL PROMPT('prompt_name', 'prompt'); -CREATE PROMPT('prompt_name', 'prompt'); -``` - -### Retrieve Prompts - -- Retrieve all available prompts: +- Retrieve all available prompts ```sql GET PROMPTS; ``` -- Retrieve details of a specific prompt: +- Retrieve details of a specific prompt ```sql GET PROMPT 'prompt_name'; ``` -### Update Prompts - -- Update an existing prompt: +- Create a new prompt ```sql -UPDATE PROMPT('prompt_name', 'prompt'); +CREATE PROMPT('prompt_name', 'prompt'); ``` -- Toggle a prompt's state between global and local: +- Modify an existing prompt ```sql -UPDATE PROMPT 'prompt_name' TO GLOBAL; -UPDATE PROMPT 'prompt_name' TO LOCAL; +UPDATE PROMPT('prompt_name', 'prompt'); ``` - Remove a prompt @@ -102,3 +75,33 @@ SELECT llm_complete( ) AS review_summary FROM reviews; ``` + +## 4. Global and Local Prompts + +Prompt creation is database specific if you want it to be available irrespective of the database then make it a GLOBAL mode. Note that previously, the creation was specific to the running database, which is LOCAL by default and the keyword LOCAL is optional. + +### Create Prompts + +* Create a global prompt: + +```sql +CREATE GLOBAL PROMPT('prompt_name', 'prompt'); +``` + +- Create a local prompt (default if no type is specified): + +```sql +CREATE LOCAL PROMPT('prompt_name', 'prompt'); +CREATE PROMPT('prompt_name', 'prompt'); +``` + +### Toggle Prompt State + +- Toggle a prompt's state between global and local: + +```sql +UPDATE PROMPT 'prompt_name' TO GLOBAL; +UPDATE PROMPT 'prompt_name' TO LOCAL; +``` + +All the other queries remain the same for both global and local prompts.