diff --git a/aana/deployments/hqq_text_generation_deployment.py b/aana/deployments/hqq_text_generation_deployment.py index 7f00a5c4..bec971f2 100644 --- a/aana/deployments/hqq_text_generation_deployment.py +++ b/aana/deployments/hqq_text_generation_deployment.py @@ -118,6 +118,7 @@ async def apply_config(self, config: dict[str, Any]): self.dtype = Dtype.BFLOAT16 if config_obj.quantize_on_fly: + self.model_kwargs.pop("device_map", None) self.model = AutoModelForCausalLM.from_pretrained( self.model_id, torch_dtype=self.dtype.to_torch(), **self.model_kwargs )