diff --git a/examples/server/b64.cpp b/examples/server/b64.cpp new file mode 100644 index 00000000..e0aacf04 --- /dev/null +++ b/examples/server/b64.cpp @@ -0,0 +1,42 @@ + +//FROM +//https://stackoverflow.com/a/34571089/5155484 + +static const std::string b = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";//= +static std::string base64_encode(const std::string &in) { + std::string out; + + int val=0, valb=-6; + for (uint8_t c : in) { + val = (val<<8) + c; + valb += 8; + while (valb>=0) { + out.push_back(b[(val>>valb)&0x3F]); + valb-=6; + } + } + if (valb>-6) out.push_back(b[((val<<8)>>(valb+8))&0x3F]); + while (out.size()%4) out.push_back('='); + return out; +} + + +static std::string base64_decode(const std::string &in) { + + std::string out; + + std::vector T(256,-1); + for (int i=0; i<64; i++) T[b[i]] = i; + + int val=0, valb=-8; + for (uint8_t c : in) { + if (T[c] == -1) break; + val = (val<<6) + T[c]; + valb += 6; + if (valb>=0) { + out.push_back(char((val>>valb)&0xFF)); + valb-=8; + } + } + return out; +} \ No newline at end of file diff --git a/examples/server/main.cpp b/examples/server/main.cpp index 73ea0b21..419c15ae 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -7,7 +7,9 @@ #include // #include "preprocessing.hpp" +#include "b64.cpp" #include "flux.hpp" +#include "json.hpp" #include "stable-diffusion.h" #define STB_IMAGE_IMPLEMENTATION @@ -49,7 +51,6 @@ const char* schedule_str[] = { "ays", }; - enum SDMode { TXT2IMG, IMG2IMG, @@ -86,7 +87,6 @@ struct SDParams { int height = 512; int batch_count = 1; - sample_method_t sample_method = EULER_A; schedule_t schedule = DEFAULT; int sample_steps = 20; @@ -100,9 +100,9 @@ struct SDParams { bool vae_on_cpu = false; bool color = false; - //server things - int port = 8080; - std::string host = "127.0.0.1"; + // server things + int port = 8080; + std::string host = "127.0.0.1"; }; void print_params(SDParams params) { @@ -227,7 +227,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.vae_path = argv[i]; - // TODO Tiny AE + // TODO Tiny AE } else if (arg == "--type") { if (++i >= argc) { invalid_arg = true; @@ -565,17 +565,104 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { fflush(out_stream); } -static void log_server_request(const httplib::Request & req, const httplib::Response & res) { +static void log_server_request(const httplib::Request& req, const httplib::Response& res) { printf("request: %s %s (%s)\n", req.method.c_str(), req.path.c_str(), req.body.c_str()); } +void parseJsonPrompt(std::string json_str, SDParams* params) { + using namespace nlohmann; + json payload = json::parse(json_str); + // if no exception, the request is a json object + // now we try to get the new param values from the payload object + // const char *prompt, const char *negative_prompt, int clip_skip, float cfg_scale, float guidance, int width, int height, sample_method_t sample_method, int sample_steps, int64_t seed, int batch_count, const sd_image_t *control_cond, float control_strength, float style_strength, bool normalize_input, const char *input_id_images_path + try { + std::string prompt = payload["prompt"]; + params->prompt = prompt; + } catch (...) { + } + try { + std::string negative_prompt = payload["negative_prompt"]; + params->negative_prompt = negative_prompt; + } catch (...) { + } + try { + int clip_skip = payload["clip_skip"]; + params->clip_skip = clip_skip; + } catch (...) { + } + try { + float cfg_scale = payload["cfg_scale"]; + params->cfg_scale = cfg_scale; + } catch (...) { + } + try { + float guidance = payload["guidance"]; + params->guidance = guidance; + } catch (...) { + } + try { + int width = payload["width"]; + params->width = width; + } catch (...) { + } + try { + int height = payload["height"]; + params->height = height; + } catch (...) { + } + try { + std::string sample_method = payload["sample_method"]; + // TODO map to enum value + LOG_WARN("sample_method is not supported yet\n"); + } catch (...) { + } + try { + int sample_steps = payload["sample_steps"]; + params->sample_steps = sample_steps; + } catch (...) { + } + try { + int64_t seed = payload["seed"]; + params->seed = seed; + } catch (...) { + } + try { + int batch_count = payload["batch_count"]; + params->batch_count = batch_count; + } catch (...) { + } + + try { + std::string control_cond = payload["control_cond"]; + // TODO map to enum value + LOG_WARN("control_cond is not supported yet\n"); + } catch (...) { + } + try { + float control_strength = payload["control_strength"]; + } catch (...) { + } + try { + float style_strength = payload["style_strength"]; + } catch (...) { + } + try { + bool normalize_input = payload["normalize_input"]; + params->normalize_input = normalize_input; + } catch (...) { + } + try { + std::string input_id_images_path = payload["input_id_images_path"]; + // TODO replace with b64 image maybe? + } catch (...) { + } +} int main(int argc, const char* argv[]) { SDParams params; parse_args(argc, argv, params); - sd_set_log_callback(sd_log_cb, (void*)¶ms); if (params.verbose) { @@ -583,9 +670,8 @@ int main(int argc, const char* argv[]) { printf("%s", sd_get_system_info()); } + bool vae_decode_only = true; - bool vae_decode_only = true; - sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(), params.clip_l_path.c_str(), params.t5xxl_path.c_str(), @@ -614,33 +700,48 @@ int main(int argc, const char* argv[]) { int n_prompts = 0; - const auto txt2imgRequest = [&sd_ctx, ¶ms, &n_prompts](const httplib::Request & req, httplib::Response & res) { - //TODO: proper payloads - std::string prompt = req.body; - if(!prompt.empty()){ - params.prompt = prompt; - }else{ - params.seed+=1; + const auto txt2imgRequest = [&sd_ctx, ¶ms, &n_prompts](const httplib::Request& req, httplib::Response& res) { + LOG_INFO("raw body is: %s\n", req.body.c_str()); + // parse req.body as json using jsoncpp + using json = nlohmann::json; + + try { + std::string json_str = req.body; + parseJsonPrompt(json_str, ¶ms); + } catch (json::parse_error& e) { + // assume the request is just a prompt + LOG_WARN("Failed to parse json: %s\n Assuming it's just a prompt...\n", e.what()); + std::string prompt = req.body; + if (!prompt.empty()) { + params.prompt = prompt; + } else { + params.seed += 1; + } + } catch (...) { + // Handle any other type of exception + LOG_ERROR("An unexpected error occurred\n"); } + LOG_INFO("prompt is: %s\n", params.prompt.c_str()); + { sd_image_t* results; results = txt2img(sd_ctx, - params.prompt.c_str(), - params.negative_prompt.c_str(), - params.clip_skip, - params.cfg_scale, - params.guidance, - params.width, - params.height, - params.sample_method, - params.sample_steps, - params.seed, - params.batch_count, - NULL, - 1, - params.style_ratio, - params.normalize_input, - ""); + params.prompt.c_str(), + params.negative_prompt.c_str(), + params.clip_skip, + params.cfg_scale, + params.guidance, + params.width, + params.height, + params.sample_method, + params.sample_steps, + params.seed, + params.batch_count, + NULL, + 1, + params.style_ratio, + params.normalize_input, + ""); if (results == NULL) { printf("generate failed\n"); @@ -650,52 +751,67 @@ int main(int argc, const char* argv[]) { size_t last = params.output_path.find_last_of("."); std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path; + json images_json = json::array(); for (int i = 0; i < params.batch_count; i++) { if (results[i].data == NULL) { continue; } - std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1 + n_prompts*params.batch_count) + ".png" : dummy_name + ".png"; + // TODO allow disable save to disk + std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1 + n_prompts * params.batch_count) + ".png" : dummy_name + ".png"; stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, - results[i].data, 0, get_image_params(params, params.seed + i).c_str()); + results[i].data, 0, get_image_params(params, params.seed + i).c_str()); printf("save result image to '%s'\n", final_image_path.c_str()); - // Todo: return base64 encoded image via websocket? + // Todo: return base64 encoded image via httplib::Response& res + + int len; + unsigned char* png = stbi_write_png_to_mem((const unsigned char*)results[i].data, 0, results[i].width, results[i].height, results[i].channel, &len, NULL); + + std::string data_str(png, png + len); + std::string encoded_img = base64_encode(data_str); + + images_json.push_back({{"width", results[i].width}, + {"height", results[i].height}, + {"channel", results[i].channel}, + {"data", encoded_img}, + {"encoding", "png"}}); + free(results[i].data); results[i].data = NULL; } free(results); n_prompts++; + res.set_content(images_json.dump(), "application/json"); } return 0; }; - std::unique_ptr svr; svr.reset(new httplib::Server()); svr->set_default_headers({{"Server", "sd.cpp"}}); // CORS preflight - svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) { + svr->Options(R"(.*)", [](const httplib::Request&, httplib::Response& res) { // Access-Control-Allow-Origin is already set by middleware res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "POST"); - res.set_header("Access-Control-Allow-Headers", "*"); - return res.set_content("", "text/html"); // blank response, no data + res.set_header("Access-Control-Allow-Methods", "POST"); + res.set_header("Access-Control-Allow-Headers", "*"); + return res.set_content("", "text/html"); // blank response, no data }); svr->set_logger(log_server_request); svr->Post("/txt2img", txt2imgRequest); - // bind HTTP listen port, run the HTTP server in a thread if (!svr->bind_to_port(params.host, params.port)) { - //TODO: Error message + // TODO: Error message return 1; - } + } std::thread t([&]() { svr->listen_after_bind(); }); svr->wait_until_ready(); - printf("Server listening at %s:%d\n",params.host.c_str(),params.port); + printf("Server listening at %s:%d\n", params.host.c_str(), params.port); - while(1); + while (1) + ; free_sd_ctx(sd_ctx);