Skip to content

Commit

Permalink
remove subfolder
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Oct 7, 2024
1 parent 34e2789 commit e4c7184
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@

class ORTPipeline(ORTModel, ConfigMixin):
config_name = "model_index.json"
auto_model_class = None

def __init__(
self,
Expand Down Expand Up @@ -241,7 +242,6 @@ def _from_pretrained(
cls,
model_id: Union[str, Path],
config: Dict[str, Any],
subfolder: str = "",
force_download: bool = False,
local_files_only: bool = False,
revision: Optional[str] = None,
Expand All @@ -264,10 +264,8 @@ def _from_pretrained(
"IOBinding is not yet available for diffusion pipelines, please set `use_io_binding` to False."
)

all_components = {key for key in config.keys() if not key.startswith("_")}
all_components.update({"vae_encoder", "vae_decoder"})

if not os.path.isdir(str(model_id)):
all_components = {key for key in config.keys() if not key.startswith("_")} | {"vae_encoder", "vae_decoder"}
allow_patterns = {os.path.join(component, "*") for component in all_components}
allow_patterns.update(
{
Expand All @@ -281,10 +279,6 @@ def _from_pretrained(
CONFIG_NAME,
}
)

if subfolder:
allow_patterns = {os.path.join(subfolder, pattern) for pattern in allow_patterns}

model_id = snapshot_download(
model_id,
cache_dir=cache_dir,
Expand All @@ -298,8 +292,8 @@ def _from_pretrained(

model_save_path = Path(model_id)

if subfolder:
model_save_path = model_save_path / subfolder
if model_save_dir is None:
model_save_dir = model_save_path

submodels = {}
for name in {"feature_extractor", "tokenizer", "tokenizer_2", "scheduler"}:
Expand Down Expand Up @@ -329,7 +323,7 @@ def _from_pretrained(
**models,
**submodels,
use_io_binding=use_io_binding,
model_save_dir=model_save_dir or model_save_path,
model_save_dir=model_save_dir,
)

@classmethod
Expand Down Expand Up @@ -437,8 +431,8 @@ def components(self) -> Dict[str, Any]:
"unet": self.unet,
"text_encoder": self.text_encoder,
"text_encoder_2": self.text_encoder_2,
"image_encoder": self.image_encoder,
"safety_checker": self.safety_checker,
"image_encoder": self.image_encoder,
}
components = {k: v for k, v in components.items() if v is not None}
return components
Expand Down Expand Up @@ -466,17 +460,16 @@ def __init__(self, session: ort.InferenceSession, parent_pipeline: ORTPipeline):
# config is mandatory for the model part to be used for inference
raise ValueError(f"Configuration file for {self.__class__.__name__} not found at {config_file_path}")

config_dict = parent_pipeline._dict_from_json_file(config_file_path)
config_dict = self._dict_from_json_file(config_file_path)
self.register_to_config(**config_dict)

# ort model part
self.session = session
self.parent_pipeline = parent_pipeline

self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}
self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()}
self.output_dtypes = {output_key.name: output_key.type for output_key in session.get_outputs()}
self.input_dtypes = {input_key.name: input_key.type for input_key in self.session.get_inputs()}
self.output_dtypes = {output_key.name: output_key.type for output_key in self.session.get_outputs()}

@property
def device(self):
Expand Down

0 comments on commit e4c7184

Please sign in to comment.