Skip to content

Commit

Permalink
restore input format for stable diffusion and export configs mapping (#…
Browse files Browse the repository at this point in the history
…1091)

* restore input format for stable diffusion

* update configs registration

* fix shapes for timestep

* align names for t5
  • Loading branch information
eaidova authored Dec 23, 2024
1 parent 29b2ac9 commit 87c431c
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,7 @@ def patch_model_for_export(

@register_in_tasks_manager("clip-text-model", *["feature-extraction"], library_name="transformers")
@register_in_tasks_manager("clip-text-model", *["feature-extraction"], library_name="diffusers")
@register_in_tasks_manager("clip-text", *["feature-extraction"], library_name="diffusers")
class CLIPTextOpenVINOConfig(CLIPTextOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -1795,12 +1796,31 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
)


class DummyUnetTimestepInputGenerator(DummyTimestepInputGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name != "timestep":
return super().generate(input_name, framework, int_dtype, float_dtype)
shape = [self.batch_size]
return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=int_dtype)


@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
@register_in_tasks_manager("unet-2d-condition", *["semantic-segmentation"], library_name="diffusers")
class UnetOpenVINOConfig(UNetOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyUnetVisionInputGenerator,) + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:]
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyUnetVisionInputGenerator,
DummyUnetTimestepInputGenerator,
) + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[2:]

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs
common_inputs["timestep"] = {0: "batch_size"}
return common_inputs


@register_in_tasks_manager("sd3-transformer", *["semantic-segmentation"], library_name="diffusers")
@register_in_tasks_manager("sd3-transformer-2d", *["semantic-segmentation"], library_name="diffusers")
class SD3TransformerOpenVINOConfig(UNetOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
(DummyTransformerTimestpsInputGenerator,)
Expand Down Expand Up @@ -1830,6 +1850,7 @@ def rename_ambiguous_inputs(self, inputs):


@register_in_tasks_manager("t5-encoder-model", *["feature-extraction"], library_name="diffusers")
@register_in_tasks_manager("t5-encoder", *["feature-extraction"], library_name="diffusers")
class T5EncoderOpenVINOConfig(CLIPTextOpenVINOConfig):
pass

Expand Down Expand Up @@ -1905,6 +1926,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int


@register_in_tasks_manager("flux-transformer", *["semantic-segmentation"], library_name="diffusers")
@register_in_tasks_manager("flux-transformer-2d", *["semantic-segmentation"], library_name="diffusers")
class FluxTransformerOpenVINOConfig(SD3TransformerOpenVINOConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestpsInputGenerator,
Expand Down

0 comments on commit 87c431c

Please sign in to comment.