diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 2888fcfed1e15..08faedc601ded 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -11,6 +11,8 @@ # include #endif +#include + #include #include #include @@ -25,6 +27,16 @@ #include "json.hpp" #include "llama-cpp.h" +static bool sigint = false; + +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) +static void sigint_handler(int signo) { + if (signo == SIGINT) { + sigint = true; + } +} +#endif + GGML_ATTRIBUTE_FORMAT(1, 2) static std::string fmt(const char * fmt, ...) { va_list ap; @@ -801,7 +813,20 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str static int read_user_input(std::string & user) { std::getline(std::cin, user); - return user.empty(); // Should have data in happy path + if (sigint || std::cin.eof()) { + printf("\n"); + return 1; + } + + if (user == "/bye") { + return 1; + } + + if (user.empty()) { + return 2; + } + + return 0; // Should have data in happy path } // Function to generate a response based on the prompt @@ -868,7 +893,25 @@ static bool is_stdout_a_terminal() { #endif } -// Function to tokenize the prompt +// Function to handle user input +static int get_user_input(std::string & user_input, const std::string & user) { + while (true) { + const int ret = handle_user_input(user_input, user); + if (ret == 1) { + return 1; + } + + if (ret == 2) { + continue; + } + + break; + } + + return 0; +} + +// Main chat loop function static int chat_loop(LlamaData & llama_data, const std::string & user) { int prev_len = 0; llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); @@ -876,7 +919,8 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) { while (true) { // Get user input std::string user_input; - while (handle_user_input(user_input, user)) { + if (get_user_input(user_input, user) == 1) { + return 0; } add_message("user", user.empty() ? user_input : user, llama_data); @@ -917,7 +961,23 @@ static std::string read_pipe_data() { return result.str(); } +static void ctrl_c_handling() { +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = sigint_handler; + sigemptyset(&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); +#elif defined(_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif +} + int main(int argc, const char ** argv) { + ctrl_c_handling(); Opt opt; const int ret = opt.init(argc, argv); if (ret == 2) {