diff --git a/Readme.md b/Readme.md new file mode 100644 index 0000000..29b4089 --- /dev/null +++ b/Readme.md @@ -0,0 +1,57 @@ +## Networks + +This folder contains models for training and associated code. Models that are currently supported can be queried by calling `networks.models.list_models()`. + +### Directory structure +This folder is organized as follows: + +``` +makani +├── ... +├── models # code realted to ML models +│ ├── common # folder containing common features used in the neworks +│ │ ├── activations.py # complex activation functions +│ │ ├── contractions.py # einsum wrappers for complex contractions +│ │ ├── factorizations.py # tensor factorizations +│ │ ├── layers.py # common layers such as MLPs and wrappers for FFTs +│ │ └── spectral_convolution.py # spectral convolution layers for (S)FNO architectures +│ ├── networks # contains the actual architectures +│ │ ├── afnonet_v2.py # optimized AFNO +│ │ ├── afnonet.py # AFNO implementation +│ │ ├── debug.py # dummy network for debugging purposes +│ │ ├── sfnonet.py # implementation of (S)FNO +│ │ └── vit.py # implementation of a VIT +│ ├── helpers.py # helper functions +│ ├── model_package.py # model package implementation +│ ├── model_registry.py # model registry with get_model routine that takes care of wrapping the model +│ ├── preprocessor.py # implementation of preprocessor for dealing with unpredicted channels +│ ├── steppers.py # implements multistep and singlestep wrappers +│ └── Readme.md # this file +... + +``` + +### Model registry + +The model registry is a central place for organizing models in makani. By default, it contains the architectures contained in the `networks` directory, to which makani also exposes entrypoints. Models can be instantiated via + +```python +from makani.models import model_registry + +model = model_registry.get_model(params) +``` + +where `params` is the parameters object used to instantiate the model. Custom models can be registered in the registry using the `register` method. Models are required to take keyword arguments. These are automatically parsed from the `params` datastructure and passed to the model. + +In addition, models can be automatically registered through the `nettype` field in the configuration yaml file. To do so, the user can specify + +```yaml +nettype: "path/to/model_file.py:ModelName" +``` + +using the path to the model file and the class name `ModelName`. + +### Model packages + +Model packages are used for seamless inference outside of this repository. They define a flexible interfact which takes care of normalization, unpredicted channels etc. Model packages seemlessly integrate with [earth2mip](https://github.com/NVIDIA/earth2mip). + diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..4da6442 --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .preprocessor import Preprocessor2D +from .stepper import SingleStepWrapper, MultiStepWrapper + +import makani.models.model_registry \ No newline at end of file diff --git a/__pycache__/__init__.cpython-310.pyc b/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..3645545 Binary files /dev/null and b/__pycache__/__init__.cpython-310.pyc differ diff --git a/__pycache__/helpers.cpython-310.pyc b/__pycache__/helpers.cpython-310.pyc new file mode 100644 index 0000000..e019867 Binary files /dev/null and b/__pycache__/helpers.cpython-310.pyc differ diff --git a/__pycache__/model_package.cpython-310.pyc b/__pycache__/model_package.cpython-310.pyc new file mode 100644 index 0000000..10a5af4 Binary files /dev/null and b/__pycache__/model_package.cpython-310.pyc differ diff --git a/__pycache__/model_registry.cpython-310.pyc b/__pycache__/model_registry.cpython-310.pyc new file mode 100644 index 0000000..985b90b Binary files /dev/null and b/__pycache__/model_registry.cpython-310.pyc differ diff --git a/__pycache__/preprocessor.cpython-310.pyc b/__pycache__/preprocessor.cpython-310.pyc new file mode 100644 index 0000000..2a54fd0 Binary files /dev/null and b/__pycache__/preprocessor.cpython-310.pyc differ diff --git a/__pycache__/stepper.cpython-310.pyc b/__pycache__/stepper.cpython-310.pyc new file mode 100644 index 0000000..fa22307 Binary files /dev/null and b/__pycache__/stepper.cpython-310.pyc differ diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..0e8a082 --- /dev/null +++ b/common/__init__.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .activations import ComplexReLU, ComplexActivation +from .layers import DropPath, PatchEmbed, EncoderDecoder, MLP, RealFFT2, InverseRealFFT2 +from .spectral_convolution import SpectralConv, FactorizedSpectralConv, SpectralAttention \ No newline at end of file diff --git a/common/__pycache__/__init__.cpython-310.pyc b/common/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..9a94c2a Binary files /dev/null and b/common/__pycache__/__init__.cpython-310.pyc differ diff --git a/common/__pycache__/activations.cpython-310.pyc b/common/__pycache__/activations.cpython-310.pyc new file mode 100644 index 0000000..9c4d002 Binary files /dev/null and b/common/__pycache__/activations.cpython-310.pyc differ diff --git a/common/__pycache__/contractions.cpython-310.pyc b/common/__pycache__/contractions.cpython-310.pyc new file mode 100644 index 0000000..8756ddf Binary files /dev/null and b/common/__pycache__/contractions.cpython-310.pyc differ diff --git a/common/__pycache__/factorizations.cpython-310.pyc b/common/__pycache__/factorizations.cpython-310.pyc new file mode 100644 index 0000000..70236b3 Binary files /dev/null and b/common/__pycache__/factorizations.cpython-310.pyc differ diff --git a/common/__pycache__/layers.cpython-310.pyc b/common/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000..a54e618 Binary files /dev/null and b/common/__pycache__/layers.cpython-310.pyc differ diff --git a/common/__pycache__/spectral_convolution.cpython-310.pyc b/common/__pycache__/spectral_convolution.cpython-310.pyc new file mode 100644 index 0000000..fb13050 Binary files /dev/null and b/common/__pycache__/spectral_convolution.cpython-310.pyc differ diff --git a/common/activations.py b/common/activations.py new file mode 100644 index 0000000..dc4052e --- /dev/null +++ b/common/activations.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + + +class ComplexReLU(nn.Module): + """ + Complex-valued variants of the ReLU activation function + """ + + def __init__(self, negative_slope=0.0, mode="real", bias_shape=None, scale=1.0): + super().__init__() + + # store parameters + self.mode = mode + if self.mode in ["modulus", "halfplane"]: + if bias_shape is not None: + self.bias = nn.Parameter(scale * torch.ones(bias_shape, dtype=torch.float32)) + else: + self.bias = nn.Parameter(scale * torch.ones((1), dtype=torch.float32)) + else: + self.bias = 0 + + self.negative_slope = negative_slope + self.act = nn.LeakyReLU(negative_slope=negative_slope) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + if self.mode == "cartesian": + zr = torch.view_as_real(z) + za = self.act(zr) + out = torch.view_as_complex(za) + + elif self.mode == "modulus": + zabs = torch.sqrt(torch.square(z.real) + torch.square(z.imag)) + out = torch.where(zabs + self.bias > 0, (zabs + self.bias) * z / zabs, 0.0) + # out = self.act(zabs - self.bias) * torch.exp(1.j * z.angle()) + + elif self.mode == "halfplane": + # bias is an angle parameter in this case + modified_angle = torch.angle(z) - self.bias + condition = torch.logical_and((0.0 <= modified_angle), (modified_angle < torch.pi / 2.0)) + out = torch.where(condition, z, self.negative_slope * z) + + elif self.mode == "real": + zr = torch.view_as_real(z) + outr = zr.clone() + outr[..., 0] = self.act(zr[..., 0]) + out = torch.view_as_complex(outr) + + else: + raise NotImplementedError + + return out + + +class ComplexActivation(nn.Module): + def __init__(self, activation, mode="cartesian", bias_shape=None): + super().__init__() + + # store parameters + self.mode = mode + if self.mode == "modulus": + if bias_shape is not None: + self.bias = nn.Parameter(torch.zeros(bias_shape, dtype=torch.float32)) + else: + self.bias = nn.Parameter(torch.zeros((1), dtype=torch.float32)) + else: + bias = torch.zeros((1), dtype=torch.float32) + self.register_buffer("bias", bias) + + # real valued activation + self.act = activation + + def forward(self, z: torch.Tensor) -> torch.Tensor: + if self.mode == "cartesian": + zr = torch.view_as_real(z) + za = self.act(zr) + out = torch.view_as_complex(za) + elif self.mode == "modulus": + zabs = torch.sqrt(torch.square(z.real) + torch.square(z.imag)) + out = self.act(zabs + self.bias) * torch.exp(1.0j * z.angle()) + else: + # identity + out = z + + return out diff --git a/common/contractions.py b/common/contractions.py new file mode 100644 index 0000000..edbf6cf --- /dev/null +++ b/common/contractions.py @@ -0,0 +1,178 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +@torch.jit.script +def _contract_rank(xc: torch.Tensor, wc: torch.Tensor, ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # return torch.einsum("bixy,ior,xr,yr->boxy", x, w, a, b) + # xc = torch.view_as_complex(x) + # wc = w #torch.view_as_complex(w) + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,ior,xr,yr->boxy", xc, wc, ac, bc) + # res = torch.view_as_real(resc) + return resc + + +# # Helper routines for FNOs +@torch.jit.script +def compl_mul1d_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bix,io->box", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def compl_muladd1d_fwd(ac: torch.Tensor, bc: torch.Tensor, cc: torch.Tensor) -> torch.Tensor: + tmpcc = compl_mul1d_fwd(ac, bc) + # cc = torch.view_as_complex(c) + return tmpcc + cc + + +@torch.jit.script +def compl_mul2d_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,io->boxy", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def compl_muladd2d_fwd(ac: torch.Tensor, bc: torch.Tensor, cc: torch.Tensor) -> torch.Tensor: + tmpcc = compl_mul2d_fwd(ac, bc) + # cc = torch.view_as_complex(c) + return tmpcc + cc + + +@torch.jit.script +def _contract_localconv_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,iox->boxy", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def _contract_blockconv_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bim,imn->bin", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def _contractadd_blockconv_fwd(ac: torch.Tensor, bc: torch.Tensor, cc: torch.Tensor) -> torch.Tensor: + tmpcc = _contract_blockconv_fwd(ac, bc) + # cc = torch.view_as_complex(c) + return tmpcc + cc + + +# for the experimental layer +@torch.jit.script +def compl_exp_mul2d_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,xio->boxy", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def compl_exp_muladd2d_fwd(ac: torch.Tensor, bc: torch.Tensor, cc: torch.Tensor) -> torch.Tensor: + tmpcc = compl_exp_mul2d_fwd(ac, bc) + # cc = torch.view_as_complex(c) + return tmpcc + cc + + +@torch.jit.script +def real_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + res = torch.einsum("bixy,io->boxy", a, b) + return res + + +@torch.jit.script +def real_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + res = real_mul2d_fwd(a, b) + c + return res + + +# new contractions set to replace older ones. We use complex + + +@torch.jit.script +def _contract_diagonal(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,ioxy->boxy", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def _contract_dhconv(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,iox->boxy", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def _contract_sep_diagonal(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,ixy->boxy", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def _contract_sep_dhconv(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,ix->boxy", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def _contract_diagonal_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + res = torch.einsum("bixys,ioxy->boxys", a, b).contiguous() + return res + + +@torch.jit.script +def _contract_dhconv_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + res = torch.einsum("bixys,iox->boxys", a, b).contiguous() + return res + + +@torch.jit.script +def _contract_sep_diagonal_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + res = torch.einsum("bixys,ixy->boxys", a, b).contiguous() + return res + + +@torch.jit.script +def _contract_sep_dhconv_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + res = torch.einsum("bixys,ix->boxys", a, b).contiguous() + return res diff --git a/common/factorizations.py b/common/factorizations.py new file mode 100644 index 0000000..0fd2759 --- /dev/null +++ b/common/factorizations.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from functools import partial + +import tensorly as tl + +tl.set_backend("pytorch") + +from makani.models.common.contractions import _contract_diagonal, _contract_dhconv, _contract_sep_diagonal, _contract_sep_dhconv +from makani.models.common.contractions import _contract_diagonal_real, _contract_dhconv_real, _contract_sep_diagonal_real, _contract_sep_dhconv_real + + +from tltorch.factorized_tensors.core import FactorizedTensor + +einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + +def _contract_dense(x, weight, separable=False, operator_type="diagonal"): + order = tl.ndim(x) + # batch-size, in_channels, x, y... + x_syms = list(einsum_symbols[:order]) + + # in_channels, out_channels, x, y... + weight_syms = list(x_syms[1:]) # no batch-size + + # batch-size, out_channels, x, y... + if separable: + out_syms = [x_syms[0]] + list(weight_syms) + else: + weight_syms.insert(1, einsum_symbols[order]) # outputs + out_syms = list(weight_syms) + out_syms[0] = x_syms[0] + + if operator_type == "diagonal": + pass + elif operator_type == "block-diagonal": + weight_syms.insert(-1, einsum_symbols[order + 1]) + out_syms[-1] = weight_syms[-2] + elif operator_type == "dhconv": + weight_syms.pop() + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + eq = "".join(x_syms) + "," + "".join(weight_syms) + "->" + "".join(out_syms) + + if not torch.is_tensor(weight): + weight = weight.to_tensor() + + res = tl.einsum(eq, x, weight).contiguous() + + return res + + +def _contract_cp(x, cp_weight, separable=False, operator_type="diagonal"): + order = tl.ndim(x) + + x_syms = str(einsum_symbols[:order]) + rank_sym = einsum_symbols[order] + out_sym = einsum_symbols[order + 1] + out_syms = list(x_syms) + + if separable: + factor_syms = [einsum_symbols[1] + rank_sym] # in only + else: + out_syms[1] = out_sym + factor_syms = [einsum_symbols[1] + rank_sym, out_sym + rank_sym] # in, out + + factor_syms += [xs + rank_sym for xs in x_syms[2:]] # x, y, ... + + if operator_type == "diagonal": + pass + elif operator_type == "block-diagonal": + out_syms[-1] = einsum_symbols[order + 2] + factor_syms += [out_syms[-1] + rank_sym] + elif operator_type == "dhconv": + factor_syms.pop() + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + eq = x_syms + "," + rank_sym + "," + ",".join(factor_syms) + "->" + "".join(out_syms) + + res = tl.einsum(eq, x, cp_weight.weights, *cp_weight.factors).contiguous() + + return res + + +def _contract_tucker(x, tucker_weight, separable=False, operator_type="diagonal"): + order = tl.ndim(x) + + x_syms = str(einsum_symbols[:order]) + out_sym = einsum_symbols[order] + out_syms = list(x_syms) + if separable: + core_syms = einsum_symbols[order + 1 : 2 * order] + factor_syms = [xs + rs for (xs, rs) in zip(x_syms[1:], core_syms)] # x, y, ... + + else: + core_syms = einsum_symbols[order + 1 : 2 * order + 1] + out_syms[1] = out_sym + factor_syms = [einsum_symbols[1] + core_syms[0], out_sym + core_syms[1]] # out, in + factor_syms += [xs + rs for (xs, rs) in zip(x_syms[2:], core_syms[2:])] # x, y, ... + + if operator_type == "diagonal": + pass + elif operator_type == "block-diagonal": + raise NotImplementedError(f"Operator type {operator_type} not implemented for Tucker") + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + eq = x_syms + "," + core_syms + "," + ",".join(factor_syms) + "->" + "".join(out_syms) + + res = tl.einsum(eq, x, tucker_weight.core, *tucker_weight.factors).contiguous() + + return res + + +def _contract_tt(x, tt_weight, separable=False, operator_type="diagonal"): + order = tl.ndim(x) + + x_syms = list(einsum_symbols[:order]) + weight_syms = list(x_syms[1:]) # no batch-size + + if not separable: + weight_syms.insert(1, einsum_symbols[order]) # outputs + out_syms = list(weight_syms) + out_syms[0] = x_syms[0] + else: + out_syms = list(x_syms) + + if operator_type == "diagonal": + pass + elif operator_type == "block-diagonal": + weight_syms.insert(-1, einsum_symbols[order + 1]) + out_syms[-1] = weight_syms[-2] + elif operator_type == "dhconv": + weight_syms.pop() + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + rank_syms = list(einsum_symbols[order + 2 :]) + tt_syms = [] + for i, s in enumerate(weight_syms): + tt_syms.append([rank_syms[i], s, rank_syms[i + 1]]) + eq = "".join(x_syms) + "," + ",".join("".join(f) for f in tt_syms) + "->" + "".join(out_syms) + + res = tl.einsum(eq, x, *tt_weight.factors).contiguous() + + return res + + +# jitted PyTorch contractions: +def _contract_dense_pytorch(x, weight, separable=False, operator_type="diagonal", complex=True): + # make sure input is contig + x = x.contiguous() + + if separable: + if operator_type == "diagonal": + if complex: + x = _contract_sep_diagonal(x, weight) + else: + x = _contract_sep_diagonal_real(x, weight) + elif operator_type == "dhconv": + if complex: + x = _contract_sep_dhconv(x, weight) + else: + x = _contract_sep_dhconv_real(x, weight) + else: + raise ValueError(f"Unkonw operator type {operator_type}") + else: + if operator_type == "diagonal": + if complex: + x = _contract_diagonal(x, weight) + else: + x = _contract_diagonal_real(x, weight) + elif operator_type == "dhconv": + if complex: + x = _contract_dhconv(x, weight) + else: + x = _contract_dhconv_real(x, weight) + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + # make contiguous + x = x.contiguous() + return x + + +def _contract_dense_reconstruct(x, weight, separable=False, operator_type="diagonal", complex=True): + """Contraction for dense tensors, factorized or not""" + if not torch.is_tensor(weight): + weight = weight.to_tensor() + # weight = torch.view_as_real(weight) + + return _contract_dense_pytorch(x, weight, separable=separable, operator_type=operator_type, complex=complex) + + +def get_contract_fun(weight, implementation="reconstructed", separable=False, operator_type="diagonal", complex=True): + """Generic ND implementation of Fourier Spectral Conv contraction + + Parameters + ---------- + weight : tensorly-torch's FactorizedTensor + implementation : {'reconstructed', 'factorized'}, default is 'reconstructed' + whether to reconstruct the weight and do a forward pass (reconstructed) + or contract directly the factors of the factorized weight with the input (factorized) + + Returns + ------- + function : (x, weight) -> x * weight in Fourier space + """ + if implementation == "reconstructed": + handle = partial(_contract_dense_reconstruct, separable=separable, complex=complex, operator_type=operator_type) + return handle + elif implementation == "factorized": + if torch.is_tensor(weight): + handle = partial(_contract_dense_pytorch, separable=separable, complex=complex, operator_type=operator_type) + return handle + elif isinstance(weight, FactorizedTensor): + if weight.name.lower() == "complexdense" or weight.name.lower() == "dense": + return _contract_dense + elif weight.name.lower() == "complextucker": + return _contract_tucker + elif weight.name.lower() == "complextt": + return _contract_tt + elif weight.name.lower() == "complexcp": + return _contract_cp + else: + raise ValueError(f"Got unexpected factorized weight type {weight.name}") + else: + raise ValueError(f"Got unexpected weight type of class {weight.__class__.__name__}") + else: + raise ValueError(f'Got {implementation=}, expected "reconstructed" or "factorized"') diff --git a/common/layers.py b/common/layers.py new file mode 100644 index 0000000..1c18bd3 --- /dev/null +++ b/common/layers.py @@ -0,0 +1,288 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from collections import OrderedDict +from copy import Error, deepcopy +from re import S +from numpy.lib.arraypad import pad +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.fft +from torch.nn.modules.container import Sequential +from torch.utils.checkpoint import checkpoint, checkpoint_sequential +from torch.cuda import amp +from typing import Optional +import math + +from makani.models.common.contractions import compl_muladd2d_fwd, compl_mul2d_fwd +from makani.models.common.contractions import _contract_diagonal + + +@torch.jit.script +def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1.0 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2d ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class PatchEmbed(nn.Module): + def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768): + super(PatchEmbed, self).__init__() + self.red_img_size = ((img_size[0] // patch_size[0]), (img_size[1] // patch_size[1])) + num_patches = self.red_img_size[0] * self.red_img_size[1] + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + self.proj.weight.is_shared_mp = ["spatial"] + self.proj.bias.is_shared_mp = ["spatial"] + + def forward(self, x): + # gather input + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + # new: B, C, H*W + x = self.proj(x).flatten(2) + return x + + +class EncoderDecoder(nn.Module): + def __init__(self, num_layers, input_dim, output_dim, hidden_dim, act_layer, gain=1.0, input_format="nchw"): + super(EncoderDecoder, self).__init__() + + encoder_modules = [] + current_dim = input_dim + for i in range(num_layers): + # fully connected layer + if input_format == "nchw": + encoder_modules.append(nn.Conv2d(current_dim, hidden_dim, 1, bias=True)) + elif input_format == "traditional": + encoder_modules.append(nn.Linear(current_dim, hidden_dim, bias=True)) + else: + raise NotImplementedError(f"Error, input format {input_format} not supported.") + + # weight sharing + encoder_modules[-1].weight.is_shared_mp = ["spatial"] + + # proper initializaiton + scale = math.sqrt(2.0 / current_dim) + nn.init.normal_(encoder_modules[-1].weight, mean=0.0, std=scale) + if encoder_modules[-1].bias is not None: + encoder_modules[-1].bias.is_shared_mp = ["spatial"] + nn.init.constant_(encoder_modules[-1].bias, 0.0) + + encoder_modules.append(act_layer()) + current_dim = hidden_dim + + # final output layer + if input_format == "nchw": + encoder_modules.append(nn.Conv2d(current_dim, output_dim, 1, bias=False)) + elif input_format == "traditional": + encoder_modules.append(nn.Linear(current_dim, output_dim, bias=False)) + + # weight sharing + encoder_modules[-1].weight.is_shared_mp = ["spatial"] + + # proper initializaiton + scale = math.sqrt(gain / current_dim) + nn.init.normal_(encoder_modules[-1].weight, mean=0.0, std=scale) + if encoder_modules[-1].bias is not None: + encoder_modules[-1].bias.is_shared_mp = ["spatial"] + nn.init.constant_(encoder_modules[-1].bias, 0.0) + + self.fwd = nn.Sequential(*encoder_modules) + + def forward(self, x): + return self.fwd(x) + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + output_bias=True, + input_format="nchw", + drop_rate=0.0, + drop_type="iid", + checkpointing=0, + gain=1.0, + **kwargs, + ): + super(MLP, self).__init__() + self.checkpointing = checkpointing + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + # First fully connected layer + if input_format == "nchw": + fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True) + fc1.weight.is_shared_mp = ["spatial"] + fc1.bias.is_shared_mp = ["spatial"] + elif input_format == "traditional": + fc1 = nn.Linear(in_features, hidden_features, bias=True) + else: + raise NotImplementedError(f"Error, input format {input_format} not supported.") + + # initialize the weights correctly + scale = math.sqrt(2.0 / in_features) + nn.init.normal_(fc1.weight, mean=0.0, std=scale) + nn.init.constant_(fc1.bias, 0.0) + + # activation + act = act_layer() + + # sanity checks + if (input_format == "traditional") and (drop_type == "features"): + raise NotImplementedError(f"Error, traditional input format and feature dropout cannot be selected simultaneously") + + # output layer + if input_format == "nchw": + fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=output_bias) + fc2.weight.is_shared_mp = ["spatial"] + if output_bias: + fc2.bias.is_shared_mp = ["spatial"] + elif input_format == "traditional": + fc2 = nn.Linear(hidden_features, out_features, bias=output_bias) + else: + raise NotImplementedError(f"Error, input format {input_format} not supported.") + + # gain factor for the output determines the scaling of the output init + scale = math.sqrt(gain / hidden_features) + nn.init.normal_(fc2.weight, mean=0.0, std=scale) + if fc2.bias is not None: + nn.init.constant_(fc2.bias, 0.0) + + if drop_rate > 0.0: + if drop_type == "iid": + drop = nn.Dropout(drop_rate) + elif drop_type == "features": + drop = nn.Dropout2d(drop_rate) + else: + raise NotImplementedError(f"Error, drop_type {drop_type} not supported") + else: + drop = nn.Identity() + + # create forward pass + self.fwd = nn.Sequential(fc1, act, drop, fc2, drop) + + @torch.jit.ignore + def checkpoint_forward(self, x): + return checkpoint(self.fwd, x, use_reentrant=False) + + def forward(self, x): + if self.checkpointing >= 2: + return self.checkpoint_forward(x) + else: + return self.fwd(x) + + +class RealFFT2(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None): + super(RealFFT2, self).__init__() + + # use local FFT here + self.fft_handle = torch.fft.rfft2 + + self.nlat = nlat + self.nlon = nlon + self.lmax = min(lmax or self.nlat, self.nlat) + self.mmax = min(mmax or self.nlon // 2 + 1, self.nlon // 2 + 1) + + self.truncate = True + if (self.lmax == self.nlat) and (self.mmax == (self.nlon // 2 + 1)): + self.truncate = False + + self.lmax_high = math.ceil(self.lmax / 2) + self.lmax_low = math.floor(self.lmax / 2) + + def forward(self, x): + y = self.fft_handle(x, s=(self.nlat, self.nlon), dim=(-2, -1), norm="ortho") + + if self.truncate: + y = torch.cat((y[..., : self.lmax_high, : self.mmax], y[..., -self.lmax_low :, : self.mmax]), dim=-2) + + return y + + +class InverseRealFFT2(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None): + super(InverseRealFFT2, self).__init__() + + # use local FFT here + self.ifft_handle = torch.fft.irfft2 + + self.nlat = nlat + self.nlon = nlon + self.lmax = min(lmax or self.nlat, self.nlat) + self.mmax = min(mmax or self.nlon // 2 + 1, self.nlon // 2 + 1) + + self.truncate = True + if (self.lmax == self.nlat) and (self.mmax == (self.nlon // 2 + 1)): + self.truncate = False + + self.lmax_high = math.ceil(self.lmax / 2) + self.lmax_low = math.floor(self.lmax / 2) + + def forward(self, x): + # truncation is implicit but better do it manually + xt = x[..., : self.mmax] + + if self.truncate: + # pad + xth = xt[..., : self.lmax_high, :] + xtl = xt[..., -self.lmax_low :, :] + xthp = F.pad(xth, (0, 0, 0, self.nlat - self.lmax)) + xt = torch.cat([xthp, xtl], dim=-2) + + out = torch.fft.irfft2(xt, s=(self.nlat, self.nlon), dim=(-2, -1), norm="ortho") + + return out diff --git a/common/spectral_convolution.py b/common/spectral_convolution.py new file mode 100644 index 0000000..57d768f --- /dev/null +++ b/common/spectral_convolution.py @@ -0,0 +1,405 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math + +from torch.cuda import amp + +# import FactorizedTensor from tensorly for tensorized operations +import tensorly as tl + +tl.set_backend("pytorch") +from tltorch.factorized_tensors.core import FactorizedTensor + +# import convenience functions for factorized tensors +from makani.utils import comm +from makani.models.common.activations import ComplexReLU +from makani.models.common.contractions import compl_muladd2d_fwd, compl_mul2d_fwd, _contract_rank +from makani.models.common.factorizations import get_contract_fun + +# for the experimental module +from makani.models.common.contractions import compl_exp_muladd2d_fwd, compl_exp_mul2d_fwd + +import torch_harmonics as th +import torch_harmonics.distributed as thd + + +class SpectralConv(nn.Module): + """ + Spectral Convolution implemented via SHT or FFT. Designed for convolutions on the two-sphere S2 + using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic + domain via the RealFFT2 and InverseRealFFT2 wrappers. + """ + + def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, operator_type="diagonal", separable=False, bias=False, gain=1.0): + super(SpectralConv, self).__init__() + + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.in_channels = in_channels + self.out_channels = out_channels + + self.modes_lat = self.inverse_transform.lmax + self.modes_lon = self.inverse_transform.mmax + + self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon) + if hasattr(self.forward_transform, "grid"): + self.scale_residual = self.scale_residual or (self.forward_transform.grid != self.inverse_transform.grid) + + # remember factorization details + self.operator_type = operator_type + self.separable = separable + + assert self.inverse_transform.lmax == self.modes_lat + assert self.inverse_transform.mmax == self.modes_lon + + weight_shape = [in_channels] + + if not self.separable: + weight_shape += [out_channels] + + if isinstance(self.inverse_transform, thd.DistributedInverseRealSHT): + self.modes_lat_local = self.inverse_transform.l_shapes[comm.get_rank("h")] + self.modes_lon_local = self.inverse_transform.m_shapes[comm.get_rank("w")] + self.nlat_local = self.inverse_transform.lat_shapes[comm.get_rank("h")] + self.nlon_local = self.inverse_transform.lon_shapes[comm.get_rank("w")] + else: + self.modes_lat_local = self.modes_lat + self.modes_lon_local = self.modes_lon + self.nlat_local = self.inverse_transform.nlat + self.nlon_local = self.inverse_transform.nlon + + # unpadded weights + if self.operator_type == "diagonal": + weight_shape += [self.modes_lat_local, self.modes_lon_local] + elif self.operator_type == "dhconv": + weight_shape += [self.modes_lat_local] + else: + raise ValueError(f"Unsupported operator type f{self.operator_type}") + + # Compute scaling factor for correct initialization + scale = math.sqrt(gain / in_channels) * torch.ones(self.modes_lat_local, dtype=torch.complex64) + # seemingly the first weight is not really complex, so we need to account for that + scale[0] *= math.sqrt(2.0) + init = scale * torch.randn(*weight_shape, dtype=torch.complex64) + self.weight = nn.Parameter(init) + + if self.operator_type == "dhconv": + self.weight.is_shared_mp = ["matmul", "w"] + self.weight.sharded_dims_mp = [None for _ in weight_shape] + self.weight.sharded_dims_mp[-1] = "h" + else: + self.weight.is_shared_mp = ["matmul"] + self.weight.sharded_dims_mp = [None for _ in weight_shape] + self.weight.sharded_dims_mp[-1] = "w" + self.weight.sharded_dims_mp[-2] = "h" + + # get the contraction handle. This should return a pyTorch contraction + self._contract = get_contract_fun(self.weight, implementation="factorized", separable=separable, complex=True, operator_type=operator_type) + + if bias == "constant": + self.bias = nn.Parameter(torch.zeros(1, self.out_channels, 1, 1)) + elif bias == "position": + self.bias = nn.Parameter(torch.zeros(1, self.out_channels, self.nlat_local, self.nlon_local)) + self.bias.is_shared_mp = ["matmul"] + self.bias.sharded_dims_mp = [None, None, "h", "w"] + + def forward(self, x): + dtype = x.dtype + residual = x + x = x.float() + B, C, H, W = x.shape + + with amp.autocast(enabled=False): + x = self.forward_transform(x).contiguous() + if self.scale_residual: + residual = self.inverse_transform(x) + residual = residual.to(dtype) + + # approach with unpadded weights + xp = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type) + x = xp.contiguous() + + with amp.autocast(enabled=False): + x = self.inverse_transform(x) + + if hasattr(self, "bias"): + x = x + self.bias + + x = x.to(dtype=dtype) + + return x, residual + + +class FactorizedSpectralConv(nn.Module): + """ + Factorized version of SpectralConv. Uses tensorly-torch to keep the weights factorized + """ + + def __init__( + self, + forward_transform, + inverse_transform, + in_channels, + out_channels, + operator_type="diagonal", + rank=0.2, + factorization=None, + separable=False, + decomposition_kwargs=dict(), + bias=False, + gain=1.0, + ): + super(FactorizedSpectralConv, self).__init__() + + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.in_channels = in_channels + self.out_channels = out_channels + + self.modes_lat = self.inverse_transform.lmax + self.modes_lon = self.inverse_transform.mmax + + self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon) + if hasattr(self.forward_transform, "grid"): + self.scale_residual = self.scale_residual or (self.forward_transform.grid != self.inverse_transform.grid) + + # Make sure we are using a Complex Factorized Tensor + if factorization is None: + factorization = "ComplexDense" # No factorization + complex_weight = factorization[:7].lower() == "complex" + + # remember factorization details + self.operator_type = operator_type + self.rank = rank + self.factorization = factorization + self.separable = separable + + assert self.inverse_transform.lmax == self.modes_lat + assert self.inverse_transform.mmax == self.modes_lon + + weight_shape = [in_channels] + + if not self.separable: + weight_shape += [out_channels] + + if isinstance(self.inverse_transform, thd.DistributedInverseRealSHT): + self.modes_lat_local = self.inverse_transform.l_shapes[comm.get_rank("h")] + self.modes_lon_local = self.inverse_transform.m_shapes[comm.get_rank("w")] + else: + self.modes_lat_local = self.modes_lat + self.modes_lon_local = self.modes_lon + + # unpadded weights + if self.operator_type == "diagonal": + weight_shape += [self.modes_lat_local, self.modes_lon_local] + elif self.operator_type == "dhconv": + weight_shape += [self.modes_lat_local] + elif self.operator_type == "rank": + weight_shape += [self.rank] + else: + raise ValueError(f"Unsupported operator type f{self.operator_type}") + + # form weight tensors + self.weight = FactorizedTensor.new(weight_shape, rank=self.rank, factorization=factorization, fixed_rank_modes=False, **decomposition_kwargs) + # initialization of weights + scale = math.sqrt(gain / float(weight_shape[0])) + self.weight.normal_(mean=0.0, std=scale) + + # get the contraction handle + if operator_type == "rank": + self._contract = _contract_rank + else: + self._contract = get_contract_fun(self.weight, implementation="reconstructed", separable=separable, complex=complex_weight, operator_type=operator_type) + + if bias == "constant": + self.bias = nn.Parameter(torch.zeros(1, self.out_channels, 1, 1)) + elif bias == "position": + self.bias = nn.Parameter(torch.zeros(1, self.out_channels, self.nlat_local, self.nlon_local)) + self.bias.is_shared_mp = ["matmul"] + self.bias.sharded_dims_mp = [None, None, "h", "w"] + + def forward(self, x): + dtype = x.dtype + residual = x + x = x.float() + + with amp.autocast(enabled=False): + x = self.forward_transform(x).contiguous() + if self.scale_residual: + residual = self.inverse_transform(x) + residual = residual.to(dtype) + + if self.operator_type == "rank": + xp = self._contract(x, self.weight, self.lat_weight, self.lon_weight) + else: + xp = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type) + x = xp.contiguous() + + with amp.autocast(enabled=False): + x = self.inverse_transform(x) + + if hasattr(self, "bias"): + x = x + self.bias + + x = x.type(dtype) + + return x, residual + + +class SpectralAttention(nn.Module): + """ + Spherical non-linear FNO layer + """ + + def __init__( + self, + forward_transform, + inverse_transform, + in_channels, + out_channels, + operator_type="diagonal", + hidden_size_factor=2, + complex_activation="real", + bias=False, + spectral_layers=1, + drop_rate=0.0, + gain=1.0, + ): + super(SpectralAttention, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.operator_type = operator_type + self.spectral_layers = spectral_layers + + self.modes_lat = forward_transform.lmax + self.modes_lon = forward_transform.mmax + + # only storing the forward handle to be able to call it + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.scale_residual = ( + (self.forward_transform.nlat != self.inverse_transform.nlat) + or (self.forward_transform.nlon != self.inverse_transform.nlon) + or (self.forward_transform.grid != self.inverse_transform.grid) + ) + + assert inverse_transform.lmax == self.modes_lat + assert inverse_transform.mmax == self.modes_lon + + hidden_size = int(hidden_size_factor * self.in_channels) + + if operator_type == "diagonal": + self.mul_add_handle = compl_muladd2d_fwd + self.mul_handle = compl_mul2d_fwd + + # weights + scale = math.sqrt(2.0 / float(in_channels)) + w = [scale * torch.randn(self.in_channels, hidden_size, dtype=torch.complex64)] + for l in range(1, self.spectral_layers): + scale = math.sqrt(2.0 / float(hidden_size)) + w.append(scale * torch.randn(hidden_size, hidden_size, dtype=torch.complex64)) + self.w = nn.ParameterList(w) + + scale = math.sqrt(gain / float(in_channels)) + self.wout = nn.Parameter(scale * torch.randn(hidden_size, self.out_channels, dtype=torch.complex64)) + + if bias: + self.b = nn.ParameterList([scale * torch.randn(hidden_size, 1, 1, dtype=torch.complex64) for _ in range(self.spectral_layers)]) + + self.activations = nn.ModuleList([]) + for l in range(0, self.spectral_layers): + self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=scale)) + + elif operator_type == "l-dependant": + self.mul_add_handle = compl_exp_muladd2d_fwd + self.mul_handle = compl_exp_mul2d_fwd + + # weights + scale = math.sqrt(2.0 / float(in_channels)) + w = [scale * torch.randn(self.modes_lat, self.in_channels, hidden_size, dtype=torch.complex64)] + for l in range(1, self.spectral_layers): + scale = math.sqrt(2.0 / float(hidden_size)) + w.append(scale * torch.randn(self.modes_lat, hidden_size, hidden_size, dtype=torch.complex64)) + self.w = nn.ParameterList(w) + + if bias: + self.b = nn.ParameterList([scale * torch.randn(hidden_size, 1, 1, dtype=torch.complex64) for _ in range(self.spectral_layers)]) + + scale = math.sqrt(gain / float(in_channels)) + self.wout = nn.Parameter(scale * torch.randn(self.modes_lat, hidden_size, self.out_channels, dtype=torch.complex64)) + + self.activations = nn.ModuleList([]) + for l in range(0, self.spectral_layers): + self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=scale)) + + else: + raise ValueError("Unknown operator type") + + self.drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity() + + def forward_mlp(self, x): + B, C, H, W = x.shape + + xr = torch.view_as_real(x) + + for l in range(self.spectral_layers): + if hasattr(self, "b"): + xr = self.mul_add_handle(xr, self.w[l], self.b[l]) + else: + xr = self.mul_handle(xr, self.w[l]) + xr = torch.view_as_complex(xr) + xr = self.activations[l](xr) + xr = self.drop(xr) + xr = torch.view_as_real(xr) + + # final MLP + x = self.mul_handle(xr, self.wout) + + x = torch.view_as_complex(x) + + return x + + def forward(self, x): + dtype = x.dtype + residual = x + x = x.to(torch.float32) + + # FWD transform + with amp.autocast(enabled=False): + x = self.forward_transform(x) + if self.scale_residual: + residual = self.inverse_transform(x) + residual = residual.to(dtype) + + # MLP + x = self.forward_mlp(x) + + # BWD transform + with amp.autocast(enabled=False): + x = self.inverse_transform(x) + + # cast back to initial precision + x = x.to(dtype) + + return x, residual diff --git a/helpers.py b/helpers.py new file mode 100644 index 0000000..ad12ea1 --- /dev/null +++ b/helpers.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed as dist + +from makani.utils import comm + + +def count_parameters(model, device): + with torch.no_grad(): + total_count = 0 + for p in model.parameters(): + if not p.requires_grad: + continue + # reduce over model group + pcount = torch.tensor(p.numel(), device=device) + if hasattr(p, "is_shared_mp") and p.is_shared_mp: + if comm.get_size("model") > 1: + dist.all_reduce(pcount, group=comm.get_group("model")) + # divide by shared dims: + for cname in p.is_shared_mp: + pcount = pcount / comm.get_size(cname) + total_count += int(pcount.item()) + + return total_count + + +def check_parameters(model): + for p in model.parameters(): + if p.requires_grad: + print(p.shape, p.stride(), p.is_contiguous()) diff --git a/model_package.py b/model_package.py new file mode 100644 index 0000000..842708c --- /dev/null +++ b/model_package.py @@ -0,0 +1,268 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model package for easy inference/packaging. Model packages contain all the necessary data to +perform inference and its interface is compatible with earth2mip +""" +import os +import shutil +import json +import jsbeautifier +import numpy as np +import torch +from makani.utils.YParams import ParamsBase +from makani.third_party.climt.zenith_angle import cos_zenith_angle + +from makani.models import model_registry + +import datetime + +import logging + + +class LocalPackage: + """ + Implements the earth2mip/modulus Package interface. + """ + + def __init__(self, root): + self.root = root + + def get(self, path): + return os.path.join(self.root, path) + + +logger = logging.getLogger(__name__) + +THIS_MODULE = "makani.models.model_package" +MODEL_PACKAGE_CHECKPOINT_PATH = "training_checkpoints/best_ckpt_mp0.tar" +MINS_FILE = "mins.npy" +MAXS_FILE = "maxs.npy" +MEANS_FILE = "global_means.npy" +STDS_FILE = "global_stds.npy" + + +class ModelWrapper(torch.nn.Module): + """ + Model wrapper to make inference simple outside of makani. + + Attributes + ---------- + model : torch.nn.Module + ML model that is wrapped. + params : ParamsBase + parameter object containing information on how the model was initialized in makani + + Methods + ------- + forward(x, time): + performs a single prediction steps + """ + + def __init__(self, model, params): + super().__init__() + self.model = model + self.params = params + nlat = params.img_shape_x + nlon = params.img_shape_y + + self.lats = 90 - 180 * np.arange(nlat) / (nlat - 1) + self.lons = 360 * np.arange(nlon) / nlon + self.add_zenith = params.add_zenith + + def forward(self, x, time): + if self.add_zenith: + lon_grid, lat_grid = np.meshgrid(self.lons, self.lats) + cosz = cos_zenith_angle(time, lon_grid, lat_grid) + cosz = cosz.astype(np.float32) + z = torch.from_numpy(cosz).to(device=x.device) + while z.ndim != x.ndim: + z = z[None] + x = torch.cat([x, z], dim=1) + + return self.model(x) + + +def save_model_package(params): + """ + Saves out a self-contained model-package. + The idea is to save anything necessary for inference beyond the checkpoints in one location. + """ + # save out the current state of the parameters, make it human readable + config_path = os.path.join(params.experiment_dir, "config.json") + jsopts = jsbeautifier.default_options() + jsopts.indent_size = 2 + + with open(config_path, "w") as f: + msg = jsbeautifier.beautify(json.dumps(params.to_dict()), jsopts) + f.write(msg) + + if hasattr(params, "add_orography") and params.add_orography: + shutil.copy(params.orography_path, os.path.join(params.experiment_dir, "orography.nc")) + + if hasattr(params, "add_landmask") and params.add_landmask: + shutil.copy(params.landmask_path, os.path.join(params.experiment_dir, "land_mask.nc")) + + # a bit hacky - we should change this to get the normalization from the dataloader. + if hasattr(params, "global_means_path") and params.global_means_path is not None: + shutil.copy(params.global_means_path, os.path.join(params.experiment_dir, MEANS_FILE)) + if hasattr(params, "global_stds_path") and params.global_stds_path is not None: + shutil.copy(params.global_stds_path, os.path.join(params.experiment_dir, STDS_FILE)) + + if params.normalization == "minmax": + if hasattr(params, "min_path") and params.min_path is not None: + shutil.copy(params.min_path, os.path.join(params.experiment_dir, MINS_FILE)) + if hasattr(params, "max_path") and params.max_path is not None: + shutil.copy(params.max_path, os.path.join(params.experiment_dir, MAXS_FILE)) + + # write out earth2mip metadata.json + fcn_mip_data = { + "entrypoint": {"name": f"{THIS_MODULE}:load_time_loop"}, + } + with open(os.path.join(params.experiment_dir, "metadata.json"), "w") as f: + msg = jsbeautifier.beautify(json.dumps(fcn_mip_data), jsopts) + f.write(msg) + + +def _load_static_data(package, params): + if hasattr(params, "add_orography") and params.add_orography: + params.orography_path = package.get("orography.nc") + + if hasattr(params, "add_landmask") and params.add_landmask: + params.landmask_path = package.get("land_mask.nc") + + # a bit hacky - we should change this to correctly + if params.normalization == "zscore": + if hasattr(params, "global_means_path") and params.global_means_path is not None: + params.global_means_path = package.get(MEANS_FILE) + if hasattr(params, "global_stds_path") and params.global_stds_path is not None: + params.global_stds_path = package.get(STDS_FILE) + elif params.normalization == "minmax": + if hasattr(params, "min_path") and params.min_path is not None: + params.min_path = package.get(MINS_FILE) + if hasattr(params, "max_path") and params.max_path is not None: + params.max_path = package.get(MAXS_FILE) + else: + raise ValueError("Unknown normalization mode.") + + +def load_model_package(package, pretrained=True, device="cpu"): + """ + Loads model package and return the wrapper which can be used for inference. + """ + path = package.get("config.json") + params = ParamsBase.from_json(path) + logger.info(str(params.to_dict())) + _load_static_data(package, params) + + # assume we are not distributed + # distributed checkpoints might be saved with different params values + params.img_local_offset_x = 0 + params.img_local_offset_y = 0 + params.img_local_shape_x = params.img_shape_x + params.img_local_shape_y = params.img_shape_y + + # get the model and + model = model_registry.get_model(params).to(device) + + if pretrained: + best_checkpoint_path = package.get(MODEL_PACKAGE_CHECKPOINT_PATH) + # critical that this map_location be cpu, rather than the device to + # avoid out of memory errors. + checkpoint = torch.load(best_checkpoint_path, map_location=device) + state_dict = checkpoint["model_state"] + torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict, "module.") + model.load_state_dict(state_dict, strict=True) + + model = ModelWrapper(model, params=params) + + # by default we want to do evaluation so setting it to eval here + # 1-channel difference in training/eval mode + model.eval() + + return model + + +def load_time_loop(package, device=None, time_step_hours=None): + """This function loads an earth2mip TimeLoop object that + can be used for inference. + + A TimeLoop encapsulates normalization, regridding, and other logic, so is a + very minimal interface to expose to a framework like earth2mip. + + See https://github.com/NVIDIA/earth2mip/blob/main/docs/concepts.rst + for more info on this interface. + """ + + from earth2mip.networks import Inference + from earth2mip.grid import equiangular_lat_lon_grid + + config = package.get("config.json") + params = ParamsBase.from_json(config) + + if params.in_channels != params.out_channels: + raise NotImplementedError("Non-equal input and output channels are not implemented yet.") + + names = [params.channel_names[i] for i in params.in_channels] + + if params.normalization == "minmax": + min_path = package.get(MINS_FILE) + max_path = package.get(MAXS_FILE) + + a = np.load(min_path) + a = np.squeeze(a)[params.in_channels] + + b = np.load(max_path) + b = np.squeeze(b)[params.in_channels] + + # work around to implement minmax scaling based with the earth2mip + # Inference class below + center = (a + b) / 2 + scale = (b - a) / 2 + else: + center_path = package.get(MEANS_FILE) + scale_path = package.get(STDS_FILE) + + center = np.load(center_path) + center = np.squeeze(center)[params.in_channels] + + scale = np.load(scale_path) + scale = np.squeeze(scale)[params.in_channels] + + model = load_model_package(package, pretrained=True, device=device) + shape = (params.img_shape_x, params.img_shape_y) + + grid = equiangular_lat_lon_grid(nlat=params.img_shape_x, nlon=params.img_shape_y, includes_south_pole=True) + + if time_step_hours is None: + hour = datetime.timedelta(hours=1) + time_step = hour * params.get("dt", 6) + else: + time_step = datetime.timedelta(hours=time_step_hours) + + # Here we use the built-in class earth2mip.networks.Inference + # will later be extended to use the makani inferencer + inference = Inference( + model=model, + channel_names=names, + center=center, + scale=scale, + grid=grid, + n_history=params.n_history, + time_step=time_step, + ) + inference.to(device) + return inference diff --git a/model_registry.py b/model_registry.py new file mode 100644 index 0000000..01bd64d --- /dev/null +++ b/model_registry.py @@ -0,0 +1,170 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import importlib.util + +# we need this here for the code to work +import importlib_metadata +from importlib.metadata import EntryPoint, entry_points + +import logging + +from typing import List, Union +from functools import partial + +import torch.nn as nn + +from makani.utils.YParams import ParamsBase +from makani.models import SingleStepWrapper, MultiStepWrapper + + +def _construct_registry() -> dict: + registry = {} + entrypoints = entry_points(group="makani.models") + for entry_point in entrypoints: + registry[entry_point.name] = entry_point + return registry + + +def _register_from_module(model: nn.Module, name: Union[str, None] = None) -> None: + """ + registers a module in the registry + """ + + # Check if model is a torch module + if not issubclass(model, nn.Module): + raise ValueError(f"Only subclasses of torch.nn.Module can be registered. " f"Provided model is of type {type(model)}") + + # If no name provided, use the model's name + if name is None: + name = model.__name__ + + # Check if name already in use + if name in _model_registry: + raise ValueError(f"Name {name} already in use") + + # Add this class to the dict of model registry + _model_registry[name] = model + + +def _register_from_file(model_string: str, name: Union[str, None] = None) -> None: + """ + parses a string and attempts to get the module from the specified location + """ + + assert len(model_string.split(":")) == 2 + model_path, model_handle = model_string.split(":") + + if not os.path.exists(model_path): + raise ValueError(f"Expected string of format 'path/to/model_file.py:ModuleName' but {model_path} does not exist.") + + module_spec = importlib.util.spec_from_file_location(model_handle, model_path) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) + model = getattr(module, model_handle) + + _register_from_module(model, name) + + +def register_model(model: Union[str, nn.Module], name: Union[str, None] = None) -> None: + """ + Registers a model in the model registry under the provided name. If no name + is provided, the model's name (from its `__name__` attribute) is used. If the + name is already in use, raises a ValueError. + + Parameters + ---------- + model : torch.nn.Module + The model to be registered. Can be an instance of any class. + name : str, optional + The name to register the model under. If None, the model's name is used. + + Raises + ------ + ValueError + If the provided name is already in use in the registry. + """ + + if isinstance(model, str): + _register_from_file(model, name) + else: + _register_from_module(model, name) + +def list_models() -> List[str]: + """ + Returns a list of the names of all models currently registered in the registry. + + Returns + ------- + List[str] + A list of the names of all registered models. The order of the names is not + guaranteed to be consistent. + """ + return list(_model_registry.keys()) + + +def get_model(params: ParamsBase, **kwargs) -> "torch.nn.Module": + """ + Convenience routine that constructs the model passing parameters and kwargs. + Unloads all the parameters in the params datastructure as a dict. + + Parameters + ---------- + params : ParamsBase + parameter struct. + + Returns + ------- + model : torch.nn.Module + The registered model. + + Raises + ------ + KeyError + If no model is registered under the provided name. + """ + + if params is not None: + # makani requires that these entries are set in params for now + inp_shape = (params.img_crop_shape_x, params.img_crop_shape_y) + out_shape = (params.out_shape_x, params.out_shape_y) if hasattr(params, "out_shape_x") and hasattr(params, "out_shape_y") else inp_shape + inp_chans = params.N_in_channels + out_chans = params.N_out_channels + + if params.nettype not in _model_registry: + logging.warning(f"Net type {params.nettype} does not exist in the registry. Trying to register it.") + register_model(params.nettype, params.nettype) + + model_handle = _model_registry.get(params.nettype) + if model_handle is not None: + if isinstance(model_handle, (EntryPoint, importlib_metadata.EntryPoint)): + model_handle = model_handle.load() + + model_handle = partial(model_handle, inp_shape=inp_shape, out_shape=out_shape, inp_chans=inp_chans, out_chans=out_chans, **params.to_dict()) + else: + raise KeyError(f"No model is registered under the name {name}") + + # wrap into Multi-Step if requested + if params.n_future > 0: + model = MultiStepWrapper(params, model_handle) + else: + model = SingleStepWrapper(params, model_handle) + + return model + + +# initialize the internal state upon import +_model_registry = _construct_registry() diff --git a/networks/__pycache__/sfnonet.cpython-310.pyc b/networks/__pycache__/sfnonet.cpython-310.pyc new file mode 100644 index 0000000..16d8b41 Binary files /dev/null and b/networks/__pycache__/sfnonet.cpython-310.pyc differ diff --git a/networks/afnonet.py b/networks/afnonet.py new file mode 100644 index 0000000..9777793 --- /dev/null +++ b/networks/afnonet.py @@ -0,0 +1,268 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.fft +from makani.utils.img_utils import PeriodicPad2d + +from makani.models.common import DropPath, PatchEmbed + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class AFNO2D(nn.Module): + def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1): + super().__init__() + assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.num_blocks = num_blocks + self.block_size = self.hidden_size // self.num_blocks + self.hard_thresholding_fraction = hard_thresholding_fraction + self.hidden_size_factor = hidden_size_factor + self.scale = 0.02 + + self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) + self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) + self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) + self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) + + def forward(self, x): + bias = x + + dtype = x.dtype + x = x.float() + B, H, W, C = x.shape + + x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho") + x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) + + o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o2_real = torch.zeros(x.shape, device=x.device) + o2_imag = torch.zeros(x.shape, device=x.device) + + total_modes = H // 2 + 1 + kept_modes = int(total_modes * self.hard_thresholding_fraction) + + o1_real[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes] = F.relu( + torch.einsum("...bi,bio->...bo", x[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes].real, self.w1[0]) + - torch.einsum("...bi,bio->...bo", x[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes].imag, self.w1[1]) + + self.b1[0] + ) + + o1_imag[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes] = F.relu( + torch.einsum("...bi,bio->...bo", x[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes].imag, self.w1[0]) + + torch.einsum("...bi,bio->...bo", x[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes].real, self.w1[1]) + + self.b1[1] + ) + + o2_real[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes] = ( + torch.einsum("...bi,bio->...bo", o1_real[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes], self.w2[0]) + - torch.einsum("...bi,bio->...bo", o1_imag[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes], self.w2[1]) + + self.b2[0] + ) + + o2_imag[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes] = ( + torch.einsum("...bi,bio->...bo", o1_imag[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes], self.w2[0]) + + torch.einsum("...bi,bio->...bo", o1_real[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes], self.w2[1]) + + self.b2[1] + ) + + x = torch.stack([o2_real, o2_imag], dim=-1) + x = F.softshrink(x, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + x = x.reshape(B, H, W // 2 + 1, C) + x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm="ortho") + x = x.type(dtype) + + return x + bias + + +class Block(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + double_skip=True, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + # self.drop_path = nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.double_skip = double_skip + + def forward(self, x): + residual = x + x = self.norm1(x) + x = self.filter(x) + + if self.double_skip: + x = x + residual + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path(x) + x = x + residual + return x + + +class PrecipNet(nn.Module): + def __init__(self, backbone, patch_size=(16, 16), inp_chans=2, out_chans=2, **kwargs): + super().__init__() + self.patch_size = patch_size + self.inp_chans = inp_chans + self.out_chans = out_chans + self.backbone = backbone + self.ppad = PeriodicPad2d(1) + self.conv = nn.Conv2d(self.out_chans, self.out_chans, kernel_size=3, stride=1, padding=0, bias=True) + self.act = nn.ReLU() + + def forward(self, x): + x = self.backbone(x) + x = self.ppad(x) + x = self.conv(x) + x = self.act(x) + return x + + +class AdaptiveFourierNeuralOperatorNet(nn.Module): + def __init__( + self, + inp_shape=(720, 1440), + patch_size=(16, 16), + inp_chans=2, + out_chans=2, + embed_dim=768, + num_layers=12, + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.0, + num_blocks=16, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0, + **kwargs, + ): + super(AdaptiveFourierNeuralOperatorNet, self).__init__() + self.img_size = inp_shape + self.patch_size = patch_size + self.inp_chans = inp_chans + self.out_chans = out_chans + self.embed_dim = embed_dim + + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=self.patch_size, in_chans=self.inp_chans, embed_dim=self.embed_dim) + num_patches = self.patch_embed.num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] + + self.h = self.img_size[0] // self.patch_size[0] + self.w = self.img_size[1] // self.patch_size[1] + + self.blocks = nn.ModuleList( + [ + Block( + dim=self.embed_dim, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + num_blocks=num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + for i in range(num_layers) + ] + ) + + self.head = nn.Linear(self.embed_dim, self.out_chans * self.patch_size[0] * self.patch_size[1], bias=False) + + with torch.no_grad(): + nn.init.trunc_normal_(self.pos_embed, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x).transpose(1, 2) + x = x + self.pos_embed + x = self.pos_drop(x) + + x = x.reshape(B, self.h, self.w, self.embed_dim) + for blk in self.blocks: + x = blk(x) + + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + + # rearrange + b = x.shape[0] + xv = x.view(b, self.h, self.w, self.patch_size[0], self.patch_size[1], -1) + xvt = torch.permute(xv, (0, 5, 1, 3, 2, 4)).contiguous() + x = xvt.view(b, -1, (self.h * self.patch_size[0]), (self.w * self.patch_size[1])) + + return x diff --git a/networks/afnonet_v2.py b/networks/afnonet_v2.py new file mode 100644 index 0000000..a131f36 --- /dev/null +++ b/networks/afnonet_v2.py @@ -0,0 +1,315 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from collections import OrderedDict +from copy import Error, deepcopy +from re import S +from numpy.lib.arraypad import pad +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.fft +from torch.nn.modules.container import Sequential +from torch.utils.checkpoint import checkpoint_sequential +from typing import Optional +import math + +# helpers +from makani.models.common import ComplexReLU, PatchEmbed, DropPath, MLP + + +@torch.jit.script +def compl_mul_add_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + tmp = torch.einsum("bkixys,kior->srbkoxy", a, b) + res = torch.stack([tmp[0, 0, ...] - tmp[1, 1, ...], tmp[1, 0, ...] + tmp[0, 1, ...]], dim=-1) + return res + + +@torch.jit.script +def compl_mul_add_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bkixy,kio->bkoxy", ac, bc) + res = torch.view_as_real(resc) + return res + + +class AFNO2D(nn.Module): + def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.0, hard_thresholding_fraction=1, hidden_size_factor=1, use_complex_kernels=False): + super(AFNO2D, self).__init__() + assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.num_blocks = num_blocks + self.block_size = self.hidden_size // self.num_blocks + self.hard_thresholding_fraction = hard_thresholding_fraction + self.hidden_size_factor = hidden_size_factor + self.scale = 0.02 + self.mult_handle = compl_mul_add_fwd_c if use_complex_kernels else compl_mul_add_fwd + + # new + self.w1 = nn.Parameter(self.scale * torch.randn(self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor, 2)) + self.b1 = nn.Parameter(self.scale * torch.randn(1, self.num_blocks * self.block_size, 1, 1)) + self.w2 = nn.Parameter(self.scale * torch.randn(self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size, 2)) + # self.b2 = nn.Parameter(self.scale * torch.randn(self.num_blocks, self.block_size, 1, 1, 2)) + + # self.act = nn.ReLU() + self.act = ComplexReLU(negative_slope=0.0, mode="cartesian") + + def forward(self, x): + bias = x + + dtype = x.dtype + x = x.float() + B, C, H, W = x.shape + total_modes_H = H // 2 + 1 + total_modes_W = W // 2 + 1 + kept_modes_H = int(total_modes_H * self.hard_thresholding_fraction) + kept_modes_W = int(total_modes_W * self.hard_thresholding_fraction) + + x = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho") + x = x.view(B, self.num_blocks, self.block_size, H, W // 2 + 1) + + # do spectral conv + x = torch.view_as_real(x) + x_fft = torch.zeros(x.shape, device=x.device) + + if kept_modes_H == total_modes_H: + oac = torch.view_as_complex(self.mult_handle(x[:, :, :, :, :kept_modes_W, :], self.w1)) + oa = torch.view_as_real(self.act(oac)) + x_fft[:, :, :, :, :kept_modes_W, :] = self.mult_handle(oa, self.w2) + else: + olc = torch.view_as_complex(self.mult_handle(x[:, :, :, :kept_modes_H, :kept_modes_W, :], self.w1)) + ohc = torch.view_as_complex(self.mult_handle(x[:, :, :, -kept_modes_H:, :kept_modes_W, :], self.w1)) + + ol = torch.view_as_real(self.act(olc)) + oh = torch.view_as_real(self.act(ohc)) + + x_fft[:, :, :, :kept_modes_H, :kept_modes_W, :] = self.mult_handle(ol, self.w2) + x_fft[:, :, :, -kept_modes_H:, :kept_modes_W, :] = self.mult_handle(oh, self.w2) + + # finalize + x = F.softshrink(x_fft, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + x = x.reshape(B, C, H, W // 2 + 1) + x = torch.fft.irfft2(x, s=(H, W), dim=(-2, -1), norm="ortho") + x = x.type(dtype) + + return x + self.b1 + bias + + +class Block(nn.Module): + def __init__( + self, + h, + w, + dim, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0, + use_complex_kernels=True, + skip_fno="linear", + nested_skip_fno=True, + checkpointing=False, + verbose=True, + ): + super(Block, self).__init__() + + # norm layer + self.norm1 = norm_layer() # ((h,w)) + + if skip_fno is None: + if verbose: + print("Using no skip connection around FNO.") + + elif skip_fno == "linear": + # self.skip_layer = nn.Linear(dim, dim) + self.skip_layer = nn.Conv2d(dim, dim, 1, 1) + if verbose: + print("Using Linear skip connection around FNO.") + + elif skip_fno == "identity": + self.skip_layer = nn.Identity() + if verbose: + print("Using Identity skip connection around FNO.") + + else: + if verbose: + print(f"Got skip_fno={skip_fno}, not using any skip around FNO -- use linear or identity to change this.") + self.skip_fno = skip_fno + + self.nested_skip_fno = nested_skip_fno + + # filter + self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction, use_complex_kernels=use_complex_kernels) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + # norm layer + self.norm2 = norm_layer() # ((h,w)) + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_rate=drop, checkpointing=checkpointing) + + def forward(self, x): + residual = x + + x = self.norm1(x) + x = self.filter(x) + + if self.skip_fno is not None: + x = x + self.skip_layer(residual) + if not self.nested_skip_fno: + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path(x) + x = x + residual + return x + + +class AdaptiveFourierNeuralOperatorNet(nn.Module): + def __init__( + self, + inp_shape=(720, 1440), + patch_size=(16, 16), + inp_chans=2, + out_chans=2, + embed_dim=768, + num_layers=12, + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.0, + num_blocks=16, + sparsity_threshold=0.01, + normalization_layer="instance_norm", + skip_fno="linear", + nested_skip_fno=True, + hard_thresholding_fraction=1.0, + checkpointing=False, + use_complex_kernels=True, + verbose=False, + **kwargs, + ): + super(AdaptiveFourierNeuralOperatorNet, self).__init__() + self.img_size = inp_shape + self.patch_size = patch_size + self.inp_chans = inp_chans + self.out_chans = out_chans + self.embed_dim = embed_dim + + # some sanity checks + assert len(patch_size) == 2, f"Expected patch_size to have two entries but got {patch_size} instead" + assert (self.img_size[0] % self.patch_size[0] == 0) and ( + self.img_size[1] % self.patch_size[1] == 0 + ), f"Error, the patch size {self.patch_size} does not divide the image dimensions {self.img_size} evenly." + + self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=self.patch_size, in_chans=self.inp_chans, embed_dim=self.embed_dim) + num_patches = self.patch_embed.num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, num_patches)) + self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity() + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] + + # compute the downscaled image size + self.h = self.img_size[0] // self.patch_size[0] + self.w = self.img_size[1] // self.patch_size[1] + + # pick norm layer + if normalization_layer == "layer_norm": + norm_layer = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6) + elif normalization_layer == "instance_norm": + norm_layer = partial(nn.InstanceNorm2d, num_features=embed_dim, eps=1e-6, affine=True, track_running_stats=False) + else: + raise NotImplementedError(f"Error, normalization {normalization_layer} not implemented.") + + self.blocks = nn.ModuleList( + [ + Block( + h=self.h, + w=self.w, + dim=self.embed_dim, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + num_blocks=num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + use_complex_kernels=use_complex_kernels, + skip_fno=skip_fno, + nested_skip_fno=nested_skip_fno, + checkpointing=checkpointing, + verbose=verbose, + ) + for i in range(num_layers) + ] + ) + + # head + self.head = nn.Conv2d(embed_dim, self.out_chans * self.patch_size[0] * self.patch_size[1], 1, bias=False) + + with torch.no_grad(): + nn.init.trunc_normal_(self.pos_embed, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): + nn.init.trunc_normal_(m.weight, std=0.02) + # nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.InstanceNorm3d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + # reshape + x = x.reshape(B, self.embed_dim, self.h, self.w) + + for blk in self.blocks: + x = blk(x) + + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + + # new: B, C, H, W + b = x.shape[0] + xv = x.view(b, self.patch_size[0], self.patch_size[1], -1, self.h, self.w) + xvt = torch.permute(xv, (0, 3, 4, 1, 5, 2)).contiguous() + x = xvt.view(b, -1, (self.h * self.patch_size[0]), (self.w * self.patch_size[1])) + + return x diff --git a/networks/debug.py b/networks/debug.py new file mode 100644 index 0000000..5165465 --- /dev/null +++ b/networks/debug.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + + +class DebugNet(nn.Module): + def __init__(self, **kwargs): + super().__init__() + + # create dummy param so that it won't crash in optimizer instantiation + self.factor = nn.Parameter(torch.ones((1), dtype=torch.float32)) + + def forward(self, x): + return self.factor * x diff --git a/networks/sfnonet.py b/networks/sfnonet.py new file mode 100644 index 0000000..9bf4fdf --- /dev/null +++ b/networks/sfnonet.py @@ -0,0 +1,673 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +import math +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from torch.cuda import amp + +from functools import partial + +# helpers +from makani.models.common import DropPath, MLP, EncoderDecoder + +# import global convolution and non-linear spectral layers +from makani.models.common import SpectralConv, FactorizedSpectralConv, SpectralAttention + +# get spectral transforms from torch_harmonics +import torch_harmonics as th +import torch_harmonics.distributed as thd + +# wrap fft, to unify interface to spectral transforms +from makani.models.common import RealFFT2, InverseRealFFT2 +from makani.mpu.layers import DistributedRealFFT2, DistributedInverseRealFFT2, DistributedMLP, DistributedEncoderDecoder + +# more distributed stuff +from makani.utils import comm + +# layer normalization +from modulus.distributed.mappings import scatter_to_parallel_region, gather_from_parallel_region +from makani.mpu.layer_norm import DistributedInstanceNorm2d, DistributedLayerNorm + +# for annotation of models +import modulus +from modulus.models.meta import ModelMetaData + + +class SpectralFilterLayer(nn.Module): + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + filter_type="linear", + operator_type="diagonal", + hidden_size_factor=1, + factorization=None, + rank=1.0, + separable=False, + complex_activation="real", + spectral_layers=1, + bias=False, + drop_rate=0.0, + gain=1.0, + ): + super(SpectralFilterLayer, self).__init__() + + if filter_type == "non-linear": + self.filter = SpectralAttention( + forward_transform, + inverse_transform, + embed_dim, + embed_dim, + operator_type=operator_type, + hidden_size_factor=hidden_size_factor, + complex_activation=complex_activation, + spectral_layers=spectral_layers, + drop_rate=drop_rate, + bias=bias, + gain=gain, + ) + + elif filter_type == "linear" and factorization is None: + self.filter = SpectralConv( + forward_transform, + inverse_transform, + embed_dim, + embed_dim, + operator_type=operator_type, + separable=separable, + bias=bias, + gain=gain, + ) + + elif filter_type == "linear" and factorization is not None: + self.filter = FactorizedSpectralConv( + forward_transform, + inverse_transform, + embed_dim, + embed_dim, + operator_type=operator_type, + rank=rank, + factorization=factorization, + separable=separable, + bias=bias, + gain=gain, + ) + + else: + raise (NotImplementedError) + + def forward(self, x): + return self.filter(x) + + +class FourierNeuralOperatorBlock(nn.Module): + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + filter_type="linear", + operator_type="diagonal", + mlp_ratio=2.0, + mlp_drop_rate=0.0, + path_drop_rate=0.0, + act_layer=nn.GELU, + norm_layer=(nn.Identity, nn.Identity), + rank=1.0, + factorization=None, + separable=False, + inner_skip="linear", + outer_skip=None, + use_mlp=False, + comm_feature_inp_name=None, + comm_feature_hidden_name=None, + complex_activation="real", + spectral_layers=1, + bias=False, + final_activation=False, + checkpointing=0, + ): + super(FourierNeuralOperatorBlock, self).__init__() + + # determine some shapes + if comm.get_size("spatial") > 1: + self.input_shape_loc = (forward_transform.lat_shapes[comm.get_rank("h")], + forward_transform.lon_shapes[comm.get_rank("w")]) + self.output_shape_loc = (inverse_transform.lat_shapes[comm.get_rank("h")], + inverse_transform.lon_shapes[comm.get_rank("w")]) + else: + self.input_shape_loc = (forward_transform.nlat, forward_transform.nlon) + self.output_shape_loc = (inverse_transform.nlat, inverse_transform.nlon) + + # norm layer + self.norm0 = norm_layer[0]() + + if act_layer == nn.Identity: + gain_factor = 1.0 + else: + gain_factor = 2.0 + + if inner_skip == "linear": + self.inner_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1, bias=False) + gain_factor /= 2.0 + nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / embed_dim)) + elif inner_skip == "identity": + self.inner_skip = nn.Identity() + gain_factor /= 2.0 + elif inner_skip == "none": + pass + else: + raise ValueError(f"Unknown skip connection type {inner_skip}") + + # convolution layer + self.filter = SpectralFilterLayer( + forward_transform, + inverse_transform, + embed_dim, + filter_type, + operator_type, + hidden_size_factor=mlp_ratio, + factorization=factorization, + rank=rank, + separable=separable, + complex_activation=complex_activation, + spectral_layers=spectral_layers, + bias=bias, + drop_rate=path_drop_rate, + gain=gain_factor, + ) + + self.act_layer0 = act_layer() + + # norm layer + self.norm1 = norm_layer[1]() + + if final_activation and act_layer != nn.Identity: + gain_factor = 2.0 + else: + gain_factor = 1.0 + + if outer_skip == "linear": + self.outer_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1, bias=False) + gain_factor /= 2.0 + torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / embed_dim)) + elif outer_skip == "identity": + self.outer_skip = nn.Identity() + gain_factor /= 2.0 + elif outer_skip == "none": + pass + else: + raise ValueError(f"Unknown skip connection type {outer_skip}") + + if use_mlp == True: + MLPH = DistributedMLP if (comm.get_size("matmul") > 1) else MLP + mlp_hidden_dim = int(embed_dim * mlp_ratio) + self.mlp = MLPH( + in_features=embed_dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop_rate=mlp_drop_rate, + drop_type="features", + comm_inp_name=comm_feature_inp_name, + comm_hidden_name=comm_feature_hidden_name, + checkpointing=checkpointing, + gain=gain_factor, + ) + + # dropout + self.drop_path = DropPath(path_drop_rate) if path_drop_rate > 0.0 else nn.Identity() + + if final_activation: + self.act_layer1 = act_layer() + + def forward(self, x): + """ + Updated FNO block + """ + + x, residual = self.filter(x) + + x = self.norm0(x) + + if hasattr(self, "inner_skip"): + x = x + self.inner_skip(residual) + + if hasattr(self, "act_layer0"): + x = self.act_layer0(x) + + if hasattr(self, "mlp"): + x = self.mlp(x) + + x = self.norm1(x) + + x = self.drop_path(x) + + if hasattr(self, "outer_skip"): + x = x + self.outer_skip(residual) + + if hasattr(self, "act_layer1"): + x = self.act_layer1(x) + + return x + + +class SphericalFourierNeuralOperatorNet(nn.Module): + """ + SFNO implementation as in Bonev et al.; Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere + """ + + def __init__( + self, + spectral_transform="sht", + model_grid_type="equiangular", + sht_grid_type="legendre-gauss", + filter_type="linear", + operator_type="dhconv", + inp_shape=(721, 1440), + out_shape=(721, 1440), + scale_factor=8, + inp_chans=2, + out_chans=2, + embed_dim=32, + num_layers=4, + repeat_layers=1, + use_mlp=True, + mlp_ratio=2.0, + encoder_ratio=1, + decoder_ratio=1, + activation_function="gelu", + encoder_layers=1, + pos_embed="none", + pos_drop_rate=0.0, + path_drop_rate=0.0, + mlp_drop_rate=0.0, + normalization_layer="instance_norm", + max_modes=None, + hard_thresholding_fraction=1.0, + big_skip=True, + rank=1.0, + factorization=None, + separable=False, + complex_activation="real", + spectral_layers=3, + bias=False, + checkpointing=0, + **kwargs, + ): + super(SphericalFourierNeuralOperatorNet, self).__init__() + + self.inp_shape = inp_shape + self.out_shape = out_shape + self.inp_chans = inp_chans + self.out_chans = out_chans + self.embed_dim = embed_dim + self.repeat_layers = repeat_layers + self.big_skip = big_skip + self.checkpointing = checkpointing + + # compute the downscaled image size + self.h = int(self.inp_shape[0] // scale_factor) + self.w = int(self.inp_shape[1] // scale_factor) + + # initialize spectral transforms + self._init_spectral_transforms(spectral_transform, model_grid_type, sht_grid_type, hard_thresholding_fraction, max_modes) + + # determine activation function + if activation_function == "relu": + activation_function = nn.ReLU + elif activation_function == "gelu": + activation_function = nn.GELU + elif activation_function == "silu": + activation_function = nn.SiLU + else: + raise ValueError(f"Unknown activation function {activation_function}") + + # set up encoder + if comm.get_size("matmul") > 1: + self.encoder = DistributedEncoderDecoder( + num_layers=encoder_layers, + input_dim=self.inp_chans, + output_dim=self.embed_dim, + hidden_dim=int(encoder_ratio * self.embed_dim), + act_layer=activation_function, + input_format="nchw", + comm_inp_name="fin", + comm_out_name="fout", + ) + fblock_mlp_inp_name = self.encoder.comm_out_name + fblock_mlp_hidden_name = "fout" if (self.encoder.comm_out_name == "fin") else "fin" + else: + self.encoder = EncoderDecoder( + num_layers=encoder_layers, + input_dim=self.inp_chans, + output_dim=self.embed_dim, + hidden_dim=int(encoder_ratio * self.embed_dim), + act_layer=activation_function, + input_format="nchw", + ) + fblock_mlp_inp_name = "fin" + fblock_mlp_hidden_name = "fout" + + # dropout + self.pos_drop = nn.Dropout(p=pos_drop_rate) if pos_drop_rate > 0.0 else nn.Identity() + dpr = [x.item() for x in torch.linspace(0, path_drop_rate, num_layers)] + + # pick norm layer + if normalization_layer == "layer_norm": + norm_layer_inp = partial(DistributedLayerNorm, normalized_shape=(embed_dim), elementwise_affine=True, eps=1e-6) + norm_layer_out = norm_layer_mid = norm_layer_inp + elif normalization_layer == "instance_norm": + if comm.get_size("spatial") > 1: + norm_layer_inp = partial(DistributedInstanceNorm2d, num_features=embed_dim, eps=1e-6, affine=True) + else: + norm_layer_inp = partial(nn.InstanceNorm2d, num_features=embed_dim, eps=1e-6, affine=True, track_running_stats=False) + norm_layer_out = norm_layer_mid = norm_layer_inp + elif normalization_layer == "none": + norm_layer_out = norm_layer_mid = norm_layer_inp = nn.Identity + else: + raise NotImplementedError(f"Error, normalization {normalization_layer} not implemented.") + + # FNO blocks + self.blocks = nn.ModuleList([]) + for i in range(num_layers): + first_layer = i == 0 + last_layer = i == num_layers - 1 + + forward_transform = self.trans_down if first_layer else self.trans + inverse_transform = self.itrans_up if last_layer else self.itrans + + inner_skip = "none" + outer_skip = "linear" + + if first_layer: + norm_layer = (norm_layer_inp, norm_layer_mid) + elif last_layer: + norm_layer = (norm_layer_mid, norm_layer_out) + else: + norm_layer = (norm_layer_mid, norm_layer_mid) + + block = FourierNeuralOperatorBlock( + forward_transform, + inverse_transform, + embed_dim, + filter_type=filter_type, + operator_type=operator_type, + mlp_ratio=mlp_ratio, + mlp_drop_rate=mlp_drop_rate, + path_drop_rate=dpr[i], + act_layer=activation_function, + norm_layer=norm_layer, + inner_skip=inner_skip, + outer_skip=outer_skip, + use_mlp=use_mlp, + comm_feature_inp_name=fblock_mlp_inp_name, + comm_feature_hidden_name=fblock_mlp_hidden_name, + rank=rank, + factorization=factorization, + separable=separable, + complex_activation=complex_activation, + spectral_layers=spectral_layers, + bias=bias, + checkpointing=checkpointing, + ) + + self.blocks.append(block) + + # decoder takes the output of FNO blocks and the residual from the big skip connection + if comm.get_size("matmul") > 1: + comm_inp_name = fblock_mlp_inp_name + comm_out_name = fblock_mlp_hidden_name + self.decoder = DistributedEncoderDecoder( + num_layers=encoder_layers, + input_dim=embed_dim, + output_dim=self.out_chans, + hidden_dim=int(decoder_ratio * embed_dim), + act_layer=activation_function, + gain=0.5 if self.big_skip else 1.0, + comm_inp_name=comm_inp_name, + comm_out_name=comm_out_name, + input_format="nchw", + ) + self.gather_shapes = compute_split_shapes(self.out_chans, + comm.get_size(self.decoder.comm_out_name)) + + else: + self.decoder = EncoderDecoder( + num_layers=encoder_layers, + input_dim=embed_dim, + output_dim=self.out_chans, + hidden_dim=int(decoder_ratio * embed_dim), + act_layer=activation_function, + gain=0.5 if self.big_skip else 1.0, + input_format="nchw", + ) + + # output transform + if self.big_skip: + self.residual_transform = nn.Conv2d(self.inp_chans, self.out_chans, 1, bias=False) + self.residual_transform.weight.is_shared_mp = ["spatial"] + self.residual_transform.weight.sharded_dims_mp = [None, None, None, None] + scale = math.sqrt(0.5 / self.inp_chans) + nn.init.normal_(self.residual_transform.weight, mean=0.0, std=scale) + + # learned position embedding + if pos_embed == "direct": + # currently using deliberately a differently shape position embedding + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, self.inp_shape_loc[0], self.inp_shape_loc[1])) + # information about how tensors are shared / sharded across ranks + self.pos_embed.is_shared_mp = [] # no reduction required since pos_embed is already serial + self.pos_embed.sharded_dims_mp = [None, None, "h", "w"] + self.pos_embed.type = "direct" + with torch.no_grad(): + nn.init.trunc_normal_(self.pos_embed, std=0.02) + elif pos_embed == "frequency": + if comm.get_size("spatial") > 1: + lmax_loc = self.itrans_up.l_shapes[comm.get_rank("h")] + mmax_loc = self.itrans_up.m_shapes[comm.get_rank("w")] + else: + lmax_loc = self.itrans_up.lmax + mmax_loc = self.itrans_up.mmax + + rcoeffs = nn.Parameter(torch.tril(torch.randn(1, embed_dim, lmax_loc, mmax_loc), diagonal=0)) + ccoeffs = nn.Parameter(torch.tril(torch.randn(1, embed_dim, lmax_loc, mmax_loc - 1), diagonal=-1)) + with torch.no_grad(): + nn.init.trunc_normal_(rcoeffs, std=0.02) + nn.init.trunc_normal_(ccoeffs, std=0.02) + self.pos_embed = nn.ParameterList([rcoeffs, ccoeffs]) + self.pos_embed.type = "frequency" + self.pos_embed.is_shared_mp = [] + self.pos_embed.sharded_dims_mp = [None, None, "h", "w"] + + elif pos_embed == "none" or pos_embed == "None" or pos_embed == None: + pass + else: + raise ValueError("Unknown position embedding type") + + @torch.jit.ignore + def _init_spectral_transforms( + self, + spectral_transform="sht", + model_grid_type="equiangular", + sht_grid_type="legendre-gauss", + hard_thresholding_fraction=1.0, + max_modes=None, + ): + """ + Initialize the spectral transforms based on the maximum number of modes to keep. Handles the computation + of local image shapes and domain parallelism, based on the + """ + + if max_modes is not None: + modes_lat, modes_lon = max_modes + else: + modes_lat = int(self.h * hard_thresholding_fraction) + modes_lon = int((self.w // 2 + 1) * hard_thresholding_fraction) + + # prepare the spectral transforms + if spectral_transform == "sht": + sht_handle = th.RealSHT + isht_handle = th.InverseRealSHT + + # parallelism + if comm.get_size("spatial") > 1: + polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(polar_group, azimuth_group) + sht_handle = thd.DistributedRealSHT + isht_handle = thd.DistributedInverseRealSHT + + # set up + self.trans_down = sht_handle(*self.inp_shape, lmax=modes_lat, mmax=modes_lon, grid=model_grid_type).float() + self.itrans_up = isht_handle(*self.out_shape, lmax=modes_lat, mmax=modes_lon, grid=model_grid_type).float() + self.trans = sht_handle(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=sht_grid_type).float() + self.itrans = isht_handle(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=sht_grid_type).float() + + elif spectral_transform == "fft": + fft_handle = RealFFT2 + ifft_handle = InverseRealFFT2 + + if comm.get_size("spatial") > 1: + h_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + w_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(h_group, w_group) + fft_handle = DistributedRealFFT2 + ifft_handle = DistributedInverseRealFFT2 + + self.trans_down = fft_handle(*self.inp_shape, lmax=modes_lat, mmax=modes_lon).float() + self.itrans_up = ifft_handle(*self.out_shape, lmax=modes_lat, mmax=modes_lon).float() + self.trans = fft_handle(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float() + self.itrans = ifft_handle(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float() + else: + raise (ValueError("Unknown spectral transform")) + + # use the SHT/FFT to compute the local, downscaled grid dimensions + if comm.get_size("spatial") > 1: + self.inp_shape_loc = (self.trans_down.lat_shapes[comm.get_rank("h")], + self.trans_down.lon_shapes[comm.get_rank("w")]) + self.out_shape_loc = (self.itrans_up.lat_shapes[comm.get_rank("h")], + self.itrans_up.lon_shapes[comm.get_rank("w")]) + self.h_loc = self.itrans.lat_shapes[comm.get_rank("h")] + self.w_loc = self.itrans.lon_shapes[comm.get_rank("w")] + else: + self.inp_shape_loc = (self.trans_down.nlat, self.trans_down.nlon) + self.out_shape_loc = (self.itrans_up.nlat, self.itrans_up.nlon) + self.h_loc = self.itrans.nlat + self.w_loc = self.itrans.nlon + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def _forward_features(self, x): + for r in range(self.repeat_layers): + for blk in self.blocks: + if self.checkpointing >= 3: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + + return x + + def forward(self, x): + # save big skip + if self.big_skip: + # if output shape differs, use the spectral transforms to change resolution + if self.out_shape != self.inp_shape: + xtype = x.dtype + # only take the predicted channels as residual + residual = x.to(torch.float32) + with amp.autocast(enabled=False): + residual = self.trans_down(residual) + residual = residual.contiguous() + residual = self.itrans_up(residual) + residual = residual.to(dtype=xtype) + else: + # only take the predicted channels + residual = x + + if comm.get_size("fin") > 1: + x = scatter_to_parallel_region(x, 1, "fin") + + if self.checkpointing >= 1: + x = checkpoint(self.encoder, x, use_reentrant=False) + else: + x = self.encoder(x) + + if hasattr(self, "pos_embed"): + if self.pos_embed.type == "frequency": + pos_embed = torch.stack([self.pos_embed[0], nn.functional.pad(self.pos_embed[1], (1, 0), "constant", 0)], dim=-1) + with amp.autocast(enabled=False): + pos_embed = self.itrans_up(torch.view_as_complex(pos_embed)) + else: + pos_embed = self.pos_embed + + # add pos embed + x = x + pos_embed + + # maybe clean the padding just in case + x = self.pos_drop(x) + + # do the feature extraction + x = self._forward_features(x) + + if self.checkpointing >= 1: + x = checkpoint(self.decoder, x, use_reentrant=False) + else: + x = self.decoder(x) + + if hasattr(self.decoder, "comm_out_name") and (comm.get_size(self.decoder.comm_out_name) > 1): + x = gather_from_parallel_region(x, 1, self.gather_shapes, self.decoder.comm_out_name) + + if self.big_skip: + x = x + self.residual_transform(residual) + + return x + +# this part exposes the model to modulus by constructing modulus Modules +@dataclass +class SphericalFourierNeuralOperatorNetMetaData(ModelMetaData): + name: str = "SFNO" + + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + +SFNO = modulus.Module.from_torch( + SphericalFourierNeuralOperatorNet, + SphericalFourierNeuralOperatorNetMetaData() +) + +class FourierNeuralOperatorNet(SphericalFourierNeuralOperatorNet): + def __init__(self, *args, **kwargs): + return super().__init__(*args, spectral_transform="fft", **kwargs) + +@dataclass +class FourierNeuralOperatorNetMetaData(ModelMetaData): + name: str = "FNO" + + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + +FNO = modulus.Module.from_torch( + FourierNeuralOperatorNet, + FourierNeuralOperatorNetMetaData() +) \ No newline at end of file diff --git a/networks/vit.py b/networks/vit.py new file mode 100644 index 0000000..426520a --- /dev/null +++ b/networks/vit.py @@ -0,0 +1,231 @@ +import math + +import torch.nn.functional as F +import torch +import torch.nn as nn +from functools import partial + +# mp stuff +from makani.utils import comm +from makani.models.common import DropPath, MLP, PatchEmbed +from makani.mpu.layers import DistributedMatmul, DistributedMLP, DistributedAttention + + +class Attention(nn.Module): + def __init__( + self, + dim, + input_format="traditional", + num_heads=8, + qkv_bias=False, + qk_norm=False, + attn_drop_rate=0.0, + proj_drop_rate=0.0, + norm_layer=nn.LayerNorm, + ): + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop_rate = attn_drop_rate + + self.proj = nn.Linear(dim, dim) + + if proj_drop_rate > 0: + self.proj_drop = nn.Dropout(proj_drop_rate) + else: + self.proj_drop = nn.Identity() + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop_rate) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + mlp_drop_rate=0.0, + attn_drop_rate=0.0, + path_drop_rate=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + comm_inp_name="fin", + comm_hidden_name="fout", + ): + super().__init__() + + if (comm.get_size(comm_inp_name) * comm.get_size(comm_hidden_name)) > 1: + self.attn = DistributedAttention( + dim, + input_format="traditional", + comm_inp_name=comm_inp_name, + comm_hidden_name=comm_hidden_name, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=mlp_drop_rate, + norm_layer=norm_layer, + ) + else: + self.attn = Attention( + dim, input_format="traditional", num_heads=num_heads, qkv_bias=qkv_bias, attn_drop_rate=attn_drop_rate, proj_drop_rate=mlp_drop_rate, norm_layer=norm_layer + ) + self.drop_path = DropPath(path_drop_rate) if path_drop_rate > 0.0 else nn.Identity() + + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + + mlp_hidden_dim = int(dim * mlp_ratio) + + # distribute MLP for model parallelism + if (comm.get_size(comm_inp_name) * comm.get_size(comm_hidden_name)) > 1: + self.mlp = DistributedMLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + out_features=dim, + act_layer=act_layer, + drop_rate=mlp_drop_rate, + input_format="traditional", + comm_inp_name=comm_inp_name, + comm_hidden_name=comm_hidden_name, + ) + else: + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop_rate=mlp_drop_rate, input_format="traditional") + + def forward(self, x): + # flatten transpose: + y = self.attn(self.norm1(x)) + x = x + self.drop_path(y) + x = self.norm2(x) + x = x + self.drop_path(self.mlp(x)) + + return x + + +class VisionTransformer(nn.Module): + def __init__( + self, + inp_shape=[224, 224], + patch_size=(16, 16), + inp_chans=3, + out_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + mlp_drop_rate=0.0, + attn_drop_rate=0.0, + path_drop_rate=0.0, + norm_layer="layer_norm", + comm_inp_name="fin", + comm_hidden_name="fout", + **kwargs, + ): + super().__init__() + self.num_features = self.embed_dim = embed_dim + self.patch_size = patch_size + self.img_size = inp_shape + self.out_ch = out_chans + self.comm_inp_name = comm_inp_name + self.comm_hidden_name = comm_hidden_name + + self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=patch_size, in_chans=inp_chans, embed_dim=self.embed_dim) + num_patches = self.patch_embed.num_patches + + # annotate for distributed + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim)) + self.pos_embed.is_shared_mp = [] + + self.pos_drop = nn.Dropout(p=path_drop_rate) + + dpr = [x.item() for x in torch.linspace(0, path_drop_rate, depth)] # stochastic depth decay rule + + if norm_layer == "layer_norm": + norm_layer_handle = nn.LayerNorm + else: + raise NotImplementedError(f"Error, normalization layer type {norm_layer} not implemented for ViT.") + + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + mlp_drop_rate=mlp_drop_rate, + attn_drop_rate=attn_drop_rate, + path_drop_rate=dpr[i], + norm_layer=norm_layer_handle, + comm_inp_name=comm_inp_name, + comm_hidden_name=comm_hidden_name, + ) + for i in range(depth) + ] + ) + + self.norm = norm_layer_handle(embed_dim) + + self.out_size = self.out_ch * self.patch_size[0] * self.patch_size[1] + + self.head = nn.Linear(embed_dim, self.out_size, bias=False) + + nn.init.trunc_normal_(self.pos_embed, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def prepare_tokens(self, x): + B, C, H, W = x.shape + x = self.patch_embed(x).transpose(1, 2) # patch linear embedding + + # add positional encoding to each token + x = x + self.pos_embed + return self.pos_drop(x) + + def forward_head(self, x): + B, _, _ = x.shape # B x N x embed_dim + x = x.reshape(B, self.patch_embed.red_img_size[0], self.patch_embed.red_img_size[1], self.embed_dim) + B, h, w, _ = x.shape + + # apply head + x = self.head(x) + x = x.reshape(shape=(B, h, w, self.patch_size[0], self.patch_size[1], self.out_ch)) + x = torch.einsum("nhwpqc->nchpwq", x) + x = x.reshape(shape=(B, self.out_ch, self.img_size[0], self.img_size[1])) + + return x + + def forward(self, x): + x = self.prepare_tokens(x) + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + x = self.forward_head(x) + return x diff --git a/preprocessor.py b/preprocessor.py new file mode 100644 index 0000000..6689112 --- /dev/null +++ b/preprocessor.py @@ -0,0 +1,426 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from functools import partial + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from makani.utils import comm +from makani.utils.grids import GridConverter +from modulus.distributed.mappings import reduce_from_parallel_region, copy_to_parallel_region + + +class Preprocessor2D(nn.Module): + def __init__(self, params): + super(Preprocessor2D, self).__init__() + + self.n_history = params.n_history + self.history_normalization_mode = params.history_normalization_mode + if self.history_normalization_mode == "exponential": + self.history_normalization_decay = params.history_normalization_decay + # inverse ordering, since first element is oldest + history_normalization_weights = torch.exp((-self.history_normalization_decay) * torch.arange(start=self.n_history, end=-1, step=-1, dtype=torch.float32)) + history_normalization_weights = history_normalization_weights / torch.sum(history_normalization_weights) + history_normalization_weights = torch.reshape(history_normalization_weights, (1, -1, 1, 1, 1)) + elif self.history_normalization_mode == "mean": + history_normalization_weights = torch.Tensor(1.0 / float(self.n_history + 1), dtype=torch.float32) + history_normalization_weights = torch.reshape(history_normalization_weights, (1, -1, 1, 1, 1)) + else: + history_normalization_weights = torch.ones(self.n_history + 1, dtype=torch.float32) + self.register_buffer("history_normalization_weights", history_normalization_weights, persistent=False) + self.history_mean = None + self.history_std = None + self.history_diff_mean = None + self.history_diff_var = None + self.history_eps = 1e-6 + + # residual normalization + self.learn_residual = params.target == "residual" + if self.learn_residual and (params.normalize_residual): + with torch.no_grad(): + residual_scale = torch.from_numpy(np.load(params.time_diff_stds_path)).to(torch.float32) + self.register_buffer("residual_scale", residual_scale, persistent=False) + else: + self.residual_scale = None + + # image shape + self.img_shape = [params.img_shape_x, params.img_shape_y] + + # unpredicted input channels: + self.unpredicted_inp_train = None + self.unpredicted_tar_train = None + self.unpredicted_inp_eval = None + self.unpredicted_tar_eval = None + + # process static features + static_features = None + # needed for sharding + start_x = params.img_local_offset_x + end_x = min(start_x + params.img_local_shape_x, params.img_shape_x) + start_y = params.img_local_offset_y + end_y = min(start_y + params.img_local_shape_y, params.img_shape_y) + + # set up grid + if params.add_grid: + with torch.no_grad(): + if hasattr(params, "lat") and hasattr(params, "lon"): + lat = torch.tensor(params.lat).to(torch.float32) + lon = torch.tensor(params.lon).to(torch.float32) + + # convert grid if required + gconv = GridConverter(params.data_grid_type, params.model_grid_type, torch.deg2rad(lat), torch.deg2rad(lon)) + tx, ty = gconv.get_dst_coords() + tx = tx.to(torch.float32) + ty = ty.to(torch.float32) + else: + tx = torch.linspace(0, 1, params.img_shape_x + 1, dtype=torch.float32)[0:-1] + ty = torch.linspace(0, 1, params.img_shape_y + 1, dtype=torch.float32)[0:-1] + + x_grid, y_grid = torch.meshgrid(tx, ty, indexing="ij") + x_grid, y_grid = x_grid.unsqueeze(0).unsqueeze(0), y_grid.unsqueeze(0).unsqueeze(0) + grid = torch.cat([x_grid, y_grid], dim=1) + + # shard spatially: + grid = grid[:, :, start_x:end_x, start_y:end_y] + + # transform if requested + if params.gridtype == "sinusoidal": + num_freq = 1 + if hasattr(params, "grid_num_frequencies"): + num_freq = int(params.grid_num_frequencies) + + singrid = None + for freq in range(1, num_freq + 1): + if singrid is None: + singrid = torch.sin(grid) + else: + singrid = torch.cat([singrid, torch.sin(freq * grid)], dim=1) + + static_features = singrid + else: + static_features = grid + + if params.add_orography: + from makani.utils.conditioning_inputs import get_orography + + with torch.no_grad(): + oro = torch.tensor(get_orography(params.orography_path), dtype=torch.float32) + oro = torch.reshape(oro, (1, 1, oro.shape[0], oro.shape[1])) + + # normalize + eps = 1.0e-6 + oro = (oro - torch.mean(oro)) / (torch.std(oro) + eps) + + # shard + oro = oro[:, :, start_x:end_x, start_y:end_y] + + if static_features is None: + static_features = oro + else: + static_features = torch.cat([static_features, oro], dim=1) + + if params.add_landmask: + from makani.utils.conditioning_inputs import get_land_mask + + with torch.no_grad(): + lsm = torch.tensor(get_land_mask(params.landmask_path), dtype=torch.long) + # one hot encode and move channels to front: + lsm = torch.permute(torch.nn.functional.one_hot(lsm), (2, 0, 1)).to(torch.float32) + lsm = torch.reshape(lsm, (1, lsm.shape[0], lsm.shape[1], lsm.shape[2])) + + # shard + lsm = lsm[:, :, start_x:end_x, start_y:end_y] + + if static_features is None: + static_features = lsm + else: + static_features = torch.cat([static_features, lsm], dim=1) + + self.do_add_static_features = False + if static_features is not None: + self.do_add_static_features = True + self.register_buffer("static_features", static_features, persistent=False) + + def flatten_history(self, x): + # flatten input + if x.dim() == 5: + b_, t_, c_, h_, w_ = x.shape + x = torch.reshape(x, (b_, t_ * c_, h_, w_)) + + return x + + def expand_history(self, x, nhist): + if x.dim() == 4: + b_, ct_, h_, w_ = x.shape + x = torch.reshape(x, (b_, nhist, ct_ // nhist, h_, w_)) + return x + + def add_residual(self, x, dx): + if self.learn_residual: + if self.residual_scale is not None: + dx = dx * self.residual_scale + + # add residual: deal with history + x = self.expand_history(x, nhist=self.n_history + 1) + x[:, -1, ...] = x[:, -1, ...] + dx + x = self.flatten_history(x) + else: + x = dx + + return x + + def add_static_features(self, x): + if self.do_add_static_features: + # we need to replicate the grid for each batch: + static = torch.tile(self.static_features, dims=(x.shape[0], 1, 1, 1)) + x = torch.cat([x, static], dim=1) + + return x + + def remove_static_features(self, x): + # only remove if something was added in the first place + if self.do_add_static_features: + nfeat = self.static_features.shape[1] + x = x[:, : x.shape[1] - nfeat, :, :] + return x + + def append_history(self, x1, x2, step): + # take care of unpredicted features first + # this is necessary in order to copy the targets unpredicted features + # (such as zenith angle) into the inputs unpredicted features, + # such that they can be forward in the next autoregressive step + # extract utar + + # update the unpredicted input + if self.training: + if (self.unpredicted_tar_train is not None) and (step < self.unpredicted_tar_train.shape[1]): + utar = self.unpredicted_tar_train[:, step : (step + 1), :, :, :] + if self.n_history == 0: + self.unpredicted_inp_train.copy_(utar) + else: + self.unpredicted_inp_train.copy_(torch.cat([self.unpredicted_inp_train[:, 1:, :, :, :], utar], dim=1)) + else: + if (self.unpredicted_tar_eval is not None) and (step < self.unpredicted_tar_eval.shape[1]): + utar = self.unpredicted_tar_eval[:, step : (step + 1), :, :, :] + if self.n_history == 0: + self.unpredicted_inp_eval.copy_(utar) + else: + self.unpredicted_inp_eval.copy_(torch.cat([self.unpredicted_inp_eval[:, 1:, :, :, :], utar], dim=1)) + + if self.n_history > 0: + # this is more complicated + x1 = self.expand_history(x1, nhist=self.n_history + 1) + x2 = self.expand_history(x2, nhist=1) + + # append + res = torch.cat([x1[:, 1:, :, :, :], x2], dim=1) + + # flatten again + res = self.flatten_history(res) + else: + res = x2 + + return res + + def append_channels(self, x, xc): + xdim = x.dim() + x = self.expand_history(x, self.n_history + 1) + + xc = self.expand_history(xc, self.n_history + 1) + + # concatenate + xo = torch.cat([x, xc], dim=2) + + # flatten if requested + if xdim == 4: + xo = self.flatten_history(xo) + + return xo + + def history_compute_stats(self, x): + if self.history_normalization_mode == "none": + self.history_mean = torch.zeros((1, 1, 1, 1), dtype=torch.float32, device=x.device) + self.history_std = torch.ones((1, 1, 1, 1), dtype=torch.float32, device=x.device) + elif self.history_normalization_mode == "timediff": + # reshaping + xdim = x.dim() + if xdim == 4: + b_, c_, h_, w_ = x.shape + xr = torch.reshape(x, (b_, (self.n_history + 1), c_ // (self.n_history + 1), h_, w_)) + else: + xshape = x.shape + xr = x + + # time difference mean: + self.history_diff_mean = torch.mean(torch.sum(xr[:, 1:, ...] - xr[:, 0:-1, ...], dim=(4, 5)), dim=(1, 2)) + # reduce across gpus + if comm.get_size("spatial") > 1: + self.history_diff_mean = reduce_from_parallel_region(self.history_diff_mean, "spatial") + self.history_diff_mean = self.history_diff_mean / float(self.img_shape[0] * self.img_shape[1]) + + # time difference std + self.history_diff_var = torch.mean(torch.sum(torch.square((xr[:, 1:, ...] - xr[:, 0:-1, ...]) - self.history_diff_mean), dim=(4, 5)), dim=(1, 2)) + # reduce across gpus + if comm.get_size("spatial") > 1: + self.history_diff_var = reduce_from_parallel_region(self.history_diff_var, "spatial") + self.history_diff_var = self.history_diff_var / float(self.img_shape[0] * self.img_shape[1]) + + # time difference stds + self.history_diff_mean = copy_to_parallel_region(self.history_diff_mean, "spatial") + self.history_diff_var = copy_to_parallel_region(self.history_diff_var, "spatial") + else: + xdim = x.dim() + if xdim == 4: + b_, c_, h_, w_ = x.shape + xr = torch.reshape(x, (b_, (self.n_history + 1), c_ // (self.n_history + 1), h_, w_)) + else: + xshape = x.shape + xr = x + + # mean + # compute weighted mean over dim 1, but sum over dim=3,4 + self.history_mean = torch.sum(xr * self.history_normalization_weights, dim=(1, 3, 4), keepdim=True) + # reduce across gpus + if comm.get_size("spatial") > 1: + self.history_mean = reduce_from_parallel_region(self.history_mean, "spatial") + self.history_mean = self.history_mean / float(self.img_shape[0] * self.img_shape[1]) + + # compute std + self.history_std = torch.sum(torch.square(xr - self.history_mean) * self.history_normalization_weights, dim=(1, 3, 4), keepdim=True) + # reduce across gpus + if comm.get_size("spatial") > 1: + self.history_std = reduce_from_parallel_region(self.history_std, "spatial") + self.history_std = torch.sqrt(self.history_std / float(self.img_shape[0] * self.img_shape[1])) + + # squeeze + self.history_mean = torch.squeeze(self.history_mean, dim=1) + self.history_std = torch.squeeze(self.history_std, dim=1) + + # copy to parallel region + self.history_mean = copy_to_parallel_region(self.history_mean, "spatial") + self.history_std = copy_to_parallel_region(self.history_std, "spatial") + + return + + def history_normalize(self, x, target=False): + if self.history_normalization_mode in ["none", "timediff"]: + return x + + xdim = x.dim() + if xdim == 4: + b_, c_, h_, w_ = x.shape + xr = torch.reshape(x, (b_, (self.n_history + 1), c_ // (self.n_history + 1), h_, w_)) + else: + xshape = x.shape + xr = x + x = self.flatten_history(x) + + # normalize + if target: + # strip off the unpredicted channels + xn = (x - self.history_mean[:, : x.shape[1], :, :]) / self.history_std[:, : x.shape[1], :, :] + else: + # tile to include history + hm = torch.tile(self.history_mean, (1, self.n_history + 1, 1, 1)) + hs = torch.tile(self.history_std, (1, self.n_history + 1, 1, 1)) + xn = (x - hm) / hs + + if xdim == 5: + xn = torch.reshape(xn, xshape) + + return xn + + def history_denormalize(self, xn, target=False): + if self.history_normalization_mode in ["none", "timediff"]: + return xn + + assert self.history_mean is not None + assert self.history_std is not None + + xndim = xn.dim() + if xndim == 5: + xnshape = xn.shape + xn = self.flatten_history(xn) + + # de-normalize + if target: + # strip off the unpredicted channels + x = xn * self.history_std[:, : xn.shape[1], :, :] + self.history_mean[:, : xn.shape[1], :, :] + else: + # tile to include history + hm = torch.tile(self.history_mean, (1, self.n_history + 1, 1, 1)) + hs = torch.tile(self.history_std, (1, self.n_history + 1, 1, 1)) + x = xn * hs + hm + + if xndim == 5: + x = torch.reshape(x, xnshape) + + return x + + def cache_unpredicted_features(self, x, y, xz=None, yz=None): + if self.training: + if (self.unpredicted_inp_train is not None) and (xz is not None): + self.unpredicted_inp_train.copy_(xz) + else: + self.unpredicted_inp_train = xz + + if (self.unpredicted_tar_train is not None) and (yz is not None): + self.unpredicted_tar_train.copy_(yz) + else: + self.unpredicted_tar_train = yz + else: + if (self.unpredicted_inp_eval is not None) and (xz is not None): + self.unpredicted_inp_eval.copy_(xz) + else: + self.unpredicted_inp_eval = xz + + if (self.unpredicted_tar_eval is not None) and (yz is not None): + self.unpredicted_tar_eval.copy_(yz) + else: + self.unpredicted_tar_eval = yz + + return x, y + + def append_unpredicted_features(self, inp): + if self.training: + if self.unpredicted_inp_train is not None: + inp = self.append_channels(inp, self.unpredicted_inp_train) + else: + if self.unpredicted_inp_eval is not None: + inp = self.append_channels(inp, self.unpredicted_inp_eval) + return inp + + def remove_unpredicted_features(self, inp): + if self.training: + if self.unpredicted_inp_train is not None: + inpf = self.expand_history(inp, nhist=self.n_history + 1) + inpc = inpf[:, :, : inpf.shape[2] - self.unpredicted_inp_train.shape[2], :, :] + inp = self.flatten_history(inpc) + else: + if self.unpredicted_inp_eval is not None: + inpf = self.expand_history(inp, nhist=self.n_history + 1) + inpc = inpf[:, :, : inpf.shape[2] - self.unpredicted_inp_eval.shape[2], :, :] + inp = self.flatten_history(inpc) + + return inp + + +def get_preprocessor(params): + return Preprocessor2D(params) diff --git a/stepper.py b/stepper.py new file mode 100644 index 0000000..26cfbd7 --- /dev/null +++ b/stepper.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from makani.models.preprocessor import Preprocessor2D + +class SingleStepWrapper(nn.Module): + def __init__(self, params, model_handle): + super(SingleStepWrapper, self).__init__() + self.preprocessor = Preprocessor2D(params) + self.model = model_handle() + + def forward(self, inp): + # first append unpredicted features + inpa = self.preprocessor.append_unpredicted_features(inp) + + # now normalize + self.preprocessor.history_compute_stats(inpa) + inpan = self.preprocessor.history_normalize(inpa, target=False) + + # now add static features if requested + inpans = self.preprocessor.add_static_features(inpan) + + # forward pass + yn = self.model(inpans) + + # undo normalization + y = self.preprocessor.history_denormalize(yn, target=True) + + # add residual (for residual learning, no-op for direct learning + y = self.preprocessor.add_residual(inp, y) + + return y + + +class MultiStepWrapper(nn.Module): + def __init__(self, params, model_handle): + super(MultiStepWrapper, self).__init__() + self.preprocessor = Preprocessor2D(params) + self.model = model_handle() + self.residual_mode = True if (params.target == "target") else False + + # collect parameters for history + self.n_future = params.n_future + + def _forward_train(self, inp): + result = [] + inpt = inp + for step in range(self.n_future + 1): + # add unpredicted features + inpa = self.preprocessor.append_unpredicted_features(inpt) + + # do history normalization + self.preprocessor.history_compute_stats(inpa) + inpan = self.preprocessor.history_normalize(inpa, target=False) + + # add static features + inpans = self.preprocessor.add_static_features(inpan) + + # prediction + predn = self.model(inpans) + + # append the denormalized result to output list + # important to do that here, otherwise normalization stats + # will have been updated later: + pred = self.preprocessor.history_denormalize(predn, target=True) + # add residual (for residual learning, no-op for direct learning + pred = self.preprocessor.add_residual(inpt, pred) + # append output + result.append(pred) + + if step == self.n_future: + break + + # append history + inpt = self.preprocessor.append_history(inpt, pred, step) + + # concat the tensors along channel dim to be compatible with flattened target + result = torch.cat(result, dim=1) + + return result + + def _forward_eval(self, inp): + # first append unpredicted features + inpa = self.preprocessor.append_unpredicted_features(inp) + + # do history normalization + self.preprocessor.history_compute_stats(inpa) + inpan = self.preprocessor.history_normalize(inpa, target=False) + + # add static features + inpans = self.preprocessor.add_static_features(inpan) + + # important, remove normalization here, + # because otherwise normalization stats are already outdated + yn = self.model(inpans) + + # important, remove normalization here, + # because otherwise normalization stats are already outdated + y = self.preprocessor.history_denormalize(yn, target=True) + + # add residual (for residual learning, no-op for direct learning + y = self.preprocessor.add_residual(inp, y) + + return y + + def forward(self, inp): + # decide which routine to call + if self.training: + y = self._forward_train(inp) + else: + y = self._forward_eval(inp) + + return y \ No newline at end of file