diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index c1cbe7f3e96..d2216241720 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -210,6 +210,7 @@ def local_launcher( quantize: Optional[str] = None, trust_remote_code: bool = False, use_flash_attention: bool = True, + dtype: Optional[str] = None ): port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -237,6 +238,9 @@ def local_launcher( if quantize is not None: args.append("--quantize") args.append(quantize) + if dtype is not None: + args.append("--dtype") + args.append(dtype) if trust_remote_code: args.append("--trust-remote-code") @@ -269,6 +273,7 @@ def docker_launcher( quantize: Optional[str] = None, trust_remote_code: bool = False, use_flash_attention: bool = True, + dtype: Optional[str] = None ): port = random.randint(8000, 10_000) @@ -279,6 +284,9 @@ def docker_launcher( if quantize is not None: args.append("--quantize") args.append(quantize) + if dtype is not None: + args.append("--dtype") + args.append(dtype) if trust_remote_code: args.append("--trust-remote-code") diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index 5f4571b57ed..5a81a4f0911 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -3,7 +3,7 @@ @pytest.fixture(scope="module") def idefics_handle(launcher): - with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2) as handle: + with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16") as handle: yield handle diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index b741a84cfbd..3abe86afd23 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -76,7 +76,7 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = None if dtype is None else dtype.value - if dtype is not None and quantize is not None: + if dtype is not None and quantize not in {None, "bitsandbytes", "bitsandbytes-nf4", "bitsandbytes-fp4"}: raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5b1b5715c5e..ab3b25b7c89 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -87,7 +87,9 @@ def get_model( trust_remote_code: bool, ) -> Model: if dtype is None: - dtype = torch.float16 + # Keep it as default for now and let + # every model resolve their own default dtype. + dtype = None elif dtype == "float16": dtype = torch.float16 elif dtype == "bfloat16":