Skip to content

Commit

Permalink
plt_temp -> plt_viz, plt_temp2 -> plt_temp
Browse files Browse the repository at this point in the history
  • Loading branch information
arthurfeeney committed Nov 4, 2023
1 parent a95bef8 commit e9b81f3
Show file tree
Hide file tree
Showing 2 changed files with 319 additions and 65 deletions.
159 changes: 159 additions & 0 deletions scripts/viz_temp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import argparse
import torch
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, BoundaryNorm
import numpy as np
import os
from pathlib import Path
import subprocess
import scipy.fft as sfft
from dataclasses import dataclass

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--path', required=True, type=str,
help='Path to directory with model and sim output.pt files')
return parser.parse_args()

@dataclass
class BoilingData:
temp: torch.Tensor

def load_vel_data(temp_path):
pred = BoilingData(
torch.load(f'{temp_path}/model_ouput.pt').numpy())
label = BoilingData(
torch.load(f'{temp_path}/sim_ouput.pt').numpy())
return pred, label

def main():
args = parse_args()

job_id = '25057303/'
pred, label = load_vel_data(f'test_im/temp/{job_id}')

plt_temp(pred.temp, label.temp, args.path, 'model')

subprocess.call(
f'ffmpeg -y -framerate 25 -pattern_type glob -i "{args.path}/*.png" output.mp4',
shell=True)

