-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
api: Create generative AI APIs using AI subnet (#2246)
* api/generate: Create the simplest form of a /text-to-image * (noop) api/controllers: Sort API router alphabetically * api/generate: Create the other generate API proxies Pretend they are JSON for now, adding support for multipart forms next. * api: Add ajv-formats to support binary format That required upgrading ajv itself, which was a bit of a pain but worked and also found a few issues in our schema. * api/generate: Add multipart request support * [DEV-only] Default to the staging gateway * api/test: Create tests for new /generate APIs * api/test: Fix multistream tests * prettier 😡 * api/test: Add more cases for multipart validation * Revert "[DEV-only] Default to the staging gateway" This reverts commit 802d297.
- Loading branch information
Showing
15 changed files
with
643 additions
and
173 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
import FormData from "form-data"; | ||
import { User } from "../schema/types"; | ||
import { | ||
AuxTestServer, | ||
TestClient, | ||
clearDatabase, | ||
setupUsers, | ||
startAuxTestServer, | ||
} from "../test-helpers"; | ||
import serverPromise, { TestServer } from "../test-server"; | ||
|
||
let server: TestServer; | ||
let mockAdminUserInput: User; | ||
let mockNonAdminUserInput: User; | ||
|
||
// jest.setTimeout(70000) | ||
|
||
beforeAll(async () => { | ||
server = await serverPromise; | ||
|
||
mockAdminUserInput = { | ||
email: "[email protected]", | ||
password: "x".repeat(64), | ||
}; | ||
|
||
mockNonAdminUserInput = { | ||
email: "[email protected]", | ||
password: "y".repeat(64), | ||
}; | ||
}); | ||
|
||
afterEach(async () => { | ||
await clearDatabase(server); | ||
}); | ||
|
||
describe("controllers/generate", () => { | ||
let client: TestClient; | ||
let adminUser: User; | ||
let adminApiKey: string; | ||
let nonAdminUser: User; | ||
let nonAdminToken: string; | ||
|
||
let aiGatewayServer: AuxTestServer; | ||
let aiGatewayCalls: Record<string, number>; | ||
|
||
beforeAll(async () => { | ||
aiGatewayServer = await startAuxTestServer(30303); // port configured in test-params.ts | ||
const apis = [ | ||
"text-to-image", | ||
"image-to-image", | ||
"image-to-video", | ||
"upscale", | ||
]; | ||
for (const api of apis) { | ||
aiGatewayServer.app.post(`/${api}`, (req, res) => { | ||
aiGatewayCalls[api] = (aiGatewayCalls[api] || 0) + 1; | ||
return res.status(200).json({ | ||
message: "success", | ||
reqContentType: req.headers["content-type"] ?? "unknown", | ||
}); | ||
}); | ||
} | ||
}); | ||
|
||
afterAll(async () => { | ||
await aiGatewayServer.close(); | ||
}); | ||
|
||
beforeEach(async () => { | ||
({ client, adminUser, adminApiKey, nonAdminUser, nonAdminToken } = | ||
await setupUsers(server, mockAdminUserInput, mockNonAdminUserInput)); | ||
|
||
client.apiKey = adminApiKey; | ||
await client.post("/experiment", { | ||
name: "ai-generate", | ||
audienceUserIds: [adminUser.id, nonAdminUser.id], | ||
}); | ||
client.apiKey = null; | ||
client.jwtAuth = nonAdminToken; | ||
|
||
aiGatewayCalls = {}; | ||
}); | ||
|
||
const buildMultipartBody = (textFields: Record<string, any>) => { | ||
const form = new FormData(); | ||
for (const [k, v] of Object.entries(textFields)) { | ||
form.append(k, v); | ||
} | ||
form.append("image", "dummy", { | ||
contentType: "image/png", | ||
}); | ||
return form; | ||
}; | ||
|
||
describe("API proxies", () => { | ||
it("should call the AI Gateway for generate API /text-to-image", async () => { | ||
const res = await client.post("/beta/generate/text-to-image", { | ||
prompt: "a man in a suit and tie", | ||
}); | ||
expect(res.status).toBe(200); | ||
expect(await res.json()).toEqual({ | ||
message: "success", | ||
reqContentType: "application/json", | ||
}); | ||
expect(aiGatewayCalls).toEqual({ "text-to-image": 1 }); | ||
}); | ||
|
||
it("should call the AI Gateway for generate API /image-to-image", async () => { | ||
const res = await client.fetch("/beta/generate/image-to-image", { | ||
method: "POST", | ||
body: buildMultipartBody({ | ||
prompt: "replace the suit with a bathing suit", | ||
}), | ||
}); | ||
expect(res.status).toBe(200); | ||
expect(await res.json()).toEqual({ | ||
message: "success", | ||
reqContentType: expect.stringMatching("^multipart/form-data"), | ||
}); | ||
expect(aiGatewayCalls).toEqual({ "image-to-image": 1 }); | ||
}); | ||
|
||
it("should call the AI Gateway for generate API /image-to-video", async () => { | ||
const res = await client.fetch("/beta/generate/image-to-video", { | ||
method: "POST", | ||
body: buildMultipartBody({}), | ||
}); | ||
expect(res.status).toBe(200); | ||
expect(await res.json()).toEqual({ | ||
message: "success", | ||
reqContentType: expect.stringMatching("^multipart/form-data"), | ||
}); | ||
expect(aiGatewayCalls).toEqual({ "image-to-video": 1 }); | ||
}); | ||
|
||
it("should call the AI Gateway for generate API /upscale", async () => { | ||
const res = await client.fetch("/beta/generate/upscale", { | ||
method: "POST", | ||
body: buildMultipartBody({ prompt: "enhance" }), | ||
}); | ||
expect(res.status).toBe(200); | ||
expect(await res.json()).toEqual({ | ||
message: "success", | ||
reqContentType: expect.stringMatching("^multipart/form-data"), | ||
}); | ||
expect(aiGatewayCalls).toEqual({ upscale: 1 }); | ||
}); | ||
}); | ||
|
||
describe("validates multipart schema", () => { | ||
const hugeForm = new FormData(); | ||
const file11mb = "a".repeat(11 * 1024 * 1024); | ||
hugeForm.append("image", file11mb, { | ||
contentType: "image/png", | ||
}); | ||
|
||
const testCases = [ | ||
[ | ||
"should fail with a missing required field", | ||
buildMultipartBody({}), | ||
"must have required property 'prompt'", | ||
], | ||
[ | ||
"should fail with bad type for a field", | ||
buildMultipartBody({ prompt: "impromptu", seed: "NaN" }), | ||
"must be integer", | ||
], | ||
[ | ||
"should fail with an unknown field", | ||
buildMultipartBody({ | ||
prompt: "impromptu", | ||
extra_good_image: "yes pls", | ||
}), | ||
"must NOT have additional properties", | ||
], | ||
["should limit maximum payload size", hugeForm, "Field value too long"], | ||
] as const; | ||
|
||
for (const [title, input, error] of testCases) { | ||
it(title, async () => { | ||
const res = await client.fetch("/beta/generate/image-to-image", { | ||
method: "POST", | ||
body: input, | ||
}); | ||
expect(res.status).toBe(422); | ||
expect(await res.json()).toEqual({ | ||
errors: [expect.stringContaining(error)], | ||
}); | ||
expect(aiGatewayCalls).toEqual({}); | ||
}); | ||
} | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import { RequestHandler, Router } from "express"; | ||
import FormData from "form-data"; | ||
import multer from "multer"; | ||
import { BodyInit } from "node-fetch"; | ||
import logger from "../logger"; | ||
import { authorizer, validateFormData, validatePost } from "../middleware"; | ||
import { fetchWithTimeout } from "../util"; | ||
import { experimentSubjectsOnly } from "./experiment"; | ||
import { pathJoin2 } from "./helpers"; | ||
|
||
const AI_GATEWAY_TIMEOUT = 10 * 60 * 1000; // 10 minutes | ||
|
||
const multipart = multer({ | ||
storage: multer.memoryStorage(), | ||
limits: { fileSize: 10485760 }, // 10MiB | ||
}); | ||
|
||
const app = Router(); | ||
|
||
app.use(experimentSubjectsOnly("ai-generate")); | ||
|
||
function registerGenerateHandler( | ||
name: string, | ||
defaultModel: string, | ||
isJSONReq = false, | ||
): RequestHandler { | ||
const path = `/${name}`; | ||
const middlewares = isJSONReq | ||
? [validatePost(`${name}-payload`)] | ||
: [multipart.any(), validateFormData(`${name}-payload`)]; | ||
return app.post( | ||
path, | ||
authorizer({}), | ||
...middlewares, | ||
async function proxyGenerate(req, res) { | ||
const { aiGatewayUrl } = req.config; | ||
if (!aiGatewayUrl) { | ||
res.status(500).json({ errors: ["AI Gateway URL is not set"] }); | ||
return; | ||
} | ||
|
||
const apiUrl = pathJoin2(aiGatewayUrl, path); | ||
|
||
let payload: BodyInit; | ||
if (isJSONReq) { | ||
payload = JSON.stringify({ | ||
model_id: defaultModel, | ||
...req.body, | ||
}); | ||
} else { | ||
const form = new FormData(); | ||
if (!("model_id" in req.body)) { | ||
form.append("model_id", defaultModel); | ||
} | ||
for (const [key, value] of Object.entries(req.body)) { | ||
form.append(key, value); | ||
} | ||
|
||
if (!Array.isArray(req.files)) { | ||
return res.status(400).json({ | ||
errors: ["Expected an array of files"], | ||
}); | ||
} | ||
for (const file of req.files) { | ||
form.append(file.fieldname, file.buffer, { | ||
filename: file.originalname, | ||
contentType: file.mimetype, | ||
knownLength: file.size, | ||
}); | ||
} | ||
payload = form; | ||
} | ||
|
||
const response = await fetchWithTimeout(apiUrl, { | ||
method: "POST", | ||
body: payload, | ||
timeout: AI_GATEWAY_TIMEOUT, | ||
headers: isJSONReq ? { "content-type": "application/json" } : {}, | ||
}); | ||
|
||
const body = await response.json(); | ||
if (!response.ok) { | ||
logger.error( | ||
`Error from generate API ${path} status=${ | ||
response.status | ||
} body=${JSON.stringify(body)}`, | ||
); | ||
} | ||
if (response.status >= 500) { | ||
return res.status(500).json({ errors: [`Failed to generate ${name}`] }); | ||
} | ||
|
||
res.status(response.status).json(body); | ||
}, | ||
); | ||
} | ||
|
||
registerGenerateHandler( | ||
"text-to-image", | ||
"SG161222/RealVisXL_V4.0_Lightning", | ||
true, | ||
); | ||
registerGenerateHandler("image-to-image", "timbrooks/instruct-pix2pix"); | ||
registerGenerateHandler( | ||
"image-to-video", | ||
"stabilityai/stable-video-diffusion-img2vid-xt-1-1", | ||
); | ||
registerGenerateHandler("upscale", "stabilityai/stable-diffusion-x4-upscaler"); | ||
|
||
export default app; |
Oops, something went wrong.