diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 59e053ee444..d386c09b7d8 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -690,7 +690,7 @@ def check_dummy_inputs_are_allowed( The model input names. """ - forward = model.forward if is_torch_available() and isinstance(model, nn.Module) else model.call + forward = model.forward if is_torch_available() and isinstance(model, torch.nn.Module) else model.call forward_parameters = signature(forward).parameters forward_inputs_set = set(forward_parameters.keys()) dummy_input_names = set(dummy_input_names)