Skip to content

Commit

Permalink
🔧 Use the vector indices in get_partition
Browse files Browse the repository at this point in the history
  • Loading branch information
ronisbr committed Jun 14, 2024
1 parent 0412514 commit 1cbe0ee
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
14 changes: 8 additions & 6 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ end
# == Threads ===============================================================================

"""
get_partition(cp::Integer, v::AbstractVector, np::Integer) -> Int, Int
get_partition(cp::Integer, inds::AbstractVector, np::Integer) -> Int, Int
Return the `cp`-th partition (start and end indices) of the vector `v` considering that we
are partitioning it into `np` parts.
Return the `cp`-th partition (start and end indices) of a vector with indices `inds`
considering that we are partitioning it into `np` parts.
This function is useful to splitting input information for spawning multiple tasks.
Expand All @@ -52,16 +52,18 @@ This function is useful to splitting input information for spawning multiple tas
- `Int`: Current partition start index.
- `Int`: Current partition last index.
"""
function get_partition(cp::Integer, v::AbstractVector, np::Integer)
num_elements = length(v)
function get_partition(cp::Integer, inds::AbstractVector, np::Integer)
num_elements = length(inds)

num_elements == 0 && return 0, 0

# Check inputs.
np = min(np, num_elements)
cp = min(cp, np)

len, rem = divrem(num_elements, np)

i₀ = firstindex(v) + (cp - 1) * len
i₀ = first(inds) + (cp - 1) * len
i₁ = i₀ + len - 1

i₀ += cp <= rem ? cp - 1 : rem
Expand Down
12 changes: 7 additions & 5 deletions test/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,28 @@
############################################################################################

@testset "Helpers" begin
i₀, i₁ = SatelliteToolboxBase.get_partition(1, 1:1:10, 3)
vt = 1:2:20
inds = eachindex(vt)
i₀, i₁ = SatelliteToolboxBase.get_partition(1, vt, 3)
@test i₀ == 1
@test i₁ == 4

i₀, i₁ = SatelliteToolboxBase.get_partition(2, 1:1:10, 3)
i₀, i₁ = SatelliteToolboxBase.get_partition(2, vt, 3)
@test i₀ == 5
@test i₁ == 7

i₀, i₁ = SatelliteToolboxBase.get_partition(3, 1:1:10, 3)
i₀, i₁ = SatelliteToolboxBase.get_partition(3, vt, 3)
@test i₀ == 8
@test i₁ == 10

for i in 1:10
i₀, i₁ = SatelliteToolboxBase.get_partition(i, 1:1:10, 100)
i₀, i₁ = SatelliteToolboxBase.get_partition(i, vt, 100)
@test i₀ == i
@test i₁ == i
end

for i in 11:100
i₀, i₁ = SatelliteToolboxBase.get_partition(i, 1:1:10, 100)
i₀, i₁ = SatelliteToolboxBase.get_partition(i, vt, 100)
@test i₀ == 10
@test i₁ == 10
end
Expand Down

0 comments on commit 1cbe0ee

Please sign in to comment.