From fc47275218e27531068da5db89e7fa56780551e3 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 11 Dec 2023 18:19:54 +0100 Subject: [PATCH 1/7] allow tokenized dataset --- optimum/gptq/quantizer.py | 45 +++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 1a3d4b8702c..7935583c17a 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -84,9 +84,9 @@ def __init__( Args: bits (`int`): The number of bits to quantize to, supported numbers are (2, 3, 4, 8). - dataset (`Union[List[str],str]`, defaults to None): - The dataset used for quantization. You can provide your own dataset in a list of string or just use the original datasets used - in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new']. + dataset (`Union[List[str],str,Any]`, defaults to None): + The dataset used for quantization. You can provide your own dataset in a list of string or in a list of tokenized data or + just use the original datasets used in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new']. group_size (int, defaults to 128): The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization. damp_percent (`float`, defaults to `0.1`): @@ -283,14 +283,14 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st self._replace_by_quant_layers(child, names, name + "." + name1 if name != "" else name1) @torch.no_grad() - def quantize_model(self, model: nn.Module, tokenizer: Any): + def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None): """ Quantizes the model using the dataset Args: model (`nn.Module`): The model to quantize - tokenizer (`Any`): + tokenizer (Optional[`Any`], defaults to None): The tokenizer to use in order to prepare the dataset. You can pass either: - A custom tokenizer object. - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. @@ -341,23 +341,26 @@ def quantize_model(self, model: nn.Module, tokenizer: Any): device = get_device(model) # Step 1: Prepare the data - if isinstance(tokenizer, str): - try: - tokenizer = AutoTokenizer.from_pretrained(tokenizer) - except Exception: - raise ValueError( - f"""We were not able to get the tokenizer using `AutoTokenizer.from_pretrained` - with the string that you have passed {tokenizer}. If you have a custom tokenizer, you can pass it as input. - For now, we only support quantization for text model. Support for vision, speech and multimodel will come later.""" - ) - if self.dataset is None: - raise ValueError("You need to pass `dataset` in order to quantize your model") - elif isinstance(self.dataset, str): - dataset = get_dataset(self.dataset, tokenizer, seqlen=self.model_seqlen, split="train") - elif isinstance(self.dataset, list): - dataset = [tokenizer(data, return_tensors="pt") for data in self.dataset] + if isinstance(self.dataset, list) and not isinstance(self.dataset[0], str): + logger.info("You are using an already tokenized dataset") else: - raise ValueError("You need to pass a list of string or a string for `dataset`") + if isinstance(tokenizer, str): + try: + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + except Exception: + raise ValueError( + f"""We were not able to get the tokenizer using `AutoTokenizer.from_pretrained` + with the string that you have passed {tokenizer}. If you have a custom tokenizer, you can pass it as input. + For now, we only support quantization for text model. Support for vision, speech and multimodel will come later.""" + ) + if self.dataset is None: + raise ValueError("You need to pass `dataset` in order to quantize your model") + elif isinstance(self.dataset, str): + dataset = get_dataset(self.dataset, tokenizer, seqlen=self.model_seqlen, split="train") + elif isinstance(self.dataset, list): + dataset = [tokenizer(data, return_tensors="pt") for data in self.dataset] + else: + raise ValueError("You need to pass a list of string, a list of tokenized data or a string for `dataset`") dataset = prepare_dataset(dataset, pad_token_id=self.pad_token_id, batch_size=self.batch_size) From c9f105b78f682cd58ffa2ac71e7b873518518255 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 11 Dec 2023 18:20:07 +0100 Subject: [PATCH 2/7] style --- optimum/gptq/quantizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 7935583c17a..948cdbaa12b 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -360,7 +360,9 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None): elif isinstance(self.dataset, list): dataset = [tokenizer(data, return_tensors="pt") for data in self.dataset] else: - raise ValueError("You need to pass a list of string, a list of tokenized data or a string for `dataset`") + raise ValueError( + "You need to pass a list of string, a list of tokenized data or a string for `dataset`" + ) dataset = prepare_dataset(dataset, pad_token_id=self.pad_token_id, batch_size=self.batch_size) From 2c6a6f091ce9a5073fd7c901e01f2c5d488fbe65 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:18:54 -0500 Subject: [PATCH 3/7] Update optimum/gptq/quantizer.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> --- optimum/gptq/quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 948cdbaa12b..9cdc2ac67a5 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -84,7 +84,7 @@ def __init__( Args: bits (`int`): The number of bits to quantize to, supported numbers are (2, 3, 4, 8). - dataset (`Union[List[str],str,Any]`, defaults to None): + dataset (`Union[List[str], str, Any]`, defaults to `None`): The dataset used for quantization. You can provide your own dataset in a list of string or in a list of tokenized data or just use the original datasets used in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new']. group_size (int, defaults to 128): From cb3803b6d352f30633faf328cc403a7ab02dea01 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:19:05 -0500 Subject: [PATCH 4/7] Update optimum/gptq/quantizer.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> --- optimum/gptq/quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 9cdc2ac67a5..65f3ef4154d 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -290,7 +290,7 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None): Args: model (`nn.Module`): The model to quantize - tokenizer (Optional[`Any`], defaults to None): + tokenizer (Optional[`Any`], defaults to `None`): The tokenizer to use in order to prepare the dataset. You can pass either: - A custom tokenizer object. - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. From 1d2439f93338b54e709cfa3be22fb495d13575ae Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:19:18 -0500 Subject: [PATCH 5/7] Update optimum/gptq/quantizer.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> --- optimum/gptq/quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 65f3ef4154d..69bac225bee 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -342,7 +342,7 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None): # Step 1: Prepare the data if isinstance(self.dataset, list) and not isinstance(self.dataset[0], str): - logger.info("You are using an already tokenized dataset") + logger.info("GPTQQuantizer dataset appears to be already tokenized. Skipping tokenization.") else: if isinstance(tokenizer, str): try: From 19322613b76b2b2b8f5d4d865e83dd2f20ba60c5 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:19:26 -0500 Subject: [PATCH 6/7] Update optimum/gptq/quantizer.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> --- optimum/gptq/quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 69bac225bee..118d12a2e2b 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -361,7 +361,7 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None): dataset = [tokenizer(data, return_tensors="pt") for data in self.dataset] else: raise ValueError( - "You need to pass a list of string, a list of tokenized data or a string for `dataset`" + f"You need to pass a list of string, a list of tokenized data or a string for `dataset`. Found: {type(self.dataset)}." ) dataset = prepare_dataset(dataset, pad_token_id=self.pad_token_id, batch_size=self.batch_size) From d68b32b9d1b305f4c1fcd724356181c83bec5783 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 13 Dec 2023 15:24:23 +0100 Subject: [PATCH 7/7] add example in docstring --- optimum/gptq/quantizer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 118d12a2e2b..3113e7f2ecd 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -85,8 +85,9 @@ def __init__( bits (`int`): The number of bits to quantize to, supported numbers are (2, 3, 4, 8). dataset (`Union[List[str], str, Any]`, defaults to `None`): - The dataset used for quantization. You can provide your own dataset in a list of string or in a list of tokenized data or - just use the original datasets used in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new']. + The dataset used for quantization. You can provide your own dataset in a list of string or in a list of tokenized data + (e.g. [{ "input_ids": [ 1, 100, 15, ... ],"attention_mask": [ 1, 1, 1, ... ]},...]) + or just use the original datasets used in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new']. group_size (int, defaults to 128): The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization. damp_percent (`float`, defaults to `0.1`):