Skip to content

Commit

Permalink
Cleanup models after conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
aleksandr-mokrov committed Nov 25, 2024
1 parent 9af6200 commit 31479d1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
5 changes: 3 additions & 2 deletions notebooks/catvton/catvton.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
"from ov_catvton_helper import download_models, convert_pipeline_models, convert_automasker_models\n",
"\n",
"pipeline, mask_processor, automasker = download_models()\n",
"vae_scaling_factor = pipeline.vae.config.scaling_factor\n",
"convert_pipeline_models(pipeline)\n",
"convert_automasker_models(automasker)"
]
Expand Down Expand Up @@ -181,7 +182,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": null,
"id": "8612d4be-e0cf-4249-881e-5270cc33ef28",
"metadata": {},
"outputs": [],
Expand All @@ -197,7 +198,7 @@
" SCHP_PROCESSOR_LIP,\n",
")\n",
"\n",
"pipeline = get_compiled_pipeline(pipeline, core, device, VAE_ENCODER_PATH, VAE_DECODER_PATH, UNET_PATH)\n",
"pipeline = get_compiled_pipeline(pipeline, core, device, VAE_ENCODER_PATH, VAE_DECODER_PATH, UNET_PATH, vae_scaling_factor)\n",
"automasker = get_compiled_automasker(automasker, core, device, DENSEPOSE_PROCESSOR_PATH, SCHP_PROCESSOR_ATR, SCHP_PROCESSOR_LIP)"
]
},
Expand Down
16 changes: 12 additions & 4 deletions notebooks/catvton/ov_catvton_helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import os
from collections import namedtuple
from pathlib import Path
Expand Down Expand Up @@ -93,13 +94,16 @@ def download_models():
def convert_pipeline_models(pipeline):
convert(VaeEncoder(pipeline.vae), VAE_ENCODER_PATH, torch.zeros(1, 3, 1024, 768))
convert(VaeDecoder(pipeline.vae), VAE_DECODER_PATH, torch.zeros(1, 4, 128, 96))
del pipeline.vae

inpainting_latent_model_input = torch.zeros(2, 9, 256, 96)
timestep = torch.tensor(0)
encoder_hidden_states = torch.zeros(2, 1, 768)
example_input = (inpainting_latent_model_input, timestep, encoder_hidden_states)

convert(UNetWrapper(pipeline.unet), UNET_PATH, example_input)
del pipeline.unet
gc.collect()


def convert_automasker_models(automasker):
Expand All @@ -115,19 +119,23 @@ def inference(model, inputs):
traceable_model = TracingAdapter(automasker.densepose_processor.predictor.model, tracing_input, inference)

convert(traceable_model, DENSEPOSE_PROCESSOR_PATH, tracing_input[0]["image"])
del automasker.densepose_processor.predictor.model

convert(automasker.schp_processor_atr.model, SCHP_PROCESSOR_ATR, torch.rand([1, 3, 512, 512], dtype=torch.float32))
convert(automasker.schp_processor_lip.model, SCHP_PROCESSOR_LIP, torch.rand([1, 3, 473, 473], dtype=torch.float32))
del automasker.schp_processor_atr.model
del automasker.schp_processor_lip.model
gc.collect()


class VAEWrapper(torch.nn.Module):
def __init__(self, vae_encoder, vae_decoder, config):
def __init__(self, vae_encoder, vae_decoder, scaling_factor):
super().__init__()
self.vae_enocder = vae_encoder
self.vae_decoder = vae_decoder
self.device = "cpu"
self.dtype = torch.float32
self.config = config
self.config = namedtuple("VAEConfig", ["scaling_factor"])(scaling_factor)

def encode(self, pixel_values):
ov_outputs = self.vae_enocder(pixel_values).to_dict()
Expand Down Expand Up @@ -202,12 +210,12 @@ def forward(self, image):
return torch.from_numpy(outputs[0])


def get_compiled_pipeline(pipeline, core, device, vae_encoder_path, vae_decoder_path, unet_path):
def get_compiled_pipeline(pipeline, core, device, vae_encoder_path, vae_decoder_path, unet_path, vae_scaling_factor):
compiled_unet = core.compile_model(unet_path, device.value)
compiled_vae_encoder = core.compile_model(vae_encoder_path, device.value)
compiled_vae_decoder = core.compile_model(vae_decoder_path, device.value)

pipeline.vae = VAEWrapper(compiled_vae_encoder, compiled_vae_decoder, pipeline.vae.config)
pipeline.vae = VAEWrapper(compiled_vae_encoder, compiled_vae_decoder, vae_scaling_factor)
pipeline.unet = ConvUnetWrapper(compiled_unet)

return pipeline
Expand Down

0 comments on commit 31479d1

Please sign in to comment.