Skip to content

Commit

Permalink
#8 remove stacked_chart_quad - replaced by horizontal_grouped_bar_cha…
Browse files Browse the repository at this point in the history
…rt()
  • Loading branch information
Zhangyixue1537 committed Jun 6, 2024
1 parent 4019257 commit 36d5778
Showing 1 changed file with 0 additions and 122 deletions.
122 changes: 0 additions & 122 deletions rick.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,128 +623,6 @@ def multi_stacked_bar_chart(data, xlab, lab1, lab2, lab3, **kwargs):

return fig, ax

def stacked_chart_quad(data_in, xlab, lab1, lab2, lab3, lab4, **kwargs):
"""Creates a stacked bar chart comparing 4 sets of data
Parameters
-----------
data : dataframe
Data for the stacked bar chart. The dataframe must have 5 columns, the first representing the y ticks, the second representing the baseline, and the third representing the next series of data.
xlab : str
Label for the x axis.
lab1 : str
Label in the legend for the baseline
lab2 : str
Label in the legend fot the next data series
xmax : int, optional, default is the max s value
The max value of the y axis
xmin : int, optional, default is 0
The minimum value of the x axis
precision : int, optional, default is -1
Decimal places in the annotations
percent : boolean, optional, default is False
Whether the annotations should be formatted as percentages
xinc : int, optional
The increment of ticks on the x axis.
Returns
--------
fig
Matplotlib fig object
ax
Matplotlib ax object
"""

func()
data = data_in.copy(deep=True)

data.columns = ['name', 'values1', 'values2', 'values3', 'values4']

xmin = kwargs.get('xmin', 0)
xmax = kwargs.get('xmax', None)
precision = kwargs.get('precision', -1)
percent = kwargs.get('percent', False)

xmax_flag = True
if xmax == None:
xmax = int(max(data[['values1', 'values2', 'values3', 'values4']].max()))
xmax_flag = False

delta = (xmax - xmin)/4
i = 0
while True:
delta /= 10
i += 1
if delta < 10:
break
xinc = kwargs.get('xinc', int(round(delta+1)*pow(10,i)))

if xmax_flag == True:
upper = xmax
else:
upper = int(4*xinc+xmin)

ind = np.arange(len(data))
print(len(data))

fig, ax = plt.subplots()
fig.set_size_inches(6.1, len(data)*1.5)
ax.grid(color='k', linestyle='-', linewidth=0.25)

p1 = ax.barh(ind+0.6, data['values1'], 0.2, align='center', color = colour.green)
p2 = ax.barh(ind+0.4, data['values2'], 0.2, align='center', color = colour.blue)
p3 = ax.barh(ind+0.2, data['values3'], 0.2, align='center', color = colour.grey)
p4 = ax.barh(ind, data['values4'], 0.2, align='center', color=colour.purple)
ax.xaxis.set_major_formatter(mpl.ticker.StrMethodFormatter('{x:,.0f}'))

ax.xaxis.grid(True)
ax.yaxis.grid(False)
ax.set_yticks(ind+0.6/2)
ax.set_xlim(0,upper)
ax.set_yticklabels(data['name'])
ax.set_xlabel(xlab, horizontalalignment='left', x=0, labelpad=10, fontname = font.normal, fontsize=10, fontweight = 'bold')

ax.set_facecolor('xkcd:white')


if precision < 1:
data[['values1', 'values2', 'values3', 'values4']] = data[['values1', 'values2', 'values3', 'values4']].astype(int)

j = 0.0
for k in range(4,0,-1):

for i in data[f'values{k}']:
if i < 0.1*upper:
ax.annotate(str(format(round(i,precision), ',')), xy=(i+0.015*upper, j-0.05), ha = 'left', color = 'k', fontname = font.normal, fontsize=10)
else:
ax.annotate(str(format(round(i,precision), ',')), xy=(i-0.015*upper, j-0.05), ha = 'right', color = 'w', fontname = font.normal, fontsize=10)
j=j+1
j = j-len(data[f'values{k}']) + 0.2


ax.legend((p1[0], p2[0], p3[0], p4[0]), (lab1, lab2, lab3, lab4), loc=4, frameon=False, prop=font.leg_font)
plt.xticks(range(xmin,upper+int(0.1*xinc), xinc), fontname = font.normal, fontsize =10)
plt.yticks( fontname = font.normal, fontsize =10)

if percent == True:
j = 0.15
data_yoy = data
for k in range(3,0,-1):
data_yoy[f'percent{k}'] = (data['values4']-data[f'values{k}'])*100/data[f'values{k}']
if k == 1:
print(data_yoy)

for index, row in data_yoy.iterrows():
ax.annotate(('+' if row[f'percent{k}'] > 0 else '')+str(format(int(round(row[f'percent{k}'],0)), ','))+'%',
xy=(max(row[['values1', 'values2', 'values3', 'values4']]) + (0.12 if row['values4'] < 0.1*upper else 0.03)*upper, j), color = 'k', fontname = font.normal, fontsize=10)
j+=1
j = j-len(data_yoy) + 0.2


return fig, ax

def bar_chart(data_in, xlab, ylab, horizontal=False, **kwargs):
"""Creates a bar chart
Expand Down

0 comments on commit 36d5778

Please sign in to comment.