diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 7a20aac..3e90a0a 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -37,8 +37,9 @@ jobs: TORCH_TEST: 1 TORCH_INSTALL: 1 TORCH_COMMIT_SHA: "none" + OMP_NUM_THREADS: 1 + MKL_NUM_THREADS: 1 PYTORCH_ENABLE_MPS_FALLBACK: 1 - PYTORCH_MPS_HIGH_WATERMARK_RATIO: 0.0 steps: - uses: actions/checkout@v3 @@ -65,6 +66,13 @@ jobs: run: keras::install_keras() shell: Rscript {0} + # Get a tmux ssh session for interactive debugging + # Controlled via inputs from GitHub webinterface + # See https://github.com/mxschmitt/action-tmate + - name: Setup tmate session + uses: mxschmitt/action-tmate@v3 + if: ${{ github.event_name == 'workflow_dispatch' && inputs.debug_enabled }} + - name: Check if torch is installed run: | library(torch) @@ -72,6 +80,10 @@ jobs: print("Torch is not installed!") install_torch() } + if (torch::backends_mps_is_available()) { + print("LibTorch is built with MPS support!") + } + print(paste0("Default number of threads: ", torch_get_num_threads())) print(torch_randn(1)) shell: Rscript {0}