Skip to content

Commit

Permalink
origin
Browse files Browse the repository at this point in the history
  • Loading branch information
jqs2019011556 committed Jul 25, 2024
0 parents commit 901e1f3
Show file tree
Hide file tree
Showing 31 changed files with 3,863 additions and 0 deletions.
57 changes: 57 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
## Networks

This folder contains models for training and associated code. Models that are currently supported can be queried by calling `networks.models.list_models()`.

### Directory structure
This folder is organized as follows:

```
makani
├── ...
├── models # code realted to ML models
│ ├── common # folder containing common features used in the neworks
│ │ ├── activations.py # complex activation functions
│ │ ├── contractions.py # einsum wrappers for complex contractions
│ │ ├── factorizations.py # tensor factorizations
│ │ ├── layers.py # common layers such as MLPs and wrappers for FFTs
│ │ └── spectral_convolution.py # spectral convolution layers for (S)FNO architectures
│ ├── networks # contains the actual architectures
│ │ ├── afnonet_v2.py # optimized AFNO
│ │ ├── afnonet.py # AFNO implementation
│ │ ├── debug.py # dummy network for debugging purposes
│ │ ├── sfnonet.py # implementation of (S)FNO
│ │ └── vit.py # implementation of a VIT
│ ├── helpers.py # helper functions
│ ├── model_package.py # model package implementation
│ ├── model_registry.py # model registry with get_model routine that takes care of wrapping the model
│ ├── preprocessor.py # implementation of preprocessor for dealing with unpredicted channels
│ ├── steppers.py # implements multistep and singlestep wrappers
│ └── Readme.md # this file
...
```

### Model registry

The model registry is a central place for organizing models in makani. By default, it contains the architectures contained in the `networks` directory, to which makani also exposes entrypoints. Models can be instantiated via

```python
from makani.models import model_registry

model = model_registry.get_model(params)
```

where `params` is the parameters object used to instantiate the model. Custom models can be registered in the registry using the `register` method. Models are required to take keyword arguments. These are automatically parsed from the `params` datastructure and passed to the model.

In addition, models can be automatically registered through the `nettype` field in the configuration yaml file. To do so, the user can specify

```yaml
nettype: "path/to/model_file.py:ModelName"
```
using the path to the model file and the class name `ModelName`.

### Model packages

