Skip to content

Commit

Permalink
Lint fixes for _models (#1483)
Browse files Browse the repository at this point in the history
* Lint fixes for _models

* Format files
  • Loading branch information
jainapurva authored Jan 3, 2025
1 parent 3f36c78 commit d9fe2c2
Show file tree
Hide file tree
Showing 26 changed files with 1,404 additions and 668 deletions.
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ include = [
"torchao/sparsity/**/*.py",
"torchao/profiler/**/*.py",
"torchao/testing/**/*.py",
"torchao/_models/**/*.py",
"torchao/prototype/low_bit_optim/**.py",
"torchao/utils.py",
"torchao/ops.py",
Expand Down
30 changes: 14 additions & 16 deletions torchao/_models/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import lm_eval
import torch
import torch.nn.functional as F

from torchao.quantization.utils import _lm_eval_available, _MultiInput
from torchao.quantization.GPTQ_MT import MultiTensor
import lm_eval
from torchao.quantization.utils import _MultiInput

try: # lm_eval version 0.4
from lm_eval.evaluator import evaluate # pyre-ignore[21]
from lm_eval.models.huggingface import HFLM as eval_wrapper # pyre-ignore[21]
Expand All @@ -24,8 +25,7 @@
eval_wrapper = base.BaseLM
get_task_dict = tasks.get_task_dict
evaluate = evaluator.evaluate
import torch
import torch.nn.functional as F


class MultiTensorInputRecorder(eval_wrapper):
def __init__(
Expand Down Expand Up @@ -140,7 +140,9 @@ def _model_call(self, inps):
if T >= self.calibration_seq_length:
inps = inps[: self.calibration_seq_length]
else:
inps = F.pad(inps, (0, self.calibration_seq_length - T), value=self.pad_token)
inps = F.pad(
inps, (0, self.calibration_seq_length - T), value=self.pad_token
)

inps = inps.unsqueeze(0)
model_in = self.input_prep_func(inps)
Expand All @@ -155,6 +157,7 @@ def _model_call(self, inps):
def _model_generate(self, context, max_length, eos_token_id):
raise Exception("unimplemented")


class InputRecorder(eval_wrapper):
"""
This is a fake evaluation wrapper from the lm_eval library that just records the inputs
Expand Down Expand Up @@ -196,8 +199,7 @@ def __init__(
# need to take inps and convert to corrent input
# for model
self.input_prep_func = (
input_prep_func if input_prep_func is not None
else lambda x: (x,)
input_prep_func if input_prep_func is not None else lambda x: (x,)
)

self.pad_calibration_inputs = pad_calibration_inputs
Expand Down Expand Up @@ -307,17 +309,14 @@ def _model_call(self, inps):
def _model_generate(self, context, max_length, eos_token_id):
raise Exception("unimplemented")


class TransformerEvalWrapper(InputRecorder):
"""
A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library.
"""

def __init__(
self,
model,
tokenizer,
max_seq_length,
input_prep_func=None,
device="cuda"
self, model, tokenizer, max_seq_length, input_prep_func=None, device="cuda"
):
super().__init__(tokenizer, None)
self._model = model
Expand All @@ -328,8 +327,7 @@ def __init__(
# need to take inps and convert to corrent input
# for model
self.input_prep_func = (
input_prep_func if input_prep_func is not None
else lambda x: (x,)
input_prep_func if input_prep_func is not None else lambda x: (x,)
)

def _model_call(self, inps):
Expand All @@ -343,7 +341,7 @@ def _model_call(self, inps):
return logits

def _model_generate(self, context, max_length, eos_token_id):
raise Exception('unimplemented')
raise Exception("unimplemented")

def run_eval(self, tasks, limit):
try:
Expand Down
Loading

0 comments on commit d9fe2c2

Please sign in to comment.