Skip to content

Commit

Permalink
refactor bounds with new Bounds data type
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzoh committed Apr 28, 2021
1 parent a9794d5 commit 266764e
Show file tree
Hide file tree
Showing 23 changed files with 288 additions and 250 deletions.
30 changes: 16 additions & 14 deletions src/items/image.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,20 @@ showitems(item)
```
"""
struct Image{N,T,B} <: AbstractArrayItem{N,T}
struct Image{N,T} <: AbstractArrayItem{N,T}
data::AbstractArray{T,N}
bounds::AbstractArray{<:SVector{N,B},N}
bounds::Bounds{N}
end

Image(data) = Image(data, size(data))
Image(data) = Image(data, Bounds(axes(data)))

function Image(data::AbstractArray{T,N}, sz::NTuple{N,Int}) where {T,N}
bounds = makebounds(sz)
return Image(data, bounds)
return Image(data, Bounds(sz))
end


Base.show(io::IO, item::Image{N,T}) where {N,T} =
print(io, "Image{$N, $T}() with size $(size(itemdata(item)))")
print(io, "Image{$N, $T}() with bounds $(item.bounds)")


function showitem!(img, image::Image{2, <:Colorant})
Expand All @@ -72,22 +71,25 @@ getbounds(image::Image) = image.bounds
# We have to pass the inverse of the projection `P` as it uses backward
# mode warping.

function project(P, image::Image{N, T}, indices) where {N, T}
## Transform the bounds along with the image
bounds_ = P.(getbounds(image))
data_ = warp(itemdata(image), inv(P), indices, zero(T))
return Image(data_, makebounds(indices))
function project(P, image::Image{N, T}, bounds::Bounds) where {N, T}
# TODO: make interpolation scheme and boundary conditions configurable
data_ = warp(
itemdata(image),
inv(P),
bounds.rs,
zero(T))
return Image(data_, bounds)
end

# The inplace version `project!` is quite similar. Note `indices` are not needed
# as they are implicitly given by the buffer.

function project!(bufimage::Image, P, image::Image{N, T}, indices) where {N, T}
a = OffsetArray(parent(itemdata(bufimage)), indices)
function project!(bufimage::Image, P, image::Image{N, T}, bounds::Bounds{N}) where {N, T}
a = OffsetArray(parent(itemdata(bufimage)), bounds.rs)
res = warp!(
a,
box_extrapolation(itemdata(image), zero(T)),
inv(P),
)
return Image(res, P.(getbounds(image)))
return Image(res, bounds)
end
11 changes: 6 additions & 5 deletions src/items/keypoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ showitems(item)
"""
struct Keypoints{N, T, S<:Union{SVector{N, T}, Nothing}, M} <: AbstractArrayItem{M, S}
data::AbstractArray{S, M}
bounds::AbstractArray{<:SVector{N, Float32}, N}
bounds::Bounds{N}
end


function Keypoints(data::AbstractArray{S, M}, sz::NTuple{N, Int}) where {T, N, S<:Union{SVector{N, T}, Nothing}, M}
return Keypoints{N, T, S, M}(data, makebounds(sz, Float32))
function Keypoints(data, sz::NTuple{N, Int}) where N
return Keypoints(data, Bounds(sz))
end


Expand All @@ -39,10 +39,11 @@ Base.show(io::IO, item::Keypoints{N, T, M}) where {N, T, M} =
getbounds(keypoints::Keypoints) = keypoints.bounds


function project(P, keypoints::Keypoints{N, T}, indices) where {N, T}
function project(P, keypoints::Keypoints{N, T}, bounds::Bounds{N}) where {N, T}
# TODO: convert back to `T`?
return Keypoints(
map(fmap(P), keypoints.data),
makebounds(indices),
bounds,
)
end

Expand Down
38 changes: 18 additions & 20 deletions src/items/mask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ mask = MaskMulti(rand(1:3, 100, 100))
showitems(mask)
```
"""
struct MaskMulti{N, T<:Integer, U, B} <: AbstractArrayItem{N, T}
struct MaskMulti{N, T<:Integer, U} <: AbstractArrayItem{N, T}
data::AbstractArray{T, N}
classes::AbstractVector{U}
bounds::AbstractArray{<:SVector{N, B}, N}
bounds::Bounds{N}
end


function MaskMulti(a::AbstractArray, classes = unique(a))
bounds = makebounds(size(a))
bounds = Bounds(size(a))
minimum(a) >= 1 || error("Class values must start at 1")
return MaskMulti(a, classes, bounds)
end
Expand All @@ -39,21 +39,20 @@ Base.show(io::IO, mask::MaskMulti{N, T}) where {N, T} =
getbounds(mask::MaskMulti) = mask.bounds


function project(P, mask::MaskMulti, indices)
function project(P, mask::MaskMulti, bounds::Bounds)
a = itemdata(mask)
etp = mask_extrapolation(a)
res = warp(etp, inv(P), indices)
res = warp(etp, inv(P), bounds.rs)
return MaskMulti(
res,
mask.classes,
makebounds(indices)
bounds
)
end


function project!(bufmask::MaskMulti, P, mask::MaskMulti, indices)
a = OffsetArray(parent(itemdata(bufmask)), indices)
bounds_ = P.(getbounds(mask))
function project!(bufmask::MaskMulti, P, mask::MaskMulti, bounds)
a = OffsetArray(parent(itemdata(bufmask)), bounds.rs)
warp!(
a,
mask_extrapolation(itemdata(mask)),
Expand All @@ -62,7 +61,7 @@ function project!(bufmask::MaskMulti, P, mask::MaskMulti, indices)
return MaskMulti(
a,
mask.classes,
makebounds(indices)
bounds
)
end

Expand Down Expand Up @@ -96,12 +95,12 @@ mask = MaskBinary(rand(Bool, 100, 100))
showitems(mask)
```
"""
struct MaskBinary{N, B} <: AbstractArrayItem{N, Bool}
struct MaskBinary{N} <: AbstractArrayItem{N, Bool}
data::AbstractArray{Bool, N}
bounds::AbstractArray{<:SVector{N, B}, N}
bounds::Bounds{N}
end

