diff --git a/src/OctreeDistributedDiscreteModels.jl b/src/OctreeDistributedDiscreteModels.jl index 144a76a..ee58b56 100644 --- a/src/OctreeDistributedDiscreteModels.jl +++ b/src/OctreeDistributedDiscreteModels.jl @@ -1599,9 +1599,23 @@ end # Assumptions. Either: # A) model.parts MPI tasks are included in parts_redistributed_model MPI tasks; or # B) model.parts MPI tasks include parts_redistributed_model MPI tasks +const WeightsArrayType=Union{Nothing,MPIArray{<:Vector{<:Integer}}} function GridapDistributed.redistribute(model::OctreeDistributedDiscreteModel{Dc,Dp}, - parts_redistributed_model=model.parts) where {Dc,Dp} + parts_redistributed_model=model.parts; + weights::WeightsArrayType=nothing) where {Dc,Dp} parts = (parts_redistributed_model === model.parts) ? model.parts : parts_redistributed_model + _weights=nothing + if (weights !== nothing) + Gridap.Helpers.@notimplementedif parts!==model.parts + _weights=map(model.dmodel.models,weights) do lmodel,weights + # The length of the local weights array has to match the number of + # cells in the model. This includes both owned and ghost cells. + # Only the flags for owned cells are actually taken into account. + @assert num_cells(lmodel)==length(weights) + convert(Vector{Cint},weights) + end + end + comm = parts.comm if (GridapDistributed.i_am_in(model.parts.comm) || GridapDistributed.i_am_in(parts.comm)) if (parts_redistributed_model !== model.parts) @@ -1610,7 +1624,7 @@ function GridapDistributed.redistribute(model::OctreeDistributedDiscreteModel{Dc @assert A || B end if (parts_redistributed_model===model.parts || A) - _redistribute_parts_subseteq_parts_redistributed(model,parts_redistributed_model) + _redistribute_parts_subseteq_parts_redistributed(model,parts_redistributed_model,_weights) else _redistribute_parts_supset_parts_redistributed(model, parts_redistributed_model) end @@ -1619,7 +1633,9 @@ function GridapDistributed.redistribute(model::OctreeDistributedDiscreteModel{Dc end end -function _redistribute_parts_subseteq_parts_redistributed(model::OctreeDistributedDiscreteModel{Dc,Dp}, parts_redistributed_model) where {Dc,Dp} +function _redistribute_parts_subseteq_parts_redistributed(model::OctreeDistributedDiscreteModel{Dc,Dp}, + parts_redistributed_model, + _weights::WeightsArrayType) where {Dc,Dp} parts = (parts_redistributed_model === model.parts) ? model.parts : parts_redistributed_model if (parts_redistributed_model === model.parts) ptr_pXest_old = model.ptr_pXest @@ -1631,7 +1647,15 @@ function _redistribute_parts_subseteq_parts_redistributed(model::OctreeDistribut parts.comm) end ptr_pXest_new = pXest_copy(model.pXest_type, ptr_pXest_old) - pXest_partition!(model.pXest_type, ptr_pXest_new) + if (_weights !== nothing) + init_fn_callback_c = pXest_reset_callbacks(model.pXest_type) + map(_weights) do _weights + pXest_reset_data!(model.pXest_type, ptr_pXest_new, Cint(sizeof(Cint)), init_fn_callback_c, pointer(_weights)) + end + pXest_partition!(model.pXest_type, ptr_pXest_new; weights_set=true) + else + pXest_partition!(model.pXest_type, ptr_pXest_new; weights_set=false) + end # Compute RedistributeGlue parts_snd, lids_snd, old2new = pXest_compute_migration_control_data(model.pXest_type,ptr_pXest_old,ptr_pXest_new) diff --git a/src/PXestTypeMethods.jl b/src/PXestTypeMethods.jl index 6397742..d20c370 100644 --- a/src/PXestTypeMethods.jl +++ b/src/PXestTypeMethods.jl @@ -309,16 +309,31 @@ function pXest_balance!(::P8estType, ptr_pXest; k_2_1_balance=0) end end -function pXest_partition!(::P4estType, ptr_pXest) - p4est_partition(ptr_pXest, 0, C_NULL) +function pXest_partition!(pXest_type::P4estType, ptr_pXest; weights_set=false) + if (!weights_set) + p4est_partition(ptr_pXest, 0, C_NULL) + else + wcallback=pXest_weight_callback(pXest_type) + p4est_partition(ptr_pXest, 0, wcallback) + end end -function pXest_partition!(::P6estType, ptr_pXest) - p6est_partition(ptr_pXest, C_NULL) +function pXest_partition!(pXest_type::P6estType, ptr_pXest; weights_set=false) + if (!weights_set) + p6est_partition(ptr_pXest, C_NULL) + else + wcallback=pXest_weight_callback(pXest_type) + p6est_partition(ptr_pXest, wcallback) + end end -function pXest_partition!(::P8estType, ptr_pXest) - p8est_partition(ptr_pXest, 0, C_NULL) +function pXest_partition!(pXest_type::P8estType, ptr_pXest; weights_set=false) + if (!weights_set) + p8est_partition(ptr_pXest, 0, C_NULL) + else + wcallback=pXest_weight_callback(pXest_type) + p8est_partition(ptr_pXest, 0, wcallback) + end end @@ -805,6 +820,30 @@ function pXest_refine_callbacks(::P8estType) refine_callback_c, refine_replace_callback_c end +function pXest_weight_callback(::P4estType) + function weight_callback(::Ptr{p4est_t}, + which_tree::p4est_topidx_t, + quadrant_ptr::Ptr{p4est_quadrant_t}) + quadrant = quadrant_ptr[] + return unsafe_wrap(Array, Ptr{Cint}(quadrant.p.user_data), 1)[] + end + @cfunction($weight_callback, Cint, (Ptr{p4est_t}, p4est_topidx_t, Ptr{p4est_quadrant_t})) +end + +function pXest_weight_callback(::P6estType) + Gridap.Helpers.@notimplemented +end + +function pXest_weight_callback(::P8estType) + function weight_callback(::Ptr{p8est_t}, + which_tree::p4est_topidx_t, + quadrant_ptr::Ptr{p8est_quadrant_t}) + quadrant = quadrant_ptr[] + return unsafe_wrap(Array, Ptr{Cint}(quadrant.p.user_data), 1)[] + end + @cfunction($weight_callback, Cint, (Ptr{p8est_t}, p4est_topidx_t, Ptr{p8est_quadrant_t})) +end + function _unwrap_ghost_quadrants(::P4estType, pXest_ghost) Ptr{p4est_quadrant_t}(pXest_ghost.ghosts.array) end diff --git a/test/PoissonNonConformingOctreeModelsTests.jl b/test/PoissonNonConformingOctreeModelsTests.jl index 7ec3cb8..e7b4cf5 100644 --- a/test/PoissonNonConformingOctreeModelsTests.jl +++ b/test/PoissonNonConformingOctreeModelsTests.jl @@ -139,7 +139,14 @@ module PoissonNonConformingOctreeModelsTests e = uH - uhH el2 = sqrt(sum( ∫( e⋅e )*dΩH )) - fmodel_red, red_glue=GridapDistributed.redistribute(fmodel); + weights=map(ranks,fmodel.dmodel.models) do rank,lmodel + if (rank%2==0) + zeros(Cint,num_cells(lmodel)) + else + ones(Cint,num_cells(lmodel)) + end + end + fmodel_red, red_glue=GridapDistributed.redistribute(fmodel,weights=weights); Vhred=FESpace(fmodel_red,reffe,conformity=:H1;dirichlet_tags="boundary") Uhred=TrialFESpace(Vhred,u) @@ -274,12 +281,12 @@ module PoissonNonConformingOctreeModelsTests #debug_logger = ConsoleLogger(stderr, Logging.Debug) #global_logger(debug_logger); # Enable the debug logger globally ranks = distribute(LinearIndices((MPI.Comm_size(MPI.COMM_WORLD),))) - for Dc=3:3, perm=1:4, order=1:4, scalar_or_vector in (:scalar,) - test(ranks,Val{Dc},perm,order,_field_type(Val{Dc}(),scalar_or_vector)) - end - for Dc=2:3, perm in (1,2), order in (1,4), scalar_or_vector in (:vector,) - test(ranks,Val{Dc},perm,order,_field_type(Val{Dc}(),scalar_or_vector)) - end + # for Dc=3:3, perm=1:4, order=1:4, scalar_or_vector in (:scalar,) + # test(ranks,Val{Dc},perm,order,_field_type(Val{Dc}(),scalar_or_vector)) + # end + # for Dc=2:3, perm in (1,2), order in (1,4), scalar_or_vector in (:vector,) + # test(ranks,Val{Dc},perm,order,_field_type(Val{Dc}(),scalar_or_vector)) + # end for order=2:2, scalar_or_vector in (:scalar,:vector) test_2d(ranks,order,_field_type(Val{2}(),scalar_or_vector), num_amr_steps=5) test_3d(ranks,order,_field_type(Val{3}(),scalar_or_vector), num_amr_steps=4)