-
Notifications
You must be signed in to change notification settings - Fork 0
/
example.py
49 lines (34 loc) · 1.32 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import cmdstanpy
import matplotlib.pyplot as plt
import seaborn
import pandas as pd
from retrospectr.importance_weights import calculate_log_weights, extract_samples
from retrospectr.resampling import resample
model_file = "test/test_models/bernoulli/bernoulli.stan"
stan_model = cmdstanpy.CmdStanModel(stan_file=model_file)
original_data = {
"N": 10,
"y": [0, 1, 0, 0, 0, 0, 0, 0, 0, 1]
}
original_fit = stan_model.sample(data=original_data, chains=1)
new_data = {
"N": 20,
"y": [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1]
}
new_fit = stan_model.sample(data=new_data, chains=1)
original_samples = extract_samples(original_fit)
new_samples = extract_samples(new_fit)
log_weights = calculate_log_weights(model_file, original_samples, original_data, new_data)
resampled_original_samples = resample(original_samples, log_weights)
df_original = pd.DataFrame({
"theta": original_samples.reshape(len(original_samples)),
"model": "Original"})
df_new = pd.DataFrame({
"theta": new_samples.reshape(len(new_samples)),
"model": "New"})
df_resampled = pd.DataFrame({
"theta": resampled_original_samples.reshape(len(resampled_original_samples)),
"model": "Resampled"})
df = pd.concat((df_original, df_new, df_resampled))
seaborn.displot(df, x="theta", hue="model", kind="kde")
plt.show()