Skip to content

Commit

Permalink
Fix fp16 DARE on CPU (#69)
Browse files Browse the repository at this point in the history
Use either float32 or bfloat16 for `torch.bernoulli`.

Also improves merge tests by checking for NaN in output tensors.
  • Loading branch information
cg123 authored Jan 4, 2024
1 parent 519a868 commit a5c5dc0
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 5 deletions.
14 changes: 11 additions & 3 deletions mergekit/sparsify.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,19 @@ def bernoulli(
if density >= 1:
return tensor

mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density))
res = tensor * mask
if (tensor.device.type != "cpu") or tensor.dtype == torch.bfloat16:
work_dtype = tensor.dtype
else:
# torch.bernoulli not implemented for float16 on CPU, upcast to float32
work_dtype = torch.float32

mask = torch.bernoulli(
torch.full_like(input=tensor, fill_value=density, dtype=work_dtype)
)
res = tensor.to(work_dtype) * mask
if rescale:
res /= density
return res
return res.to(tensor.dtype)


def sparsify(
Expand Down
43 changes: 41 additions & 2 deletions tests/test_merges.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import tempfile
from typing import Optional
from typing import Dict, Optional

import pytest
from transformers import LlamaConfig, LlamaForCausalLM
Expand All @@ -10,7 +10,9 @@
InputSliceDefinition,
MergeConfiguration,
OutputSliceDefinition,
ParameterSetting,
)
from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex
from mergekit.merge import MergeOptions, run_merge


Expand Down Expand Up @@ -81,6 +83,26 @@ def test_task_arithmetic_merge(self, model_a, model_b, model_c):
)
self.run_and_check_merge(config)

def test_ties_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
model_a,
model_b,
merge_method="ties",
base_model=model_c,
params={"density": 0.3},
)
self.run_and_check_merge(config)

def test_dare_ties_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
model_a,
model_b,
merge_method="dare_ties",
base_model=model_c,
params={"density": 0.66},
)
self.run_and_check_merge(config)

def run_and_check_merge(self, config: MergeConfiguration):
with tempfile.TemporaryDirectory() as tmpdir:
run_merge(config, out_path=tmpdir, options=MergeOptions())
Expand All @@ -91,8 +113,24 @@ def run_and_check_merge(self, config: MergeConfiguration):
os.path.join(tmpdir, "config.json")
), "No config json produced by merge"

# check for NaN in output
loader = LazyTensorLoader(
ShardedTensorIndex.from_disk(tmpdir), lazy_unpickle=False
)
tp = loader.index.tensor_paths
sorted_tensors = sorted(tp.keys(), key=lambda k: tp[k])
for tensor_name in sorted_tensors:
tensor = loader.get_tensor(tensor_name)
has_nan = tensor.view(-1).isnan().any()
assert not has_nan, "Output contains NaN"

def two_model_config(
self, model_a, model_b, merge_method: str, base_model: Optional[str] = None
self,
model_a,
model_b,
merge_method: str,
base_model: Optional[str] = None,
params: Optional[Dict[str, ParameterSetting]] = None,
):
config = MergeConfiguration(
merge_method=merge_method,
Expand All @@ -108,6 +146,7 @@ def two_model_config(
),
],
dtype="bfloat16",
parameters=params,
)

return config
8 changes: 8 additions & 0 deletions tests/test_sparsify.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,11 @@ def test_bernoulli_without_rescale(self, sample_tensor):
sample_tensor, density=0.5, method=SparsificationMethod.random
)
assert 0 < torch.count_nonzero(result) <= sample_tensor.view(-1).shape[0]

def test_cpu_dtypes(self, sample_tensor):
for dt in (torch.float16, torch.bfloat16, torch.float32):
sparsify(
tensor=sample_tensor.to(dtype=dt).cpu(),
density=0.5,
method=SparsificationMethod.rescaled_random,
)

0 comments on commit a5c5dc0

Please sign in to comment.