Skip to content

Commit

Permalink
Merge branch 'main' of github.com:orlox/SideKicks.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
Pablo Marchant committed Aug 14, 2024
2 parents 4aec5a3 + 9d6f633 commit f437e1e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/KickMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function KickMCMC(; which_model, observations::Tuple{Observations, String}, prio
end

# for the general model, need to reweight by the true anomaly
if model_type==:general
if which_model==:general
results[:weights] .= results[:weights].*sqrt.(1 .- results[:e_i].^2).^3 ./ (1 .+ results[:e_i].*cos.(results[:ν_i])).^2
end

Expand Down
76 changes: 46 additions & 30 deletions src/TuringModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,18 @@ function create_simplified_mcmc_model(;
vkick_dist = priors.vkick_dist
frac_dist = priors.frac_dist

valid_values = [:P_f, :e_f, :K1, :K2, :m1_f, :m2_f, :i_f]
for prop observations.props
if prop valid_values
throw(DomainError(observation.props, "Allowed observations are only [:P_f, :e_f, :K1, :K2, :m1_f, :m2_f, :i_f]"))
end
end

if !(likelihood == :Cauchy || likelihood == :Normal)
throw(DomainError(likelihood, "likelihood must be either :Cauchy or :Normal"))
end

@model function create_mcmc_model(obs_vals, obs_errs)
@model function create_mcmc_model(props, obs_vals, obs_errs)

# set priors
#Pre-explosion masses and orbital period
Expand All @@ -97,7 +104,7 @@ function create_simplified_mcmc_model(;
m2_i = 10^(logm2)*m_sun
logP ~ logP_dist
P_i = 10^(logP)*day
a_i = kepler_a_from_P(m1=m1, m2=m2, P=P_i)
a_i = kepler_a_from_P(m1=m1_i, m2=m2_i, P=P_i)
cosi ~ Uniform(0,1)
i_f = acos(cosi)

Expand All @@ -118,34 +125,43 @@ function create_simplified_mcmc_model(;

#m1 is assumed to remain constant
m1_f = m1_i
a_f, e_f = post_supernova_circular_orbit_a(m1_i=m1_i, m2_i=m2_i, a=a_i, m2_f=m2_f, vkick=vkick, θ=θ, ϕ=ϕ)
a_f, e_f = post_supernova_circular_orbit_a(m1_i=m1_i, m2_i=m2_i, a_i=a_i, m2_f=m2_f, vkick=vkick, θ=θ, ϕ=ϕ)
P_f = kepler_P_from_a(m1=m1_f, m2=m2_f, a=a_f)
K1 = RV_semiamplitude_K1(m1=m1_f, m2=m2_f, P=P_f, e=e_f, i=i_f)
K2 = RV_semiamplitude_K1(m1=m2_f, m2=m1_f, P=P_f, e=e_f, i=i_f)

likelihood == :Cauchy ?
likelihood_dist = Cauchy :
likelihood_dist = Normal

for ii in eachindex(obs.props)
obs_symbol = obs.props[ii]
use_cauchy = likelihood == :Cauchy
for ii in eachindex(props)
obs_symbol = props[ii]
if obs_symbol == :P_f
param = P_f
use_cauchy ?
obs_vals[ii] ~ Cauchy(P_f, obs_errs[ii]) :
obs_vals[ii] ~ Normal(P_f, obs_errs[ii])
elseif obs_symbol == :e_f
param = e_f
use_cauchy ?
obs_vals[ii] ~ Cauchy(e_f, obs_errs[ii]) :
obs_vals[ii] ~ Normal(e_f, obs_errs[ii])
elseif obs_symbol == :K1
param = K1
use_cauchy ?
obs_vals[ii] ~ Cauchy(K1, obs_errs[ii]) :
obs_vals[ii] ~ Normal(K1, obs_errs[ii])
elseif obs_symbol == :K2
param = K2
use_cauchy ?
obs_vals[ii] ~ Cauchy(K2, obs_errs[ii]) :
obs_vals[ii] ~ Normal(K2, obs_errs[ii])
elseif obs_symbol == :m1_f
param = m1_f
use_cauchy ?
obs_vals[ii] ~ Cauchy(m1_f, obs_errs[ii]) :
obs_vals[ii] ~ Normal(m1_f, obs_errs[ii])
elseif obs_symbol == :m2_f
param = m2_f
else
continue
use_cauchy ?
obs_vals[ii] ~ Cauchy(m2_f, obs_errs[ii]) :
obs_vals[ii] ~ Normal(m2_f, obs_errs[ii])
elseif obs_symbol == :i_f
use_cauchy ?
obs_vals[ii] ~ Cauchy(i_f, obs_errs[ii]) :
obs_vals[ii] ~ Normal(i_f, obs_errs[ii])
end

obs_vals[ii] ~ likelihood_dist(param, obs_errs[ii])
end

# other params
Expand All @@ -154,9 +170,9 @@ function create_simplified_mcmc_model(;
end
return_props = [:m1_i, :m2_i, :P_i, :a_i, :i_f, :vkick, :m2_f, :a_f, :P_f, :e_f, :K1, :K2, :frac, :dm2]

obs_vals_cgs = obs.vals .* obs.units
obs_errs_cgs = obs.errs .* obs.units
return [create_mcmc_model(obs_vals_cgs, obs_errs_cgs), return_props]
obs_vals_cgs = observations.vals .* observations.units
obs_errs_cgs = observations.errs .* observations.units
return [create_mcmc_model(observations.props, obs_vals_cgs, obs_errs_cgs), return_props]
end


Expand Down Expand Up @@ -193,12 +209,12 @@ function create_general_mcmc_model(;
venv_E_100kms_dist = priors.venv_E_100kms_dist
venv_r_100kms_dist = priors.venv_r_100kms_dist

#valid_values = [:P_f, :e_f, :K1, :K2, :m1_f, :m2_f, :Ω_i, :ω_i, :i_f, :v_N, :v_E, :v_r]
#for prop ∈ observations.props
# if prop ∉ valid_values
# throw(DomainError(observation.props, "Allowed observations are only [:P_f, :e_f, :K1, :K2, :m1_f, :m2_f, :Ω_i, :ω_i, :i_f, :v_N, :v_E, :v_r]
# end
#end
valid_values = [:P_f, :e_f, :K1, :K2, :m1_f, :m2_f, :Ω_f, :ω_f, :i_f, :v_N, :v_E, :v_r]
for prop observations.props
if prop valid_values
throw(DomainError(observation.props, "Allowed observations are only [:P_f, :e_f, :K1, :K2, :m1_f, :m2_f, :Ω_f, :ω_f, :i_f, :v_N, :v_E, :v_r]"))
end
end
if !(likelihood == :Cauchy || likelihood == :Normal)
throw(DomainError(likelihood, "likelihood must be either :Cauchy or :Normal"))
end
Expand Down Expand Up @@ -336,9 +352,9 @@ function create_general_mcmc_model(;
end

dm2 = m2_i - m2_f
return ( m1_i, m2_i, P_i, e_i, a_i, vkick, frac, dm2, i_f, ω_f, Ω_f, m1_f, m2_f, a_f, P_f, e_f, K1, K2, v_N, v_E, v_r, vsys)
return ( m1_i, m2_i, P_i, e_i, a_i, ν_i, vkick, frac, dm2, i_f, ω_f, Ω_f, m1_f, m2_f, a_f, P_f, e_f, K1, K2, v_N, v_E, v_r, vsys)
end
return_props = [:m1_i, :m2_i, :P_i, :e_i, :a_i, :vkick, :frac, :dm2, :i_f, :ω_f, :Ω_f, :m1_f,:m2_f, :a_f, :P_f, :e_f, :K1, :K2, :v_N, :v_E, :v_r, :vsys]
return_props = [:m1_i, :m2_i, :P_i, :e_i, :a_i, :ν_i, :vkick, :frac, :dm2, :i_f, :ω_f, :Ω_f, :m1_f,:m2_f, :a_f, :P_f, :e_f, :K1, :K2, :v_N, :v_E, :v_r, :vsys]

# Need to combine some of the observations to compare against the predicted output
obs_vals_cgs = observations.vals .* observations.units
Expand Down

0 comments on commit f437e1e

Please sign in to comment.