Skip to content

Commit

Permalink
Add AMDGPU extension (#95)
Browse files Browse the repository at this point in the history
* Add AMDGPU extension

* Update Project.toml [skip ci]
  • Loading branch information
jipolanco authored Jun 14, 2024
1 parent 61c80d2 commit 8864677
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PencilArrays"
uuid = "0e08944d-e94e-41b1-9406-dcf66b6a9d2e"
authors = ["Juan Ignacio Polanco <[email protected]> and contributors"]
version = "0.19.4"
version = "0.19.5"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -19,15 +19,18 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
VersionParsing = "81def892-9a0e-5fdd-b105-ffc91e053289"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"

[extensions]
PencilArraysAMDGPUExt = ["AMDGPU"]
PencilArraysDiffEqExt = ["DiffEqBase"]
PencilArraysHDF5Ext = ["HDF5"]

[compat]
Adapt = "3, 4"
AMDGPU = "0.8, 0.9"
DiffEqBase = "6"
HDF5 = "0.16, 0.17"
JSON3 = "1.4"
Expand Down
22 changes: 22 additions & 0 deletions ext/PencilArraysAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module PencilArraysAMDGPUExt

using PencilArrays: typeof_array, typeof_ptr
using PencilArrays.Transpositions: Transpositions
using AMDGPU: ROCVector

# Workaround `unsafe_wrap` not allowing the `own` keyword argument in the AMDGPU
# implementation.
# Moreover, one needs to set the `lock = false` argument to indicate that we want to wrap an
# array which is already in the GPU.
function Transpositions.unsafe_as_array(::Type{T}, x::ROCVector{UInt8}, dims::Tuple) where {T}
p = typeof_ptr(x){T}(pointer(x))
unsafe_wrap(typeof_array(x), p, dims; lock = false)
end

# Workaround `unsafe_wrap` for ROCArrays not providing a definition for dims::Integer.
# We convert that argument to a tuple, which is accepted by the implementation in AMDGPU.
function Transpositions.unsafe_as_array(::Type{T}, x::ROCVector{UInt8}, N::Integer) where {T}
Transpositions.unsafe_as_array(T, x, (N,))
end

end

0 comments on commit 8864677

Please sign in to comment.