Skip to content

Commit

Permalink
chore: format code (#371)
Browse files Browse the repository at this point in the history
Co-authored-by: mofeing <[email protected]>
  • Loading branch information
github-actions[bot] and mofeing authored Dec 13, 2024
1 parent 311498b commit b56e661
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -736,18 +736,18 @@ function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N,

# (d1, d2, ..., dP) -> (d1, 1, d2, 1, ..., dP, 1)
interleaved_size = ones(Int, 2P)
interleaved_size[1:2:2N] .= size(x)
interleaved_size[1:2:(2N)] .= size(x)

x_interleaved = reshape(x, interleaved_size...)

# (d1, 1, d2, 1, ..., dP, 1) -> (d1, r1, d2, r2, ..., dP, rP)
broadcast_target_size = interleaved_size
broadcast_target_size[2:2:2M] .= counts
broadcast_target_size[2:2:(2M)] .= counts

x_broadcasted = broadcast_to_size(x_interleaved, broadcast_target_size)

# (d1, r1, d2, r2, ..., dP, rP) -> (d1*r1, d2*r2, ..., dP*rP)
final_size = vec(prod(reshape(broadcast_target_size, 2, :), dims=1))
final_size = vec(prod(reshape(broadcast_target_size, 2, :); dims=1))

x_final = reshape(x_broadcasted, final_size...)

Expand Down
4 changes: 2 additions & 2 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,8 @@ end

@testset "repeat" begin
@testset for (size, counts) in Iterators.product(
[(2,), (2,3), (2,3,4), (2,3,4,5)],
[(), (1,), (2,), (2,1), (1,2), (2,2), (2,2,2), (1,1,1,1,1)]
[(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)],
[(), (1,), (2,), (2, 1), (1, 2), (2, 2), (2, 2, 2), (1, 1, 1, 1, 1)],
)
x = rand(size...)
@test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...)
Expand Down

0 comments on commit b56e661

Please sign in to comment.