diff --git a/README.md b/README.md index cbbfffd..5d2521c 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # BubbleML +[![Paper](https://img.shields.io/badge/arXiv-2209.15616-blue)](https://arxiv.org/abs/2307.14623) + A multi-physics dataset of boiling processes. This repository includes downloads, visualizations, and sample applications. @@ -82,3 +84,18 @@ then, to train a UNet model on the subcooled boiling dataset, just run ~~~~ python src/train.py data_base_dir=/your/path/to/BubbleML dataset=PB_SubCooled experiment=temp_unet ~~~~ + +## Citation + +If you find this dataset useful in your research, please consider citing the following paper: + +```bibtex +@article{hassan2023bubbleml, + title={BubbleML: A Multi-Physics Dataset and Benchmarks for Machine Learning}, + author={Sheikh Md Shakeel Hassan and Arthur Feeney and Akash Dhruv and Jihoon Kim and Youngjoon Suh and Jaiyoung Ryu and Yoonjin Won and Aparna Chandramowlishwaran}, + year={2023}, + eprint={2307.14623}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` diff --git a/src/models/PINO_util/fno.py b/src/models/PINO_util/fno.py new file mode 100644 index 0000000..f9d695a --- /dev/null +++ b/src/models/PINO_util/fno.py @@ -0,0 +1,654 @@ +import torch.nn as nn +import torch.nn.functional as F +from functools import partialmethod + +from neuralop.layers.spectral_convolution import SpectralConv +from neuralop.layers.spherical_convolution import SphericalConv +from neuralop.layers.padding import DomainPadding +from .fno_block import FNOBlocks, resample +from neuralop.layers.mlp import MLP + + +class FNO(nn.Module): + """N-Dimensional Fourier Neural Operator + + Parameters + ---------- + n_modes : int tuple + number of modes to keep in Fourier Layer, along each dimension + The dimensionality of the TFNO is inferred from ``len(n_modes)`` + hidden_channels : int + width of the FNO (i.e. number of channels) + in_channels : int, optional + Number of input channels, by default 3 + out_channels : int, optional + Number of output channels, by default 1 + lifting_channels : int, optional + number of hidden channels of the lifting block of the FNO, by default 256 + projection_channels : int, optional + number of hidden channels of the projection block of the FNO, by default 256 + n_layers : int, optional + Number of Fourier Layers, by default 4 + incremental_n_modes : None or int tuple, default is None + * If not None, this allows to incrementally increase the number of modes in Fourier domain + during training. Has to verify n <= N for (n, m) in zip(incremental_n_modes, n_modes). + + * If None, all the n_modes are used. + + This can be updated dynamically during training. + fno_block_precision : str {'full', 'half', 'mixed'} + if 'full', the FNO Block runs in full precision + if 'half', the FFT, contraction, and inverse FFT run in half precision + if 'mixed', the contraction and inverse FFT run in half precision + stabilizer : str {'tanh'} or None, optional + By default None, otherwise tanh is used before FFT in the FNO block + use_mlp : bool, optional + Whether to use an MLP layer after each FNO block, by default False + mlp_dropout : float + droupout parameter of MLP layer (default is 0) + mlp_expansion : float + expansion parameter of MLP layer (default is 0.5) + non_linearity : nn.Module, optional + Non-Linearity module to use, by default F.gelu + norm : F.module, optional + Normalization layer to use, by default None + preactivation : bool, default is False + if True, use resnet-style preactivation + skip : {'linear', 'identity', 'soft-gating'}, optional + Type of skip connection to use, by default 'soft-gating' + separable : bool, default is False + if True, use a depthwise separable spectral convolution + factorization : str or None, {'tucker', 'cp', 'tt'} + Tensor factorization of the parameters weight to use, by default None. + * If None, a dense tensor parametrizes the Spectral convolutions + * Otherwise, the specified tensor factorization is used. + joint_factorization : bool, optional + Whether all the Fourier Layers should be parametrized by a single tensor (vs one per layer), by default False + rank : float or rank, optional + Rank of the tensor factorization of the Fourier weights, by default 1.0 + fixed_rank_modes : bool, optional + Modes to not factorize, by default False + implementation : {'factorized', 'reconstructed'}, optional, default is 'factorized' + If factorization is not None, forward mode to use:: + * `reconstructed` : the full weight tensor is reconstructed from the factorization and used for the forward pass + * `factorized` : the input is directly contracted with the factors of the decomposition + decomposition_kwargs : dict, optional, default is {} + Optionaly additional parameters to pass to the tensor decomposition + domain_padding : None or float, optional + If not None, percentage of padding to use, by default None + domain_padding_mode : {'symmetric', 'one-sided'}, optional + How to perform domain padding, by default 'one-sided' + fft_norm : str, optional + by default 'forward' + """ + def __init__(self, n_modes, hidden_channels, + in_channels=3, + out_channels=1, + lifting_channels=256, + projection_channels=256, + n_layers=4, + output_scaling_factor=None, + incremental_n_modes=None, + fno_block_precision='full', + use_mlp=False, mlp_dropout=0, mlp_expansion=0.5, + non_linearity=F.gelu, + stabilizer=None, + norm=None, preactivation=False, + fno_skip='linear', + mlp_skip='soft-gating', + separable=False, + factorization=None, + rank=1.0, + joint_factorization=False, + fixed_rank_modes=False, + implementation='factorized', + decomposition_kwargs=dict(), + domain_padding=None, + domain_padding_mode='one-sided', + fft_norm='forward', + SpectralConv=SpectralConv, + render_default_scale = False, + **kwargs): + super().__init__() + self.n_dim = len(n_modes) + self.n_modes = n_modes + self.hidden_channels = hidden_channels + self.lifting_channels = lifting_channels + self.projection_channels = projection_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.n_layers = n_layers + self.joint_factorization = joint_factorization + self.non_linearity = non_linearity + self.rank = rank + self.factorization = factorization + self.fixed_rank_modes = fixed_rank_modes + self.decomposition_kwargs = decomposition_kwargs + self.fno_skip = fno_skip, + self.mlp_skip = mlp_skip, + self.fft_norm = fft_norm + self.implementation = implementation + self.separable = separable + self.preactivation = preactivation + self.fno_block_precision = fno_block_precision + + # See the class' property for underlying mechanism + # When updated, change should be reflected in fno blocks + self._incremental_n_modes = incremental_n_modes + self.render_default_scale = render_default_scale + + if domain_padding is not None and domain_padding > 0: + self.domain_padding = DomainPadding(domain_padding=domain_padding, padding_mode=domain_padding_mode, output_scaling_factor=output_scaling_factor) + if self.render_default_scale: + self.domain_padding_default_scale = DomainPadding(domain_padding=domain_padding, padding_mode=domain_padding_mode, output_scaling_factor=None) + else: + self.domain_padding = None + self.domain_padding_mode = domain_padding_mode + + if output_scaling_factor is not None and not joint_factorization: + if isinstance(output_scaling_factor, (float, int)): + output_scaling_factor = [output_scaling_factor]*self.n_layers + self.output_scaling_factor = output_scaling_factor + + self.fno_blocks = FNOBlocks( + in_channels=hidden_channels, + out_channels=hidden_channels, + n_modes=self.n_modes, + output_scaling_factor=output_scaling_factor, + use_mlp=use_mlp, mlp_dropout=mlp_dropout, mlp_expansion=mlp_expansion, + non_linearity=non_linearity, + stabilizer=stabilizer, + norm=norm, preactivation=preactivation, + fno_skip=fno_skip, + mlp_skip=mlp_skip, + incremental_n_modes=incremental_n_modes, + fno_block_precision=fno_block_precision, + rank=rank, + fft_norm=fft_norm, + fixed_rank_modes=fixed_rank_modes, + implementation=implementation, + separable=separable, + factorization=factorization, + decomposition_kwargs=decomposition_kwargs, + joint_factorization=joint_factorization, + SpectralConv=SpectralConv, + n_layers=n_layers) + #render_default_scale=self.render_default_scale) + + self.lifting = MLP(in_channels=in_channels, out_channels=self.hidden_channels, hidden_channels=self.hidden_channels, n_layers=1, n_dim=self.n_dim) + self.projection = MLP(in_channels=self.hidden_channels, out_channels=out_channels, hidden_channels=self.projection_channels, n_layers=2, n_dim=self.n_dim, non_linearity=non_linearity) + + def forward(self, x): + """TFNO's forward pass + """ + x = self.lifting(x) + + if self.domain_padding is not None: + x = self.domain_padding.pad(x) + if self.render_default_scale: + x_default_scale = self.domain_padding_default_scale.pad(x) + + for layer_idx in range(self.n_layers): + if self.render_default_scale: + x, x_default_scale = self.fno_blocks(x, layer_idx, default_render = x_default_scale) + else: + x = self.fno_blocks(x, layer_idx) + + + if self.domain_padding is not None: + x = self.domain_padding.unpad(x) + if self.render_default_scale: + x_default_scale = self.domain_padding_default_scale.unpad(x_default_scale) + + x = self.projection(x) + + if self.render_default_scale: + x_default_scale = self.projection(x_default_scale) + return x, x_default_scale + return x + + @property + def incremental_n_modes(self): + return self._incremental_n_modes + + @incremental_n_modes.setter + def incremental_n_modes(self, incremental_n_modes): + self.fno_blocks.incremental_n_modes = incremental_n_modes + + +class FNO1d(FNO): + """1D Fourier Neural Operator + + Parameters + ---------- + modes_height : int + number of Fourier modes to keep along the height + hidden_channels : int + width of the FNO (i.e. number of channels) + in_channels : int, optional + Number of input channels, by default 3 + out_channels : int, optional + Number of output channels, by default 1 + lifting_channels : int, optional + number of hidden channels of the lifting block of the FNO, by default 256 + projection_channels : int, optional + number of hidden channels of the projection block of the FNO, by default 256 + n_layers : int, optional + Number of Fourier Layers, by default 4 + incremental_n_modes : None or int tuple, default is None + * If not None, this allows to incrementally increase the number of modes in Fourier domain + during training. Has to verify n <= N for (n, m) in zip(incremental_n_modes, n_modes). + + * If None, all the n_modes are used. + + This can be updated dynamically during training. + fno_block_precision : str {'full', 'half', 'mixed'} + if 'full', the FNO Block runs in full precision + if 'half', the FFT, contraction, and inverse FFT run in half precision + if 'mixed', the contraction and inverse FFT run in half precision + stabilizer : str {'tanh'} or None, optional + By default None, otherwise tanh is used before FFT in the FNO block + use_mlp : bool, optional + Whether to use an MLP layer after each FNO block, by default False + mlp : dict, optional + Parameters of the MLP, by default None + {'expansion': float, 'dropout': float} + non_linearity : nn.Module, optional + Non-Linearity module to use, by default F.gelu + norm : F.module, optional + Normalization layer to use, by default None + preactivation : bool, default is False + if True, use resnet-style preactivation + skip : {'linear', 'identity', 'soft-gating'}, optional + Type of skip connection to use, by default 'soft-gating' + separable : bool, default is False + if True, use a depthwise separable spectral convolution + factorization : str or None, {'tucker', 'cp', 'tt'} + Tensor factorization of the parameters weight to use, by default None. + * If None, a dense tensor parametrizes the Spectral convolutions + * Otherwise, the specified tensor factorization is used. + joint_factorization : bool, optional + Whether all the Fourier Layers should be parametrized by a single tensor (vs one per layer), by default False + rank : float or rank, optional + Rank of the tensor factorization of the Fourier weights, by default 1.0 + fixed_rank_modes : bool, optional + Modes to not factorize, by default False + implementation : {'factorized', 'reconstructed'}, optional, default is 'factorized' + If factorization is not None, forward mode to use:: + * `reconstructed` : the full weight tensor is reconstructed from the factorization and used for the forward pass + * `factorized` : the input is directly contracted with the factors of the decomposition + decomposition_kwargs : dict, optional, default is {} + Optionaly additional parameters to pass to the tensor decomposition + domain_padding : None or float, optional + If not None, percentage of padding to use, by default None + domain_padding_mode : {'symmetric', 'one-sided'}, optional + How to perform domain padding, by default 'one-sided' + fft_norm : str, optional + by default 'forward' + """ + def __init__( + self, + n_modes_height, + hidden_channels, + in_channels=3, + out_channels=1, + lifting_channels=256, + projection_channels=256, + incremental_n_modes=None, + fno_block_precision='full', + n_layers=4, + output_scaling_factor=None, + non_linearity=F.gelu, + stabilizer=None, + use_mlp=False, mlp_dropout=0, mlp_expansion=0.5, + norm=None, + skip='soft-gating', + separable=False, + preactivation=False, + factorization=None, + rank=1.0, + joint_factorization=False, + fixed_rank_modes=False, + implementation='factorized', + decomposition_kwargs=dict(), + domain_padding=None, + domain_padding_mode='one-sided', + fft_norm='forward', + **kwargs): + super().__init__( + n_modes=(n_modes_height, ), + hidden_channels=hidden_channels, + in_channels=in_channels, + out_channels=out_channels, + lifting_channels=lifting_channels, + projection_channels=projection_channels, + n_layers=n_layers, + output_scaling_factor=None, + non_linearity=non_linearity, + stabilizer=stabilizer, + use_mlp=use_mlp, mlp_dropout=mlp_dropout, mlp_expansion=mlp_expansion, + incremental_n_modes=incremental_n_modes, + fno_block_precision=fno_block_precision, + norm=norm, + skip=skip, + separable=separable, + preactivation=preactivation, + factorization=factorization, + rank=rank, + joint_factorization=joint_factorization, + fixed_rank_modes=fixed_rank_modes, + implementation=implementation, + decomposition_kwargs=decomposition_kwargs, + domain_padding=domain_padding, + domain_padding_mode=domain_padding_mode, + fft_norm=fft_norm + ) + self.n_modes_height = n_modes_height + + +class FNO2d(FNO): + """2D Fourier Neural Operator + + Parameters + ---------- + n_modes_width : int + number of modes to keep in Fourier Layer, along the width + n_modes_height : int + number of Fourier modes to keep along the height + hidden_channels : int + width of the FNO (i.e. number of channels) + in_channels : int, optional + Number of input channels, by default 3 + out_channels : int, optional + Number of output channels, by default 1 + lifting_channels : int, optional + number of hidden channels of the lifting block of the FNO, by default 256 + projection_channels : int, optional + number of hidden channels of the projection block of the FNO, by default 256 + n_layers : int, optional + Number of Fourier Layers, by default 4 + incremental_n_modes : None or int tuple, default is None + * If not None, this allows to incrementally increase the number of modes in Fourier domain + during training. Has to verify n <= N for (n, m) in zip(incremental_n_modes, n_modes). + + * If None, all the n_modes are used. + + This can be updated dynamically during training. + fno_block_precision : str {'full', 'half', 'mixed'} + if 'full', the FNO Block runs in full precision + if 'half', the FFT, contraction, and inverse FFT run in half precision + if 'mixed', the contraction and inverse FFT run in half precision + stabilizer : str {'tanh'} or None, optional + By default None, otherwise tanh is used before FFT in the FNO block + use_mlp : bool, optional + Whether to use an MLP layer after each FNO block, by default False + mlp : dict, optional + Parameters of the MLP, by default None + {'expansion': float, 'dropout': float} + non_linearity : nn.Module, optional + Non-Linearity module to use, by default F.gelu + norm : F.module, optional + Normalization layer to use, by default None + preactivation : bool, default is False + if True, use resnet-style preactivation + skip : {'linear', 'identity', 'soft-gating'}, optional + Type of skip connection to use, by default 'soft-gating' + separable : bool, default is False + if True, use a depthwise separable spectral convolution + factorization : str or None, {'tucker', 'cp', 'tt'} + Tensor factorization of the parameters weight to use, by default None. + * If None, a dense tensor parametrizes the Spectral convolutions + * Otherwise, the specified tensor factorization is used. + joint_factorization : bool, optional + Whether all the Fourier Layers should be parametrized by a single tensor (vs one per layer), by default False + rank : float or rank, optional + Rank of the tensor factorization of the Fourier weights, by default 1.0 + fixed_rank_modes : bool, optional + Modes to not factorize, by default False + implementation : {'factorized', 'reconstructed'}, optional, default is 'factorized' + If factorization is not None, forward mode to use:: + * `reconstructed` : the full weight tensor is reconstructed from the factorization and used for the forward pass + * `factorized` : the input is directly contracted with the factors of the decomposition + decomposition_kwargs : dict, optional, default is {} + Optionaly additional parameters to pass to the tensor decomposition + domain_padding : None or float, optional + If not None, percentage of padding to use, by default None + domain_padding_mode : {'symmetric', 'one-sided'}, optional + How to perform domain padding, by default 'one-sided' + fft_norm : str, optional + by default 'forward' + """ + def __init__( + self, + n_modes_height, + n_modes_width, + hidden_channels, + in_channels=3, + out_channels=1, + lifting_channels=256, + projection_channels=256, + n_layers=4, + output_scaling_factor=None, + incremental_n_modes=None, + fno_block_precision='full', + non_linearity=F.gelu, + stabilizer=None, + use_mlp=False, mlp_dropout=0, mlp_expansion=0.5, + norm=None, + skip='soft-gating', + separable=False, + preactivation=False, + factorization=None, + rank=1.0, + joint_factorization=False, + fixed_rank_modes=False, + implementation='factorized', + decomposition_kwargs=dict(), + domain_padding=None, + domain_padding_mode='one-sided', + fft_norm='forward', + **kwargs): + super().__init__( + n_modes=(n_modes_height, n_modes_width), + hidden_channels=hidden_channels, + in_channels=in_channels, + out_channels=out_channels, + lifting_channels=lifting_channels, + projection_channels=projection_channels, + n_layers=n_layers, + output_scaling_factor=None, + non_linearity=non_linearity, + stabilizer=stabilizer, + use_mlp=use_mlp, mlp_dropout=mlp_dropout, mlp_expansion=mlp_expansion, + incremental_n_modes=incremental_n_modes, + fno_block_precision=fno_block_precision, + norm=norm, + skip=skip, + separable=separable, + preactivation=preactivation, + factorization=factorization, + rank=rank, + joint_factorization=joint_factorization, + fixed_rank_modes=fixed_rank_modes, + implementation=implementation, + decomposition_kwargs=decomposition_kwargs, + domain_padding=domain_padding, + domain_padding_mode=domain_padding_mode, + fft_norm=fft_norm + ) + self.n_modes_height = n_modes_height + self.n_modes_width = n_modes_width + + + +class FNO3d(FNO): + """3D Fourier Neural Operator + + Parameters + ---------- + modes_width : int + number of modes to keep in Fourier Layer, along the width + modes_height : int + number of Fourier modes to keep along the height + modes_depth : int + number of Fourier modes to keep along the depth + hidden_channels : int + width of the FNO (i.e. number of channels) + in_channels : int, optional + Number of input channels, by default 3 + out_channels : int, optional + Number of output channels, by default 1 + lifting_channels : int, optional + number of hidden channels of the lifting block of the FNO, by default 256 + projection_channels : int, optional + number of hidden channels of the projection block of the FNO, by default 256 + n_layers : int, optional + Number of Fourier Layers, by default 4 + incremental_n_modes : None or int tuple, default is None + * If not None, this allows to incrementally increase the number of modes in Fourier domain + during training. Has to verify n <= N for (n, m) in zip(incremental_n_modes, n_modes). + + * If None, all the n_modes are used. + + This can be updated dynamically during training. + fno_block_precision : str {'full', 'half', 'mixed'} + if 'full', the FNO Block runs in full precision + if 'half', the FFT, contraction, and inverse FFT run in half precision + if 'mixed', the contraction and inverse FFT run in half precision + stabilizer : str {'tanh'} or None, optional + By default None, otherwise tanh is used before FFT in the FNO block + use_mlp : bool, optional + Whether to use an MLP layer after each FNO block, by default False + mlp : dict, optional + Parameters of the MLP, by default None + {'expansion': float, 'dropout': float} + non_linearity : nn.Module, optional + Non-Linearity module to use, by default F.gelu + norm : F.module, optional + Normalization layer to use, by default None + preactivation : bool, default is False + if True, use resnet-style preactivation + skip : {'linear', 'identity', 'soft-gating'}, optional + Type of skip connection to use, by default 'soft-gating' + separable : bool, default is False + if True, use a depthwise separable spectral convolution + factorization : str or None, {'tucker', 'cp', 'tt'} + Tensor factorization of the parameters weight to use, by default None. + * If None, a dense tensor parametrizes the Spectral convolutions + * Otherwise, the specified tensor factorization is used. + joint_factorization : bool, optional + Whether all the Fourier Layers should be parametrized by a single tensor (vs one per layer), by default False + rank : float or rank, optional + Rank of the tensor factorization of the Fourier weights, by default 1.0 + fixed_rank_modes : bool, optional + Modes to not factorize, by default False + implementation : {'factorized', 'reconstructed'}, optional, default is 'factorized' + If factorization is not None, forward mode to use:: + * `reconstructed` : the full weight tensor is reconstructed from the factorization and used for the forward pass + * `factorized` : the input is directly contracted with the factors of the decomposition + decomposition_kwargs : dict, optional, default is {} + Optionaly additional parameters to pass to the tensor decomposition + domain_padding : None or float, optional + If not None, percentage of padding to use, by default None + domain_padding_mode : {'symmetric', 'one-sided'}, optional + How to perform domain padding, by default 'one-sided' + fft_norm : str, optional + by default 'forward' + """ + def __init__(self, + n_modes_height, + n_modes_width, + n_modes_depth, + hidden_channels, + in_channels=3, + out_channels=1, + lifting_channels=256, + projection_channels=256, + n_layers=4, + output_scaling_factor=None, + incremental_n_modes=None, + fno_block_precision='full', + non_linearity=F.gelu, + stabilizer=None, + use_mlp=False, mlp_dropout=0, mlp_expansion=0.5, + norm=None, + skip='soft-gating', + separable=False, + preactivation=False, + factorization=None, + rank=1.0, + joint_factorization=False, + fixed_rank_modes=False, + implementation='factorized', + decomposition_kwargs=dict(), + domain_padding=None, + domain_padding_mode='one-sided', + fft_norm='forward', + **kwargs): + super().__init__( + n_modes=(n_modes_height, n_modes_width, n_modes_depth), + hidden_channels=hidden_channels, + in_channels=in_channels, + out_channels=out_channels, + lifting_channels=lifting_channels, + projection_channels=projection_channels, + n_layers=n_layers, + output_scaling_factor=None, + non_linearity=non_linearity, + stabilizer=stabilizer, + incremental_n_modes=incremental_n_modes, + fno_block_precision=fno_block_precision, + use_mlp=use_mlp, mlp_dropout=mlp_dropout, mlp_expansion=mlp_expansion, + norm=norm, + skip=skip, + separable=separable, + preactivation=preactivation, + factorization=factorization, + rank=rank, + joint_factorization=joint_factorization, + fixed_rank_modes=fixed_rank_modes, + implementation=implementation, + decomposition_kwargs=decomposition_kwargs, + domain_padding=domain_padding, + domain_padding_mode=domain_padding_mode, + fft_norm=fft_norm + ) + self.n_modes_height = n_modes_height + self.n_modes_width = n_modes_width + self.n_modes_height = n_modes_height + + +def partialclass(new_name, cls, *args, **kwargs): + """Create a new class with different default values + + Notes + ----- + An obvious alternative would be to use functools.partial + >>> new_class = partial(cls, **kwargs) + + The issue is twofold: + 1. the class doesn't have a name, so one would have to set it explicitly: + >>> new_class.__name__ = new_name + + 2. the new class will be a functools object and one cannot inherit from it. + + Instead, here, we define dynamically a new class, inheriting from the existing one. + """ + __init__ = partialmethod(cls.__init__, *args, **kwargs) + new_class = type(new_name, (cls,), { + '__init__': __init__, + '__doc__': cls.__doc__, + 'forward': cls.forward, + }) + return new_class + + +TFNO = partialclass('TFNO', FNO, factorization='Tucker') +TFNO1d = partialclass('TFNO1d', FNO1d, factorization='Tucker') +TFNO2d = partialclass('TFNO2d', FNO2d, factorization='Tucker') +TFNO3d = partialclass('TFNO3d', FNO3d, factorization='Tucker') + +SFNO = partialclass('SFNO', FNO, factorization='dense', SpectralConv=SphericalConv) +SFNO.__doc__ = SFNO.__doc__.replace('Fourier', 'Spherical Fourier', 1) +SFNO.__doc__ = SFNO.__doc__.replace('FNO', 'SFNO') +SFNO.__doc__ = SFNO.__doc__.replace('fno', 'sfno') \ No newline at end of file diff --git a/src/models/PINO_util/fno_block.py b/src/models/PINO_util/fno_block.py new file mode 100644 index 0000000..e4a6c8d --- /dev/null +++ b/src/models/PINO_util/fno_block.py @@ -0,0 +1,253 @@ +from torch import nn +import torch.nn.functional as F +import torch +from neuralop.layers.spectral_convolution import SpectralConv +from neuralop.layers.skip_connections import skip_connection +from neuralop.layers.resample import resample +from neuralop.layers.mlp import MLP +from neuralop.layers.normalization_layers import AdaIN + +class FNOBlocks(nn.Module): + def __init__(self, in_channels, out_channels, n_modes, + output_scaling_factor=None, + n_layers=1, + incremental_n_modes=None, + fno_block_precision='full', + use_mlp=False, mlp_dropout=0, mlp_expansion=0.5, + non_linearity=F.gelu, + stabilizer=None, + norm=None, ada_in_features=None, + preactivation=False, + fno_skip='linear', + mlp_skip='soft-gating', + separable=False, + factorization=None, + rank=1.0, + SpectralConv=SpectralConv, + joint_factorization=False, + fixed_rank_modes=False, + implementation='factorized', + decomposition_kwargs=dict(), + fft_norm='forward', + #render_default_scale = False, + **kwargs): + super().__init__() + if isinstance(n_modes, int): + n_modes = [n_modes] + self.n_modes = n_modes + self.n_dim = len(n_modes) + + if output_scaling_factor is not None: + if isinstance(output_scaling_factor, (float, int)): + output_scaling_factor = [[float(output_scaling_factor)]*len(self.n_modes)]*n_layers + elif isinstance(output_scaling_factor[0], (float, int)): + output_scaling_factor = [[s]*len(self.n_modes) for s in output_scaling_factor] + self.output_scaling_factor = output_scaling_factor + + self._incremental_n_modes = incremental_n_modes + self.fno_block_preicison = fno_block_precision + self.in_channels = in_channels + self.out_channels = out_channels + self.n_layers = n_layers + self.joint_factorization = joint_factorization + self.non_linearity = non_linearity + self.stabilizer = stabilizer + self.rank = rank + self.factorization = factorization + self.fixed_rank_modes = fixed_rank_modes + self.decomposition_kwargs = decomposition_kwargs + self.fno_skip = fno_skip + self.mlp_skip = mlp_skip + self.use_mlp = use_mlp + self.mlp_expansion = mlp_expansion + self.mlp_dropout = mlp_dropout + self.fft_norm = fft_norm + self.implementation = implementation + self.separable = separable + self.preactivation = preactivation + self.ada_in_features = ada_in_features + #self.render_default_scale = render_default_scale + + self.convs = SpectralConv( + self.in_channels, self.out_channels, self.n_modes, + output_scaling_factor=output_scaling_factor, + incremental_n_modes=incremental_n_modes, + fno_block_precision=fno_block_precision, + rank=rank, + fft_norm=fft_norm, + fixed_rank_modes=fixed_rank_modes, + implementation=implementation, + separable=separable, + factorization=factorization, + decomposition_kwargs=decomposition_kwargs, + joint_factorization=joint_factorization, + n_layers=n_layers, + ) + + self.fno_skips = nn.ModuleList([skip_connection(self.in_channels, self.out_channels, type=fno_skip, n_dim=self.n_dim) for _ in range(n_layers)]) + + if use_mlp: + self.mlp = nn.ModuleList( + [MLP(in_channels=self.out_channels, + hidden_channels=int(round(self.out_channels*mlp_expansion)), + dropout=mlp_dropout, n_dim=self.n_dim) for _ in range(n_layers)] + ) + self.mlp_skips = nn.ModuleList([skip_connection(self.in_channels, self.out_channels, type=mlp_skip, n_dim=self.n_dim) for _ in range(n_layers)]) + else: + self.mlp = None + + # Each block will have 2 norms if we also use an MLP + self.n_norms = 1 if self.mlp is None else 2 + if norm is None: + self.norm = None + elif norm == 'instance_norm': + self.norm = nn.ModuleList([getattr(nn, f'InstanceNorm{self.n_dim}d')(num_features=self.out_channels) for _ in range(n_layers*self.n_norms)]) + elif norm == 'group_norm': + self.norm = nn.ModuleList([nn.GroupNorm(num_groups=1, num_channels=self.out_channels) for _ in range(n_layers*self.n_norms)]) + # elif norm == 'layer_norm': + # self.norm = nn.ModuleList([nn.LayerNorm(elementwise_affine=False) for _ in range(n_layers*self.n_norms)]) + elif norm == 'ada_in': + self.norm = nn.ModuleList([AdaIN(ada_in_features, out_channels) for _ in range(n_layers*self.n_norms)]) + else: + raise ValueError(f'Got {norm=} but expected None or one of [instance_norm, group_norm, layer_norm]') + + def set_ada_in_embeddings(self, *embeddings): + """Sets the embeddings of each Ada-IN norm layers + + Parameters + ---------- + embeddings : tensor or list of tensor + if a single embedding is given, it will be used for each norm layer + otherwise, each embedding will be used for the corresponding norm layer + """ + if len(embeddings) == 1: + for norm in self.norm: + norm.set_embedding(embeddings[0]) + else: + for norm, embedding in zip(self.norm, embeddings): + norm.set_embedding(embedding) + + def forward(self, x, index=0, output_shape = None, default_render = None): + + if self.preactivation: + x = self.non_linearity(x) + if default_render is not None: + default_render = self.non_linearity(default_render) + + if self.norm is not None: + x = self.norm[self.n_norms*index](x) + if default_render is not None: + default_render = self.norm[self.n_norms*index](default_render) + + x_skip_fno = self.fno_skips[index](x) + if default_render is not None: + x_skip_fno_default_scale = self.fno_skips[index](default_render)#no need to resample + + if self.convs.output_scaling_factor is not None: + # x_skip_fno = resample(x_skip_fno, self.convs.output_scaling_factor[index], list(range(-len(self.convs.output_scaling_factor[index]), 0))) + x_skip_fno = resample(x_skip_fno, self.output_scaling_factor[index]\ + , list(range(-len(self.output_scaling_factor[index]), 0)), output_shape = output_shape ) + + if self.mlp is not None: + x_skip_mlp = self.mlp_skips[index](x) + if default_render is not None: + x_skip_mlp_default_scale = self.mlp_skips[index](default_render)#no need to resample + if self.convs.output_scaling_factor is not None: + x_skip_mlp = resample(x_skip_mlp, self.output_scaling_factor[index]\ + , list(range(-len(self.output_scaling_factor[index]), 0)), output_shape = output_shape ) + + if self.stabilizer == 'tanh': + x = torch.tanh(x) + if default_render is not None: + default_render = torch.tanh(default_render) + + x_fno = self.convs(x, index, output_shape=output_shape) + if default_render is not None: + _, _, *default_mode_size = x_fno_default_scale.shape + x_fno_default_scale = self.convs(default_render, index, output_shape=default_mode_size) + + if not self.preactivation and self.norm is not None: + x_fno = self.norm[self.n_norms*index](x_fno) + if default_render is not None: + x_fno_default_scale = self.norm[self.n_norms*index](x_fno_default_scale) + + x = x_fno + x_skip_fno + if default_render is not None: + default_render = x_fno_default_scale + x_skip_fno_default_scale + + if not self.preactivation and (self.mlp is not None) or (index < (self.n_layers - index)): + x = self.non_linearity(x) + if default_render is not None: + default_render = self.non_linearity(default_render) + + if self.mlp is not None: + # x_skip = self.mlp_skips[index](x) + + if self.preactivation: + if index < (self.n_layers - 1): + x = self.non_linearity(x) + if default_render is not None: + default_render = self.non_linearity(default_render) + + if self.norm is not None: + x = self.norm[self.n_norms*index+1](x) + if default_render is not None: + default_render = self.norm[self.n_norms*index+1](default_render) + + x = self.mlp[index](x) + x_skip_mlp + if default_render is not None: + default_render = self.mlp[index](default_render) + x_skip_mlp_default_scale + + if not self.preactivation and self.norm is not None: + x = self.norm[self.n_norms*index+1](x) + if default_render is not None: + default_render = self.norm[self.n_norms*index+1](default_render) + + if not self.preactivation: + if index < (self.n_layers - 1): + x = self.non_linearity(x) + if default_render is not None: + default_render = self.non_linearity(default_render) + + if default_render is not None: + return x, default_render + return x + + @property + def incremental_n_modes(self): + return self._incremental_n_modes + + @incremental_n_modes.setter + def incremental_n_modes(self, incremental_n_modes): + self.convs.incremental_n_modes = incremental_n_modes + + def get_block(self, indices): + """Returns a sub-FNO Block layer from the jointly parametrized main block + + The parametrization of an FNOBlock layer is shared with the main one. + """ + if self.n_layers == 1: + raise ValueError('A single layer is parametrized, directly use the main class.') + + return SubModule(self, indices) + + def __getitem__(self, indices): + return self.get_block(indices) + + +class SubModule(nn.Module): + """Class representing one of the sub_module from the mother joint module + + Notes + ----- + This relies on the fact that nn.Parameters are not duplicated: + if the same nn.Parameter is assigned to multiple modules, they all point to the same data, + which is shared. + """ + def __init__(self, main_module, indices): + super().__init__() + self.main_module = main_module + self.indices = indices + + def forward(self, x): + return self.main_module.forward(x, self.indices) \ No newline at end of file diff --git a/src/op_lib/losses.py b/src/op_lib/losses.py index bb17fa3..cbf4a62 100644 --- a/src/op_lib/losses.py +++ b/src/op_lib/losses.py @@ -4,9 +4,11 @@ """ import math import torch +import torch.nn.functional as F +from neuralop.layers.resample import resample class LpLoss(object): - def __init__(self, d=1, p=2, L=2*math.pi, reduce_dims=0, reductions='sum'): + def __init__(self, d=1, p=2, L=2*math.pi, reduce_dims=0, reductions='sum', add_PDE_LOSS = False): super().__init__() self.d = d @@ -180,3 +182,48 @@ def rel(self, x, y, h=None): def __call__(self, x, y, h=None): return self.rel(x, y, h=h) + +def temp_stokes_loss2D(T, u, v, T_prev, resolution_scaling, dt): + batchsize = T.size(0) + nx = T.size(2) + ny = T.size(3) + + device = T.device + T = T.reshape(batchsize, nx, ny) + u = u.reshape(batchsize, nx, ny) + v = v.reshape(batchsize, nx, ny) + if T_prev.size(-2) != T.size(-2) or T_prev.size(-1) != T.size(-1): + T_prev = resample(T_prev, resolution_scaling, [-2, -1], output_shape=T.shape) + + T_h = torch.fft.fft2(T, dim=[-2, -1]) + u_h = torch.fft.fft2(u, dim=[-2, -1]) + v_h = torch.fft.fft2(v, dim=[-2, -1]) + # Wavenumbers in y-direction + k_maxx = nx//2 + k_maxy = ny//2 + Nx = nx + Ny = ny + k_x = torch.cat((torch.arange(start=0, end=k_maxx, step=1, device=device), + torch.arange(start=-k_maxx, end=0, step=1, device=device)), 0).reshape(Nx, 1).repeat(1, Ny).reshape(1,Nx,Ny) + k_y = torch.cat((torch.arange(start=0, end=k_maxy, step=1, device=device), + torch.arange(start=-k_maxy, end=0, step=1, device=device)), 0).reshape(1, Ny).repeat(Nx, 1).reshape(1,Nx,Ny) + #Laplacian in Fourier space + lap = (k_x ** 2 + k_y ** 2) + lap[0, 0, 0] = 1.0 + + Ty_h = 1j * k_y * T_h + Tx_h = 1j * k_x * T_h + Tlap_h = lap * T_h + + Txu_conv_h = F.conv2d(u_h, Tx_h, stride = 1, padding = (0,0))*(1/(4*(math.pi**2))) + Tyv_conv_h = F.conv2d(v_h, Ty_h, stride = 1, padding = (0,0))*(1/(4*(math.pi**2))) + + gradTdotu_h = Tyv_conv_h + Txu_conv_h + + gradTdotu = torch.fft.irfft2(gradTdotu_h[:, :, :k_maxy + 1], dim=[-2, -1]) + Tlap = torch.fft.irfft2(Tlap_h[:, :, :k_maxy+1], dim=[-2,-1]) + + Tdt = (T-T_prev)/dt + + PDE_LOSS = torch.sum(torch.square(Tdt + gradTdotu - Tlap)) + return PDE_LOSS diff --git a/src/op_lib/temp_trainer.py b/src/op_lib/temp_trainer.py index d7079fe..990c65c 100644 --- a/src/op_lib/temp_trainer.py +++ b/src/op_lib/temp_trainer.py @@ -36,7 +36,8 @@ def __init__(self, lr_scheduler, val_variable, writer, - cfg): + cfg, + add_PDE_LOSS = False): self.model = model self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader @@ -45,7 +46,7 @@ def __init__(self, self.val_variable = val_variable self.writer = writer self.cfg = cfg - self.loss = LpLoss(d=2, reduce_dims=[0, 1]) + self.loss = LpLoss(d=2, add_PDE_LOSS=add_PDE_LOSS, reduce_dims=[0, 1]) self.push_forward_steps = push_forward_steps self.future_window = future_window diff --git a/src/train.py b/src/train.py index a06cacf..0cb6d04 100644 --- a/src/train.py +++ b/src/train.py @@ -9,6 +9,8 @@ import torchvision.transforms.functional as TF import matplotlib.pyplot as plt import numpy as np +from neuralop.models import UNO +from ..models.PINO_util.fno import FNO from pathlib import Path import os import time