Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Oct 5, 2024
1 parent 5ca1c51 commit 0924993
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 8 deletions.
1 change: 0 additions & 1 deletion mergekit/architecture/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing_extensions import Literal



class WeightInfo(BaseModel, frozen=True):
"""Information about an individual weight tensor in a model.
Expand Down
6 changes: 3 additions & 3 deletions mergekit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ class MergeConfiguration(BaseModel):
merge_method: str
base_model: Optional[ModelReference] = None
dtype: Optional[str] = None
tokenizer_source: Union[Literal["union"], Literal["base"], ModelReference, None] = (
None
)
tokenizer_source: Union[
Literal["union"], Literal["base"], ModelReference, None
] = None
tokenizer: Optional[TokenizerConfig] = None
chat_template: Optional[str] = None
out_dtype: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion mergekit/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import logging
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, List, Optional, Tuple

from mergekit import merge_methods
from mergekit.architecture import (
Expand Down
18 changes: 15 additions & 3 deletions tests/test_chat_template.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
from typing import Optional

from common import run_and_check_merge
import pytest
from common import make_picollama, run_and_check_merge
from test_tokenizer import make_tokenizer
from transformers import AutoTokenizer

from mergekit.config import InputModelDefinition, MergeConfiguration
from test_tokenizer import model_base # pylint: disable=unused-import
from test_basic_merges import model_b # pylint: disable=unused-import


@pytest.fixture(scope="session")
def model_base(tmp_path_factory):
model_path = make_picollama(tmp_path_factory.mktemp("model_base"), vocab_size=64)
make_tokenizer(vocab_size=64, added_tokens=[]).save_pretrained(model_path)
return model_path


@pytest.fixture(scope="session")
def model_b(tmp_path_factory):
return make_picollama(tmp_path_factory.mktemp("model_b"))


def check_chat_template(model_path: str, needle: Optional[str] = None):
Expand Down

0 comments on commit 0924993

Please sign in to comment.