Skip to content

Commit

Permalink
Support struct types in create_result (#6)
Browse files Browse the repository at this point in the history
* Add `struct` type case to `create_result`

* Fix type matching

* Fix `Symbol` case in `create_result`

* Add `Array` case in `create_result`

* Test compilation on custom `struct` types

* Fix forgotten `include`

* Test mutable struct types

* Fix forgotten `mutable` specifier

* Add comments

* Refactor test

* Test against recursive data structure (linked list)

* Test non-`ConcreteRArray` field

* Fix bug in `traced_type`

* Refactor test of recursive data structure (linked list)

* Fix call to `datatype_fieldcount` in `traced_type`

* Mark commented test as broken test

---------

Co-authored-by: Sergio Sánchez Ramírez <Sergio Sánchez Ramírez>
  • Loading branch information
mofeing authored May 28, 2024
1 parent 417ca7d commit 1d5ab77
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 2 deletions.
34 changes: 32 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ using Enzyme
if aT === nothing
throw("Unhandled type $T")
end
if datatype_fieldcount(aT) === nothing
if Base.datatype_fieldcount(aT) === nothing
throw("Unhandled type $T")
end
end
Expand Down Expand Up @@ -342,7 +342,7 @@ using Enzyme
end

if Val(T) seen
return seen[T]
return T
end

seen = (Val(T), seen...)
Expand Down Expand Up @@ -720,10 +720,40 @@ function generate_jlfunc(concrete_result, client, mod, Nargs, linear_args, linea
end)
return
end
if T <: Array
elems = Symbol[]
for (i, v) in enumerate(tocopy)
sym = Symbol(string(resname)*"_"*string(i))
create_result(v, sym, (path...,i))
push!(elems, sym)
end
push!(concrete_result_maker, quote
$resname = $(eltype(T))[$(elems...)]
end)
return
end
if T <: Int || T <: AbstractFloat || T <: AbstractString || T <: Nothing
push!(concrete_result_maker, :($resname = $tocopy))
return
end
if T <: Symbol
push!(concrete_result_maker, :($resname = $(QuoteNode(tocopy))))
return
end
if isstructtype(T)
elems = Symbol[]
nf = fieldcount(T)
for i in 1:nf
sym = Symbol(resname, :_, i)
create_result(getfield(tocopy, i), sym, (path..., i))
push!(elems, sym)
end
push!(concrete_result_maker, quote
flds = Any[$(elems...)]
$resname = ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), $T, flds, $nf)
end)
return
end

error("canot copy $T")
end
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ include("layout.jl")
include("basic.jl")
include("bcast.jl")
include("nn.jl")
include("struct.jl")
92 changes: 92 additions & 0 deletions test/struct.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
using Reactant
using Test

# from bsc-quantic/Tenet.jl
struct MockTensor{T,N,A<:AbstractArray{T,N}}
data::A
inds::Vector{Symbol}
end

MockTensor(data::A, inds) where {T,N,A<:AbstractArray{T,N}} = MockTensor{T,N,A}(data, inds)
Base.parent(t::MockTensor) = t.data

Base.cos(x::MockTensor) = MockTensor(cos(parent(x)), x.inds)

mutable struct MutableMockTensor{T,N,A<:AbstractArray{T,N}}
data::A
inds::Vector{Symbol}
end

MutableMockTensor(data::A, inds) where {T,N,A<:AbstractArray{T,N}} = MutableMockTensor{T,N,A}(data, inds)
Base.parent(t::MutableMockTensor) = t.data

Base.cos(x::MutableMockTensor) = MutableMockTensor(cos(parent(x)), x.inds)

# modified from JuliaCollections/DataStructures.jl
# NOTE original uses abstract type instead of union, which is not supported
mutable struct MockLinkedList{T}
head::T
tail::Union{MockLinkedList{T},Nothing}
end

function list(x::T...) where {T}
l = nothing
for i in Iterators.reverse(eachindex(x))
l = MockLinkedList{T}(x[i], l)
end
return l
end

Base.sum(x::MockLinkedList{T}) where {T} = sum(x.head) + (!isnothing(x.tail) ? sum(x.tail) : 0)

@testset "Struct" begin
@testset "MockTensor" begin
@testset "immutable" begin
x = MockTensor(rand(4, 4), [:i, :j])
x2 = MockTensor(Reactant.ConcreteRArray(parent(x)), x.inds)

f = Reactant.compile(cos, (x2,))
y = f(x2)

@test y isa MockTensor{Float64,2,Reactant.ConcreteRArray{Float64,(4, 4),2}}
@test isapprox(parent(y), cos.(parent(x)))
@test x.inds == [:i, :j]
end

@testset "mutable" begin
x = MutableMockTensor(rand(4, 4), [:i, :j])
x2 = MutableMockTensor(Reactant.ConcreteRArray(parent(x)), x.inds)

f = Reactant.compile(cos, (x2,))
y = f(x2)

@test y isa MutableMockTensor{Float64,2,Reactant.ConcreteRArray{Float64,(4, 4),2}}
@test isapprox(parent(y), cos.(parent(x)))
@test x.inds == [:i, :j]
end
end

@testset "MockLinkedList" begin
x = [rand(2, 2) for _ in 1:2]
x2 = list(x...)
x3 = Reactant.make_tracer(IdDict(), x2, (), Reactant.ArrayToConcrete, nothing)
x4 = list(Reactant.ConcreteRArray.(x)...)


# TODO this should be able to run without problems, but crashes
@test_broken begin
f = Reactant.compile(identity, (x3,))
isapprox(f(x3), x3)
end

f3 = Reactant.compile(sum, (x3,))
f4 = Reactant.compile(sum, (x4,))

y = sum(x2)
y3 = f3(x3)
y4 = f4(x4)

@test isapprox(y, y3)
@test isapprox(y, y4)
end
end

0 comments on commit 1d5ab77

Please sign in to comment.