diff --git a/examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py b/examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py index 8e0ba8291..7c1aee71d 100644 --- a/examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py +++ b/examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py @@ -1,4 +1,5 @@ import torch +from loguru import logger from transformers import AutoModelForCausalLM from llmcompressor.transformers import apply @@ -52,3 +53,7 @@ lr_scheduler_type=lr_scheduler_type, warmup_ratio=warmup_ratio, ) +logger.info( + "Note: vLLM requires the dtype=torch.float16 when running the ", + "compressed marlin-24 model", +) diff --git a/setup.py b/setup.py index 4c57ae3ce..71a681b48 100644 --- a/setup.py +++ b/setup.py @@ -71,6 +71,7 @@ "pytest-mock>=3.6.0", "pytest-rerunfailures>=13.0", "parameterized", + "lm_eval==0.4.5", # example test dependencies "beautifulsoup4~=4.12.3", "cmarkgfm~=2024.1.14", diff --git a/src/llmcompressor/modifiers/modifier.py b/src/llmcompressor/modifiers/modifier.py index 494f8bdfc..65b4a4029 100644 --- a/src/llmcompressor/modifiers/modifier.py +++ b/src/llmcompressor/modifiers/modifier.py @@ -1,16 +1,15 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Optional -from pydantic import BaseModel - from llmcompressor.core.events import Event, EventType from llmcompressor.core.state import State from llmcompressor.modifiers.interface import ModifierInterface +from llmcompressor.modifiers.utils.hooks import HooksMixin __all__ = ["Modifier"] -class Modifier(BaseModel, ModifierInterface, ABC): +class Modifier(ModifierInterface, HooksMixin): """ A base class for all modifiers to inherit from. Modifiers are used to modify the training process for a model. diff --git a/src/llmcompressor/modifiers/obcq/base.py b/src/llmcompressor/modifiers/obcq/base.py index 3da0e3d0c..9cf0ff331 100644 --- a/src/llmcompressor/modifiers/obcq/base.py +++ b/src/llmcompressor/modifiers/obcq/base.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -130,7 +131,8 @@ def initialize_compression( "Inferring layer-wise sparsities from " f"{len(dataloader)} calibration samples..." ) - self.sparsity = self._infer_layer_sparsity(dataloader) + activations = self._get_activations(dataloader) + self.sparsity = self._infer_layer_sparsity(activations) self._validate_layerwise_sparsity() for idx, (name, layer) in enumerate(self.compressible_layers_.items()): @@ -254,19 +256,17 @@ def _infer_mask_block_size(self): self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":"))) - def _infer_layer_sparsity(self, calibration_dataloader): - acts = _get_activations(self.model, calibration_dataloader) + def _infer_layer_sparsity(self, activations): sparsegpt_groups = {} for name, layer in self.compressible_layers_.items(): prunable_layers = get_prunable_layers(layer) z = [ - m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0) + m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0) for n, m in prunable_layers.items() ] sparsegpt_groups[name] = torch.cat([item.flatten().cpu() for item in z]) - acts = None - del acts + del activations torch.cuda.empty_cache() outlier_ratios = {} @@ -300,36 +300,34 @@ def _infer_layer_sparsity(self, calibration_dataloader): logger.info(f"Sparsity for {k}: {sparsities[k]}") return sparsities + @torch.no_grad() + def _get_activations(self, data_loader, nsamples=128): + self.model.eval() + acts = {} + + def save_acts(module, input, name): + if isinstance(input, tuple): + input = input[0] + if name not in acts: + acts[name] = ( + 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + ) + else: + acts[name] += ( + 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + ) + + for name, mod in self.model.named_modules(): + if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: + self.register_hook(mod, partial(save_acts, name=name), "forward_pre") + + device = next(self.model.parameters()).device + for batch in tqdm(data_loader): + batch = {k: v.to(device) for k, v in batch.items()} + self.model(**batch) + batch = None + torch.cuda.empty_cache() -@torch.no_grad() -def _get_activations(model, data_loader, nsamples=128): - import functools - - model.eval() - acts = {} - - def save_acts(module, input, name): - if isinstance(input, tuple): - input = input[0] - if name not in acts: - acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() - else: - acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() - - hooks = [] - for name, mod in model.named_modules(): - if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: - hooks.append( - mod.register_forward_pre_hook(functools.partial(save_acts, name=name)) - ) - device = next(model.parameters()).device - for batch in tqdm(data_loader): - batch = {k: v.to(device) for k, v in batch.items()} - model(**batch) - batch = None - torch.cuda.empty_cache() - - for h in hooks: - h.remove() + self.remove_hooks() - return acts + return acts diff --git a/src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py b/src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py index 3ada8c7fb..d59b4563b 100644 --- a/src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py +++ b/src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py @@ -2,11 +2,10 @@ from typing import Dict import torch -from pydantic import BaseModel from torch.nn import Parameter -from torch.utils.hooks import RemovableHandle from llmcompressor.core import ModelParameterizedLayer +from llmcompressor.modifiers.utils.hooks import HooksMixin __all__ = ["LayerParamMasking", "param_mask_name"] @@ -39,11 +38,9 @@ class ParameterizedLayerMaskSettings: use_hooks: bool = False -class LayerParamMasking(BaseModel): +class LayerParamMasking(HooksMixin): _mask_settings: Dict[str, ParameterizedLayerMaskSettings] = {} _masked_layer_params: Dict[str, ModelParameterizedLayer] = {} - _forward_hooks: Dict[str, RemovableHandle] = {} - _backward_hooks: Dict[str, RemovableHandle] = {} enabled_: bool = False def add_mask( @@ -100,12 +97,8 @@ def _backward_hook_fn(gradients): return gradients - self._forward_hooks[layer_param_name] = ( - parameterized_layer.layer.register_forward_hook(_forward_hook_fn) - ) - self._backward_hooks[layer_param_name] = ( - parameterized_layer.param.register_hook(_backward_hook_fn) - ) + self.register_hook(parameterized_layer.layer, _forward_hook_fn, "forward") + self.register_hook(parameterized_layer.param, _backward_hook_fn, "") def update_mask( self, @@ -131,11 +124,7 @@ def remove_mask(self, layer_param_name: str): del self._mask_settings[layer_param_name] if mask_settings.use_hooks: - self._forward_hooks[layer_param_name].remove() - self._backward_hooks[layer_param_name].remove() - - del self._forward_hooks[layer_param_name] - del self._backward_hooks[layer_param_name] + self.remove_hooks() def apply_mask_weight(self, layer_param_name: str): if not self.enabled_: diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index f056ee1ae..1881a347c 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -121,7 +122,8 @@ def initialize_compression( "Inferring layer-wise sparsities from " f"{len(dataloader) if dataloader else 0} calibration samples..." ) - self.sparsity = self._infer_layer_sparsity(dataloader) + activations = self._get_activations(dataloader) + self.sparsity = self._infer_layer_sparsity(activations) self._validate_layerwise_sparsity() for idx, (name, layer) in enumerate(self.compressible_layers_.items()): @@ -224,19 +226,17 @@ def _infer_mask_block_size(self): self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":"))) - def _infer_layer_sparsity(self, calibration_dataloader): - acts = _get_activations(self.model, calibration_dataloader) + def _infer_layer_sparsity(self, activations): wanda = {} for name, layer in self.compressible_layers_.items(): prunable_layers = get_prunable_layers(layer) z = [ - m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0) + m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0) for n, m in prunable_layers.items() ] wanda[name] = torch.cat([item.flatten().cpu() for item in z]) - acts = None - del acts + del activations torch.cuda.empty_cache() outlier_ratios = {} @@ -268,36 +268,34 @@ def _infer_layer_sparsity(self, calibration_dataloader): logger.info(f"Sparsity for {k}: {sparsities[k]}") return sparsities + @torch.no_grad() + def _get_activations(self, data_loader, nsamples=128): + self.model.eval() + acts = {} + + def save_acts(module, input, name): + if isinstance(input, tuple): + input = input[0] + if name not in acts: + acts[name] = ( + 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + ) + else: + acts[name] += ( + 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + ) + + for name, mod in self.model.named_modules(): + if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: + self.register_hook(mod, partial(save_acts, name=name), "forward_pre") + + device = next(self.model.parameters()).device + for batch in tqdm(data_loader): + batch = {k: v.to(device) for k, v in batch.items()} + self.model(**batch) + batch = None + torch.cuda.empty_cache() -@torch.no_grad() -def _get_activations(model, data_loader, nsamples=128): - import functools - - model.eval() - acts = {} - - def save_acts(module, input, name): - if isinstance(input, tuple): - input = input[0] - if name not in acts: - acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() - else: - acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() - - hooks = [] - for name, mod in model.named_modules(): - if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: - hooks.append( - mod.register_forward_pre_hook(functools.partial(save_acts, name=name)) - ) - device = next(model.parameters()).device - for batch in tqdm(data_loader): - batch = {k: v.to(device) for k, v in batch.items()} - model(**batch) - batch = None - torch.cuda.empty_cache() - - for h in hooks: - h.remove() + self.remove_hooks() - return acts + return acts diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 0c9508530..ee4ce171e 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Dict, Optional, Tuple import torch from compressed_tensors.quantization import QuantizationStatus, is_attention_module @@ -146,71 +146,57 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): ) -def calibrate_input_hook(): +def calibrate_input_hook(module: Module, args: Any): """ Hook to calibrate input activations. Will call the observers to update the scales/zp before applying input QDQ in the module's forward pass. """ + args = args[0] if isinstance(args, tuple) else args + calibrate_activations(module, value=args, base_name="input") - def hook_fn(module: Module, inp): - inp = inp[0] if isinstance(inp, tuple) else inp - calibrate_activations(module, value=inp, base_name="input") - return hook_fn - - -def calibrate_output_hook(): +def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): """ Hook to calibrate output activations. Will call the observers to update the scales/zp before applying output QDQ. """ - - def hook_fn(module: Module, inp, output: torch.Tensor): - calibrate_activations( - module, - value=output, - base_name="output", - ) - output = forward_quantize( - module=module, - value=output, - base_name="output", - args=module.quantization_scheme.output_activations, - ) - return output - - return hook_fn + calibrate_activations( + module, + value=output, + base_name="output", + ) + output = forward_quantize( + module=module, + value=output, + base_name="output", + args=module.quantization_scheme.output_activations, + ) + return output -def calibrate_kv_cache_input_hook(): +def calibrate_kv_cache_input_hook( + module: Module, args: Any, kwargs: Dict[str, Any] +) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """ Hook to update inputs to attention layers when running kv_cache quantization. Will update the passed in kv_cache to singleton QuantizedKVParameterCache. """ + kv_cache = getattr(module, "kv_cache") + kwargs["past_key_value"] = kv_cache + kwargs["use_cache"] = False + return args, kwargs - def hook_fn(module: Module, args, kwargs): - kv_cache = getattr(module, "kv_cache") - kwargs["past_key_value"] = kv_cache - kwargs["use_cache"] = False - return args, kwargs - - return hook_fn - -def calibrate_kv_cache_output_hook(): +def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor): """ Hook to update k_scale and v_scale parameters when running kv_cache quantization. """ - - def hook_fn(module: Module, inpt, output: torch.Tensor): - kv_cache = getattr(module, "kv_cache") - update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale") - update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale") - - return hook_fn + kv_cache = getattr(module, "kv_cache") + update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale") + update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale") def set_unset_kv_cache(module: Module): diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index b6dbda485..c5200cf0f 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -21,6 +21,7 @@ from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.utils.fsdp.context import fix_fsdp_module_name +from llmcompressor.utils.helpers import DisableKVCache from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, @@ -281,37 +282,33 @@ def apply_compression( # want to calibrate wrt to these self.model.apply(disable_quantization) - forward_pass_use_cache = self.model.config.use_cache - self.model.config.use_cache = False - - # run_calibration_forward uses the early stop exception to capture values - # as intermediates right before the forward pass of the first module - intermediates = run_calibration_forward( - self.model, dataloader, mask_padding=True - ) - self.layer_compressors_[0].clear_early_stop() - - num_layers = len(self.compressible_layers_) - for idx, layer_compressor in enumerate(self.layer_compressors_): - logger.info(f"\n===== Compressing layer {idx+1}/{num_layers} " " =====") - - # run the forward pass for each transformer layer (block) one at a time - logger.info(f"Calibrating {layer_compressor.name}...") - layer_compressor.pre_compress() - unquantized_outputs = layer_compressor.calibrate_layer(intermediates) - - layer_compressor.compress() - layer_compressor.post_compress() - layer_compressor.revert_layer_wrappers() - - # perform a second forward pass of the module to calculate weight-quantized - # outputs for use as inputs to the next layer (block) - quantized_outputs = layer_compressor.calibrate_layer(intermediates) - error = get_output_error(unquantized_outputs, quantized_outputs) - logger.info(f"Mean output error from quantization: {error:.3f}") - intermediates = quantized_outputs - - self.model.config.use_cache = forward_pass_use_cache + with DisableKVCache(self.model): + # run_calibration_forward uses the early stop exception to capture values + # as intermediates right before the forward pass of the first module + intermediates = run_calibration_forward( + self.model, dataloader, mask_padding=True + ) + self.layer_compressors_[0].clear_early_stop() + + num_layers = len(self.compressible_layers_) + for idx, layer_compressor in enumerate(self.layer_compressors_): + logger.info(f"\n===== Compressing layer {idx+1}/{num_layers} " " =====") + + # run the forward pass for each transformer layer (block) one at a time + logger.info(f"Calibrating {layer_compressor.name}...") + layer_compressor.pre_compress() + unquantized_outputs = layer_compressor.calibrate_layer(intermediates) + + layer_compressor.compress() + layer_compressor.post_compress() + layer_compressor.revert_layer_wrappers() + + # perform a second forward pass of the module to calculate + # weight-quantized outputs for use as inputs to the next layer + quantized_outputs = layer_compressor.calibrate_layer(intermediates) + error = get_output_error(unquantized_outputs, quantized_outputs) + logger.info(f"Mean output error from quantization: {error:.3f}") + intermediates = quantized_outputs # re-enable quantization self.model.apply(enable_quantization) diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index b469c00ac..9b4516b52 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -82,7 +82,6 @@ class QuantizationModifier(Modifier): calibration_dataloader_: Any = None calibration_function_: Any = None - calibration_hooks_: List = None @field_validator("targets", mode="before") def validate_targets(cls, value: Union[str, List[str]]) -> List[str]: @@ -109,7 +108,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: self._check_calibration_data(config) module.apply(update_weight_zp_scale) module.apply(apply_calibration_status) - self.calibration_hooks_ = [] self._calibrate_if_possible(module) self._check_token_distribution( module, threshold=kwargs.get("min_tokens_per_module") @@ -232,15 +230,12 @@ def _calibrate_if_possible(self, module: Module): register_calibration_hooks(): if input activation and not dynamic quant (used to call observers before intput QDQ): - - pre_hook_handle = module.register_forward_pre_hook(calibrate_input_hook()) + - pre_hook := calibrate_input_hook if output activation and not dynamic quant (used to call observers before output QDQ): - - post_hook_handle = module.register_forward_hook(calibrate_kv_cache_output_hook()) + - post_hook := calibrate_kv_cache_output_hook if kv_cache quantization (used to set kv_cache to QuantizedKVParameterCache and update k_scale/v_scale) - - pre_hook_handle = module.register_forward_pre_hook(calibrate_kv_cache_input_hook(), with_kwargs=True) - - post_hook_handle = module.register_forward_hook(calibrate_kv_cache_output_hook()) - - self.calibration_hooks.append(pre_hook_handle) - self.calibration_hooks.append(post_hook_handle) + - pre_hook := calibrate_kv_cache_input_hook + - post_hook := calibrate_kv_cache_output_hook self._calibrate(module) # run forward pass through model using calibration data set_unset_kv_cache() # remove kv_cache objects attached to attention layers @@ -269,8 +264,7 @@ def _calibrate_if_possible(self, module: Module): module.apply(self.register_calibration_hooks) self._calibrate(module) module.apply(set_unset_kv_cache) - for h in self.calibration_hooks_: - h.remove() + self.remove_hooks() def register_calibration_hooks(self, module: Module): """ @@ -280,8 +274,6 @@ def register_calibration_hooks(self, module: Module): if not quantization_scheme: return - pre_hook_handle = None - post_hook_handle = None is_attention_module_ = is_attention_module(module) input_quant = quantization_scheme.input_activations output_quant = quantization_scheme.output_activations @@ -292,27 +284,23 @@ def register_calibration_hooks(self, module: Module): # Calibrate inputs if an input_quant is provided and not running dynamic quant if calibrate_inputs: - pre_hook_handle = module.register_forward_pre_hook(calibrate_input_hook()) + self.register_hook(module, calibrate_input_hook, "forward_pre") if output_quant: # hooks for attn modules if running kv_cache quant if is_attention_module_: - pre_hook_handle = module.register_forward_pre_hook( - calibrate_kv_cache_input_hook(), with_kwargs=True - ) - post_hook_handle = module.register_forward_hook( - calibrate_kv_cache_output_hook() + self.register_hook( + module, + calibrate_kv_cache_input_hook, + "forward_pre", + with_kwargs=True, ) + + self.register_hook(module, calibrate_kv_cache_output_hook, "forward") + # hooks for output quant if not running dynamic quant elif not output_quant.dynamic: - post_hook_handle = module.register_forward_hook(calibrate_output_hook()) - - if pre_hook_handle: - logger.debug(f"Add {pre_hook_handle} for calibration") - self.calibration_hooks_.append(pre_hook_handle) - if post_hook_handle: - logger.debug(f"Add {post_hook_handle} for calibration") - self.calibration_hooks_.append(post_hook_handle) + self.register_hook(module, calibrate_output_hook, "forward") def _calibrate(self, module: Module): class_name = self.__class__.__name__.replace("PyTorch", "") diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 7487d0609..f4117e31d 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -99,7 +99,6 @@ class SmoothQuantModifier(Modifier): num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None - hooks_: Optional[List] = None resolved_mappings_: Optional[List] = None scales_: Optional[Dict] = None @@ -127,7 +126,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: self.scales_ = {} calibration_dataloader = state.data.calib - self.hooks_ = [] self._setup_scale_hooks() self._calibrate(state.model, calibration_dataloader) @@ -228,7 +226,7 @@ def hook_fn(module, inp, out): for mapping in self.resolved_mappings_: name = mapping.smooth_name layer = mapping.smooth_layer - self.hooks_.append(layer.register_forward_hook(create_hook_fn(name))) + self.register_hook(layer, create_hook_fn(name), "forward") @torch.no_grad() def _calibrate(self, model: Module, calibration_dataloader: List): @@ -255,9 +253,7 @@ def _calibrate(self, model: Module, calibration_dataloader: List): ) # remove the hooks now that we are done calibrating - for hook in self.hooks_: - hook.remove() - del self.hooks_ + self.remove_hooks() @torch.no_grad() def _apply_smoothing(self, model: Module): diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py new file mode 100644 index 000000000..bb1755519 --- /dev/null +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -0,0 +1,83 @@ +import contextlib +from functools import wraps +from typing import Any, Callable, ClassVar, List, Union + +import torch +from loguru import logger +from pydantic import BaseModel +from torch.utils.hooks import RemovableHandle + +__all__ = ["HooksMixin"] + + +class HooksMixin(BaseModel): + """ + Mixin to manage hook registration, disabling, and removal. + Modifiers should use `self.register_hook(module, hook, hook_type)` + for hook registration and `self.remove_hooks()` for removal. + + Modifiers which implement hooks should register them using + `self.register_..._hook(module, hook)` rather than the usual + `module.register_..._hook(hook)`. Modifiers should remove hooks with + `self.remove_hooks()`. + + Hooks can be applied to modules or parameters + + Lifecycle: + - modifier.register_forward_hook(module, hook) + - with HooksMixin.disable_hooks(): model.forward() + - modifier.remove_hooks() + """ + + _HOOKS_DISABLED: ClassVar[bool] = False # attached to global HooksMixin + _hooks: List[RemovableHandle] = [] # attached to local subclasses + + @classmethod + @contextlib.contextmanager + def disable_hooks(cls): + """Disable all hooks across all modifiers""" + try: + cls._HOOKS_DISABLED = True + yield + finally: + cls._HOOKS_DISABLED = False + + def register_hook( + self, + target: Union[torch.nn.Module, torch.nn.Parameter], + hook: Callable[[Any], Any], + hook_type: str, + **kwargs, + ) -> RemovableHandle: + """ + Registers a hook on a specified module/parameter with the option to disable it + with HooksMixin.disable_hooks() + + :param target: the module or parameter on which the hook should be registered + :param hook: the hook to register + :param hook_type: the type of hook to register corresponding to the + `register_{hook_type}_hook` attribute on torch.nn.Module. + Ex. "forward", "forward_pre", "full_backward", "state_dict_post", "" + :param kwargs: keyword arguments to pass to register hook method + """ + + @wraps(hook) + def wrapped_hook(*args, **kwargs): + if HooksMixin._HOOKS_DISABLED: + return + + return hook(*args, **kwargs) + + register_function = getattr(target, f"register_{hook_type}_hook") + handle = register_function(wrapped_hook, **kwargs) + self._hooks.append(handle) + logger.debug(f"{self} added {handle}") + + return handle + + def remove_hooks(self): + """Remove all hooks belonging to a modifier""" + for hook in self._hooks: + hook.remove() + + self._hooks = [] diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 11a924f1d..3db9be173 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -24,8 +24,6 @@ "save_completed_stages", ] -RECIPE_FILE_NAME = "recipe.yaml" - def log_model_load( model: Module, model_name_or_path: str, model_type: str, delayed_load: bool @@ -106,6 +104,8 @@ def save_model_and_recipe( :param save_safetensors: whether to save as safetensors or pickle (bin) :param save_compressed: whether to compress sparse weights on disk """ + # avoid circular import + from llmcompressor.transformers.utils.helpers import RECIPE_FILE_NAME model.save_pretrained( save_path, save_compressed=save_compressed, safe_serialization=save_safetensors diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index e3a9c4d84..b1ac57b95 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -24,8 +24,9 @@ from llmcompressor.modifiers.distillation.utils.pytorch.model_wrapper import ( KDModelWrapper, ) -from llmcompressor.pytorch.model_load.helpers import RECIPE_FILE_NAME, get_session_model +from llmcompressor.pytorch.model_load.helpers import get_session_model from llmcompressor.pytorch.utils import ModuleSparsificationInfo +from llmcompressor.transformers import RECIPE_FILE_NAME from llmcompressor.transformers.finetune.callbacks import ( DisableHalfPrecisionCallback, TrainingLoopCallbacks, diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 6de89dd8b..759098894 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -24,6 +24,7 @@ from llmcompressor.transformers.compression.sparsity_config import ( SparsityConfigMetadata, ) +from llmcompressor.transformers.utils import RECIPE_FILE_NAME from llmcompressor.utils.fsdp.helpers import ( find_and_move_state_dicts_to_cpu, unwrap_and_export_model, @@ -189,7 +190,7 @@ def skip(*args, **kwargs): ) compressor.update_config(save_directory) - recipe_path = os.path.join(save_directory, "recipe.yaml") + recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME) session = active_session() if (recipe_yaml_str := session.get_serialized_recipe()) is not None: diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index 401a454cf..1263bb004 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -3,80 +3,26 @@ huggingface/transformers flows """ -import inspect import os -from collections import OrderedDict -from contextlib import suppress -from enum import Enum -from pathlib import Path -from typing import Iterable, List, Optional -from typing import OrderedDict as OrderedDictType -from typing import Tuple, Union +from typing import TYPE_CHECKING, Optional -import requests -import torch -import transformers -from huggingface_hub import HUGGINGFACE_CO_URL_HOME, HfFileSystem, hf_hub_download from loguru import logger -from transformers import AutoConfig from transformers.trainer_utils import get_last_checkpoint -from transformers.utils import PaddingStrategy -from llmcompressor.utils.fsdp.context import main_process_first_context +if TYPE_CHECKING: + from llmcompressor.transformers import ModelArguments, TrainingArguments __all__ = [ - "RECIPE_NAME", + "RECIPE_FILE_NAME", "detect_last_checkpoint", - "TaskNames", - "resolve_sequence_length", - "ALL_TASK_NAMES", - "create_fake_dataloader", - "POSSIBLE_TOKENIZER_FILES", - "download_repo_from_huggingface_hub", - "download_model_directory", ] - -class TaskNames(Enum): - mlm = {"masked-language-modeling", "mlm"} - qa = {"question-answering", "qa"} - token_classification = {"token-classification", "ner"} - text_classification = { - "text-classification", - "sentiment-analysis", - "sequence-classification", - "glue", - } - text_generation = {"text-generation"} - - -ALL_TASK_NAMES = list(set.union(*[task_names.value for task_names in TaskNames])) -RECIPE_NAME = "recipe.yaml" - -MANDATORY_DEPLOYMENT_FILES = { - "tokenizer_config.json", - "config.json", -} -OPTIONAL_DEPLOYMENT_FILES = {"tokenizer.json", "tokenizer.model"} -NLG_MANDATORY_DEPLOYMENT_FILES = {"special_tokens_map.json"} -NLG_OPTIONAL_DEPLOYMENT_FILES = { - "vocab.json", - "merges.txt", -} -POSSIBLE_TOKENIZER_FILES = { - "vocab.json", - "merges.txt", - "tokenizer.json", - "tokenizer.model", - "special_tokens_map.json", - "tokenizer_config.json", -} -RELEVANT_HF_SUFFIXES = ["json", "md", "bin", "safetensors", "yaml", "yml", "py"] +RECIPE_FILE_NAME = "recipe.yaml" def detect_last_checkpoint( - training_args: "TrainingArguments", # noqa 821 - model_args: Optional["ModelArguments"] = None, # noqa 821 + training_args: "TrainingArguments", + model_args: Optional["ModelArguments"] = None, ): last_checkpoint = None if ( @@ -108,385 +54,3 @@ def detect_last_checkpoint( ) return last_checkpoint - - -def resolve_sequence_length(config: AutoConfig) -> int: - """ - Resolve the sequence length from the config - - :param config: the config to resolve the sequence length from - :return: the sequence length - """ - if hasattr(config, "max_position_embeddings"): - sequence_length = config.max_position_embeddings - - elif hasattr(config, "max_seq_len"): - sequence_length = config.max_seq_len - else: - raise ValueError( - "Could not infer a default sequence length " - "from the HF transformers config. Please specify " - "the sequence length with --sequence_length" - ) - logger.debug( - f"Using default sequence length of {sequence_length} " - "(inferred from HF transformers config) " - ) - return sequence_length - - -def resolve_recipe( - model_path: Union[str, Path], - recipe: Union[str, Path, None] = None, -) -> Union[str, None]: - """ - Resolve the recipe to apply to the model. - :param recipe: the recipe to apply to the model. - It can be one of the following: - - None - This means that we are not either not applying - any recipe and allowing the model to potentially - infer the appropriate pre-existing recipe - from the model_path - - a path to the recipe file - This can be a string or Path object pointing - to a recipe file. If the specified recipe file - is different from the potential pre-existing - recipe for that model (stored in the model_path), - the function will raise an warning - - name of the recipe file (e.g. "recipe.yaml") - Recipe file name specific is assumed to be stored - in the model_path - - a string containing the recipe - Needs to adhere to the SparseML recipe format - - :param model_path: the path to the model to load. - It can be one of the following: - - a path to the model directory - - a path to the model file - - Hugging face model id - - :return: the resolved recipe - """ - - if recipe is None: - return infer_recipe_from_model_path(model_path) - - elif os.path.isfile(recipe): - # recipe is a path to a recipe file - return resolve_recipe_file(recipe, model_path) - - elif os.path.isfile(os.path.join(model_path, recipe)): - # recipe is a name of a recipe file - recipe = os.path.join(model_path, recipe) - return resolve_recipe_file(recipe, model_path) - - elif isinstance(recipe, str): - # recipe is a string containing the recipe - logger.debug( - "Applying the recipe string directly to the model, without " - "checking for a potential existing recipe in the model_path." - ) - return recipe - - logger.info( - "No recipe requested and no default recipe " - f"found in {model_path}. Skipping recipe resolution." - ) - return None - - -def infer_recipe_from_model_path(model_path: Union[str, Path]) -> Optional[str]: - """ - Infer the recipe from the model_path. - :param model_path: the path to the model to load. - It can be one of the following: - - a path to the model directory - - a path to the model file - - Hugging face model id - :return the path to the recipe file if found, None otherwise - """ - model_path = model_path.as_posix() if isinstance(model_path, Path) else model_path - - if os.path.isdir(model_path) or os.path.isfile(model_path): - # model_path is a local path to the model directory or model file - # attempting to find the recipe in the model_directory - model_path = ( - os.path.dirname(model_path) if os.path.isfile(model_path) else model_path - ) - recipe = os.path.join(model_path, RECIPE_NAME) - if os.path.isfile(recipe): - logger.info(f"Found recipe in the model_path: {recipe}") - return recipe - logger.debug(f"No recipe found in the model_path: {model_path}") - return None - - recipe = recipe_from_huggingface_model_id(model_path)[0] - - if recipe is None: - logger.info("Failed to infer the recipe from the model_path") - return recipe - - -def recipe_from_huggingface_model_id( - model_path: str, recipe_name: str = RECIPE_NAME -) -> Tuple[Optional[str], bool]: - """ - Attempts to download the recipe from the huggingface model id. - - :param model_path: Assumed to be the huggingface model id. - If it is not, this function will return None. - :param recipe_name: The name of the recipe file to download. - Defaults to RECIPE_NAME. - :return: tuple: - - the path to the recipe file if found, None otherwise - - True if model_path is a valid huggingface model id, False otherwise - """ - model_id = os.path.join(HUGGINGFACE_CO_URL_HOME, model_path) - request = requests.get(model_id) - if not request.status_code == 200: - logger.debug( - "model_path is not a valid huggingface model id. " - "Skipping recipe resolution." - ) - return None, False - - logger.info( - "model_path is a huggingface model id. " - "Attempting to download recipe from " - f"{HUGGINGFACE_CO_URL_HOME}" - ) - try: - recipe = hf_hub_download(repo_id=model_path, filename=recipe_name) - logger.info(f"Found recipe: {recipe_name} for model id: {model_path}.") - except Exception as e: - logger.info( - f"Unable to to find recipe {recipe_name} " - f"for model id: {model_path}: {e}. " - "Skipping recipe resolution." - ) - recipe = None - return recipe, True - - -def resolve_recipe_file( - requested_recipe: Union[str, Path], model_path: Union[str, Path] -) -> Union[str, Path, None]: - """ - Given the requested recipe and the model_path, return the path to the recipe file. - - :param requested_recipe. Is a full path to the recipe file - :param model_path: the path to the model to load. - It can be one of the following: - - a path to the model directory - - a path to the model file - - Hugging face model id - :return the path to the recipe file if found, None otherwise - """ - # preprocess arguments so that they are all strings - requested_recipe = ( - requested_recipe.as_posix() - if isinstance(requested_recipe, Path) - else requested_recipe - ) - model_path = model_path.as_posix() if isinstance(model_path, Path) else model_path - model_path = ( - os.path.dirname(model_path) if os.path.isfile(model_path) else model_path - ) - - if not os.path.isdir(model_path): - default_recipe, model_exists = recipe_from_huggingface_model_id(model_path) - if not model_exists: - raise ValueError(f"Unrecognized model_path: {model_path}") - - if not default_recipe == requested_recipe and default_recipe is not None: - logger.warning( - f"Attempting to apply recipe: {requested_recipe} " - f"to the model at: {model_path}, " - f"but the model already has a recipe: {default_recipe}. " - f"Using {requested_recipe} instead." - ) - return requested_recipe - - # pathway for model_path that is a directory - default_recipe = os.path.join(model_path, RECIPE_NAME) - default_recipe_exists = os.path.isfile(default_recipe) - default_and_request_recipes_identical = os.path.samefile( - default_recipe, requested_recipe - ) - - if ( - default_recipe_exists - and requested_recipe - and not default_and_request_recipes_identical - ): - logger.warning( - f"Attempting to apply recipe: {requested_recipe} " - f"to the model located in {model_path}, " - f"but the model already has a recipe stored as {default_recipe}. " - f"Using {requested_recipe} instead." - ) - - elif not default_recipe_exists and requested_recipe: - logger.warning( - f"Attempting to apply {requested_recipe} " - f"to the model located in {model_path}." - "However, it is expected that the model " - f"has its target recipe stored as {default_recipe}." - "Applying any recipe before the target recipe may " - "result in unexpected behavior." - f"Applying {requested_recipe} nevertheless." - ) - - elif default_recipe_exists: - logger.info(f"Using the default recipe: {requested_recipe}") - - return requested_recipe - - -def create_fake_dataloader( - model: torch.nn.Module, - tokenizer: transformers.AutoTokenizer, - num_samples: int, -) -> Tuple[Iterable[OrderedDictType[str, torch.Tensor]], List[str]]: - """ - Creates fake transformers dataloader for the model, based on the model's - forward signature. - - :param model: The model to create the dataloader for - :param tokenizer: The tokenizer to use for the dataloader - :param num_samples: The number of fake samples in the dataloader - :return: The data loader (iterable) and the input names for the model - """ - - forward_args_spec = inspect.getfullargspec(model.__class__.forward) - inputs = tokenizer( - "", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value - ).data - fake_inputs = OrderedDict( - [ - (input_key, inputs[input_key][0].reshape(1, -1)) - for input_key in forward_args_spec.args - if input_key in inputs - ] - ) - data_loader = (fake_inputs for _ in range(num_samples)) - input_names = list(fake_inputs.keys()) - return data_loader, input_names - - -def fetch_recipe_path(target: str): - """ - Fetches the recipe path for the given target. - This method will also download the recipe if it is not - already downloaded. - - Takes care of three scenarios: - 1. target is a local path to a model directory - (looks for recipe.yaml in the directory) - 2. target is a HuggingFace stub (downloads and - returns the path to the default recipe) - - :param target: The target to fetch the recipe path for - can be a local path or HuggingFace stub - :return: The path to the recipe for the target - """ - DEFAULT_RECIPE_NAME = "recipe.yaml" - if Path(target).exists(): - # target is a local path - potential_recipe_path = Path(target) / DEFAULT_RECIPE_NAME - return str(potential_recipe_path) if potential_recipe_path.exists() else None - - # Recipe must be downloaded - - recipe_path = None - - # target is a HuggingFace stub - with suppress(Exception): - # suppress any errors if the recipe is not found on HuggingFace - recipe_path = hf_hub_download(repo_id=target, filename=DEFAULT_RECIPE_NAME) - - return recipe_path - - -def download_repo_from_huggingface_hub(repo_id, **kwargs): - """ - Download relevant model files from the Hugging Face Hub - using the huggingface_hub.hf_hub_download function - - Note(s): - - Does not download the entire repo, only the relevant files - for the model, such as the model weights, tokenizer files, etc. - - Does not re-download files that already exist locally, unless - the force_download flag is set to True - - :pre-condition: the repo_id must be a valid Hugging Face Hub repo id - :param repo_id: the repo id to download - :param kwargs: additional keyword arguments to pass to hf_hub_download - """ - hf_filesystem = HfFileSystem() - files = hf_filesystem.ls(repo_id) - - if not files: - raise ValueError(f"Could not find any files in HF repo {repo_id}") - - # All file(s) from hf_filesystem have "name" key - # Extract the file names from the files - relevant_file_names = ( - Path(file["name"]).name - for file in files - if any(file["name"].endswith(suffix) for suffix in RELEVANT_HF_SUFFIXES) - ) - - hub_kwargs_names = ( - "subfolder", - "repo_type", - "revision", - "library_name", - "library_version", - "cache_dir", - "local_dir", - "local_dir_use_symlinks", - "user_agent", - "force_download", - "force_filename", - "proxies", - "etag_timeout", - "resume_download", - "token", - "local_files_only", - "headers", - "legacy_cache_layout", - "endpoint", - ) - hub_kwargs = {name: kwargs[name] for name in hub_kwargs_names if name in kwargs} - - for file_name in relevant_file_names: - last_file = hf_hub_download(repo_id=repo_id, filename=file_name, **hub_kwargs) - - # parent directory of the last file is the model directory - return str(Path(last_file).parent.resolve().absolute()) - - -def download_model_directory(pretrained_model_name_or_path: str, **kwargs): - """ - Download the model directory from the HF hub if the model is not found locally - - :param pretrained_model_name_or_path: the name of or path to the model to load - can be a HuggingFace model stub - :param kwargs: additional keyword arguments to pass to the download function - :return: the path to the downloaded model directory - """ - pretrained_model_path: Path = Path(pretrained_model_name_or_path) - - if pretrained_model_path.exists(): - logger.debug( - "Model directory already exists locally.", - ) - return pretrained_model_name_or_path - - with main_process_first_context(): - logger.debug("Downloading model from HuggingFace Hub.") - return download_repo_from_huggingface_hub( - repo_id=pretrained_model_name_or_path, **kwargs - ) diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 266acf973..bdf27f620 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -22,6 +22,7 @@ from urllib.parse import urlparse import numpy +import torch from loguru import logger __all__ = [ @@ -59,6 +60,7 @@ "is_package_available", "import_from_path", "getattr_chain", + "DisableKVCache", ] @@ -1041,3 +1043,40 @@ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any: res = getattr(res, attr_name) return res + + +class DisableKVCache: + """ + Temporarily disable the key-value cache for transformer models. Used to prevent + excess memory use in one-shot cases where the model only performs the prefill + phase and not the generation phase. + + Example: + >>> model = AutoModel.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + >>> input = torch.randint(0, 32, size=(1, 32)) + >>> with DisableKVCache(model): + ... output = model(input) + """ + + def __init__(self, model: torch.nn.Module): + if hasattr(model.config, "use_cache"): + self.config = model.config + + # MllamaConfig + elif hasattr(model.config, "text_config") and hasattr( + model.config.text_config, "use_cache" + ): + self.config = model.config.text_config + + # unknown config structure + else: + raise NotImplementedError(f"Cannot find `use_cache` for {model.config}") + + self.restore_value = self.config.use_cache + + def __enter__(self): + self.restore_value = self.config.use_cache + self.config.use_cache = False + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.config.use_cache = self.restore_value diff --git a/src/llmcompressor/version.py b/src/llmcompressor/version.py index 9632ef77f..df576ff82 100644 --- a/src/llmcompressor/version.py +++ b/src/llmcompressor/version.py @@ -8,7 +8,7 @@ from typing import Optional, Tuple # Define the base version and build type -version_base = "0.3.0" +version_base = "0.3.1" build_type = "dev" # can be 'release', 'nightly', 'dev', or 'dev' with a dev number diff --git a/tests/e2e/e2e_utils.py b/tests/e2e/e2e_utils.py new file mode 100644 index 000000000..d8dfea005 --- /dev/null +++ b/tests/e2e/e2e_utils.py @@ -0,0 +1,58 @@ +from datasets import load_dataset +from loguru import logger +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier +from llmcompressor.transformers import oneshot +from tests.testing_utils import preprocess_tokenize_dataset + + +def run_oneshot_for_e2e_testing( + model: str, + device: str, + num_calibration_samples: int, + max_seq_length: int, + dataset_id: str, + recipe: str, + dataset_split: str, + dataset_config: str, + scheme: str, + quant_type: str, +): + # Load model. + oneshot_kwargs = {} + loaded_model = AutoModelForCausalLM.from_pretrained( + model, device_map=device, torch_dtype="auto" + ) + tokenizer = AutoTokenizer.from_pretrained(model) + + if dataset_id: + ds = load_dataset(dataset_id, name=dataset_config, split=dataset_split) + ds = ds.shuffle(seed=42).select(range(num_calibration_samples)) + ds = preprocess_tokenize_dataset(ds, tokenizer, max_seq_length) + oneshot_kwargs["dataset"] = ds + oneshot_kwargs["max_seq_length"] = max_seq_length + oneshot_kwargs["num_calibration_samples"] = num_calibration_samples + + oneshot_kwargs["model"] = loaded_model + if recipe: + oneshot_kwargs["recipe"] = recipe + else: + # Test assumes that if a recipe was not provided, using + # a compatible preset sceme + if quant_type == "GPTQ": + oneshot_kwargs["recipe"] = GPTQModifier( + targets="Linear", scheme=scheme, ignore=["lm_head"] + ) + else: + oneshot_kwargs["recipe"] = QuantizationModifier( + targets="Linear", scheme=scheme, ignore=["lm_head"] + ) + + # Apply quantization. + logger.info("ONESHOT KWARGS", oneshot_kwargs) + oneshot( + **oneshot_kwargs, + oneshot_device=device, + ) + return oneshot_kwargs["model"], tokenizer diff --git a/tests/e2e/vLLM/configs/FP8/fp8_dynamic_per_token.yaml b/tests/e2e/vLLM/configs/fp8_dynamic_per_token.yaml similarity index 100% rename from tests/e2e/vLLM/configs/FP8/fp8_dynamic_per_token.yaml rename to tests/e2e/vLLM/configs/fp8_dynamic_per_token.yaml diff --git a/tests/e2e/vLLM/configs/FP8/fp8_static_per_tensor.yaml b/tests/e2e/vLLM/configs/fp8_static_per_tensor.yaml similarity index 100% rename from tests/e2e/vLLM/configs/FP8/fp8_static_per_tensor.yaml rename to tests/e2e/vLLM/configs/fp8_static_per_tensor.yaml diff --git a/tests/e2e/vLLM/configs/FP8/fp8_weight_only_channel.yaml b/tests/e2e/vLLM/configs/fp8_weight_only_channel.yaml similarity index 100% rename from tests/e2e/vLLM/configs/FP8/fp8_weight_only_channel.yaml rename to tests/e2e/vLLM/configs/fp8_weight_only_channel.yaml diff --git a/tests/e2e/vLLM/configs/FP8/fp8_weight_only_tensor.yaml b/tests/e2e/vLLM/configs/fp8_weight_only_tensor.yaml similarity index 100% rename from tests/e2e/vLLM/configs/FP8/fp8_weight_only_tensor.yaml rename to tests/e2e/vLLM/configs/fp8_weight_only_tensor.yaml diff --git a/tests/e2e/vLLM/configs/INT8/int8_channel_weight_static_per_tensor_act.yaml b/tests/e2e/vLLM/configs/int8_channel_weight_static_per_tensor_act.yaml similarity index 100% rename from tests/e2e/vLLM/configs/INT8/int8_channel_weight_static_per_tensor_act.yaml rename to tests/e2e/vLLM/configs/int8_channel_weight_static_per_tensor_act.yaml diff --git a/tests/e2e/vLLM/configs/INT8/int8_dynamic_per_token.yaml b/tests/e2e/vLLM/configs/int8_dynamic_per_token.yaml similarity index 100% rename from tests/e2e/vLLM/configs/INT8/int8_dynamic_per_token.yaml rename to tests/e2e/vLLM/configs/int8_dynamic_per_token.yaml diff --git a/tests/e2e/vLLM/configs/INT8/int8_tensor_weight_static_per_tensor_act.yaml b/tests/e2e/vLLM/configs/int8_tensor_weight_static_per_tensor_act.yaml similarity index 100% rename from tests/e2e/vLLM/configs/INT8/int8_tensor_weight_static_per_tensor_act.yaml rename to tests/e2e/vLLM/configs/int8_tensor_weight_static_per_tensor_act.yaml diff --git a/tests/e2e/vLLM/configs/WNA16_2of4/w4a16_2of4_channel_quant.yaml b/tests/e2e/vLLM/configs/w4a16_2of4_channel_quant.yaml similarity index 100% rename from tests/e2e/vLLM/configs/WNA16_2of4/w4a16_2of4_channel_quant.yaml rename to tests/e2e/vLLM/configs/w4a16_2of4_channel_quant.yaml diff --git a/tests/e2e/vLLM/configs/WNA16_2of4/w4a16_2of4_grouped_quant.yaml b/tests/e2e/vLLM/configs/w4a16_2of4_grouped_quant.yaml similarity index 100% rename from tests/e2e/vLLM/configs/WNA16_2of4/w4a16_2of4_grouped_quant.yaml rename to tests/e2e/vLLM/configs/w4a16_2of4_grouped_quant.yaml diff --git a/tests/e2e/vLLM/configs/actorder/w4a16_actorder_group.yaml b/tests/e2e/vLLM/configs/w4a16_actorder_group.yaml similarity index 90% rename from tests/e2e/vLLM/configs/actorder/w4a16_actorder_group.yaml rename to tests/e2e/vLLM/configs/w4a16_actorder_group.yaml index ddc9fc803..bb02c51ef 100644 --- a/tests/e2e/vLLM/configs/actorder/w4a16_actorder_group.yaml +++ b/tests/e2e/vLLM/configs/w4a16_actorder_group.yaml @@ -5,5 +5,5 @@ recipe: tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_group.yaml dataset_id: openai/gsm8k dataset_config: main dataset_split: train -scheme: W4A16 +scheme: W4A16_actorder_group save_dir: TinyLlama-1.1B-Chat-v1.0-actorder-group \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/actorder/w4a16_actorder_weight.yaml b/tests/e2e/vLLM/configs/w4a16_actorder_weight.yaml similarity index 90% rename from tests/e2e/vLLM/configs/actorder/w4a16_actorder_weight.yaml rename to tests/e2e/vLLM/configs/w4a16_actorder_weight.yaml index 7362be296..318e4706e 100644 --- a/tests/e2e/vLLM/configs/actorder/w4a16_actorder_weight.yaml +++ b/tests/e2e/vLLM/configs/w4a16_actorder_weight.yaml @@ -5,5 +5,5 @@ recipe: tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_weight.yaml dataset_id: openai/gsm8k dataset_config: main dataset_split: train -scheme: W4A16 +scheme: W4A16_actorder_weight save_dir: TinyLlama-1.1B-Chat-v1.0-actorder-weight \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/WNA16/w4a16_channel_quant.yaml b/tests/e2e/vLLM/configs/w4a16_channel_quant.yaml similarity index 100% rename from tests/e2e/vLLM/configs/WNA16/w4a16_channel_quant.yaml rename to tests/e2e/vLLM/configs/w4a16_channel_quant.yaml diff --git a/tests/e2e/vLLM/configs/WNA16/w4a16_grouped_quant.yaml b/tests/e2e/vLLM/configs/w4a16_grouped_quant.yaml similarity index 76% rename from tests/e2e/vLLM/configs/WNA16/w4a16_grouped_quant.yaml rename to tests/e2e/vLLM/configs/w4a16_grouped_quant.yaml index bbd1406ce..6a53963e0 100644 --- a/tests/e2e/vLLM/configs/WNA16/w4a16_grouped_quant.yaml +++ b/tests/e2e/vLLM/configs/w4a16_grouped_quant.yaml @@ -3,4 +3,5 @@ test_type: "regression" model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 scheme: W4A16 dataset_id: HuggingFaceH4/ultrachat_200k -dataset_split: train_sft \ No newline at end of file +dataset_split: train_sft +quant_type: "GPTQ" \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/WNA16/w8a16_channel_quant.yaml b/tests/e2e/vLLM/configs/w8a16_channel_quant.yaml similarity index 100% rename from tests/e2e/vLLM/configs/WNA16/w8a16_channel_quant.yaml rename to tests/e2e/vLLM/configs/w8a16_channel_quant.yaml diff --git a/tests/e2e/vLLM/configs/WNA16/w8a16_grouped_quant.yaml b/tests/e2e/vLLM/configs/w8a16_grouped_quant.yaml similarity index 76% rename from tests/e2e/vLLM/configs/WNA16/w8a16_grouped_quant.yaml rename to tests/e2e/vLLM/configs/w8a16_grouped_quant.yaml index 4e9a278a5..44fd79032 100644 --- a/tests/e2e/vLLM/configs/WNA16/w8a16_grouped_quant.yaml +++ b/tests/e2e/vLLM/configs/w8a16_grouped_quant.yaml @@ -3,4 +3,5 @@ test_type: "regression" model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 scheme: W8A16 dataset_id: HuggingFaceH4/ultrachat_200k -dataset_split: train_sft \ No newline at end of file +dataset_split: train_sft +quant_type: "GPTQ" \ No newline at end of file diff --git a/tests/e2e/vLLM/lm_eval_configs/fp8_dynamic_per_token.yaml b/tests/e2e/vLLM/lm_eval_configs/fp8_dynamic_per_token.yaml new file mode 100644 index 000000000..461353770 --- /dev/null +++ b/tests/e2e/vLLM/lm_eval_configs/fp8_dynamic_per_token.yaml @@ -0,0 +1,8 @@ +cadence: "weekly" +model: meta-llama/Meta-Llama-3-8B-Instruct +scheme: FP8_DYNAMIC +num_fewshot: 5 +limit: 1000 +task: "gsm8k" +exact_match,flexible-extract: 0.753 +exact_match,strict-match: 0.753 diff --git a/tests/e2e/vLLM/lm_eval_configs/int8_w8a8_dynamic_per_token.yaml b/tests/e2e/vLLM/lm_eval_configs/int8_w8a8_dynamic_per_token.yaml new file mode 100644 index 000000000..b16f5575a --- /dev/null +++ b/tests/e2e/vLLM/lm_eval_configs/int8_w8a8_dynamic_per_token.yaml @@ -0,0 +1,8 @@ +cadence: "weekly" +model: meta-llama/Meta-Llama-3-8B-Instruct +scheme: INT8 +num_fewshot: 5 +limit: 250 +task: "gsm8k" +exact_match,flexible-extract: 0.728 +exact_match,strict-match: 0.728 diff --git a/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml b/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml index 7528f7dfb..2c0094f88 100644 --- a/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml +++ b/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml @@ -1,5 +1,7 @@ quant_stage: quant_modifiers: + SmoothQuantModifier: + smoothing_strength: 0.8 QuantizationModifier: ignore: [lm_head] config_groups: diff --git a/tests/e2e/vLLM/recipes/INT8/recipe_int8_tensor_weight_static_per_tensor_act.yaml b/tests/e2e/vLLM/recipes/INT8/recipe_int8_tensor_weight_static_per_tensor_act.yaml index 16e39d8b0..4473829e1 100644 --- a/tests/e2e/vLLM/recipes/INT8/recipe_int8_tensor_weight_static_per_tensor_act.yaml +++ b/tests/e2e/vLLM/recipes/INT8/recipe_int8_tensor_weight_static_per_tensor_act.yaml @@ -1,5 +1,7 @@ quant_stage: quant_modifiers: + SmoothQuantModifier: + smoothing_strength: 0.8 QuantizationModifier: ignore: [lm_head] config_groups: diff --git a/tests/e2e/vLLM/recipes/WNA16/recipe_w4a16_channel_quant.yaml b/tests/e2e/vLLM/recipes/WNA16/recipe_w4a16_channel_quant.yaml index 1c4fcf7ab..8a5302c7f 100644 --- a/tests/e2e/vLLM/recipes/WNA16/recipe_w4a16_channel_quant.yaml +++ b/tests/e2e/vLLM/recipes/WNA16/recipe_w4a16_channel_quant.yaml @@ -1,6 +1,6 @@ quant_stage: quant_modifiers: - QuantizationModifier: + GPTQModifier: ignore: [lm_head] config_groups: group_0: diff --git a/tests/e2e/vLLM/recipes/WNA16/recipe_w8a16_channel_quant.yaml b/tests/e2e/vLLM/recipes/WNA16/recipe_w8a16_channel_quant.yaml index ecf57221a..f7d1b742b 100644 --- a/tests/e2e/vLLM/recipes/WNA16/recipe_w8a16_channel_quant.yaml +++ b/tests/e2e/vLLM/recipes/WNA16/recipe_w8a16_channel_quant.yaml @@ -1,6 +1,6 @@ quant_stage: quant_modifiers: - QuantizationModifier: + GPTQModifier: ignore: [lm_head] config_groups: group_0: diff --git a/tests/e2e/vLLM/run_tests.sh b/tests/e2e/vLLM/run_tests.sh new file mode 100644 index 000000000..6f19acedb --- /dev/null +++ b/tests/e2e/vLLM/run_tests.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +SUCCESS=0 + +while getopts "c:t:" OPT; do + case ${OPT} in + c ) + CONFIG="$OPTARG" + ;; + t ) + TEST="$OPTARG" + ;; + \? ) + exit 1 + ;; + esac +done + +# Parse list of configs. +for MODEL_CONFIG in "$CONFIG"/* +do + LOCAL_SUCCESS=0 + + echo "=== RUNNING MODEL: $MODEL_CONFIG ===" + + export TEST_DATA_FILE="$MODEL_CONFIG" + pytest \ + -r a \ + --capture=tee-sys \ + --junitxml="test-results/e2e-$(date +%s).xml" \ + "$TEST" || LOCAL_SUCCESS=$? + + if [[ $LOCAL_SUCCESS == 0 ]]; then + echo "=== PASSED MODEL: $MODEL_CONFIG ===" + else + echo "=== FAILED MODEL: $MODEL_CONFIG ===" + fi + + SUCCESS=$((SUCCESS + LOCAL_SUCCESS)) + +done + +exit "$SUCCESS" diff --git a/tests/e2e/vLLM/test_lmeval.py b/tests/e2e/vLLM/test_lmeval.py new file mode 100644 index 000000000..f77bda983 --- /dev/null +++ b/tests/e2e/vLLM/test_lmeval.py @@ -0,0 +1,131 @@ +import os +import shutil +from pathlib import Path + +import numpy +import pytest +import yaml +from loguru import logger + +from llmcompressor.core import active_session +from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing +from tests.examples.utils import requires_gpu_count + +try: + import lm_eval + + lm_eval_installed = True +except ImportError: + lm_eval_installed = False + logger.warning("lm_eval is not installed. This test will be skipped") + +TEST_DATA_FILE = os.environ.get("TEST_DATA_FILE", None) + + +# Will run each test case in its own process through run_tests.sh +# emulating vLLM CI testing +@requires_gpu_count(1) +@pytest.mark.skipif( + not lm_eval_installed, reason="lm eval is not installed, skipping test" +) +class TestLMEval: + """ + The following test quantizes a model using a preset scheme or recipe, + and then evaluates the model using LM Eval. Each test case is focused on a + specific quantization type (e.g W4A16 with grouped quantization, + W4N16 with channel quantization). To add a new test case, a new config has to be + added to the lm_eval_configs folder. The tests run on a cadence defined by the + `cadence` field. Each config defines the model to quantize. Optionally, a dataset + id and split can be provided for calibration. Finally, all config files must list + a scheme. The scheme can be a preset scheme from + https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py + or another identifier which can be used for the particular test case. If a recipe + is not provided, it is assumed that the scheme provided is a preset scheme and will + be used for quantization. Otherwise, the recipe will always be used if given. + """ # noqa: E501 + + def set_up(self): + eval_config = yaml.safe_load(Path(TEST_DATA_FILE).read_text(encoding="utf-8")) + + if os.environ.get("CADENCE", "commit") != eval_config.get("cadence"): + pytest.skip("Skipping test; cadence mismatch") + + self.model = eval_config["model"] + self.scheme = eval_config.get("scheme") + self.dataset_id = eval_config.get("dataset_id") + self.dataset_config = eval_config.get("dataset_config") + self.dataset_split = eval_config.get("dataset_split") + self.recipe = eval_config.get("recipe") + self.quant_type = eval_config.get("quant_type") + self.save_dir = eval_config.get("save_dir") + self.task = eval_config.get("task") + self.num_fewshot = eval_config.get("num_fewshot") + self.limit = eval_config.get("limit") + self.exact_flex = eval_config.get("exact_match,flexible-extract") + self.exact_strict = eval_config.get("exact_match,strict-match") + + logger.info("========== RUNNING ==============") + logger.info(self.scheme) + + self.device = "cuda:0" + self.num_calibration_samples = 256 + self.max_seq_length = 2048 + + def test_lm_eval(self): + # Run vLLM with saved model + self.set_up() + if not self.save_dir: + self.save_dir = self.model.split("/")[1] + f"-{self.scheme}" + oneshot_model, tokenizer = run_oneshot_for_e2e_testing( + model=self.model, + device=self.device, + num_calibration_samples=self.num_calibration_samples, + max_seq_length=self.max_seq_length, + scheme=self.scheme, + dataset_id=self.dataset_id, + dataset_config=self.dataset_config, + dataset_split=self.dataset_split, + recipe=self.recipe, + quant_type=self.quant_type, + ) + + logger.info("================= SAVING TO DISK ======================") + oneshot_model.save_pretrained(self.save_dir) + tokenizer.save_pretrained(self.save_dir) + recipe_path = os.path.join(self.save_dir, "recipe.yaml") + + # Use the session to fetch the recipe; + # Reset session for next test case + session = active_session() + recipe_yaml_str = session.get_serialized_recipe() + with open(recipe_path, "w") as fp: + fp.write(recipe_yaml_str) + session.reset() + + logger.info("================= Running LM Eval ======================") + + model_args = f"pretrained={self.save_dir}" + results = lm_eval.simple_evaluate( + model="hf", + model_args=model_args, + tasks=[self.task], + num_fewshot=self.num_fewshot, + limit=self.limit, + device="cuda:0", + batch_size=100, + ) + + metrics = results["results"][self.task] + exact_match_strict = metrics.get("exact_match,strict-match") + exact_match_flex = metrics.get("exact_match,flexible-extract") + logger.info("Exact Match, Strict") + logger.info(exact_match_strict) + logger.info("Exact Match, Flex") + logger.info(exact_match_flex) + assert numpy.isclose(exact_match_strict, self.exact_strict, rtol=0.05) + assert numpy.isclose(exact_match_flex, self.exact_flex, rtol=0.05) + self.tear_down() + + def tear_down(self): + if self.save_dir is not None: + shutil.rmtree(self.save_dir) diff --git a/tests/e2e/vLLM/test_vllm.py b/tests/e2e/vLLM/test_vllm.py index 6fcc901ac..aab2a7ed8 100644 --- a/tests/e2e/vLLM/test_vllm.py +++ b/tests/e2e/vLLM/test_vllm.py @@ -1,24 +1,17 @@ import os import re import shutil -import unittest +from pathlib import Path from typing import Callable import pytest -from datasets import load_dataset +import yaml +from huggingface_hub import HfApi from loguru import logger -from parameterized import parameterized, parameterized_class -from transformers import AutoTokenizer from llmcompressor.core import active_session -from llmcompressor.modifiers.quantization import QuantizationModifier -from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot -from tests.testing_utils import ( - parse_params, - preprocess_tokenize_dataset, - requires_gpu, - requires_torch, -) +from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing +from tests.examples.utils import requires_gpu_count try: from vllm import LLM, SamplingParams @@ -28,16 +21,8 @@ vllm_installed = False logger.warning("vllm is not installed. This test will be skipped") -# Defines the file paths to the directories containing the test configs -# for each of the quantization schemes -WNA16 = "tests/e2e/vLLM/configs/WNA16" -FP8 = "tests/e2e/vLLM/configs/FP8" -INT8 = "tests/e2e/vLLM/configs/INT8" -ACTORDER = "tests/e2e/vLLM/configs/actorder" -WNA16_2of4 = "tests/e2e/vLLM/configs/WNA16_2of4" -CONFIGS = [WNA16, FP8, INT8, ACTORDER, WNA16_2of4] - HF_MODEL_HUB_NAME = "nm-testing" +TEST_DATA_FILE = os.environ.get("TEST_DATA_FILE", "") EXPECTED_SAVED_FILES = [ "config.json", @@ -47,54 +32,52 @@ ] -def gen_test_name(testcase_func: Callable, param_num: int, param: dict) -> str: - return "_".join( - [ - testcase_func.__name__, - parameterized.to_safe_name( - param.get("testconfig_path", "").split("configs/")[-1] - ), - param.get("cadence", "").lower(), - ] - ) +@pytest.fixture +def record_config_file(record_testsuite_property: Callable[[str, object], None]): + test_data_file_name = TEST_DATA_FILE.split("configs/")[-1] + record_testsuite_property("TEST_DATA_FILE_NAME", test_data_file_name) -@requires_gpu -@requires_torch -@pytest.mark.skipif(not vllm_installed, reason="vLLM is not installed, skipping test") -@parameterized_class(parse_params(CONFIGS), class_name_func=gen_test_name) -class TestvLLM(unittest.TestCase): +# Will run each test case in its own process through run_tests.sh +# emulating vLLM CI testing +@requires_gpu_count(1) +# @pytest.mark.skipif(not vllm_installed, reason="vLLM is not installed, skipping test") +class TestvLLM: """ The following test quantizes a model using a preset scheme or recipe, runs the model using vLLM, and then pushes the model to the hub for future use. Each test case is focused on a specific quantization type (e.g W4A16 with grouped quantization, W4N16 with channel quantization). - To add a new test case, a new config has to be added to one of the folders - listed in the `CONFIGS` folder. If the test case is for a data type not listed - in `CONFIGS`, a new folder can be created and added to the list. The tests - run on a cadence defined by the `cadence` field. Each config defines the model - to quantize. Optionally, a dataset id and split can be provided for calibration. - Finally, all config files must list a scheme. The scheme can be a preset scheme - from https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py + To add a new test case, a new config has to be added to the `configs` folder. + The tests run on a cadence defined by the `cadence` field. Each config defines + the model to quantize. Optionally, a dataset id and split can be provided for + calibration. Finally, all config files must list a scheme. The scheme can be a + preset scheme from + https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py or another identifier which can be used for the particular test case. If a recipe is not provided, it is assumed that the scheme provided is a preset scheme and will be used for quantization. Otherwise, the recipe will always be used if given. """ # noqa: E501 - model = None - scheme = None - dataset_id = None - dataset_config = None - dataset_split = None - recipe = None - save_dir = None + def set_up(self): + eval_config = yaml.safe_load(Path(TEST_DATA_FILE).read_text(encoding="utf-8")) + + if os.environ.get("CADENCE", "commit") != eval_config.get("cadence"): + pytest.skip("Skipping test; cadence mismatch") + + self.model = eval_config["model"] + self.scheme = eval_config.get("scheme") + self.dataset_id = eval_config.get("dataset_id") + self.dataset_config = eval_config.get("dataset_config") + self.dataset_split = eval_config.get("dataset_split") + self.recipe = eval_config.get("recipe") + self.quant_type = eval_config.get("quant_type") + self.save_dir = eval_config.get("save_dir") - def setUp(self): logger.info("========== RUNNING ==============") - logger.debug(self.scheme) + logger.info(self.scheme) self.device = "cuda:0" - self.oneshot_kwargs = {} self.num_calibration_samples = 256 self.max_seq_length = 2048 self.prompts = [ @@ -102,54 +85,36 @@ def setUp(self): "The president of the US is", "My name is", ] - self.session = active_session() + self.api = HfApi() + @pytest.mark.usefixtures("record_config_file") def test_vllm(self): + # Run vLLM with saved model import torch - # Load model. - loaded_model = SparseAutoModelForCausalLM.from_pretrained( - self.model, device_map=self.device, torch_dtype="auto" - ) - tokenizer = AutoTokenizer.from_pretrained(self.model) - - if self.dataset_id: - ds = load_dataset( - self.dataset_id, name=self.dataset_config, split=self.dataset_split - ) - ds = ds.shuffle(seed=42).select(range(self.num_calibration_samples)) - ds = preprocess_tokenize_dataset(ds, tokenizer, self.max_seq_length) - self.oneshot_kwargs["dataset"] = ds - self.oneshot_kwargs["max_seq_length"] = self.max_seq_length - self.oneshot_kwargs["num_calibration_samples"] = ( - self.num_calibration_samples - ) - - if self.save_dir is None: + self.set_up() + if not self.save_dir: self.save_dir = self.model.split("/")[1] + f"-{self.scheme}" - - self.oneshot_kwargs["model"] = loaded_model - if self.recipe: - self.oneshot_kwargs["recipe"] = self.recipe - else: - # Test assumes that if a recipe was not provided, using - # a compatible preset sceme - self.oneshot_kwargs["recipe"] = QuantizationModifier( - targets="Linear", scheme=self.scheme, ignore=["lm_head"] - ) - - # Apply quantization. - logger.debug("ONESHOT KWARGS", self.oneshot_kwargs) - oneshot( - **self.oneshot_kwargs, - oneshot_device=self.device, + oneshot_model, tokenizer = run_oneshot_for_e2e_testing( + model=self.model, + device=self.device, + num_calibration_samples=self.num_calibration_samples, + max_seq_length=self.max_seq_length, + scheme=self.scheme, + dataset_id=self.dataset_id, + dataset_config=self.dataset_config, + dataset_split=self.dataset_split, + recipe=self.recipe, + quant_type=self.quant_type, ) # check that session contains recipe self._check_session_contains_recipe() - self.oneshot_kwargs["model"].save_pretrained(self.save_dir) + logger.info("================= SAVING TO DISK ======================") + oneshot_model.save_pretrained(self.save_dir) tokenizer.save_pretrained(self.save_dir) + recipe_path = os.path.join(self.save_dir, "recipe.yaml") # check that expected files exist self._check_save_dir_has_expected_files() @@ -157,8 +122,23 @@ def test_vllm(self): # Reset after session info is extracted on save -- recipe self.session.reset() - # Run vLLM with saved model + # Use the session to fetch the recipe; + # Reset session for next test case + session = active_session() + recipe_yaml_str = session.get_serialized_recipe() + with open(recipe_path, "w") as fp: + fp.write(recipe_yaml_str) + session.reset() + + logger.info("================= UPLOADING TO HUB ======================") + + self.api.upload_folder( + repo_id=f"{HF_MODEL_HUB_NAME}/{self.save_dir}-e2e", + folder_path=self.save_dir, + ) + logger.info("================= RUNNING vLLM =========================") + sampling_params = SamplingParams(temperature=0.80, top_p=0.95) if "W4A16_2of4" in self.scheme: # required by the kernel @@ -172,15 +152,15 @@ def test_vllm(self): assert output prompt = output.prompt generated_text = output.outputs[0].text - logger.debug("PROMPT", prompt) - logger.debug("GENERATED TEXT", generated_text) - logger.info("================= UPLOADING TO HUB ======================") - hf_upload_path = os.path.join(HF_MODEL_HUB_NAME, f"{self.save_dir}-e2e") - self.oneshot_kwargs["model"].push_to_hub(hf_upload_path) - tokenizer.push_to_hub(hf_upload_path) + logger.info("PROMPT") + logger.info(prompt) + logger.info("GENERATED TEXT") + logger.info(generated_text) + + self.tear_down() - def tearDown(self): + def tear_down(self): if self.save_dir is not None: shutil.rmtree(self.save_dir) diff --git a/tests/llmcompressor/modifiers/calibration/test_kv_cache.py b/tests/llmcompressor/modifiers/calibration/test_kv_cache.py index d9fca8fa2..25b8468f4 100644 --- a/tests/llmcompressor/modifiers/calibration/test_kv_cache.py +++ b/tests/llmcompressor/modifiers/calibration/test_kv_cache.py @@ -54,9 +54,9 @@ def _prep_for_calibration(module: torch.nn.Module): if is_attention_module(module): module.register_forward_pre_hook( - calibrate_kv_cache_input_hook(), with_kwargs=True + calibrate_kv_cache_input_hook, with_kwargs=True ) - module.register_forward_hook(calibrate_kv_cache_output_hook()) + module.register_forward_hook(calibrate_kv_cache_output_hook) module.quantization_status = QuantizationStatus.CALIBRATION diff --git a/tests/llmcompressor/modifiers/utils/test_hooks.py b/tests/llmcompressor/modifiers/utils/test_hooks.py new file mode 100644 index 000000000..5c4fc5891 --- /dev/null +++ b/tests/llmcompressor/modifiers/utils/test_hooks.py @@ -0,0 +1,83 @@ +import torch + +from llmcompressor.modifiers.utils.hooks import HooksMixin + + +class DummyModel(torch.nn.Module): + """Dummy Model for testing hooks""" + + def __init__(self): + super(DummyModel, self).__init__() + + self.linear1 = torch.nn.Linear(1, 2) + self.linear2 = torch.nn.Linear(2, 3) + self.linear3 = torch.nn.Linear(3, 1) + self.dummy_inputs = torch.tensor([0.0]) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + + return x + + +class DummyMod(HooksMixin): + hook_called: bool = False + + def hook(self, *args, **kwargs): + self.hook_called = True + + +class ModA(DummyMod): + pass + + +class ModB(DummyMod): + pass + + +def test_register_hook(): + model = DummyModel() + + mod_a = ModA() + mod_a.register_hook(model.linear1, mod_a.hook, "forward") + + mod_b = ModB() + mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") + + model(model.dummy_inputs) + assert mod_a.hook_called and mod_b.hook_called + + +def test_remove_hooks(): + model = DummyModel() + + mod_a = ModA() + mod_a.register_hook(model.linear1, mod_a.hook, "forward") + + mod_b = ModB() + mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") + mod_b.remove_hooks() + + model(model.dummy_inputs) + assert mod_a.hook_called and not mod_b.hook_called + + +def test_disable_hooks(): + model = DummyModel() + + mod_a = ModA() + mod_a.register_hook(model.linear1, mod_a.hook, "forward") + + mod_b = ModB() + mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") + + with HooksMixin.disable_hooks(): + model(model.dummy_inputs) + assert not mod_a.hook_called and not mod_b.hook_called + + mod_a.hook_called = False + mod_b.hook_called = False + model(model.dummy_inputs) + assert mod_a.hook_called and mod_b.hook_called diff --git a/tests/llmcompressor/observers/test_helpers.py b/tests/llmcompressor/observers/test_helpers.py index 6668223f7..527176019 100644 --- a/tests/llmcompressor/observers/test_helpers.py +++ b/tests/llmcompressor/observers/test_helpers.py @@ -32,7 +32,7 @@ def _prep_for_input_quant_calibration(module: torch.nn.Module): if not quantization_scheme: return - module.register_forward_pre_hook(calibrate_input_hook()) + module.register_forward_pre_hook(calibrate_input_hook) module.quantization_status = QuantizationStatus.CALIBRATION diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 5421af4cf..1a229a6aa 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -13,11 +13,9 @@ setup_modifier_factory, ) from tests.llmcompressor.pytorch.helpers import LinearNet -from tests.testing_utils import requires_torch @pytest.mark.unit -@requires_torch class TestInvalidLayerwiseRecipesRaiseExceptions(unittest.TestCase): def setUp(self): setup_modifier_factory() @@ -45,7 +43,6 @@ def test_invalid_layerwise_recipes_raise_exceptions(self, sparsity, targets): @pytest.mark.unit -@requires_torch class TestSuccessfulLayerwiseRecipe(unittest.TestCase): def setUp(self): setup_modifier_factory() @@ -66,7 +63,6 @@ def test_successful_layerwise_recipe(self): @pytest.mark.unit -@requires_torch class TestCreateDefaultQuantModifier(unittest.TestCase): def setUp(self): setup_modifier_factory() @@ -91,7 +87,6 @@ def test_create_default_quant_modifier(self): @pytest.mark.unit -@requires_torch class TestSetQuantIfModifierAlreadyExists(unittest.TestCase): def setUp(self): setup_modifier_factory() diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/wanda/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/wanda/test_pytorch.py index 4fac2b12f..b2050c179 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/wanda/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/wanda/test_pytorch.py @@ -4,11 +4,9 @@ from llmcompressor.modifiers.factory import ModifierFactory from tests.llmcompressor.modifiers.conf import setup_modifier_factory -from tests.testing_utils import requires_torch @pytest.mark.unit -@requires_torch class TestWandaPytorchIsRegistered(unittest.TestCase): def setUp(self): self.kwargs = dict( diff --git a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py index c521a1361..7977c4546 100644 --- a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py @@ -6,11 +6,9 @@ from llmcompressor.core import State from llmcompressor.modifiers.smoothquant import SmoothQuantModifier from tests.llmcompressor.pytorch.helpers import LinearNet -from tests.testing_utils import requires_torch @pytest.mark.unit -@requires_torch class TestSmoothQuantMapping(unittest.TestCase): def setUp(self): self.model = LinearNet() diff --git a/tests/llmcompressor/transformers/compression/test_quantization.py b/tests/llmcompressor/transformers/compression/test_quantization.py index 2dd1249d6..c0f0d2c02 100644 --- a/tests/llmcompressor/transformers/compression/test_quantization.py +++ b/tests/llmcompressor/transformers/compression/test_quantization.py @@ -14,12 +14,11 @@ from llmcompressor.transformers import oneshot from llmcompressor.transformers.finetune.data import TextGenerationDataset from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments -from tests.testing_utils import parse_params, requires_gpu, requires_torch +from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/compression/configs" -@requires_torch @requires_gpu @pytest.mark.integration @parameterized_class(parse_params(CONFIGS_DIRECTORY)) diff --git a/tests/llmcompressor/transformers/compression/test_run_compressed.py b/tests/llmcompressor/transformers/compression/test_run_compressed.py index 33a31c332..0c2a0ab0e 100644 --- a/tests/llmcompressor/transformers/compression/test_run_compressed.py +++ b/tests/llmcompressor/transformers/compression/test_run_compressed.py @@ -9,12 +9,11 @@ from parameterized import parameterized_class from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -from tests.testing_utils import parse_params, requires_gpu, requires_torch +from tests.testing_utils import parse_params, requires_gpu CONFIG_DIR = "tests/llmcompressor/transformers/compression/run_compressed_configs" -@requires_torch @requires_gpu @parameterized_class(parse_params(CONFIG_DIR)) class TestQuantizationMatches(unittest.TestCase): diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py index 3415858af..a602c4828 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py @@ -8,7 +8,6 @@ from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.runner import StageRunner from llmcompressor.transformers.finetune.training_args import TrainingArguments -from tests.testing_utils import requires_torch @pytest.mark.unit @@ -283,7 +282,6 @@ def test_split_loading(self, split_def): self.assertIsInstance(train_dataset[0], dict) -@requires_torch @pytest.mark.unit class TestTokenizationDataset(unittest.TestCase): @pytest.fixture(autouse=True) diff --git a/tests/llmcompressor/transformers/finetune/test_finetune_no_recipe_custom_dataset.py b/tests/llmcompressor/transformers/finetune/test_finetune_no_recipe_custom_dataset.py index 5f00c4c28..e7c8e7b9a 100644 --- a/tests/llmcompressor/transformers/finetune/test_finetune_no_recipe_custom_dataset.py +++ b/tests/llmcompressor/transformers/finetune/test_finetune_no_recipe_custom_dataset.py @@ -10,7 +10,7 @@ import pytest from parameterized import parameterized_class -from tests.testing_utils import parse_params, requires_gpu, requires_torch +from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/finetune/finetune_custom" GPU_CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/finetune/finetune_custom/gpu" @@ -112,7 +112,6 @@ def tearDown(self): self.monkeypatch.undo() -@requires_torch @pytest.mark.integration @parameterized_class(parse_params(CONFIGS_DIRECTORY)) class TestOneshotCustomDatasetSmall(TestFinetuneNoRecipeCustomDataset): @@ -137,7 +136,6 @@ def test_oneshot_then_finetune_small(self): self._test_finetune_wout_recipe_custom_dataset() -@requires_torch @requires_gpu @pytest.mark.integration @parameterized_class(parse_params(GPU_CONFIGS_DIRECTORY)) diff --git a/tests/llmcompressor/transformers/finetune/test_finetune_oneshot_with_modifier.py b/tests/llmcompressor/transformers/finetune/test_finetune_oneshot_with_modifier.py index 47ef85244..ec517e2d6 100644 --- a/tests/llmcompressor/transformers/finetune/test_finetune_oneshot_with_modifier.py +++ b/tests/llmcompressor/transformers/finetune/test_finetune_oneshot_with_modifier.py @@ -5,13 +5,12 @@ import pytest from parameterized import parameterized_class -from tests.testing_utils import parse_params, requires_gpu, requires_torch +from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/finetune/finetune_generic" @pytest.mark.integration -@requires_torch @requires_gpu @parameterized_class(parse_params(CONFIGS_DIRECTORY)) class TestOneshotWithModifierObject(unittest.TestCase): diff --git a/tests/llmcompressor/transformers/finetune/test_finetune_without_recipe.py b/tests/llmcompressor/transformers/finetune/test_finetune_without_recipe.py index 0087c7c2d..7facd088e 100644 --- a/tests/llmcompressor/transformers/finetune/test_finetune_without_recipe.py +++ b/tests/llmcompressor/transformers/finetune/test_finetune_without_recipe.py @@ -4,13 +4,12 @@ import pytest from parameterized import parameterized_class -from tests.testing_utils import parse_params, requires_gpu, requires_torch +from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/finetune/finetune_generic" @pytest.mark.integration -@requires_torch @requires_gpu @parameterized_class(parse_params(CONFIGS_DIRECTORY)) class TestFinetuneWithoutRecipe(unittest.TestCase): diff --git a/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py b/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py index b0f24de43..870503496 100644 --- a/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py +++ b/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py @@ -7,7 +7,7 @@ from parameterized import parameterized_class from transformers import AutoConfig -from tests.testing_utils import parse_params, requires_gpu, requires_torch +from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/finetune/finetune_oneshot_configs" GPU_CONFIGS_DIRECTORY = ( @@ -56,7 +56,6 @@ def tearDown(self): shutil.rmtree(self.output) -@requires_torch @pytest.mark.integration @parameterized_class(parse_params(CONFIGS_DIRECTORY)) class TestOneshotAndFinetuneSmall(TestOneshotAndFinetune): @@ -77,7 +76,6 @@ def test_oneshot_then_finetune_small(self): self._test_oneshot_and_finetune() -@requires_torch @requires_gpu @pytest.mark.integration @parameterized_class(parse_params(GPU_CONFIGS_DIRECTORY)) diff --git a/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py b/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py index fa0644a39..509464a34 100644 --- a/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py +++ b/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py @@ -4,13 +4,12 @@ import pytest from parameterized import parameterized_class -from tests.testing_utils import parse_params, requires_gpu, requires_torch +from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/finetune/finetune_tokenizer" @pytest.mark.integration -@requires_torch @requires_gpu @parameterized_class(parse_params(CONFIGS_DIRECTORY)) class TestOneshotAndFinetuneWithTokenizer(unittest.TestCase): diff --git a/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py b/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py index db5950188..e9c3d7c5c 100644 --- a/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py +++ b/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py @@ -5,11 +5,8 @@ import pytest -from tests.testing_utils import requires_torch - @pytest.mark.unit -@requires_torch @pytest.mark.skipif( "CADENCE" in os.environ and (os.environ["CADENCE"] == "weekly" or os.environ["CADENCE"] == "nightly"), diff --git a/tests/llmcompressor/transformers/finetune/test_safetensors.py b/tests/llmcompressor/transformers/finetune/test_safetensors.py index 09d08b459..84c1bf1b2 100644 --- a/tests/llmcompressor/transformers/finetune/test_safetensors.py +++ b/tests/llmcompressor/transformers/finetune/test_safetensors.py @@ -6,13 +6,12 @@ import pytest from parameterized import parameterized_class -from tests.testing_utils import parse_params, requires_gpu, requires_torch +from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/finetune/finetune_generic" @pytest.mark.integration -@requires_torch @requires_gpu @parameterized_class(parse_params(CONFIGS_DIRECTORY)) class TestSafetensors(unittest.TestCase): diff --git a/tests/llmcompressor/transformers/gptq/test_oneshot.py b/tests/llmcompressor/transformers/gptq/test_oneshot.py index b9e9ed41e..7f1a1ec99 100644 --- a/tests/llmcompressor/transformers/gptq/test_oneshot.py +++ b/tests/llmcompressor/transformers/gptq/test_oneshot.py @@ -6,7 +6,6 @@ from transformers import AutoModelForCausalLM from llmcompressor.modifiers.quantization.gptq import GPTQModifier -from tests.testing_utils import requires_torch recipe_str = """ quant_stage: @@ -51,7 +50,6 @@ ) -@requires_torch @parameterized_class( [ {"recipe": recipe_str}, diff --git a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py index 60b0e2b31..2f6c51ebb 100644 --- a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py +++ b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py @@ -6,7 +6,7 @@ import yaml from parameterized import parameterized_class -from tests.testing_utils import parse_params, requires_gpu, requires_torch +from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/obcq/obcq_configs/consec_runs" GPU_CONFIGS_DIRECTORY = ( @@ -83,7 +83,6 @@ def tearDown(self): shutil.rmtree(self.output) -@requires_torch @pytest.mark.integration @parameterized_class(parse_params(CONFIGS_DIRECTORY)) class TestConsecutiveRunsSmall(TestConsecutiveRuns): @@ -106,7 +105,6 @@ def test_consecutive_runs_small(self): # TODO: @Satrat and @dsikka, revisit if we want these nightly or weekly @requires_gpu -@requires_torch @pytest.mark.integration @parameterized_class(parse_params(GPU_CONFIGS_DIRECTORY)) class TestConsecutiveRunsGPU(TestConsecutiveRuns): diff --git a/tests/llmcompressor/transformers/obcq/test_mask_structure_preservation.py b/tests/llmcompressor/transformers/obcq/test_mask_structure_preservation.py index 957f19b3f..5095fe827 100644 --- a/tests/llmcompressor/transformers/obcq/test_mask_structure_preservation.py +++ b/tests/llmcompressor/transformers/obcq/test_mask_structure_preservation.py @@ -6,14 +6,13 @@ from parameterized import parameterized_class from llmcompressor.core import reset_session -from tests.testing_utils import parse_params, requires_torch +from tests.testing_utils import parse_params MASK_STRUCTURE_CONFIGS_DIRECTORY = ( "tests/llmcompressor/transformers/obcq/obcq_configs/consec_runs/mask_structure" ) -@requires_torch @pytest.mark.integration @parameterized_class(parse_params(MASK_STRUCTURE_CONFIGS_DIRECTORY)) class TestMaskStructurePreserved(unittest.TestCase): diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py index 03517de07..cb7f64943 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py @@ -4,7 +4,7 @@ import pytest from parameterized import parameterized_class -from tests.testing_utils import parse_params, requires_gpu, requires_torch +from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/obcq/obcq_configs/completion" GPU_CONFIGS_DIRECTORY = ( @@ -99,7 +99,6 @@ def tearDown(self): shutil.rmtree(self.output) -@requires_torch @requires_gpu @pytest.mark.integration @parameterized_class(parse_params(CONFIGS_DIRECTORY)) @@ -121,7 +120,6 @@ def test_obcq_completion_small(self): self._test_oneshot_completion() -@requires_torch @requires_gpu @pytest.mark.integration @parameterized_class(parse_params(GPU_CONFIGS_DIRECTORY)) diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py b/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py index f15b37c4c..997794ae9 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py @@ -3,11 +3,9 @@ import pytest from llmcompressor.utils.pytorch.module import get_no_split_params -from tests.testing_utils import requires_torch @pytest.mark.integration -@requires_torch class TestInferTargets(unittest.TestCase): def setUp(self): from transformers import AutoModelForCausalLM diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py b/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py index f040da45c..f9818391c 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py @@ -2,11 +2,8 @@ import pytest -from tests.testing_utils import requires_torch - @pytest.mark.integration -@requires_torch class TestLMHead(unittest.TestCase): def setUp(self): import torch diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py b/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py index 0e80b6d0c..0ef7f872d 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py @@ -5,13 +5,12 @@ import pytest from parameterized import parameterized_class -from tests.testing_utils import parse_params, requires_gpu, requires_torch +from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/obcq/obcq_configs/sparse" GPU_CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/obcq/obcq_configs/sparse/gpu" -@requires_torch @pytest.mark.integration @parameterized_class(parse_params(CONFIGS_DIRECTORY)) class TestSparsities(unittest.TestCase): @@ -59,7 +58,6 @@ def tearDown(self): # TODO: @Satrat and @dsikka, revisit if we want these nightly or weekly @requires_gpu -@requires_torch @pytest.mark.integration @parameterized_class(parse_params(GPU_CONFIGS_DIRECTORY)) class TestSparsitiesGPU(unittest.TestCase): diff --git a/tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py b/tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py index 8cdca786a..dd27ebc2e 100644 --- a/tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py +++ b/tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py @@ -2,11 +2,8 @@ import pytest -from tests.testing_utils import requires_torch - @pytest.mark.integration -@requires_torch class TestSGPTDefaults(unittest.TestCase): def test_sgpt_defaults(self): from llmcompressor.core.state import State diff --git a/tests/llmcompressor/transformers/oneshot/test_api_inputs.py b/tests/llmcompressor/transformers/oneshot/test_api_inputs.py index 3e3ee2147..a64a218db 100644 --- a/tests/llmcompressor/transformers/oneshot/test_api_inputs.py +++ b/tests/llmcompressor/transformers/oneshot/test_api_inputs.py @@ -5,7 +5,7 @@ from parameterized import parameterized_class from tests.llmcompressor.transformers.oneshot.dataset_processing import get_data_utils -from tests.testing_utils import parse_params, requires_torch +from tests.testing_utils import parse_params CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/oneshot/oneshot_configs" @@ -15,7 +15,6 @@ @pytest.mark.smoke @pytest.mark.integration -@requires_torch @parameterized_class(parse_params(CONFIGS_DIRECTORY)) class TestOneShotInputs(unittest.TestCase): model = None diff --git a/tests/llmcompressor/transformers/oneshot/test_cli.py b/tests/llmcompressor/transformers/oneshot/test_cli.py index 15b1ba379..5780ca46f 100644 --- a/tests/llmcompressor/transformers/oneshot/test_cli.py +++ b/tests/llmcompressor/transformers/oneshot/test_cli.py @@ -4,14 +4,13 @@ import pytest from parameterized import parameterized_class -from tests.testing_utils import parse_params, requires_torch, run_cli_command +from tests.testing_utils import parse_params, run_cli_command CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/oneshot/oneshot_configs" @pytest.mark.smoke @pytest.mark.integration -@requires_torch @parameterized_class(parse_params(CONFIGS_DIRECTORY)) class TestOneShotCli(unittest.TestCase): model = None diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 7a5dab66f..07b166013 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -37,10 +37,6 @@ def is_gpu_available(): return False -def requires_torch(test_case): - return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) - - def requires_gpu(test_case): return unittest.skipUnless(is_gpu_available(), "test requires GPU")(test_case)