diff --git a/plugins/cuda/cuda_plugin.c b/plugins/cuda/cuda_plugin.c index 718db30251..52c494dd09 100644 --- a/plugins/cuda/cuda_plugin.c +++ b/plugins/cuda/cuda_plugin.c @@ -26,6 +26,13 @@ #define ACTION_RESTORE "restore" #define ACTION_UNLOCK "unlock" +typedef enum { + CUDA_TASK_RUNNING = 0, + CUDA_TASK_LOCKED, + CUDA_TASK_CHECKPOINTED, + CUDA_TASK_UNKNOWN = -1 +} cuda_task_state_t; + #define CUDA_CKPT_BUF_SIZE (128) #ifdef LOG_PREFIX @@ -43,6 +50,7 @@ bool plugin_added_to_inventory = false; struct pid_info { int pid; char checkpointed; + cuda_task_state_t original_state; struct list_head list; }; @@ -62,7 +70,7 @@ static void dealloc_pid_buffer(struct list_head *pid_buf) } } -static int add_pid_to_buf(struct list_head *pid_buf, int pid) +static int add_pid_to_buf(struct list_head *pid_buf, int pid, cuda_task_state_t state) { struct pid_info *new = xmalloc(sizeof(*new)); @@ -72,25 +80,12 @@ static int add_pid_to_buf(struct list_head *pid_buf, int pid) new->pid = pid; new->checkpointed = 0; + new->original_state = state; list_add_tail(&new->list, pid_buf); return 0; } -static int update_checkpointed_pid(struct list_head *pid_buf, int pid) -{ - struct pid_info *info; - - list_for_each_entry(info, pid_buf, list) { - if (info->pid == pid) { - info->checkpointed = 1; - return 0; - } - } - - return -1; -} - static int launch_cuda_checkpoint(const char **args, char *buf, int buf_size) { #define READ 0 @@ -231,6 +226,38 @@ static int get_cuda_restore_tid(int root_pid) return atoi(pid_out); } +static cuda_task_state_t get_task_state_enum(const char *state_str) +{ + if (strncmp(state_str, "running", 7) == 0) + return CUDA_TASK_RUNNING; + + if (strncmp(state_str, "locked", 6) == 0) + return CUDA_TASK_LOCKED; + + if (strncmp(state_str, "checkpointed", 12) == 0) + return CUDA_TASK_CHECKPOINTED; + + pr_err("Unknown CUDA state: %s\n", state_str); + return CUDA_TASK_UNKNOWN; +} + +/* Retrieve current CUDA state of a process */ +static cuda_task_state_t get_cuda_state(pid_t pid) +{ + char pid_buf[16]; + char state_str[CUDA_CKPT_BUF_SIZE]; + const char *args[] = { CUDA_CHECKPOINT, "--get-state", "--pid", pid_buf, NULL }; + + snprintf(pid_buf, sizeof(pid_buf), "%d", pid); + + if (launch_cuda_checkpoint(args, state_str, sizeof(state_str))) { + pr_err("Failed to launch cuda-checkpoint to retrieve state: %s\n", state_str); + return CUDA_TASK_UNKNOWN; + } + + return get_task_state_enum(state_str); +} + static int cuda_process_checkpoint_action(int pid, const char *action, unsigned int timeout, char *msg_buf, int buf_size) { @@ -319,6 +346,8 @@ int cuda_plugin_checkpoint_devices(int pid) int int_ret; int status; k_rtsigset_t save_sigset; + struct pid_info *task_info; + bool pid_found = false; if (plugin_disabled) { return -ENOTSUP; @@ -336,6 +365,26 @@ int cuda_plugin_checkpoint_devices(int pid) return 0; } + /* Check if the process is already in a checkpointed state */ + list_for_each_entry(task_info, &cuda_pids, list) { + if (task_info->pid == restore_tid) { + if (task_info->original_state == CUDA_TASK_CHECKPOINTED) { + pr_info("pid %d already in a checkpointed state\n", pid); + return 0; + } + pid_found = true; + break; + } + } + + if (pid_found == false) { + /* We return an error here. The task should be restored + * to its original state at cuda_plugin_fini(). + */ + pr_err("Failed to track pid %d\n", pid); + return -1; + } + pr_info("Checkpointing CUDA devices on pid %d restore_tid %d\n", pid, restore_tid); /* We need to resume the checkpoint thread to prepare the mappings for * checkpointing @@ -348,22 +397,8 @@ int cuda_plugin_checkpoint_devices(int pid) pr_err("CHECKPOINT_DEVICES failed with %s\n", msg_buf); goto interrupt; } - status = update_checkpointed_pid(&cuda_pids, pid); - if (status) { - pr_err("Failed to track checkpointed pid %d\n", pid); - status = cuda_process_checkpoint_action(pid, ACTION_RESTORE, 0, msg_buf, sizeof(msg_buf)); - if (status) { - pr_err("Failed to restore process after error %s on pid %d\n", msg_buf, pid); - } - } - if (!status && !plugin_added_to_inventory) { - status = add_inventory_plugin(CR_PLUGIN_DESC.name); - if (status) - pr_err("Failed to add CUDA plugin to inventory image\n"); - else - plugin_added_to_inventory = true; - } + task_info->checkpointed = 1; interrupt: int_ret = interrupt_restore_thread(restore_tid, &save_sigset); @@ -376,6 +411,7 @@ int cuda_plugin_pause_devices(int pid) { int restore_tid; char msg_buf[CUDA_CKPT_BUF_SIZE]; + cuda_task_state_t task_state; if (plugin_disabled) { return -ENOTSUP; @@ -388,6 +424,34 @@ int cuda_plugin_pause_devices(int pid) return 0; } + task_state = get_cuda_state(restore_tid); + if (task_state == CUDA_TASK_UNKNOWN) { + pr_err("Failed to get CUDA state for PID %d\n", restore_tid); + return -1; + } + + if (!plugin_added_to_inventory) { + if (add_inventory_plugin(CR_PLUGIN_DESC.name)) { + pr_err("Failed to add CUDA plugin to inventory image\n"); + return -1; + } + plugin_added_to_inventory = true; + } + + if (task_state == CUDA_TASK_LOCKED) { + pr_info("pid %d already in a locked state\n", pid); + /* Leave this PID in a "locked" state at resume_device() */ + add_pid_to_buf(&cuda_pids, restore_tid, CUDA_TASK_LOCKED); + return 0; + } + + if (task_state == CUDA_TASK_CHECKPOINTED) { + /* We need to skip this PID in cuda_plugin_checkpoint_devices(), + * and leave it in a "checkpoined" state at resume_device(). */ + add_pid_to_buf(&cuda_pids, restore_tid, CUDA_TASK_CHECKPOINTED); + return 0; + } + pr_info("pausing devices on pid %d\n", pid); int status = cuda_process_checkpoint_action(pid, ACTION_LOCK, opts.timeout * 1000, msg_buf, sizeof(msg_buf)); if (status) { @@ -397,7 +461,7 @@ int cuda_plugin_pause_devices(int pid) return -1; } - if (add_pid_to_buf(&cuda_pids, pid)) { + if (add_pid_to_buf(&cuda_pids, pid, CUDA_TASK_RUNNING)) { pr_err("unable to track paused pid %d\n", pid); goto unlock; } @@ -412,7 +476,7 @@ int cuda_plugin_pause_devices(int pid) } CR_PLUGIN_REGISTER_HOOK(CR_PLUGIN_HOOK__PAUSE_DEVICES, cuda_plugin_pause_devices) -int resume_device(int pid, int checkpointed) +int resume_device(int pid, int checkpointed, cuda_task_state_t original_state) { char msg_buf[CUDA_CKPT_BUF_SIZE]; int status; @@ -420,6 +484,11 @@ int resume_device(int pid, int checkpointed) int int_ret; k_rtsigset_t save_sigset; + if (original_state == CUDA_TASK_UNKNOWN) { + pr_info("skip resume for PID %d (unknown state)\n", pid); + return 0; + } + int restore_tid = get_cuda_restore_tid(pid); if (restore_tid == -1) { pr_info("No need to resume devices on pid %d\n", pid); @@ -439,7 +508,8 @@ int resume_device(int pid, int checkpointed) return -1; } - if (checkpointed) { + if (checkpointed && (original_state == CUDA_TASK_RUNNING || original_state == CUDA_TASK_LOCKED)) { + /* If the process was locked or running before we checkpoint it, we need to restore it */ status = cuda_process_checkpoint_action(pid, ACTION_RESTORE, 0, msg_buf, sizeof(msg_buf)); if (status) { pr_err("RESUME_DEVICES RESTORE failed with %s\n", msg_buf); @@ -448,10 +518,13 @@ int resume_device(int pid, int checkpointed) } } - status = cuda_process_checkpoint_action(pid, ACTION_UNLOCK, 0, msg_buf, sizeof(msg_buf)); - if (status) { - pr_err("RESUME_DEVICES UNLOCK failed with %s\n", msg_buf); - ret = -1; + if (original_state == CUDA_TASK_RUNNING) { + /* If the process was running before we paused it, we need to unlock it */ + status = cuda_process_checkpoint_action(pid, ACTION_UNLOCK, 0, msg_buf, sizeof(msg_buf)); + if (status) { + pr_err("RESUME_DEVICES UNLOCK failed with %s\n", msg_buf); + ret = -1; + } } interrupt: @@ -466,7 +539,12 @@ int cuda_plugin_resume_devices_late(int pid) return -ENOTSUP; } - return resume_device(pid, 1); + /* RESUME_DEVICES_LATE is used during `criu restore`. + * Here, we assume that users expect the target process + * to be in a running after restore, even if it was + * in a "locked" or "checkpointed" state during `criu dump`. + */ + return resume_device(pid, 1, CUDA_TASK_RUNNING); } CR_PLUGIN_REGISTER_HOOK(CR_PLUGIN_HOOK__RESUME_DEVICES_LATE, cuda_plugin_resume_devices_late) @@ -542,7 +620,7 @@ void cuda_plugin_fini(int stage, int ret) if (stage == CR_PLUGIN_STAGE__DUMP && (opts.final_state == TASK_ALIVE || ret != 0)) { struct pid_info *info; list_for_each_entry(info, &cuda_pids, list) { - resume_device(info->pid, info->checkpointed); + resume_device(info->pid, info->checkpointed, info->original_state); } } if (stage == CR_PLUGIN_STAGE__DUMP) {