diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index df05d912..01239cd3 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -48,11 +48,19 @@ def bernoulli( if density >= 1: return tensor - mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density)) - res = tensor * mask + if (tensor.device.type != "cpu") or tensor.dtype == torch.bfloat16: + work_dtype = tensor.dtype + else: + # torch.bernoulli not implemented for float16 on CPU, upcast to float32 + work_dtype = torch.float32 + + mask = torch.bernoulli( + torch.full_like(input=tensor, fill_value=density, dtype=work_dtype) + ) + res = tensor.to(work_dtype) * mask if rescale: res /= density - return res + return res.to(tensor.dtype) def sparsify( diff --git a/tests/test_merges.py b/tests/test_merges.py index 9c5f35bb..195d0a3c 100644 --- a/tests/test_merges.py +++ b/tests/test_merges.py @@ -1,6 +1,6 @@ import os import tempfile -from typing import Optional +from typing import Dict, Optional import pytest from transformers import LlamaConfig, LlamaForCausalLM @@ -10,7 +10,9 @@ InputSliceDefinition, MergeConfiguration, OutputSliceDefinition, + ParameterSetting, ) +from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex from mergekit.merge import MergeOptions, run_merge @@ -81,6 +83,26 @@ def test_task_arithmetic_merge(self, model_a, model_b, model_c): ) self.run_and_check_merge(config) + def test_ties_merge(self, model_a, model_b, model_c): + config = self.two_model_config( + model_a, + model_b, + merge_method="ties", + base_model=model_c, + params={"density": 0.3}, + ) + self.run_and_check_merge(config) + + def test_dare_ties_merge(self, model_a, model_b, model_c): + config = self.two_model_config( + model_a, + model_b, + merge_method="dare_ties", + base_model=model_c, + params={"density": 0.66}, + ) + self.run_and_check_merge(config) + def run_and_check_merge(self, config: MergeConfiguration): with tempfile.TemporaryDirectory() as tmpdir: run_merge(config, out_path=tmpdir, options=MergeOptions()) @@ -91,8 +113,24 @@ def run_and_check_merge(self, config: MergeConfiguration): os.path.join(tmpdir, "config.json") ), "No config json produced by merge" + # check for NaN in output + loader = LazyTensorLoader( + ShardedTensorIndex.from_disk(tmpdir), lazy_unpickle=False + ) + tp = loader.index.tensor_paths + sorted_tensors = sorted(tp.keys(), key=lambda k: tp[k]) + for tensor_name in sorted_tensors: + tensor = loader.get_tensor(tensor_name) + has_nan = tensor.view(-1).isnan().any() + assert not has_nan, "Output contains NaN" + def two_model_config( - self, model_a, model_b, merge_method: str, base_model: Optional[str] = None + self, + model_a, + model_b, + merge_method: str, + base_model: Optional[str] = None, + params: Optional[Dict[str, ParameterSetting]] = None, ): config = MergeConfiguration( merge_method=merge_method, @@ -108,6 +146,7 @@ def two_model_config( ), ], dtype="bfloat16", + parameters=params, ) return config diff --git a/tests/test_sparsify.py b/tests/test_sparsify.py index 837d47b7..e63f1607 100644 --- a/tests/test_sparsify.py +++ b/tests/test_sparsify.py @@ -49,3 +49,11 @@ def test_bernoulli_without_rescale(self, sample_tensor): sample_tensor, density=0.5, method=SparsificationMethod.random ) assert 0 < torch.count_nonzero(result) <= sample_tensor.view(-1).shape[0] + + def test_cpu_dtypes(self, sample_tensor): + for dt in (torch.float16, torch.bfloat16, torch.float32): + sparsify( + tensor=sample_tensor.to(dtype=dt).cpu(), + density=0.5, + method=SparsificationMethod.rescaled_random, + )