Skip to content

Commit

Permalink
added correct auto classes
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 18, 2024
1 parent fcb1690 commit 1cbb544
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
9 changes: 6 additions & 3 deletions optimum/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class PreTrainedModel(ABC): # noqa: F811

class OptimizedModel(PreTrainedModel):
config_class = AutoConfig
load_tf_weights = None
base_model_prefix = "optimized_model"
config_name = CONFIG_NAME

Expand Down Expand Up @@ -378,10 +377,14 @@ def from_pretrained(
)
model_id, revision = model_id.split("@")

library_name = TasksManager.infer_library_from_model(model_id, subfolder, revision, cache_dir, token=token)
library_name = TasksManager.infer_library_from_model(
model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
)

if library_name == "timm":
config = PretrainedConfig.from_pretrained(model_id, subfolder, revision)
config = PretrainedConfig.from_pretrained(
model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
)

if config is None:
if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME:
Expand Down
24 changes: 14 additions & 10 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@
from diffusers import (
DDIMScheduler,
DiffusionPipeline,
LatentConsistencyModelPipeline,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
)
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available
Expand Down Expand Up @@ -73,11 +77,13 @@

class ORTDiffusionPipeline(ORTModel):
auto_model_class = DiffusionPipeline
main_input_name = "input_ids"
main_input_name = "prompt"
base_model_prefix = "onnx_model"
config_name = "model_index.json"
sub_component_config_name = "config.json"

# TODO: instead of having a bloated init, we should probably have an init per pipeline,
# so that we can easily add new pipelines without having to modify the base class
def __init__(
self,
vae_decoder_session: ort.InferenceSession,
Expand Down Expand Up @@ -401,7 +407,7 @@ def _from_transformers(
provider_options: Optional[Dict[str, Any]] = None,
use_io_binding: Optional[bool] = None,
task: Optional[str] = None,
) -> "ORTStableDiffusionPipeline":
) -> "ORTDiffusionPipeline":
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
Expand Down Expand Up @@ -568,7 +574,7 @@ class ORTStableDiffusionPipeline(ORTDiffusionPipeline, StableDiffusionPipelineMi
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline).
"""

__call__ = StableDiffusionPipelineMixin.__call__
auto_model_class = StableDiffusionPipeline


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
Expand All @@ -577,7 +583,7 @@ class ORTStableDiffusionImg2ImgPipeline(ORTDiffusionPipeline, StableDiffusionImg
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img#diffusers.StableDiffusionImg2ImgPipeline).
"""

__call__ = StableDiffusionImg2ImgPipelineMixin.__call__
auto_model_class = StableDiffusionImg2ImgPipeline


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
Expand All @@ -586,7 +592,7 @@ class ORTStableDiffusionInpaintPipeline(ORTDiffusionPipeline, StableDiffusionInp
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionInpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint#diffusers.StableDiffusionInpaintPipeline).
"""

__call__ = StableDiffusionInpaintPipelineMixin.__call__
auto_model_class = StableDiffusionInpaintPipeline


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
Expand All @@ -595,12 +601,10 @@ class ORTLatentConsistencyModelPipeline(ORTDiffusionPipeline, LatentConsistencyP
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency#diffusers.LatentConsistencyModelPipeline).
"""

__call__ = LatentConsistencyPipelineMixin.__call__
auto_model_class = LatentConsistencyModelPipeline


class ORTStableDiffusionXLPipelineBase(ORTDiffusionPipeline):
auto_model_class = StableDiffusionXLImg2ImgPipeline

def __init__(
self,
vae_decoder_session: ort.InferenceSession,
Expand Down Expand Up @@ -653,7 +657,7 @@ class ORTStableDiffusionXLPipeline(ORTStableDiffusionXLPipelineBase, StableDiffu
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline).
"""

__call__ = StableDiffusionXLPipelineMixin.__call__
auto_model_class = StableDiffusionXLPipeline


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
Expand All @@ -662,7 +666,7 @@ class ORTStableDiffusionXLImg2ImgPipeline(ORTStableDiffusionXLPipelineBase, Stab
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline).
"""

__call__ = StableDiffusionXLImg2ImgPipelineMixin.__call__
auto_model_class = StableDiffusionXLImg2ImgPipeline


AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
Expand Down

0 comments on commit 1cbb544

Please sign in to comment.