Skip to content

Commit

Permalink
ROCm 6.0 prep changes (#4537)
Browse files Browse the repository at this point in the history
* ROCm 6.0 prep changes

* PR feedback

* Try updating apex
  • Loading branch information
loadams authored Oct 20, 2023
1 parent 488f7e2 commit e238351
Show file tree
Hide file tree
Showing 5 changed files with 3 additions and 16 deletions.
1 change: 1 addition & 0 deletions .github/workflows/amd-mi200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jobs:
run: |
git clone https://github.com/ROCmSoftwarePlatform/apex.git
cd apex
git checkout torch_2.1_higher
CURRENT_VER=$(git rev-parse HEAD)
INSTALLED_VER=$(cat /blob/amd-apex/.venv_installed_version)
if [[ "$CURRENT_VER" != "$INSTALLED_VER" ]]; then
Expand Down
6 changes: 1 addition & 5 deletions op_builder/cpu_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,5 @@ def include_paths(self):
elif not self.is_rocm_pytorch():
CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
else:
CUDA_INCLUDE = [
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"),
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"),
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"),
]
CUDA_INCLUDE = []
return ['csrc/includes'] + CUDA_INCLUDE
6 changes: 1 addition & 5 deletions op_builder/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,5 @@ def include_paths(self):
elif not self.is_rocm_pytorch():
CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
else:
CUDA_INCLUDE = [
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"),
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"),
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"),
]
CUDA_INCLUDE = []
return ['csrc/includes'] + CUDA_INCLUDE
3 changes: 0 additions & 3 deletions op_builder/random_ltd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,4 @@ def sources(self):

def include_paths(self):
includes = ['csrc/includes']
if self.is_rocm_pytorch():
from torch.utils.cpp_extension import ROCM_HOME
includes += ['{}/hiprand/include'.format(ROCM_HOME), '{}/rocrand/include'.format(ROCM_HOME)]
return includes
3 changes: 0 additions & 3 deletions op_builder/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,4 @@ def sources(self):

def include_paths(self):
includes = ['csrc/includes']
if self.is_rocm_pytorch():
from torch.utils.cpp_extension import ROCM_HOME
includes += ['{}/hiprand/include'.format(ROCM_HOME), '{}/rocrand/include'.format(ROCM_HOME)]
return includes

0 comments on commit e238351

Please sign in to comment.