Skip to content

Commit

Permalink
Move tasks.py into io
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Dec 31, 2023
1 parent d07d066 commit 9e83c33
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 16 deletions.
File renamed without changes.
2 changes: 1 addition & 1 deletion mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from mergekit.common import MergeOptions, ModelReference
from mergekit.config import MergeConfiguration
from mergekit.graph import Executor
from mergekit.io.tasks import LoaderCache
from mergekit.plan import MergePlanner
from mergekit.tasks import LoaderCache
from mergekit.tokenizer import TokenizerInfo


Expand Down
2 changes: 1 addition & 1 deletion mergekit/merge_methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.tasks import GatherTensors
from mergekit.io.tasks import GatherTensors


class ConfigParameterDef(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion mergekit/merge_methods/generalized_task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@

from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.sparsify import SparsificationMethod, sparsify
from mergekit.tasks import GatherTensors


class ConsensusMethod(str, Enum):
Expand Down
2 changes: 1 addition & 1 deletion mergekit/merge_methods/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.tasks import GatherTensors


class LinearMergeTask(Task[torch.Tensor]):
Expand Down
2 changes: 1 addition & 1 deletion mergekit/merge_methods/passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.tasks import GatherTensors


class PassthroughMergeTask(Task[torch.Tensor]):
Expand Down
2 changes: 1 addition & 1 deletion mergekit/merge_methods/slerp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.tasks import GatherTensors


class SlerpTask(Task[torch.Tensor]):
Expand Down
2 changes: 1 addition & 1 deletion mergekit/merge_methods/tokenizer_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.merge_methods.slerp import slerp
from mergekit.tasks import GatherTensors
from mergekit.tokenizer import BuildTokenizer, TokenizerInfo


Expand Down
2 changes: 1 addition & 1 deletion mergekit/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
OutputSliceDefinition,
)
from mergekit.graph import Task
from mergekit.io.tasks import FinalizeModel, GatherTensors, SaveTensor, TensorWriterTask
from mergekit.merge_methods import MergeMethod
from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge
from mergekit.tasks import FinalizeModel, GatherTensors, SaveTensor, TensorWriterTask
from mergekit.tokenizer import BuildTokenizer


Expand Down
22 changes: 14 additions & 8 deletions tests/test_merges.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import tempfile
from typing import Optional

Expand Down Expand Up @@ -49,8 +50,7 @@ def test_gpt2_copy(self):
models=[InputModelDefinition(model="gpt2")],
dtype="bfloat16",
)
with tempfile.TemporaryDirectory() as tmpdir:
run_merge(config, out_path=tmpdir, options=MergeOptions())
self.run_and_check_merge(config)

def test_gpt2_stack(self):
config = MergeConfiguration(
Expand All @@ -63,28 +63,34 @@ def test_gpt2_stack(self):
],
dtype="bfloat16",
)
with tempfile.TemporaryDirectory() as tmpdir:
run_merge(config, out_path=tmpdir, options=MergeOptions())
self.run_and_check_merge(config)

def test_linear_merge(self, model_a, model_b):
config = self.two_model_config(model_a, model_b, merge_method="linear")
with tempfile.TemporaryDirectory() as tmpdir:
run_merge(config, out_path=tmpdir, options=MergeOptions())
self.run_and_check_merge(config)

def test_slerp_merge(self, model_a, model_b):
config = self.two_model_config(
model_a, model_b, merge_method="slerp", base_model=model_a
)
config.parameters = {"t": 0.35}
with tempfile.TemporaryDirectory() as tmpdir:
run_merge(config, out_path=tmpdir, options=MergeOptions())
self.run_and_check_merge(config)

def test_task_arithmetic_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
model_a, model_b, merge_method="task_arithmetic", base_model=model_c
)
self.run_and_check_merge(config)

def run_and_check_merge(self, config: MergeConfiguration):
with tempfile.TemporaryDirectory() as tmpdir:
run_merge(config, out_path=tmpdir, options=MergeOptions())
assert os.path.exists(
os.path.join(tmpdir, "model.safetensors.index.json")
), "No index file for merge"
assert os.path.exists(
os.path.join(tmpdir, "config.json")
), "No config json produced by merge"

def two_model_config(
self, model_a, model_b, merge_method: str, base_model: Optional[str] = None
Expand Down

0 comments on commit 9e83c33

Please sign in to comment.