diff --git a/ai_edge_torch/lowertools/torch_xla_utils.py b/ai_edge_torch/lowertools/torch_xla_utils.py index a63c76dd..5130e62a 100644 --- a/ai_edge_torch/lowertools/torch_xla_utils.py +++ b/ai_edge_torch/lowertools/torch_xla_utils.py @@ -19,9 +19,14 @@ import gc import itertools import logging +import os import tempfile from typing import Any, Dict, Optional, Tuple, Union +if "PJRT_DEVICE" not in os.environ: + # https://github.com/google-ai-edge/ai-edge-torch/issues/326 + os.environ["PJRT_DEVICE"] = "CPU" + from ai_edge_torch import model from ai_edge_torch._convert import conversion_utils from ai_edge_torch._convert import signature as signature_module