Skip to content

Commit

Permalink
Merge pull request #53 from sintefmath/dev
Browse files Browse the repository at this point in the history
Entity tags to main
  • Loading branch information
moyner authored Jan 4, 2024
2 parents 402db0f + 6544903 commit 590b633
Show file tree
Hide file tree
Showing 12 changed files with 416 additions and 84 deletions.
25 changes: 21 additions & 4 deletions ext/JutulMakieExt/mesh_plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,25 @@ function Jutul.plot_mesh_impl!(ax, m; cells = nothing, is_depth = true, outer =
keep[i] = cell_ix[tri[i, 1]] in cells
end
tri = tri[keep, :]
tri, pts = remove_unused_points(tri, pts)
end
f = mesh!(ax, pts, tri; color = color, kwarg...)
return f
end

function remove_unused_points(tri, pts)
unique_pts_ix = unique(vec(tri))
renum = Dict{Int, Int}()
for (i, ix) in enumerate(unique_pts_ix)
renum[ix] = i
end
pts = pts[unique_pts_ix, :]
for i in eachindex(tri)
tri[i] = renum[tri[i]]
end
return (tri, pts)
end

function Jutul.plot_cell_data_impl(m, data;
colorbar = :horizontal,
resolution = default_jutul_resolution(),
Expand Down Expand Up @@ -55,17 +69,20 @@ function Jutul.plot_cell_data_impl!(ax, m, data::AbstractVecOrMat; cells = nothi
@assert length(cells) == nc
cells = findall(cells)
end
new_data = zeros(nc)
@. new_data = NaN
new_data = fill(NaN, nc)
if length(data) == length(cells)
new_data[cells] = data
for (i, j) in enumerate(cells)
new_data[j] = data[i]
end
else
@assert length(data) == nc
for i in cells
new_data[i] = data[i]
end
end
data = new_data
end
@assert length(data) == nc
return mesh!(ax, pts, tri; color = mapper.Cells(data), kwarg...)
color = mapper.Cells(data)
return mesh!(ax, pts, tri; color = color, kwarg...)
end
146 changes: 146 additions & 0 deletions src/core_types/core_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ An entity for something that isn't associated with an entity
"""
struct NoEntity <: JutulEntity end


# Sim model

"""
Expand Down Expand Up @@ -1039,3 +1040,148 @@ function renumber!(x, im::IndexRenumerator)
x[i] = im[v]
end
end

struct EntityTags{T}
tags::Dict{Symbol, Dict{Symbol, Vector{T}}}
count::Int
end

function EntityTags(n; tag_type = Int)
data = Dict{Symbol, Dict{Symbol, Vector{tag_type}}}()
return EntityTags{tag_type}(data, n)
end

Base.haskey(et::EntityTags, k) = Base.haskey(et.tags, k)
Base.keys(et::EntityTags) = Base.keys(et.tags)
Base.getindex(et::EntityTags, arg...) = Base.getindex(et.tags, arg...)
Base.setindex!(et::EntityTags, arg...) = Base.setindex!(et.tags, arg...)

export set_mesh_entity_tag!, get_mesh_entity_tag, mesh_entity_has_tag
struct MeshEntityTags{T}
tags::Dict{JutulEntity, EntityTags{T}}
end

Base.haskey(et::MeshEntityTags, k) = Base.haskey(et.tags, k)
Base.keys(et::MeshEntityTags) = Base.keys(et.tags)
Base.getindex(et::MeshEntityTags, arg...) = Base.getindex(et.tags, arg...)
Base.setindex!(et::MeshEntityTags, arg...) = Base.setindex!(et.tags, arg...)

function Base.show(io::IO, t::MIME"text/plain", options::MeshEntityTags{T}) where T
println(io, "MeshEntityTags stored as $T:")
for (k, v) in pairs(options.tags)
kv = keys(v)
if length(kv) == 0
kv = "<no tags>"
else
s = map(x -> "x $(keys(v[x]))", collect(kv))
kv = join(s, ",")
end
println(io, " $k: $(kv)")
end
end

function MeshEntityTags(g::JutulMesh; kwarg...)
return MeshEntityTags(declare_entities(g); kwarg...)
end

function MeshEntityTags(entities = missing; tag_type = Int)
tags = Dict{JutulEntity, EntityTags{tag_type}}()
if !ismissing(entities)
for epair in entities
tags[epair.entity] = EntityTags(epair.count, tag_type = tag_type)
end
end
return MeshEntityTags{tag_type}(tags)
end

function initialize_entity_tags!(g::JutulMesh)
tags = mesh_entity_tags(g)
for epair in declare_entities(g)
tags[epair.entity] = EntityTags(epair.count)
end
return g
end

function mesh_entity_tags(x)
return x.tags
end

function set_mesh_entity_tag!(m::JutulMesh, arg...; kwarg...)
set_mesh_entity_tag!(mesh_entity_tags(m), arg...; kwarg...)
return m
end

function set_mesh_entity_tag!(met::MeshEntityTags{T}, entity::JutulEntity, tag_group::Symbol, tag_value::Symbol, ix::Vector{T}; allow_merge = true, allow_new = true) where T
tags = met.tags[entity]
tags::EntityTags{T}
if !haskey(tags, tag_group)
@assert allow_new "allow_new = false and tag group $tag_group for entity $entity already exists."
tags[tag_group] = Dict{Symbol, Vector{T}}()
end
@assert maximum(ix, init = one(T)) <= tags.count "Tag value must not exceed $(tags.count) for $entity"
@assert minimum(ix, init = tags.count) > 0 "Tags must have positive indices."

tg = tags[tag_group]
if haskey(tg, tag_value)
@assert allow_merge "allow_merge = false and tag $tag_value in group $tag_group for entity $entity already exists."
vals = tg[tag_value]
for i in ix
push!(vals, i)
end
else
tg[tag_value] = copy(ix)
end
vals = tg[tag_value]
sort!(vals)
unique!(vals)
return met
end


"""
get_mesh_entity_tag(met::JutulMesh, entity::JutulEntity, tag_group::Symbol, tag_value = missing; throw = true)
Get the indices tagged for `entity` in group `tag_group`, optionally for the
specific `tag_value`. If `ismissing(tag_value)`, the Dict containing the tag
group will be returned.
"""
function get_mesh_entity_tag(m::JutulMesh, arg...; kwarg...)
return get_mesh_entity_tag(mesh_entity_tags(m), arg...; kwarg...)
end

function get_mesh_entity_tag(met::MeshEntityTags, entity::JutulEntity, tag_group::Symbol, tag_value = missing; throw = true)
out = missing
tags = met.tags[entity]
if haskey(tags, tag_group)
tg = tags[tag_group]
if ismissing(tag_value)
out = tg
elseif haskey(tg, tag_value)
tag_value::Symbol
out = tg[tag_value]
end
end
if ismissing(out) && throw
if ismissing(tag_value)
error("Tag $tag_group not found in $entity.")
else
error("Tag $tag_group.$tag_value not found in $entity.")
end
end
return out
end

function mesh_entity_has_tag(m::JutulMesh, arg...; kwarg...)
return mesh_entity_has_tag(mesh_entity_tags(m), arg...; kwarg...)
end

function mesh_entity_has_tag(met::MeshEntityTags, entity::JutulEntity, tag_group::Symbol, tag_value::Symbol, ix)
tag = get_mesh_entity_tag(met, entity, tag_group, tag_value)
pos = searchsortedfirst(tag, ix)
if pos > length(tag)
out = false
else
out = tag[pos] == ix
end
return out
end
Loading

0 comments on commit 590b633

Please sign in to comment.