def temp_cmap():
temp_ranges = [0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.134, 0.167,
0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
color_codes = ['#0000FF', '#0443FF', '#0E7AFF', '#16B4FF', '#1FF1FF', '#21FFD3',
'#22FF9B', '#22FF67', '#22FF15', '#29FF06', '#45FF07', '#6DFF08',
'#9EFF09', '#D4FF0A', '#FEF30A', '#FEB709', '#FD7D08', '#FC4908',
'#FC1407', '#FB0007']
colors = list(zip(temp_ranges, color_codes))
cmap = LinearSegmentedColormap.from_list('temperature_colormap', colors)
return cmap

def fft(x):
x_fft = sfft.fft2(x)
x_shift = np.abs(sfft.fftshift(x_fft))
return x_shift

def mag(velx, vely):
return np.sqrt(velx**2 + vely**2)

def plt_vel(pred, label, path, model_name):
plt.rc("font", family="serif", size=16, weight="bold")
plt.rc("axes", labelweight="bold")

label_mag = mag(label.velx, label.vely)
pred_mag = mag(pred.velx, pred.vely)
mag_vmax = abs(pred_mag[:50]).max()
print(label_mag.max(), pred_mag.max())

frames = min(pred.temp.shape[0], 100)
for i in range(frames):
i_str = str(i).zfill(3)
f, ax = plt.subplots(2, 2, layout='constrained')

#x_vmax, x_vmin = label.velx.max(), label.velx.min()
#y_vmax, y_vmin = label.vely.max(), label.vely.min()

cm_object = ax[0, 0].imshow(np.flipud(label.temp[i]), vmin=0, vmax=1, cmap=temp_cmap())
#ax[1, 0].imshow(np.flipud(label.velx[i]), vmin=x_vmin, vmax=x_vmax, cmap='jet')
#ax[2, 0].imshow(np.flipud(label.vely[i]), vmin=y_vmin, vmax=y_vmax, cmap='jet')
#ax[1, 0].imshow(np.flipud(label_mag[i]), vmin=0, vmax=mag_vmax, cmap='jet')

ax[0, 1].imshow(np.flipud(np.nan_to_num(pred.temp[i])), vmin=0, vmax=1, cmap=temp_cmap())
#ax[1, 1].imshow(np.flipud(pred.velx[i]), vmin=x_vmin, vmax=x_vmax, cmap='jet')
#ax[2, 1].imshow(np.flipud(pred.vely[i]), vmin=y_vmin, vmax=x_vmax, cmap='jet')
#ax[1, 1].imshow(np.flipud(pred_mag[i]), vmin=0, vmax=mag_vmax, cmap='jet')

ax[0, 0].axis('off')
ax[1, 0].axis('off')
ax[0, 1].axis('off')
ax[1, 1].axis('off')

#ax[0, 2].imshow(np.flipud(fft(label.temp[i])))
#ax[1, 2].imshow(np.flipud(fft(label.velx[i])))
#ax[2, 2].imshow(np.flipud(fft(label.vely[i])))
#ax[3, 2].imshow(np.flipud(fft(label_mag)))

#ax[0, 3].imshow(np.flipud(fft(pred.temp[i])))
#ax[1, 3].imshow(np.flipud(fft(pred.velx[i])))
#ax[2, 3].imshow(np.flipud(fft(pred.vely[i])))
#ax[3, 3].imshow(np.flipud(fft(pred_mag)))

im_path = Path(path)
im_path.mkdir(parents=True, exist_ok=True)
plt.savefig(f'{str(im_path)}/{i_str}.png',
dpi=200,
bbox_inches='tight',
transparent=True)
plt.close()


def plt_temp(temps, labels, path, model_name):
print(temps.min(), temps.max(),
labels.min(), labels.max())

plt.rc("font", family="serif", size=16, weight="bold")
plt.rc("axes", labelweight="bold")
for i in range(len(temps)):
i_str = str(i).zfill(3)

def plt_temp_arr(f, ax, arr, title):
cm_object = ax.imshow(arr, vmin=0, vmax=1, cmap=temp_cmap())
#ax.set_title(title)
ax.axis('off')
return cm_object

temp = temps[i]
label = labels[i]
f, axarr = plt.subplots(2, 3, layout="constrained")
cm_object = plt_temp_arr(f, axarr[0, 0], np.flipud(label), 'Ground Truth')
cm_object = plt_temp_arr(f, axarr[0, 1], np.flipud(temp), model_name)

err = np.abs(temp - label)
cm_object = plt_temp_arr(f, axarr[0, 2], np.flipud(err), 'Absolute Error')
f.tight_layout()
f.colorbar(cm_object,
ax=axarr.ravel().tolist(),
ticks=[0, 0.2, 0.6, 0.9],
fraction=0.04,
pad=0.02)
f.set_size_inches(w=6, h=3)

label_h = fft(label)
temp_h = fft(temp)
err_h = np.abs(label_h - temp_h)

axarr[1, 0].imshow(np.flipud(label_h))
axarr[1, 1].imshow(np.flipud(temp_h))
axarr[1, 2].imshow(np.flipud(err_h))

im_path = Path(path)
im_path.mkdir(parents=True, exist_ok=True)
plt.savefig(f'{str(im_path)}/{i_str}.png',
dpi=600,
bbox_inches='tight',
transparent=True)
plt.close()

if __name__ == '__main__':
main()
225 changes: 160 additions & 65 deletions scripts/viz_vel.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,165 @@

import argparse
import torch
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, BoundaryNorm
from matplotlib.gridspec import GridSpec
import numpy as np
import os
from pathlib import Path
import subprocess
import scipy.fft as sfft
from dataclasses import dataclass

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--path', required=True, type=str,
help='Path to directory with model and sim output.pt files')
return parser.parse_args()

@dataclass
class BoilingData:
temp: torch.Tensor
velx: torch.Tensor
vely: torch.Tensor

def load_vel_data(temp_path, vel_path):
pred = BoilingData(
torch.load(f'{temp_path}/model_ouput.pt').numpy(),
torch.load(f'{vel_path}/velx_output.pt').numpy(),
torch.load(f'{vel_path}/vely_output.pt').numpy())
label = BoilingData(
torch.load(f'{temp_path}/sim_ouput.pt').numpy(),
torch.load(f'{vel_path}/velx_label.pt').numpy(),
torch.load(f'{vel_path}/vely_label.pt').numpy())
return pred, label

def main():
args = parse_args()

job_id = '25042240/'
pred, label = load_vel_data(f'test_im/temp/{job_id}', f'test_im/vel/{job_id}')

plt_vel(pred, label, args.path, 'model')

subprocess.call(
f'ffmpeg -y -framerate 25 -pattern_type glob -i "{args.path}/*.png" output.mp4',
shell=True)

