diff --git a/optimum/utils/testing_utils.py b/optimum/utils/testing_utils.py index f1c2f668e3c..d04a610eb6a 100644 --- a/optimum/utils/testing_utils.py +++ b/optimum/utils/testing_utils.py @@ -16,6 +16,7 @@ import importlib.util import itertools import os +import shutil import subprocess import sys import unittest @@ -184,3 +185,16 @@ def grid_parameters( else: returned_list = [test_name] + list(params) if add_test_name is True else list(params) yield returned_list + + +def remove_directory(dirpath): + """ + Remove a directory and its content. + This is a cross-platform solution to remove a directory and its content that avoids the use of `shutil.rmtree` on Windows. + Reference: https://github.com/python/cpython/issues/107408 + """ + if os.path.exists(dirpath) and os.path.isdir(dirpath): + if os.name == "nt": + os.system(f"rmdir /S /Q {dirpath}") + else: + shutil.rmtree(dirpath) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 5d85dac32bd..36c36c297fc 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -14,7 +14,6 @@ # limitations under the License. import gc import os -import shutil import subprocess import tempfile import time @@ -109,7 +108,7 @@ DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, logging, ) -from optimum.utils.testing_utils import grid_parameters, require_hf_token, require_ort_rocm +from optimum.utils.testing_utils import grid_parameters, remove_directory, require_hf_token, require_ort_rocm logger = logging.get_logger() @@ -184,12 +183,8 @@ def test_load_model_from_cache(self): def test_load_model_from_empty_cache(self): dirpath = os.path.join(default_cache_path, "models--" + self.TINY_ONNX_MODEL_ID.replace("/", "--")) + remove_directory(dirpath) - if os.path.exists(dirpath) and os.path.isdir(dirpath): - if os.name == "nt": - os.system(f"rmdir /S /Q {dirpath}") - else: - shutil.rmtree(dirpath) with self.assertRaises(Exception): _ = ORTModel.from_pretrained(self.TINY_ONNX_MODEL_ID, local_files_only=True) @@ -205,12 +200,8 @@ def test_load_seq2seq_model_from_cache(self): def test_load_seq2seq_model_from_empty_cache(self): dirpath = os.path.join(default_cache_path, "models--" + self.TINY_ONNX_SEQ2SEQ_MODEL_ID.replace("/", "--")) + remove_directory(dirpath) - if os.path.exists(dirpath) and os.path.isdir(dirpath): - if os.name == "nt": - os.system(f"rmdir /S /Q {dirpath}") - else: - shutil.rmtree(dirpath) with self.assertRaises(Exception): _ = ORTModelForSeq2SeqLM.from_pretrained(self.TINY_ONNX_SEQ2SEQ_MODEL_ID, local_files_only=True) @@ -231,12 +222,8 @@ def test_load_stable_diffusion_model_from_empty_cache(self): dirpath = os.path.join( default_cache_path, "models--" + self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID.replace("/", "--") ) + remove_directory(dirpath) - if os.path.exists(dirpath) and os.path.isdir(dirpath): - if os.name == "nt": - os.system(f"rmdir /S /Q {dirpath}") - else: - shutil.rmtree(dirpath) with self.assertRaises(Exception): _ = ORTStableDiffusionPipeline.from_pretrained( self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID, local_files_only=True @@ -1014,9 +1001,7 @@ def test_save_load_ort_model_with_external_data(self): # verify loading from local folder works model = ORTModelForSequenceClassification.from_pretrained(tmpdirname, export=False) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") - - if os.name == "nt": - os.system(f"rmdir /s /q {tmpdirname}") + remove_directory(tmpdirname) @parameterized.expand([(False,), (True,)]) @pytest.mark.run_slow @@ -1038,9 +1023,7 @@ def test_save_load_decoder_model_with_external_data(self, use_cache: bool): model = ORTModelForCausalLM.from_pretrained( tmpdirname, use_cache=use_cache, export=False, use_io_binding=False ) - - if os.name == "nt": - os.system(f"rmdir /s /q {tmpdirname}") + remove_directory(tmpdirname) @parameterized.expand([(False,), (True,)]) def test_save_load_seq2seq_model_with_external_data(self, use_cache: bool): @@ -1063,9 +1046,7 @@ def test_save_load_seq2seq_model_with_external_data(self, use_cache: bool): # verify loading from local folder works model = ORTModelForSeq2SeqLM.from_pretrained(tmpdirname, use_cache=use_cache, export=False) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") - - if os.name == "nt": - os.system(f"rmdir /s /q {tmpdirname}") + remove_directory(tmpdirname) def test_save_load_stable_diffusion_model_with_external_data(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -1087,9 +1068,7 @@ def test_save_load_stable_diffusion_model_with_external_data(self): # verify loading from local folder works model = ORTStableDiffusionPipeline.from_pretrained(tmpdirname, export=False) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") - - if os.name == "nt": - os.system(f"rmdir /s /q {tmpdirname}") + remove_directory(tmpdirname) @parameterized.expand([(False,), (True,)]) @unittest.skip("Skipping as this test consumes too much memory")