Skip to content

Commit

Permalink
Cleanup models after conversion (#2558)
Browse files Browse the repository at this point in the history
CVS-157654
  • Loading branch information
aleksandr-mokrov authored Nov 28, 2024
1 parent e242c65 commit 3dc041d
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 66 deletions.
69 changes: 36 additions & 33 deletions notebooks/catvton/catvton.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
"from ov_catvton_helper import download_models, convert_pipeline_models, convert_automasker_models\n",
"\n",
"pipeline, mask_processor, automasker = download_models()\n",
"vae_scaling_factor = pipeline.vae.config.scaling_factor\n",
"convert_pipeline_models(pipeline)\n",
"convert_automasker_models(automasker)"
]
Expand Down Expand Up @@ -181,7 +182,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": null,
"id": "8612d4be-e0cf-4249-881e-5270cc33ef28",
"metadata": {},
"outputs": [],
Expand All @@ -197,7 +198,7 @@
" SCHP_PROCESSOR_LIP,\n",
")\n",
"\n",
"pipeline = get_compiled_pipeline(pipeline, core, device, VAE_ENCODER_PATH, VAE_DECODER_PATH, UNET_PATH)\n",
"pipeline = get_compiled_pipeline(pipeline, core, device, VAE_ENCODER_PATH, VAE_DECODER_PATH, UNET_PATH, vae_scaling_factor)\n",
"automasker = get_compiled_automasker(automasker, core, device, DENSEPOSE_PROCESSOR_PATH, SCHP_PROCESSOR_ATR, SCHP_PROCESSOR_LIP)"
]
},
Expand Down Expand Up @@ -239,13 +240,12 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "1b307bdd",
"metadata": {},
"outputs": [],
"source": [
"optimized_pipe = None\n",
"optimized_automasker = None\n",
"is_optimized_pipe_available = False\n",
"\n",
"# Fetch skip_kernel_extension module\n",
"r = requests.get(\n",
Expand Down Expand Up @@ -309,16 +309,23 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "f64b96e4",
"metadata": {},
"outputs": [],
"source": [
"%%skip not $to_quantize.value\n",
"\n",
"import gc\n",
"import nncf\n",
"from ov_catvton_helper import UNET_PATH\n",
"\n",
"# cleanup before quantization to free memory\n",
"del pipeline\n",
"del automasker\n",
"gc.collect()\n",
"\n",
"\n",
"if not UNET_INT8_PATH.exists():\n",
" unet = core.read_model(UNET_PATH)\n",
" quantized_model = nncf.quantize(\n",
Expand All @@ -327,7 +334,9 @@
" subset_size=subset_size,\n",
" model_type=nncf.ModelType.TRANSFORMER,\n",
" )\n",
" ov.save_model(quantized_model, UNET_INT8_PATH)"
" ov.save_model(quantized_model, UNET_INT8_PATH)\n",
" del quantized_model\n",
" gc.collect()"
]
},
{
Expand All @@ -352,29 +361,9 @@
"\n",
"from catvton_quantization_helper import compress_models\n",
"\n",
"compress_models(core)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e9c41725",
"metadata": {},
"outputs": [],
"source": [
"%%skip not $to_quantize.value\n",
"\n",
"from catvton_quantization_helper import (\n",
" VAE_ENCODER_INT4_PATH,\n",
" VAE_DECODER_INT4_PATH,\n",
" DENSEPOSE_PROCESSOR_INT4_PATH,\n",
" SCHP_PROCESSOR_ATR_INT4,\n",
" SCHP_PROCESSOR_LIP_INT4,\n",
")\n",
"compress_models(core)\n",
"\n",
"optimized_pipe, _, optimized_automasker = download_models()\n",
"optimized_pipe = get_compiled_pipeline(optimized_pipe, core, device, VAE_ENCODER_INT4_PATH, VAE_DECODER_INT4_PATH, UNET_INT8_PATH)\n",
"optimized_automasker = get_compiled_automasker(optimized_automasker, core, device, DENSEPOSE_PROCESSOR_INT4_PATH, SCHP_PROCESSOR_ATR_INT4, SCHP_PROCESSOR_LIP_INT4)"
"is_optimized_pipe_available = True"
]
},
{
Expand Down Expand Up @@ -432,7 +421,7 @@
"source": [
"from ov_catvton_helper import get_pipeline_selection_option\n",
"\n",
"use_quantized_models = get_pipeline_selection_option(optimized_pipe)\n",
"use_quantized_models = get_pipeline_selection_option(is_optimized_pipe_available)\n",
"\n",
"use_quantized_models"
]
Expand All @@ -448,11 +437,25 @@
"source": [
"from gradio_helper import make_demo\n",
"\n",
"pipe = optimized_pipe if use_quantized_models.value else pipeline\n",
"masker = optimized_automasker if use_quantized_models.value else automasker\n",
"from catvton_quantization_helper import (\n",
" VAE_ENCODER_INT4_PATH,\n",
" VAE_DECODER_INT4_PATH,\n",
" DENSEPOSE_PROCESSOR_INT4_PATH,\n",
" SCHP_PROCESSOR_ATR_INT4,\n",
" SCHP_PROCESSOR_LIP_INT4,\n",
" UNET_INT8_PATH,\n",
")\n",
"\n",
"pipeline, mask_processor, automasker = download_models()\n",
"if use_quantized_models.value:\n",
" pipeline = get_compiled_pipeline(pipeline, core, device, VAE_ENCODER_INT4_PATH, VAE_DECODER_INT4_PATH, UNET_INT8_PATH, vae_scaling_factor)\n",
" automasker = get_compiled_automasker(automasker, core, device, DENSEPOSE_PROCESSOR_INT4_PATH, SCHP_PROCESSOR_ATR_INT4, SCHP_PROCESSOR_LIP_INT4)\n",
"else:\n",
" pipeline = get_compiled_pipeline(pipeline, core, device, VAE_ENCODER_PATH, VAE_DECODER_PATH, UNET_PATH, vae_scaling_factor)\n",
" automasker = get_compiled_automasker(automasker, core, device, DENSEPOSE_PROCESSOR_PATH, SCHP_PROCESSOR_ATR, SCHP_PROCESSOR_LIP)\n",
"\n",
"output_dir = \"output\"\n",
"demo = make_demo(pipe, mask_processor, masker, output_dir)\n",
"demo = make_demo(pipeline, mask_processor, automasker, output_dir)\n",
"try:\n",
" demo.launch(debug=True)\n",
"except Exception:\n",
Expand Down
71 changes: 42 additions & 29 deletions notebooks/catvton/catvton_quantization_helper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Any, List
import torch
import nncf
from pathlib import Path
import pickle

from tqdm.notebook import tqdm
from transformers import set_seed
import numpy as np
import openvino as ov
from PIL import Image
import torch
import nncf

from ov_catvton_helper import (
MODEL_DIR,
Expand Down Expand Up @@ -49,34 +51,45 @@ def __call__(self, *args, **kwargs):


def collect_calibration_data(pipeline, automasker, mask_processor, dataset, subset_size):
original_unet = pipeline.unet.unet
pipeline.unet.unet = CompiledModelDecorator(original_unet)

calibration_dataset = []
pbar = tqdm(total=subset_size, desc="Collecting calibration dataset")
for data in dataset:
person_image_path, cloth_image_path = data
person_image = Image.open(person_image_path)
cloth_image = Image.open(cloth_image_path)
cloth_type = "upper" if "upper" in person_image_path.as_posix() else "overall"
mask = automasker(person_image, cloth_type)["mask"]
mask = mask_processor.blur(mask, blur_factor=9)

pipeline(
image=person_image,
condition_image=cloth_image,
mask=mask,
num_inference_steps=NUM_INFERENCE_STEPS,
guidance_scale=GUIDANCE_SCALE,
generator=GENERATOR,
)
collected_subset_size = len(pipeline.unet.unet.data_cache)
pbar.update(NUM_INFERENCE_STEPS)
if collected_subset_size >= subset_size:
break
calibration_dataset_filepath = Path("calibration_data") / f"{subset_size}.pkl"
calibration_dataset_filepath.parent.mkdir(exist_ok=True, parents=True)

if not calibration_dataset_filepath.exists():
original_unet = pipeline.unet.unet
pipeline.unet.unet = CompiledModelDecorator(original_unet)

calibration_dataset = []
pbar = tqdm(total=subset_size, desc="Collecting calibration dataset")
for data in dataset:
person_image_path, cloth_image_path = data
person_image = Image.open(person_image_path)
cloth_image = Image.open(cloth_image_path)
cloth_type = "upper" if "upper" in person_image_path.as_posix() else "overall"
mask = automasker(person_image, cloth_type)["mask"]
mask = mask_processor.blur(mask, blur_factor=9)

pipeline(
image=person_image,
condition_image=cloth_image,
mask=mask,
num_inference_steps=NUM_INFERENCE_STEPS,
guidance_scale=GUIDANCE_SCALE,
generator=GENERATOR,
)
collected_subset_size = len(pipeline.unet.unet.data_cache)
pbar.update(NUM_INFERENCE_STEPS)
if collected_subset_size >= subset_size:
break

calibration_dataset = pipeline.unet.unet.data_cache
pipeline.unet.unet = original_unet

with open(calibration_dataset_filepath, "wb") as f:
pickle.dump(calibration_dataset, f)
else:
with open(calibration_dataset_filepath, "rb") as f:
calibration_dataset = pickle.load(f)

calibration_dataset = pipeline.unet.unet.data_cache
pipeline.unet.unet = original_unet
return calibration_dataset


Expand Down
16 changes: 12 additions & 4 deletions notebooks/catvton/ov_catvton_helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import os
from collections import namedtuple
from pathlib import Path
Expand Down Expand Up @@ -93,13 +94,16 @@ def download_models():
def convert_pipeline_models(pipeline):
convert(VaeEncoder(pipeline.vae), VAE_ENCODER_PATH, torch.zeros(1, 3, 1024, 768))
convert(VaeDecoder(pipeline.vae), VAE_DECODER_PATH, torch.zeros(1, 4, 128, 96))
del pipeline.vae

inpainting_latent_model_input = torch.zeros(2, 9, 256, 96)
timestep = torch.tensor(0)
encoder_hidden_states = torch.zeros(2, 1, 768)
example_input = (inpainting_latent_model_input, timestep, encoder_hidden_states)

convert(UNetWrapper(pipeline.unet), UNET_PATH, example_input)
del pipeline.unet
gc.collect()


def convert_automasker_models(automasker):
Expand All @@ -115,19 +119,23 @@ def inference(model, inputs):
traceable_model = TracingAdapter(automasker.densepose_processor.predictor.model, tracing_input, inference)

convert(traceable_model, DENSEPOSE_PROCESSOR_PATH, tracing_input[0]["image"])
del automasker.densepose_processor.predictor.model

convert(automasker.schp_processor_atr.model, SCHP_PROCESSOR_ATR, torch.rand([1, 3, 512, 512], dtype=torch.float32))
convert(automasker.schp_processor_lip.model, SCHP_PROCESSOR_LIP, torch.rand([1, 3, 473, 473], dtype=torch.float32))
del automasker.schp_processor_atr.model
del automasker.schp_processor_lip.model
gc.collect()


class VAEWrapper(torch.nn.Module):
def __init__(self, vae_encoder, vae_decoder, config):
def __init__(self, vae_encoder, vae_decoder, scaling_factor):
super().__init__()
self.vae_enocder = vae_encoder
self.vae_decoder = vae_decoder
self.device = "cpu"
self.dtype = torch.float32
self.config = config
self.config = namedtuple("VAEConfig", ["scaling_factor"])(scaling_factor)

def encode(self, pixel_values):
ov_outputs = self.vae_enocder(pixel_values).to_dict()
Expand Down Expand Up @@ -202,12 +210,12 @@ def forward(self, image):
return torch.from_numpy(outputs[0])


def get_compiled_pipeline(pipeline, core, device, vae_encoder_path, vae_decoder_path, unet_path):
def get_compiled_pipeline(pipeline, core, device, vae_encoder_path, vae_decoder_path, unet_path, vae_scaling_factor):
compiled_unet = core.compile_model(unet_path, device.value)
compiled_vae_encoder = core.compile_model(vae_encoder_path, device.value)
compiled_vae_decoder = core.compile_model(vae_decoder_path, device.value)

pipeline.vae = VAEWrapper(compiled_vae_encoder, compiled_vae_decoder, pipeline.vae.config)
pipeline.vae = VAEWrapper(compiled_vae_encoder, compiled_vae_decoder, vae_scaling_factor)
pipeline.unet = ConvUnetWrapper(compiled_unet)

return pipeline
Expand Down

0 comments on commit 3dc041d

Please sign in to comment.