function MaskBinary(a::AbstractArray{Bool, N}, bounds = makebounds(size(a))) where N
function MaskBinary(a::AbstractArray{Bool, N}, bounds = Bounds(size(a))) where N
return MaskBinary(a, bounds)
end

Expand All @@ -110,26 +109,25 @@ Base.show(io::IO, mask::MaskBinary{N}) where {N} =

getbounds(mask::MaskBinary) = mask.bounds

function project(P, mask::MaskBinary, indices)
function project(P, mask::MaskBinary, bounds::Bounds)
etp = mask_extrapolation(itemdata(mask))
return MaskBinary(
warp(etp, inv(P), indices),
makebounds(indices),
warp(etp, inv(P), bounds.rs),
bounds,
)
end


function project!(bufmask::MaskBinary, P, mask::MaskBinary, indices)
bounds_ = P.(getbounds(mask))
a = OffsetArray(parent(itemdata(bufmask)), indices)
function project!(bufmask::MaskBinary, P, mask::MaskBinary, bounds)
a = OffsetArray(parent(itemdata(bufmask)), bounds.rs)
res = warp!(
a,
mask_extrapolation(itemdata(mask)),
inv(P),
)
return MaskBinary(
a,
makebounds(indices)
bounds
)
end

Expand Down
69 changes: 54 additions & 15 deletions src/projective/affine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,21 @@ end


function getprojection(scale::ScaleKeepAspect{N}, bounds; randstate = nothing) where N
ratio = maximum(scale.minlengths ./ boundssize(bounds))
return scaleprojection(Tuple(ratio for _ in 1:N))
ratio = maximum(scale.minlengths ./ length.(bounds.rs))
upperleft = SVector{N, Float32}(minimum.(bounds.rs)) .- 1
P = scaleprojection(Tuple(ratio for _ in 1:N))
if upperleft != SVector(0, 0)
P = P Translation(-upperleft)
end
return P
end

function projectionbounds(tfm::ScaleKeepAspect{N}, P, bounds::Bounds{N}; randstate = nothing) where N
origsz = length.(bounds.rs)
ratio = maximum(tfm.minlengths ./ origsz)
sz = round.(Int, ratio .* origsz)
bounds_ = transformbounds(bounds, P)
return offsetcropbounds(sz, bounds_, ntuple(_ -> 1., N))
end

