Skip to content

Commit

Permalink
rm dataset with restrictive license
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Jun 17, 2024
1 parent f480930 commit 8eb5704
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 35 deletions.
39 changes: 5 additions & 34 deletions optimum/gptq/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,40 +182,11 @@ def get_c4_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train")


def get_ptb(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
if split == "train":
data = load_dataset("ptb_text_only", "penn_treebank", split="train")
elif split == "validation":
data = load_dataset("ptb_text_only", "penn_treebank", split="validation")

enc = tokenizer(" ".join(data["sentence"]), return_tensors="pt")

dataset = []
for _ in range(nsamples):
i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = enc.input_ids[:, i:j]
attention_mask = torch.ones_like(inp)
dataset.append({"input_ids": inp, "attention_mask": attention_mask})

return dataset
raise RuntimeError("Loading the `ptb` dataset was deprecated")


def get_ptb_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
if split == "train":
data = load_dataset("ptb_text_only", "penn_treebank", split="train")
elif split == "validation":
data = load_dataset("ptb_text_only", "penn_treebank", split="test")

enc = tokenizer(" ".join(data["sentence"]), return_tensors="pt")

dataset = []
for _ in range(nsamples):
i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = enc.input_ids[:, i:j]
attention_mask = torch.ones_like(inp)
dataset.append({"input_ids": inp, "attention_mask": attention_mask})
return dataset
raise RuntimeError("Loading the `ptb` dataset was deprecated")


def get_dataset(
Expand All @@ -226,7 +197,7 @@ def get_dataset(
Args:
dataset_name (`str`):
Dataset name. Available options are `['wikitext2', 'c4', 'ptb', 'c4-new', 'ptb_new']`.
Dataset name. Available options are `['wikitext2', 'c4', 'c4-new']`.
tokenizer (`Any`):
Tokenizer of the model
nsamples (`int`, defaults to `128`):
Expand All @@ -247,11 +218,11 @@ def get_dataset(
"wikitext2": get_wikitext2,
"c4": get_c4,
"c4-new": get_c4_new,
"ptb": get_ptb,
"ptb-new": get_ptb_new,
}
if split not in ["train", "validation"]:
raise ValueError(f"The split need to be 'train' or 'validation' but found {split}")
if dataset_name in {"ptb", "ptb-new"}:
raise ValueError(f"{dataset_name} dataset was deprecated, only the following dataset are supported : {list(get_dataset_map)}")
if dataset_name not in get_dataset_map:
raise ValueError(f"Expected a value in {list(get_dataset_map.keys())} but found {dataset_name}")
get_dataset_fn = get_dataset_map[dataset_name]
Expand Down
2 changes: 1 addition & 1 deletion tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ class GPTQDataTest(unittest.TestCase):
def setUp(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)

@parameterized.expand(["wikitext2", "c4", "ptb", "c4-new", "ptb-new"])
@parameterized.expand(["wikitext2", "c4", "c4-new"])
def test_dataset(self, dataset):
train_dataset = get_dataset(
dataset, self.tokenizer, nsamples=self.NBSAMPLES, seqlen=self.SEQLEN, split="train"
Expand Down

0 comments on commit 8eb5704

Please sign in to comment.