diff --git a/scripts/plot2d.py b/scripts/plot2d.py index 3773c3d66..0dae94a00 100644 --- a/scripts/plot2d.py +++ b/scripts/plot2d.py @@ -20,6 +20,18 @@ Specify a file that contains all the kernel files with results. ''') +parser.add_argument('-o', '--outfile', action='store', + help= +''' +Specify a the path in which you want to save the plot. +''') + +parser.add_argument('-s', '--show', action='store', + help= +''' +Show the plot. +''') + args = parser.parse_args() # Roofline plot @@ -34,9 +46,13 @@ def heatmap(db): axisfont = {'fontname' : 'Garamond'}; sns.heatmap(db, cmap="Greens", annot=True) plt.title('Relative kernel performance on maximum achievable', **titlefont) - plt.xlabel('Vector Length [elements]', **axisfont) + plt.xlabel('Vector Length [Byte]', **axisfont) plt.ylabel('Kernel', **axisfont) - plt.show() + if args.show: + plt.show() + if args.outfile: + plt.savefig(args.outfile) + # Append a new entry to the main database def append_entry(lst, template): @@ -63,6 +79,10 @@ def update_db(fpath, db, template): 'sb_full' : int(elm[9]), } +def elm_2_byte_db(db): + for i, elm in enumerate(db): + db[i]['vsize'] = elm['vsize'] * elm['sew'] + def kernel_list_gen(db): return list(set([d['kernel'] for d in db])) @@ -94,6 +114,9 @@ def main(): # Update the database with the information from the input file update_db(args.infile, db, template) + # Size in Byte + elm_2_byte_db(db) + # Build list of available kernels and common vsizes kernel_list = kernel_list_gen(db) vsize_list = vsize_list_gen(db, kernel_list)