From cce0ee8f717975b3cb9475e920e406c8132d9094 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 18 Sep 2024 09:09:30 +0200 Subject: [PATCH] reduce parent model usage in model parts --- optimum/onnxruntime/modeling_diffusion.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 0c965d76bc2..99e618c5deb 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -468,13 +468,17 @@ def __call__(self, *args, **kwargs): return self.auto_model_class.__call__(self, *args, **kwargs) -class ORTPipelinePart(ORTModelPart): +class ORTPipelinePart(ORTModelPart, ConfigMixin): + config_name: str = "config.json" + def __init__(self, session: ort.InferenceSession, parent_model: ORTPipeline): super().__init__(session, parent_model) - config_path = Path(session._model_path).parent / "config.json" - config_dict = parent_model._dict_from_json_file(config_path) if config_path.is_file() else {} - self.config = FrozenDict(config_dict) + config_path = Path(session._model_path).parent / self.config_name + config_dict = self.load_config(config_path) if config_path.is_file() else {} + config_dict = config_dict[0] if isinstance(config_dict, tuple) else config_dict + + self._internal_dict = FrozenDict(config_dict) @property def input_dtype(self): @@ -605,10 +609,11 @@ def forward( class ORTVaeWrapper(ORTPipelinePart): def __init__(self, vae_encoder: ORTModelVaeEncoder, vae_decoder: ORTModelVaeDecoder, parent_model: ORTPipeline): - super().__init__(vae_decoder.session, parent_model) self.vae_encoder = vae_encoder self.vae_decoder = vae_decoder + super().__init__(vae_decoder.session, parent_model) + def encode( self, sample: Union[np.ndarray, torch.Tensor],