diff --git a/optimum/utils/preprocessing/base.py b/optimum/utils/preprocessing/base.py index dc995ccc50b..19b4d9614c0 100644 --- a/optimum/utils/preprocessing/base.py +++ b/optimum/utils/preprocessing/base.py @@ -20,15 +20,16 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from datasets import Dataset, DatasetDict -from datasets import load_dataset as datasets_load_dataset from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import BaseImageProcessor +from optimum.utils.import_utils import requires_backends + from .. import logging if TYPE_CHECKING: + from datasets import Dataset, DatasetDict from transformers import PretrainedConfig @@ -102,11 +103,14 @@ def create_dataset_processing_func( def prepare_dataset( self, - dataset: Union[DatasetDict, Dataset], + dataset: Union["DatasetDict", "Dataset"], data_keys: Dict[str, str], ref_keys: Optional[List[str]] = None, split: Optional[str] = None, - ) -> Union[DatasetDict, Dataset]: + ) -> Union["DatasetDict", "Dataset"]: + requires_backends(self, ["datasets"]) + from datasets import Dataset + if isinstance(dataset, Dataset) and split is not None: raise ValueError("A Dataset and a split name were provided, but splits are for DatasetDict.") elif split is not None: @@ -131,7 +135,12 @@ def load_dataset( num_samples: Optional[int] = None, shuffle: bool = False, **load_dataset_kwargs, - ) -> Union[DatasetDict, Dataset]: + ) -> Union["DatasetDict", "Dataset"]: + requires_backends(self, ["datasets"]) + + from datasets import DatasetDict + from datasets import load_dataset as datasets_load_dataset + dataset = datasets_load_dataset(path, **load_dataset_kwargs) if isinstance(dataset, DatasetDict) and load_smallest_split: