Skip to content

Commit

Permalink
Run code formatting
Browse files Browse the repository at this point in the history
Fix isort generated folder

Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks committed Jun 28, 2024
1 parent deb93c6 commit 08baaf4
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ import_heading_thirdparty=Third Party
import_heading_firstparty=First Party
import_heading_localfolder=Local
known_firstparty=alog,aconfig,caikit,import_tracker
known_localfolder=caikit_computer_vision,tests
known_localfolder=caikit_computer_vision,tests,generated
6 changes: 3 additions & 3 deletions examples/runtime/run_train_and_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

# pylint: disable=no-name-in-module,import-error
try:
# Third Party
# Local
from generated import (
computervisionservice_pb2_grpc,
computervisiontrainingservice_pb2_grpc,
Expand All @@ -41,7 +41,7 @@
# The location of these imported message types depends on the version of Caikit
# that we are using.
try:
# Third Party
# Local
from generated.caikit_data_model.caikit_computer_vision import (
flatchannel_pb2,
flatimage_pb2,
Expand All @@ -58,7 +58,7 @@
IS_LEGACY = False
except ModuleNotFoundError:
# older versions of Caikit / py to proto create a flat proto structure
# Third Party
# Local
from generated import objectdetectiontaskrequest_pb2
from generated import (
objectdetectiontasktransformersobjectdetectortrainrequest_pb2 as odt_request_pb2,
Expand Down
32 changes: 28 additions & 4 deletions examples/text_to_image/README.md
Original file line number Diff line number Diff line change
@@ -1,82 +1,101 @@
# Text To Image (SDXL)
# Getting Started With Text To Image

This directory provides guidance for running text to image inference for text to image and a few useful scripts for getting started.

## Task and Module Overview

The text to image task has only one required parameter, the input text, and produces a `caikit_computer_vision.data_model.CaptionedImage` in response, which wraps the provided input text, as well as the generated image.

Currently there are two modules for text to image.:

- `caikit_computer_vision.modules.text_to_image.TTIStub` - A simple stub which produces a blue image of the request height and width at inference. This module is purely used for testing purposes.

- `caikit_computer_vision.modules.text_to_image.SDXL` - A module implementing text to image via SDXL.

This document will help you get started with both at the library & runtime level, ending with a sample gRPC client that can be usde to hit models running in a Caikit runtime container.

## Building the Environment

The easiest way to get started is to build a virtual environment in the root directory of this repo. Make sure the root of this project is on the `PYTHONPATH` so that `caikit_computer_vision` is findable.

To install the project:

```bash
python3 -m venv venv
source venv/bin/activate
pip install .
```

Note that if you prefer running in Docker, you can build an image as you normally would, and mount things into a running container:

```bash
docker build -t caikit-computer-vision:latest .
```

## Creating the Models

For the remainder of this demo, commands are intended to be run from this directory. First, we will be creating our models & runtime config in a directory named `caikit`, which is convenient for running locally or mounting into a container.

Copy the runtime config from the root of this project into the `caikit` directory.

```bash
mkdir -p caikit/models
cp ../../runtime_config.yaml caikit/runtime_config.yaml
```

Next, create your models.

```bash
python create_tti_models.py
```

This will create two models.

1. The stub model, at `caikit/models/stub_model`
2. The SDXL turbo model, at `caikit/models/sdxl_turbo_model`

Note that the names of these directories will be their model IDs in caikit runtime.

## Running Local Inference / API Overview

The text to image API is simple.

### Stub Module

For the stub module, we take an input prompt, a height, and a width, and create a blue image of the specified height and width.

```python
run(
inputs: str,
height: int,
inputs: str,
height: int,
width: int
) -> CaptionedImage:
```

Example using the stub model created from above:

```python
>>> import caikit_computer_vision, caikit
>>> stub_model = caikit.load("caikit/models/stub_model")
>>> res = stub_model.run("This is a text", height=512, width=512)
```

The resulting object holds the provided input text under `.caption`:

```python
>>> res.caption
'This is a text'
```

And the image bytes stored as PNG under `.output.image_data`

```python
>>> res.output.image_data
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x02\x00\x00\x00\x02\x00 ...
```

Note that the `output` object is a `Caikit` image backed by PIL. If you need a handle to it, you can call `as_pil()` to get handle to the PIL object as shown below.

```
>>> pil_im = res.output.as_pil()
>>> type(pil_im)
Expand All @@ -86,7 +105,9 @@ Note that the `output` object is a `Caikit` image backed by PIL. If you need a h
Grabbing a handle to the PIL image and then `.save()` on the result is the easiest way to save the image to disk.

### SDXL Module

The SDXL module is signature to the stub, with some additional options.

```python
run(
inputs: str,
Expand All @@ -109,16 +130,18 @@ The `image_format` arg follows the same conventions as PIL and controls the form
>>> res = stub_model.run("A golden retriever puppy sitting in a grassy field", height=512, width=512, num_steps=2, image_format="jpeg")
```


## Inference Through Runtime

To write a client, you'll need to export the proto files to compile. To do so, run `python export_protos.py`; this will use the runtime file you had previously copied to create a new directory called `protos`, containing the exported data model / task protos from caikit runtime.

Then to compile them, you can do something like the following; note that you may need to `pip install grpcio-tools` if it's not present in your environment, since it's not a dependency of `caikit_computer_vision`:

```bash
python -m grpc_tools.protoc -I protos --python_out=generated --grpc_python_out=generated protos/*.proto
```

In general, you will want to run Caikit Runtime in a Docker container. The easiest way to do this is to mount the `caikit` directory with your models into the container as shown below.

```bash
docker run -e CONFIG_FILES=/caikit/runtime_config.yaml \
-v $PWD/caikit/:/caikit \
Expand All @@ -129,5 +152,6 @@ docker run -e CONFIG_FILES=/caikit/runtime_config.yaml \
Then, you can hit it with a gRPC client using your compiled protobufs. A full example of inference via gRPC client calling both models can be found in `sample_client.py`.

Running `python sample_client.py` should produce two images.

- `stub_response_image.png` - blue image generated from the stub module
- `turbo_response_image.png` - picture of a golden retriever in a field generated by SDXL turbo
14 changes: 12 additions & 2 deletions examples/text_to_image/create_tti_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Creates and exports SDXL Turbo as a caikit module.
"""
from caikit_computer_vision.modules.text_to_image import TTIStub
# Standard
import os

# Local
from caikit_computer_vision.modules.text_to_image import TTIStub

SCRIPT_DIR = os.path.dirname(__file__)
MODELS_DIR = os.path.join(SCRIPT_DIR, "caikit", "models")
STUB_MODEL_PATH = os.path.join(MODELS_DIR, "stub_model")
Expand All @@ -15,13 +18,17 @@
model.save(STUB_MODEL_PATH)


# Third Party
### Make the model for SDXL turbo
import diffusers

# Local
from caikit_computer_vision.modules.text_to_image import SDXL

### Download the model for SDXL turbo...
sdxl_model = SDXL.bootstrap("stabilityai/sdxl-turbo")
sdxl_model.save(SDXL_TURBO_MODEL_PATH)
# Standard
# There appears to be a bug in the way that sharded safetensors are reloaded into the
# pipeline from diffusers, and there ALSO appears to be a bug where passing the max
# safetensor shard size to diffusers on a pipeline doesn't work as exoected.
Expand All @@ -30,11 +37,14 @@
# the sharded u-net, and reexport it as one file. By default the
# max shard size if 10GB, and the turbo unit is barely larger than 10.
from shutil import rmtree

unet_path = os.path.join(SDXL_TURBO_MODEL_PATH, "sdxl_model", "unet")
try:
diffusers.UNet2DConditionModel.from_pretrained(unet_path)
except RuntimeError:
print("Unable to reload turbo u-net due to sharding issues; reexporting as single file")
print(
"Unable to reload turbo u-net due to sharding issues; reexporting as single file"
)
rmtree(unet_path)
sdxl_model.pipeline.unet.save_pretrained(unet_path, max_shard_size="12GB")

Expand Down
8 changes: 4 additions & 4 deletions examples/text_to_image/export_protos.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from caikit.runtime.dump_services import dump_grpc_services
import caikit

SCRIPT_DIR=os.path.dirname(__file__)
PROTO_EXPORT_DIR=os.path.join(SCRIPT_DIR, "protos")
RUNTIME_CONFIG_PATH=os.path.join(SCRIPT_DIR, "caikit", "runtime_config.yaml")
SCRIPT_DIR = os.path.dirname(__file__)
PROTO_EXPORT_DIR = os.path.join(SCRIPT_DIR, "protos")
RUNTIME_CONFIG_PATH = os.path.join(SCRIPT_DIR, "caikit", "runtime_config.yaml")

if os.path.isdir(PROTO_EXPORT_DIR):
rmtree(PROTO_EXPORT_DIR)
Expand All @@ -27,4 +27,4 @@
k: v for k, v in grpc_service_dumper_kwargs.items() if k in expected_grpc_params
}
dump_grpc_services(**grpc_service_dumper_kwargs)
# NOTE: If you need an http client for inference, use `dump_http_services` from caikit instead.
# NOTE: If you need an http client for inference, use `dump_http_services` from caikit instead.
10 changes: 5 additions & 5 deletions examples/text_to_image/sample_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Standard
import io

from generated import (
computervisionservice_pb2_grpc,
)
from generated.ccv import texttoimagetaskrequest_pb2

# Third Party
from PIL import Image
import grpc

# Local
from generated import computervisionservice_pb2_grpc
from generated.ccv import texttoimagetaskrequest_pb2

# Setup the client
port = 8085
Expand Down

0 comments on commit 08baaf4

Please sign in to comment.