Skip to content

Commit

Permalink
Merge pull request oobabooga#5534 from oobabooga/dev
Browse files Browse the repository at this point in the history
Merge dev branch
  • Loading branch information
oobabooga authored Feb 17, 2024
2 parents dd46229 + d6bd71d commit 7838075
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
17 changes: 9 additions & 8 deletions modules/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,21 @@ def from_pretrained(self, path_to_model):

model = ExLlamaV2(config)

if shared.args.cache_8bit:
cache = ExLlamaV2Cache_8bit(model, lazy=True)
else:
cache = ExLlamaV2Cache(model, lazy=True)

if shared.args.autosplit:
model.load_autosplit(cache)
else:
if not shared.args.autosplit:
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]

model.load(split)

if shared.args.cache_8bit:
cache = ExLlamaV2Cache_8bit(model, lazy=shared.args.autosplit)
else:
cache = ExLlamaV2Cache(model, lazy=shared.args.autosplit)

if shared.args.autosplit:
model.load_autosplit(cache)

tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)

Expand Down
18 changes: 10 additions & 8 deletions modules/exllamav2_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,26 @@ class Exllamav2HF(PreTrainedModel):
def __init__(self, config: ExLlamaV2Config):
super().__init__(PretrainedConfig())
self.ex_config = config
self.ex_model = ExLlamaV2(config)
self.loras = None
self.generation_config = GenerationConfig()

if shared.args.cache_8bit:
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=True)
else:
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=True)
self.ex_model = ExLlamaV2(config)

if shared.args.autosplit:
self.ex_model.load_autosplit(self.ex_cache)
else:
if not shared.args.autosplit:
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]

self.ex_model.load(split)

if shared.args.cache_8bit:
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit)
else:
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit)

if shared.args.autosplit:
self.ex_model.load_autosplit(self.ex_cache)

self.past_seq = None
if shared.args.cfg_cache:
if shared.args.cache_8bit:
Expand Down

0 comments on commit 7838075

Please sign in to comment.