"""
Expand All @@ -57,12 +70,22 @@ struct ScaleFixed{N} <: ProjectiveTransform
end


function getprojection(scale::ScaleFixed{N}, bounds; randstate = nothing) where N
ratios = scale.sizes ./ boundssize(bounds)
return scaleprojection(ratios)
function getprojection(scale::ScaleFixed, bounds; randstate = nothing)
ratios = scale.sizes ./ length.(bounds.rs)
upperleft = SVector{2, Float32}(minimum.(bounds.rs)) .- 1
P = scaleprojection(ratios)
if upperleft != SVector(0, 0)
P = P Translation(-upperleft)
end
return P
end


function projectionbounds(tfm::ScaleFixed{N}, P, bounds::Bounds{N}; randstate = nothing) where N
bounds_ = transformbounds(bounds, P)
return offsetcropbounds(tfm.sizes, bounds_, (1., 1.))
end

"""
Zoom(scales = (1, 1.2)) <: ProjectiveTransform
Zoom(distribution)
Expand All @@ -78,7 +101,7 @@ Zoom(scales::NTuple{2, T} = (1., 1.2)) where T = Zoom(Uniform(scales[1], scales[

getrandstate(tfm::Zoom) = rand(tfm.dist)

function getprojection(tfm::Zoom, bounds::AbstractArray{<:SVector{N}}; randstate = getrandstate(tfm)) where N
function getprojection(tfm::Zoom, bounds::Bounds{N}; randstate = getrandstate(tfm)) where N
ratio = randstate
return scaleprojection(ntuple(_ -> ratio, N))
end
Expand Down Expand Up @@ -109,12 +132,12 @@ getrandstate(tfm::Rotate) = rand(tfm.dist)

function getprojection(
tfm::Rotate,
bounds::AbstractArray{<:SVector{N, T}};
randstate = getrandstate(tfm)) where {N, T}
bounds::Bounds{2};
randstate = getrandstate(tfm))
γ = randstate
middlepoint = sum(bounds) ./ length(bounds)
middlepoint = SVector{2, Float32}(mean.(bounds.rs))
r = γ / 360 * 2pi
return recenter(RotMatrix(convert(T, r)), middlepoint)
return recenter(RotMatrix(convert(Float32, r)), middlepoint)
end


Expand All @@ -140,16 +163,32 @@ end


function getprojection(tfm::Reflect, bounds; randstate = getrandstate(tfm))
midpoint = sum(bounds) ./ length(bounds)
r = tfm.γ / 360 * 2pi
return recenter(reflectionmatrix(r), midpoint)
return centered(LinearMap(reflectionmatrix(r)), bounds)
end

"""
centered(P, bounds)
Transform `P` so that is applied around the center of `bounds`
instead of the origin
"""
function centered(P, bounds::Bounds{2})
upperleft = minimum.(bounds.rs)
bottomright = maximum.(bounds.rs)

midpoint = SVector{2, Float32}((bottomright .- upperleft) ./ 2) .+ SVector{2, Float32}(.5, .5)
return recenter(P, midpoint)
end


FlipX() = Reflect(180)
FlipY() = Reflect(90)

reflectionmatrix(r) = SMatrix{2, 2, Float32}(cos(2r), sin(2r), sin(2r), -cos(2r))
function reflectionmatrix(r)
A = SMatrix{2, 2, Float32}(cos(2r), sin(2r), sin(2r), -cos(2r))
return round.(A; digits = 12)
end


"""
Expand All @@ -171,12 +210,12 @@ struct PinOrigin <: ProjectiveTransform end

function getprojection(::PinOrigin, bounds; randstate = nothing)
# TODO: translate by actual minimum x and y coordinates
return Translation(-bounds[1])
return Translation((-SVector{2, Float32}(minimum.(bounds.rs))) .+ 1)
end

function apply(::PinOrigin, item::Union{Image, MaskMulti, MaskBinary}; randstate = nothing)
item = @set item.data = parent(itemdata(item))
item = @set item.bounds = makebounds(size(itemdata(item)))
item = @set item.bounds = Bounds(size(itemdata(item)))
return item
end

Expand Down
Loading

0 comments on commit 266764e

Please sign in to comment.