diff --git a/ext/JutulMakieExt/mesh_plots.jl b/ext/JutulMakieExt/mesh_plots.jl index 7ac51df9..25ece817 100644 --- a/ext/JutulMakieExt/mesh_plots.jl +++ b/ext/JutulMakieExt/mesh_plots.jl @@ -27,14 +27,28 @@ function Jutul.plot_mesh_impl!(ax, m; has_face_filter = !isnothing(faces) has_bface_filter = !isnothing(boundaryfaces) if has_cell_filter || has_face_filter || has_bface_filter + keep_cells = Dict{Int, Bool}() + keep_faces = Dict{Int, Bool}() + keep_bf = Dict{Int, Bool}() + if eltype(cells) == Bool @assert length(cells) == number_of_cells(m) cells = findall(cells) end + if has_cell_filter + for c in cells + keep_cells[c] = true + end + end if eltype(faces) == Bool @assert length(faces) == number_of_faces(m) faces = findall(faces) end + if has_face_filter + for f in faces + keep_faces[f] = true + end + end if eltype(boundaryfaces) == Bool @assert length(boundaryfaces) == number_of_boundary_faces(m) boundaryfaces = findall(boundaryfaces) @@ -43,6 +57,9 @@ function Jutul.plot_mesh_impl!(ax, m; nf = number_of_faces(m) boundaryfaces = deepcopy(boundaryfaces) boundaryfaces .+= nf + for f in boundaryfaces + keep_bf[f] = true + end end ntri = size(tri, 1) keep = fill(false, ntri) @@ -53,13 +70,13 @@ function Jutul.plot_mesh_impl!(ax, m; tri_tmp = tri[i, 1] keep_this = true if has_cell_filter - keep_this = keep_this && cell_ix[tri_tmp] in cells + keep_this = keep_this && haskey(keep_cells, cell_ix[tri_tmp]) end if has_face_filter - keep_this = keep_this && face_ix[tri_tmp] in faces + keep_this = keep_this && haskey(keep_faces, face_ix[tri_tmp]) end if has_bface_filter - keep_this = keep_this && face_ix[tri_tmp] in boundaryfaces + keep_this = keep_this && haskey(keep_bf, face_ix[tri_tmp]) end keep[i] = keep_this end