Skip to content

Commit

Permalink
Server: accept json inputs + return bas64 image
Browse files Browse the repository at this point in the history
  • Loading branch information
stduhpf committed Oct 4, 2024
1 parent 8529431 commit 533da39
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 46 deletions.
42 changes: 42 additions & 0 deletions examples/server/b64.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@

//FROM
//https://stackoverflow.com/a/34571089/5155484

static const std::string b = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";//=
static std::string base64_encode(const std::string &in) {
std::string out;

int val=0, valb=-6;
for (uint8_t c : in) {
val = (val<<8) + c;
valb += 8;
while (valb>=0) {
out.push_back(b[(val>>valb)&0x3F]);
valb-=6;
}
}
if (valb>-6) out.push_back(b[((val<<8)>>(valb+8))&0x3F]);
while (out.size()%4) out.push_back('=');
return out;
}


static std::string base64_decode(const std::string &in) {

std::string out;

std::vector<int> T(256,-1);
for (int i=0; i<64; i++) T[b[i]] = i;

int val=0, valb=-8;
for (uint8_t c : in) {
if (T[c] == -1) break;
val = (val<<6) + T[c];
valb += 6;
if (valb>=0) {
out.push_back(char((val>>valb)&0xFF));
valb-=8;
}
}
return out;
}
208 changes: 162 additions & 46 deletions examples/server/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#include <vector>

// #include "preprocessing.hpp"
#include "b64.cpp"
#include "flux.hpp"
#include "json.hpp"
#include "stable-diffusion.h"

#define STB_IMAGE_IMPLEMENTATION
Expand Down Expand Up @@ -49,7 +51,6 @@ const char* schedule_str[] = {
"ays",
};


