Skip to content

Commit

Permalink
reorg imports
Browse files Browse the repository at this point in the history
  • Loading branch information
deven367 committed Jul 19, 2022
1 parent 360f199 commit 9c8c2ac
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 10 deletions.
42 changes: 37 additions & 5 deletions clean_plot/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
__all__ = ['Plot']

# Cell
from fastcore.basics import store_attr
from fastcore.basics import store_attr, patch_to, patch
from fastcore.xtras import globtastic
from fastcore.meta import delegates
from fastcore.basics import patch_to, patch
from pathlib import Path
import os
import numpy as np
Expand Down Expand Up @@ -53,16 +52,48 @@ def create_ssms(self):
del norm_ssm


def get_standardized(self):
def get_standardized(self, start, end):
pass


def get_corr_plots(self):
pass

def get_sectional_ssms(self):
pass
def get_sectional_ssms(self, start, end):
import gc
if start == 0 and end == -1:
pass
else:
assert start < end, 'Incorrect bounds'
new_path = self.path/f'sections_{start} {end}'
new_path.mkdir(exist_ok=True)

if start == 0:
labels = np.linspace(start + 1, end, y, dtype=int)
else:
labels = np.linspace(start, end, y, dtype=int)

ticks = np.linspace(1, end - start, y, dtype=int)

for method, norm_ssm in self.norm.items():
title = f'{self.book_name} {method}'
sns.heatmap(norm_ssm[start:end, start:end], cmap='hot',
vmin=0, vmax=1, square=True,
xticklabels=False)
length = norm_ssm.shape[0]



plt.yticks(ticks, ticks, rotation = 0)
plt.ylabel('sentence number')
plt.savefig(new_path/f'{title}.png', dpi = 300, bbox_inches='tight')
print(f'Done plotting {title}.png')
plt.clf()
del norm_ssm
_ = gc.collect()

def __repr__(self):
# remember __str__ calls the __repr_ internally
dir_path = os.path.dirname(os.path.realpath(self.path))
return f'This object contains the path to `{dir_path}`'

Expand All @@ -71,6 +102,7 @@ def __repr__(self):
def get_normalized(self:Plot):
"Returns the normalized ssms"
files = self.view_all_files(file_glob='*.npy')

for f in files:
f = Path(f)
fname = f.stem.split('_cleaned_')
Expand Down
53 changes: 48 additions & 5 deletions nbs/05_plot.utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@
"> This module tries to include most of the plotting functionality available in the package"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8a11a137",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -27,10 +38,9 @@
"outputs": [],
"source": [
"#|export\n",
"from fastcore.basics import store_attr\n",
"from fastcore.basics import store_attr, patch_to, patch\n",
"from fastcore.xtras import globtastic\n",
"from fastcore.meta import delegates\n",
"from fastcore.basics import patch_to, patch\n",
"from pathlib import Path\n",
"import os \n",
"import numpy as np\n",
Expand Down Expand Up @@ -101,16 +111,48 @@
" del norm_ssm\n",
" \n",
" \n",
" def get_standardized(self):\n",
" def get_standardized(self, start, end):\n",
" pass\n",
" \n",
" \n",
" def get_corr_plots(self):\n",
" pass\n",
" \n",
" def get_sectional_ssms(self):\n",
" pass\n",
" def get_sectional_ssms(self, start, end):\n",
" import gc\n",
" if start == 0 and end == -1:\n",
" pass\n",
" else:\n",
" assert start < end, 'Incorrect bounds'\n",
" new_path = self.path/f'sections_{start} {end}'\n",
" new_path.mkdir(exist_ok=True)\n",
" \n",
" if start == 0:\n",
" labels = np.linspace(start + 1, end, y, dtype=int)\n",
" else:\n",
" labels = np.linspace(start, end, y, dtype=int)\n",
" \n",
" ticks = np.linspace(1, end - start, y, dtype=int)\n",
" \n",
" for method, norm_ssm in self.norm.items():\n",
" title = f'{self.book_name} {method}'\n",
" sns.heatmap(norm_ssm[start:end, start:end], cmap='hot', \n",
" vmin=0, vmax=1, square=True, \n",
" xticklabels=False)\n",
" length = norm_ssm.shape[0]\n",
" \n",
" \n",
" \n",
" plt.yticks(ticks, ticks, rotation = 0)\n",
" plt.ylabel('sentence number')\n",
" plt.savefig(new_path/f'{title}.png', dpi = 300, bbox_inches='tight')\n",
" print(f'Done plotting {title}.png')\n",
" plt.clf()\n",
" del norm_ssm\n",
" _ = gc.collect()\n",
" \n",
" def __repr__(self):\n",
" # remember __str__ calls the __repr_ internally\n",
" dir_path = os.path.dirname(os.path.realpath(self.path))\n",
" return f'This object contains the path to `{dir_path}`'"
]
Expand All @@ -127,6 +169,7 @@
"def get_normalized(self:Plot):\n",
" \"Returns the normalized ssms\"\n",
" files = self.view_all_files(file_glob='*.npy')\n",
" \n",
" for f in files:\n",
" f = Path(f)\n",
" fname = f.stem.split('_cleaned_')\n",
Expand Down

0 comments on commit 9c8c2ac

Please sign in to comment.