-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_vietnam_calibration.py
176 lines (145 loc) · 5.84 KB
/
plot_vietnam_calibration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import covasim as cv
import pandas as pd
import sciris as sc
import pylab as pl
import numpy as np
from matplotlib import ticker
import datetime as dt
import matplotlib.patches as patches
# Filepaths
figsfolder = 'figs234'
simsfilepath = 'results/vietnam_sim.obj'
calibration_end = '2020-10-15'
T = sc.tic()
# Import files
simsfile = sc.loadobj(simsfilepath)
# Define plotting functions
#%% Helper functions
def format_ax(ax, sim, key=None):
@ticker.FuncFormatter
def date_formatter(x, pos):
return (sim['start_day'] + dt.timedelta(days=x)).strftime('%b')
ax.xaxis.set_major_formatter(date_formatter)
pl.xlim([0, sim.day(calibration_end)])
sc.boxoff()
return
def plotter(key, sims, ax, ys=None, calib=False, label='', ylabel='', low_q=0.025, high_q=0.975, flabel=True, startday=None, subsample=2, chooseseed=None):
which = key.split('_')[1]
try:
color = cv.get_colors()[which]
except:
color = [0.5,0.5,0.5]
if which == 'diagnoses':
color = [0.03137255, 0.37401 , 0.63813918, 1. ]
elif which == '':
color = [0.82400815, 0. , 0. , 1. ]
if ys is None:
ys = []
for s in sims:
ys.append(s.results[key].values)
yarr = np.array(ys)
if chooseseed is not None:
best = sims[chooseseed].results[key].values
else:
best = pl.median(yarr, axis=0)
low = pl.quantile(yarr, q=low_q, axis=0)
high = pl.quantile(yarr, q=high_q, axis=0)
sim = sims[0] # For having a sim to refer to
tvec = np.arange(len(best))
if key in sim.data:
data_t = np.array((sim.data.index-sim['start_day'])/np.timedelta64(1,'D'))
inds = np.arange(0, len(data_t), subsample)
pl.plot(data_t[inds], sim.data[key][inds], 'd', c=color, markersize=15, alpha=0.75, label='Data')
start = None
if startday is not None:
start = sim.day(startday)
end = sim.day(calibration_end)
if flabel:
if which == 'infections':
fill_label = '95% projected interval'
else:
fill_label = '95% projected interval'
else:
fill_label = None
pl.fill_between(tvec[startday:end], low[startday:end], high[startday:end], facecolor=color, alpha=0.2, label=fill_label)
pl.plot(tvec[startday:end], best[startday:end], c=color, label=label, lw=4, alpha=1.0)
# Print some stats
if key == 'cum_infections':
print(f'Estimated {which} on July 25: {best[sim.day("2020-07-25")]} (95%: {low[sim.day("2020-07-25")]}-{high[sim.day("2020-07-25")]})')
print(f'Estimated {which} overall: {best[sim.day(calibration_end)]} (95%: {low[sim.day(calibration_end)]}-{high[sim.day(calibration_end)]})')
elif key=='n_infectious':
peakday = sc.findnearest(best, max(best))
peakval = max(best)
print(f'Estimated peak {which} on {sim.date(peakday)}: {peakval} (95%: {low[peakday]}-{high[peakday]})')
print(f'Estimated {which} on last day: {best[sim.day(calibration_end)]} (95%: {low[sim.day(calibration_end)]}-{high[sim.day(calibration_end)]})')
elif key=='cum_diagnoses':
print(f'Estimated {which} overall: {best[sim.day(calibration_end)]} (95%: {low[sim.day(calibration_end)]}-{high[sim.day(calibration_end)]})')
sc.setylim()
xmin,xmax = ax.get_xlim()
if calib:
ax.set_xticks(pl.arange(xmin+2, xmax, 28))
else:
ax.set_xticks(pl.arange(xmin+2, xmax, 28))
pl.ylabel(ylabel)
datemarks = pl.array([sim.day('2020-07-01'), sim.day('2020-08-01'), sim.day('2020-09-01'), sim.day('2020-10-01')]) * 1.
ax.set_xticks(datemarks)
return
def plot_intervs(sim, labels=True):
color = [0, 0, 0]
jul25 = sim.day('2020-07-25')
for day in [jul25, sim.day('2020-09-05'), sim.day('2020-09-14')]:
pl.axvline(day, c=color, linestyle='--', alpha=0.4, lw=3)
if labels:
yl = pl.ylim()
labely = yl[1]*0.85
pl.text(jul25-20, labely, 'Da Nang\noutbreak', color=color, alpha=0.9, style='italic')
pl.text(sim.day('2020-09-05')-17, labely, 'Work\nreopens', color=color, alpha=0.9, style='italic')
pl.text(sim.day('2020-09-14') + 2, labely, 'School\nreopens', color=color, alpha=0.9, style='italic')
return
# Fonts and sizes
font_size = 36
font_family = 'Libertinus Sans'
pl.rcParams['font.size'] = font_size
pl.rcParams['font.family'] = font_family
pl.figure(figsize=(24,16))
# Extract a sim to refer to
sims = simsfile.sims
sim = sims[0]
# Plot locations
ygapb = 0.05
ygapm = 0.05
ygapt = 0.01
xgapl = 0.065
xgapm = 0.05
xgapr = 0.02
remainingy = 1-(ygapb+ygapm+ygapt)
remainingx = 1-(xgapl+xgapm+xgapr)
dy = remainingy/2
dx1 = 0.5
dx2 = 1-dx1-(xgapl+xgapm+xgapr)
ax = {}
# a: daily diagnoses
ax[0] = pl.axes([xgapl, ygapb+ygapm+dy, dx1, dy])
format_ax(ax[0], sim)
plotter('new_diagnoses', sims, ax[0], calib=True, label='Model', ylabel='Daily diagnoses')
plot_intervs(sim)
# b. cumulative diagnoses
ax[1] = pl.axes([xgapl+xgapm+dx1, ygapb+ygapm+dy, dx2, dy])
format_ax(ax[1], sim)
plotter('cum_diagnoses', sims, ax[1], calib=True, label='Diagnoses\n(modeled)', ylabel='Cumulative diagnoses', flabel=False)
pl.legend(loc='upper left', frameon=False)
#pl.ylim([0, 10e3])
# c. cumulative and active infections
ax[2] = pl.axes([xgapl, ygapb, dx1, dy])
format_ax(ax[2], sim)
plotter('cum_infections', sims, ax[2], calib=True, label='Cumulative infections\n(modeled)', ylabel='', flabel=False)
plotter('n_infectious', sims, ax[2], calib=True, label='Active infections\n(modeled)', ylabel='Estimated infections', flabel=False)
pl.legend(loc='upper left', frameon=False)
# d. cumulative deaths
ax[3] = pl.axes([xgapl+xgapm+dx1, ygapb, dx2, dy])
format_ax(ax[3], sim)
plotter('cum_deaths', sims, ax[3], calib=True, label='Deaths\n(modeled)', ylabel='Cumulative deaths', flabel=False)
pl.legend(loc='upper left', frameon=False)
#pl.ylim([0, 10e3])
cv.savefig(f'{figsfolder}/fig2_calibration.pdf')
sc.toc(T)