diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index a6e8a76f4f..625af75841 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -34,7 +34,7 @@ from ..utils.constant import _TASK_ALIASES from ..utils.import_utils import is_torch_version -from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS +from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device logger = logging.getLogger(__name__) @@ -63,7 +63,11 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") - return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None} + return { + key: recursive_to_device(dummy_inputs[key], model.device) + for key in signature.parameters + if dummy_inputs.get(key, None) is not None + } def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):