From 901e1f32b56ec2d5851321e518dc4f074ae763ce Mon Sep 17 00:00:00 2001 From: jqs2019011556 <2019011556@secoder.net> Date: Fri, 26 Jul 2024 03:40:01 +0800 Subject: [PATCH] origin --- Readme.md | 57 ++ __init__.py | 19 + __pycache__/__init__.cpython-310.pyc | Bin 0 -> 359 bytes __pycache__/helpers.cpython-310.pyc | Bin 0 -> 1030 bytes __pycache__/model_package.cpython-310.pyc | Bin 0 -> 7158 bytes __pycache__/model_registry.cpython-310.pyc | Bin 0 -> 4400 bytes __pycache__/preprocessor.cpython-310.pyc | Bin 0 -> 9893 bytes __pycache__/stepper.cpython-310.pyc | Bin 0 -> 2375 bytes common/__init__.py | 18 + common/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 523 bytes .../__pycache__/activations.cpython-310.pyc | Bin 0 -> 2648 bytes .../__pycache__/contractions.cpython-310.pyc | Bin 0 -> 3551 bytes .../factorizations.cpython-310.pyc | Bin 0 -> 5942 bytes common/__pycache__/layers.cpython-310.pyc | Bin 0 -> 8093 bytes .../spectral_convolution.cpython-310.pyc | Bin 0 -> 9163 bytes common/activations.py | 100 +++ common/contractions.py | 178 +++++ common/factorizations.py | 247 +++++++ common/layers.py | 288 ++++++++ common/spectral_convolution.py | 405 +++++++++++ helpers.py | 44 ++ model_package.py | 268 +++++++ model_registry.py | 170 +++++ networks/__pycache__/sfnonet.cpython-310.pyc | Bin 0 -> 13751 bytes networks/afnonet.py | 268 +++++++ networks/afnonet_v2.py | 315 ++++++++ networks/debug.py | 28 + networks/sfnonet.py | 673 ++++++++++++++++++ networks/vit.py | 231 ++++++ preprocessor.py | 426 +++++++++++ stepper.py | 128 ++++ 31 files changed, 3863 insertions(+) create mode 100644 Readme.md create mode 100644 __init__.py create mode 100644 __pycache__/__init__.cpython-310.pyc create mode 100644 __pycache__/helpers.cpython-310.pyc create mode 100644 __pycache__/model_package.cpython-310.pyc create mode 100644 __pycache__/model_registry.cpython-310.pyc create mode 100644 __pycache__/preprocessor.cpython-310.pyc create mode 100644 __pycache__/stepper.cpython-310.pyc create mode 100644 common/__init__.py create mode 100644 common/__pycache__/__init__.cpython-310.pyc create mode 100644 common/__pycache__/activations.cpython-310.pyc create mode 100644 common/__pycache__/contractions.cpython-310.pyc create mode 100644 common/__pycache__/factorizations.cpython-310.pyc create mode 100644 common/__pycache__/layers.cpython-310.pyc create mode 100644 common/__pycache__/spectral_convolution.cpython-310.pyc create mode 100644 common/activations.py create mode 100644 common/contractions.py create mode 100644 common/factorizations.py create mode 100644 common/layers.py create mode 100644 common/spectral_convolution.py create mode 100644 helpers.py create mode 100644 model_package.py create mode 100644 model_registry.py create mode 100644 networks/__pycache__/sfnonet.cpython-310.pyc create mode 100644 networks/afnonet.py create mode 100644 networks/afnonet_v2.py create mode 100644 networks/debug.py create mode 100644 networks/sfnonet.py create mode 100644 networks/vit.py create mode 100644 preprocessor.py create mode 100644 stepper.py diff --git a/Readme.md b/Readme.md new file mode 100644 index 0000000..29b4089 --- /dev/null +++ b/Readme.md @@ -0,0 +1,57 @@ +## Networks + +This folder contains models for training and associated code. Models that are currently supported can be queried by calling `networks.models.list_models()`. + +### Directory structure +This folder is organized as follows: + +``` +makani +├── ... +├── models # code realted to ML models +│ ├── common # folder containing common features used in the neworks +│ │ ├── activations.py # complex activation functions +│ │ ├── contractions.py # einsum wrappers for complex contractions +│ │ ├── factorizations.py # tensor factorizations +│ │ ├── layers.py # common layers such as MLPs and wrappers for FFTs +│ │ └── spectral_convolution.py # spectral convolution layers for (S)FNO architectures +│ ├── networks # contains the actual architectures +│ │ ├── afnonet_v2.py # optimized AFNO +│ │ ├── afnonet.py # AFNO implementation +│ │ ├── debug.py # dummy network for debugging purposes +│ │ ├── sfnonet.py # implementation of (S)FNO +│ │ └── vit.py # implementation of a VIT +│ ├── helpers.py # helper functions +│ ├── model_package.py # model package implementation +│ ├── model_registry.py # model registry with get_model routine that takes care of wrapping the model +│ ├── preprocessor.py # implementation of preprocessor for dealing with unpredicted channels +│ ├── steppers.py # implements multistep and singlestep wrappers +│ └── Readme.md # this file +... + +``` + +### Model registry + +The model registry is a central place for organizing models in makani. By default, it contains the architectures contained in the `networks` directory, to which makani also exposes entrypoints. Models can be instantiated via + +```python +from makani.models import model_registry + +model = model_registry.get_model(params) +``` + +where `params` is the parameters object used to instantiate the model. Custom models can be registered in the registry using the `register` method. Models are required to take keyword arguments. These are automatically parsed from the `params` datastructure and passed to the model. + +In addition, models can be automatically registered through the `nettype` field in the configuration yaml file. To do so, the user can specify + +```yaml +nettype: "path/to/model_file.py:ModelName" +``` + +using the path to the model file and the class name `ModelName`. + +### Model packages + +Model packages are used for seamless inference outside of this repository. They define a flexible interfact which takes care of normalization, unpredicted channels etc. Model packages seemlessly integrate with [earth2mip](https://github.com/NVIDIA/earth2mip). + diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..4da6442 --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .preprocessor import Preprocessor2D +from .stepper import SingleStepWrapper, MultiStepWrapper + +import makani.models.model_registry \ No newline at end of file diff --git a/__pycache__/__init__.cpython-310.pyc b/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36455456266155845b57b4f8315b2d67a3c6febe GIT binary patch literal 359 zcmYjM%}T^D5Kh|e57O23?gQ9^4SE*w;8hU{3cZBV#+o%v>ZFBMpT!68m2%zFzJezw z7NG7uV`~Uy| literal 0 HcmV?d00001 diff --git a/__pycache__/helpers.cpython-310.pyc b/__pycache__/helpers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e019867c8c0391e1b030280a43cddc3472382396 GIT binary patch literal 1030 zcmYjQOK%e~5cYVzn|G6zs^W1#;+6x6DiFN@ssvYhDF_h>t(LV(Hr-cjFQjE7A(d;z zFF;C;`~iN9ubg`1z=cD_ctb_J@{Gqbp2s({S*Oz^xITST?+-jeep+F3Q4|hwtA|K9 z;SiJ2nnOZI4%jn|S>lddJUg6yAtR5w+(TaDa7_BOOMFR=`e0a?B}r&KUN4MD>bRfb zRzHwvvV;}+`T~8yf|g`SSL8W)`()vi&eT!cvz8`PRys>}MM~iGf-ac$N{2h=R34&D zO&;p{vccI3xO>9n1KpgqN``ms>dw2MK&%~lZ7keM(&+F?pS6;-tQV$XEL0aTAKM|8F16g0~sVj3#m*;idhj!q1Gx-i&SOOxM`7y*wiM0 z3bdAzVOVJ)HTAL3zKYI-VLVdWG=n(yF@AyBG_0T*kQ*EW_mDJ!K^r`vutR&W1-sBh zz611k^NMwQzgSn^Mzhc&d%Y|gcCf0U8sGuPh*V?LTR;g@kh_{9cS=|?J4SYxY5 literal 0 HcmV?d00001 diff --git a/__pycache__/model_package.cpython-310.pyc b/__pycache__/model_package.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10a5af4b5887b19c99f4e99752f7035d46fc2f3f GIT binary patch literal 7158 zcmaJ`&669)b)Ojwz6KwQ{UEvIQubJ}9HT;FEjg*A7+Mic(b7iC6_Z`bF(PX+i*5h| z4rajahISVoT)c|qREd>Jm2W9mVRNz$x#f_Xj`;_2njCWptCBw;I#~I=9$&LiOa}4}`_G{Tae8n*SlM1Up8-+LU$bUz|jJCl{W`$36Q3 z({Hvz6g(C=cqh<7e54a@DaN#EL`Gpw{LI-d`#!&=Jzs}a6zs8G$TZ5hRT>Ca_ zu=S6P_DgJoZQ}h6c8gtSH$FDo-()*%o83gtW!9n5ygivYk7LF|cjR}T`aSM;W8rdN zp1DEP<$^~YzMF}HsMm6@=t{Q}N6HT(*AGKi^|>4I4wupwXD;)V@2c2!MqHrD;j%f` zkC+=Mi5dl&-{EdREe%J$3XVhWo&>6oAqdrfZ5WK47g&4aplPPg6JPj4`9oiFlp38_ zdJ`T&+4G}bNOkogO~n(u7f2<}I`rl0&qCo1JkmwN4eX;qdx-tBSPtuCGn?h{&a(p6 zrkGauShyXboyj-@+SUrw*3zw1!viy!6bd5A}n z4JKw{crdX;<1>@yYKu5lO)Iq}54)lWjn52G#_Nl}-i@O$i1@3!XC2w`dV^=Dr>B$C z-2wKe@5hT_EZ%pA{!>2+yjTRiAll7JB;=vYUfz7?TcfkIpFdDa1jl2=r7qsR{OO`zFZUjr6`g)7idyJx9C9two$tGuit77{ zpdcFu%Z|-RpLu+~Jp7kmY8eNJxbg9TchmxounjbLpbiivj$Hif$0z!@W<8zwgg$`^ z{4khspiwlRP&S{(T=ipiWiq-?w0PnR_TAGvt_p^{d2czA4CJKiW332`+>zic=;-dH z;-jpw&SKwa@5t2Xp*r~?5@nnlGxNeQGc>RgoF6ZjSOa@XX@08Riz>{DN$$ejGLjt2 zZ5unrwgJ5ov$e$55_^a0%J|B`+aX+O<_})9(aVc}p#(c`1bw8}KBzQVB%<4C5C&)6K6oB8Zq-$Y;H7LwGC zph;;QeP*U*&+CM~lsJ_76w-m$qj(1yNhso&h56rYYs;(xieNjJ_Pd0l)nz5Dyo92k zqW};x#C^4ZBV%A+SRYu^+|*WiEO9Qe&vVm!lAjfnIVfJ3)4VFp%1Qph#NwHQ3NxRZ z(?U`hR8guv&57?P1!mkdR_kO_l?F9{c%f&}@q2Df0qCM%7FIXt2hwt-uWEI2D&#Z2bbq1FL^yW%Y5?e%Hz~`_RDJ z|Ig~$Q(bdo@+*zNhyF)gl6P`_m-y1%IV}aUYdE{RU^rU)2&@fW%qLd^S10i?DCwE% zliOGlSQs)nk&FVn$NVe?Z6(IH;^4_~#1VHx&_sgz%w>I1LI2&L*BZz;dO;*AHBS!G z8X;pW;vU>=|LlwEWwsS{(kcl#}X>}lvxj$AxH{hbZ#(38shsujbv<$Lc2_i?|(#C=k!Emke7d!6c> z$B!Ny9C|<6+kYhfj4H~9M{j23v=$LBfZx|UCLU<<&%58V&y& z83i`MOq^agqk)Okm;niK;!H26WjKxEHe@zZk^%W3%cG{i@<5BXv{nIHucEHV3}Di$ zq^U|;6Q>(J5V}2tysYJqwA|5h)Ew%Y4ZU@%CA2E^SA=vQpsu37`apkG)n8S#oUko| zyr$(NEw5`iYL>04HC8z{e{MFb2ik7LpJL+LJj4%>OA8a(i3JyA9NHkn8v)!~)m5;@k+dU47M9)^Rj(NTq_McCYm%)qLV4IYd6B`uC_9}S8V|q1JpRJMSgdc>3fmE5|3AL}*Q7Wj7!DdoOGI~@@ zip&OyDxr33aBXmXb|WdVJbB9`KiEbdL6GOsDxOFP0RMHFz4X&n)_L%0DT_L-?SvbmGliG?=OvSw)-@(N+x>v56VbU8_L_(uFW z^+j8iZV=U`o%ojAi=!@NaSq*MUn; zH#kMeKpQF7K<))ys+WNObnViBGnpoDsm(v~!)96hlzMTXtMAU7X#$)1oLdlS{+e!O zcj(JYwFcePSvuj=(c7W>{^|+Z(EVS5+kZmKk~pSgJLZOo8-`slOZaQ#Y|An0xe{vY zfZ>vDTjC6gSB}p$RL~J3h?SqhNOaWB7~=ToNM0_qhPYY4h>aPdNpuJlZWeT#_&ej5 z#-NC&ggb<*%Ak&bEPY|8cwdR6Z{oO33&=T)K=N;>6;(kiIDW_h8gXE1xL43UrHtHK zQob8O&f{5C^3gsAzcRq)W(MR zXN6-=bP@Nfjb!7Ov|@-`URYbkbTcZD6q+R6#4T+N&{0XM$vRrQk!-S!=f&w(vNhYL zI~LY^3y`qse}FAH>Vw&J$5D3sCKJHDJX%KT+#~RqeXyQI*nk2L8cKc#GHCgO<);XY z`Xf2U9SYR*ir2W~BL0XA76YHWqmd{EZ4@Ahy`Xc)l`#=dU$^%WN2C;lEdU__mt2HH z!Ae25V#$f0`feAI8h#Rsr*0S^EOQxgoJU^{Id}U?jpV)E-5#Q$&jJ+0G05J?vnC4k{-_4Pm5EsLQspF#y?okDA-6R6yvJRTi4M+33qF?zL%PT z=s=u7g^n>g_d|q>DFFTYe^A5J%XMS&vIdV!irl)d!CAhJ%tAe*5q|>xBLp$W)GS@X zn>q0^<$n)JdozeS;h1shNzA$LjpC3WEP?>}emoX3wU~NA*I3+|b+9BsJ}q{~LMAT4GJ{xmB;X)H3&4CC(wa>~NR0*t@ z)4;Rgg9n2o|AKVLFEy`er=S>Z2Pp^`mkD&2mY4imv&C_K@E_qyP`l)#=&Z)hR zk+dt>8fHx0(F>{bnQzH!-zolqvO|eQ?nH zRu(_ex3b5m6mL`VBP3V^iU4K1pdI(tJUFYJb?YZt;8c)Rr?1RKINE|?C4@R_b;4k% zTC4w)l7*dWb6xE|rREWW2xwX)8Ut|CVR20@_BF4`B0ii3MiC=1q-1F(OPh$PmBM39 zpEM=V*BDKMw8z!cU-9KSE~Gc0bU!0quM-T(*O1t2C7axL_D8`fYy$r3Ij3q{wz>G* O%9Wa>rgOdi%Krg(!h>!A literal 0 HcmV?d00001 diff --git a/__pycache__/model_registry.cpython-310.pyc b/__pycache__/model_registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..985b90b787e8a9f6e3fe42a0dfc1019f5cdc4190 GIT binary patch literal 4400 zcma)A-E$jP72jR0R?=FQWjl_OG-Y-&&{UmNro%^z=>ToW1C$6)lENTb)Q$E^US+kb z+`URIvE~8y75)LpV9)T5m(HLC~)`Tl|19M^}u{+v~GMwdh4{T+`ShoIXg(g zSmkv={oaHIDc0;?gsIBQAID)5x6HgU@udp=2wnHUm;N~Y&`$;SZ9WX+VI&?ZG5NLh zClev_=KX1;!qpi^3N@8+pBOKn5;`{CyLi(#(1gaZpILdaRZ#=>dY z76UEmmxgTO$k(6jC2nM$>fp`?kJ7FOlIe zjQ8}27I7W@(_8XxPiFacub;%Y$h5DD`CBA5xwLJ%X4R~lwkfyh_GKm`>*)6_G|Pwj zg}x3jP)24bb7Y-Z>4qu=rm{2R#Qcl#+&VG;2>k`t)R0#*V^k(>D(9IYZ((04!=Au? zUgBMbax&;Pu%OxU{EqA4(N`BT^?Z*77hV{9>QF3JH{bC5ND4og z5#uTRtdhH7y5y8wahyAvLGJ#_kEY_Dl!?r31uA*9*CPk$_417#vnX`_r6sqaIMyxO zT{oMU%Xb#LgzODuGNzNb|65o}8))jzmRT)1c$;RuByZA@x~2}p!b>jLiomF_)gLe* zTb-JLd1ju#E>EmveL)zWFi7Os zSKjT3uMYPVhmeSRgD?^ZiT(SmUb6OW@9|W5^QIRhxLTYjPaGj?+nrV;FGV8G%jqNv zRc<9|ZWFQG;SH(WB_q_r@g$Kd3Loe8RE1H#OPAF)Bxb$kjq|3KN-xo-+=&uD0AQ_& zBOKAoc?C*L!r8ixz(4+Z-^jN~@!m<-due zY_$iBibZ%+25-PHhwHUxzGabFP3`SHwS)N3;lzVJRd{~a%xP<;5& z1?s#3#Fx-|t|zT#7Qd6k;w2hfYDi1^fOWvcwUFFBS4hED>$5P8w#ExeUE01#SA4$H zvI%gP66YmI#=giwAzy>@UB=JGQZ$XZ56C|qH12kT9Qhq|`*<%rg(Fk4Q%Gi~))`V5 z0Lz^Jn=kT!>jpxs2$;YL2DTz?a@O#pNEOeJu%YhC@6s^`P>&*QzHuob_fZ`ePZd2% zRUhF^|Ai()`874=Effvtk5IK9TLdcsiVS&!dq^+x2G#&CBYR-QW`@*-e)$;X$H+lK zGDa0ueP+t{RPEFaXtshxR(@XM6tzZUdMP=n>S>+&NMugtFz??{>QW80F=`BmYbSF~ zn`Z`6+Ggeqzzs>z&dP!FS5Q9PIy3%m{u!s^irNk;>Pk`m7^9t_`oxw`6-uDYK~k(_ z6~N-|>Zn!o{|1GyZSE%VGZBY^A{^DELNy2@B7+sR8&YHe6cbcD$od|unSqCs>66cW zIZU-3KaMHA5Q;^7aAoVlFSU)lz*j!Zjp}eJh3AvE2Vo!H!1p>JL%tH)79{pP-o{0F z;c2`^CkdE9{eQUs|AUg3y?#mZrJ(eYm~mZJTu3Jl%?DAjLhMz!hL3jgHQ1w<9{Ll} zJ6avg=Cux$EL1$s-@Q~pyb%4~BRM0Rt>SQ~+Q7P67*90)yjHN%vwfuaRL^pE{zfM` zuUua$0)xGzKx&s+8ANRjJGs`Q7e##}s_i;wPizZ+1^uoAD9AkV_Qc?!Boraw7wy4p_exbdD>L? z$!`(Id$e*cX?rOttAIZb<%hJzf%ZsIRHe%@&C~|dxUZ5VO0_V)zv|>w;`424=+i;xQp$;p)+`yD@})pueIFr%kC7jd jOQj@xb)(8(B2}wNwW7sUW9e4a+j(ny>)O`3`+?cyN**gacswS>^5zL#x`7purr`}M-sK< zkiIioOX@IoD!Wb%EaIYnY5GUCXp2Q2+Lsn>(e^EQY0=jLgQ7*zzU)JRA_!Wb1>An; z%#fle?IK1g%(?$R=iGD7J?C6?i^aTw-_QMau>Eh(8^*s%{?quuPw9V;cr6G(?uAb)_vpv-*)C-)q+Qm+(UgCVFJ>4nS z%bl6}j4)m_l&!Mw8Y=5$_bSI`eO5Ut_pVW&Q{rW#mcN7D7+Wd`864IQS zR%N7lHKQs>r_?1ir{>=k^#XDWY7x1j+E7bs8M%^LQLE@Rt=7~!((?O;Wz=Uti~Hiz z&uFiwyG<_$y88O_O*yvhCFFDXX!k*wF^td%MPwYyfjq&fL}Jg34A(?zA;q_6DRD#U zCs30?O(wJlS=45xjL_b5kTaE368W4mZ&`E3iRjA+0jX) zGnnyZj43EPEb<)LMdKK~c^>B$dY&0Ih5FL4-)SGzOsgE`&a0_gCUFgRSY|lGu)=Uw z6;MA14CYnwhD3e=BNy~{QM0JOid%Yzlvubzk#F<7IufDvedl}eF0 zTo-vEp!CTPh5p-g@BfUwe<~7c`dFg&3g^pd z-X1)H-kW3Cfru8!5^!$(nb4c~RNJ{AFMPV*M3Q-e>R7_Z5GpHZJT@M32uzm{(QQ>@5>3mip6CY4Bre z2D@b#O^=r(_rRzpqEa-yBOzsT;Gc2?K2XMcGJ&A0{hkCBo%iIng`JeK#~)YoH;7(; zp|brMrr841>|5by(fYY#(FYgxE8&yC2v~tb<@3htYyH^>e6??2fd#T*3FpISA`N^`X?C0p%9kPj{Z_&oOrr zpT0u5y+kS$>7=qR$s*U-v?GzVD&YDOQWX>P5!tCm5?LcKMT7=}UV*HTdbMdLq=)i5&49wN6Dw7+3R(JjvznT?if)KY?Y)-!$@L%8m zlc(yA_JWq$Z?xmmFy{u%M%#cYRIUiDhdKb2u$uc?T_j=eCGV)#}9gpylre z-Il@(b33?uU4OrW%hr2y-}9R-FQ7Yl#|z!A?(XPDZ})Im(Qf$4?KFZLHS-Q#TD(kx z@EVG*#M2;9*lN1lUL)Mc=4xA@**=iaS=x}}x-GvIx^A5F-4ywF`C=+`y2`7sT->I- zX5%o`?I?QNIa%V(-?HgtV7!s`XG4ORgW<5Jse zXdl)~a%T0)@UXp6U7Yp1E-0<)5v_hSTb~$==>g;+64Q>ddtj zTBO|t|GC}m?Eq|Z5Sty$J#*SH+}-)Uhv&N;#!f>i(4(c|>~14ygrNpC`2h7SaorpN zGIU{OHNpmWbx986Y2uppSVdNOJJ*S_C1Pk5gl@Cj)hdXccFPZXjiwhn9WV*6mfQCO z$bfeQNoqop9IEAc&6XePWukT_;WHC56c=fqBO$48bQ(7rert0-Y_)^UX4h9Ol4nRB zQG7oP^a}5JG$J|33{yu>)9fbkT7jk|Q6ovyt2CxKEG9$17XA)QX-{|dLl2zdgLC%% zCM}19%7VS!-EJ5!sIE(sO-khww3pZ3j%_loJAp1^4Q3|2PP@pmn!`C}py-gKWk?*A z4^I!Fp4kDKH^GYs_zvs#0h7LNA123V5drJYP`%W%pc+%!B5jk9!vr}X!@VHE)NBVG zD5{+6HrtILaNXO+f7h>eeH@o}?doAOXu3OlZyp>R935QU!)&{a?o}N6etSO{rjI1& zfRkzMwEU|{i5k6ja1~N8GQBr@hk64DU&0ss6+lTi(h+&#hzDdzl#p8oT##j9kNyf~ z0rktGB33_=@7Pw!w9TZ|GHqEwzlw0oyr_yf{H>#R0driIm&F{uir6py=G8m8|#Liq`RNDPb{W_9ZfG^%~!5jFR_;&HV5}1pyyk z2`8uw9Do8SUj?0x)+cbA7|3-cszHT-{S8M7P5ST!%gv*I8&GfFgjV=jhIQL~?iKDd%%+B^dgBT%6Bxl5S zk%7}==tlwJ+U~=q02;{8!KFhgadPDRiA6a7Fh7A(YGlFf-8Nx@Npe@l>z>4n1GqXc zmze<3kPw+IzT|HUeG|*FnX_5T=<8IVJ4nAp;64H&K+WXSPiQ68!6Z>w!3{6OKgNK7 zCb1z5c`;``K%!w;AK}b9rn*$$P9W6V~6N4&!*+ ziPD|Kef9!<{MQ)ETw#lK=ma2s`2xfzl=$88e+JVr{_B9C0AUW67#=nGu3Q(%>d5~} z*nmB+A@3)qs+r(S_W~n2j$=LJlc6_>;)3=%-J9Ozt@FJRma-roCfdXX+F zfmz^Z3*7n;Ml-LX#!3_z)0SK^R$O`?IcP7ws83b({nXA>WU^&Wm3CboG-K-knh&|| zB*msv44bqf1TZP^ zNow4@MFG$Z%J#Y1EIIF>qQP|eCs5ZpA=M`p1vI~i@wwMdFMoj$s9=elQa0{Fl7G&P z83I|K0)9CP=poD}_6qol5!PYpw;5*nHXYU>>Ugj0D=?dV8K^0@NcUy&Kbct`4XazixBz!I3;lvD=XJp>r&iN@s0JCH_$`@e#5q(kED4bSUAJJgo-Wojf%5@(J) z-3{~=s`y0$j}Z7IflUIZZUU*-#9Ab$SAT-QRRXjUyl!~vK?}2ys?-k?_#8FZaLyar z(~nXO83OFL!j*j!F}ZB&xz}>1YE1tEbs{TWzezw7AS)gB0zwK0iQ_$CvCj{~I(VIe z)1H_y0qg)HcnL%EVgY~FllkwFt5^~b2{>=!S+j&oVMA=CX~`@J>ptgS$h}K?RIM`i zc)+Ceil~?s!M)_s0~2P@g&TO`%tb=s!;z1(o|-{Q+%z;UXsAFj5Kw$%@Q^w4FvL$n z@v&0}HD;2L&+0CtOej7U6v{#I*->_zJRrx1a@>>s926g3kp_6_V7EZ(247S|b(dI^ za3{_hLu5%-Nr}PmNem^LOI0&4k$i%%Jp_?i9KQ=cZ`W@hAw#uUfc z>>72y8?WhKrq*8p0N43&fjjulSs|nN;wU=!EO8u37w%FHOAI*)W%E_JYOW%P;J}o| zO}lXy4jN5%hIxo_Mlk@qh9x)VDf6VZbWr+|@$-pI0`D@I=aLwhV~q764i#+0_Q#Mt zUSjr0pLOtn#&`Bb?16{^uQXy0-0c_jTQvF^db;9gO<0x8RAAN-d)R3k*y~&y#39OJ zCEV|@JQH#FI{N3p^z`V$R%y-J(!WNh#QKo7(M(6f5VuB5HCMF;iLrRAn<#Q}B z=`mg~*;bPj>hMQ-(znmy3;q%Sv>T5m%z|1JRhXqy8o+yEUj`}#HyUw8(ppQNRCJAM z(2G4L&|Gdys3C9$JyoYA_MI@}GwICFV2m{}GudGHmjv{BHJ$lYs=07xcHU{7+hPmW z?K{9hlb4{sL4e%{_6*2JK(yj@s-+`{1%m4bM|F94G(AbH@YS!>5eHp*&K6 zuLP}Mf$TXY`-e`U@}Hx!Wq;(zquPW#-hI|HMle08kh9B2LR$ywh9)C1ac-v}Xh~Q{ zu!}&Nes~aTqc(Hcrhf}lz+89-NkW+SDMu6IA}4ahh->#IX|R&o=@${$JF$CARTD2D+PhV&(EFX= zz*CeP=Utbh?vyXME)27F(!v4t1QkZ_h%oy)fwu{~M}XPzS1I=$0>4S%cL|&j_&tEQ zn8yE`Uc2p%2aI>A3q_6jMR6S&){l;Xe^7ecF&)RTof2Tyx#CRczmQ+euQ=6mj(ayj zlhKdzBycQjbR~vhzGH|~A~u3Y;HCC!XqqH^nM~B)*h|7~qBhOu2>nM$YV-8pO?br1 z@v}^~+YU4}BXUe{d-`#j^nR*v_z`o{_w`p%tj`XMo2TzNx=tO%3vprmfhjSTE8|t0 z^tKZ4%UHacK8kEAid47WfF-LoJCI9+R)WN`CKEcgzf6xO9%IW}AH$U|D2~uOPbR1R c6j3c$1#qUx`XKwGA_{gfF zo%P8sa_UbgARm40&*8PFoO|n`=nPkil(=cT0B488;S4!5-*C9J)FSXb|J82)e2I|X zakBUcm^^@CzX76%qB)7{KOL}`QRG?92O<`fJR*uKaYB^PtREcnI0Q+cLXd=0Nkh>s z(rdf~Yh<%WUFTb)M`oC6Ylr6EMvC&qk54D|Ltr$CsUk6ZLAI%)ik(C;uRX@%1ra0; zz;@s+ZDqwyuD41(`pzVykv8ef3|d2eHoV!Al(ZtpY{E_nB_*%uG0df?*fD{BpvV?E zqno|Z1$GS94DNQ3lv$CLQW}QS_GsZEE*K1z&d<0LR_EI;93>_h*tqS*ayKbduFt4z zN}1+~wNjpuKcB7-i##jzXX^*4P32C1Z-0OPaDTlIx7kgG>#%p6kL~RL=^%NQ6qy{F zY$q$$>lrTU+^$==JDhZFbYMa_-8C2%nGS8yHa%QD7ooM;ZBo3=>}vpn74S8TKYjpG zkv^Hwik{LV=6P1JQ%b$O&nt#m`Gp;Eza}aUdZHqepc23ZDD8(;069ag&}%g+T7{=L z3n9#6^C83~EEBV%%O_`%!6Uevn8iZLaRE_QSz2l(w{=pEjkfOUZf47&IgsgaFdCOy z+A=AvyEGFQLo-P7>`=SqL@B?TrLVS|2pufXl$D;Vw^1EDWNMt`XAOf$aXc$VP5?<< zfV6O-|5?Wl3P%r22efK(5%n9}L-SK&)J>S!E|4~B^Db@k7Vq*77wquH`D9G@K*d3o z70EzL=~`0y3P8RsV!Sndqy$XykhlsRkH}J9) zWUqZ#EY;AApzt0%Aq`p`y>2Sy8N@ z3aU&(_4Fh7M+nlM*fh&v=gGrf(*(enyGXF`JxF~q`GkA5cxjegjI4H1A-Bg^Q>KYK ze9U|Zec;f_=P)dm0pcj4S0IgtH{QJCKSK$3p$UYN{SC+t0SM6}ULts)ovn~#xR@RRcH(MxQ}QA|U1!_AizPX5qb2KofLH&@x9 z*KvVKiXCk}0cjm#EHHyj$E+Z^js$ULZXj6&;=&rzt}})96o4)OD5txQr%BmsdB~Y} zyguX#ipfI{C^k5A3$5KoatBBpU~?#(z~4gcdl+`)6jQP^NmkqkefusDK?S``7cd!g z0GTZY(i)w2&ZqY^OK?cjrMK0@O0`O2g3WaeXBn7f+yFGTb`7~UWD zuc#T_?Oyjt??q+}mf#*(Lu*ChSW(2Q_0zakKaFdP?{xiihmY}P)xHhZ9b@j;bH0M2 V`*r*M3L#K?zu1P#xy-KJ`Wx+nCR+di literal 0 HcmV?d00001 diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..0e8a082 --- /dev/null +++ b/common/__init__.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .activations import ComplexReLU, ComplexActivation +from .layers import DropPath, PatchEmbed, EncoderDecoder, MLP, RealFFT2, InverseRealFFT2 +from .spectral_convolution import SpectralConv, FactorizedSpectralConv, SpectralAttention \ No newline at end of file diff --git a/common/__pycache__/__init__.cpython-310.pyc b/common/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a94c2a491994e86cd8106391fdaa806c5e066b7 GIT binary patch literal 523 zcmY+B%}yIJ5Xbi;*=({&C=x;(d)-63YR?FPAW@MhLQpQ2(8~3?iNRy1_9oh-Z^4}h z;FWyk)K}=Ko{&gj$-nV`#+n(A^UPf6f*Z@Xjy!h#n;B=dq4HF#h1(xfR%;_xyn5%~-Kvwi<_cT5%!efNmp2_*gA*}+ zRDnw*UHU4fg`PGdL39u)qU%BS?51$Zd}S@?nAPCaO@`=uJO^foXT-m|LMtg0|57ZO zW##1lchfYhrnnbsGNrlzUCrg(ZtW+a4^)ZLzh)v-@v1P$r7f5S(5fg)p+sGlpJz*# U9#Q7#@WG0q?t(L;wH) literal 0 HcmV?d00001 diff --git a/common/__pycache__/activations.cpython-310.pyc b/common/__pycache__/activations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c4d002746d6a99ea6797542753bbe4be79f48cc GIT binary patch literal 2648 zcmbVOTaVjB6rQoij_qWxU2WSfgivlW(2d$6;Gxh~6;wziLR=)27vp*+S!W$P9Xs3I z;C*Q;^{GDrY4g~>G_O3MN<5S&1P_RF#<^@&0TL6<$z0By@tpa-&(2m>JObs%pYq=3 zDj`2(=X7DvxdAo528I((6Vk(lo;M>4N)U$v*}oNSAWbH8|~>yz!XyD7Sg-k=te55!S%wF4`XB9gXFj4d*y;O(TF%US;NETkK~O7wLLbKsSfTTif8 z=LC{qcYH(Dh8Z6x<6OCWVbU8VVJeK0hIQW4G@+2DxzoD#r zJQ7klK@g{L5d?}sEFlf>SMFD#42PnCd8^1|w5M2>id@xtNfs7YuPAmm4s+SizIQ~p ze}w5(Cgj(*6w5`@)01)yJv?*{^1W~*VAm*2@EP48h>|eRgWxIo^M`hpCUGh*w~wMc z3i^Y4hlhug!}cIf`+H&5o}F_3e`y%*hiM#SGVaG|yXxVjNb+`+4To9UUS4&+J35j8 zdU76Wz6r*ow&_uaHfhta;kQXATc-fEyGRM0C|inREYk?{GA}MPY^b-O<{o4iE*?0u zZ^-R}o){&a(qjO_iCG%ZGGUCBCbR$w-&DQ^_+E!`3#whxL2XKBEoVy40zi&&#w{&~ z?41Sf?X!8iq+6u4E&vm!=BbDlL=>>5j=STkbL7~Degj}+@6!_6gB5OHAp2x$9M^R8 z8iH%EiVKsFtUZKtT62-Le-P39|P`!Fy+b^#T9U^C;?EbGo7sX0%HzE zS_uA`l$rEhc^1~oCI&Gk5^AW-Y+NWd3Ge3JRlr)@iv@>>xPcpGDUVT3x(vLed<9o) z^D2g@^)h)sl-!3-{w^4gRzJ4w{o15&Q?}{-wr=duEz0WN?}*J!@L(4EYjWl^GUstv zl~2rIMRn}bBqyMb-AAsvZ?Om<(_%?At_btH&@E$=6N?+p+ zD7~$fz6K)asg2q3Z!V-;wjf$3n!ie*qm>!6@E^(yuzqi*%uUqZ7j=22$`|G{OHH25 ztfK#N6hKc-T%Z@58t?`{W%jf(Tb&B`C=;@d2J_|D}olRX?a=>6;xS|LF__&5G9#XdJw_G+WL&16 zihaf`C19S3=ZslSz`PJI8BusinAHT#p*Uj9S^}mbUNdGr0dp)) z7_%YdGpY3EU+R!kPjL-T`GALxN6qjwsCT+x87X=l{L*-y!{d8AoOLuE>5DAnzSNid zN?*NA-=z8}TvIvyxLPyaUfHyv>_WK*hi_}P+Z!lex?_WJ$xVvFT@lSF3b*Qw(pUTlEMNpM0IVD+Y6Ep75ol1cKHn_jtye(z+RbXK+dLN4 ziw|NKAKsB4`4sxGL=eR=4B{XD=q=qxe4%W7Qf}0~+E+HDGie~7mSp~Vkcv1!j-RJH zt+ruI`cVx1(1q|_DIuC&$toEIDD_iA0(pEpz!Jy>CE(|tLjVPQH%T{=eZsO|6K0vr z7MiJYfzkx4Xiy#@0Z0BdjNX&ZVIi|S{}<2&MF zK@VmUQtvF*%vRZKS;%Kv)8iJQt0CX@knV?&?jg}lE0?kO5i&M@+o2|`h3Xj3kMs@4 z@e%P%YZ0yolI6$I4O*f%UD%3We|5J_z2GP9oX9wakBRJFq6ir)YJH3Vh4~x?;aSiw zUU{c#JJj-f$oJ}A ze8D)JIYm|{=vdqzAz3+z>~4HhzE!^{x6)`!TBh7NxG4=qJGduhXz6)0UY2%2^RjDq z9e7GlCyFOBaE1x3Qayb}EZ1_8YMI9O^voXOTZ;Zt=!zeH2abp_j&^Mn5W&G<(mx4) zOy&g5{Ru^tq;_aK)nN0&_IR8$>ab7 z$pzw%R0Y=u&y~z~{5P!rRz`$FJ42?RsRy#Q_(;|=`?97*&%yOvc1(s<9u Gp8P*9`?UrD literal 0 HcmV?d00001 diff --git a/common/__pycache__/factorizations.cpython-310.pyc b/common/__pycache__/factorizations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70236b3d88908f259253004aace381f44e41741e GIT binary patch literal 5942 zcma)A%X8bt8OJU`5(FP+WZ9PD#D*QkHeEfO$0V6lahm#-q=_8Ij@z)!aA+<>Q6d3) zv5X}JoM~%k+8)xGo->(By>xu*rT;;1z4nA_PCn=6L)+iCAP7=UJi!bW`}p?5V)ys^ z9`MU$+rah3H^KTpPaDR!)R??7X#5O!@DCKi5Ugv|>B+mS$LpNMeVMxHntNH1*=M@P z?B;s;dcJ4XtzMyC=-G8!_h->xtQUKwdZ|~gmwT0ZMfc~rhkDg|74>}m8S$*JMezZv z&!AQkWz-I%RuP9#n?{zU=I};+E?LzKmXN-f0PUy}i-d z-0Jo`zkgQ-YVi5J?Va6setzllm8;j@eeeBWeDLAg^^b0ReDl^Pw}1JoPwQ7BD_LmE zr2QIQ1dJT+OSl7z8HO_USZM6Cp{cma2qQGVW(5<*sGSSBJcS;Nwi(89 z_t`OHnD=JF+@2|z%7o^;u@CUXW5awb8|DRHHNuR@Y;pO9_abU0Y9H!ZS(+uFC!QJR zL~h=|jO?o6RUw&&7T1xiu z(f@Eg`Y(zoC-j*;KJtmT<$L!WzwgQ>PR&s}ebHGc2zTNB&+$&Pz)Xh!5<3(8w>!{EP8Y)IZLHS8F z7wd{_LVTw^@CQN7j@WjTyXSV=8!9RQdrjHA({-bAvcv{0G0OTD3y(A=L(8j#IRKWeLO$sab9$5yt!CG4i$Y2Zya$A!6=-7EAc-|5 zMhUH$7zJf)T48}`uvutlP}+ZE9~*xLIY0?Jv_XmTnmj~1Nt8$~jtUu+!?=T46m26k zKFbYR$mpGa;dlAX7%mz!H6HQh3lZNA9c{!rHAAN%#Q3;;swyqzj%{WlKxp^ z*GW;(>MyT?h|O+r5tT&v5eZ7-sjO1*3>CB}NsE$4sW?W(5frtOoTK_YiYOB(DUVb0 zSt_2V;sg~ZsdxcJl#lrkMEMw>fm}d$ttel_qh4tB8?k^qOH;GsFIUBJ*1Fx?Lfl?B2^ZXvYpb3*g5;JrLOgN!;k4xoIB zAE70(^S}vuE#eDU#K{1Q+*mh{S%UFVEq({k50q<{(Dn|aMl1;65;XRO{Y{3Pg362H z6c#jpr29E!OhU#E?pBm|ih zA_SrLJ3^Tzg!LOfWJ4aZkUvR^hZ*@Y1m!CzB78)Q1`P>9gM1Bzhxxg)ZKnZjN88sQ(=*#?;!aT%U4pyT={RS5Bwb@47wLQ?78< z;redrm3=S!$|?Kg8>D;8z9_TX=|_jxeDz+cPCLMZlr&qubmf~gTXV4#bFmz6$W!y>e|_3n?3wr>1Tv_;V&VzmN@lSG78 z86qo**UhF$HYw3udS{)0 ztyUIga#5YEtLDtVPziA!FvteNu$SNf7#zS{Xv!Id3?^jv`Oxs*2M3b+Wyo;IoQKQ; z(3|wXfp9LVzpRkpPB@HM>m|&1OXbE=CxP;0#{5X@Cy=?WbRMDKTvj@m()IiSt;xFK zBx>ebjMOnSYqqv%lF~9s`I#iYEh%J;jUa(XlJStVBeH_>@q+WX* zM@H!(HZbI&PWZmcVk7wOg-fGU9kIXmAv)l9;!G!4bdux%9tI8uPN$YhKUS{cMKx18 zc*s*!oTh>nCqJY@oAcDBPo(7fHOR%0Hz5`rwA8;LAJG2m3?d^xLNzK+2)JVEgGrQb zjm-a`5;Bd(iZXyjgiM5yugoDrrJFQHx$hmt_pk~3OrC<2c~A;rm7#L?x%@OVHuHPd zzCn3%@anilslsatnd-QQQiit`0`+kZ<;(9Xge&773X*Q9(oI8^5hy`(l0Bo1fM>I^ zd1#`i7_zUqD8->vt=#+;hV>fknC!Hiwacl?*V&rmuRB-$f$X@_x!HGH`03JhaH5|( zl#kLM7#yo@JGeeV7TR-_D+AqJ8eh5@$I{N5ql546Ea|NGY~Wly@GnG*j{}c*6>mCU zoSizUvy0Bz6gg*yiw^#tGzVRTEP-=CR4b$v-rI2120ldj&V(`tea>ikO;0$&cbWv$ zdrc{velrN1a}(08sEzdjugNYrq5~n--O<|Ol?AkUq&!_{@}7GgwV{VPUVosRbE!lv z-pr3(HIUxaPU_dL4?Ml)*vNC+iz9V4XX&D|?VL}vcf9eL1^s5r)idP{KsezR(MqQ@ zk)6J}UK&{zO`RyKANBI6GmeocywcW3sd2T{-#v2;!Qy~?AvS?x#qpUYYtlbWb{7uF zjkaA6VzXLZ+V0)MBsGnIw|37jI%D;xrw-&a8IIHGD^J&|QYVhY@CnmFRwJ@I0sh9Q zrq^<1i8@RVsEx8+|DG%LKLvS(rje)8370m{QK8Y$i)%C@tI?n-jYch}O>p%H^*iU*lJ`qp4-$ z=PlSuI_IP{uRlL3XxRAp7o|5QvSRDo-O(|7R1C0`J53}Vo`_Dxzi^i0UpUJ+Zv4W* zW6btIYG0+@6xqkvlWEsQUMu!)@RMkc&#_X*=*2XO5y+7FP*Re literal 0 HcmV?d00001 diff --git a/common/__pycache__/layers.cpython-310.pyc b/common/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a54e6180329c4f6f46c97e943a27cffd82521dea GIT binary patch literal 8093 zcmb7JOLH98b?(>n^gJ*#7?J=8QBsq7fFqNDO-r(+a%{zXNL0y?Vo-51vD0bx?ZNb* zr+aw2ha`|j#WJYKc~E&ek84%24g9Dy6kh z8|qzMruEPmmb#^(*)?U}2+KpOYsqvetPHE&s!W?&>I9%#3$$TX|JzVZC%XBq-W_YH1rs(@@_gTzP z3!fW4-+g{)cWqf-NB(U0tg85PgL8uy*jd&@>P6P_Ut;tA%dAbgp2beEllMw{YIo&r zg`Hvx_Y}6^HwWkU^lk?wi);xcOGU{mC^^lRQLO zv`22*Pwev>Y3%jgBn>?Kia$#G9ed<+I{}A6-#&k*??<*f8ij$+>>#p-*pcjM$32_- zNx;T#Xx|KDZ!77nR&6{t`axpjKkfT8jBmSnnkKy*N0INPc7W0C>$b;zH^socnA=x- zy}%0~$8|rwVEd_OM}B%I=3B`HyC2{2Z~Od$9G}LGd9WFP_jW`RV6tJ5guctbY?^LG zZm^hah3Z|K1-%}$DoSa#AuXRI6_2qP?9fhpC?FPt&0+tmtNC6wd>hw*?_Y*wd!uTj*{>2yVn|mv4?~0J&2zGPYM}-BJcCHdfq6 z!xVkNg1h!CkHH7m&IgiQqS3~QFDF|`#{7108w{j9uS`TVnJnKtZ$T>y?f%biaDXBX z>I3@W3zWvuWa;+E7bf#>2cF-tL@DWmeZq)XFchUU=3ZYI95eGp>b3+TOzx+eLR2;2IjwP$3?a5IF zi(kZC32~!dF;q)kR+m5Wft1-2!v+^f&jCE$*t8@*9 z1A9;DmOfDorAvp?5*7*Y1I)PdtCaeGozb87kgj_`faybNy?7InPpYJ9rtE3^nuHLf z9i0J=4%PLJ0RfM3JUB5#SdIgDOC3jw59(kfe%Oxr z#G-~;Rd?nmJ6@fP_So{M?Pcmj5=eoQgo1drJ$*yrFQKrbOF{+HC;tfDTh5Q(;n&YIF6_~j?^Yf*B$2;aIdH-I}VFI?2~rK zFVGxw#pXRf3?~e*sTO`m!3tvGSMVklh@}{6Lu(n5u}{^fQhDP66sKdG>ci# zB}3{`3A$7jRp^n|zdF3>vqSB{zo9FC4FqZ8Bd56^kmMhb{Jj!VlZVU_nJ21EehBiY zPB8Vfl0lX30RO_t0wi$L-yE1P)5)3lRRq`^uj6} zon%Wiu0|eg6u;ujAASlWz{Yz0J05|@p;u5V=CICuAJl^O@uQi2=(1k!V`k1`)fss4 zbFBIwW=43*s!8sKvX%H9jX#OL=Be)7Ub8G&riy$&jom>bePC* zbyNHErX9tp4MRMNIXuW!{tFo8!T;j<84~ooC0(~eSmdGaMnX?gcCB+F$HDheFLhd! z;96}0HsRj;@X&sdDoBfk`5s&H4Ex&c}v?_JbpT#49$dXbS??f62fReDB$@3z?pk6 zDkreuUm#*niL?uZXdmO1&I(+ng-0fw;2A)jI!c`)1e_V?2&zc@-=V>#P<4@08O?s| zAta$Ec^72rbl{Ha>w^L#2&!1QmpdXj^Ke8;zBzzdXH{0ary!zGVRbeK$D}G9lN!uK zL+Ian?}POl`8rz|Tk z0kg6aVb*R_uCGZe%}V_HMQ(PrpUTyukL;oEHlMuOteR=ZR*Nx@R+}}kPD0(kE7nc< zeBEVM-7@$;W5o@u`0rB#C``t_6)hCI_$QRj7iD$qsF_*dRIA{WvEM#Wc3Wt3BCBVO ztd%vh)&aSACnfsL&!QjyCCFFe82>hOE_t*uzH9Xj1hR)`LMvo8N57> zlXTxLT2HU~o*NN%-h@x$hX_O<_K*aK`2eN-c$|dWJi;tH^TpTqHzPzd^($auY<9={5o+6qO@)Oc@Lz~0C>`$-o}*=5UaK&yq<(t zYw9h1Qe&W2`~FdF74?>8$sX{4HA`F4cj`cor(-N#V_f+^rWN;c$Bzr?w$Dg7@es$D#7HB<>H+k$9t4~(;91>U; z;>%klUWFaN3ds&koNBC7Tj%ef6|S{X1iMH-b$#-}`7dZ85xoi)R<)fPEox#H>W_M~ zWa)+R_SN@3%$?h3D6bD-z|CM2Cfw(8MzT9k_jyV4e+O@J2E!hn0c9a!u$o)Qsm>?z8hYDS{w>IHcQ}O^gL%L`RNhR>@@8tK6&RC6h^I{N9^F=} zC^JScA#X8t8DZH)65li}+thG-Qk*)p3dzzs3ntT1SL(B zw2qa`6GR@Cay7^u#A}^)4mlrTE38$_lW#Hkq!}? z$XOx;hX0QDA;kz0;?YgGPAAnDRa3KZn{8_g2>MKf{p5p85+kF38E+mM$8jMk=Xedqp^&+TQ8XO7@;X8z z{{t_Ye~5WS3C?vtKGcPpP^fYUR>HYmD7f#BXKKe3klT(KHO(s%2gm}FPw-10ji0G8Gv)BJB}^h(|e%H3U_ zcF5zz2|f=+DK4u%&+oqrc>@@<5Oq_hB6mNkwCHNy!W%P}(~U=%@WaTPtRo|#x_dVN zXtA1qw4g6g_-G+#qK}f2A9K{WL8M9Kw?GgZ^>{FXfxjU?S4n^;^KcV9m9cYSev7`9 z_%Z^W2TR|l-a{fE5!okVA19;^!undi;jcp_l!St0^(F-@>(X%}_pJ?=@2I4u)#3J? RQ`_x!t$nI}slBma{~x>diDUo( literal 0 HcmV?d00001 diff --git a/common/__pycache__/spectral_convolution.cpython-310.pyc b/common/__pycache__/spectral_convolution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb130504148ed5e27b57bbf9cf0e214b7162fa8d GIT binary patch literal 9163 zcmc&)OOPAab?w*P02<%Hd^s~5lHgAwD9%u#C4XW@KhiWM+apCJEg7-W0`YogfN3!>jm){{PBOmjNr&opN%v(356T=P)n5bGR&zPV6Y;B*@Ki0M<^H`-M zjISCpBeNeHGV7U*<9k-+1T;CBhbG_GoP?$zC!m?=YfedV%_tWiVC=?fS!k=;Z0l6} zPSjGh4LldGRBK4rcRhLC3&NIy&aSnZP5fprwVG|;+j+yCbj6k_>3nQd5}z2hQAuJFGCFsy?bXmZ|57WsUCj%_ z$dvFIXiKDS`CC!F6_n~U^rja?o-Ez2S4-EfUN5y&>B^PsE2Ya`SYHp2u?7vsJ17*E zP%zr?O3`*}Ic#ru$}3&Fz-?@W^IkW3Mq$EX{0hpS&b||XE&|(IeyOZb8eH$=A#1Xq}Vgnt4QVB zFSE@A+f2kZodhh49kk$1vL3B8oJ}))>rLY$^akyswMGV-_gG`Ci!Ny5 z#_h%>F?Kh$$Ee+-uqjG=(^P+lei;pFadF+G-UTz9mq6Y>#(8nC11CBrX>N;uVf@V8 z6Ca7F)H~jBI@V!hUqmMwC#fF%o|371=8VzJbTZOO7z+m3$@ISX5*G5J(aFH(bSEw8 z_f%bcPrx=CJ=1%(lijBkJu_(k2SQEAthCm#!i_U-Id(ps^krj@=K73#R&q|`tSnqK zf<>8KkM(6aL2l}RCDyocCbk@wDqb|A=SVLnJ4rd!NyzDYw8G~)DLJ!gqW!DXujoW? zSPDQD&EzRkT{k;XLxc5`f3KWGPAn(L!O+qh>dZh761hyIFsxq{xK2KtugSGB#a z>QajhEw)tw+g&H+2A%X@AHEUv&mzu$3x}tQ^ID4Tw7t^WeW4SQk_xWH=$|3uRW-*I zG^-JOXN|H39~>b*OxN)H%Ob=&w%hx57cT9G<}LG@abK+7H}0F-TzOs#R~?24i(l8z zw*tQ&crTp4QwwYEdgIp4&d%=6`3Bs~MzwXm*^*oSRyg=wZdNy|LEVL?U#|z}<0s1W z{O~-SF8t8>uCYudC#~ zbvpFgRaL#|dpgBI#oN*O!E(8@M%uwR!>AfW4`@lW({6=zDtq5vT~(9R+7v|jyiQ=j zJ*Dz&H5jO#gu~vhD(NDi#NJy^)?@!Vrpnd>x3*CY0!*~dwYCrn2T!)&tnTPEEkfw} z)kxdTaRmlp2doZ!qzfUY)pHd>kKC&II<-*^t5KvBZ5vf4@ho*PQ%9tsVbwfsMH^Bt zT*f-6I)aEN$NWFynRa|P+^Dv_P$hV56saj{>xp5fTwYG?H2aprtX9y^%Ak*f%nWsS zsE0{hy@q~d|IMnu<$Xn|mcodu(T29eTPo5Cj(gf}1zxDr{Wbgii#mZG$v~R{7WHi` zY14~5rIXuUeSIU+xq3*AC{MZoHWQSTNXvRNghFL#TI*xzuBFA?7^(ju^oIUs8b-ngm3sC5FODlx@Jdot&Zr4eWNR& zF~4nh?QWu*?4}~CVM`&+d*b^`2=qoT(GmNimt;$$ky3EmJ%@Ct!RncxWtXj$o$VtfiCLf=8-D%0yEJV2&`|iY=-@*Y+ z@OD--liI0nMeuYuTC{z12F@g^uWz-sLiJlzoc5D8v8S)o)9;X=^BFhrWlB*i<+OSi zKU(a>4LQT}Ene5+k``CBcvqV{?UCzOw6S!m?|nLW2+u$yg)JOxi-L)5MjJ<%A1?j( zd$35Wt*eil`Cojb(K@vaoqq%&2@Fgbvc+&I_J8#;C%Wy zDc6;)8umNwa@B;;`TilOdA>h52<*ncO(uLK$c;XSZ%7-^F>Fyt7T;v&`Uuib!Gb;^ z&FBH7Z%RwrGU3^K2I2j+6eE8^{OL+U0?PpTCmH#tu+&*SH@fDG4}bRq2EL=r65(Tl z|E;y*wSePW=|<>F}SQ>>Tu9~NH{m=i5KL7N}YCOL+ z0@8yIP52O02iRNx@1Pp@-Yr92k!CjqKxIhlyXJz?EAHc5ZJOO2gKhvc0JEtM!TM={ zPQxXb$}pb+xc3ME4~76d56hYGPAA_7;Q5HaoE^^(FfayS+~Tx zWM6kApz8#+Ot}QCkKlJwrT~=x7Ihm9(jQ?EoexT!hpU15`#6t5be_TWq)bEqbJkDw z^-gqD(j{eb2(FJcN}c=~`74YXFgW*9tCRj~z-6?29FX*6CmWsWeNVgFE>>Fbx2-)Y5-O&!Eh+K{g_2eOwkXf0$J{$zbNGWbmIV0;b&kY5iDiiL1+_xykOVuP$A#|~DfcB3pC|DJ z5?>_oB@*nQ*#{GN(h0w{4SdHTpT?qYlVEQRLdj6SP2yz|uaF?Dq+TWQ8j06QP#{%T zNf0}st-#xkef}Gydy_bRSJEq<;E^deTy2ZkYJ}hLg6&qKgBov zJ%~c#Q843JHoyx(8u~j3Yx+ZXbLK-J=Ge$5k2~z)J>vkJp%oz) z%5cWUMtc;_L?&?wV*tix0F&qgn4~%dfRP#k80QH9W@Q{|;E<)Ig#&#VWgmhv`+!aX zftrz?+(BB}fH-!??s3awy^0}D+=E9UPGh{?2pot^z*XL1h*OLqPH&R?d;s7mS}8uC z>H|tK(wV2tR*pcO0= zB~S%Rx%yVguio+02!lLe)M_iD17p9sjh->wI{-(B>>!t9;Tt{!`4Ai-Zq8y}&t@Qy zAWjM#pXf7uf*3dOdlVM4%mvaZ$Tlb-=Nfr%eP85fIt{r-VfakEUak@Idc}@OwK0%* zoneDH;{glG0z$yG!;_=++2tdT`^LNpF$BKjTr{b=N78=C8S$7Sj z4&$FKkC^3K>yp$3|%ABbC-@ww_a=+kRsW6wNc>`$4~xFyKJymDI@--##A z*yPh^487b4?5HfW#Hj*|45QBgn(=9+D(_n{JXl3qeS^djX-%Eo zs7vXAZozw;ICPIZYY8S-5fhoWJ7w&yJmN`)LG^Lu9pg5<{jmwf>C8-<`b?YzD1u+{4BhRz)tWyc`3- z52-u@Kq6R1tPMz@0g=frFjS8NgU?eLx}%T5!XiCV^wU-y@32yoDuA+e@xIFiYbSkA z5zNpTQuY-G0p3IF!{pVvr<3%WC6EmBbbiQUeVv%ODQfj?s@C2HKd8k`o%sf_w;pBd zENIU9Z3fk$N|J4s1l=GRaO%`$)oN2g=M8Lus2<%xc!R23csfkfQ)8RPjnRAm9iO-3 zj#`b8;vZ25(^M}oJqS2Lj@)%E{1_+ldJai1y#F-928{YR5qOX&Ey9LHt02tJ3a0&= zVzP*M?hq*^rmRKKS5u&|KyR5PVVftwR}}GVg1d+!XFjv)&(O)d>P5kg?)4AJqZ`;MGq|mOuPkLG`^<@(GNbnA$a_V~!dSXzZ zkNee6DC)p*7J0f~bEpI|i2QDip&1sLQnGGho>!gEIq zB?0^qLn;1m03U0?HH_8#bS`pr9H(M& zINqaj_SotWC{+xpN7$85Aff5srzI#ML1>yM9j9!-3uaOUi&;S%7<%D}3sFwRLG2Hy zp^qRQhK|@QkHOcmTkcUI>dBMg^@qsKQ9tLA7>Cy}$S(ZKd9OI9$f|rg-t7)EZ!1B- zzO*tucwE6h7Qj)`%WJ$E8l}Y*od~v?c#UNSZAH#nXW}xoEm`Hct~di)4}7h!vmxGs z+^p6%;jk;mB0shsup8hsV(W0F)>2-4-Ot7E(^l}G1AhE2rBdpvC~Ep-ORv~gMqV}c z<@)$z&4^2;f1$u|LJmV5c9=7$oP2DFLHGJ`f;ojO2Fr0mPw*HJvEu(X2rG=09{Yj!8_^#^ehv%a{)shH~v(#HbO!UdGDe4gujz&^rqUf1sdo`tjZZ T+iwQjvS>N torch.Tensor: + if self.mode == "cartesian": + zr = torch.view_as_real(z) + za = self.act(zr) + out = torch.view_as_complex(za) + + elif self.mode == "modulus": + zabs = torch.sqrt(torch.square(z.real) + torch.square(z.imag)) + out = torch.where(zabs + self.bias > 0, (zabs + self.bias) * z / zabs, 0.0) + # out = self.act(zabs - self.bias) * torch.exp(1.j * z.angle()) + + elif self.mode == "halfplane": + # bias is an angle parameter in this case + modified_angle = torch.angle(z) - self.bias + condition = torch.logical_and((0.0 <= modified_angle), (modified_angle < torch.pi / 2.0)) + out = torch.where(condition, z, self.negative_slope * z) + + elif self.mode == "real": + zr = torch.view_as_real(z) + outr = zr.clone() + outr[..., 0] = self.act(zr[..., 0]) + out = torch.view_as_complex(outr) + + else: + raise NotImplementedError + + return out + + +class ComplexActivation(nn.Module): + def __init__(self, activation, mode="cartesian", bias_shape=None): + super().__init__() + + # store parameters + self.mode = mode + if self.mode == "modulus": + if bias_shape is not None: + self.bias = nn.Parameter(torch.zeros(bias_shape, dtype=torch.float32)) + else: + self.bias = nn.Parameter(torch.zeros((1), dtype=torch.float32)) + else: + bias = torch.zeros((1), dtype=torch.float32) + self.register_buffer("bias", bias) + + # real valued activation + self.act = activation + + def forward(self, z: torch.Tensor) -> torch.Tensor: + if self.mode == "cartesian": + zr = torch.view_as_real(z) + za = self.act(zr) + out = torch.view_as_complex(za) + elif self.mode == "modulus": + zabs = torch.sqrt(torch.square(z.real) + torch.square(z.imag)) + out = self.act(zabs + self.bias) * torch.exp(1.0j * z.angle()) + else: + # identity + out = z + + return out diff --git a/common/contractions.py b/common/contractions.py new file mode 100644 index 0000000..edbf6cf --- /dev/null +++ b/common/contractions.py @@ -0,0 +1,178 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +@torch.jit.script +def _contract_rank(xc: torch.Tensor, wc: torch.Tensor, ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # return torch.einsum("bixy,ior,xr,yr->boxy", x, w, a, b) + # xc = torch.view_as_complex(x) + # wc = w #torch.view_as_complex(w) + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,ior,xr,yr->boxy", xc, wc, ac, bc) + # res = torch.view_as_real(resc) + return resc + + +# # Helper routines for FNOs +@torch.jit.script +def compl_mul1d_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bix,io->box", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def compl_muladd1d_fwd(ac: torch.Tensor, bc: torch.Tensor, cc: torch.Tensor) -> torch.Tensor: + tmpcc = compl_mul1d_fwd(ac, bc) + # cc = torch.view_as_complex(c) + return tmpcc + cc + + +@torch.jit.script +def compl_mul2d_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,io->boxy", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def compl_muladd2d_fwd(ac: torch.Tensor, bc: torch.Tensor, cc: torch.Tensor) -> torch.Tensor: + tmpcc = compl_mul2d_fwd(ac, bc) + # cc = torch.view_as_complex(c) + return tmpcc + cc + + +@torch.jit.script +def _contract_localconv_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,iox->boxy", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def _contract_blockconv_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bim,imn->bin", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def _contractadd_blockconv_fwd(ac: torch.Tensor, bc: torch.Tensor, cc: torch.Tensor) -> torch.Tensor: + tmpcc = _contract_blockconv_fwd(ac, bc) + # cc = torch.view_as_complex(c) + return tmpcc + cc + + +# for the experimental layer +@torch.jit.script +def compl_exp_mul2d_fwd(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,xio->boxy", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def compl_exp_muladd2d_fwd(ac: torch.Tensor, bc: torch.Tensor, cc: torch.Tensor) -> torch.Tensor: + tmpcc = compl_exp_mul2d_fwd(ac, bc) + # cc = torch.view_as_complex(c) + return tmpcc + cc + + +@torch.jit.script +def real_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + res = torch.einsum("bixy,io->boxy", a, b) + return res + + +@torch.jit.script +def real_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + res = real_mul2d_fwd(a, b) + c + return res + + +# new contractions set to replace older ones. We use complex + + +@torch.jit.script +def _contract_diagonal(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,ioxy->boxy", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def _contract_dhconv(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,iox->boxy", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def _contract_sep_diagonal(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,ixy->boxy", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def _contract_sep_dhconv(ac: torch.Tensor, bc: torch.Tensor) -> torch.Tensor: + # ac = torch.view_as_complex(a) + # bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,ix->boxy", ac, bc) + # res = torch.view_as_real(resc) + return resc + + +@torch.jit.script +def _contract_diagonal_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + res = torch.einsum("bixys,ioxy->boxys", a, b).contiguous() + return res + + +@torch.jit.script +def _contract_dhconv_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + res = torch.einsum("bixys,iox->boxys", a, b).contiguous() + return res + + +@torch.jit.script +def _contract_sep_diagonal_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + res = torch.einsum("bixys,ixy->boxys", a, b).contiguous() + return res + + +@torch.jit.script +def _contract_sep_dhconv_real(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + res = torch.einsum("bixys,ix->boxys", a, b).contiguous() + return res diff --git a/common/factorizations.py b/common/factorizations.py new file mode 100644 index 0000000..0fd2759 --- /dev/null +++ b/common/factorizations.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from functools import partial + +import tensorly as tl + +tl.set_backend("pytorch") + +from makani.models.common.contractions import _contract_diagonal, _contract_dhconv, _contract_sep_diagonal, _contract_sep_dhconv +from makani.models.common.contractions import _contract_diagonal_real, _contract_dhconv_real, _contract_sep_diagonal_real, _contract_sep_dhconv_real + + +from tltorch.factorized_tensors.core import FactorizedTensor + +einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + +def _contract_dense(x, weight, separable=False, operator_type="diagonal"): + order = tl.ndim(x) + # batch-size, in_channels, x, y... + x_syms = list(einsum_symbols[:order]) + + # in_channels, out_channels, x, y... + weight_syms = list(x_syms[1:]) # no batch-size + + # batch-size, out_channels, x, y... + if separable: + out_syms = [x_syms[0]] + list(weight_syms) + else: + weight_syms.insert(1, einsum_symbols[order]) # outputs + out_syms = list(weight_syms) + out_syms[0] = x_syms[0] + + if operator_type == "diagonal": + pass + elif operator_type == "block-diagonal": + weight_syms.insert(-1, einsum_symbols[order + 1]) + out_syms[-1] = weight_syms[-2] + elif operator_type == "dhconv": + weight_syms.pop() + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + eq = "".join(x_syms) + "," + "".join(weight_syms) + "->" + "".join(out_syms) + + if not torch.is_tensor(weight): + weight = weight.to_tensor() + + res = tl.einsum(eq, x, weight).contiguous() + + return res + + +def _contract_cp(x, cp_weight, separable=False, operator_type="diagonal"): + order = tl.ndim(x) + + x_syms = str(einsum_symbols[:order]) + rank_sym = einsum_symbols[order] + out_sym = einsum_symbols[order + 1] + out_syms = list(x_syms) + + if separable: + factor_syms = [einsum_symbols[1] + rank_sym] # in only + else: + out_syms[1] = out_sym + factor_syms = [einsum_symbols[1] + rank_sym, out_sym + rank_sym] # in, out + + factor_syms += [xs + rank_sym for xs in x_syms[2:]] # x, y, ... + + if operator_type == "diagonal": + pass + elif operator_type == "block-diagonal": + out_syms[-1] = einsum_symbols[order + 2] + factor_syms += [out_syms[-1] + rank_sym] + elif operator_type == "dhconv": + factor_syms.pop() + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + eq = x_syms + "," + rank_sym + "," + ",".join(factor_syms) + "->" + "".join(out_syms) + + res = tl.einsum(eq, x, cp_weight.weights, *cp_weight.factors).contiguous() + + return res + + +def _contract_tucker(x, tucker_weight, separable=False, operator_type="diagonal"): + order = tl.ndim(x) + + x_syms = str(einsum_symbols[:order]) + out_sym = einsum_symbols[order] + out_syms = list(x_syms) + if separable: + core_syms = einsum_symbols[order + 1 : 2 * order] + factor_syms = [xs + rs for (xs, rs) in zip(x_syms[1:], core_syms)] # x, y, ... + + else: + core_syms = einsum_symbols[order + 1 : 2 * order + 1] + out_syms[1] = out_sym + factor_syms = [einsum_symbols[1] + core_syms[0], out_sym + core_syms[1]] # out, in + factor_syms += [xs + rs for (xs, rs) in zip(x_syms[2:], core_syms[2:])] # x, y, ... + + if operator_type == "diagonal": + pass + elif operator_type == "block-diagonal": + raise NotImplementedError(f"Operator type {operator_type} not implemented for Tucker") + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + eq = x_syms + "," + core_syms + "," + ",".join(factor_syms) + "->" + "".join(out_syms) + + res = tl.einsum(eq, x, tucker_weight.core, *tucker_weight.factors).contiguous() + + return res + + +def _contract_tt(x, tt_weight, separable=False, operator_type="diagonal"): + order = tl.ndim(x) + + x_syms = list(einsum_symbols[:order]) + weight_syms = list(x_syms[1:]) # no batch-size + + if not separable: + weight_syms.insert(1, einsum_symbols[order]) # outputs + out_syms = list(weight_syms) + out_syms[0] = x_syms[0] + else: + out_syms = list(x_syms) + + if operator_type == "diagonal": + pass + elif operator_type == "block-diagonal": + weight_syms.insert(-1, einsum_symbols[order + 1]) + out_syms[-1] = weight_syms[-2] + elif operator_type == "dhconv": + weight_syms.pop() + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + rank_syms = list(einsum_symbols[order + 2 :]) + tt_syms = [] + for i, s in enumerate(weight_syms): + tt_syms.append([rank_syms[i], s, rank_syms[i + 1]]) + eq = "".join(x_syms) + "," + ",".join("".join(f) for f in tt_syms) + "->" + "".join(out_syms) + + res = tl.einsum(eq, x, *tt_weight.factors).contiguous() + + return res + + +# jitted PyTorch contractions: +def _contract_dense_pytorch(x, weight, separable=False, operator_type="diagonal", complex=True): + # make sure input is contig + x = x.contiguous() + + if separable: + if operator_type == "diagonal": + if complex: + x = _contract_sep_diagonal(x, weight) + else: + x = _contract_sep_diagonal_real(x, weight) + elif operator_type == "dhconv": + if complex: + x = _contract_sep_dhconv(x, weight) + else: + x = _contract_sep_dhconv_real(x, weight) + else: + raise ValueError(f"Unkonw operator type {operator_type}") + else: + if operator_type == "diagonal": + if complex: + x = _contract_diagonal(x, weight) + else: + x = _contract_diagonal_real(x, weight) + elif operator_type == "dhconv": + if complex: + x = _contract_dhconv(x, weight) + else: + x = _contract_dhconv_real(x, weight) + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + # make contiguous + x = x.contiguous() + return x + + +def _contract_dense_reconstruct(x, weight, separable=False, operator_type="diagonal", complex=True): + """Contraction for dense tensors, factorized or not""" + if not torch.is_tensor(weight): + weight = weight.to_tensor() + # weight = torch.view_as_real(weight) + + return _contract_dense_pytorch(x, weight, separable=separable, operator_type=operator_type, complex=complex) + + +def get_contract_fun(weight, implementation="reconstructed", separable=False, operator_type="diagonal", complex=True): + """Generic ND implementation of Fourier Spectral Conv contraction + + Parameters + ---------- + weight : tensorly-torch's FactorizedTensor + implementation : {'reconstructed', 'factorized'}, default is 'reconstructed' + whether to reconstruct the weight and do a forward pass (reconstructed) + or contract directly the factors of the factorized weight with the input (factorized) + + Returns + ------- + function : (x, weight) -> x * weight in Fourier space + """ + if implementation == "reconstructed": + handle = partial(_contract_dense_reconstruct, separable=separable, complex=complex, operator_type=operator_type) + return handle + elif implementation == "factorized": + if torch.is_tensor(weight): + handle = partial(_contract_dense_pytorch, separable=separable, complex=complex, operator_type=operator_type) + return handle + elif isinstance(weight, FactorizedTensor): + if weight.name.lower() == "complexdense" or weight.name.lower() == "dense": + return _contract_dense + elif weight.name.lower() == "complextucker": + return _contract_tucker + elif weight.name.lower() == "complextt": + return _contract_tt + elif weight.name.lower() == "complexcp": + return _contract_cp + else: + raise ValueError(f"Got unexpected factorized weight type {weight.name}") + else: + raise ValueError(f"Got unexpected weight type of class {weight.__class__.__name__}") + else: + raise ValueError(f'Got {implementation=}, expected "reconstructed" or "factorized"') diff --git a/common/layers.py b/common/layers.py new file mode 100644 index 0000000..1c18bd3 --- /dev/null +++ b/common/layers.py @@ -0,0 +1,288 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from collections import OrderedDict +from copy import Error, deepcopy +from re import S +from numpy.lib.arraypad import pad +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.fft +from torch.nn.modules.container import Sequential +from torch.utils.checkpoint import checkpoint, checkpoint_sequential +from torch.cuda import amp +from typing import Optional +import math + +from makani.models.common.contractions import compl_muladd2d_fwd, compl_mul2d_fwd +from makani.models.common.contractions import _contract_diagonal + + +@torch.jit.script +def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1.0 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2d ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class PatchEmbed(nn.Module): + def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768): + super(PatchEmbed, self).__init__() + self.red_img_size = ((img_size[0] // patch_size[0]), (img_size[1] // patch_size[1])) + num_patches = self.red_img_size[0] * self.red_img_size[1] + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + self.proj.weight.is_shared_mp = ["spatial"] + self.proj.bias.is_shared_mp = ["spatial"] + + def forward(self, x): + # gather input + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + # new: B, C, H*W + x = self.proj(x).flatten(2) + return x + + +class EncoderDecoder(nn.Module): + def __init__(self, num_layers, input_dim, output_dim, hidden_dim, act_layer, gain=1.0, input_format="nchw"): + super(EncoderDecoder, self).__init__() + + encoder_modules = [] + current_dim = input_dim + for i in range(num_layers): + # fully connected layer + if input_format == "nchw": + encoder_modules.append(nn.Conv2d(current_dim, hidden_dim, 1, bias=True)) + elif input_format == "traditional": + encoder_modules.append(nn.Linear(current_dim, hidden_dim, bias=True)) + else: + raise NotImplementedError(f"Error, input format {input_format} not supported.") + + # weight sharing + encoder_modules[-1].weight.is_shared_mp = ["spatial"] + + # proper initializaiton + scale = math.sqrt(2.0 / current_dim) + nn.init.normal_(encoder_modules[-1].weight, mean=0.0, std=scale) + if encoder_modules[-1].bias is not None: + encoder_modules[-1].bias.is_shared_mp = ["spatial"] + nn.init.constant_(encoder_modules[-1].bias, 0.0) + + encoder_modules.append(act_layer()) + current_dim = hidden_dim + + # final output layer + if input_format == "nchw": + encoder_modules.append(nn.Conv2d(current_dim, output_dim, 1, bias=False)) + elif input_format == "traditional": + encoder_modules.append(nn.Linear(current_dim, output_dim, bias=False)) + + # weight sharing + encoder_modules[-1].weight.is_shared_mp = ["spatial"] + + # proper initializaiton + scale = math.sqrt(gain / current_dim) + nn.init.normal_(encoder_modules[-1].weight, mean=0.0, std=scale) + if encoder_modules[-1].bias is not None: + encoder_modules[-1].bias.is_shared_mp = ["spatial"] + nn.init.constant_(encoder_modules[-1].bias, 0.0) + + self.fwd = nn.Sequential(*encoder_modules) + + def forward(self, x): + return self.fwd(x) + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + output_bias=True, + input_format="nchw", + drop_rate=0.0, + drop_type="iid", + checkpointing=0, + gain=1.0, + **kwargs, + ): + super(MLP, self).__init__() + self.checkpointing = checkpointing + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + # First fully connected layer + if input_format == "nchw": + fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True) + fc1.weight.is_shared_mp = ["spatial"] + fc1.bias.is_shared_mp = ["spatial"] + elif input_format == "traditional": + fc1 = nn.Linear(in_features, hidden_features, bias=True) + else: + raise NotImplementedError(f"Error, input format {input_format} not supported.") + + # initialize the weights correctly + scale = math.sqrt(2.0 / in_features) + nn.init.normal_(fc1.weight, mean=0.0, std=scale) + nn.init.constant_(fc1.bias, 0.0) + + # activation + act = act_layer() + + # sanity checks + if (input_format == "traditional") and (drop_type == "features"): + raise NotImplementedError(f"Error, traditional input format and feature dropout cannot be selected simultaneously") + + # output layer + if input_format == "nchw": + fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=output_bias) + fc2.weight.is_shared_mp = ["spatial"] + if output_bias: + fc2.bias.is_shared_mp = ["spatial"] + elif input_format == "traditional": + fc2 = nn.Linear(hidden_features, out_features, bias=output_bias) + else: + raise NotImplementedError(f"Error, input format {input_format} not supported.") + + # gain factor for the output determines the scaling of the output init + scale = math.sqrt(gain / hidden_features) + nn.init.normal_(fc2.weight, mean=0.0, std=scale) + if fc2.bias is not None: + nn.init.constant_(fc2.bias, 0.0) + + if drop_rate > 0.0: + if drop_type == "iid": + drop = nn.Dropout(drop_rate) + elif drop_type == "features": + drop = nn.Dropout2d(drop_rate) + else: + raise NotImplementedError(f"Error, drop_type {drop_type} not supported") + else: + drop = nn.Identity() + + # create forward pass + self.fwd = nn.Sequential(fc1, act, drop, fc2, drop) + + @torch.jit.ignore + def checkpoint_forward(self, x): + return checkpoint(self.fwd, x, use_reentrant=False) + + def forward(self, x): + if self.checkpointing >= 2: + return self.checkpoint_forward(x) + else: + return self.fwd(x) + + +class RealFFT2(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None): + super(RealFFT2, self).__init__() + + # use local FFT here + self.fft_handle = torch.fft.rfft2 + + self.nlat = nlat + self.nlon = nlon + self.lmax = min(lmax or self.nlat, self.nlat) + self.mmax = min(mmax or self.nlon // 2 + 1, self.nlon // 2 + 1) + + self.truncate = True + if (self.lmax == self.nlat) and (self.mmax == (self.nlon // 2 + 1)): + self.truncate = False + + self.lmax_high = math.ceil(self.lmax / 2) + self.lmax_low = math.floor(self.lmax / 2) + + def forward(self, x): + y = self.fft_handle(x, s=(self.nlat, self.nlon), dim=(-2, -1), norm="ortho") + + if self.truncate: + y = torch.cat((y[..., : self.lmax_high, : self.mmax], y[..., -self.lmax_low :, : self.mmax]), dim=-2) + + return y + + +class InverseRealFFT2(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None): + super(InverseRealFFT2, self).__init__() + + # use local FFT here + self.ifft_handle = torch.fft.irfft2 + + self.nlat = nlat + self.nlon = nlon + self.lmax = min(lmax or self.nlat, self.nlat) + self.mmax = min(mmax or self.nlon // 2 + 1, self.nlon // 2 + 1) + + self.truncate = True + if (self.lmax == self.nlat) and (self.mmax == (self.nlon // 2 + 1)): + self.truncate = False + + self.lmax_high = math.ceil(self.lmax / 2) + self.lmax_low = math.floor(self.lmax / 2) + + def forward(self, x): + # truncation is implicit but better do it manually + xt = x[..., : self.mmax] + + if self.truncate: + # pad + xth = xt[..., : self.lmax_high, :] + xtl = xt[..., -self.lmax_low :, :] + xthp = F.pad(xth, (0, 0, 0, self.nlat - self.lmax)) + xt = torch.cat([xthp, xtl], dim=-2) + + out = torch.fft.irfft2(xt, s=(self.nlat, self.nlon), dim=(-2, -1), norm="ortho") + + return out diff --git a/common/spectral_convolution.py b/common/spectral_convolution.py new file mode 100644 index 0000000..57d768f --- /dev/null +++ b/common/spectral_convolution.py @@ -0,0 +1,405 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math + +from torch.cuda import amp + +# import FactorizedTensor from tensorly for tensorized operations +import tensorly as tl + +tl.set_backend("pytorch") +from tltorch.factorized_tensors.core import FactorizedTensor + +# import convenience functions for factorized tensors +from makani.utils import comm +from makani.models.common.activations import ComplexReLU +from makani.models.common.contractions import compl_muladd2d_fwd, compl_mul2d_fwd, _contract_rank +from makani.models.common.factorizations import get_contract_fun + +# for the experimental module +from makani.models.common.contractions import compl_exp_muladd2d_fwd, compl_exp_mul2d_fwd + +import torch_harmonics as th +import torch_harmonics.distributed as thd + + +class SpectralConv(nn.Module): + """ + Spectral Convolution implemented via SHT or FFT. Designed for convolutions on the two-sphere S2 + using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic + domain via the RealFFT2 and InverseRealFFT2 wrappers. + """ + + def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, operator_type="diagonal", separable=False, bias=False, gain=1.0): + super(SpectralConv, self).__init__() + + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.in_channels = in_channels + self.out_channels = out_channels + + self.modes_lat = self.inverse_transform.lmax + self.modes_lon = self.inverse_transform.mmax + + self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon) + if hasattr(self.forward_transform, "grid"): + self.scale_residual = self.scale_residual or (self.forward_transform.grid != self.inverse_transform.grid) + + # remember factorization details + self.operator_type = operator_type + self.separable = separable + + assert self.inverse_transform.lmax == self.modes_lat + assert self.inverse_transform.mmax == self.modes_lon + + weight_shape = [in_channels] + + if not self.separable: + weight_shape += [out_channels] + + if isinstance(self.inverse_transform, thd.DistributedInverseRealSHT): + self.modes_lat_local = self.inverse_transform.l_shapes[comm.get_rank("h")] + self.modes_lon_local = self.inverse_transform.m_shapes[comm.get_rank("w")] + self.nlat_local = self.inverse_transform.lat_shapes[comm.get_rank("h")] + self.nlon_local = self.inverse_transform.lon_shapes[comm.get_rank("w")] + else: + self.modes_lat_local = self.modes_lat + self.modes_lon_local = self.modes_lon + self.nlat_local = self.inverse_transform.nlat + self.nlon_local = self.inverse_transform.nlon + + # unpadded weights + if self.operator_type == "diagonal": + weight_shape += [self.modes_lat_local, self.modes_lon_local] + elif self.operator_type == "dhconv": + weight_shape += [self.modes_lat_local] + else: + raise ValueError(f"Unsupported operator type f{self.operator_type}") + + # Compute scaling factor for correct initialization + scale = math.sqrt(gain / in_channels) * torch.ones(self.modes_lat_local, dtype=torch.complex64) + # seemingly the first weight is not really complex, so we need to account for that + scale[0] *= math.sqrt(2.0) + init = scale * torch.randn(*weight_shape, dtype=torch.complex64) + self.weight = nn.Parameter(init) + + if self.operator_type == "dhconv": + self.weight.is_shared_mp = ["matmul", "w"] + self.weight.sharded_dims_mp = [None for _ in weight_shape] + self.weight.sharded_dims_mp[-1] = "h" + else: + self.weight.is_shared_mp = ["matmul"] + self.weight.sharded_dims_mp = [None for _ in weight_shape] + self.weight.sharded_dims_mp[-1] = "w" + self.weight.sharded_dims_mp[-2] = "h" + + # get the contraction handle. This should return a pyTorch contraction + self._contract = get_contract_fun(self.weight, implementation="factorized", separable=separable, complex=True, operator_type=operator_type) + + if bias == "constant": + self.bias = nn.Parameter(torch.zeros(1, self.out_channels, 1, 1)) + elif bias == "position": + self.bias = nn.Parameter(torch.zeros(1, self.out_channels, self.nlat_local, self.nlon_local)) + self.bias.is_shared_mp = ["matmul"] + self.bias.sharded_dims_mp = [None, None, "h", "w"] + + def forward(self, x): + dtype = x.dtype + residual = x + x = x.float() + B, C, H, W = x.shape + + with amp.autocast(enabled=False): + x = self.forward_transform(x).contiguous() + if self.scale_residual: + residual = self.inverse_transform(x) + residual = residual.to(dtype) + + # approach with unpadded weights + xp = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type) + x = xp.contiguous() + + with amp.autocast(enabled=False): + x = self.inverse_transform(x) + + if hasattr(self, "bias"): + x = x + self.bias + + x = x.to(dtype=dtype) + + return x, residual + + +class FactorizedSpectralConv(nn.Module): + """ + Factorized version of SpectralConv. Uses tensorly-torch to keep the weights factorized + """ + + def __init__( + self, + forward_transform, + inverse_transform, + in_channels, + out_channels, + operator_type="diagonal", + rank=0.2, + factorization=None, + separable=False, + decomposition_kwargs=dict(), + bias=False, + gain=1.0, + ): + super(FactorizedSpectralConv, self).__init__() + + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.in_channels = in_channels + self.out_channels = out_channels + + self.modes_lat = self.inverse_transform.lmax + self.modes_lon = self.inverse_transform.mmax + + self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon) + if hasattr(self.forward_transform, "grid"): + self.scale_residual = self.scale_residual or (self.forward_transform.grid != self.inverse_transform.grid) + + # Make sure we are using a Complex Factorized Tensor + if factorization is None: + factorization = "ComplexDense" # No factorization + complex_weight = factorization[:7].lower() == "complex" + + # remember factorization details + self.operator_type = operator_type + self.rank = rank + self.factorization = factorization + self.separable = separable + + assert self.inverse_transform.lmax == self.modes_lat + assert self.inverse_transform.mmax == self.modes_lon + + weight_shape = [in_channels] + + if not self.separable: + weight_shape += [out_channels] + + if isinstance(self.inverse_transform, thd.DistributedInverseRealSHT): + self.modes_lat_local = self.inverse_transform.l_shapes[comm.get_rank("h")] + self.modes_lon_local = self.inverse_transform.m_shapes[comm.get_rank("w")] + else: + self.modes_lat_local = self.modes_lat + self.modes_lon_local = self.modes_lon + + # unpadded weights + if self.operator_type == "diagonal": + weight_shape += [self.modes_lat_local, self.modes_lon_local] + elif self.operator_type == "dhconv": + weight_shape += [self.modes_lat_local] + elif self.operator_type == "rank": + weight_shape += [self.rank] + else: + raise ValueError(f"Unsupported operator type f{self.operator_type}") + + # form weight tensors + self.weight = FactorizedTensor.new(weight_shape, rank=self.rank, factorization=factorization, fixed_rank_modes=False, **decomposition_kwargs) + # initialization of weights + scale = math.sqrt(gain / float(weight_shape[0])) + self.weight.normal_(mean=0.0, std=scale) + + # get the contraction handle + if operator_type == "rank": + self._contract = _contract_rank + else: + self._contract = get_contract_fun(self.weight, implementation="reconstructed", separable=separable, complex=complex_weight, operator_type=operator_type) + + if bias == "constant": + self.bias = nn.Parameter(torch.zeros(1, self.out_channels, 1, 1)) + elif bias == "position": + self.bias = nn.Parameter(torch.zeros(1, self.out_channels, self.nlat_local, self.nlon_local)) + self.bias.is_shared_mp = ["matmul"] + self.bias.sharded_dims_mp = [None, None, "h", "w"] + + def forward(self, x): + dtype = x.dtype + residual = x + x = x.float() + + with amp.autocast(enabled=False): + x = self.forward_transform(x).contiguous() + if self.scale_residual: + residual = self.inverse_transform(x) + residual = residual.to(dtype) + + if self.operator_type == "rank": + xp = self._contract(x, self.weight, self.lat_weight, self.lon_weight) + else: + xp = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type) + x = xp.contiguous() + + with amp.autocast(enabled=False): + x = self.inverse_transform(x) + + if hasattr(self, "bias"): + x = x + self.bias + + x = x.type(dtype) + + return x, residual + + +class SpectralAttention(nn.Module): + """ + Spherical non-linear FNO layer + """ + + def __init__( + self, + forward_transform, + inverse_transform, + in_channels, + out_channels, + operator_type="diagonal", + hidden_size_factor=2, + complex_activation="real", + bias=False, + spectral_layers=1, + drop_rate=0.0, + gain=1.0, + ): + super(SpectralAttention, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.operator_type = operator_type + self.spectral_layers = spectral_layers + + self.modes_lat = forward_transform.lmax + self.modes_lon = forward_transform.mmax + + # only storing the forward handle to be able to call it + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.scale_residual = ( + (self.forward_transform.nlat != self.inverse_transform.nlat) + or (self.forward_transform.nlon != self.inverse_transform.nlon) + or (self.forward_transform.grid != self.inverse_transform.grid) + ) + + assert inverse_transform.lmax == self.modes_lat + assert inverse_transform.mmax == self.modes_lon + + hidden_size = int(hidden_size_factor * self.in_channels) + + if operator_type == "diagonal": + self.mul_add_handle = compl_muladd2d_fwd + self.mul_handle = compl_mul2d_fwd + + # weights + scale = math.sqrt(2.0 / float(in_channels)) + w = [scale * torch.randn(self.in_channels, hidden_size, dtype=torch.complex64)] + for l in range(1, self.spectral_layers): + scale = math.sqrt(2.0 / float(hidden_size)) + w.append(scale * torch.randn(hidden_size, hidden_size, dtype=torch.complex64)) + self.w = nn.ParameterList(w) + + scale = math.sqrt(gain / float(in_channels)) + self.wout = nn.Parameter(scale * torch.randn(hidden_size, self.out_channels, dtype=torch.complex64)) + + if bias: + self.b = nn.ParameterList([scale * torch.randn(hidden_size, 1, 1, dtype=torch.complex64) for _ in range(self.spectral_layers)]) + + self.activations = nn.ModuleList([]) + for l in range(0, self.spectral_layers): + self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=scale)) + + elif operator_type == "l-dependant": + self.mul_add_handle = compl_exp_muladd2d_fwd + self.mul_handle = compl_exp_mul2d_fwd + + # weights + scale = math.sqrt(2.0 / float(in_channels)) + w = [scale * torch.randn(self.modes_lat, self.in_channels, hidden_size, dtype=torch.complex64)] + for l in range(1, self.spectral_layers): + scale = math.sqrt(2.0 / float(hidden_size)) + w.append(scale * torch.randn(self.modes_lat, hidden_size, hidden_size, dtype=torch.complex64)) + self.w = nn.ParameterList(w) + + if bias: + self.b = nn.ParameterList([scale * torch.randn(hidden_size, 1, 1, dtype=torch.complex64) for _ in range(self.spectral_layers)]) + + scale = math.sqrt(gain / float(in_channels)) + self.wout = nn.Parameter(scale * torch.randn(self.modes_lat, hidden_size, self.out_channels, dtype=torch.complex64)) + + self.activations = nn.ModuleList([]) + for l in range(0, self.spectral_layers): + self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=scale)) + + else: + raise ValueError("Unknown operator type") + + self.drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity() + + def forward_mlp(self, x): + B, C, H, W = x.shape + + xr = torch.view_as_real(x) + + for l in range(self.spectral_layers): + if hasattr(self, "b"): + xr = self.mul_add_handle(xr, self.w[l], self.b[l]) + else: + xr = self.mul_handle(xr, self.w[l]) + xr = torch.view_as_complex(xr) + xr = self.activations[l](xr) + xr = self.drop(xr) + xr = torch.view_as_real(xr) + + # final MLP + x = self.mul_handle(xr, self.wout) + + x = torch.view_as_complex(x) + + return x + + def forward(self, x): + dtype = x.dtype + residual = x + x = x.to(torch.float32) + + # FWD transform + with amp.autocast(enabled=False): + x = self.forward_transform(x) + if self.scale_residual: + residual = self.inverse_transform(x) + residual = residual.to(dtype) + + # MLP + x = self.forward_mlp(x) + + # BWD transform + with amp.autocast(enabled=False): + x = self.inverse_transform(x) + + # cast back to initial precision + x = x.to(dtype) + + return x, residual diff --git a/helpers.py b/helpers.py new file mode 100644 index 0000000..ad12ea1 --- /dev/null +++ b/helpers.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed as dist + +from makani.utils import comm + + +def count_parameters(model, device): + with torch.no_grad(): + total_count = 0 + for p in model.parameters(): + if not p.requires_grad: + continue + # reduce over model group + pcount = torch.tensor(p.numel(), device=device) + if hasattr(p, "is_shared_mp") and p.is_shared_mp: + if comm.get_size("model") > 1: + dist.all_reduce(pcount, group=comm.get_group("model")) + # divide by shared dims: + for cname in p.is_shared_mp: + pcount = pcount / comm.get_size(cname) + total_count += int(pcount.item()) + + return total_count + + +def check_parameters(model): + for p in model.parameters(): + if p.requires_grad: + print(p.shape, p.stride(), p.is_contiguous()) diff --git a/model_package.py b/model_package.py new file mode 100644 index 0000000..842708c --- /dev/null +++ b/model_package.py @@ -0,0 +1,268 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model package for easy inference/packaging. Model packages contain all the necessary data to +perform inference and its interface is compatible with earth2mip +""" +import os +import shutil +import json +import jsbeautifier +import numpy as np +import torch +from makani.utils.YParams import ParamsBase +from makani.third_party.climt.zenith_angle import cos_zenith_angle + +from makani.models import model_registry + +import datetime + +import logging + + +class LocalPackage: + """ + Implements the earth2mip/modulus Package interface. + """ + + def __init__(self, root): + self.root = root + + def get(self, path): + return os.path.join(self.root, path) + + +logger = logging.getLogger(__name__) + +THIS_MODULE = "makani.models.model_package" +MODEL_PACKAGE_CHECKPOINT_PATH = "training_checkpoints/best_ckpt_mp0.tar" +MINS_FILE = "mins.npy" +MAXS_FILE = "maxs.npy" +MEANS_FILE = "global_means.npy" +STDS_FILE = "global_stds.npy" + + +class ModelWrapper(torch.nn.Module): + """ + Model wrapper to make inference simple outside of makani. + + Attributes + ---------- + model : torch.nn.Module + ML model that is wrapped. + params : ParamsBase + parameter object containing information on how the model was initialized in makani + + Methods + ------- + forward(x, time): + performs a single prediction steps + """ + + def __init__(self, model, params): + super().__init__() + self.model = model + self.params = params + nlat = params.img_shape_x + nlon = params.img_shape_y + + self.lats = 90 - 180 * np.arange(nlat) / (nlat - 1) + self.lons = 360 * np.arange(nlon) / nlon + self.add_zenith = params.add_zenith + + def forward(self, x, time): + if self.add_zenith: + lon_grid, lat_grid = np.meshgrid(self.lons, self.lats) + cosz = cos_zenith_angle(time, lon_grid, lat_grid) + cosz = cosz.astype(np.float32) + z = torch.from_numpy(cosz).to(device=x.device) + while z.ndim != x.ndim: + z = z[None] + x = torch.cat([x, z], dim=1) + + return self.model(x) + + +def save_model_package(params): + """ + Saves out a self-contained model-package. + The idea is to save anything necessary for inference beyond the checkpoints in one location. + """ + # save out the current state of the parameters, make it human readable + config_path = os.path.join(params.experiment_dir, "config.json") + jsopts = jsbeautifier.default_options() + jsopts.indent_size = 2 + + with open(config_path, "w") as f: + msg = jsbeautifier.beautify(json.dumps(params.to_dict()), jsopts) + f.write(msg) + + if hasattr(params, "add_orography") and params.add_orography: + shutil.copy(params.orography_path, os.path.join(params.experiment_dir, "orography.nc")) + + if hasattr(params, "add_landmask") and params.add_landmask: + shutil.copy(params.landmask_path, os.path.join(params.experiment_dir, "land_mask.nc")) + + # a bit hacky - we should change this to get the normalization from the dataloader. + if hasattr(params, "global_means_path") and params.global_means_path is not None: + shutil.copy(params.global_means_path, os.path.join(params.experiment_dir, MEANS_FILE)) + if hasattr(params, "global_stds_path") and params.global_stds_path is not None: + shutil.copy(params.global_stds_path, os.path.join(params.experiment_dir, STDS_FILE)) + + if params.normalization == "minmax": + if hasattr(params, "min_path") and params.min_path is not None: + shutil.copy(params.min_path, os.path.join(params.experiment_dir, MINS_FILE)) + if hasattr(params, "max_path") and params.max_path is not None: + shutil.copy(params.max_path, os.path.join(params.experiment_dir, MAXS_FILE)) + + # write out earth2mip metadata.json + fcn_mip_data = { + "entrypoint": {"name": f"{THIS_MODULE}:load_time_loop"}, + } + with open(os.path.join(params.experiment_dir, "metadata.json"), "w") as f: + msg = jsbeautifier.beautify(json.dumps(fcn_mip_data), jsopts) + f.write(msg) + + +def _load_static_data(package, params): + if hasattr(params, "add_orography") and params.add_orography: + params.orography_path = package.get("orography.nc") + + if hasattr(params, "add_landmask") and params.add_landmask: + params.landmask_path = package.get("land_mask.nc") + + # a bit hacky - we should change this to correctly + if params.normalization == "zscore": + if hasattr(params, "global_means_path") and params.global_means_path is not None: + params.global_means_path = package.get(MEANS_FILE) + if hasattr(params, "global_stds_path") and params.global_stds_path is not None: + params.global_stds_path = package.get(STDS_FILE) + elif params.normalization == "minmax": + if hasattr(params, "min_path") and params.min_path is not None: + params.min_path = package.get(MINS_FILE) + if hasattr(params, "max_path") and params.max_path is not None: + params.max_path = package.get(MAXS_FILE) + else: + raise ValueError("Unknown normalization mode.") + + +def load_model_package(package, pretrained=True, device="cpu"): + """ + Loads model package and return the wrapper which can be used for inference. + """ + path = package.get("config.json") + params = ParamsBase.from_json(path) + logger.info(str(params.to_dict())) + _load_static_data(package, params) + + # assume we are not distributed + # distributed checkpoints might be saved with different params values + params.img_local_offset_x = 0 + params.img_local_offset_y = 0 + params.img_local_shape_x = params.img_shape_x + params.img_local_shape_y = params.img_shape_y + + # get the model and + model = model_registry.get_model(params).to(device) + + if pretrained: + best_checkpoint_path = package.get(MODEL_PACKAGE_CHECKPOINT_PATH) + # critical that this map_location be cpu, rather than the device to + # avoid out of memory errors. + checkpoint = torch.load(best_checkpoint_path, map_location=device) + state_dict = checkpoint["model_state"] + torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict, "module.") + model.load_state_dict(state_dict, strict=True) + + model = ModelWrapper(model, params=params) + + # by default we want to do evaluation so setting it to eval here + # 1-channel difference in training/eval mode + model.eval() + + return model + + +def load_time_loop(package, device=None, time_step_hours=None): + """This function loads an earth2mip TimeLoop object that + can be used for inference. + + A TimeLoop encapsulates normalization, regridding, and other logic, so is a + very minimal interface to expose to a framework like earth2mip. + + See https://github.com/NVIDIA/earth2mip/blob/main/docs/concepts.rst + for more info on this interface. + """ + + from earth2mip.networks import Inference + from earth2mip.grid import equiangular_lat_lon_grid + + config = package.get("config.json") + params = ParamsBase.from_json(config) + + if params.in_channels != params.out_channels: + raise NotImplementedError("Non-equal input and output channels are not implemented yet.") + + names = [params.channel_names[i] for i in params.in_channels] + + if params.normalization == "minmax": + min_path = package.get(MINS_FILE) + max_path = package.get(MAXS_FILE) + + a = np.load(min_path) + a = np.squeeze(a)[params.in_channels] + + b = np.load(max_path) + b = np.squeeze(b)[params.in_channels] + + # work around to implement minmax scaling based with the earth2mip + # Inference class below + center = (a + b) / 2 + scale = (b - a) / 2 + else: + center_path = package.get(MEANS_FILE) + scale_path = package.get(STDS_FILE) + + center = np.load(center_path) + center = np.squeeze(center)[params.in_channels] + + scale = np.load(scale_path) + scale = np.squeeze(scale)[params.in_channels] + + model = load_model_package(package, pretrained=True, device=device) + shape = (params.img_shape_x, params.img_shape_y) + + grid = equiangular_lat_lon_grid(nlat=params.img_shape_x, nlon=params.img_shape_y, includes_south_pole=True) + + if time_step_hours is None: + hour = datetime.timedelta(hours=1) + time_step = hour * params.get("dt", 6) + else: + time_step = datetime.timedelta(hours=time_step_hours) + + # Here we use the built-in class earth2mip.networks.Inference + # will later be extended to use the makani inferencer + inference = Inference( + model=model, + channel_names=names, + center=center, + scale=scale, + grid=grid, + n_history=params.n_history, + time_step=time_step, + ) + inference.to(device) + return inference diff --git a/model_registry.py b/model_registry.py new file mode 100644 index 0000000..01bd64d --- /dev/null +++ b/model_registry.py @@ -0,0 +1,170 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import importlib.util + +# we need this here for the code to work +import importlib_metadata +from importlib.metadata import EntryPoint, entry_points + +import logging + +from typing import List, Union +from functools import partial + +import torch.nn as nn + +from makani.utils.YParams import ParamsBase +from makani.models import SingleStepWrapper, MultiStepWrapper + + +def _construct_registry() -> dict: + registry = {} + entrypoints = entry_points(group="makani.models") + for entry_point in entrypoints: + registry[entry_point.name] = entry_point + return registry + + +def _register_from_module(model: nn.Module, name: Union[str, None] = None) -> None: + """ + registers a module in the registry + """ + + # Check if model is a torch module + if not issubclass(model, nn.Module): + raise ValueError(f"Only subclasses of torch.nn.Module can be registered. " f"Provided model is of type {type(model)}") + + # If no name provided, use the model's name + if name is None: + name = model.__name__ + + # Check if name already in use + if name in _model_registry: + raise ValueError(f"Name {name} already in use") + + # Add this class to the dict of model registry + _model_registry[name] = model + + +def _register_from_file(model_string: str, name: Union[str, None] = None) -> None: + """ + parses a string and attempts to get the module from the specified location + """ + + assert len(model_string.split(":")) == 2 + model_path, model_handle = model_string.split(":") + + if not os.path.exists(model_path): + raise ValueError(f"Expected string of format 'path/to/model_file.py:ModuleName' but {model_path} does not exist.") + + module_spec = importlib.util.spec_from_file_location(model_handle, model_path) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) + model = getattr(module, model_handle) + + _register_from_module(model, name) + + +def register_model(model: Union[str, nn.Module], name: Union[str, None] = None) -> None: + """ + Registers a model in the model registry under the provided name. If no name + is provided, the model's name (from its `__name__` attribute) is used. If the + name is already in use, raises a ValueError. + + Parameters + ---------- + model : torch.nn.Module + The model to be registered. Can be an instance of any class. + name : str, optional + The name to register the model under. If None, the model's name is used. + + Raises + ------ + ValueError + If the provided name is already in use in the registry. + """ + + if isinstance(model, str): + _register_from_file(model, name) + else: + _register_from_module(model, name) + +def list_models() -> List[str]: + """ + Returns a list of the names of all models currently registered in the registry. + + Returns + ------- + List[str] + A list of the names of all registered models. The order of the names is not + guaranteed to be consistent. + """ + return list(_model_registry.keys()) + + +def get_model(params: ParamsBase, **kwargs) -> "torch.nn.Module": + """ + Convenience routine that constructs the model passing parameters and kwargs. + Unloads all the parameters in the params datastructure as a dict. + + Parameters + ---------- + params : ParamsBase + parameter struct. + + Returns + ------- + model : torch.nn.Module + The registered model. + + Raises + ------ + KeyError + If no model is registered under the provided name. + """ + + if params is not None: + # makani requires that these entries are set in params for now + inp_shape = (params.img_crop_shape_x, params.img_crop_shape_y) + out_shape = (params.out_shape_x, params.out_shape_y) if hasattr(params, "out_shape_x") and hasattr(params, "out_shape_y") else inp_shape + inp_chans = params.N_in_channels + out_chans = params.N_out_channels + + if params.nettype not in _model_registry: + logging.warning(f"Net type {params.nettype} does not exist in the registry. Trying to register it.") + register_model(params.nettype, params.nettype) + + model_handle = _model_registry.get(params.nettype) + if model_handle is not None: + if isinstance(model_handle, (EntryPoint, importlib_metadata.EntryPoint)): + model_handle = model_handle.load() + + model_handle = partial(model_handle, inp_shape=inp_shape, out_shape=out_shape, inp_chans=inp_chans, out_chans=out_chans, **params.to_dict()) + else: + raise KeyError(f"No model is registered under the name {name}") + + # wrap into Multi-Step if requested + if params.n_future > 0: + model = MultiStepWrapper(params, model_handle) + else: + model = SingleStepWrapper(params, model_handle) + + return model + + +# initialize the internal state upon import +_model_registry = _construct_registry() diff --git a/networks/__pycache__/sfnonet.cpython-310.pyc b/networks/__pycache__/sfnonet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16d8b41c0e689862872942679a33cd69ad9631f4 GIT binary patch literal 13751 zcmb7Ld5|2}S?_CRy64_Qdub(&ZjUUjw32N3kmXBW+ljO>k?fEr$#^uccXwucdPc8% zbZmOCvE<4h)aK?0Nyu&>0aH|fgir+(Aq5mrKvf7;6h#xLP*6#fgB(tr#Ds|7_qyj= ztzww1`R04yJHK~**V{%Wla%mz`-k0yMp=^ngA)Cp7!nWR>%JgM5|fy0OLf_lWx^F( zF%|KvnyUEK>M=7W@>DxsPnZdTYqnlDOhe!?J6TVeDS^lBbUkBc>RB^e&zU(vOMsR) z^YtNfs6K2C3z}k&n4_ZoE%h;TT+nosn=mKpljfvI8}?LvtGTs4ZBEy>ncM2y&FzAo zw0G1CWx7(6(5`Yw9)GoU!)T4y+xv4p}qT zq;deb&7Ieu(DUK*k#w9 zM=DuavMS5XM%D3|;))!7P5-5OLvmCG5aEeNxztJk&b93j92DvAlOBQ#n z4xy+8Bd4pb$Ez1wp2a%kV5*O}p-PbL%cG%A_N6zBS`M^Iqh7}lrrb&yBjP2mQ3BIt z+qUcyw-zzLVCy2-M%Dsv)Hi0LYJdArI<8lCD%O01*N?Jbw6E(k<*ODaI_5QV7ISBx zwY>6aS_4{;{*M+V+><~o$&{I7Du!$*hH7X=%!nHaU~xk?Vl49xX+dTNOI|l{giR4qS=&V{^Y&d262F|85|K+gXZ#_5{XdFut zr6sA7i$+90glF(o@pXR{z>_@Lm#!&m%DN;=zQ*Kbg>UmzU-lJW#b3;i`$<3LXZ*CE zWy)&Gmr5Fw&QVzvWsmy_CbcAf+>7vtGQfwC*STutsb(*{R_r zUXY9N`IV&8a1Mn$7ZX9I(X@COyR+n7ZCb(bQk5~wDY@9Jr3JBVgJEoorfpp*Az8gt zrhOXZ+-QTBY+6be2M9-^gqEy8yHG8=fwow#I(!mSQq=ek`~-$;(Qdk6TYL}Fd@tq2 z!#s_bo#h}C_E=pJo$!5Bs#jDGVr~ng4lZk64~DAYaqA-)R{es-N~~HBQepdI zmc)>GfqJ@wz%j+_Zql}xD-w9e#;fL@O$X7UxIRtkRB@wNXi-b#UCFPZ^n{^(Q zE2hMT<5SekR)Bdv4LHZQbqI5OJE24^#%T{Z<+@cW1<6uL%sJudQt4u=Y==2wD)|YN z4>HjTRxI1bw(E`cC^bm~7CZA85SKd8B_+PuEM~u*e^8XHpwPf)G)lQY4Wx6|F_|e$ zWg62gwMO3P;h7IzA{ikxut~RWKl1PZ>n1(iCSUG@RKYy zDP6;sE2VwI%ha;CUC*+3d%{szqCJ7zR|kb7Cr39kW?3QGu^K~-yg$Jb+objca))Ze zzEsQCM!eCrE&hbBZk51AYE_XXX9}&P(aQ5)W-Yr;eRXR(biZei_o9>V##q+Reo*?b za!r0g_Qs>TS`*SBEu_JSKk9Gsb@ZEC&7z$N-woUXsP7jH_66 zJr?Ly$T_d-T@5t4HdnSi?JPTu%TB>vt~LvmhT}kkgQzQzTr3og|JQwgK8V-LUcF@p z`hr#VTHJDrMv$sHrB2e!MUu8d2n>kVP@3}H7_=EmR3Vi!q?ke6HN|yh=0!59PuEnC z=_#y`AVTDbQ(N4_i!=eGJB2sdFt9*hv^*i=gQQ&swp4CfZb%h!IB7Sa0tl*^L$|G{ zObIuHnQyeb0SYD%tq_D_5OW|)j|7V21p1Sm8IF@03>7e0#jNrY&2XTR^bQhKylj_( z_+_iQxP(?PZ%Yu;7r7TCpDEidO9*apUkYgvB)ju;B3}Utv4UkGtB%#6?$u* ze#T#q_RSQjV38JQYEb{5fN6-aM=e3E;X-r?NhahM=?z!#b>9n+oyw|NIV;C+W|ge0 z-=tJpO)7C&Q*>oa*5v}eN&M;146=%%ZYJ$IH2pjJP+HCQ)!sEwJE=^_=}om1{rl6I zxRRE&n=89-FSs^WKu8vm1@R~`^&x?O)?`vg+Oh|n$=4J2QinE&JM^`iWJ!pvE zOW-~N=Kuoj@yE_Q&CgSsHlxtI`2B=!rf<=`b{Sun&TYTGrD0cgcHXVO*{|d;i#w6# zVcU8U->@-aj3Q?gzgvtkV;pdT z?G}33B%6X>roq7xuO*^(QI73lTiNt=84eCT%wyZw_Un+E5CuC}0Xkfg?PR;~n-aQQ z8arlhu=89K_IR}dr@>}Q-Mre8*LyP5>>k9E4#Mu&r?xmNM>gS<4R@dMvyP;zdZXhMyC=j5<5mNrV z5bH28ExW}ZBP#95AXAMz4kfyM`*r^tbPt3Y>uzA34{bZL9_j1;HF_M>SYfG4B^9iId zz9IDaSKbb<;HSv#$3|DR_I`iACbjQq7o8I*u>~c*D)<|V`1==>--+@gs4-6Eog^r^ z+OCKOx8G{4ua!}>GKE%NL7A=8GQLIB(!Jdb+U*A=HvM5Z6`_^q*7htyvj=B;y?uWA zntWP%>+|gc{sHpgwD0r}_-Uq0NWc&JcS0MT!p!Y=AUA4vv=8|QLD>zwC@2ToGi<7T zm~Cqx5oM41`~BjA(w>DLO(pO1i~b!VtpO7`-kts-f5uNOs8xw=_qF%QY=<8MEciLV zo&GrBE`OhYU{V6M+aCtp?5F>?>5*1@PDaP`-{xg8FEupPq?IBPfW)G0f+`d~-9ubsNf^x#aIY+5S zqrB7Y6Cp>n$1t)v!PVXViFH`z?7>y&ChdDR%cnh$G7m*%Xhw*ui7r?7ba;xmytk9? z=g~0)ho9`{5WPL_-{arAh`wnq*-4tA(gT~$X*e4=$(W~q!@NA)#c(1qBVVNX!oBzi zo#o=A^Pvq$J_+*QLQY&2s9UvLE7KiYujkP$EVP`^-s5v9XlhPn>2jb#u@1Bam={GI zI(;ir>EZ4Sts?RalJPTki(_(A)!d(kaQ-O+(tBKZ;e9_mDi+o={?Ftgg1oBgt}VQv&gTn z%pBpQu%G^xF6T@^?=mIFb)-J_?6~MrJTT|u02EW9 z>F-#}d=N0d-cXmd;)ErTqip^$En$HtHNg@{N{`lMDkx0 zp9sAW4b&4_y*NJ6+8cBdgs&l_R+iw94+K@9!VnKK+-ky+(up?c7phoPxDmuG@-j8t zWg?giBLm$KMc2pgMmEM9E!Y?iaj2cE0(}eBC(j5!Ltuz<5@4Hxdf|MNCK@Ct?1$kB z^CZy0tuXufvy`dBpASz|#p2K-B|b)A3uR;n<6!(C5t9gdwQTDQ4oeUh=A9eF;CorL zf}|L3iNO~Y!~~9s32Cin5D!sdyW8R@D^^*9b^y@@yvP!`5OToAa4GW=%XN#mau z?Q}Mm2&siP4&6QG$iTf7o`=A|v{3v4PP{2vuadWfgKa`& zbnre`+%q(8alz7k8^pw@dY2)>g0xE5JTwjDqSR@h2rR)X4QpY+H4~K(+05fLLPGBu zfb70F{8RmZqw1(Utm>)`Usd=AUzqT7X|mSmuhM^@C&y$3akd^^e^t{pdAm9x>uO$j z$+GgWJPMDAF6ZUEl2+5e((S;YTG)07`+G3CFtxbmu&P=2WE z$`3R{`M#D^UeQv@_q4R~T`i-$tYww&XgTHET3-2{H8&STYjVfQ)wkTiI z# zdEy8o*5!h-2UJRL|8ZW-S6L~tQ4QV8QrNlYC_KwVJbj1$K6&wB`;wRON)+wkIR2^ND%slsqQ zFH&)lQutj=b~I3Q21+|A^2oTR;QXWAJW$#x&sa-RFn=8iy=T-?Xx*r#+XlioQY{1b zLBdy<{vq|9augiOksG0w-y~;9)FD$v5xFYUS7TAFVNq-Iyb;F+e`LJB?2We5-j;S| zfVa;M@b+un@e^;^0p5PEn@7Cm1aDbCCwLn}%j0c&vJmmbk{?p~xPnU~jfH}^;|pG+uxwe)nZgrghuM}Z zS_>t=&}O?GqK44Rq1TodtwJdO3#e3JjXGijh0g1qs#`xe(08|qd436}Pk7yUlQ7c8 zd%9Lp<$p+oKLUV14zd+C(}IVPewVy5)lm;`Ww^L&qwc04E|Z|AP3N9CPpVV*1r*Xf z6)i9*g0;vS(1=uTi3w%I9PWEz6g6QEK1dp9jl_HqTd*6j%zlHIAd$~sBJi68ev3ew zK!!k;K#qV=7taz#S_*#yf!_x(GhJO)2%p$e_;ZHKXAqDQlVc{a7`;k;iA1MF z;&NDYgHgySSn>)9MpE(gY!U$;gnU#Ij#+hDo`z5$d`g`Hz8%**iI%Jq<`6!CfDYy9 zN*>}bk1$dm;v#-?6m$}KXi;9d|NkHw#514H-s!2U?tnpzkxFC~3t2(hI-DBH620+L z=R=jssMs!Ev@KgsWEY34&JGz>+Spu3f?K7dArK@2%I``UxuPL_S%zjOIrieLgo@Jm{_^L$2x+2+%SMbJ}j&J0X*?gF+uX0 z<{@YT>7;sznocXcD!*$QE;|aN;UNj8PF+i`OEiURsSc&4K?9dl1qTd|_ z$1Sy-Udu2z_uz2X;DCpt9uIA_YnXnFC%=HHX%}Yph6CLW~^(A3QM!hs?yf_c*_gkQeJFr z(#M$4zIF=_9~j_)HqS^Kfl6Pa@Fm7I6$KcjwXGd2;^Q-dkLeA3Al>1kHqEl& zXm`4{Et;1vc-v`SK-&@0uW4d_B>p84>MYRQi~ebo=H}<%g@Qp%u_6XE#>fEUKZ{il zCCfh$Murc69|1bd5FzBe6L_E_G(?X{J^p^m!V?Y(UcX5A&d{t8cQ@VW{GR{<1+M^} zB_tv4_hUynoxVUXw>&ti@P?Kye7y5*;0efLtI=}9d%x!c7(ssT$ncL)(k38A`3F=x zh7sYxln7fUJLX*34%8+L+uWt9b-7e_OC*=^c#^ME<$omb+XQIre3ZZ^2>cEJj%@E- z#;%Ads#B)ua`2uKanaY)k$eIr5FLlzlBRg9(9N^*^pBM1wL})WdXlbM#Aeg-L0q`Q z^7KvR`8X_^B>joJ?aC-@k1TAEmEvn3=1%x_SEvqLqycgQlLTA>UKg|oyVM1j3A;jo zmNtyyg{v%<`f4Yq7t;&ho^7X$6@W0NcV2kQI*-xB{TFVT*v2tUQ=MXFL!>!Kl;B>e z;0>P$F+Yn!fm*|Blz4R!2COAyAOAea;ewqY?0y2hML0^yy9kh1D~lcLUI3y?u&*|K z^>=F&`X^Dc|0+F)s7CiiMf%nF9M)FNCq;i1bT#myd;KANUnUoGCvN>TgrKQfh%__h zBupYRZY6+g0@no|6SyJpIPip-6u1uDz|lsgfZz+7$<6m zkh@0uL@;qXfB1k#LQ6uI4ylap7=VYzh>&R{HIypCzrT)v8?`sjg(mnnD1u9`h##>p zp<2^*5NPC!;eSD7vFUm&((ZjP-2wPYWAe(*+l{O1r^gj@4>fczf%^zh8@J!v-$Cws zs0U(1mqukyluak^A!P6`0btK`UJeYtFU3=Y?p9H;Q^Z<_8>P3P=f#T7^KVddo2~12 zk?{){68|=ppj7dAkm~;XhUIdajnGf|EU-Wq55i^~N8HrrSeR^iRok8EJ)0An5qz)o zcx;AtZ747daj#>`xx%e$j)rfqXM`^i_3#g} zBMJVuC}eKkK%HUmUq=7h0e*T3{yA!scCoNS@Si;5G58D|y)%)^uxFZ(%Ovz)r=wzm zQDHd+(-<{9EAQ`WrVbx6yo4?n5^t1-wq;`v!apm#*}|_KB%+3bkxq-ze834@=HqB1 zh@TY>?(p5-Ng5U1lj7_JN%7Akf`9%MqSB@kOUb`S*e1&!=qE*2)+y2r+}8o5bL5bt yGmSq5jyec6oNu_>(kfXTlp0otPfO1YPe{*6&ncIO<+v^!bU)6=6s?oefBe6Ov$9

