-
Notifications
You must be signed in to change notification settings - Fork 0
/
plotting_styles.py
103 lines (68 loc) · 2.56 KB
/
plotting_styles.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import seaborn as sb
'''
Set style for figures.
'''
def default_figure(font_scale=1.0):
sb.set(font='Times New Roman', style='ticks',
rc={'text.usetex': True})
sb.set_context("poster", font_scale=font_scale)
sb.set_palette("colorblind")
def twocolumn_figure(fig_ratio=None, font_scale=1.25):
default_figure()
width = 8.4
# Keep the default ratio used in seaborn. This can get overwritten.
height = (4.4 / 6.4) * width
figsize = (width, height)
if fig_ratio is not None:
figsize = (width, width * fig_ratio)
sb.set_context("paper", font_scale=font_scale,
rc={"figure.figsize": figsize})
def twocolumn_twopanel_figure(fig_ratio=None, **kwargs):
# Use half the ratio of a one column figure.
if fig_ratio is None:
fig_ratio = (4.4 / 6.4) / 2
twocolumn_figure(fig_ratio=fig_ratio, **kwargs)
def onecolumn_figure(fig_ratio=None, font_scale=1.2):
'''
fig_ratio is width / height.
'''
default_figure()
# About the width (in inches) of a column
width = 4.2
# Keep the default ratio used in seaborn. This can get overwritten.
height = (4.4 / 6.4) * width
figsize = (width, height)
if fig_ratio is not None:
figsize = (width, width * fig_ratio)
sb.set_context("paper", font_scale=font_scale,
rc={"figure.figsize": figsize})
def onecolumn_Npanel_figure(N, font_scale=1.2):
width = 4.2
height = (4.4 / 6.4) * N * width
onecolumn_figure(fig_ratio=height / width, font_scale=font_scale)
def onecolumn_twopanel_figure(font_scale=1.2):
onecolumn_Npanel_figure(N=2, font_scale=font_scale)
def image_shape_ratio(shape):
'''
Return the width / height for adjusting figure sizes.
'''
return shape[0] / float(shape[1])
def align_yaxis(ax1, v1, ax2, v2):
"""adjust ax2 ylimit so that v2 in ax2 is aligned to v1 in ax1"""
_, y1 = ax1.transData.transform((0, v1))
_, y2 = ax2.transData.transform((0, v2))
adjust_yaxis(ax2, (y1 - y2) / 2, v2)
adjust_yaxis(ax1, (y2 - y1) / 2, v1)
def adjust_yaxis(ax, ydif, v):
"""shift axis ax by ydiff, maintaining point v at the same location"""
inv = ax.transData.inverted()
_, dy = inv.transform((0, 0)) - inv.transform((0, ydif))
miny, maxy = ax.get_ylim()
miny, maxy = miny - v, maxy - v
if -miny > maxy or (-miny == maxy and dy > 0):
nminy = miny
nmaxy = miny * (maxy + dy) / (miny + dy)
else:
nmaxy = maxy
nminy = maxy * (miny + dy) / (maxy + dy)
ax.set_ylim(nminy + v, nmaxy + v)