Skip to content

Commit

Permalink
fix apply! for wrapper transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzoh committed Apr 28, 2021
1 parent 2393d97 commit a9794d5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 17 deletions.
10 changes: 9 additions & 1 deletion src/sequence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@ function makebuffer(pipeline::Sequence, items)
end


function apply!(buffers, pipeline::Sequence, items; randstate = getrandstate(pipeline))
function apply!(buffers::Item, pipeline::Sequence, items::Item; randstate = getrandstate(pipeline))
@assert length(buffers) == length(pipeline.transforms)
for (tfm, buffer, r) in zip(pipeline.transforms, buffers, randstate)
items = apply!(buffer, tfm, items; randstate = r)
end
return items
end

function apply!(buffers::Vector, pipeline::Sequence, items::Vector; randstate = getrandstate(pipeline))
@assert length(buffers) == length(pipeline.transforms)
for (tfm, buffer, r) in zip(pipeline.transforms, buffers, randstate)
items = apply!(buffer, tfm, items; randstate = r)
Expand Down
46 changes: 30 additions & 16 deletions test/buffered.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,6 @@
include("./imports.jl")


# TODO: test apply!(Sequential)

a = randn(5, 5)
item = ArrayItem(a)

tfm = MapElem(x -> x + 1)

buf = apply(tfm, item)

apply!(buf, tfm, item)

DataAugmentation.copyitemdata!([item], [buf])
item
buf

@testset ExtendedTestSet "apply!(buf, ::Map, ...)" begin
newitem() = ArrayItem(randn(5, 5))
tfm = MapElem(x -> x + 1)
Expand Down Expand Up @@ -61,7 +46,7 @@ end
@testset ExtendedTestSet "`Buffered`" begin
newitem() = ArrayItem(randn(5, 5))
# buffer should be created
tb = Buffered(MapElem(x -> x + 1))
tb = Buffered(MapElem(x -> x + one(typeof(x))))
@test isnothing(tb.buffer)
@test_nowarn apply(tb, newitem())
@test !isnothing(tb.buffer)
Expand All @@ -75,6 +60,7 @@ end
buf2 = deepcopy(buf)
@test_nowarn apply!(buf, tb, newitem())
@test !(buf2.data tb.buffer.data)
testapply(Buffered(MapElem(x -> x + one(typeof(x)))), ArrayItem)
end

@testset ExtendedTestSet "`BufferedThreadsafe`" begin
Expand All @@ -95,4 +81,32 @@ end
buf2 = deepcopy(buf)
@test_nowarn apply!(buf, tbt, newitem())
@test !(buf2.data tb.buffer.data)

testapply(BufferedThreadsafe(MapElem(x -> x + one(typeof(x)))), ArrayItem)
end


struct PlusRand <: Transform
end
DataAugmentation.getrandstate(::PlusRand) = rand()
function DataAugmentation.apply(tfm::PlusRand, item::DataAugmentation.AbstractItem; randstate = getrandstate(tfm))
return DataAugmentation.setdata(item, map(x -> x + randstate, itemdata(item)))
end
function DataAugmentation.apply!(buf, tfm::PlusRand, item::DataAugmentation.AbstractItem; randstate = getrandstate(tfm))
map!(x -> x + randstate, itemdata(buf), itemdata(item))
return buf
end

@testset ExtendedTestSet "" begin
item1 = ArrayItem(ones(10))
item2 = ArrayItem(ones(10))
tfm = PlusRand()
@test_nowarn apply(tfm, item1)
titem1, titem2 = apply(tfm, (item1, item2))
a1, a2 = itemdata.((titem1, titem2))
b1, b2 = copy.((a1, a2))
@test a1 == a2
apply!((titem1, titem2), tfm, (item1, item2))
@test a1 == a2
@test a1 != b1
end

0 comments on commit a9794d5

Please sign in to comment.