Skip to content

Commit

Permalink
Revert "fix: traced_getfield move to TracedUtils"
Browse files Browse the repository at this point in the history
This reverts commit f35d911.
  • Loading branch information
avik-pal committed Dec 16, 2024
1 parent d9c167f commit e2d7377
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 31 deletions.
6 changes: 5 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ import ..Reactant:
TracedToConcrete,
append_path,
TracedType
import ..TracedUtils: TracedUtils, traced_getfield

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
@inline traced_getfield(
@nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field
) = Base.getindex(obj, field)

function create_result(tocopy::T, path, result_stores) where {T}
if !isstructtype(typeof(tocopy))
Expand Down
17 changes: 17 additions & 0 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
struct XLAArray{T,N} <: RArray{T,N}
# size::NTuple{N,Int}
end

mutable struct ConcreteRArray{T,N} <: RArray{T,N}
data::XLA.AsyncBuffer
# data::XLAArray{T, N}
shape::NTuple{N,Int}
end

const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}}
const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}}

mutable struct ConcreteRNumber{T} <: RNumber{T}
data::XLA.AsyncBuffer
end

function ConcreteRNumber{T}(
data::T2; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing
) where {T<:Number,T2<:Number}
Expand Down
2 changes: 1 addition & 1 deletion src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ function set_act!(inp, path, reverse, tostore; emptypath=false)
end

for p in path
x = TracedUtils.traced_getfield(x, p)
x = traced_getfield(x, p)
end

#if inp isa Enzyme.Active || !reverse
Expand Down
23 changes: 3 additions & 20 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,6 @@ mutable struct TracedRNumber{T} <: RNumber{T}
end
end

struct XLAArray{T,N} <: RArray{T,N}
# size::NTuple{N,Int}
end

mutable struct ConcreteRArray{T,N} <: RArray{T,N}
data::XLA.AsyncBuffer
# data::XLAArray{T, N}
shape::NTuple{N,Int}
end

const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}}
const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}}

mutable struct ConcreteRNumber{T} <: RNumber{T}
data::XLA.AsyncBuffer
end

const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}

include("Ops.jl")
include("TracedUtils.jl")

Expand All @@ -143,6 +124,8 @@ include("ConcreteRArray.jl")

include("linear_algebra.jl")

const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}

include("ControlFlow.jl")
include("Tracing.jl")
include("Compiler.jl")
Expand All @@ -163,7 +146,7 @@ function Enzyme.make_zero(
return res
end

using .Compiler: @compile, @code_hlo, @jit, create_result, compile
using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace

const registry = Ref{MLIR.IR.DialectRegistry}()
Expand Down
12 changes: 3 additions & 9 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,12 @@ using ..Reactant:
AnyTracedRArray,
MissingTracedValue,
OrderedIdDict,
ConcreteRArray,
ConcreteRNumber
Compiler
import ..Reactant
import ..Reactant.MLIR
import ..ReactantPrimitive
import ..Ops

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
@inline traced_getfield(
@nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field
) = Base.getindex(obj, field)

materialize_traced_array(x::TracedRArray) = x
materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...]
function materialize_traced_array(
Expand Down Expand Up @@ -330,7 +324,7 @@ end

function push_val!(ad_inputs, x, path)
for p in path
x = traced_getfield(x, p)
x = Compiler.traced_getfield(x, p)
end
x = x.mlir_data
return push!(ad_inputs, x)
Expand All @@ -350,7 +344,7 @@ end

function set!(x, path, tostore; emptypath=false)
for p in path
x = traced_getfield(x, p)
x = Compiler.traced_getfield(x, p)
end

x.mlir_data = tostore
Expand Down

0 comments on commit e2d7377

Please sign in to comment.