Skip to content

Commit

Permalink
Add Support for Convolutional Neural Operator
Browse files Browse the repository at this point in the history
Add the convolutional neural operator as an optional model. This includes a large number of utilities that, unfortunately, seem necessary. This currently just exposes a few parameters to adjust the input/output dimensions and the number of layers
  • Loading branch information
neelsankaran authored Feb 29, 2024
1 parent 41cc536 commit dc6cca8
Show file tree
Hide file tree
Showing 33 changed files with 6,394 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ im/
test_im/

# slurm and tensorboard logs output files
outputs/
logs/
maes*
slurm*.out
Expand Down
2 changes: 1 addition & 1 deletion conf/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ model_checkpoint:
defaults:
- _self_
- dataset: PB_WallSuperHeat
- experiment: temp_unet2d
- experiment: paper/unet_arena/pb_temp
28 changes: 28 additions & 0 deletions conf/experiment/experimental/cno/pb_temp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
torch_dataset_name: temp_input_dataset

distributed: False

train:
max_epochs: 250
batch_size: 4
shuffle_data: True
time_window: 5
future_window: 5
push_forward_steps: 1
use_coords: True
noise: True
downsample_factor: 1

model:
model_name: cno
in_size: 512
n_layers: 6


optimizer:
initial_lr: 1e-3
weight_decay: 1e-6

lr_scheduler:
name: cosine
eta_min: 1e-5
556 changes: 556 additions & 0 deletions sciml/models/ConvolutionalNeuralOperator/CNOModule.py

Large diffs are not rendered by default.

103 changes: 103 additions & 0 deletions sciml/models/ConvolutionalNeuralOperator/debug_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import math
import torch

units = {
0: 'B',
1: 'KiB',
2: 'MiB',
3: 'GiB',
4: 'TiB'
}


def format_mem(x):
"""
Takes integer 'x' in bytes and returns a number in [0, 1024) and
the corresponding unit.
"""
if abs(x) < 1024:
return round(x, 2), 'B'

scale = math.log2(abs(x)) // 10
scaled_x = x / 1024 ** scale
unit = units[scale]

if int(scaled_x) == scaled_x:
return int(scaled_x), unit

# rounding leads to 2 or fewer decimal places, as required
return round(scaled_x, 2), unit


def format_tensor_size(x):
val, unit = format_mem(x)
return f'{val}{unit}'


class CudaMemoryDebugger():
"""
Helper to track changes in CUDA memory.
"""
DEVICE = 'cuda'
LAST_MEM = 0
ENABLED = True


def __init__(self, print_mem):
self.print_mem = print_mem
if not CudaMemoryDebugger.ENABLED:
return

cur_mem = torch.cuda.memory_allocated(CudaMemoryDebugger.DEVICE)
cur_mem_fmt, cur_mem_unit = format_mem(cur_mem)
print(f'cuda allocated (initial): {cur_mem_fmt:.2f}{cur_mem_unit}')
CudaMemoryDebugger.LAST_MEM = cur_mem

def print(self,id_str=None):
if not CudaMemoryDebugger.ENABLED:
return

desc = 'cuda allocated'

if id_str is not None:
desc += f' ({id_str})'

desc += ':'

cur_mem = torch.cuda.memory_allocated(CudaMemoryDebugger.DEVICE)
cur_mem_fmt, cur_mem_unit = format_mem(cur_mem)

diff = cur_mem - CudaMemoryDebugger.LAST_MEM
if self.print_mem:
if diff == 0:
print(f'{desc} {cur_mem_fmt:.2f}{cur_mem_unit} (no change)')

else:
diff_fmt, diff_unit = format_mem(diff)
print(f'{desc} {cur_mem_fmt:.2f}{cur_mem_unit}'
f' ({diff_fmt:+}{diff_unit})')

CudaMemoryDebugger.LAST_MEM = cur_mem


def print_tensor_mem(x, id_str=None):
"""
Prints the memory required by tensor 'x'.
"""
if not CudaMemoryDebugger.ENABLED:
return

desc = 'memory'

if id_str is not None:
desc += f' ({id_str})'

desc += ':'

val, unit = format_mem(x.element_size() * x.nelement())

print(f'{desc} {val}{unit}')

9 changes: 9 additions & 0 deletions sciml/models/ConvolutionalNeuralOperator/dnnlib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

from .util import EasyDict, make_cache_dir_path
Loading

0 comments on commit dc6cca8

Please sign in to comment.