diff --git a/optimum/gptq/data.py b/optimum/gptq/data.py index 37a42714fc8..ae8f83cda25 100644 --- a/optimum/gptq/data.py +++ b/optimum/gptq/data.py @@ -182,40 +182,11 @@ def get_c4_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train") def get_ptb(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"): - if split == "train": - data = load_dataset("ptb_text_only", "penn_treebank", split="train") - elif split == "validation": - data = load_dataset("ptb_text_only", "penn_treebank", split="validation") - - enc = tokenizer(" ".join(data["sentence"]), return_tensors="pt") - - dataset = [] - for _ in range(nsamples): - i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = enc.input_ids[:, i:j] - attention_mask = torch.ones_like(inp) - dataset.append({"input_ids": inp, "attention_mask": attention_mask}) - - return dataset + raise RuntimeError("Loading the `ptb` dataset was deprecated") def get_ptb_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"): - if split == "train": - data = load_dataset("ptb_text_only", "penn_treebank", split="train") - elif split == "validation": - data = load_dataset("ptb_text_only", "penn_treebank", split="test") - - enc = tokenizer(" ".join(data["sentence"]), return_tensors="pt") - - dataset = [] - for _ in range(nsamples): - i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = enc.input_ids[:, i:j] - attention_mask = torch.ones_like(inp) - dataset.append({"input_ids": inp, "attention_mask": attention_mask}) - return dataset + raise RuntimeError("Loading the `ptb` dataset was deprecated") def get_dataset( @@ -226,7 +197,7 @@ def get_dataset( Args: dataset_name (`str`): - Dataset name. Available options are `['wikitext2', 'c4', 'ptb', 'c4-new', 'ptb_new']`. + Dataset name. Available options are `['wikitext2', 'c4', 'c4-new']`. tokenizer (`Any`): Tokenizer of the model nsamples (`int`, defaults to `128`): @@ -247,11 +218,11 @@ def get_dataset( "wikitext2": get_wikitext2, "c4": get_c4, "c4-new": get_c4_new, - "ptb": get_ptb, - "ptb-new": get_ptb_new, } if split not in ["train", "validation"]: raise ValueError(f"The split need to be 'train' or 'validation' but found {split}") + if dataset_name in {"ptb", "ptb-new"}: + raise ValueError(f"{dataset_name} dataset was deprecated, only the following dataset are supported : {list(get_dataset_map)}") if dataset_name not in get_dataset_map: raise ValueError(f"Expected a value in {list(get_dataset_map.keys())} but found {dataset_name}") get_dataset_fn = get_dataset_map[dataset_name] diff --git a/tests/gptq/test_quantization.py b/tests/gptq/test_quantization.py index 0c070f8c9e4..5ed1619fde3 100644 --- a/tests/gptq/test_quantization.py +++ b/tests/gptq/test_quantization.py @@ -394,7 +394,7 @@ class GPTQDataTest(unittest.TestCase): def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True) - @parameterized.expand(["wikitext2", "c4", "ptb", "c4-new", "ptb-new"]) + @parameterized.expand(["wikitext2", "c4", "c4-new"]) def test_dataset(self, dataset): train_dataset = get_dataset( dataset, self.tokenizer, nsamples=self.NBSAMPLES, seqlen=self.SEQLEN, split="train"