Skip to content

Commit

Permalink
api: Create generative AI APIs using AI subnet (#2246)
Browse files Browse the repository at this point in the history
* api/generate: Create the simplest form of a /text-to-image

* (noop) api/controllers: Sort API router alphabetically

* api/generate: Create the other generate API proxies

Pretend they are JSON for now, adding support for multipart forms next.

* api: Add ajv-formats to support binary format

That required upgrading ajv itself, which was a bit
of a pain but worked and also found a few issues in
our schema.

* api/generate: Add multipart request support

* [DEV-only] Default to the staging gateway

* api/test: Create tests for new /generate APIs

* api/test: Fix multistream tests

* prettier 😡

* api/test: Add more cases for multipart validation

* Revert "[DEV-only] Default to the staging gateway"

This reverts commit 802d297.
  • Loading branch information
victorges authored Jul 17, 2024
1 parent 4e24f8e commit a9979b7
Show file tree
Hide file tree
Showing 15 changed files with 643 additions and 173 deletions.
7 changes: 5 additions & 2 deletions packages/api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@
"@tus/server": "^1.0.0",
"@types/amqp-connection-manager": "2.0.10",
"@types/cors": "^2.8.12",
"ajv": "^6.10.0",
"ajv": "^8.16.0",
"ajv-cli": "^3.1.0",
"ajv-formats": "^3.0.1",
"amqp-connection-manager": "^4.1.6",
"amqplib": "^0.8.0",
"analytics-node": "^3.4.0-beta.1",
Expand All @@ -78,6 +79,7 @@
"express-prom-bundle": "^6.4.1",
"fakefilter": "^0.1.880",
"fast-stable-stringify": "^1.0.0",
"form-data": "^4.0.0",
"fs-extra": "^7.0.1",
"google-auth-library": "^5.2.2",
"googleapis": "^43.0.0",
Expand All @@ -96,6 +98,7 @@
"morgan": "^1.9.1",
"mqtt": "^4.2.6",
"ms": "^2.1.2",
"multer": "^1.4.5-lts.1",
"mustache": "^3.0.3",
"node-fetch": "^2.6.1",
"node-jose": "^1.1.4",
Expand Down Expand Up @@ -126,11 +129,11 @@
"@types/jsonwebtoken": "^8.5.6",
"@types/lodash": "^4.14.191",
"@types/ms": "^0.7.31",
"@types/multer": "^1.4.11",
"@types/mustache": "^4.1.1",
"@types/node-fetch": "^2.5.10",
"@types/pg": "^7.14.4",
"@types/uuid": "^9.0.0",
"ajv-pack": "^0.3.1",
"cloudflare": "^2.7.0",
"esbuild": "^0.18.17",
"esm": "^3.2.22",
Expand Down
20 changes: 14 additions & 6 deletions packages/api/src/compile-schemas.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import Ajv from "ajv";
import pack from "ajv-pack";
import { safeLoad as parseYaml, safeDump as serializeYaml } from "js-yaml";
import ajvFormats from "ajv-formats";
import standaloneCode from "ajv/dist/standalone";
import fs from "fs-extra";
import { safeLoad as parseYaml, safeDump as serializeYaml } from "js-yaml";
import $RefParser from "json-schema-ref-parser";
import { compile as generateTypes } from "json-schema-to-typescript";
import _ from "lodash";
import path from "path";
import { compile as generateTypes } from "json-schema-to-typescript";
import $RefParser from "json-schema-ref-parser";

// This takes schema.yaml as its input and produces a few outputs.
// 1. types.d.ts, TypeScript definitions of the JSON-schema objects
Expand Down Expand Up @@ -52,7 +53,14 @@ const data = _.merge({}, apiData, dbData);
write(path.resolve(schemaDir, "schema.json"), str);
write(path.resolve(schemaDistDir, "schema.json"), str);

const ajv = new Ajv({ sourceCode: true });
let ajv = new Ajv({
keywords: [
...["example", "minValue"], // OpenAPI keywords not supported by ajv
...["table", "index", "indexType", "unique"], // our custom keywords
],
code: { source: true },
});
ajv = ajvFormats(ajv, ["binary", "uri"]);

const index = [];
let types = [];
Expand All @@ -62,7 +70,7 @@ const data = _.merge({}, apiData, dbData);
const type = await generateTypes(schema);
types.push(type);
var validate = ajv.compile(schema);
var moduleCode = pack(ajv, validate);
var moduleCode = standaloneCode(ajv, validate);
const outPath = path.resolve(validatorDir, `${name}.js`);
write(outPath, moduleCode);
index.push(`'${name}': require('./${name}.js'),`);
Expand Down
2 changes: 1 addition & 1 deletion packages/api/src/controllers/asset.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ describe("controllers/asset", () => {
await testStoragePatch(
{ ipfs: { nftMetadata: { a: "b" } } } as any,
null,
[expect.stringContaining("should NOT have additional properties")],
[expect.stringContaining("must NOT have additional properties")],
);
await testStoragePatch(
{ ipfs: {} },
Expand Down
193 changes: 193 additions & 0 deletions packages/api/src/controllers/generate.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import FormData from "form-data";
import { User } from "../schema/types";
import {
AuxTestServer,
TestClient,
clearDatabase,
setupUsers,
startAuxTestServer,
} from "../test-helpers";
import serverPromise, { TestServer } from "../test-server";

let server: TestServer;
let mockAdminUserInput: User;
let mockNonAdminUserInput: User;

// jest.setTimeout(70000)

beforeAll(async () => {
server = await serverPromise;

mockAdminUserInput = {
email: "[email protected]",
password: "x".repeat(64),
};

mockNonAdminUserInput = {
email: "[email protected]",
password: "y".repeat(64),
};
});

afterEach(async () => {
await clearDatabase(server);
});

describe("controllers/generate", () => {
let client: TestClient;
let adminUser: User;
let adminApiKey: string;
let nonAdminUser: User;
let nonAdminToken: string;

let aiGatewayServer: AuxTestServer;
let aiGatewayCalls: Record<string, number>;

beforeAll(async () => {
aiGatewayServer = await startAuxTestServer(30303); // port configured in test-params.ts
const apis = [
"text-to-image",
"image-to-image",
"image-to-video",
"upscale",
];
for (const api of apis) {
aiGatewayServer.app.post(`/${api}`, (req, res) => {
aiGatewayCalls[api] = (aiGatewayCalls[api] || 0) + 1;
return res.status(200).json({
message: "success",
reqContentType: req.headers["content-type"] ?? "unknown",
});
});
}
});

afterAll(async () => {
await aiGatewayServer.close();
});

beforeEach(async () => {
({ client, adminUser, adminApiKey, nonAdminUser, nonAdminToken } =
await setupUsers(server, mockAdminUserInput, mockNonAdminUserInput));

client.apiKey = adminApiKey;
await client.post("/experiment", {
name: "ai-generate",
audienceUserIds: [adminUser.id, nonAdminUser.id],
});
client.apiKey = null;
client.jwtAuth = nonAdminToken;

aiGatewayCalls = {};
});

const buildMultipartBody = (textFields: Record<string, any>) => {
const form = new FormData();
for (const [k, v] of Object.entries(textFields)) {
form.append(k, v);
}
form.append("image", "dummy", {
contentType: "image/png",
});
return form;
};

describe("API proxies", () => {
it("should call the AI Gateway for generate API /text-to-image", async () => {
const res = await client.post("/beta/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(200);
expect(await res.json()).toEqual({
message: "success",
reqContentType: "application/json",
});
expect(aiGatewayCalls).toEqual({ "text-to-image": 1 });
});

it("should call the AI Gateway for generate API /image-to-image", async () => {
const res = await client.fetch("/beta/generate/image-to-image", {
method: "POST",
body: buildMultipartBody({
prompt: "replace the suit with a bathing suit",
}),
});
expect(res.status).toBe(200);
expect(await res.json()).toEqual({
message: "success",
reqContentType: expect.stringMatching("^multipart/form-data"),
});
expect(aiGatewayCalls).toEqual({ "image-to-image": 1 });
});

it("should call the AI Gateway for generate API /image-to-video", async () => {
const res = await client.fetch("/beta/generate/image-to-video", {
method: "POST",
body: buildMultipartBody({}),
});
expect(res.status).toBe(200);
expect(await res.json()).toEqual({
message: "success",
reqContentType: expect.stringMatching("^multipart/form-data"),
});
expect(aiGatewayCalls).toEqual({ "image-to-video": 1 });
});

it("should call the AI Gateway for generate API /upscale", async () => {
const res = await client.fetch("/beta/generate/upscale", {
method: "POST",
body: buildMultipartBody({ prompt: "enhance" }),
});
expect(res.status).toBe(200);
expect(await res.json()).toEqual({
message: "success",
reqContentType: expect.stringMatching("^multipart/form-data"),
});
expect(aiGatewayCalls).toEqual({ upscale: 1 });
});
});

describe("validates multipart schema", () => {
const hugeForm = new FormData();
const file11mb = "a".repeat(11 * 1024 * 1024);
hugeForm.append("image", file11mb, {
contentType: "image/png",
});

const testCases = [
[
"should fail with a missing required field",
buildMultipartBody({}),
"must have required property 'prompt'",
],
[
"should fail with bad type for a field",
buildMultipartBody({ prompt: "impromptu", seed: "NaN" }),
"must be integer",
],
[
"should fail with an unknown field",
buildMultipartBody({
prompt: "impromptu",
extra_good_image: "yes pls",
}),
"must NOT have additional properties",
],
["should limit maximum payload size", hugeForm, "Field value too long"],
] as const;

for (const [title, input, error] of testCases) {
it(title, async () => {
const res = await client.fetch("/beta/generate/image-to-image", {
method: "POST",
body: input,
});
expect(res.status).toBe(422);
expect(await res.json()).toEqual({
errors: [expect.stringContaining(error)],
});
expect(aiGatewayCalls).toEqual({});
});
}
});
});
110 changes: 110 additions & 0 deletions packages/api/src/controllers/generate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import { RequestHandler, Router } from "express";
import FormData from "form-data";
import multer from "multer";
import { BodyInit } from "node-fetch";
import logger from "../logger";
import { authorizer, validateFormData, validatePost } from "../middleware";
import { fetchWithTimeout } from "../util";
import { experimentSubjectsOnly } from "./experiment";
import { pathJoin2 } from "./helpers";

const AI_GATEWAY_TIMEOUT = 10 * 60 * 1000; // 10 minutes

const multipart = multer({
storage: multer.memoryStorage(),
limits: { fileSize: 10485760 }, // 10MiB
});

const app = Router();

app.use(experimentSubjectsOnly("ai-generate"));

function registerGenerateHandler(
name: string,
defaultModel: string,
isJSONReq = false,
): RequestHandler {
const path = `/${name}`;
const middlewares = isJSONReq
? [validatePost(`${name}-payload`)]
: [multipart.any(), validateFormData(`${name}-payload`)];
return app.post(
path,
authorizer({}),
...middlewares,
async function proxyGenerate(req, res) {
const { aiGatewayUrl } = req.config;
if (!aiGatewayUrl) {
res.status(500).json({ errors: ["AI Gateway URL is not set"] });
return;
}

const apiUrl = pathJoin2(aiGatewayUrl, path);

let payload: BodyInit;
if (isJSONReq) {
payload = JSON.stringify({
model_id: defaultModel,
...req.body,
});
} else {
const form = new FormData();
if (!("model_id" in req.body)) {
form.append("model_id", defaultModel);
}
for (const [key, value] of Object.entries(req.body)) {
form.append(key, value);
}

if (!Array.isArray(req.files)) {
return res.status(400).json({
errors: ["Expected an array of files"],
});
}
for (const file of req.files) {
form.append(file.fieldname, file.buffer, {
filename: file.originalname,
contentType: file.mimetype,
knownLength: file.size,
});
}
payload = form;
}

const response = await fetchWithTimeout(apiUrl, {
method: "POST",
body: payload,
timeout: AI_GATEWAY_TIMEOUT,
headers: isJSONReq ? { "content-type": "application/json" } : {},
});

const body = await response.json();
if (!response.ok) {
logger.error(
`Error from generate API ${path} status=${
response.status
} body=${JSON.stringify(body)}`,
);
}
if (response.status >= 500) {
return res.status(500).json({ errors: [`Failed to generate ${name}`] });
}

res.status(response.status).json(body);
},
);
}

registerGenerateHandler(
"text-to-image",
"SG161222/RealVisXL_V4.0_Lightning",
true,
);
registerGenerateHandler("image-to-image", "timbrooks/instruct-pix2pix");
registerGenerateHandler(
"image-to-video",
"stabilityai/stable-video-diffusion-img2vid-xt-1-1",
);
registerGenerateHandler("upscale", "stabilityai/stable-diffusion-x4-upscaler");

export default app;
Loading

0 comments on commit a9979b7

Please sign in to comment.