Skip to content

Commit

Permalink
Enhance user input handling for llama-run
Browse files Browse the repository at this point in the history
The main motivation for this change is it was not handing
ctrl-c/ctrl-d correctly. Modify `read_user_input` to handle EOF,
"/bye" command, and empty input cases. Introduce `get_user_input`
function to manage user input loop and handle different return
cases.

Signed-off-by: Eric Curtin <[email protected]>
  • Loading branch information
ericcurtin committed Jan 8, 2025
1 parent 99a3755 commit a85db13
Showing 1 changed file with 63 additions and 3 deletions.
66 changes: 63 additions & 3 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# include <curl/curl.h>
#endif

#include <signal.h>

#include <climits>
#include <cstdarg>
#include <cstdio>
Expand All @@ -25,6 +27,16 @@
#include "json.hpp"
#include "llama-cpp.h"

static bool sigint = false;

#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
static void sigint_handler(int signo) {
if (signo == SIGINT) {
sigint = true;
}
}
#endif

GGML_ATTRIBUTE_FORMAT(1, 2)
static std::string fmt(const char * fmt, ...) {
va_list ap;
Expand Down Expand Up @@ -801,7 +813,20 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str

static int read_user_input(std::string & user) {
std::getline(std::cin, user);
return user.empty(); // Should have data in happy path
if (sigint || std::cin.eof()) {
printf("\n");
return 1;
}

if (user == "/bye") {
return 1;
}

if (user.empty()) {
return 2;
}

return 0; // Should have data in happy path
}

// Function to generate a response based on the prompt
Expand Down Expand Up @@ -868,15 +893,34 @@ static bool is_stdout_a_terminal() {
#endif
}

// Function to tokenize the prompt
// Function to handle user input
static int get_user_input(std::string & user_input, const std::string & user) {
while (true) {
const int ret = handle_user_input(user_input, user);
if (ret == 1) {
return 1;
}

if (ret == 2) {
continue;
}

break;
}

return 0;
}

// Main chat loop function
static int chat_loop(LlamaData & llama_data, const std::string & user) {
int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
std::string user_input;
while (handle_user_input(user_input, user)) {
if (get_user_input(user_input, user) == 1) {
return 0;
}

add_message("user", user.empty() ? user_input : user, llama_data);
Expand Down Expand Up @@ -917,7 +961,23 @@ static std::string read_pipe_data() {
return result.str();
}

static void ctrl_c_handling() {
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = sigint_handler;
sigemptyset(&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
#elif defined(_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
}

int main(int argc, const char ** argv) {
ctrl_c_handling();
Opt opt;
const int ret = opt.init(argc, argv);
if (ret == 2) {
Expand Down

0 comments on commit a85db13

Please sign in to comment.