diff --git a/mergekit/architecture/base.py b/mergekit/architecture/base.py index c3c62d5b..4ac3ef63 100644 --- a/mergekit/architecture/base.py +++ b/mergekit/architecture/base.py @@ -21,7 +21,6 @@ from typing_extensions import Literal - class WeightInfo(BaseModel, frozen=True): """Information about an individual weight tensor in a model. diff --git a/mergekit/config.py b/mergekit/config.py index ed874852..6edb579e 100644 --- a/mergekit/config.py +++ b/mergekit/config.py @@ -102,9 +102,9 @@ class MergeConfiguration(BaseModel): merge_method: str base_model: Optional[ModelReference] = None dtype: Optional[str] = None - tokenizer_source: Union[Literal["union"], Literal["base"], ModelReference, None] = ( - None - ) + tokenizer_source: Union[ + Literal["union"], Literal["base"], ModelReference, None + ] = None tokenizer: Optional[TokenizerConfig] = None chat_template: Optional[str] = None out_dtype: Optional[str] = None diff --git a/mergekit/plan.py b/mergekit/plan.py index 2f370422..81e3d476 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -15,7 +15,7 @@ import logging from functools import lru_cache -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, List, Optional, Tuple from mergekit import merge_methods from mergekit.architecture import ( diff --git a/tests/test_chat_template.py b/tests/test_chat_template.py index f678b25e..2bd41cde 100644 --- a/tests/test_chat_template.py +++ b/tests/test_chat_template.py @@ -1,11 +1,23 @@ from typing import Optional -from common import run_and_check_merge +import pytest +from common import make_picollama, run_and_check_merge +from test_tokenizer import make_tokenizer from transformers import AutoTokenizer from mergekit.config import InputModelDefinition, MergeConfiguration -from test_tokenizer import model_base # pylint: disable=unused-import -from test_basic_merges import model_b # pylint: disable=unused-import + + +@pytest.fixture(scope="session") +def model_base(tmp_path_factory): + model_path = make_picollama(tmp_path_factory.mktemp("model_base"), vocab_size=64) + make_tokenizer(vocab_size=64, added_tokens=[]).save_pretrained(model_path) + return model_path + + +@pytest.fixture(scope="session") +def model_b(tmp_path_factory): + return make_picollama(tmp_path_factory.mktemp("model_b")) def check_chat_template(model_path: str, needle: Optional[str] = None):