From 9e83c334364ebfe405480e524fcd532701a2cf4f Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Sat, 30 Dec 2023 20:04:54 -0800 Subject: [PATCH] Move tasks.py into io --- mergekit/{ => io}/tasks.py | 0 mergekit/merge.py | 2 +- mergekit/merge_methods/base.py | 2 +- .../generalized_task_arithmetic.py | 2 +- mergekit/merge_methods/linear.py | 2 +- mergekit/merge_methods/passthrough.py | 2 +- mergekit/merge_methods/slerp.py | 2 +- mergekit/merge_methods/tokenizer_permute.py | 2 +- mergekit/plan.py | 2 +- tests/test_merges.py | 22 ++++++++++++------- 10 files changed, 22 insertions(+), 16 deletions(-) rename mergekit/{ => io}/tasks.py (100%) diff --git a/mergekit/tasks.py b/mergekit/io/tasks.py similarity index 100% rename from mergekit/tasks.py rename to mergekit/io/tasks.py diff --git a/mergekit/merge.py b/mergekit/merge.py index ab48d813..312b7e8c 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -23,8 +23,8 @@ from mergekit.common import MergeOptions, ModelReference from mergekit.config import MergeConfiguration from mergekit.graph import Executor +from mergekit.io.tasks import LoaderCache from mergekit.plan import MergePlanner -from mergekit.tasks import LoaderCache from mergekit.tokenizer import TokenizerInfo diff --git a/mergekit/merge_methods/base.py b/mergekit/merge_methods/base.py index dee79731..d9780349 100644 --- a/mergekit/merge_methods/base.py +++ b/mergekit/merge_methods/base.py @@ -20,7 +20,7 @@ from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task -from mergekit.tasks import GatherTensors +from mergekit.io.tasks import GatherTensors class ConfigParameterDef(BaseModel): diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index ea1fb42f..384a9ae5 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -23,9 +23,9 @@ from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod from mergekit.sparsify import SparsificationMethod, sparsify -from mergekit.tasks import GatherTensors class ConsensusMethod(str, Enum): diff --git a/mergekit/merge_methods/linear.py b/mergekit/merge_methods/linear.py index 96a2e776..7a3b88e7 100644 --- a/mergekit/merge_methods/linear.py +++ b/mergekit/merge_methods/linear.py @@ -20,8 +20,8 @@ from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod -from mergekit.tasks import GatherTensors class LinearMergeTask(Task[torch.Tensor]): diff --git a/mergekit/merge_methods/passthrough.py b/mergekit/merge_methods/passthrough.py index 1165e484..ee6515bb 100644 --- a/mergekit/merge_methods/passthrough.py +++ b/mergekit/merge_methods/passthrough.py @@ -20,8 +20,8 @@ from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod -from mergekit.tasks import GatherTensors class PassthroughMergeTask(Task[torch.Tensor]): diff --git a/mergekit/merge_methods/slerp.py b/mergekit/merge_methods/slerp.py index 907eed89..00160142 100644 --- a/mergekit/merge_methods/slerp.py +++ b/mergekit/merge_methods/slerp.py @@ -21,8 +21,8 @@ from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod -from mergekit.tasks import GatherTensors class SlerpTask(Task[torch.Tensor]): diff --git a/mergekit/merge_methods/tokenizer_permute.py b/mergekit/merge_methods/tokenizer_permute.py index d4b1ffc7..5de2286c 100644 --- a/mergekit/merge_methods/tokenizer_permute.py +++ b/mergekit/merge_methods/tokenizer_permute.py @@ -21,9 +21,9 @@ from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod from mergekit.merge_methods.slerp import slerp -from mergekit.tasks import GatherTensors from mergekit.tokenizer import BuildTokenizer, TokenizerInfo diff --git a/mergekit/plan.py b/mergekit/plan.py index 8a755c6e..6ca035d5 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -26,9 +26,9 @@ OutputSliceDefinition, ) from mergekit.graph import Task +from mergekit.io.tasks import FinalizeModel, GatherTensors, SaveTensor, TensorWriterTask from mergekit.merge_methods import MergeMethod from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge -from mergekit.tasks import FinalizeModel, GatherTensors, SaveTensor, TensorWriterTask from mergekit.tokenizer import BuildTokenizer diff --git a/tests/test_merges.py b/tests/test_merges.py index cf2fae5b..95999564 100644 --- a/tests/test_merges.py +++ b/tests/test_merges.py @@ -1,3 +1,4 @@ +import os import tempfile from typing import Optional @@ -49,8 +50,7 @@ def test_gpt2_copy(self): models=[InputModelDefinition(model="gpt2")], dtype="bfloat16", ) - with tempfile.TemporaryDirectory() as tmpdir: - run_merge(config, out_path=tmpdir, options=MergeOptions()) + self.run_and_check_merge(config) def test_gpt2_stack(self): config = MergeConfiguration( @@ -63,28 +63,34 @@ def test_gpt2_stack(self): ], dtype="bfloat16", ) - with tempfile.TemporaryDirectory() as tmpdir: - run_merge(config, out_path=tmpdir, options=MergeOptions()) + self.run_and_check_merge(config) def test_linear_merge(self, model_a, model_b): config = self.two_model_config(model_a, model_b, merge_method="linear") - with tempfile.TemporaryDirectory() as tmpdir: - run_merge(config, out_path=tmpdir, options=MergeOptions()) + self.run_and_check_merge(config) def test_slerp_merge(self, model_a, model_b): config = self.two_model_config( model_a, model_b, merge_method="slerp", base_model=model_a ) config.parameters = {"t": 0.35} - with tempfile.TemporaryDirectory() as tmpdir: - run_merge(config, out_path=tmpdir, options=MergeOptions()) + self.run_and_check_merge(config) def test_task_arithmetic_merge(self, model_a, model_b, model_c): config = self.two_model_config( model_a, model_b, merge_method="task_arithmetic", base_model=model_c ) + self.run_and_check_merge(config) + + def run_and_check_merge(self, config: MergeConfiguration): with tempfile.TemporaryDirectory() as tmpdir: run_merge(config, out_path=tmpdir, options=MergeOptions()) + assert os.path.exists( + os.path.join(tmpdir, "model.safetensors.index.json") + ), "No index file for merge" + assert os.path.exists( + os.path.join(tmpdir, "config.json") + ), "No config json produced by merge" def two_model_config( self, model_a, model_b, merge_method: str, base_model: Optional[str] = None