Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FineTuning AutoModelForSequenceClassification.from_pretrained(meta-llama/Llama-3.2-1B) Bug:RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument target in method wrapper_CUDA_nll_loss_forward) and awq importing #35365

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .generation.configuration_utils import CompileConfig, GenerationConfig
from .generation import GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llama/modeling_llama.py
alestrami marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,7 @@ def forward(

loss = None
if labels is not None:
labels = labels.to(logits.device)
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)

if not return_dict:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
if package_exists:
try:
# Primary method to get the package version
if pkg_name == 'awq' : pkg_name='autoawq'
package_version = importlib.metadata.version(pkg_name)
except importlib.metadata.PackageNotFoundError:
# Fallback method: Only for "torch" and versions containing "dev"
Expand Down