Skip to content

Commit

Permalink
Simplify rule according to feedback in #531
Browse files Browse the repository at this point in the history
  • Loading branch information
simsurace authored Feb 7, 2024
1 parent 49049b1 commit 23c531e
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,9 @@ end

function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
d = x - y
sind = sinpi.(d)
abs2_sind_r = abs2.(sind) ./ s.r .^ 2
abs2_sind_r = (sinpi.(d) ./ s.r) .^ 2

Check warning on line 114 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L114

Added line #L114 was not covered by tests
val = sum(abs2_sind_r)
gradx = twoπ .* cospi.(d) .* sind ./ s.r .^ 2
gradx = π .* cospi.(2 .* d) ./ s.r .^ 2

Check warning on line 116 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L116

Added line #L116 was not covered by tests
function evaluate_pullback::Any)
= -2Δ .* abs2_sind_r ./ s.r
= ChainRulesCore.Tangent{typeof(s)}(; r=r̄)
Expand All @@ -136,7 +135,7 @@ function ChainRulesCore.rrule(
for j in 1:n, i in 1:n
xi = view(x, i, :)
xj = view(x, j, :)
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2
ds = π .* Δ[i, j] .* cospi.(2 .* (xi .- xj)) ./ d.r .^ 2

Check warning on line 138 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L138

Added line #L138 was not covered by tests
.-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3
x̄[i, :] += ds
x̄[j, :] -= ds
Expand All @@ -147,8 +146,8 @@ function ChainRulesCore.rrule(
xj = view(x, :, j)
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2
.-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3
x̄[:, i] += ds
x̄[:, j] -= ds
x̄[:, i] .+= ds
x̄[:, j] .-= ds

Check warning on line 150 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L149-L150

Added lines #L149 - L150 were not covered by tests
end
end
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
Expand All @@ -173,19 +172,19 @@ function ChainRulesCore.rrule(
for j in 1:m, i in 1:n
xi = view(x, i, :)
yj = view(y, j, :)
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2
ds = π .* Δ[i, j] .* cospi.(2 .* (xi .- yj)) ./ d.r .^ 2

Check warning on line 175 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L175

Added line #L175 was not covered by tests
.-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
x̄[i, :] += ds
ȳ[j, :] -= ds
x̄[i, :] .+= ds
ȳ[j, :] .-= ds

Check warning on line 178 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L177-L178

Added lines #L177 - L178 were not covered by tests
end
elseif dims == 2
for j in 1:m, i in 1:n
xi = view(x, :, i)
yj = view(y, :, j)
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2
ds = π .* Δ[i, j] .* cospi.(2 .* (xi .- yj)) ./ d.r .^ 2

Check warning on line 184 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L184

Added line #L184 was not covered by tests
.-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
x̄[:, i] += ds
ȳ[:, j] -= ds
x̄[:, i] .+= ds
ȳ[:, j] .-= ds

Check warning on line 187 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L186-L187

Added lines #L186 - L187 were not covered by tests
end
end
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
Expand All @@ -208,10 +207,10 @@ function ChainRulesCore.rrule(
for i in 1:n
xi = view(x, :, i)
yi = view(y, :, i)
ds = twoπ .* Δ[i] .* sinpi.(xi .- yi) .* cospi.(xi .- yi) ./ d.r .^ 2
ds = π .* Δ[i] .* cospi.(2 .* (xi .- yi)) ./ d.r .^ 2

Check warning on line 210 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L210

Added line #L210 was not covered by tests
.-= 2 .* Δ[i] .* sinpi.(xi .- yi) .^ 2 ./ d.r .^ 3
x̄[:, i] += ds
ȳ[:, i] -= ds
x̄[:, i] .+= ds
ȳ[:, i] .-= ds

Check warning on line 213 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L212-L213

Added lines #L212 - L213 were not covered by tests
end
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))
Expand Down

0 comments on commit 23c531e

Please sign in to comment.