def temp_cmap():
temp_ranges = [0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.134, 0.167,
0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
color_codes = ['#0000FF', '#0443FF', '#0E7AFF', '#16B4FF', '#1FF1FF', '#21FFD3',
'#22FF9B', '#22FF67', '#22FF15', '#29FF06', '#45FF07', '#6DFF08',
'#9EFF09', '#D4FF0A', '#FEF30A', '#FEB709', '#FD7D08', '#FC4908',
'#FC1407', '#FB0007']
colors = list(zip(temp_ranges, color_codes))
cmap = LinearSegmentedColormap.from_list('temperature_colormap', colors)
return cmap

def fft(x):
x_fft = sfft.fft2(x)
x_shift = np.abs(sfft.fftshift(x_fft))
return x_shift

def mag(velx, vely):
return np.sqrt(velx**2 + vely**2)

def plt_vel(pred, label, path, model_name):
plt.rc("font", family="serif", size=16, weight="bold")
plt.rc("axes", labelweight="bold")

label_mag = mag(label.velx, label.vely)
pred_mag = mag(pred.velx, pred.vely)
mag_vmax = abs(pred_mag[:50]).max()
print(label_mag.max(), pred_mag.max())

frames = min(pred.temp.shape[0], 100)
for i in range(frames):
i_str = str(i).zfill(3)
f, ax = plt.subplots(2, 2, layout='constrained')

x_vmax, x_vmin = label.velx.max(), label.velx.min()
y_vmax, y_vmin = label.vely.max(), label.vely.min()

cm_object = ax[0, 0].imshow(np.flipud(label.temp[i]), vmin=0, vmax=1, cmap=temp_cmap())
#ax[1, 0].imshow(np.flipud(label.velx[i]), vmin=x_vmin, vmax=x_vmax, cmap='jet')
#ax[2, 0].imshow(np.flipud(label.vely[i]), vmin=y_vmin, vmax=y_vmax, cmap='jet')
ax[1, 0].imshow(np.flipud(label_mag[i]), vmin=0, vmax=mag_vmax, cmap='jet')

ax[0, 1].imshow(np.flipud(np.nan_to_num(pred.temp[i])), vmin=0, vmax=1, cmap=temp_cmap())
#ax[1, 1].imshow(np.flipud(pred.velx[i]), vmin=x_vmin, vmax=x_vmax, cmap='jet')
#ax[2, 1].imshow(np.flipud(pred.vely[i]), vmin=y_vmin, vmax=x_vmax, cmap='jet')
ax[1, 1].imshow(np.flipud(pred_mag[i]), vmin=0, vmax=mag_vmax, cmap='jet')

ax[0, 0].axis('off')
ax[1, 0].axis('off')
ax[0, 1].axis('off')
ax[1, 1].axis('off')

#ax[0, 2].imshow(np.flipud(fft(label.temp[i])))
#ax[1, 2].imshow(np.flipud(fft(label.velx[i])))
#ax[2, 2].imshow(np.flipud(fft(label.vely[i])))
#ax[3, 2].imshow(np.flipud(fft(label_mag)))

#ax[0, 3].imshow(np.flipud(fft(pred.temp[i])))
#ax[1, 3].imshow(np.flipud(fft(pred.velx[i])))
#ax[2, 3].imshow(np.flipud(fft(pred.vely[i])))
#ax[3, 3].imshow(np.flipud(fft(pred_mag)))

im_path = Path(path)
im_path.mkdir(parents=True, exist_ok=True)
plt.savefig(f'{str(im_path)}/{i_str}.png',
dpi=200,
bbox_inches='tight',
transparent=True)
plt.close()


def plt_temp(temps, labels, path, model_name):
print(temps.min(), temps.max(),
labels.min(), labels.max())

plt.rc("font", family="serif", size=16, weight="bold")
plt.rc("axes", labelweight="bold")
for i in range(len(temps)):
i_str = str(i).zfill(3)

def plt_temp_arr(f, ax, arr, title):
cm_object = ax.imshow(arr, vmin=0, vmax=1, cmap=temp_cmap())
#ax.set_title(title)
ax.axis('off')
return cm_object

temp = temps[i].numpy()
label = labels[i].numpy()
f, axarr = plt.subplots(2, 3, layout="constrained")
cm_object = plt_temp_arr(f, axarr[0, 0], np.flipud(label), 'Ground Truth')
cm_object = plt_temp_arr(f, axarr[0, 1], np.flipud(temp), model_name)

err = np.abs(temp - label)
cm_object = plt_temp_arr(f, axarr[0, 2], np.flipud(err), 'Absolute Error')
f.tight_layout()
f.colorbar(cm_object,
ax=axarr.ravel().tolist(),
ticks=[0, 0.2, 0.6, 0.9],
fraction=0.04,
pad=0.02)
f.set_size_inches(w=6, h=3)

label_h = fft(label)
temp_h = fft(temp)
err_h = np.abs(label_h - temp_h)

axarr[1, 0].imshow(np.flipud(label_h))
axarr[1, 1].imshow(np.flipud(temp_h))
axarr[1, 2].imshow(np.flipud(err_h))

im_path = Path(path)
im_path.mkdir(parents=True, exist_ok=True)
plt.savefig(f'{str(im_path)}/{i_str}.png',
dpi=600,
bbox_inches='tight',
transparent=True)
plt.close()

velx_pred = torch.load('scripts/data/vel_unet_mod_push/velx_output.pt').numpy()
vely_pred = torch.load('scripts/data/vel_unet_mod_push/vely_output.pt').numpy()

velx_label = torch.load('scripts/data/vel_unet_mod_push/velx_label.pt').numpy()
vely_label = torch.load('scripts/data/vel_unet_mod_push/vely_label.pt').numpy()

w = velx_pred.shape[1]
d = 6
y, x = np.mgrid[d:w-d,d:w-d]
print(x, y)

temp_ranges = [0.0, 0.1, 0.4, 0.75, 1.0]
color_codes = ['black', 'purple', 'orange', 'yellow', 'white']
colors = list(zip(temp_ranges, color_codes))
cmap = LinearSegmentedColormap.from_list('vel_colormap', colors)

steps = list(range(1, velx_label.shape[0] // 2 + 1, 20))

plt.rc("font", family="serif", size=18, weight="bold")
plt.rc("axes", labelweight="bold")

fig, ax = plt.subplots(3, len(steps), figsize=(14.5, 7))

mag_label = np.sqrt(velx_label**2 + vely_label**2)
mag_pred = np.sqrt(velx_pred**2 + vely_pred**2)
mag_error = np.abs(mag_label - mag_pred)

vmin, vmax = 0, max(mag_label[steps].max(), mag_pred[steps].max())

for idx, t in enumerate(steps):
label_im = ax[0][idx].imshow(np.flipud(mag_label[t]), cmap=cmap, vmin=vmin, vmax=vmax)
ax[0][idx].streamplot(x, y,
velx_label[t,d:-d,d:-d],
vely_label[t,d:-d,d:-d],
linewidth=0.5,
density=0.75,
color='w',
arrowstyle='fancy')
ax[1][idx].imshow(np.flipud(mag_pred[t]), cmap=cmap, vmin=vmin, vmax=vmax)
ax[1][idx].streamplot(x, y,
velx_pred[t,d:-d,d:-d],
vely_pred[t,d:-d,d:-d],
linewidth=0.5,
density=0.75,
color='w',
arrowstyle='fancy')
error_im = ax[2][idx].imshow(np.flipud(mag_error[t]), cmap=cmap, vmin=vmin, vmax=mag_error[steps].max())
for i in range(3):
ax[i][idx].set_xticks([])
ax[i][idx].set_yticks([])
ax[0][0].set_ylabel('Ground Truth')
ax[1][0].set_ylabel('Prediction')
ax[2][0].set_ylabel('Abs. Error')

for i in range(len(steps)):
ax[0,i].set_title(f'Step {i*20}')

plt.tight_layout()
fig.subplots_adjust(wspace=0, hspace=0)
fig.colorbar(label_im, ax=ax[:2].ravel().tolist(), pad=0.005, shrink=0.5)
fig.colorbar(error_im, ax=ax[2].ravel().tolist(), pad=0.005)
plt.savefig(f'vel.pdf', dpi=500, bbox_inches='tight')
plt.close()
if __name__ == '__main__':
main()

0 comments on commit e9b81f3

Please sign in to comment.