diff --git a/nbs/Quick_REST_API_DEMO.ipynb b/nbs/Quick_REST_API_DEMO.ipynb new file mode 100644 index 0000000..3fa22f4 --- /dev/null +++ b/nbs/Quick_REST_API_DEMO.ipynb @@ -0,0 +1,1153 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Stability API Standard Feature Demo\n", + "\n", + "This notebook showcases a few key features available through our API.\n", + "You will need to obtain keys, and may need to be whitelisted for some features\n", + "\n", + "* Stability SDXL Keys are available here: https://platform.stability.ai/account/keys\n", + "\n", + "*For a complete reference of the Stability API, please visit https://platform.stability.ai/docs/api-reference*
\n", + "Please note that a REST API and gRPC API are available." + ], + "metadata": { + "id": "ej6SVLpa7Jtz" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Install Dependencies\n", + "import requests\n", + "import shutil\n", + "import getpass\n", + "import os\n", + "import base64\n", + "from google.colab import files\n", + "from PIL import Image" + ], + "metadata": { + "id": "I0zR1GllPIRe" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Load in Sample Images\n", + "#Feel free to replace these with your own images\n", + "url_mappings = {\"dog_with_armor\": \"https://i.imgur.com/4nnSP8q.png\",\n", + " \"dog_with_armor_inpaint\": \"https://i.imgur.com/eu44gJe.png\",\n", + " \"dog_with_armor_inpaint_just_armor\": \"https://i.imgur.com/Mw6QU6P.png\",\n", + " \"dog_outpaint_example\": \"https://i.imgur.com/yv9RxjQ.png\",\n", + " \"outpaint_mask_1024_1024\": \"https://i.imgur.com/L1lqrXm.png\"\n", + " }\n", + "for name in url_mappings:\n", + " response = requests.get(url_mappings[name], stream=True)\n", + " with open(f'/content/{name}.png', 'wb') as out_file:\n", + " response.raw.decode_content = True\n", + " shutil.copyfileobj(response.raw, out_file)\n", + " del response" + ], + "metadata": { + "id": "5EL8d-S7jSdM", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Stability API Key\n", + "# You will be prompted to enter your api keys after running this code\n", + "# You can view your API key here: https://next.platform.stability.ai/account/keys\n", + "api_key = getpass.getpass('Enter your API Key')" + ], + "metadata": { + "id": "_8RqX3BgQUSr" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Text To Image Example\n", + "url = \"https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/text-to-image\"\n", + "\n", + "body = {\n", + " \"steps\": 30,\n", + " \"width\": 1024,\n", + " \"height\": 1024,\n", + " \"seed\": 0,\n", + " \"cfg_scale\": 5,\n", + " \"samples\": 1,\n", + " \"text_prompts\": [\n", + " {\n", + " \"text\": \"A painting of a cat wearing armor, intricate filigree, cinematic masterpiece digital art\",\n", + " \"weight\": 1\n", + " },\n", + " {\n", + " \"text\": \"blurry, bad\",\n", + " \"weight\": -1\n", + " }\n", + " ],\n", + "}\n", + "\n", + "headers = {\n", + " \"Accept\": \"application/json\",\n", + " \"Content-Type\": \"application/json\",\n", + " \"Authorization\": f\"Bearer {api_key}\",\n", + "}\n", + "\n", + "response = requests.post(\n", + " url,\n", + " headers=headers,\n", + " json=body,\n", + ")\n", + "\n", + "if response.status_code != 200:\n", + " raise Exception(\"Non-200 response: \" + str(response.text))\n", + "\n", + "data = response.json()\n", + "\n", + "# make sure the out directory exists\n", + "if not os.path.exists(\"./out\"):\n", + " os.makedirs(\"./out\")\n", + "\n", + "for i, image in enumerate(data[\"artifacts\"]):\n", + " with open(f'./out/txt2img_{image[\"seed\"]}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image[\"base64\"]))\n", + " files.download(f.name)" + ], + "metadata": { + "id": "vXlxmX1_MnNw" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Inpainting Example\n", + "response = requests.post(\n", + " \"https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/image-to-image/masking\",\n", + " headers={\n", + " \"Accept\": \"application/json\",\n", + " \"Authorization\": f\"Bearer {api_key}\"\n", + " },\n", + " files={\n", + " #replace init image and mask image with your image and mask\n", + " \"init_image\": open(\"/content/dog_with_armor.png\", \"rb\"),\n", + " \"mask_image\": open(\"/content/dog_with_armor_inpaint_just_armor.png\", \"rb\")\n", + " },\n", + " data={\n", + " \"mask_source\": \"MASK_IMAGE_BLACK\",\n", + "\t\t\"steps\": 40,\n", + "\t\t\"seed\": 0,\n", + "\t\t\"cfg_scale\": 5,\n", + "\t\t\"samples\": 1,\n", + "\t\t\"text_prompts[0][text]\": 'Dog Armor made of chocolate',\n", + "\t\t\"text_prompts[0][weight]\": 1,\n", + "\t\t\"text_prompts[1][text]\": 'blurry, bad',\n", + "\t\t\"text_prompts[1][weight]\": -1,\n", + " }\n", + ")\n", + "\n", + "if response.status_code != 200:\n", + " raise Exception(\"Non-200 response: \" + str(response.text))\n", + "\n", + "data = response.json()\n", + "\n", + "# make sure the out directory exists\n", + "if not os.path.exists(\"./out\"):\n", + " os.makedirs(\"./out\")\n", + "\n", + "for i, image in enumerate(data[\"artifacts\"]):\n", + " with open(f'./out/img2img_{image[\"seed\"]}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image[\"base64\"]))\n", + " files.download(f.name)" + ], + "metadata": { + "id": "Sq14bHVOO__i", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Inpainting - Change Background Example\n", + "response = requests.post(\n", + " \"https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/image-to-image/masking\",\n", + " headers={\n", + " \"Accept\": \"application/json\",\n", + " \"Authorization\": f\"Bearer {api_key}\"\n", + " },\n", + " files={\n", + " \"init_image\": open(\"/content/dog_with_armor.png\", \"rb\"),\n", + " \"mask_image\": open(\"/content/dog_with_armor_inpaint.png\", \"rb\")\n", + " },\n", + " data={\n", + " # Flipping to white will make it inpaint but remove background, even though dog is masked black\n", + " \"mask_source\": \"MASK_IMAGE_WHITE\",\n", + "\t\t\"steps\": 40,\n", + "\t\t\"seed\": 0,\n", + "\t\t\"cfg_scale\": 5,\n", + "\t\t\"samples\": 1,\n", + "\t\t\"text_prompts[0][text]\": 'Medieval castle',\n", + "\t\t\"text_prompts[0][weight]\": 1,\n", + "\t\t\"text_prompts[1][text]\": 'blurry, bad',\n", + "\t\t\"text_prompts[1][weight]\": -1,\n", + " }\n", + ")\n", + "\n", + "if response.status_code != 200:\n", + " raise Exception(\"Non-200 response: \" + str(response.text))\n", + "\n", + "data = response.json()\n", + "\n", + "# make sure the out directory exists\n", + "if not os.path.exists(\"./out\"):\n", + " os.makedirs(\"./out\")\n", + "\n", + "for i, image in enumerate(data[\"artifacts\"]):\n", + " with open(f'./out/img2img_{image[\"seed\"]}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image[\"base64\"]))\n", + " files.download(f.name)" + ], + "metadata": { + "id": "ahInbhG4Y5oM", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Outpainting Example\n", + "\n", + "# Init image has to be the same size as mask image\n", + "# Paste the smaller init image onto the mask\n", + "initial_init_image = Image.open(\"/content/dog_outpaint_example.png\")\n", + "# The mask is already blurred, which will improve coherence\n", + "mask = Image.open(\"/content/outpaint_mask_1024_1024.png\")\n", + "mask.paste(initial_init_image)\n", + "mask.save('/content/dog_outpaint_init_image.png', quality=95)\n", + "\n", + "\n", + "response = requests.post(\n", + " \"https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/image-to-image/masking\",\n", + " headers={\n", + " \"Accept\": \"application/json\",\n", + " \"Authorization\": f\"Bearer {api_key}\"\n", + " },\n", + " files={\n", + " \"init_image\": open(\"/content/dog_outpaint_init_image.png\", \"rb\"),\n", + " \"mask_image\": open(\"/content/outpaint_mask_1024_1024.png\", \"rb\")\n", + " },\n", + " data={\n", + " \"mask_source\": \"MASK_IMAGE_BLACK\",\n", + "\t\t\"steps\": 40,\n", + "\t\t\"seed\": 0,\n", + "\t\t\"cfg_scale\": 5,\n", + "\t\t\"samples\": 1,\n", + "\t\t\"text_prompts[0][text]\": 'Medieval castle',\n", + "\t\t\"text_prompts[0][weight]\": 1,\n", + "\t\t\"text_prompts[1][text]\": 'blurry, bad',\n", + "\t\t\"text_prompts[1][weight]\": -1,\n", + " }\n", + ")\n", + "\n", + "if response.status_code != 200:\n", + " raise Exception(\"Non-200 response: \" + str(response.text))\n", + "\n", + "data = response.json()\n", + "\n", + "# make sure the out directory exists\n", + "if not os.path.exists(\"./out\"):\n", + " os.makedirs(\"./out\")\n", + "\n", + "for i, image in enumerate(data[\"artifacts\"]):\n", + " with open(f'./out/img2img_{image[\"seed\"]}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image[\"base64\"]))\n", + " files.download(f.name)" + ], + "metadata": { + "id": "URaWbM6xdXZw", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Image Upscaling Example\n", + "response = requests.post(\n", + " \"https://api.stability.ai/v1/generation/esrgan-v1-x2plus/image-to-image/upscale\",\n", + " headers={\n", + " \"Accept\": \"application/json\",\n", + " \"Authorization\": f\"Bearer {api_key}\"\n", + " },\n", + " files={\n", + " \"image\": open(\"/content/dog_with_armor.png\", \"rb\")\n", + " },\n", + " data={\n", + " \"width\": 2048\n", + " }\n", + ")\n", + "\n", + "if response.status_code != 200:\n", + " raise Exception(\"Non-200 response: \" + str(response.text))\n", + "\n", + "data = response.json()\n", + "\n", + "# make sure the out directory exists\n", + "if not os.path.exists(\"./out\"):\n", + " os.makedirs(\"./out\")\n", + "\n", + "for i, image in enumerate(data[\"artifacts\"]):\n", + " with open(f'./out/img2img_{image[\"seed\"]}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image[\"base64\"]))\n", + " files.download(f.name)" + ], + "metadata": { + "id": "vAWBk84DezoP", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Stability SDXL Enterprise API Demo\n", + "\n", + "Stability provides enterprise-grade features for customers that require faster speeds and dedicated managed services with support. These nodes can have significantly faster speeds based on the Stability Supercomputer, and may include prototype / preview models.
\n", + "The below section will leverage a demo node that is prepared at request. If you would like to try an enterprise node, please reach out to the Stability team." + ], + "metadata": { + "id": "nmHWiBNZAi4G" + } + }, + { + "cell_type": "code", + "source": [ + "# This demo notebook is designed to help illustrate the latency on prototype nodes with early access models\n", + "# This REST implementation will hit the special demo node. Please note that this is a demo only and results in production should exceed these speeds.\n", + "# Make sure to enable batch downloads from your browser if you want to see the images that will be downloaded locally\n", + "# OUTPUT: You will see Average and Total time for image generation. Note that in the first call there may be a warm-up time of up to 2 seconds, and tha colab will add an additional ~1.5 seconds\n", + "\n", + "import base64\n", + "import requests\n", + "import os\n", + "import time\n", + "from google.colab import files\n", + "\n", + "\n", + "def make_request(index):\n", + " #replace the with the name of the node and module provided to you\n", + " url = \"https://test.api.stability.ai/v1/generation//\"\n", + "#Steps: Increasing can improve quality, and increase latency\n", + " body = {\n", + " \"steps\": 22,\n", + " \"width\": 1024,\n", + " \"height\": 1024,\n", + " \"seed\": 0,\n", + " \"cfg_scale\": 6,\n", + " \"samples\": 1,\n", + " \"text_prompts\": [\n", + " {\n", + " \"text\": \"octane render of a barabaric software engineer\",\n", + " \"weight\": 1\n", + " },\n", + " {\n", + " \"text\": \"blurry, bad\",\n", + " \"weight\": -1\n", + " }\n", + " ],\n", + " }\n", + "\n", + " headers = {\n", + " \"Accept\": \"application/json\",\n", + " \"Content-Type\": \"application/json\",\n", + " #insert your Key\n", + " \"Authorization\": \"Bearer \",\n", + " }\n", + "\n", + " response = requests.post(\n", + " url,\n", + " headers=headers,\n", + " json=body,\n", + " )\n", + "\n", + " if response.status_code != 200:\n", + " raise Exception(\"Non-200 response: \" + str(response.text))\n", + "\n", + " data = response.json()\n", + "\n", + " if not os.path.exists(\"./out\"):\n", + " os.makedirs(\"./out\")\n", + "\n", + " for i, image in enumerate(data[\"artifacts\"]):\n", + " with open(f'./out/txt2img_{image[\"seed\"]}_{index}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image[\"base64\"]))\n", + "\n", + " #Please comment the below line to execute pure benchmarking without downloading the images\n", + " files.download(f.name)\n", + "\n", + "total_time = 0\n", + "\n", + "#Adjust num to change the number of images to get as batch\n", + "num = 2\n", + "for i in range(num):\n", + " print(i)\n", + " start = time.time()\n", + " make_request(i)\n", + " end = time.time()\n", + " total_time += (end - start)\n", + "\n", + "print(\"Average: \", total_time/num)\n", + "print(\"Total_Time: \", total_time)\n", + "print(\"Num Iterations: \", num)" + ], + "metadata": { + "id": "yDeYgKNXAiQH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# SDXL Finetuning REST API Demo\n", + "\n", + "Stability is offering a private beta of its fine-tuning service to select customers.
\n", + "\n", + "Note that this is a **developer beta** - bugs and quality issues with the generated fine-tunes may occur. Please reach out to Stability if this is the case - and share what you've made as well!\n", + "\n", + "The code below hits the Stability REST API. This REST API contract is rather solid, so it's unlikely to see large changes before the production release of fine-tuning.\n", + "\n", + "Known issues:\n", + "\n", + "* Style fine-tunes may result in overfitting - if this is the case, uncomment the `# weight=1.0` field of `DiffusionFineTune` in the diffusion section and provide a value between -1 and 1. You may need to go as low as 0.2 or 0.1.\n", + "* We will be exposing test parameters soon - please reach out with examples of datasets that produce overfitting or errors if you have them." + ], + "metadata": { + "id": "-xGp9o-iTc8e" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Stability API key\n", + "import getpass\n", + "\n", + "#@markdown Execute this step and paste your API key in the box that appears.
Visit https://platform.stability.ai/account/keys to get your API key!

Note: If you are not on the fine-tuning whitelist you will receive an error during training.\n", + "\n", + "API_KEY = getpass.getpass('Paste your Stability API Key here and press Enter: ')\n", + "\n", + "API_HOST = \"https://preview-api.stability.ai\"\n", + "\n", + "ENGINE_ID = \"stable-diffusion-xl-1024-v1-0\"" + ], + "metadata": { + "id": "t7910RqrlFJc", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Initialize the REST API wrapper\n", + "import io\n", + "import logging\n", + "import requests\n", + "import os\n", + "import shutil\n", + "import sys\n", + "import time\n", + "import json\n", + "import base64\n", + "from enum import Enum\n", + "from dataclasses import dataclass, is_dataclass, field, asdict\n", + "from typing import List, Optional, Any\n", + "from IPython.display import clear_output\n", + "from pathlib import Path\n", + "from PIL import Image\n", + "from zipfile import ZipFile\n", + "\n", + "\n", + "class Printable:\n", + " \"\"\" Helper class for printing a class to the console. \"\"\"\n", + "\n", + " @staticmethod\n", + " def to_json(obj: Any) -> Any:\n", + " if isinstance(obj, Enum):\n", + " return obj.value\n", + " if is_dataclass(obj):\n", + " return asdict(obj)\n", + "\n", + " return obj\n", + "\n", + " def __str__(self):\n", + " return f\"{self.__class__.__name__}: {json.dumps(self, default=self.to_json, indent=4)}\"\n", + "\n", + "\n", + "class ToDict:\n", + " \"\"\" Helper class to simplify converting dataclasses to dicts. \"\"\"\n", + "\n", + " def to_dict(self):\n", + " return {k: v for k, v in asdict(self).items() if v is not None}\n", + "\n", + "\n", + "@dataclass\n", + "class FineTune(Printable):\n", + " id: str\n", + " user_id: str\n", + " name: str\n", + " mode: str\n", + " engine_id: str\n", + " training_set_id: str\n", + " status: str\n", + " failure_reason: Optional[str] = field(default=None)\n", + " duration_seconds: Optional[int] = field(default=None)\n", + " object_prompt: Optional[str] = field(default=None)\n", + "\n", + "\n", + "@dataclass\n", + "class DiffusionFineTune(Printable, ToDict):\n", + " id: str\n", + " token: str\n", + " weight: Optional[float] = field(default=None)\n", + "\n", + "\n", + "@dataclass\n", + "class TextPrompt(Printable, ToDict):\n", + " text: str\n", + " weight: Optional[float] = field(default=None)\n", + "\n", + "\n", + "class Sampler(Enum):\n", + " DDIM = \"DDIM\"\n", + " DDPM = \"DDPM\"\n", + " K_DPMPP_2M = \"K_DPMPP_2M\"\n", + " K_DPMPP_2S_ANCESTRAL = \"K_DPMPP_2S_ANCESTRAL\"\n", + " K_DPM_2 = \"K_DPM_2\"\n", + " K_DPM_2_ANCESTRAL = \"K_DPM_2_ANCESTRAL\"\n", + " K_EULER = \"K_EULER\"\n", + " K_EULER_ANCESTRAL = \"K_EULER_ANCESTRAL\"\n", + " K_HEUN = \"K_HEUN\"\n", + " K_LMS = \"K_LMS\"\n", + "\n", + " @staticmethod\n", + " def from_string(val) -> Enum or None:\n", + " for sampler in Sampler:\n", + " if sampler.value == val:\n", + " return sampler\n", + " raise Exception(f\"Unknown Sampler: {val}\")\n", + "\n", + "\n", + "@dataclass\n", + "class TextToImageParams(Printable):\n", + " fine_tunes: List[DiffusionFineTune]\n", + " text_prompts: List[TextPrompt]\n", + " samples: int\n", + " sampler: Sampler\n", + " engine_id: str\n", + " steps: int\n", + " seed: Optional[int] = field(default=0)\n", + " cfg_value: Optional[int] = field(default=7)\n", + " width: Optional[int] = field(default=1024)\n", + " height: Optional[int] = field(default=1024)\n", + "\n", + "\n", + "@dataclass\n", + "class DiffusionResult:\n", + " base64: str\n", + " seed: int\n", + " finish_reason: str\n", + "\n", + " def __str__(self):\n", + " return f\"DiffusionResult(base64='too long to print', seed='{self.seed}', finish_reason='{self.finish_reason}')\"\n", + "\n", + " def __repr__(self):\n", + " return self.__str__()\n", + "\n", + "\n", + "@dataclass\n", + "class TrainingSetBase(Printable):\n", + " id: str\n", + " name: str\n", + "\n", + "\n", + "@dataclass\n", + "class TrainingSetImage(Printable):\n", + " id: str\n", + "\n", + "\n", + "@dataclass\n", + "class TrainingSet(TrainingSetBase):\n", + " images: List[TrainingSetImage]\n", + "\n", + "\n", + "class FineTuningRESTWrapper:\n", + " \"\"\"\n", + " Helper class to simplify interacting with the fine-tuning service via\n", + " Stability's REST API.\n", + "\n", + " While this class can be copied to your local environment, it is not likely\n", + " robust enough for your needs and does not support all of the features that\n", + " the REST API offers.\n", + " \"\"\"\n", + "\n", + " def __init__(self, api_key: str, api_host: str):\n", + " self.api_key = api_key\n", + " self.api_host = api_host\n", + "\n", + " def create_fine_tune(self,\n", + " name: str,\n", + " images: List[str],\n", + " engine_id: str,\n", + " mode: str,\n", + " object_prompt: Optional[str] = None) -> FineTune:\n", + " print(f\"Creating {mode} fine-tune called '{name}' using {len(images)} images...\")\n", + "\n", + " payload = {\"name\": name, \"engine_id\": engine_id, \"mode\": mode}\n", + " if object_prompt is not None:\n", + " payload[\"object_prompt\"] = object_prompt\n", + "\n", + " # Create a training set\n", + " training_set_id = self.create_training_set(name=name)\n", + " payload[\"training_set_id\"] = training_set_id\n", + " print(f\"\\tCreated training set {training_set_id}\")\n", + "\n", + " # Add images to the training set\n", + " for image in images:\n", + " print(f\"\\t\\tAdding {os.path.basename(image)}\")\n", + " self.add_image_to_training_set(\n", + " training_set_id=training_set_id,\n", + " image=image\n", + " )\n", + "\n", + " # Create the fine-tune\n", + " print(f\"\\tCreating a fine-tune from the training set\")\n", + " response = requests.post(\n", + " f\"{self.api_host}/v1/fine-tunes\",\n", + " json=payload,\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Content-Type\": \"application/json\"\n", + " }\n", + " )\n", + " raise_on_non200(response)\n", + " print(f\"\\tCreated fine-tune {response.json()['id']}\")\n", + "\n", + " print(f\"Success\")\n", + " return FineTune(**response.json())\n", + "\n", + " def get_fine_tune(self, fine_tune_id: str) -> FineTune:\n", + " response = requests.get(\n", + " f\"{self.api_host}/v1/fine-tunes/{fine_tune_id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return FineTune(**response.json())\n", + "\n", + " def list_fine_tunes(self) -> List[FineTune]:\n", + " response = requests.get(\n", + " f\"{self.api_host}/v1/fine-tunes\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return [FineTune(**ft) for ft in response.json()]\n", + "\n", + " def rename_fine_tune(self, fine_tune_id: str, name: str) -> FineTune:\n", + " response = requests.patch(\n", + " f\"{self.api_host}/v1/fine-tunes/{fine_tune_id}\",\n", + " json={\"operation\": \"RENAME\", \"name\": name},\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Content-Type\": \"application/json\"\n", + " }\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return FineTune(**response.json())\n", + "\n", + " def retrain_fine_tune(self, fine_tune_id: str) -> FineTune:\n", + " response = requests.patch(\n", + " f\"{self.api_host}/v1/fine-tunes/{fine_tune_id}\",\n", + " json={\"operation\": \"RETRAIN\"},\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Content-Type\": \"application/json\"\n", + " }\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return FineTune(**response.json())\n", + "\n", + " def delete_fine_tune(self, fine_tune: FineTune):\n", + " # Delete the underlying training set\n", + " self.delete_training_set(fine_tune.training_set_id)\n", + "\n", + " # Delete the fine-tune\n", + " response = requests.delete(\n", + " f\"{self.api_host}/v1/fine-tunes/{fine_tune.id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " def create_training_set(self, name: str) -> str:\n", + " response = requests.post(\n", + " f\"{self.api_host}/v1/training-sets\",\n", + " json={\"name\": name},\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Content-Type\": \"application/json\"\n", + " }\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return response.json().get('id')\n", + "\n", + " def get_training_set(self, training_set_id: str) -> TrainingSet:\n", + " response = requests.get(\n", + " f\"{self.api_host}/v1/training-sets/{training_set_id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return TrainingSet(**response.json())\n", + "\n", + " def list_training_sets(self) -> List[TrainingSetBase]:\n", + " response = requests.get(\n", + " f\"{self.api_host}/v1/training-sets\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return [TrainingSetBase(**tsb) for tsb in response.json()]\n", + "\n", + " def add_image_to_training_set(self, training_set_id: str, image: str) -> str:\n", + " with open(image, 'rb') as image_file:\n", + " response = requests.post(\n", + " f\"{self.api_host}/v1/training-sets/{training_set_id}/images\",\n", + " headers={\"Authorization\": self.api_key},\n", + " files={'image': image_file}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return response.json().get('id')\n", + "\n", + " def remove_image_from_training_set(self, training_set_id: str, image_id: str) -> None:\n", + " response = requests.delete(\n", + " f\"{self.api_host}/v1/training-sets/{training_set_id}/images/{image_id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " def delete_training_set(self, training_set_id: str) -> None:\n", + " response = requests.delete(\n", + " f\"{self.api_host}/v1/training-sets/{training_set_id}\",\n", + " headers={\"Authorization\": self.api_key}\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " def text_to_image(self, params: TextToImageParams) -> List[DiffusionResult]:\n", + " payload = {\n", + " \"fine_tunes\": [ft.to_dict() for ft in params.fine_tunes],\n", + " \"text_prompts\": [tp.to_dict() for tp in params.text_prompts],\n", + " \"samples\": params.samples,\n", + " \"sampler\": params.sampler.value,\n", + " \"steps\": params.steps,\n", + " \"seed\": params.seed,\n", + " \"width\": params.width,\n", + " \"height\": params.height,\n", + " \"cfg_value\": params.cfg_value,\n", + " }\n", + "\n", + " response = requests.post(\n", + " f\"{self.api_host}/v1/generation/{params.engine_id}/text-to-image\",\n", + " json=payload,\n", + " headers={\n", + " \"Authorization\": self.api_key,\n", + " \"Accept\": \"application/json\",\n", + " }\n", + " )\n", + "\n", + " raise_on_non200(response)\n", + "\n", + " return [\n", + " DiffusionResult(base64=item[\"base64\"], seed=item[\"seed\"], finish_reason=item[\"finishReason\"])\n", + " for item in response.json()[\"artifacts\"]\n", + " ]\n", + "\n", + "\n", + "def raise_on_non200(response):\n", + " if 200 <= response.status_code < 300:\n", + " return\n", + " raise Exception(f\"Status code {response.status_code}: {json.dumps(response.json(), indent=4)}\")\n", + "\n", + "\n", + "# Redirect logs to print statements so we can see them in the notebook\n", + "class PrintHandler(logging.Handler):\n", + " def emit(self, record):\n", + " print(self.format(record))\n", + "logging.getLogger().addHandler(PrintHandler())\n", + "logging.getLogger().setLevel(logging.INFO)\n", + "\n", + "# Initialize the fine-tune service\n", + "rest_api = FineTuningRESTWrapper(API_KEY, API_HOST)" + ], + "metadata": { + "id": "Dr38OlbKTb7Q" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title List your existing fine-tunes\n", + "\n", + "fine_tunes = rest_api.list_fine_tunes()\n", + "print(f\"Found {len(fine_tunes)} models\")\n", + "for fine_tune in fine_tunes:\n", + " print(f\" Model {fine_tune.id} {fine_tune.status:<9} {fine_tune.name}\")" + ], + "metadata": { + "id": "ZqIc2d8FAIW0", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Add Training Images\n", + "\n", + "For training, we need a dataset of images in a `.zip` file.\n", + "\n", + "Please only upload images that you have the permission to use.\n", + "\n", + "\n", + "### Image Dimensions\n", + "\n", + "- Images **cannot** have any side less than 328px\n", + "- Images **cannot** be larger than 10MB\n", + "\n", + "There is no upper-bound for what we'll accept for an image's dimensions, but any side above 1024px will be scaled down to 1024px, while preserving aspect ratio. For example:\n", + "- `3024x4032` will be scaled down to `768x1024`\n", + "- `1118x1118` will be scaled down to `1024x1024`\n", + "\n", + "\n", + "### Image Quantity\n", + "\n", + "- Datasets **cannot** have fewer than 3 images\n", + "- Datasets **cannot** have more than 64 images\n", + "\n", + "A larger dataset often tends to result in a more accurate fine-tune, but will also take longer to train.\n", + "\n", + "While each mode can accept up to 64 images, we have a few suggestions for a starter dataset based on the mode you are using:\n", + "* `FACE`: 6 or more images.\n", + "* `OBJECT`: 6 - 10 images.\n", + "* `STYLE`: 20 - 30 images." + ], + "metadata": { + "id": "vnAPh8ydc3SG" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Upload ZIP file of images\n", + "training_dir = \"./train\"\n", + "Path(training_dir).mkdir(exist_ok=True)\n", + "try:\n", + " from google.colab import files\n", + "\n", + " upload_res = files.upload()\n", + " extracted_dir = list(upload_res.keys())[0]\n", + " print(f\"Received {extracted_dir}\")\n", + " if not extracted_dir.endswith(\".zip\"):\n", + " raise ValueError(\"Uploaded file must be a zip file\")\n", + "\n", + " zf = ZipFile(io.BytesIO(upload_res[extracted_dir]), \"r\")\n", + " extracted_dir = Path(extracted_dir).stem\n", + " print(f\"Extracting to {extracted_dir}\")\n", + " zf.extractall(extracted_dir)\n", + "\n", + " for root, dirs, files in os.walk(extracted_dir):\n", + " for file in files:\n", + " source_path = os.path.join(root, file)\n", + " target_path = os.path.join(training_dir, file)\n", + "\n", + " # Ignore Mac-specific files\n", + " if 'MACOSX' in source_path or 'DS' in source_path:\n", + " continue\n", + "\n", + " # Move the file to the target directory\n", + " print('Copying', source_path, '==>', target_path)\n", + " shutil.move(source_path, target_path)\n", + "\n", + "\n", + "except ImportError:\n", + " pass\n", + "\n", + "print(f\"Using training images from: {training_dir}\")" + ], + "metadata": { + "id": "YKQXWltHANju" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Train a Fine-Tune\n", + "\n", + "Now we're ready to train our fine-tune. Use the parameters below to configure the name and the kind of fine-tune\n", + "\n", + "Please note that the training duration will vary based on:\n", + "- The number of images in your dataset\n", + "- The `training_mode` used\n", + "- The `engine_id` that is being fine-tuned on\n", + "\n", + "The following are some rough estimates for the training duration for each mode based on our recommended dataset sizes:\n", + "\n", + "* `FACE`: 4 - 5 minutes.\n", + "* `OBJECT`: 5 - 10 minutes.\n", + "* `STYLE`: 20 - 30 minutes." + ], + "metadata": { + "id": "UXAn59XibFv5" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Begin Training\n", + "fine_tune_name = \"my dog spot\" #@param {type:\"string\"}\n", + "#@markdown > Requirements:
  • Must be unique (only across your account, not globally)
  • Must be between 3 and 64 characters (inclusive)
  • Must only contain letters, numbers, spaces, or hyphens
\n", + "training_mode = \"OBJECT\" #@param [\"FACE\", \"STYLE\", \"OBJECT\"] {type:\"string\"}\n", + "#@markdown > Determines the kind of fine-tune you're creating:
  • FACE - a fine-tune on faces; expects pictures containing a face; automatically crops and centers on the face detected in the input photos.
  • OBJECT - a fine-tune on a particular object (e.g. a bottle); segments out the object using the `object_prompt` below
  • STYLE - a fine-tune on a particular style (e.g. satellite photos of earth); crops the images and filters for image quality.
\n", + "object_prompt = \"dog\" #@param {type:\"string\"}\n", + "#@markdown > Used for segmenting out your subject when the `training_mode` is `OBJECT`. (i.e. if you want to fine tune on a cat, put `cat` - for a bottle of liquor, use `bottle`. In general, it's best to use the most general word you can to describe your object.)\n", + "\n", + "# Gather training images\n", + "images = []\n", + "for filename in os.listdir(training_dir):\n", + " if os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg', '.heic']:\n", + " images.append(os.path.join(training_dir, filename))\n", + "\n", + "# Create the fine-tune\n", + "fine_tune = rest_api.create_fine_tune(\n", + " name=fine_tune_name,\n", + " images=images,\n", + " mode=training_mode,\n", + " object_prompt=object_prompt if training_mode == \"OBJECT\" else None,\n", + " engine_id=ENGINE_ID,\n", + ")\n", + "\n", + "print()\n", + "print(fine_tune)" + ], + "metadata": { + "id": "DMK3yOrGDLw8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Wait For Training to Finish\n", + "start_time = time.time()\n", + "while fine_tune.status != \"COMPLETED\" and fine_tune.status != \"FAILED\":\n", + " fine_tune = rest_api.get_fine_tune(fine_tune.id)\n", + " elapsed = time.time() - start_time\n", + " clear_output(wait=True)\n", + " print(f\"Training '{fine_tune.name}' ({fine_tune.id}) status: {fine_tune.status} for {elapsed:.0f} seconds\")\n", + " time.sleep(10)\n", + "\n", + "clear_output(wait=True)\n", + "status_message = \"completed\" if fine_tune.status == \"COMPLETED\" else \"failed\"\n", + "print(f\"Training '{fine_tune.name}' ({fine_tune.id}) {status_message} after {elapsed:.0f} seconds\")" + ], + "metadata": { + "id": "8-iAUX_ODwU6" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title (Optional) Retrain if Training Failed\n", + "if fine_tune.status == \"FAILED\":\n", + " print(f\"Training failed, due to {fine_tune.failure_reason}. Retraining...\")\n", + " fine_tune = rest_api.retrain_fine_tune(fine_tune.id)" + ], + "metadata": { + "id": "eZaWJT_CDyrb" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Use your Fine-Tune\n", + "\n", + "Time to diffuse! The example below uses a single fine-tune, but using multiple fine-tunes is where this process really shines. While this Colab doesn't directly support diffusing with multiple fine-tunes, you can still try it out by commenting out the" + ], + "metadata": { + "id": "vaBl4zuQfO20" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Generate Images\n", + "\n", + "prompt_token=\"$my-dog\" #@param {type:\"string\"}\n", + "#@markdown > This token is an alias for your fine-tune, allowing you to reference your fine-tune directly in your prompt. Each fine-tune you want to diffuse with must provide a unique alias.

For example, if your token was `$my-dog` you might use a prompt like: `a picture of $my-dog` or `$my-dog chasing a rabbit`.

If you have more than one fine-tune you can combine them! Given some fine-tune of film noir images you could use a prompt like `$my-dog in the style of $film-noir`.\n", + "prompt=\"a photo of $my-dog\" #@param {type:\"string\"}\n", + "#@markdown > The prompt to diffuse with. Must contain the `prompt_token` at least once.\n", + "dimensions=\"1024x1024\" #@param ['1024x1024', '1152x896', '1216x832', '1344x768', '1536x640', '640x1536', '768x1344', '832x1216', '896x1152']\n", + "#@markdown > The dimensions of the image to generate (width x height).\n", + "samples=4 #@param {type:\"slider\", min:1, max:10, step:1}\n", + "#@markdown > The number of images to generate. Requesting a large number of images may negatively response time.\n", + "steps=32 #@param {type:\"slider\", min:30, max:60, step:1}\n", + "#@markdown > The number of iterations or stages a diffusion model goes through in the process of generating an image from a given text prompt. Lower steps will generate more quickly, but if steps are lowered too much, image quality will suffer. Images with higher steps take longer to generate, but often give more detailed results.\n", + "cfg_value=7 #@param {type:\"slider\", min:0, max:35, step:1}\n", + "#@markdown > CFG (Classifier Free Guidance) scale determines how strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt).\n", + "seed=0 #@param {type:\"number\"}\n", + "#@markdown > The noise seed to use during diffusion. Using `0` means a random seed will be generated for each image. If you provide a non-zero value, images will be far less random.\n", + "\n", + "params = TextToImageParams(\n", + " fine_tunes=[\n", + " DiffusionFineTune(\n", + " id=fine_tune.id,\n", + " token=prompt_token,\n", + " # Uncomment the following to provide a weight for the fine-tune\n", + " # weight=1.0\n", + " ),\n", + "\n", + " # Uncomment the following to use multiple fine-tunes at once\n", + " # DiffusionFineTune(\n", + " # id=\"\",\n", + " # token=\"\",\n", + " # # weight=1.0\n", + " # ),\n", + " ],\n", + " text_prompts=[\n", + " TextPrompt(\n", + " text=prompt,\n", + " # weight=1.0\n", + " ),\n", + " ],\n", + " engine_id=ENGINE_ID,\n", + " samples=samples,\n", + " steps=steps,\n", + " seed=0,\n", + " cfg_value=cfg_value,\n", + " width=int(dimensions.split(\"x\")[0]),\n", + " height=int(dimensions.split(\"x\")[1]),\n", + " sampler=Sampler.K_DPMPP_2S_ANCESTRAL\n", + ")\n", + "\n", + "start_time = time.time()\n", + "images = rest_api.text_to_image(params)\n", + "\n", + "elapsed = time.time() - start_time\n", + "print(f\"Diffusion completed in {elapsed:.0f} seconds!\")\n", + "print(f\"{len(images)} result{'s' if len(images) > 1 else ''} will be displayed below momentarily (depending on the speed of Colab).\\n\")\n", + "\n", + "for image in images:\n", + " display(Image.open(io.BytesIO(base64.b64decode(image.base64))))" + ], + "metadata": { + "id": "sy1HcYqLEBXu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title (Optional) Download Images\n", + "from google.colab import files\n", + "\n", + "if not os.path.exists(\"./out\"):\n", + " os.makedirs(\"./out\")\n", + "\n", + "for index, image in enumerate(images):\n", + " with open(f'./out/txt2img_{image.seed}_{index}.png', \"wb\") as f:\n", + " f.write(base64.b64decode(image.base64))\n", + " files.download(f.name)" + ], + "metadata": { + "id": "7P-bBnScfaQQ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title (Optional) Rename Fine-Tune\n", + "\n", + "name = \"\" #@param {type:\"string\"}\n", + "rest_api.rename_fine_tune(fine_tune.id, name=name)" + ], + "metadata": { + "id": "tg2gkvlDn4Dm" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file