Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Nov 5, 2024
1 parent 40967f2 commit da1e9f5
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions optimum/utils/preprocessing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union

from datasets import Dataset, DatasetDict
from datasets import load_dataset as datasets_load_dataset
from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import BaseImageProcessor

from optimum.utils.import_utils import requires_backends

from .. import logging


if TYPE_CHECKING:
from datasets import Dataset, DatasetDict
from transformers import PretrainedConfig


Expand Down Expand Up @@ -102,11 +103,14 @@ def create_dataset_processing_func(

def prepare_dataset(
self,
dataset: Union[DatasetDict, Dataset],
dataset: Union["DatasetDict", "Dataset"],
data_keys: Dict[str, str],
ref_keys: Optional[List[str]] = None,
split: Optional[str] = None,
) -> Union[DatasetDict, Dataset]:
) -> Union["DatasetDict", "Dataset"]:
requires_backends(self, ["datasets"])
from datasets import Dataset

if isinstance(dataset, Dataset) and split is not None:
raise ValueError("A Dataset and a split name were provided, but splits are for DatasetDict.")
elif split is not None:
Expand All @@ -131,7 +135,12 @@ def load_dataset(
num_samples: Optional[int] = None,
shuffle: bool = False,
**load_dataset_kwargs,
) -> Union[DatasetDict, Dataset]:
) -> Union["DatasetDict", "Dataset"]:
requires_backends(self, ["datasets"])

from datasets import DatasetDict
from datasets import load_dataset as datasets_load_dataset

dataset = datasets_load_dataset(path, **load_dataset_kwargs)

if isinstance(dataset, DatasetDict) and load_smallest_split:
Expand Down

0 comments on commit da1e9f5

Please sign in to comment.