Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Dec 12, 2024
1 parent 3f97a1a commit c7d37d5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
6 changes: 4 additions & 2 deletions examples/quantization_kv_cache/llama3_fp8_kv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from llmcompressor.transformers import oneshot

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
MODEL_ID="TinyLlama/TinyLlama-1.1B-Chat-v1.0"

model = SparseAutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype="auto",
Expand Down
16 changes: 8 additions & 8 deletions tests/llmcompressor/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import pytest
import torch
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.quantization.cache import KVCacheScaleType
from compressed_tensors.quantization.lifecycle import KVCacheScaleType
from compressed_tensors.quantization.utils.helpers import iter_named_quantizable_modules
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor.core import reset_session
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
from llmcompressor.transformers import oneshot

MODEL_IDS = [
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_kv_cache_config_format(oneshot_fixture):

def test_kv_cache_model_state_dict_attr(oneshot_fixture):
for output_dir, _ in oneshot_fixture.items():
model = SparseAutoModelForCausalLM.from_pretrained(output_dir)
model = AutoModelForCausalLM.from_pretrained(output_dir)

counts = 0
for name, submodule in iter_named_quantizable_modules(
Expand All @@ -99,7 +99,7 @@ def test_kv_cache_model_state_dict_attr(oneshot_fixture):

def test_kv_cache_model_populate_kv_scales_only(tmp_path):
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model = SparseAutoModelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="cuda:0" if torch.cuda.is_available() else "cpu",
torch_dtype="auto",
Expand Down Expand Up @@ -187,7 +187,7 @@ def tokenize(sample):
# check for vllm loading
assert quant_config.quant_method == "compressed-tensors"

model = SparseAutoModelForCausalLM.from_pretrained(output_dir)
model = AutoModelForCausalLM.from_pretrained(output_dir)

counts = 0
for name, submodule in iter_named_quantizable_modules(
Expand All @@ -203,7 +203,7 @@ def tokenize(sample):

def test_kv_cache_with_gptq(tmp_path):
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model = SparseAutoModelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="cuda:0" if torch.cuda.is_available() else "cpu",
torch_dtype="auto",
Expand Down Expand Up @@ -298,7 +298,7 @@ def tokenize(sample):
assert scheme.dynamic == kv_cache_dynamic
assert scheme.symmetric == kv_cache_symmetric

model = SparseAutoModelForCausalLM.from_pretrained(output_dir)
model = AutoModelForCausalLM.from_pretrained(output_dir)

counts = 0
for name, submodule in iter_named_quantizable_modules(
Expand Down

0 comments on commit c7d37d5

Please sign in to comment.