Model packages are used for seamless inference outside of this repository. They define a flexible interfact which takes care of normalization, unpredicted channels etc. Model packages seemlessly integrate with [earth2mip](https://github.com/NVIDIA/earth2mip).

19 changes: 19 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .preprocessor import Preprocessor2D
from .stepper import SingleStepWrapper, MultiStepWrapper

import makani.models.model_registry
Binary file added __pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added __pycache__/helpers.cpython-310.pyc
Binary file not shown.
Binary file added __pycache__/model_package.cpython-310.pyc
Binary file not shown.
Binary file added __pycache__/model_registry.cpython-310.pyc
Binary file not shown.
Binary file added __pycache__/preprocessor.cpython-310.pyc
Binary file not shown.
Binary file added __pycache__/stepper.cpython-310.pyc
Binary file not shown.
18 changes: 18 additions & 0 deletions common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .activations import ComplexReLU, ComplexActivation
from .layers import DropPath, PatchEmbed, EncoderDecoder, MLP, RealFFT2, InverseRealFFT2
from .spectral_convolution import SpectralConv, FactorizedSpectralConv, SpectralAttention
Binary file added common/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added common/__pycache__/activations.cpython-310.pyc
Binary file not shown.
Binary file added common/__pycache__/contractions.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file added common/__pycache__/layers.cpython-310.pyc
Binary file not shown.
Binary file not shown.
100 changes: 100 additions & 0 deletions common/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn


class ComplexReLU(nn.Module):
"""
Complex-valued variants of the ReLU activation function
"""

def __init__(self, negative_slope=0.0, mode="real", bias_shape=None, scale=1.0):
super().__init__()

# store parameters
self.mode = mode
if self.mode in ["modulus", "halfplane"]:
if bias_shape is not None:
self.bias = nn.Parameter(scale * torch.ones(bias_shape, dtype=torch.float32))
else:
self.bias = nn.Parameter(scale * torch.ones((1), dtype=torch.float32))
else:
self.bias = 0

self.negative_slope = negative_slope
self.act = nn.LeakyReLU(negative_slope=negative_slope)

def forward(self, z: torch.Tensor) -> torch.Tensor:
if self.mode == "cartesian":
zr = torch.view_as_real(z)
za = self.act(zr)
out = torch.view_as_complex(za)

elif self.mode == "modulus":
zabs = torch.sqrt(torch.square(z.real) + torch.square(z.imag))
out = torch.where(zabs + self.bias > 0, (zabs + self.bias) * z / zabs, 0.0)
# out = self.act(zabs - self.bias) * torch.exp(1.j * z.angle())

elif self.mode == "halfplane":
# bias is an angle parameter in this case
modified_angle = torch.angle(z) - self.bias
condition = torch.logical_and((0.0 <= modified_angle), (modified_angle < torch.pi / 2.0))
out = torch.where(condition, z, self.negative_slope * z)

elif self.mode == "real":
zr = torch.view_as_real(z)
outr = zr.clone()
outr[..., 0] = self.act(zr[..., 0])
out = torch.view_as_complex(outr)

else:
raise NotImplementedError

return out


class ComplexActivation(nn.Module):
def __init__(self, activation, mode="cartesian", bias_shape=None):
super().__init__()

# store parameters
self.mode = mode
if self.mode == "modulus":
if bias_shape is not None:
self.bias = nn.Parameter(torch.zeros(bias_shape, dtype=torch.float32))
else:
self.bias = nn.Parameter(torch.zeros((1), dtype=torch.float32))
else:
bias = torch.zeros((1), dtype=torch.float32)
self.register_buffer("bias", bias)

# real valued activation
self.act = activation

def forward(self, z: torch.Tensor) -> torch.Tensor:
if self.mode == "cartesian":
zr = torch.view_as_real(z)
za = self.act(zr)
out = torch.view_as_complex(za)
elif self.mode == "modulus":
zabs = torch.sqrt(torch.square(z.real) + torch.square(z.imag))
out = self.act(zabs + self.bias) * torch.exp(1.0j * z.angle())
else:
# identity
out = z

return out
178 changes: 178 additions & 0 deletions common/contractions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch


@torch.jit.script
def _contract_rank(xc: torch.Tensor, wc: torch.Tensor, ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor:
# return torch.einsum("bixy,ior,xr,yr->boxy", x, w, a, b)
# xc = torch.view_as_complex(x)
# wc = w #torch.view_as_complex(w)
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
resc = torch.einsum("bixy,ior,xr,yr->boxy", xc, wc, ac, bc)
# res = torch.view_as_real(resc)
return resc


# # Helper routines for FNOs
@torch.jit.script
def compl_mul1d_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
resc = torch.einsum("bix,io->box", ac, bc)
# res = torch.view_as_real(resc)
return resc


@torch.jit.script
def compl_muladd1d_fwd(ac: torch.Tensor, bc: torch.Tensor, cc: torch.Tensor) -> torch.Tensor:
tmpcc = compl_mul1d_fwd(ac, bc)
# cc = torch.view_as_complex(c)
return tmpcc + cc


@torch.jit.script
def compl_mul2d_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
resc = torch.einsum("bixy,io->boxy", ac, bc)
# res = torch.view_as_real(resc)
return resc


@torch.jit.script
def compl_muladd2d_fwd(ac: torch.Tensor, bc: torch.Tensor, cc: torch.Tensor) -> torch.Tensor:
tmpcc = compl_mul2d_fwd(ac, bc)
# cc = torch.view_as_complex(c)
return tmpcc + cc


@torch.jit.script
def _contract_localconv_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
resc = torch.einsum("bixy,iox->boxy", ac, bc)
# res = torch.view_as_real(resc)
return resc


@torch.jit.script
def _contract_blockconv_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
resc = torch.einsum("bim,imn->bin", ac, bc)
# res = torch.view_as_real(resc)
return resc


@torch.jit.script
def _contractadd_blockconv_fwd(ac: torch.Tensor, bc: torch.Tensor, cc: torch.Tensor) -> torch.Tensor:
tmpcc = _contract_blockconv_fwd(ac, bc)
# cc = torch.view_as_complex(c)
return tmpcc + cc


# for the experimental layer
@torch.jit.script
def compl_exp_mul2d_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
resc = torch.einsum("bixy,xio->boxy", ac, bc)
# res = torch.view_as_real(resc)
return resc


@torch.jit.script
def compl_exp_muladd2d_fwd(ac: torch.Tensor, bc: torch.Tensor, cc: torch.Tensor) -> torch.Tensor:
tmpcc = compl_exp_mul2d_fwd(ac, bc)
# cc = torch.view_as_complex(c)
return tmpcc + cc


@torch.jit.script
def real_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
res = torch.einsum("bixy,io->boxy", a, b)
return res


@torch.jit.script
def real_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
res = real_mul2d_fwd(a, b) + c
return res


# new contractions set to replace older ones. We use complex


@torch.jit.script
def _contract_diagonal(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
resc = torch.einsum("bixy,ioxy->boxy", ac, bc)
# res = torch.view_as_real(resc)
return resc


@torch.jit.script
def _contract_dhconv(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
resc = torch.einsum("bixy,iox->boxy", ac, bc)
# res = torch.view_as_real(resc)
return resc


@torch.jit.script
def _contract_sep_diagonal(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
resc = torch.einsum("bixy,ixy->boxy", ac, bc)
# res = torch.view_as_real(resc)
return resc


@torch.jit.script
def _contract_sep_dhconv(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
resc = torch.einsum("bixy,ix->boxy", ac, bc)
# res = torch.view_as_real(resc)
return resc


@torch.jit.script
def _contract_diagonal_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
res = torch.einsum("bixys,ioxy->boxys", a, b).contiguous()
return res


@torch.jit.script
def _contract_dhconv_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
res = torch.einsum("bixys,iox->boxys", a, b).contiguous()
return res


@torch.jit.script
def _contract_sep_diagonal_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
res = torch.einsum("bixys,ixy->boxys", a, b).contiguous()
return res


@torch.jit.script
def _contract_sep_dhconv_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
res = torch.einsum("bixys,ix->boxys", a, b).contiguous()
return res
Loading

0 comments on commit 901e1f3

Please sign in to comment.