Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gptq tokenized dataset #1584

Merged
merged 8 commits into from
Dec 13, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 25 additions & 20 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def __init__(
Args:
bits (`int`):
The number of bits to quantize to, supported numbers are (2, 3, 4, 8).
dataset (`Union[List[str],str]`, defaults to None):
The dataset used for quantization. You can provide your own dataset in a list of string or just use the original datasets used
in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new'].
dataset (`Union[List[str],str,Any]`, defaults to None):
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
The dataset used for quantization. You can provide your own dataset in a list of string or in a list of tokenized data or
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
just use the original datasets used in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new'].
group_size (int, defaults to 128):
The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
damp_percent (`float`, defaults to `0.1`):
Expand Down Expand Up @@ -283,14 +283,14 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st
self._replace_by_quant_layers(child, names, name + "." + name1 if name != "" else name1)

@torch.no_grad()
def quantize_model(self, model: nn.Module, tokenizer: Any):
def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):
"""
Quantizes the model using the dataset

Args:
model (`nn.Module`):
The model to quantize
tokenizer (`Any`):
tokenizer (Optional[`Any`], defaults to None):
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
The tokenizer to use in order to prepare the dataset. You can pass either:
- A custom tokenizer object.
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
Expand Down Expand Up @@ -341,23 +341,28 @@ def quantize_model(self, model: nn.Module, tokenizer: Any):
device = get_device(model)

# Step 1: Prepare the data
if isinstance(tokenizer, str):
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
except Exception:
if isinstance(self.dataset, list) and not isinstance(self.dataset[0], str):
logger.info("You are using an already tokenized dataset")
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
else:
if isinstance(tokenizer, str):
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
except Exception:
raise ValueError(
f"""We were not able to get the tokenizer using `AutoTokenizer.from_pretrained`
with the string that you have passed {tokenizer}. If you have a custom tokenizer, you can pass it as input.
For now, we only support quantization for text model. Support for vision, speech and multimodel will come later."""
)
if self.dataset is None:
raise ValueError("You need to pass `dataset` in order to quantize your model")
elif isinstance(self.dataset, str):
dataset = get_dataset(self.dataset, tokenizer, seqlen=self.model_seqlen, split="train")
elif isinstance(self.dataset, list):
dataset = [tokenizer(data, return_tensors="pt") for data in self.dataset]
else:
raise ValueError(
f"""We were not able to get the tokenizer using `AutoTokenizer.from_pretrained`
with the string that you have passed {tokenizer}. If you have a custom tokenizer, you can pass it as input.
For now, we only support quantization for text model. Support for vision, speech and multimodel will come later."""
"You need to pass a list of string, a list of tokenized data or a string for `dataset`"
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
)
if self.dataset is None:
raise ValueError("You need to pass `dataset` in order to quantize your model")
elif isinstance(self.dataset, str):
dataset = get_dataset(self.dataset, tokenizer, seqlen=self.model_seqlen, split="train")
elif isinstance(self.dataset, list):
dataset = [tokenizer(data, return_tensors="pt") for data in self.dataset]
else:
raise ValueError("You need to pass a list of string or a string for `dataset`")

dataset = prepare_dataset(dataset, pad_token_id=self.pad_token_id, batch_size=self.batch_size)

Expand Down
Loading