enum SDMode {
TXT2IMG,
IMG2IMG,
Expand Down Expand Up @@ -86,7 +87,6 @@ struct SDParams {
int height = 512;
int batch_count = 1;


sample_method_t sample_method = EULER_A;
schedule_t schedule = DEFAULT;
int sample_steps = 20;
Expand All @@ -100,9 +100,9 @@ struct SDParams {
bool vae_on_cpu = false;
bool color = false;

//server things
int port = 8080;
std::string host = "127.0.0.1";
// server things
int port = 8080;
std::string host = "127.0.0.1";
};

void print_params(SDParams params) {
Expand Down Expand Up @@ -227,7 +227,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.vae_path = argv[i];
// TODO Tiny AE
// TODO Tiny AE
} else if (arg == "--type") {
if (++i >= argc) {
invalid_arg = true;
Expand Down Expand Up @@ -565,27 +565,113 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
fflush(out_stream);
}

static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
static void log_server_request(const httplib::Request& req, const httplib::Response& res) {
printf("request: %s %s (%s)\n", req.method.c_str(), req.path.c_str(), req.body.c_str());
}

void parseJsonPrompt(std::string json_str, SDParams* params) {
using namespace nlohmann;
json payload = json::parse(json_str);
// if no exception, the request is a json object
// now we try to get the new param values from the payload object
// const char *prompt, const char *negative_prompt, int clip_skip, float cfg_scale, float guidance, int width, int height, sample_method_t sample_method, int sample_steps, int64_t seed, int batch_count, const sd_image_t *control_cond, float control_strength, float style_strength, bool normalize_input, const char *input_id_images_path
try {
std::string prompt = payload["prompt"];
params->prompt = prompt;
} catch (...) {
}
try {
std::string negative_prompt = payload["negative_prompt"];
params->negative_prompt = negative_prompt;
} catch (...) {
}
try {
int clip_skip = payload["clip_skip"];
params->clip_skip = clip_skip;
} catch (...) {
}
try {
float cfg_scale = payload["cfg_scale"];
params->cfg_scale = cfg_scale;
} catch (...) {
}
try {
float guidance = payload["guidance"];
params->guidance = guidance;
} catch (...) {
}
try {
int width = payload["width"];
params->width = width;
} catch (...) {
}
try {
int height = payload["height"];
params->height = height;
} catch (...) {
}
try {
std::string sample_method = payload["sample_method"];
// TODO map to enum value
LOG_WARN("sample_method is not supported yet\n");
} catch (...) {
}
try {
int sample_steps = payload["sample_steps"];
params->sample_steps = sample_steps;
} catch (...) {
}
try {
int64_t seed = payload["seed"];
params->seed = seed;
} catch (...) {
}
try {
int batch_count = payload["batch_count"];
params->batch_count = batch_count;
} catch (...) {
}

try {
std::string control_cond = payload["control_cond"];
// TODO map to enum value
LOG_WARN("control_cond is not supported yet\n");
} catch (...) {
}
try {
float control_strength = payload["control_strength"];
} catch (...) {
}
try {
float style_strength = payload["style_strength"];
} catch (...) {
}
try {
bool normalize_input = payload["normalize_input"];
params->normalize_input = normalize_input;
} catch (...) {
}
try {
std::string input_id_images_path = payload["input_id_images_path"];
// TODO replace with b64 image maybe?
} catch (...) {
}
}

int main(int argc, const char* argv[]) {
SDParams params;

parse_args(argc, argv, params);


sd_set_log_callback(sd_log_cb, (void*)&params);

if (params.verbose) {
print_params(params);
printf("%s", sd_get_system_info());
}

bool vae_decode_only = true;

bool vae_decode_only = true;

sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
params.clip_l_path.c_str(),
params.t5xxl_path.c_str(),
Expand Down Expand Up @@ -614,33 +700,48 @@ int main(int argc, const char* argv[]) {

int n_prompts = 0;

const auto txt2imgRequest = [&sd_ctx, &params, &n_prompts](const httplib::Request & req, httplib::Response & res) {
//TODO: proper payloads
std::string prompt = req.body;
if(!prompt.empty()){
params.prompt = prompt;
}else{
params.seed+=1;
const auto txt2imgRequest = [&sd_ctx, &params, &n_prompts](const httplib::Request& req, httplib::Response& res) {
LOG_INFO("raw body is: %s\n", req.body.c_str());
// parse req.body as json using jsoncpp
using json = nlohmann::json;

try {
std::string json_str = req.body;
parseJsonPrompt(json_str, &params);
} catch (json::parse_error& e) {
// assume the request is just a prompt
LOG_WARN("Failed to parse json: %s\n Assuming it's just a prompt...\n", e.what());
std::string prompt = req.body;
if (!prompt.empty()) {
params.prompt = prompt;
} else {
params.seed += 1;
}
} catch (...) {
// Handle any other type of exception
LOG_ERROR("An unexpected error occurred\n");
}
LOG_INFO("prompt is: %s\n", params.prompt.c_str());

{
sd_image_t* results;
results = txt2img(sd_ctx,
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.guidance,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.seed,
params.batch_count,
NULL,
1,
params.style_ratio,
params.normalize_input,
"");
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.guidance,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.seed,
params.batch_count,
NULL,
1,
params.style_ratio,
params.normalize_input,
"");

if (results == NULL) {
printf("generate failed\n");
Expand All @@ -650,52 +751,67 @@ int main(int argc, const char* argv[]) {

size_t last = params.output_path.find_last_of(".");
std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path;
json images_json = json::array();
for (int i = 0; i < params.batch_count; i++) {
if (results[i].data == NULL) {
continue;
}
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1 + n_prompts*params.batch_count) + ".png" : dummy_name + ".png";
// TODO allow disable save to disk
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1 + n_prompts * params.batch_count) + ".png" : dummy_name + ".png";
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
printf("save result image to '%s'\n", final_image_path.c_str());
// Todo: return base64 encoded image via websocket?
// Todo: return base64 encoded image via httplib::Response& res

int len;
unsigned char* png = stbi_write_png_to_mem((const unsigned char*)results[i].data, 0, results[i].width, results[i].height, results[i].channel, &len, NULL);

std::string data_str(png, png + len);
std::string encoded_img = base64_encode(data_str);

images_json.push_back({{"width", results[i].width},
{"height", results[i].height},
{"channel", results[i].channel},
{"data", encoded_img},
{"encoding", "png"}});

free(results[i].data);
results[i].data = NULL;
}
free(results);
n_prompts++;
res.set_content(images_json.dump(), "application/json");
}
return 0;
};


std::unique_ptr<httplib::Server> svr;
svr.reset(new httplib::Server());
svr->set_default_headers({{"Server", "sd.cpp"}});
// CORS preflight
svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) {
svr->Options(R"(.*)", [](const httplib::Request&, httplib::Response& res) {
// Access-Control-Allow-Origin is already set by middleware
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "POST");
res.set_header("Access-Control-Allow-Headers", "*");
return res.set_content("", "text/html"); // blank response, no data
res.set_header("Access-Control-Allow-Methods", "POST");
res.set_header("Access-Control-Allow-Headers", "*");
return res.set_content("", "text/html"); // blank response, no data
});
svr->set_logger(log_server_request);

svr->Post("/txt2img", txt2imgRequest);


// bind HTTP listen port, run the HTTP server in a thread
if (!svr->bind_to_port(params.host, params.port)) {
//TODO: Error message
// TODO: Error message
return 1;
}
}
std::thread t([&]() { svr->listen_after_bind(); });
svr->wait_until_ready();

printf("Server listening at %s:%d\n",params.host.c_str(),params.port);
printf("Server listening at %s:%d\n", params.host.c_str(), params.port);

while(1);
while (1)
;

free_sd_ctx(sd_ctx);

Expand Down

0 comments on commit 533da39

Please sign in to comment.