...bo", x[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes].real, self.w1[0]) + - torch.einsum("...bi,bio->...bo", x[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes].imag, self.w1[1]) + + self.b1[0] + ) + + o1_imag[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes] = F.relu( + torch.einsum("...bi,bio->...bo", x[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes].imag, self.w1[0]) + + torch.einsum("...bi,bio->...bo", x[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes].real, self.w1[1]) + + self.b1[1] + ) + + o2_real[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes] = ( + torch.einsum("...bi,bio->...bo", o1_real[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes], self.w2[0]) + - torch.einsum("...bi,bio->...bo", o1_imag[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes], self.w2[1]) + + self.b2[0] + ) + + o2_imag[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes] = ( + torch.einsum("...bi,bio->...bo", o1_imag[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes], self.w2[0]) + + torch.einsum("...bi,bio->...bo", o1_real[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes], self.w2[1]) + + self.b2[1] + ) + + x = torch.stack([o2_real, o2_imag], dim=-1) + x = F.softshrink(x, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + x = x.reshape(B, H, W // 2 + 1, C) + x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm="ortho") + x = x.type(dtype) + + return x + bias + + +class Block(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + double_skip=True, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + # self.drop_path = nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.double_skip = double_skip + + def forward(self, x): + residual = x + x = self.norm1(x) + x = self.filter(x) + + if self.double_skip: + x = x + residual + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path(x) + x = x + residual + return x + + +class PrecipNet(nn.Module): + def __init__(self, backbone, patch_size=(16, 16), inp_chans=2, out_chans=2, **kwargs): + super().__init__() + self.patch_size = patch_size + self.inp_chans = inp_chans + self.out_chans = out_chans + self.backbone = backbone + self.ppad = PeriodicPad2d(1) + self.conv = nn.Conv2d(self.out_chans, self.out_chans, kernel_size=3, stride=1, padding=0, bias=True) + self.act = nn.ReLU() + + def forward(self, x): + x = self.backbone(x) + x = self.ppad(x) + x = self.conv(x) + x = self.act(x) + return x + + +class AdaptiveFourierNeuralOperatorNet(nn.Module): + def __init__( + self, + inp_shape=(720, 1440), + patch_size=(16, 16), + inp_chans=2, + out_chans=2, + embed_dim=768, + num_layers=12, + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.0, + num_blocks=16, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0, + **kwargs, + ): + super(AdaptiveFourierNeuralOperatorNet, self).__init__() + self.img_size = inp_shape + self.patch_size = patch_size + self.inp_chans = inp_chans + self.out_chans = out_chans + self.embed_dim = embed_dim + + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=self.patch_size, in_chans=self.inp_chans, embed_dim=self.embed_dim) + num_patches = self.patch_embed.num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] + + self.h = self.img_size[0] // self.patch_size[0] + self.w = self.img_size[1] // self.patch_size[1] + + self.blocks = nn.ModuleList( + [ + Block( + dim=self.embed_dim, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + num_blocks=num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + for i in range(num_layers) + ] + ) + + self.head = nn.Linear(self.embed_dim, self.out_chans * self.patch_size[0] * self.patch_size[1], bias=False) + + with torch.no_grad(): + nn.init.trunc_normal_(self.pos_embed, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x).transpose(1, 2) + x = x + self.pos_embed + x = self.pos_drop(x) + + x = x.reshape(B, self.h, self.w, self.embed_dim) + for blk in self.blocks: + x = blk(x) + + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + + # rearrange + b = x.shape[0] + xv = x.view(b, self.h, self.w, self.patch_size[0], self.patch_size[1], -1) + xvt = torch.permute(xv, (0, 5, 1, 3, 2, 4)).contiguous() + x = xvt.view(b, -1, (self.h * self.patch_size[0]), (self.w * self.patch_size[1])) + + return x diff --git a/networks/afnonet_v2.py b/networks/afnonet_v2.py new file mode 100644 index 0000000..a131f36 --- /dev/null +++ b/networks/afnonet_v2.py @@ -0,0 +1,315 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from collections import OrderedDict +from copy import Error, deepcopy +from re import S +from numpy.lib.arraypad import pad +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.fft +from torch.nn.modules.container import Sequential +from torch.utils.checkpoint import checkpoint_sequential +from typing import Optional +import math + +# helpers +from makani.models.common import ComplexReLU, PatchEmbed, DropPath, MLP + + +@torch.jit.script +def compl_mul_add_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + tmp = torch.einsum("bkixys,kior->srbkoxy", a, b) + res = torch.stack([tmp[0, 0, ...] - tmp[1, 1, ...], tmp[1, 0, ...] + tmp[0, 1, ...]], dim=-1) + return res + + +@torch.jit.script +def compl_mul_add_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bkixy,kio->bkoxy", ac, bc) + res = torch.view_as_real(resc) + return res + + +class AFNO2D(nn.Module): + def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.0, hard_thresholding_fraction=1, hidden_size_factor=1, use_complex_kernels=False): + super(AFNO2D, self).__init__() + assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.num_blocks = num_blocks + self.block_size = self.hidden_size // self.num_blocks + self.hard_thresholding_fraction = hard_thresholding_fraction + self.hidden_size_factor = hidden_size_factor + self.scale = 0.02 + self.mult_handle = compl_mul_add_fwd_c if use_complex_kernels else compl_mul_add_fwd + + # new + self.w1 = nn.Parameter(self.scale * torch.randn(self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor, 2)) + self.b1 = nn.Parameter(self.scale * torch.randn(1, self.num_blocks * self.block_size, 1, 1)) + self.w2 = nn.Parameter(self.scale * torch.randn(self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size, 2)) + # self.b2 = nn.Parameter(self.scale * torch.randn(self.num_blocks, self.block_size, 1, 1, 2)) + + # self.act = nn.ReLU() + self.act = ComplexReLU(negative_slope=0.0, mode="cartesian") + + def forward(self, x): + bias = x + + dtype = x.dtype + x = x.float() + B, C, H, W = x.shape + total_modes_H = H // 2 + 1 + total_modes_W = W // 2 + 1 + kept_modes_H = int(total_modes_H * self.hard_thresholding_fraction) + kept_modes_W = int(total_modes_W * self.hard_thresholding_fraction) + + x = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho") + x = x.view(B, self.num_blocks, self.block_size, H, W // 2 + 1) + + # do spectral conv + x = torch.view_as_real(x) + x_fft = torch.zeros(x.shape, device=x.device) + + if kept_modes_H == total_modes_H: + oac = torch.view_as_complex(self.mult_handle(x[:, :, :, :, :kept_modes_W, :], self.w1)) + oa = torch.view_as_real(self.act(oac)) + x_fft[:, :, :, :, :kept_modes_W, :] = self.mult_handle(oa, self.w2) + else: + olc = torch.view_as_complex(self.mult_handle(x[:, :, :, :kept_modes_H, :kept_modes_W, :], self.w1)) + ohc = torch.view_as_complex(self.mult_handle(x[:, :, :, -kept_modes_H:, :kept_modes_W, :], self.w1)) + + ol = torch.view_as_real(self.act(olc)) + oh = torch.view_as_real(self.act(ohc)) + + x_fft[:, :, :, :kept_modes_H, :kept_modes_W, :] = self.mult_handle(ol, self.w2) + x_fft[:, :, :, -kept_modes_H:, :kept_modes_W, :] = self.mult_handle(oh, self.w2) + + # finalize + x = F.softshrink(x_fft, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + x = x.reshape(B, C, H, W // 2 + 1) + x = torch.fft.irfft2(x, s=(H, W), dim=(-2, -1), norm="ortho") + x = x.type(dtype) + + return x + self.b1 + bias + + +class Block(nn.Module): + def __init__( + self, + h, + w, + dim, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0, + use_complex_kernels=True, + skip_fno="linear", + nested_skip_fno=True, + checkpointing=False, + verbose=True, + ): + super(Block, self).__init__() + + # norm layer + self.norm1 = norm_layer() # ((h,w)) + + if skip_fno is None: + if verbose: + print("Using no skip connection around FNO.") + + elif skip_fno == "linear": + # self.skip_layer = nn.Linear(dim, dim) + self.skip_layer = nn.Conv2d(dim, dim, 1, 1) + if verbose: + print("Using Linear skip connection around FNO.") + + elif skip_fno == "identity": + self.skip_layer = nn.Identity() + if verbose: + print("Using Identity skip connection around FNO.") + + else: + if verbose: + print(f"Got skip_fno={skip_fno}, not using any skip around FNO -- use linear or identity to change this.") + self.skip_fno = skip_fno + + self.nested_skip_fno = nested_skip_fno + + # filter + self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction, use_complex_kernels=use_complex_kernels) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + # norm layer + self.norm2 = norm_layer() # ((h,w)) + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_rate=drop, checkpointing=checkpointing) + + def forward(self, x): + residual = x + + x = self.norm1(x) + x = self.filter(x) + + if self.skip_fno is not None: + x = x + self.skip_layer(residual) + if not self.nested_skip_fno: + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path(x) + x = x + residual + return x + + +class AdaptiveFourierNeuralOperatorNet(nn.Module): + def __init__( + self, + inp_shape=(720, 1440), + patch_size=(16, 16), + inp_chans=2, + out_chans=2, + embed_dim=768, + num_layers=12, + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.0, + num_blocks=16, + sparsity_threshold=0.01, + normalization_layer="instance_norm", + skip_fno="linear", + nested_skip_fno=True, + hard_thresholding_fraction=1.0, + checkpointing=False, + use_complex_kernels=True, + verbose=False, + **kwargs, + ): + super(AdaptiveFourierNeuralOperatorNet, self).__init__() + self.img_size = inp_shape + self.patch_size = patch_size + self.inp_chans = inp_chans + self.out_chans = out_chans + self.embed_dim = embed_dim + + # some sanity checks + assert len(patch_size) == 2, f"Expected patch_size to have two entries but got {patch_size} instead" + assert (self.img_size[0] % self.patch_size[0] == 0) and ( + self.img_size[1] % self.patch_size[1] == 0 + ), f"Error, the patch size {self.patch_size} does not divide the image dimensions {self.img_size} evenly." + + self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=self.patch_size, in_chans=self.inp_chans, embed_dim=self.embed_dim) + num_patches = self.patch_embed.num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, num_patches)) + self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity() + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] + + # compute the downscaled image size + self.h = self.img_size[0] // self.patch_size[0] + self.w = self.img_size[1] // self.patch_size[1] + + # pick norm layer + if normalization_layer == "layer_norm": + norm_layer = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6) + elif normalization_layer == "instance_norm": + norm_layer = partial(nn.InstanceNorm2d, num_features=embed_dim, eps=1e-6, affine=True, track_running_stats=False) + else: + raise NotImplementedError(f"Error, normalization {normalization_layer} not implemented.") + + self.blocks = nn.ModuleList( + [ + Block( + h=self.h, + w=self.w, + dim=self.embed_dim, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + num_blocks=num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + use_complex_kernels=use_complex_kernels, + skip_fno=skip_fno, + nested_skip_fno=nested_skip_fno, + checkpointing=checkpointing, + verbose=verbose, + ) + for i in range(num_layers) + ] + ) + + # head + self.head = nn.Conv2d(embed_dim, self.out_chans * self.patch_size[0] * self.patch_size[1], 1, bias=False) + + with torch.no_grad(): + nn.init.trunc_normal_(self.pos_embed, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): + nn.init.trunc_normal_(m.weight, std=0.02) + # nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.InstanceNorm3d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + # reshape + x = x.reshape(B, self.embed_dim, self.h, self.w) + + for blk in self.blocks: + x = blk(x) + + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + + # new: B, C, H, W + b = x.shape[0] + xv = x.view(b, self.patch_size[0], self.patch_size[1], -1, self.h, self.w) + xvt = torch.permute(xv, (0, 3, 4, 1, 5, 2)).contiguous() + x = xvt.view(b, -1, (self.h * self.patch_size[0]), (self.w * self.patch_size[1])) + + return x diff --git a/networks/debug.py b/networks/debug.py new file mode 100644 index 0000000..5165465 --- /dev/null +++ b/networks/debug.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + + +class DebugNet(nn.Module): + def __init__(self, **kwargs): + super().__init__() + + # create dummy param so that it won't crash in optimizer instantiation + self.factor = nn.Parameter(torch.ones((1), dtype=torch.float32)) + + def forward(self, x): + return self.factor * x diff --git a/networks/sfnonet.py b/networks/sfnonet.py new file mode 100644 index 0000000..9bf4fdf --- /dev/null +++ b/networks/sfnonet.py @@ -0,0 +1,673 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +import math +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from torch.cuda import amp + +from functools import partial + +# helpers +from makani.models.common import DropPath, MLP, EncoderDecoder + +# import global convolution and non-linear spectral layers +from makani.models.common import SpectralConv, FactorizedSpectralConv, SpectralAttention + +# get spectral transforms from torch_harmonics +import torch_harmonics as th +import torch_harmonics.distributed as thd + +# wrap fft, to unify interface to spectral transforms +from makani.models.common import RealFFT2, InverseRealFFT2 +from makani.mpu.layers import DistributedRealFFT2, DistributedInverseRealFFT2, DistributedMLP, DistributedEncoderDecoder + +# more distributed stuff +from makani.utils import comm + +# layer normalization +from modulus.distributed.mappings import scatter_to_parallel_region, gather_from_parallel_region +from makani.mpu.layer_norm import DistributedInstanceNorm2d, DistributedLayerNorm + +# for annotation of models +import modulus +from modulus.models.meta import ModelMetaData + + +class SpectralFilterLayer(nn.Module): + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + filter_type="linear", + operator_type="diagonal", + hidden_size_factor=1, + factorization=None, + rank=1.0, + separable=False, + complex_activation="real", + spectral_layers=1, + bias=False, + drop_rate=0.0, + gain=1.0, + ): + super(SpectralFilterLayer, self).__init__() + + if filter_type == "non-linear": + self.filter = SpectralAttention( + forward_transform, + inverse_transform, + embed_dim, + embed_dim, + operator_type=operator_type, + hidden_size_factor=hidden_size_factor, + complex_activation=complex_activation, + spectral_layers=spectral_layers, + drop_rate=drop_rate, + bias=bias, + gain=gain, + ) + + elif filter_type == "linear" and factorization is None: + self.filter = SpectralConv( + forward_transform, + inverse_transform, + embed_dim, + embed_dim, + operator_type=operator_type, + separable=separable, + bias=bias, + gain=gain, + ) + + elif filter_type == "linear" and factorization is not None: + self.filter = FactorizedSpectralConv( + forward_transform, + inverse_transform, + embed_dim, + embed_dim, + operator_type=operator_type, + rank=rank, + factorization=factorization, + separable=separable, + bias=bias, + gain=gain, + ) + + else: + raise (NotImplementedError) + + def forward(self, x): + return self.filter(x) + + +class FourierNeuralOperatorBlock(nn.Module): + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + filter_type="linear", + operator_type="diagonal", + mlp_ratio=2.0, + mlp_drop_rate=0.0, + path_drop_rate=0.0, + act_layer=nn.GELU, + norm_layer=(nn.Identity, nn.Identity), + rank=1.0, + factorization=None, + separable=False, + inner_skip="linear", + outer_skip=None, + use_mlp=False, + comm_feature_inp_name=None, + comm_feature_hidden_name=None, + complex_activation="real", + spectral_layers=1, + bias=False, + final_activation=False, + checkpointing=0, + ): + super(FourierNeuralOperatorBlock, self).__init__() + + # determine some shapes + if comm.get_size("spatial") > 1: + self.input_shape_loc = (forward_transform.lat_shapes[comm.get_rank("h")], + forward_transform.lon_shapes[comm.get_rank("w")]) + self.output_shape_loc = (inverse_transform.lat_shapes[comm.get_rank("h")], + inverse_transform.lon_shapes[comm.get_rank("w")]) + else: + self.input_shape_loc = (forward_transform.nlat, forward_transform.nlon) + self.output_shape_loc = (inverse_transform.nlat, inverse_transform.nlon) + + # norm layer + self.norm0 = norm_layer[0]() + + if act_layer == nn.Identity: + gain_factor = 1.0 + else: + gain_factor = 2.0 + + if inner_skip == "linear": + self.inner_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1, bias=False) + gain_factor /= 2.0 + nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / embed_dim)) + elif inner_skip == "identity": + self.inner_skip = nn.Identity() + gain_factor /= 2.0 + elif inner_skip == "none": + pass + else: + raise ValueError(f"Unknown skip connection type {inner_skip}") + + # convolution layer + self.filter = SpectralFilterLayer( + forward_transform, + inverse_transform, + embed_dim, + filter_type, + operator_type, + hidden_size_factor=mlp_ratio, + factorization=factorization, + rank=rank, + separable=separable, + complex_activation=complex_activation, + spectral_layers=spectral_layers, + bias=bias, + drop_rate=path_drop_rate, + gain=gain_factor, + ) + + self.act_layer0 = act_layer() + + # norm layer + self.norm1 = norm_layer[1]() + + if final_activation and act_layer != nn.Identity: + gain_factor = 2.0 + else: + gain_factor = 1.0 + + if outer_skip == "linear": + self.outer_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1, bias=False) + gain_factor /= 2.0 + torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / embed_dim)) + elif outer_skip == "identity": + self.outer_skip = nn.Identity() + gain_factor /= 2.0 + elif outer_skip == "none": + pass + else: + raise ValueError(f"Unknown skip connection type {outer_skip}") + + if use_mlp == True: + MLPH = DistributedMLP if (comm.get_size("matmul") > 1) else MLP + mlp_hidden_dim = int(embed_dim * mlp_ratio) + self.mlp = MLPH( + in_features=embed_dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop_rate=mlp_drop_rate, + drop_type="features", + comm_inp_name=comm_feature_inp_name, + comm_hidden_name=comm_feature_hidden_name, + checkpointing=checkpointing, + gain=gain_factor, + ) + + # dropout + self.drop_path = DropPath(path_drop_rate) if path_drop_rate > 0.0 else nn.Identity() + + if final_activation: + self.act_layer1 = act_layer() + + def forward(self, x): + """ + Updated FNO block + """ + + x, residual = self.filter(x) + + x = self.norm0(x) + + if hasattr(self, "inner_skip"): + x = x + self.inner_skip(residual) + + if hasattr(self, "act_layer0"): + x = self.act_layer0(x) + + if hasattr(self, "mlp"): + x = self.mlp(x) + + x = self.norm1(x) + + x = self.drop_path(x) + + if hasattr(self, "outer_skip"): + x = x + self.outer_skip(residual) + + if hasattr(self, "act_layer1"): + x = self.act_layer1(x) + + return x + + +class SphericalFourierNeuralOperatorNet(nn.Module): + """ + SFNO implementation as in Bonev et al.; Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere + """ + + def __init__( + self, + spectral_transform="sht", + model_grid_type="equiangular", + sht_grid_type="legendre-gauss", + filter_type="linear", + operator_type="dhconv", + inp_shape=(721, 1440), + out_shape=(721, 1440), + scale_factor=8, + inp_chans=2, + out_chans=2, + embed_dim=32, + num_layers=4, + repeat_layers=1, + use_mlp=True, + mlp_ratio=2.0, + encoder_ratio=1, + decoder_ratio=1, + activation_function="gelu", + encoder_layers=1, + pos_embed="none", + pos_drop_rate=0.0, + path_drop_rate=0.0, + mlp_drop_rate=0.0, + normalization_layer="instance_norm", + max_modes=None, + hard_thresholding_fraction=1.0, + big_skip=True, + rank=1.0, + factorization=None, + separable=False, + complex_activation="real", + spectral_layers=3, + bias=False, + checkpointing=0, + **kwargs, + ): + super(SphericalFourierNeuralOperatorNet, self).__init__() + + self.inp_shape = inp_shape + self.out_shape = out_shape + self.inp_chans = inp_chans + self.out_chans = out_chans + self.embed_dim = embed_dim + self.repeat_layers = repeat_layers + self.big_skip = big_skip + self.checkpointing = checkpointing + + # compute the downscaled image size + self.h = int(self.inp_shape[0] // scale_factor) + self.w = int(self.inp_shape[1] // scale_factor) + + # initialize spectral transforms + self._init_spectral_transforms(spectral_transform, model_grid_type, sht_grid_type, hard_thresholding_fraction, max_modes) + + # determine activation function + if activation_function == "relu": + activation_function = nn.ReLU + elif activation_function == "gelu": + activation_function = nn.GELU + elif activation_function == "silu": + activation_function = nn.SiLU + else: + raise ValueError(f"Unknown activation function {activation_function}") + + # set up encoder + if comm.get_size("matmul") > 1: + self.encoder = DistributedEncoderDecoder( + num_layers=encoder_layers, + input_dim=self.inp_chans, + output_dim=self.embed_dim, + hidden_dim=int(encoder_ratio * self.embed_dim), + act_layer=activation_function, + input_format="nchw", + comm_inp_name="fin", + comm_out_name="fout", + ) + fblock_mlp_inp_name = self.encoder.comm_out_name + fblock_mlp_hidden_name = "fout" if (self.encoder.comm_out_name == "fin") else "fin" + else: + self.encoder = EncoderDecoder( + num_layers=encoder_layers, + input_dim=self.inp_chans, + output_dim=self.embed_dim, + hidden_dim=int(encoder_ratio * self.embed_dim), + act_layer=activation_function, + input_format="nchw", + ) + fblock_mlp_inp_name = "fin" + fblock_mlp_hidden_name = "fout" + + # dropout + self.pos_drop = nn.Dropout(p=pos_drop_rate) if pos_drop_rate > 0.0 else nn.Identity() + dpr = [x.item() for x in torch.linspace(0, path_drop_rate, num_layers)] + + # pick norm layer + if normalization_layer == "layer_norm": + norm_layer_inp = partial(DistributedLayerNorm, normalized_shape=(embed_dim), elementwise_affine=True, eps=1e-6) + norm_layer_out = norm_layer_mid = norm_layer_inp + elif normalization_layer == "instance_norm": + if comm.get_size("spatial") > 1: + norm_layer_inp = partial(DistributedInstanceNorm2d, num_features=embed_dim, eps=1e-6, affine=True) + else: + norm_layer_inp = partial(nn.InstanceNorm2d, num_features=embed_dim, eps=1e-6, affine=True, track_running_stats=False) + norm_layer_out = norm_layer_mid = norm_layer_inp + elif normalization_layer == "none": + norm_layer_out = norm_layer_mid = norm_layer_inp = nn.Identity + else: + raise NotImplementedError(f"Error, normalization {normalization_layer} not implemented.") + + # FNO blocks + self.blocks = nn.ModuleList([]) + for i in range(num_layers): + first_layer = i == 0 + last_layer = i == num_layers - 1 + + forward_transform = self.trans_down if first_layer else self.trans + inverse_transform = self.itrans_up if last_layer else self.itrans + + inner_skip = "none" + outer_skip = "linear" + + if first_layer: + norm_layer = (norm_layer_inp, norm_layer_mid) + elif last_layer: + norm_layer = (norm_layer_mid, norm_layer_out) + else: + norm_layer = (norm_layer_mid, norm_layer_mid) + + block = FourierNeuralOperatorBlock( + forward_transform, + inverse_transform, + embed_dim, + filter_type=filter_type, + operator_type=operator_type, + mlp_ratio=mlp_ratio, + mlp_drop_rate=mlp_drop_rate, + path_drop_rate=dpr[i], + act_layer=activation_function, + norm_layer=norm_layer, + inner_skip=inner_skip, + outer_skip=outer_skip, + use_mlp=use_mlp, + comm_feature_inp_name=fblock_mlp_inp_name, + comm_feature_hidden_name=fblock_mlp_hidden_name, + rank=rank, + factorization=factorization, + separable=separable, + complex_activation=complex_activation, + spectral_layers=spectral_layers, + bias=bias, + checkpointing=checkpointing, + ) + + self.blocks.append(block) + + # decoder takes the output of FNO blocks and the residual from the big skip connection + if comm.get_size("matmul") > 1: + comm_inp_name = fblock_mlp_inp_name + comm_out_name = fblock_mlp_hidden_name + self.decoder = DistributedEncoderDecoder( + num_layers=encoder_layers, + input_dim=embed_dim, + output_dim=self.out_chans, + hidden_dim=int(decoder_ratio * embed_dim), + act_layer=activation_function, + gain=0.5 if self.big_skip else 1.0, + comm_inp_name=comm_inp_name, + comm_out_name=comm_out_name, + input_format="nchw", + ) + self.gather_shapes = compute_split_shapes(self.out_chans, + comm.get_size(self.decoder.comm_out_name)) + + else: + self.decoder = EncoderDecoder( + num_layers=encoder_layers, + input_dim=embed_dim, + output_dim=self.out_chans, + hidden_dim=int(decoder_ratio * embed_dim), + act_layer=activation_function, + gain=0.5 if self.big_skip else 1.0, + input_format="nchw", + ) + + # output transform + if self.big_skip: + self.residual_transform = nn.Conv2d(self.inp_chans, self.out_chans, 1, bias=False) + self.residual_transform.weight.is_shared_mp = ["spatial"] + self.residual_transform.weight.sharded_dims_mp = [None, None, None, None] + scale = math.sqrt(0.5 / self.inp_chans) + nn.init.normal_(self.residual_transform.weight, mean=0.0, std=scale) + + # learned position embedding + if pos_embed == "direct": + # currently using deliberately a differently shape position embedding + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, self.inp_shape_loc[0], self.inp_shape_loc[1])) + # information about how tensors are shared / sharded across ranks + self.pos_embed.is_shared_mp = [] # no reduction required since pos_embed is already serial + self.pos_embed.sharded_dims_mp = [None, None, "h", "w"] + self.pos_embed.type = "direct" + with torch.no_grad(): + nn.init.trunc_normal_(self.pos_embed, std=0.02) + elif pos_embed == "frequency": + if comm.get_size("spatial") > 1: + lmax_loc = self.itrans_up.l_shapes[comm.get_rank("h")] + mmax_loc = self.itrans_up.m_shapes[comm.get_rank("w")] + else: + lmax_loc = self.itrans_up.lmax + mmax_loc = self.itrans_up.mmax + + rcoeffs = nn.Parameter(torch.tril(torch.randn(1, embed_dim, lmax_loc, mmax_loc), diagonal=0)) + ccoeffs = nn.Parameter(torch.tril(torch.randn(1, embed_dim, lmax_loc, mmax_loc - 1), diagonal=-1)) + with torch.no_grad(): + nn.init.trunc_normal_(rcoeffs, std=0.02) + nn.init.trunc_normal_(ccoeffs, std=0.02) + self.pos_embed = nn.ParameterList([rcoeffs, ccoeffs]) + self.pos_embed.type = "frequency" + self.pos_embed.is_shared_mp = [] + self.pos_embed.sharded_dims_mp = [None, None, "h", "w"] + + elif pos_embed == "none" or pos_embed == "None" or pos_embed == None: + pass + else: + raise ValueError("Unknown position embedding type") + + @torch.jit.ignore + def _init_spectral_transforms( + self, + spectral_transform="sht", + model_grid_type="equiangular", + sht_grid_type="legendre-gauss", + hard_thresholding_fraction=1.0, + max_modes=None, + ): + """ + Initialize the spectral transforms based on the maximum number of modes to keep. Handles the computation + of local image shapes and domain parallelism, based on the + """ + + if max_modes is not None: + modes_lat, modes_lon = max_modes + else: + modes_lat = int(self.h * hard_thresholding_fraction) + modes_lon = int((self.w // 2 + 1) * hard_thresholding_fraction) + + # prepare the spectral transforms + if spectral_transform == "sht": + sht_handle = th.RealSHT + isht_handle = th.InverseRealSHT + + # parallelism + if comm.get_size("spatial") > 1: + polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(polar_group, azimuth_group) + sht_handle = thd.DistributedRealSHT + isht_handle = thd.DistributedInverseRealSHT + + # set up + self.trans_down = sht_handle(*self.inp_shape, lmax=modes_lat, mmax=modes_lon, grid=model_grid_type).float() + self.itrans_up = isht_handle(*self.out_shape, lmax=modes_lat, mmax=modes_lon, grid=model_grid_type).float() + self.trans = sht_handle(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=sht_grid_type).float() + self.itrans = isht_handle(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=sht_grid_type).float() + + elif spectral_transform == "fft": + fft_handle = RealFFT2 + ifft_handle = InverseRealFFT2 + + if comm.get_size("spatial") > 1: + h_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + w_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(h_group, w_group) + fft_handle = DistributedRealFFT2 + ifft_handle = DistributedInverseRealFFT2 + + self.trans_down = fft_handle(*self.inp_shape, lmax=modes_lat, mmax=modes_lon).float() + self.itrans_up = ifft_handle(*self.out_shape, lmax=modes_lat, mmax=modes_lon).float() + self.trans = fft_handle(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float() + self.itrans = ifft_handle(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float() + else: + raise (ValueError("Unknown spectral transform")) + + # use the SHT/FFT to compute the local, downscaled grid dimensions + if comm.get_size("spatial") > 1: + self.inp_shape_loc = (self.trans_down.lat_shapes[comm.get_rank("h")], + self.trans_down.lon_shapes[comm.get_rank("w")]) + self.out_shape_loc = (self.itrans_up.lat_shapes[comm.get_rank("h")], + self.itrans_up.lon_shapes[comm.get_rank("w")]) + self.h_loc = self.itrans.lat_shapes[comm.get_rank("h")] + self.w_loc = self.itrans.lon_shapes[comm.get_rank("w")] + else: + self.inp_shape_loc = (self.trans_down.nlat, self.trans_down.nlon) + self.out_shape_loc = (self.itrans_up.nlat, self.itrans_up.nlon) + self.h_loc = self.itrans.nlat + self.w_loc = self.itrans.nlon + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def _forward_features(self, x): + for r in range(self.repeat_layers): + for blk in self.blocks: + if self.checkpointing >= 3: + x = checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + + return x + + def forward(self, x): + # save big skip + if self.big_skip: + # if output shape differs, use the spectral transforms to change resolution + if self.out_shape != self.inp_shape: + xtype = x.dtype + # only take the predicted channels as residual + residual = x.to(torch.float32) + with amp.autocast(enabled=False): + residual = self.trans_down(residual) + residual = residual.contiguous() + residual = self.itrans_up(residual) + residual = residual.to(dtype=xtype) + else: + # only take the predicted channels + residual = x + + if comm.get_size("fin") > 1: + x = scatter_to_parallel_region(x, 1, "fin") + + if self.checkpointing >= 1: + x = checkpoint(self.encoder, x, use_reentrant=False) + else: + x = self.encoder(x) + + if hasattr(self, "pos_embed"): + if self.pos_embed.type == "frequency": + pos_embed = torch.stack([self.pos_embed[0], nn.functional.pad(self.pos_embed[1], (1, 0), "constant", 0)], dim=-1) + with amp.autocast(enabled=False): + pos_embed = self.itrans_up(torch.view_as_complex(pos_embed)) + else: + pos_embed = self.pos_embed + + # add pos embed + x = x + pos_embed + + # maybe clean the padding just in case + x = self.pos_drop(x) + + # do the feature extraction + x = self._forward_features(x) + + if self.checkpointing >= 1: + x = checkpoint(self.decoder, x, use_reentrant=False) + else: + x = self.decoder(x) + + if hasattr(self.decoder, "comm_out_name") and (comm.get_size(self.decoder.comm_out_name) > 1): + x = gather_from_parallel_region(x, 1, self.gather_shapes, self.decoder.comm_out_name) + + if self.big_skip: + x = x + self.residual_transform(residual) + + return x + +# this part exposes the model to modulus by constructing modulus Modules +@dataclass +class SphericalFourierNeuralOperatorNetMetaData(ModelMetaData): + name: str = "SFNO" + + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + +SFNO = modulus.Module.from_torch( + SphericalFourierNeuralOperatorNet, + SphericalFourierNeuralOperatorNetMetaData() +) + +class FourierNeuralOperatorNet(SphericalFourierNeuralOperatorNet): + def __init__(self, *args, **kwargs): + return super().__init__(*args, spectral_transform="fft", **kwargs) + +@dataclass +class FourierNeuralOperatorNetMetaData(ModelMetaData): + name: str = "FNO" + + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + +FNO = modulus.Module.from_torch( + FourierNeuralOperatorNet, + FourierNeuralOperatorNetMetaData() +) \ No newline at end of file diff --git a/networks/vit.py b/networks/vit.py new file mode 100644 index 0000000..426520a --- /dev/null +++ b/networks/vit.py @@ -0,0 +1,231 @@ +import math + +import torch.nn.functional as F +import torch +import torch.nn as nn +from functools import partial + +# mp stuff +from makani.utils import comm +from makani.models.common import DropPath, MLP, PatchEmbed +from makani.mpu.layers import DistributedMatmul, DistributedMLP, DistributedAttention + + +class Attention(nn.Module): + def __init__( + self, + dim, + input_format="traditional", + num_heads=8, + qkv_bias=False, + qk_norm=False, + attn_drop_rate=0.0, + proj_drop_rate=0.0, + norm_layer=nn.LayerNorm, + ): + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop_rate = attn_drop_rate + + self.proj = nn.Linear(dim, dim) + + if proj_drop_rate > 0: + self.proj_drop = nn.Dropout(proj_drop_rate) + else: + self.proj_drop = nn.Identity() + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop_rate) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + mlp_drop_rate=0.0, + attn_drop_rate=0.0, + path_drop_rate=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + comm_inp_name="fin", + comm_hidden_name="fout", + ): + super().__init__() + + if (comm.get_size(comm_inp_name) * comm.get_size(comm_hidden_name)) > 1: + self.attn = DistributedAttention( + dim, + input_format="traditional", + comm_inp_name=comm_inp_name, + comm_hidden_name=comm_hidden_name, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=mlp_drop_rate, + norm_layer=norm_layer, + ) + else: + self.attn = Attention( + dim, input_format="traditional", num_heads=num_heads, qkv_bias=qkv_bias, attn_drop_rate=attn_drop_rate, proj_drop_rate=mlp_drop_rate, norm_layer=norm_layer + ) + self.drop_path = DropPath(path_drop_rate) if path_drop_rate > 0.0 else nn.Identity() + + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + + mlp_hidden_dim = int(dim * mlp_ratio) + + # distribute MLP for model parallelism + if (comm.get_size(comm_inp_name) * comm.get_size(comm_hidden_name)) > 1: + self.mlp = DistributedMLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + out_features=dim, + act_layer=act_layer, + drop_rate=mlp_drop_rate, + input_format="traditional", + comm_inp_name=comm_inp_name, + comm_hidden_name=comm_hidden_name, + ) + else: + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop_rate=mlp_drop_rate, input_format="traditional") + + def forward(self, x): + # flatten transpose: + y = self.attn(self.norm1(x)) + x = x + self.drop_path(y) + x = self.norm2(x) + x = x + self.drop_path(self.mlp(x)) + + return x + + +class VisionTransformer(nn.Module): + def __init__( + self, + inp_shape=[224, 224], + patch_size=(16, 16), + inp_chans=3, + out_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + mlp_drop_rate=0.0, + attn_drop_rate=0.0, + path_drop_rate=0.0, + norm_layer="layer_norm", + comm_inp_name="fin", + comm_hidden_name="fout", + **kwargs, + ): + super().__init__() + self.num_features = self.embed_dim = embed_dim + self.patch_size = patch_size + self.img_size = inp_shape + self.out_ch = out_chans + self.comm_inp_name = comm_inp_name + self.comm_hidden_name = comm_hidden_name + + self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=patch_size, in_chans=inp_chans, embed_dim=self.embed_dim) + num_patches = self.patch_embed.num_patches + + # annotate for distributed + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim)) + self.pos_embed.is_shared_mp = [] + + self.pos_drop = nn.Dropout(p=path_drop_rate) + + dpr = [x.item() for x in torch.linspace(0, path_drop_rate, depth)] # stochastic depth decay rule + + if norm_layer == "layer_norm": + norm_layer_handle = nn.LayerNorm + else: + raise NotImplementedError(f"Error, normalization layer type {norm_layer} not implemented for ViT.") + + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + mlp_drop_rate=mlp_drop_rate, + attn_drop_rate=attn_drop_rate, + path_drop_rate=dpr[i], + norm_layer=norm_layer_handle, + comm_inp_name=comm_inp_name, + comm_hidden_name=comm_hidden_name, + ) + for i in range(depth) + ] + ) + + self.norm = norm_layer_handle(embed_dim) + + self.out_size = self.out_ch * self.patch_size[0] * self.patch_size[1] + + self.head = nn.Linear(embed_dim, self.out_size, bias=False) + + nn.init.trunc_normal_(self.pos_embed, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def prepare_tokens(self, x): + B, C, H, W = x.shape + x = self.patch_embed(x).transpose(1, 2) # patch linear embedding + + # add positional encoding to each token + x = x + self.pos_embed + return self.pos_drop(x) + + def forward_head(self, x): + B, _, _ = x.shape # B x N x embed_dim + x = x.reshape(B, self.patch_embed.red_img_size[0], self.patch_embed.red_img_size[1], self.embed_dim) + B, h, w, _ = x.shape + + # apply head + x = self.head(x) + x = x.reshape(shape=(B, h, w, self.patch_size[0], self.patch_size[1], self.out_ch)) + x = torch.einsum("nhwpqc->nchpwq", x) + x = x.reshape(shape=(B, self.out_ch, self.img_size[0], self.img_size[1])) + + return x + + def forward(self, x): + x = self.prepare_tokens(x) + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + x = self.forward_head(x) + return x diff --git a/preprocessor.py b/preprocessor.py new file mode 100644 index 0000000..6689112 --- /dev/null +++ b/preprocessor.py @@ -0,0 +1,426 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from functools import partial + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from makani.utils import comm +from makani.utils.grids import GridConverter +from modulus.distributed.mappings import reduce_from_parallel_region, copy_to_parallel_region + + +class Preprocessor2D(nn.Module): + def __init__(self, params): + super(Preprocessor2D, self).__init__() + + self.n_history = params.n_history + self.history_normalization_mode = params.history_normalization_mode + if self.history_normalization_mode == "exponential": + self.history_normalization_decay = params.history_normalization_decay + # inverse ordering, since first element is oldest + history_normalization_weights = torch.exp((-self.history_normalization_decay) * torch.arange(start=self.n_history, end=-1, step=-1, dtype=torch.float32)) + history_normalization_weights = history_normalization_weights / torch.sum(history_normalization_weights) + history_normalization_weights = torch.reshape(history_normalization_weights, (1, -1, 1, 1, 1)) + elif self.history_normalization_mode == "mean": + history_normalization_weights = torch.Tensor(1.0 / float(self.n_history + 1), dtype=torch.float32) + history_normalization_weights = torch.reshape(history_normalization_weights, (1, -1, 1, 1, 1)) + else: + history_normalization_weights = torch.ones(self.n_history + 1, dtype=torch.float32) + self.register_buffer("history_normalization_weights", history_normalization_weights, persistent=False) + self.history_mean = None + self.history_std = None + self.history_diff_mean = None + self.history_diff_var = None + self.history_eps = 1e-6 + + # residual normalization + self.learn_residual = params.target == "residual" + if self.learn_residual and (params.normalize_residual): + with torch.no_grad(): + residual_scale = torch.from_numpy(np.load(params.time_diff_stds_path)).to(torch.float32) + self.register_buffer("residual_scale", residual_scale, persistent=False) + else: + self.residual_scale = None + + # image shape + self.img_shape = [params.img_shape_x, params.img_shape_y] + + # unpredicted input channels: + self.unpredicted_inp_train = None + self.unpredicted_tar_train = None + self.unpredicted_inp_eval = None + self.unpredicted_tar_eval = None + + # process static features + static_features = None + # needed for sharding + start_x = params.img_local_offset_x + end_x = min(start_x + params.img_local_shape_x, params.img_shape_x) + start_y = params.img_local_offset_y + end_y = min(start_y + params.img_local_shape_y, params.img_shape_y) + + # set up grid + if params.add_grid: + with torch.no_grad(): + if hasattr(params, "lat") and hasattr(params, "lon"): + lat = torch.tensor(params.lat).to(torch.float32) + lon = torch.tensor(params.lon).to(torch.float32) + + # convert grid if required + gconv = GridConverter(params.data_grid_type, params.model_grid_type, torch.deg2rad(lat), torch.deg2rad(lon)) + tx, ty = gconv.get_dst_coords() + tx = tx.to(torch.float32) + ty = ty.to(torch.float32) + else: + tx = torch.linspace(0, 1, params.img_shape_x + 1, dtype=torch.float32)[0:-1] + ty = torch.linspace(0, 1, params.img_shape_y + 1, dtype=torch.float32)[0:-1] + + x_grid, y_grid = torch.meshgrid(tx, ty, indexing="ij") + x_grid, y_grid = x_grid.unsqueeze(0).unsqueeze(0), y_grid.unsqueeze(0).unsqueeze(0) + grid = torch.cat([x_grid, y_grid], dim=1) + + # shard spatially: + grid = grid[:, :, start_x:end_x, start_y:end_y] + + # transform if requested + if params.gridtype == "sinusoidal": + num_freq = 1 + if hasattr(params, "grid_num_frequencies"): + num_freq = int(params.grid_num_frequencies) + + singrid = None + for freq in range(1, num_freq + 1): + if singrid is None: + singrid = torch.sin(grid) + else: + singrid = torch.cat([singrid, torch.sin(freq * grid)], dim=1) + + static_features = singrid + else: + static_features = grid + + if params.add_orography: + from makani.utils.conditioning_inputs import get_orography + + with torch.no_grad(): + oro = torch.tensor(get_orography(params.orography_path), dtype=torch.float32) + oro = torch.reshape(oro, (1, 1, oro.shape[0], oro.shape[1])) + + # normalize + eps = 1.0e-6 + oro = (oro - torch.mean(oro)) / (torch.std(oro) + eps) + + # shard + oro = oro[:, :, start_x:end_x, start_y:end_y] + + if static_features is None: + static_features = oro + else: + static_features = torch.cat([static_features, oro], dim=1) + + if params.add_landmask: + from makani.utils.conditioning_inputs import get_land_mask + + with torch.no_grad(): + lsm = torch.tensor(get_land_mask(params.landmask_path), dtype=torch.long) + # one hot encode and move channels to front: + lsm = torch.permute(torch.nn.functional.one_hot(lsm), (2, 0, 1)).to(torch.float32) + lsm = torch.reshape(lsm, (1, lsm.shape[0], lsm.shape[1], lsm.shape[2])) + + # shard + lsm = lsm[:, :, start_x:end_x, start_y:end_y] + + if static_features is None: + static_features = lsm + else: + static_features = torch.cat([static_features, lsm], dim=1) + + self.do_add_static_features = False + if static_features is not None: + self.do_add_static_features = True + self.register_buffer("static_features", static_features, persistent=False) + + def flatten_history(self, x): + # flatten input + if x.dim() == 5: + b_, t_, c_, h_, w_ = x.shape + x = torch.reshape(x, (b_, t_ * c_, h_, w_)) + + return x + + def expand_history(self, x, nhist): + if x.dim() == 4: + b_, ct_, h_, w_ = x.shape + x = torch.reshape(x, (b_, nhist, ct_ // nhist, h_, w_)) + return x + + def add_residual(self, x, dx): + if self.learn_residual: + if self.residual_scale is not None: + dx = dx * self.residual_scale + + # add residual: deal with history + x = self.expand_history(x, nhist=self.n_history + 1) + x[:, -1, ...] = x[:, -1, ...] + dx + x = self.flatten_history(x) + else: + x = dx + + return x + + def add_static_features(self, x): + if self.do_add_static_features: + # we need to replicate the grid for each batch: + static = torch.tile(self.static_features, dims=(x.shape[0], 1, 1, 1)) + x = torch.cat([x, static], dim=1) + + return x + + def remove_static_features(self, x): + # only remove if something was added in the first place + if self.do_add_static_features: + nfeat = self.static_features.shape[1] + x = x[:, : x.shape[1] - nfeat, :, :] + return x + + def append_history(self, x1, x2, step): + # take care of unpredicted features first + # this is necessary in order to copy the targets unpredicted features + # (such as zenith angle) into the inputs unpredicted features, + # such that they can be forward in the next autoregressive step + # extract utar + + # update the unpredicted input + if self.training: + if (self.unpredicted_tar_train is not None) and (step < self.unpredicted_tar_train.shape[1]): + utar = self.unpredicted_tar_train[:, step : (step + 1), :, :, :] + if self.n_history == 0: + self.unpredicted_inp_train.copy_(utar) + else: + self.unpredicted_inp_train.copy_(torch.cat([self.unpredicted_inp_train[:, 1:, :, :, :], utar], dim=1)) + else: + if (self.unpredicted_tar_eval is not None) and (step < self.unpredicted_tar_eval.shape[1]): + utar = self.unpredicted_tar_eval[:, step : (step + 1), :, :, :] + if self.n_history == 0: + self.unpredicted_inp_eval.copy_(utar) + else: + self.unpredicted_inp_eval.copy_(torch.cat([self.unpredicted_inp_eval[:, 1:, :, :, :], utar], dim=1)) + + if self.n_history > 0: + # this is more complicated + x1 = self.expand_history(x1, nhist=self.n_history + 1) + x2 = self.expand_history(x2, nhist=1) + + # append + res = torch.cat([x1[:, 1:, :, :, :], x2], dim=1) + + # flatten again + res = self.flatten_history(res) + else: + res = x2 + + return res + + def append_channels(self, x, xc): + xdim = x.dim() + x = self.expand_history(x, self.n_history + 1) + + xc = self.expand_history(xc, self.n_history + 1) + + # concatenate + xo = torch.cat([x, xc], dim=2) + + # flatten if requested + if xdim == 4: + xo = self.flatten_history(xo) + + return xo + + def history_compute_stats(self, x): + if self.history_normalization_mode == "none": + self.history_mean = torch.zeros((1, 1, 1, 1), dtype=torch.float32, device=x.device) + self.history_std = torch.ones((1, 1, 1, 1), dtype=torch.float32, device=x.device) + elif self.history_normalization_mode == "timediff": + # reshaping + xdim = x.dim() + if xdim == 4: + b_, c_, h_, w_ = x.shape + xr = torch.reshape(x, (b_, (self.n_history + 1), c_ // (self.n_history + 1), h_, w_)) + else: + xshape = x.shape + xr = x + + # time difference mean: + self.history_diff_mean = torch.mean(torch.sum(xr[:, 1:, ...] - xr[:, 0:-1, ...], dim=(4, 5)), dim=(1, 2)) + # reduce across gpus + if comm.get_size("spatial") > 1: + self.history_diff_mean = reduce_from_parallel_region(self.history_diff_mean, "spatial") + self.history_diff_mean = self.history_diff_mean / float(self.img_shape[0] * self.img_shape[1]) + + # time difference std + self.history_diff_var = torch.mean(torch.sum(torch.square((xr[:, 1:, ...] - xr[:, 0:-1, ...]) - self.history_diff_mean), dim=(4, 5)), dim=(1, 2)) + # reduce across gpus + if comm.get_size("spatial") > 1: + self.history_diff_var = reduce_from_parallel_region(self.history_diff_var, "spatial") + self.history_diff_var = self.history_diff_var / float(self.img_shape[0] * self.img_shape[1]) + + # time difference stds + self.history_diff_mean = copy_to_parallel_region(self.history_diff_mean, "spatial") + self.history_diff_var = copy_to_parallel_region(self.history_diff_var, "spatial") + else: + xdim = x.dim() + if xdim == 4: + b_, c_, h_, w_ = x.shape + xr = torch.reshape(x, (b_, (self.n_history + 1), c_ // (self.n_history + 1), h_, w_)) + else: + xshape = x.shape + xr = x + + # mean + # compute weighted mean over dim 1, but sum over dim=3,4 + self.history_mean = torch.sum(xr * self.history_normalization_weights, dim=(1, 3, 4), keepdim=True) + # reduce across gpus + if comm.get_size("spatial") > 1: + self.history_mean = reduce_from_parallel_region(self.history_mean, "spatial") + self.history_mean = self.history_mean / float(self.img_shape[0] * self.img_shape[1]) + + # compute std + self.history_std = torch.sum(torch.square(xr - self.history_mean) * self.history_normalization_weights, dim=(1, 3, 4), keepdim=True) + # reduce across gpus + if comm.get_size("spatial") > 1: + self.history_std = reduce_from_parallel_region(self.history_std, "spatial") + self.history_std = torch.sqrt(self.history_std / float(self.img_shape[0] * self.img_shape[1])) + + # squeeze + self.history_mean = torch.squeeze(self.history_mean, dim=1) + self.history_std = torch.squeeze(self.history_std, dim=1) + + # copy to parallel region + self.history_mean = copy_to_parallel_region(self.history_mean, "spatial") + self.history_std = copy_to_parallel_region(self.history_std, "spatial") + + return + + def history_normalize(self, x, target=False): + if self.history_normalization_mode in ["none", "timediff"]: + return x + + xdim = x.dim() + if xdim == 4: + b_, c_, h_, w_ = x.shape + xr = torch.reshape(x, (b_, (self.n_history + 1), c_ // (self.n_history + 1), h_, w_)) + else: + xshape = x.shape + xr = x + x = self.flatten_history(x) + + # normalize + if target: + # strip off the unpredicted channels + xn = (x - self.history_mean[:, : x.shape[1], :, :]) / self.history_std[:, : x.shape[1], :, :] + else: + # tile to include history + hm = torch.tile(self.history_mean, (1, self.n_history + 1, 1, 1)) + hs = torch.tile(self.history_std, (1, self.n_history + 1, 1, 1)) + xn = (x - hm) / hs + + if xdim == 5: + xn = torch.reshape(xn, xshape) + + return xn + + def history_denormalize(self, xn, target=False): + if self.history_normalization_mode in ["none", "timediff"]: + return xn + + assert self.history_mean is not None + assert self.history_std is not None + + xndim = xn.dim() + if xndim == 5: + xnshape = xn.shape + xn = self.flatten_history(xn) + + # de-normalize + if target: + # strip off the unpredicted channels + x = xn * self.history_std[:, : xn.shape[1], :, :] + self.history_mean[:, : xn.shape[1], :, :] + else: + # tile to include history + hm = torch.tile(self.history_mean, (1, self.n_history + 1, 1, 1)) + hs = torch.tile(self.history_std, (1, self.n_history + 1, 1, 1)) + x = xn * hs + hm + + if xndim == 5: + x = torch.reshape(x, xnshape) + + return x + + def cache_unpredicted_features(self, x, y, xz=None, yz=None): + if self.training: + if (self.unpredicted_inp_train is not None) and (xz is not None): + self.unpredicted_inp_train.copy_(xz) + else: + self.unpredicted_inp_train = xz + + if (self.unpredicted_tar_train is not None) and (yz is not None): + self.unpredicted_tar_train.copy_(yz) + else: + self.unpredicted_tar_train = yz + else: + if (self.unpredicted_inp_eval is not None) and (xz is not None): + self.unpredicted_inp_eval.copy_(xz) + else: + self.unpredicted_inp_eval = xz + + if (self.unpredicted_tar_eval is not None) and (yz is not None): + self.unpredicted_tar_eval.copy_(yz) + else: + self.unpredicted_tar_eval = yz + + return x, y + + def append_unpredicted_features(self, inp): + if self.training: + if self.unpredicted_inp_train is not None: + inp = self.append_channels(inp, self.unpredicted_inp_train) + else: + if self.unpredicted_inp_eval is not None: + inp = self.append_channels(inp, self.unpredicted_inp_eval) + return inp + + def remove_unpredicted_features(self, inp): + if self.training: + if self.unpredicted_inp_train is not None: + inpf = self.expand_history(inp, nhist=self.n_history + 1) + inpc = inpf[:, :, : inpf.shape[2] - self.unpredicted_inp_train.shape[2], :, :] + inp = self.flatten_history(inpc) + else: + if self.unpredicted_inp_eval is not None: + inpf = self.expand_history(inp, nhist=self.n_history + 1) + inpc = inpf[:, :, : inpf.shape[2] - self.unpredicted_inp_eval.shape[2], :, :] + inp = self.flatten_history(inpc) + + return inp + + +def get_preprocessor(params): + return Preprocessor2D(params) diff --git a/stepper.py b/stepper.py new file mode 100644 index 0000000..26cfbd7 --- /dev/null +++ b/stepper.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from makani.models.preprocessor import Preprocessor2D + +class SingleStepWrapper(nn.Module): + def __init__(self, params, model_handle): + super(SingleStepWrapper, self).__init__() + self.preprocessor = Preprocessor2D(params) + self.model = model_handle() + + def forward(self, inp): + # first append unpredicted features + inpa = self.preprocessor.append_unpredicted_features(inp) + + # now normalize + self.preprocessor.history_compute_stats(inpa) + inpan = self.preprocessor.history_normalize(inpa, target=False) + + # now add static features if requested + inpans = self.preprocessor.add_static_features(inpan) + + # forward pass + yn = self.model(inpans) + + # undo normalization + y = self.preprocessor.history_denormalize(yn, target=True) + + # add residual (for residual learning, no-op for direct learning + y = self.preprocessor.add_residual(inp, y) + + return y + + +class MultiStepWrapper(nn.Module): + def __init__(self, params, model_handle): + super(MultiStepWrapper, self).__init__() + self.preprocessor = Preprocessor2D(params) + self.model = model_handle() + self.residual_mode = True if (params.target == "target") else False + + # collect parameters for history + self.n_future = params.n_future + + def _forward_train(self, inp): + result = [] + inpt = inp + for step in range(self.n_future + 1): + # add unpredicted features + inpa = self.preprocessor.append_unpredicted_features(inpt) + + # do history normalization + self.preprocessor.history_compute_stats(inpa) + inpan = self.preprocessor.history_normalize(inpa, target=False) + + # add static features + inpans = self.preprocessor.add_static_features(inpan) + + # prediction + predn = self.model(inpans) + + # append the denormalized result to output list + # important to do that here, otherwise normalization stats + # will have been updated later: + pred = self.preprocessor.history_denormalize(predn, target=True) + # add residual (for residual learning, no-op for direct learning + pred = self.preprocessor.add_residual(inpt, pred) + # append output + result.append(pred) + + if step == self.n_future: + break + + # append history + inpt = self.preprocessor.append_history(inpt, pred, step) + + # concat the tensors along channel dim to be compatible with flattened target + result = torch.cat(result, dim=1) + + return result + + def _forward_eval(self, inp): + # first append unpredicted features + inpa = self.preprocessor.append_unpredicted_features(inp) + + # do history normalization + self.preprocessor.history_compute_stats(inpa) + inpan = self.preprocessor.history_normalize(inpa, target=False) + + # add static features + inpans = self.preprocessor.add_static_features(inpan) + + # important, remove normalization here, + # because otherwise normalization stats are already outdated + yn = self.model(inpans) + + # important, remove normalization here, + # because otherwise normalization stats are already outdated + y = self.preprocessor.history_denormalize(yn, target=True) + + # add residual (for residual learning, no-op for direct learning + y = self.preprocessor.add_residual(inp, y) + + return y + + def forward(self, inp): + # decide which routine to call + if self.training: + y = self._forward_train(inp) + else: + y = self._forward_eval(inp) + + return y \ No newline at end of file