From bd09407763d2e047423845b6b8d34996b5bd3eda Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Fri, 1 Nov 2024 11:10:29 -0700 Subject: [PATCH] set PJRT_DEVICE before loading torch_xla Mitigation for https://github.com/google-ai-edge/ai-edge-torch/issues/326 PiperOrigin-RevId: 692234642 --- ai_edge_torch/lowertools/torch_xla_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ai_edge_torch/lowertools/torch_xla_utils.py b/ai_edge_torch/lowertools/torch_xla_utils.py index a63c76dd..5130e62a 100644 --- a/ai_edge_torch/lowertools/torch_xla_utils.py +++ b/ai_edge_torch/lowertools/torch_xla_utils.py @@ -19,9 +19,14 @@ import gc import itertools import logging +import os import tempfile from typing import Any, Dict, Optional, Tuple, Union +if "PJRT_DEVICE" not in os.environ: + # https://github.com/google-ai-edge/ai-edge-torch/issues/326 + os.environ["PJRT_DEVICE"] = "CPU" + from ai_edge_torch import model from ai_edge_torch._convert import conversion_utils from ai_edge_torch._convert import signature as signature_module