From ae4d0d2b8ca14fb79bce9c3c8b1f29180671bd3c Mon Sep 17 00:00:00 2001 From: nnguyen19 Date: Tue, 14 May 2024 03:54:25 -0500 Subject: [PATCH] fix module issues --- ProteinReDiff/__pycache__/data.cpython-39.pyc | Bin 9531 -> 9481 bytes .../__pycache__/difffusion.cpython-39.pyc | Bin 946 -> 962 bytes .../__pycache__/features.cpython-39.pyc | Bin 2294 -> 2292 bytes .../__pycache__/mask_utils.cpython-39.pyc | Bin 3472 -> 3488 bytes .../__pycache__/model.cpython-39.pyc | Bin 15623 -> 15621 bytes .../__pycache__/modules.cpython-39.pyc | Bin 12676 -> 13413 bytes ProteinReDiff/__pycache__/mol.cpython-39.pyc | Bin 2551 -> 2549 bytes .../__pycache__/protein.cpython-39.pyc | Bin 6418 -> 6416 bytes .../__pycache__/tmalign.cpython-39.pyc | Bin 1536 -> 1861 bytes .../__pycache__/utils.cpython-39.pyc | Bin 1814 -> 2551 bytes ProteinReDiff/data.py | 12 +- .../__pycache__/AF2_modules.cpython-39.pyc | Bin 17178 -> 15342 bytes README.md | 4 +- generate.py | 59 ++++++- scripts/__pycache__/__init__.cpython-39.pyc | Bin 146 -> 155 bytes .../predict_batch_seq_msk_inp.cpython-39.pyc | Bin 0 -> 8287 bytes scripts/predict_batch_seq_msk_inp.py | 161 ++++++++++-------- scripts/predict_batch_strc_msk_inp.py | 44 +---- train.py | 14 +- train_from_ckpt.py | 13 +- 20 files changed, 168 insertions(+), 139 deletions(-) create mode 100644 scripts/__pycache__/predict_batch_seq_msk_inp.cpython-39.pyc diff --git a/ProteinReDiff/__pycache__/data.cpython-39.pyc b/ProteinReDiff/__pycache__/data.cpython-39.pyc index 5cd09783119803b5478a9bb44dc6d4fb6320dee7..fbf229224bb8dbeb9cff1be8cb0029cccd0ee4ea 100644 GIT binary patch delta 2173 zcma)-Urd`-6u`NqEiI)`+5)AvP@rO(xz;I8+<{XP(w z$k^B>YG&{n|0Tv369X|uMdQO_LNvbkCnhHP&==p9J@~TtuAXzh+JX!fa(}t^oO91T z_ndRT-WXgPtd*V4>I(S#ed+kr_u)@!do4Zvu}LbYv-E)&%gxTm)RY<IYJ5AMkx^C-6J#N96F?t|kd ztg@^xEjM6n~8k-@v7rszEDLGe9^PJ|B-v9pJ8?niVZ9!4BN97Tvdy*TUR zx7$3t#okAbZN=?h5Yo*X96ym0{FdW0*C1Lvf;fqIjDJ|eUBcf^;Th+(y~3N$3*;1^ ztR1(W#;uR@?b;_Aj>EKInUwWpra1ygi zAv+7$s4SQ@nw?^02`ShmDVdcuO_B_Et!N$R>#j3(6S#*MTM3iIMNRc_Z&8k3HcU*y z4O2pSwSL+W$7Kqk@IUHo#&fwHyP(&x_esDHG3NI z3?hLLw=9O`^Nbialc-6S)Mc9#U-Lu=`jkg#E~j?4)@{R3)PCIsG2UTNy=1|nh-i-7O+uiWUM6E zd&%n!O=b`JC(S|rZ9`yCoMB?v(Za`_FAmyIVQ6?>1zB{1hXGlaHL6D?DN9wTFy&p1 zueD^h ztXLK?jlf+y5>SQ9a&zQxX+cyX@GO=`h)NSw2}HSl;`~?Nd-uSCu>Kp zmOy(0hA5kjglsm8&BT@|> zAHh;r|1*oLa6rK>%%qgU*cA-vD&X(bpg~!>C(EGy5r$R^sIXRZSFpy6owXGXULx=v z|DdJU+zD&`TZ>{mvQUEkqJ{Cl^FS1BGTexunN z2$6+1arQRi9mKxPcbTtt+)Xa?&5rT@7tj))P@z7*pyyaJBW0CzMpqRzT`=Xd;uADs zVX5J&e+02&@pp9~3r&5Jg(zKZKjDmvJ2s JZ|=`<&tLzg-S_|i delta 2261 zcma);|4&<06u^6d_jsjifwp6;D{Fxc^Au5+EgA;6ArnUiL-5O-#n;E{D}AG1*nO`X zW7F=7ifp<$uEzLv#-C&+MnvNeBxWZ2MU%PckMC!rG5RNH;yJf1Wp+3r_tSgNIrrRi z&$;)u3qz|z9@XuxbHE>qv?qSsvFzy;`UW2Bjy`Sb)Qt8Y7>N#KwOm@ySY~w4u%5_G zMn`GZ(v6HVrpI;Dn97s{WBP!RNJQDZa6M?d26byBm)>V()21z0S*j&%!L(Fu+Gd*A zM{M?q*a7@r{G4>J%+|^zxf7zWMOC#<&girg2J}J111NTBNfnkfGI4#D?nUuEhz>+E zKz7nS*bg9j5WR>##G{D4h<%8DcC69Ig3dm&f8}}SH-z*spX+Bb$bNNw)i8<{2M`Al z!|d87*1-L(7Y@7UD}}$fPmx0`<~b~m;nb(tx1J|AKMq}6m{cv;J~>W<1*ehvBP7c5 z^=HXABMnD(jiXorELGbzrk+aBvV_=9MbT2KX(~!#eh95KcDdnDeGKQ2!c;U}fl)x_cHvzpX&`c2^k)R+|(+LQ1hJEVwk{8&j zw6;%;&1fR@zRE(&}H&(DoqS`=kb=l@9oA=3Nnl1Y}3zlg#%G|*c;`80#HFLYB z9PSBFdVVP&E)rn)WN?*L^8LfxaxpvsitmqDR5j3i~ZeP&#rF{j>KS! zTt)Ge^L6f=AkcmZ&BT}-WUi^gXG+h&L6~M8O^a<9F25;uO)9I#`5t7USa8(kncdWt z&qDWWXr=^kK)8PRp9@sDp3*JD($k6}d+0H+q!eKyaOS#X)#9*xm~uE25Y-4=j~++x z%11cMwVMUaHE6}!g=^uc&GOpa*b=Kb2^#i?KfwB%_l(A109SjhcJwp|*2RG}vPG(9 z;@PxaYbJAvL`pAY_eHd225T9lXK;xAXc|4!}a4VDjE)Z>b#YDxXS5_|zX+fHt# z#0^R>!kE4eC}jx^%F>F2gZ36Q)8lZ!GcA_xN*&&p(g9o%Th0*@V+-4RrF)CPvmdq_ zp&0sp6S0Wkso>!6&0B0NG=5i`8=(lfzPIp@VnYI;7T;gYVi}^89PWi}nbajL7wOx!L?G^P4T=v=0zG7x>!JtB8*g8=lYw z(EbR`gnF}y_5fL6v+ak;%j~E2aBC+hY=Oq74U2BVzQlhSbY(Ks4MRG_ypeXMn37L;0$ub)zoqn`rgq?HzD=I7}ZRNmrDPc4a0N-aq&W}Uom%WdQ)<1s&l4 delta 91 zcmew&_)U;6k(ZZ?0SG#(&ZcVeZ{#awvGZ`YiU}=FEh>(2^7M;w$xklLP0cGQj&aW{ n@hDA-2`I`hNzKfQ4@yl*EzV5Og9!wsx@4xM#cbZsV#)#lU;7{2 diff --git a/ProteinReDiff/__pycache__/mask_utils.cpython-39.pyc b/ProteinReDiff/__pycache__/mask_utils.cpython-39.pyc index 6d7ad1e03ad057c5b6b037a1efd34111aac5b81b..7baae5e914e490b4f4faae63a998413a7941bff5 100644 GIT binary patch delta 287 zcmbOry+E2bk(ZZ?0SKP@DW=`t$oqrQ+Cx7nH8DLW)mXnMzqCR>FC|Yupfo8bvp6F) zML(b@za%v?FFq(WCABy+9Vi762ugLyOiSBrz$C@U%vdEm`2xG><|Jl0X5J{Sy!hPI z#Ju>T)Z(bgy=<~PQJi^6Me!+_#U((2ZEPxxl9O+0Fz|v)fR!kM nxL_szK$Zs(7jpm!4km{Gd@NkjMgEgl@GNEYpX|r$$Y=oo<4Z}C delta 271 zcmZ1=Jwci`k(ZZ?0SI11`KJEZ$oqp)+d{uMxhSzDIYU1wH8DLW)mT3-B~L$~D8D2% zGcPFBB{MB8D77G8KcygNvjdY9BQsN#?BoaRqMM7D<(PSIaplG5rY7da7o`^8nmm_H zmgg2{UQ$tfN@j5hP~aGw3ZvBI*KAK1ohDyk&t~%iT3zHj*^$GNk#BM}$4N%7&90o@ zOspVrzsXa%yEVc=Doj8`1c>kk5n^yc9K`Yk5}A`dcpMD;;1XaZiXbjnNidM*3B<)5 dK!Ssb;XfY>mvm9^Li_!C)Dir$g>vz8M zedqgk?m3rzDO)Ogy(JH^-xryY_+Q69DjjtTG&k$c*e1q8Q6wic)Dyx+Iup_rJv5t+ zYO1+_l+Y|TETrnv98!>}CfHV9F^9(0cs!Ie6r^>gmT*hX=jjxD2U!F>xq+$NA#cK7 z2U_Y($=F1=a2M>RbKm-E1?RIM&yhdNG~o!5%bSJg7 zl}#F>c1hQ%TjB*smVpm|?*d>qgII*eNt3_I_c&75{OmsyJj0Ztjh{~+)(lyZHXZQ`+da-ze=Hm(`)f`dTNN6X%ov+>T0>Zh8G2kzY-!Acnm3<(0=yi& zTw2QGs99!hS}2vz)|pNztVAC7|KkJ931v;vvQuw?c0NRhs-8R_Y{ zzOL~Ju>@=6&Mv?AsGDs|{DR?8BGguN@Q(8l6`@nI_Fot(NYueb!-_5#XKti6E(nxd ztqTSAKniIF>P=z9Opz3ot=IxyBlqhn>v{8BSk@I)xAJ#nBN9G9n(Oza-bd(P_%bl!jRU_~bu;bUp;E@MBH1C|O%_ztM2d zE+ok6?)kH;%w#=4-D%xiOd|wM2Tz`}Q&V+WOVA?Sz{z~#Vc6UY7y=9c1{tU$rm-Y& zD<~cVJPH^F^a4D9{S4HB%|&cs2uz*CXsIe5VWQZI$8)c?IE2D-bE}V>hD2+`pcm}% zcVOAfK8Ja6eT$*1?K7Jb%EcGS*B$NiFh!gI@Xz%tkQx}+m%1@6EP^%!xBz$?um<4O z7qjZklpI6(%2jeO*f)FokjuT3g$Lw*@BK-rgc)Q#(MGi~vy^QN4mU>e zD!CJxY=9o)KL8%~nBwg*$XI{Buu4w%*Z3n~RK@`LgUm1Z3c22YWJ^5n-a(;Zqe1+M z#+ShUGJxNoDVZd9jU|z_U?|uICWipifEfTh7Ws~z1!)=Z8sI$O0{LcNwV;t-_YL_r1MoL)ii8K> XU|Y8bJK5u78@gFwz)h#ft)a-j1E}f5 delta 2059 zcmZ9NU2IfE6o9$=R~EX>(gKAF-9iPqKucR%3bY`#rHfLu6oI5!9B%i{?q0il@Allg z*cSg2f{Bp_=0(eoL>uzx4eCYvBWSq9=B{Z!cFd4Y5SgFgb!G$R3hh?JlmR|g63jPb*mQLhdy9%4)Nkp zB<){KQW?q{fPN}hNpFAg+vqASbd<&i#Xn23+7ZEO*J`842gHTiu%4i$V+2NcieL~b z{Lpl$c%$~Lwq}^d8sd-t@CRWW=?y@qhmC|y#g~`9q(%E_Twy`fE&qN;28A+cdQ*9> zSHU>Ve5_yknQhB3QwLz41AVkNO<0T21JEyG^~<9JBvB^gN8uOQ^YAxsu`S5O^K z)+j&LG-|sZ&#TkRP$6&Pk(A$i<{NMmOcCzO#2N8t!%r{BI_~B=n74H++t$@~fE^y` zY8x2T7nSl{WjiblhAI7-!k7$X9UO*n15+)g{KyFKiGl^(Rhp#=sFDTv2_t(q45=|0 z#8P2WD$2WHL2=J2Q^U5pyn4&fLlmTkkdZuaROW2V@RAw0M6wVzypwlj2rx-Wem6ca zY}2y6QWK@g0J+9mH8pmvppwd*9(q#FfS``&4rMr7n6;IdqL& zKFcxTsx_{X-H!+}%B;&Lm{)n2;?>w~zm^i`nh(8j5mjazWVm3v69vw(O6r2B4hdx~ z+sLM6z;+;43cQ`nb`yFC+X*`nGMLO_5U`E}y@d6IU4*TKYQk=W41zlW9*|=hYDBM? zKDbQFR|v}q4B*EcL-w@FK0|jzmd&Cp5%21eOX17X2kUtXI|Kyt=D~vRfJ#Aj+D65wx-_1 z;gh;1t1bHkKDC_X40jyPi}v=ej?xjy@DiyF$|}d^EaQ04McbDUek9Zq)XB!h>+LaJ zr+FJop6eONzT#5y`ELw`Hb;ls)t z>#Sc7q*?B9l^7Mp&tx=GO2CYm>5NwzDBkM4sV$`vO;?XKSxu8u(@ibAo5tXxxVbGJ zqY}X1glc@A;eH>Yue-ZlHB}@Dj_We~41trMl7SNzf2N!cIujXDw`4}lb|0M=!{Y5d zT5K+->O%nMN&f;twcMM6@7R2#+%)AbDR;37_FdYdTBV$eS(3d^_<-;sVQ>y{@%xU8 ztQqIBn%;aXnd4r@F~LVF2MGHKLj-z&N?CrMMlTXxBb*_q2L6;r*9f-=syWBSZ##RV hWuNM?jf=LPOWL^juBTNSJ=^{R-@V|_d;g9)uc8_ zx!P>DTGSShYxG*P)oN4Qtgsp;1tb)+-Re*~>BGPb zSC}ziHs~Y3j8vE?Fr)eyFk`Gx&S#_UKkAEa%-{0+e9A6e%Xv(vOWd*OW77yYf3Y%`ZcKiTSn24wut9Aahcb%JK+91!~x7Gp{q3 zx<{OoFY372t>`modP&o>wCZmuLXD+k7vh6O<0~srfEjX0-7^e`k4jrHenP!(e zyKB$15-ZTd8rLiaD))J=VV;&ecCRyZ%&=XLX*va;&?|`dw8jnTO*5(KoB) zfsjYVZRIoaCGlF_D%oG`t$%@#$HldVUj$VDnX_VNcbFu_vyJ=6Lh+Nv3?ZtBHBXW! z#na7~n-}3ywPV&6iA6_7q zcr*M&&;uF%}gTYMB3b@-_@0bfc9y>gt7tus!i(=mxabQ2z9d53kC6|> zLihE76`+3vkJ}7TLpoZ0cm4hV36L5R_%~mHikz=2CxP~h;=P_z; z2=SM?c-v(V{sfN;IY6jg)%>z-*p9(kvC#iiNV;Aue$Zb>=y6cl!@mm14<@n%;p+(B2Z)9qs6~lrC6X#v#2+H! zr;&N`-ER(dwOs@4f5KBKN4SY!tCS;6u80r%`^3kCv-|Ir04J7X38K~ajjjA#9wt$9~lf|M5TIY9plMzUQnmaqZ`R(aeQ=p`-4HD;rlYzA#r;&6_PV~Q6$EG z*NKNlm2x+Ws3*PrM07?YU2Eh2Gvv!QC|#9vyiZusXM#6@7!-et4v?Ezl44}z6D0I8 zYW-Xk#t(JlKL7vW{CRwcoE7yGsearqRWsL0>W0PT#FvYS#_x-lCoVMo62)Ew3a}_5 zu^CcTT!;}uZiz3)w~}9rOYvTEN&Gk-BQJ|meBkJ>Q2RE*JVG79rwHp!uG}J}U`q4S z*%DbP5s$AY3xlquvyA@+Z7w2Q6$cWr$cs1$U9k4Sby7pZL?O!Op{hpG3R2h~CgO5p z@AOBoucD^>uKo4n92uCt|NR+=j+d&FN!>{&lR}xCRb*`VnW^H?3$g6rMYKXQtC>XS&bQX*(^ir7dk)#3HRwWHBs(pyoj>)4tbId~c@QX#uIS z78DY}dWZ>>r7S9#h`z4h?J(f)!uEygu}4ST{>@vM!JrrvrWs^<8#SgIWutDAO}be&>lWFfTV*RF zok_FlcG<3H$ys8}G>7aEzE0VRZ?opov*m1|vuIA;ExUzm)pB%?>_OS4*tJ|ePtMc5 zvRAA(XK6k?U(Oc`hgP5$%7sF9YDIdnT+9q}4T?+29yTc1s(Vezc9UF+j9bY;CMU!A zk?|2N6f3)Tj@dpPGI9FT+=!On_oycJj^kPPI=jLm@!o~y+v z`gXFO=-olv`IvdYuXJuBk-Cm&`)hf~Pu)t{*TB}>Vr_t8T##37OX&n7wPdWw#N}p{q zGGRJ%iu7IS%zcu-BtfFs$qABPBRE8G7>-Ece9?2dY&?r#H^C5mEtUHyZOVn+S9zz_ z7o+h%4$hhyHVRW~9?t)7c5*x?3d(~+HC@&*B6fkL_M0UXRC1&bkFqQ!|Q?G}!HCyV+OkV2*#RCbu<@ zIhfDnvq68;TzjfIk5C8xG diff --git a/ProteinReDiff/__pycache__/mol.cpython-39.pyc b/ProteinReDiff/__pycache__/mol.cpython-39.pyc index e87dd9a2d3fff176be1d8a3ec952044facc1a465..7f14001d6c3c780709771cdbdaefad03cf5af511 100644 GIT binary patch delta 89 zcmew^{8gAQk(ZZ?0SJBvDW*N)*vMDLV&kZvl$w~HlWMGAlwVq*pO=!SA5fZ(2^7M;w$xklLP0cGQj&aW{ n@hDA-2`I`hNzKfQ4@yl*EzV5Og9!wsx@4xM#cV#nV#*2tApjo; diff --git a/ProteinReDiff/__pycache__/protein.cpython-39.pyc b/ProteinReDiff/__pycache__/protein.cpython-39.pyc index ae2422504f0c055bd08994494e9a98ab70250d87..05e7909ccff64b0855da31b01d7fa317aff47895 100644 GIT binary patch delta 1210 zcmZWoTTc^F5cc#&FSJ6j1;HC$5;q3QMNvQ{VoNl!1&g=TOWCpqSlDjMY#XsgQPHUJ z65~lseB{j+5)H2!yE! zYAUh%!F7?b+-isCE8EzCarfbC(Fs9UkU0-@takc7GG+^}|3|D7ZIB@7BIqXQ5wXBI z)+fxsQk;U=UKK#SVtaM7(zjP>fPPUIOf>wrYb4kmlJPcCl%%*7w4#?#R9v>dtQGQQ zwK4+(qApZ!kZL#NDLR|NfD^HH@gmg1H1RI9vS$MIwNNUPOeart?ispZ!XRma>Ct4$ zQ7C>;Q3VC@yuPhDOQMYFlQChE6#e3TeJ*04h#7dwSSc7-?F|muFS4rf#2$CIa1`OJ}YYFw7icGvXs z@A}dHQGyHs5KxA;XSSr9N+T5UCl{cS7`n+VxJX8(!%_0219J6%^lDX}H`OAC9O=)K zJ|NCCe2HE|QE|&K&UFXGiN@I-q?T2&y_Z3+aK-bFOFtPW#plNT(HqWooixM%ab^*#2Tk z<3?WBoYW*wvRyWbOjBQUi8f)oSbV<9%}I`gJ!@)YY*sA0n~;as8x9yl=mUodh&!N+ zVFW=5Kp)9x2+k77@~2UG`vd+r`7ZW7!f zSRfcDkWF}!q{|3)KvOH~q@h}tyuR(WDj>%OH1eTI0ZHHul86hqOF$Q}-RY76kHyQD F!Y?^<89V?0 delta 1236 zcmZWoO>7%Q6z=#B|I}{c#7QVoe?(DP73Eiu%7zA-bwW_%q)w|=!HKi+PLfURwflC> zk1LTv0i~cy5yPEZZ>>~CT;RZ&Lys-eoRKJ;IF}3a&88_5v)XUpeDmJCx9?|f?7Ojy zl};y*sI}X9#IFrL%-jxLn3>8~u6mrj%Ej4I#cbBv4et70WybN}Z?9DrsOfXZt(LjX zJ!ivxPL#Rntgl!0ABC0~ll$i)Pua^u>gG4q>%m7dlw`p}L;G*Uo-p=Igya9@hSeF3 z07e00fN{~2xXva-J#m{2iN6!QEH9$T;jaIwnaR$)?3e$4oXXem-`52f2{xa&d9i zyjrZ5-kU2f>I}+o$7vm;5%vp!d4LB%gfg<#bX=pGT(wzBAW?P4H0nUnG%h#z9#tHg7~0&>s8P?S2FrBIX*Wc|53@~&=$XS zPh>w2%(uX-0t(_(_Wq7e$CTQ4Qabym^r9QkAV6EeQY@$y?Wm0rl?m>uMkUKvoeQ1A z80Z!hV)VibB9QS$)8_T6V+Wz>It{{bP+aPHEBhskR1jPdCeLYn4Md z4QX2Z)O)`EF2qRyDoKc+W)v6+nt*%_a2=qFQBsnOS!B5zJgCuSoMB{)m*^|VqX3K! zy_k`M@mR^tvt98ZSBM~4Vxc!L&i3VUGGtj4`~}PoqPyI8K7A8=Txc+IRq{`YANzVE z-@>>WJ}&;~n||#uY^5>|&A1L+COe#(p6ZZot_HK*KN5dlcJWPrDZ2;NGjvydXcn+8 z{^}o5!~4p>>)Urgd;$0x@G0O5KzC&s(kBXJ!nS;?R<}G)XF!HMpY$9NhYm(3s3=;7 UP#oF?;0I)Qu34u&@nB%>U#C4GN&o-= diff --git a/ProteinReDiff/__pycache__/tmalign.cpython-39.pyc b/ProteinReDiff/__pycache__/tmalign.cpython-39.pyc index d538d40e2190ada36d3f77cc90961dcecc1fe585..50faa5a6395821f51dd0b93ec496b25c7bd4a841 100644 GIT binary patch delta 631 zcmYjO&2G~`5cWED;%t5zaiK^EZ7-;iVmq-L)K&$i$MplR?q4p+G(7)6|==ymuZl;HqPodG&$oDR@?2R%k_ z!b}krjNn;@ZOLqsVU%+w3C(`qtP>iC3Rti-PskM#k{1hEaDrlz5gtm)GW3bcCBig{ z3ZYzk9tu^p;#XC`M}ks@Ng^D0a(;?VD5F`#c&ey{K;FRsdA^VKy`ks#9k`S-7ssxP z8Fd(s-O^{6!>{4?LTF-$fk1TaKw;zwos4c! zL;FOS5>pK89*f${V5V@SOI_(vPv>1eXv#77x&0IjKQlHFv))W*Gv@`_z= zMa<9gEQ@(n5>^S)%0%SzVy}8u)4D6}|Jc`anqe3+dB~!)WTk zKY)#%IQSnp`$tH;c=H#Kc=FwC7u;sknP=XQ=Xsym@44HB9Lr|y1U6ZhrxWo~Kk zgLXX8NN|fMJ*$^&B&~3r@=}dd^q+2|G5u~Le0lmR$tQT`OS+M{!zKxv;?smp##j_< zU#0jApCx=M!l%)uUsxr3HNkWI0iDd?U|brLak86yoZt^Bd5+Ih$ShCQ6ZwVUXEKqn zw>aw?{pT(jVhw#(D#)=A%A}!dZNP! zXsce|!@)Ud;|FaM_y`@8=rbdfgfn5fz=i7unzqFj)6_OTKSk&%j-D>r+Xp+aEtHV1 z@_Urp)=*h1K&4uRl}fo%&6n)1F@3%1I9!R6@;lCIxwN{v@yyYym5uem^CCqR%j?zk zwc@&MA7NiXcvhDHq2*Usnvh4*NHE7Y1 z!WC^M4S`afMj|z>2VsPeolUz4b)(u{rbz_zybI_teL^QRQWqJ`M|>EKM!E4YLF4-X zWg9hRN8XIb&<7&T$pJV(CP&t8?Nz8^jrpvH z2LJCog@!ec@121aA+BI?FG@YESKXsLj@2?@Kk{X7-kMX5v-YX)6k z1)VO1kl(SKh}`|P(0w+NxaZ-yXpyS;B9IOhn0{bjCC=5Lq<2xct5v+Kqw}`?C;a&G z*3HCmK$q_@pS=84f3f4VkkPQJ4skhmB$rMtXn8_+kvn5Yn@31Sl8rGivUvE`0&8*j z(fVX9UH`S_Eeo!*sjg)`y*^FfT2^@|Gs|}NX0qAjs@-f3zhyd^MdFa2`#SmPA8>R} AKL7v# delta 286 zcmew^JdKYpk(ZZ?0SI13`KFe!P2`iQp9AEjFr+Z%Fyt~uF*1VKOgT(ZOkkQhiaC`z ziv@&JnCCF1u%xiIutc%K#n^yi>?s^TF*cYOSS{xq<`k|JZlE|jSezr369_@3@XTRN z;Z5NK%5Z^YltIQbf|)>88j~bL3O`V@K#Cwx33m!pFoUMh Mapping[str, Any]: return self.data[index] -class PDBbindDataset(Dataset): +class PDBDataset(Dataset): def __init__(self, root_dir: Union[str, Path], pdb_ids: Sequence[str]): super().__init__() if isinstance(root_dir, str): @@ -203,7 +203,7 @@ def get_stream(self): def __iter__(self): return self.get_stream() -class PDBbindDataModule(pl.LightningDataModule): +class PDBDataModule(pl.LightningDataModule): def __init__( self, data_dir: Union[str, Path] = "data", @@ -214,7 +214,7 @@ def __init__( if isinstance(data_dir, str): data_dir = Path(data_dir) self.data_dir = data_dir - self.cache_dir = data_dir / "PDBBind_processed_cache" + self.cache_dir = data_dir / "PDB_processed_cache" self.batch_size = batch_size self.num_workers = num_workers @@ -231,7 +231,7 @@ def setup(self, stage: Optional[str] = None) -> None: def train_dataloader(self) -> DataLoader: return DataLoader( - PDBbindDataset(self.cache_dir, self.train_pdb_ids), + PDBDataset(self.cache_dir, self.train_pdb_ids), batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, @@ -242,7 +242,7 @@ def train_dataloader(self) -> DataLoader: def val_dataloader(self) -> DataLoader: return DataLoader( - PDBbindDataset(self.cache_dir, self.val_pdb_ids), + PDBDataset(self.cache_dir, self.val_pdb_ids), batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=collate_fn, @@ -251,7 +251,7 @@ def val_dataloader(self) -> DataLoader: def test_dataloader(self) -> DataLoader: return DataLoader( - PDBbindDataset(self.cache_dir, self.test_pdb_ids), + PDBDataset(self.cache_dir, self.test_pdb_ids), batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=collate_fn, diff --git a/ProteinReDiff/models/__pycache__/AF2_modules.cpython-39.pyc b/ProteinReDiff/models/__pycache__/AF2_modules.cpython-39.pyc index 985840bbf1330996b42461112aab7de787f2f9b7..f5ec957f8140c3df9f1984e4c5121654fa45c904 100644 GIT binary patch delta 4234 zcma)9eQ;D)72o^zd$ai@Ng)XY3Q2fL*Z_ruHsr%lP`-m7uX0$)y8(cX%58XRDG&a_p_Zu;zgLGCIY@bJ6FRnAe zRB=FA-d{>fzrUwr$kd0^<^vr(TR72S8CJ)h!caP8A2W@PJv^H?QZG)y}+OjE94 zulaB)nG`Q7t)-3R#8$8i#>>U;lyMdj`@Nk3IOnBe%3EJ6hh_BaL^hM-d?K64^)`v? z-mQnDkOp%m&lK{e>g5%%sw5bUBvahZYbndHxpYvCDe0wZMqf*&EIq9msf>L;P1*qh z4~w!o#_PpV-@`qyQ@xJQDr9WOn=RxWX1UJhLh@@lo;6xvACmSy7Ge>`my36OHd`U~ z`wzFnJ6}QIu;hIEGA*H}6NR*%H?^d0#f0OpXRF2E{Z2^(S=_jF`}S?VO1vFtY;S{` zcEB1yjKJ~PsR>hMPQWn7QVCP7;7AEyEt-Rk#pN^{^zt#=QLG#-x0T3axN5|?apAN%3PN8kZdwv(en4{f6-{!zd} z(L{&!e99VjqZlNuIrAOU*CK}*+cI4j{gW^H2s|)?INv2oD(mQ|Hdn4M?j(s5(p-Cu z&qmCp{>LyIaLdm7;GrL|6|jTAiQZiWjxz8j0yRMVGt)`F8?}K_@8?BX^+@?KnpH!N-^ib&zA_-rRyXuYGx&7#0SBs2 zdUV(chZ%es=A^e-mrjs)Hm-ByW$}mVwY!F46W~Xn_!vL~U`yNvn1DGCb6LUGpQDdt zNQjk%c~TygQ#G5m%DN*B93^tfNAy$%w3;dJLe%4jN%M#Brw~w+J}j<9s;8s1oUwh= z#rrNO`}UIjqHB&5vo;OmSvpA?jBBHbP_NDI0cvlK%MUt6Lp(-K>bWBX1?Rag*Ap8sS9FBE3}g7y5WtJWhUc#q(#E--(+at2@yn$h?pjZZz%q-S zoh~eE@l%6)X5~~jQlKhMz&w%DEkiCw1*?K%yiaJ1J;_Df73BGEc4Xy=q0*V3d?CEt`Y7&7Xq)mjGV| z$c;)$pU0z&0bFkpo0=ZkH18xFg^CUs0?30Q>22UK06={FoVePwWhLDF?;`_Rz9E{L z53+Gd?y2UGx8+fooF<5?l`ea6s@#R-`F1%guBl6kNt$t@Y5kg7m>jw6BQktwS5nOe%)A!1$R-&vD7}1mh@!ar24hfeG_quG7yb>f;!;*GR z$TG(@*>mN-_GwoAOTVG)}UmtxyQ z?Z%_Z#BW5lv|fyNTrgW*e*tST#+UOJN`d`td-RW#7PSnO? zF?}e}jxQszrdK{13F^Gx#e&g0(Kz9PWz*3x>6JST;r9SIztp5s7P`ypkYFSCPyoeu zl(aQ>5q11@F*iBg(zAjAw@LQAuR~GcKc+9o zJF3mxrezaEBkmOM^(33`*8B-HeIsN4rP?oe@BLJfulB}@S4l~B?i>Hc8?e+pgkc;) z_qMYH3bJwj0Hbwq>ZSJte+47@S&`TAd*M#8o`KYFlW>H8pT_F4d$xw1kyP%2Q-Bqo zL8^;}gPPSw#`x1>Z(j?0P)zi#3$#J8U0m)v$Oc8YzgcdMx(W8hugrRR10$rt@oDLd z-YoX_A7pQe@ANMT$feGTxB9b1*?6b~Zfl~1*JJDwDx>E!g|vGfvTzi%v&no$KS63g zgWAslZvmu7H!wWPbW4zR_y);LWv0#^CE|@L^jlA<@J81u9Fr=OKl7-JD$G|v)Z|_z zWGEr|>nmnibWi;AN0zdUB0SJk{3djt1sDJwP!5oLSdP(ZKsTTd&=1%O*bdkM*h%07 zc4v)3+H{XSZf^2DBioE*B7YMyggUpk7I!F99NL_b4tnux+3qI@(H5ihA7bM715K49 QRXcsuyxcd9@_w}MKPM-+L;wH) delta 5670 zcma)AX>c6H6`r1%on4JqD_L0BmL-f2$*d1aK9J=DVV4gsS++4?thgA)J0oeW)$Xci zCELm>OJd_-0!d7VA0a?wl8^`pp`wtG!-P-;ggaIF14B}kq~M3CI6v|)PHghNo|RUX zWs2R}Z@OQ1py(y2Pmx@w&3 zPrdxelA9D~nX~-3;w*QyL-ohBuEj)I;jARe$|7Zn*v;mZH$84R%}&d4y{m!xH8t!| zm9y&ckh9uZL$pSR?N#Ef;z=g9@VeWdpcGmSn^1&mtoH2#R=y85?vzOp@ zF45b3(6xI~?waPblWyM4GdVYzzTMrH?C!qZ&1RaNY)U+$E~}`ZNuEleFkU6zR1dRq zu|;b&AdJVvh&KOf*{r0#Z*EE@y<8&GpKZTNJg;rvkJ)}W>+=4=oEz8pEX=AV=rX&L z+{;#tQ*b5%{b-&|ca@%u3j)#9+$KxmLHGt|KlF zD^^%UjWC0R;TKaQDvyhQg#7Y)8asQDEYBOn8^(gJ71ZY|0Zo8r0$=x%!)~1UhT{$; z6K=eU!!kUMF|FS|;_Ll(u9w&bc#gx`zUJi|zK~j78Xxj}HJ#;Au{k`yAWgW62I8h4 zKI|rYdUKvs!j}>;kb50krFC47ZVQ1Q_U!&_$}O6_wCU;`A515#bcXkn#gpk|&RRzN z-qi#OZP;WnvZK!Gm>FTS$p>W6P2!_Sy(mQHG;;8cF}VtQYzlfC=rPr}qCGN)trPD@ z;xuXD#v3r)4!Cg|BX#}CUiXsKY$W1D7;YiO>sh_1DT`N4pDQ}c=2d_=!MsD%n)}4t zWtE%PK}{Qkhq{xglr^K?@*w7<1-3$rPM~1Bx00Z{H{3)$eluXYW}^Ag9y^yz_XGw6 zS84if#4_|d6O<+p;~NDwzD+T2t(?*wHRK+4;O)*gZ}=9gJk zB2{C1o@E7&M0)bO3GNgh%({i`6pN}mH()`&afh4sGF%S&6_`iU5-yp!I90~l0nC>| zu&EC5O!WfRA>OKft*{RxJphRJ^#pC#*W9ef;RU~Jx6SQ-H|KJ>e_zj}DTw!DC^YQy zj3*aQ1?Ec-aR|@{kP0WmT+047(LIwuK@u}Vt!Bg3A~t(pGD);}#1A>Sqf~*UUbCCh zBA3ii9rDeTebnVLDgzBr9~cd~;<{Asvm-up_yUYc#gp3oaCe3uwz>0N@!sr~j$X_$ z_&s2>0QUm0d!7Sy14!iCrpM3@wzA*mJ+KsGG4r$8?&R)Q|XY%_%_$uIQ z1in6$bPopw5!(H-9JkY6HsiT`4VcyfusCvqT*%i!c!0oX13v5X+0YaRAZ#fkvfRyD z1ARkh74dXkn<1y16d%`(9FWUtfVeO}4oN2f-vG#BFD+1o?lFR5xbRaTK#I?LrfZ8> zhq;}#2D7R-Q{NoudO-ZUKFv;!+t(}&Q4Kt~_|!(2fKGkGJ(9K4j#Ne-y9$DF3{e|< zl1c&%G6PZC_li50tZm19{vZL_0i}V?p22>OJn>D-&gI-RGJrn>HUYNEspb;zELpgT zKSuPfC^4t4Ocv}Oe;EAV1Rz?2jmmg_0$sTV1bU5F(y*&-S{m43IS5DqWLC)7cF{Ee zSRWVS*@mr)A?C{Yg1CHE%xc`nj*KT8>$IhJg8eH5o$*;g=spFCR3b}95U}F8_`Cv= z&5xEc-|84#kRUKP(TJGk$`yy|F!j6axH4(UOw zM;Qs{l~Fyfo@OIzE<_y0F?B>GdZfrbnvaYc7-tG`hVzU#S)P%M2qDWy#uLkLUYWs$6= zkHMlX;=s!JbM-1SY^w57sG4-OQSrSHDH9K`tS$V2=Ev2}PFnZ0h@2LJa^IgHcL*vA zdXZ@nLt*s&Pizz61ywpUa)E;9awhpmJEl^*d`de?ZOvg3lVnF8w?3Lm9-Huj7s|Q|>4#Ico_C zQlfPC*G$c%>u`i}C!(rwnikXM5;ew5Eus$34cu2Kg>F;v+7oB|MG_Kconz+o?6PTK zdX@hIDTYY>1hIthiZES|%~dYS|h70^x=$7JstRjee znc_1=@br+!^pXRlp??S343XXn;{}m!79(9e-_oe^Gvco;FGl#wG_RAFi_z6}?7TR$ zy0&ns&-w2!RVEL?S3sNeFaHC&GQmn&hSJml-PcfK2FFEWhN%}_Ry#H0#ECVt+KYbT zuU#rx?oS@vuhYOPh`-G67y0W&yKikT#zH9+c4nMpchYrA4oQTtzTyXM{NQ=-9`Wa< zo$II7uqbL6|0Rv{Hvm^M`~9kTqNTQB;nba>j7}_|R8yoQ>e(|?JbDWT71p3C=&~w= zCaVHf3t1J6(Xd0u<&>&mURD*t6JDhi4$qmI+d<;nGV_;X{YZcQM}l}cB|BB&Qnp{& z^T_Wn;QsHyKi3lP4DJudbjsO@0Jyejro!ZLGij&Ga`(zH@g`}y?+P)2F}ht&yAtBG z_M-!qJiIM0IUZjdS50XcN@}G^yt24f47D~l$hdwSGF1K+CTRm!@g0T?N1U*K6K}P4 zH%{ez8=QLI&}FQjio4b=qW4%w*EJR1BPLm*&;Ad8A}WIWN*V4e!5_<7FqG16HH2z( z)#fjtMK5%6cA}SFtP~6Llc0t?BKAoS#;=)a*K!w}7FeZpQQ@@)=xUR(`iA%^(b~3@ zZ4mdgtuj_X^Gfk(+dk%tb8U@sb9C{-zWBQU*n+WB9!$Bx1*Qw+ zBY>v>^4^2W#u2XkLjV@NYRWbz_v%MB^-EFhFscV#ct40MOhgqgZEToRP{P$6)w`}SAHDJa$;IjD diff --git a/README.md b/README.md index 442d158..ca0d0d5 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,7 @@ python -m scripts.predict_batch_seq_msk_inp \ Download the PDBbind dataset from https://zenodo.org/record/6408497 and unzip it. -Move the resulting PDBBind_processed directory to data/. +Move the resulting PDBBind_processed directory to ./data/. Preprocess the dataset: ```bash @@ -140,7 +140,7 @@ python train.py \ --num_blocks 4 ``` -Please modify the batch_size and accumulate_grad_batches arguments according to your machine(s). +Please modify the batch_size, gpus, and accumulate_grad_batches arguments according to your machine(s). Also, use the flag data_dir for directory containing your weights (ie. "./data") and save_dir for directory to save training log files. Default values can be used to reproduce the settings used in our paper: diff --git a/generate.py b/generate.py index c2cffda..a0534f3 100644 --- a/generate.py +++ b/generate.py @@ -1,3 +1,13 @@ +""" +Adapted from Nakata, S., Mori, Y. & Tanaka, S. +End-to-end protein–ligand complex structure generation with diffusion-based generative models. +BMC Bioinformatics 24, 233 (2023). +https://doi.org/10.1186/s12859-023-05354-5 + +Repository: https://github.com/shuyana/DiffusionProteinLigand + +""" + import dataclasses import itertools import warnings @@ -18,6 +28,7 @@ from ProteinReDiff.mol import get_mol_positions, mol_from_file, update_mol_positions from ProteinReDiff.protein import ( RESIDUE_TYPES, + RESIDUE_TYPE_INDEX, Protein, protein_from_pdb_file, protein_from_sequence, @@ -25,7 +36,7 @@ ) from ProteinReDiff.tmalign import run_tmalign - +RESIDUE_TYPES_MASK = RESIDUE_TYPES + [""] def compute_residue_esm(protein: Protein) -> torch.Tensor: esm_model, esm_alphabet = torch.hub.load( "facebookresearch/esm:main", "esm2_t33_650M_UR50D" @@ -36,7 +47,7 @@ def compute_residue_esm(protein: Protein) -> torch.Tensor: data = [] for chain, _ in itertools.groupby(protein.chain_index): sequence = "".join( - [RESIDUE_TYPES[aa] for aa in protein.aatype[protein.chain_index == chain]] + [RESIDUE_TYPES_MASK[aa] for aa in protein.aatype[protein.chain_index == chain]] ) data.append(("", sequence)) batch_tokens = esm_batch_converter(data)[2].cuda() @@ -45,7 +56,7 @@ def compute_residue_esm(protein: Protein) -> torch.Tensor: token_representations = results["representations"][esm_model.num_layers].cpu() residue_representations = [] for i, (_, sequence) in enumerate(data): - residue_representations.append(token_representations[i, 1 : len(sequence) + 1]) + residue_representations.append(token_representations[i, 1 : len(protein.aatype) + 1]) residue_esm = torch.cat(residue_representations, dim=0) assert residue_esm.size(0) == len(protein.aatype) return residue_esm @@ -62,6 +73,23 @@ def update_pos( ligand = update_mol_positions(ligand, pos[: ligand.GetNumAtoms()]) return protein, ligand +def predict_seq( + proba: torch.Tensor +) -> list : + tokens = torch.argmax(torch.softmax((torch.tensor(proba)), dim = -1), dim = -1) + RESIDUE_TYPES_NEW = ["X"] + RESIDUE_TYPES + return list(map(lambda i : RESIDUE_TYPES_NEW[i], tokens)) + +def update_seq( + protein: Protein, proba: torch.Tensor +) -> Protein: + tokens = torch.argmax(torch.softmax((torch.tensor(proba)), dim = -1), dim = -1) + RESIDUE_TYPES_NEW = ["X"] + RESIDUE_TYPES + sequence = "".join(map(lambda i : RESIDUE_TYPES_NEW[i], tokens)).lstrip("X").rstrip("X") + aatype = np.array([RESIDUE_TYPES.index(s) for s in sequence], dtype=np.int64) + protein = dataclasses.replace(protein, aatype = aatype) + return protein + def main(args): pl.seed_everything(args.seed, workers=True) @@ -75,24 +103,30 @@ def main(args): model = ProteinReDiffModel.load_from_checkpoint( args.ckpt_path, num_steps=args.num_steps ) + + model.training_mode = False + args.num_gpus = 1 + model.mask_prob = args.mask_prob + # Inputs if args.protein.endswith(".pdb"): protein = protein_from_pdb_file(args.protein) else: + protein = protein_from_sequence(args.protein) if args.ligand.endswith(".sdf") or args.ligand.endswith(".mol2"): ligand = mol_from_file(args.ligand) else: ligand = Chem.MolFromSmiles(args.ligand) - ligand = update_mol_positions(ligand, np.zeros((ligand.GetNumAtoms(), 3))) + ligand = update_mol_positions(ligand, np.zeros((ligand.GetNumAtoms(), 3))) total_num_atoms = len(protein.aatype) + ligand.GetNumAtoms() print(f"Total number of atoms: {total_num_atoms}") if total_num_atoms > 384: warnings.warn("Too many atoms. May take a long time for sample generation.") - + data = { **ligand_to_data(ligand), **protein_to_data(protein, residue_esm=compute_residue_esm(protein)), @@ -104,10 +138,11 @@ def main(args): trainer = pl.Trainer.from_argparse_args( args, accelerator="auto", + gpus = args.num_gpus, default_root_dir=args.output_dir, max_epochs=-1, ) - positions = trainer.predict( + results = trainer.predict( ## (NN) model, dataloaders=DataLoader( RepeatDataset(data, args.num_samples), @@ -116,13 +151,20 @@ def main(args): collate_fn=collate_fn, ), ) + + positions = [p[0] for p in results] + sequences = [s[1] for s in results] + positions = torch.cat(positions, dim=0).detach().cpu().numpy() + probabilities = torch.cat(sequences, dim=0).detach().cpu().numpy() + #torch.save(probabilities, "sampled_seq_gvp.pt") # can save embedding # Save samples sample_proteins, sample_ligands = [], [] tmscores = [] - for pos in positions: + for pos, seq_prob in zip(positions, probabilities): sample_protein, sample_ligand = update_pos(protein, ligand, pos) + sample_protein = update_seq(sample_protein, seq_prob) if ref_protein is None: warnings.warn( "Using the first sample as a reference. The resulting structures may be mirror images." @@ -155,10 +197,13 @@ def main(args): if __name__ == "__main__": parser = ArgumentParser() + parser.add_argument("--seed", type=int, default=1234) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--num_workers", type=int, default=2) parser.add_argument("--num_steps", type=int, default=64) + parser.add_argument("--mask_prob", type=float, default=0.3) + parser.add_argument("--training_mode", action="store_true") parser.add_argument("-c", "--ckpt_path", type=Path, required=True) parser.add_argument("-o", "--output_dir", type=Path, required=True) parser.add_argument("-p", "--protein", type=str, required=True) diff --git a/scripts/__pycache__/__init__.cpython-39.pyc b/scripts/__pycache__/__init__.cpython-39.pyc index 6564106a570475e4d6981590995a5d1ffa2c6899..16ea93fa6be8acc0d98db51df06eebe31848c5b4 100644 GIT binary patch delta 105 zcmbQlIGd3tk(ZZ?0SFA86w@a1Scf?1C#5E)=cF3z7v+~$=;x*6=?9c1lIYq;;_lhPbtkwwPOSt`x#_5 E0GVGQM*si- delta 68 zcmbQuIEj%bk(ZZ?0SF$|_@++evDPuyFHSB>EJ@DLPfATp&q+1b&r8YE4=BnnNzKd) WN_ELhOAAUZ$k#7PEiRc@BM$(rEEYuo diff --git a/scripts/__pycache__/predict_batch_seq_msk_inp.cpython-39.pyc b/scripts/__pycache__/predict_batch_seq_msk_inp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15895637f78d926f3f4aefb15ca45d94816cf295 GIT binary patch literal 8287 zcma)B+m9R9eV#KHa(LV2UbU;6Wl5Gr*{hW;S#}&*vgDO)Irc`8;KQ~7-Gba>o39-eScgeRSos;&?`6P|KTh0i+AhNqp=Va2J0XPh(PbIx<>eKB}G zTyPec=Fc?F_7-_*harE8kFBx;nsbgb?*-g1@^M~%s5$5T)BXjte2GtdY4Axt_0V)K z`X$umYWlp!r@v%;hR>q*W&Y+~z5fcI`%>rg{NO{&`CeM@zgjcU;}Cir=10)u`xyH` zhaUz+?Wehe)x)6D*>{ArAJoS#6=6||$e)BGgX^9($U&FeF z+q}V-P{T0lpx1SNqrWy)(|c=2b8h`r)cCr7D@XJh@h}68k2~-b&yn<~#LBCf0k+daIdOcROv=8CRPd=xV;P=7%)I zk|(`)S{}y`W>UV{tog!kR(-0De3|5{tswBE@79{hSWsW}n%tEwm*aIZ-WDzC*PDY1 z3^#czt-9?m)oZnPTig#SMlyjlx?wAD+pVZBv6v_+QL!dkpvxjL;i7wUdFlEGNwz;_GBa3?8n(?>(xFGIBmZ89zAu^Xw~lD_ zgU0xRg6QCbLJ!~COw3wKtZ!y6gF-F(z&I?SZ#MIc9O$8ND*>RsF}Os z$?BS0Z8bmf1T!C8qz07dTU-sC`{3dQ$MJRX>1Hp|ew@ zVZETe=B@bhktR;!cTH zW%~*h*c-GYBYfEr%_`|>a21iUaVwpjf;OcVGjg<5t*cAc(_^h+?64*6ql+KvU84)_ zL4!Is9%^`F(c2groY$7LEw;nDR;>3joW)j6C#bzQ7c*``!ws2@^)1G&d97#fXx&V+ zxL4E7$XqYKqx}j%FAH%-78=D>#xvV`H~T(p$i#08K-Mh3phbUw4mLu+pWu$OJUfp* zjUvV^VO$&Q&dD*F3u}HJYtG{~v94S%6Wei~+j5*%jk7v8>H@P@U!#z#O zQ%Ib#`T$0iD~cmjnQwN&!Fr5py8~B64LKD+Iq;i_QT1eEM)gg`5fH0?+da7)#`#{tzKClzW$eB?U5Z1^iI|l~R+xH@2b%5TUWcRCN>Iu?+8LNxP z?=-{KO*YL;HusHbl;Mt(dXeSXSvHBfX?C2Av4c!pMN8UjYF;Eg1k+}hD${;X%kR^KH}p814{ zIkZ)+;C`7nv&Fh_QMAgYt-j@W4@VqNk6BH>qdm$lCro_BE+;Ji6+0#_(O2o>1|FMw zDPPGZW(($!E=zM!ZHHlvxns4*AT~phtf{B^mz+O(GgygBna$WHyvC*MHjVj4p2D-Jhrd zUHA&Mh9YfrZrc}CxXK8267}9y(LbU<>cMT`HsSWEntlr5TToM@3nBN_w2a)7)HJMs z!*NEb3D3ToY^ccwI_q5<*l9yyx!8chfKot06FFO{?(nx%_vMJ{j98D&ngKW>)Q7C! z!b9Ao zcCwGoVeiv>kP~(^Lz=N6E$ki!u@Mfl-yt0sa49>6lVO*eY}RBB`qtr=?~prUjEh}M zjbZijvT%*wt^dA zly^cf%1BJ1F43E9$J+EoD@qD)`f|AwUX`sdO1)2+0vSLhBR-(k*;J5`V;}?9Rdjbq zr~4eQ)FvYC_B#oTu>v32ei*M4xEvhUuZRO zIi7i_@CX}CR&6&nrLfTK^5buCiE|h|VIPPJo+D03OQ9vucSq!npTPm=LGvH6W%aR- zm>{wPDl#$3m6or+KZNu4rN9eUxOe$CXe<_xYyw771`PxnW<+VbN>0%cuoSTxdJmFp z)T-fR^B~Dc5QdfzMBT)BB#9Av?L9!mc2J|*OiC@%HLO!(MbZjGvu>j{BH?PrQEV}S z?6qQ~%B3+fhl`*K+j9 zYZ-88FlUPsx2-U70i7yi%7#Qp12idVQL<|}qD{rUtyIB*6Km|LJ$XrR!g=j4%}wZ^$5(eJvqF zqLaFWr*Gl3C(Se$vuQfT>7Q%Lzv252ZHY4;ou$AwV0YLBQ^*B-03eR&i3X>MlusW+ zM(Jq|$1ik^*=L&KoO%-on`#W=R2Fs;ZQ&nERYslTKmlSe;QA%cqe#Sn9#T8yZGzOd zLnUQO&wm#|8Fv4HGL=1pmSzOPHm%UfL|#ZSAPePHUl5B=3@Axlr11#|5CbwGb0^>D zfZnR_u0GB3Y9^;lm?@Y|x zag@t=mgNL;lT_N$KQy}PR4=Dcwt4oD#%*wzyWlR)0cydo?8 z??9Z5Gw3xptL?xhz(A=!k9w?6v7+C6i_84`Fqc8!d5Q~t!9<=HfM^F}V+$6D^ZTJV z2Ybsm^YRe2qC7ZtL9IxS^IO2Nk1k{F*>160s%!FaT-suPWb#N{lt*Lw`lq&ZUfkCI z$hu=V$sdam@TIj*tF!(W?A}Sp_-Sf~`#5f!m-rYT-?qBr@p$7z%;ZV=OkAKogLQ4` zD4oLntUS%j+a{lYY|AQ4ti@&JArs1PrV%&>GgLK3U$QfO$*$~5cA_uYYrB%2{4UAL zN<7&+vqL^OwJYVRRLc04{~VF?D` z8qPJ8&Cy{zl72h7u{pbE)G9O>omI!wgx%j9yZmX_O`XR5w42{FkR2Cx?u5X&muWYE^rKq50zOVop)}D^bw+b6Wl`zbldv{1HW3rB?hAyQk_>{(p(n-9X`n)Snouzk1<>l$ zAc*veaZ0XILdHfP03OGk>|L7M7miKlgD`O9AQq~4p(s#m+C0S;B@Zbf!%{rW$wm;P zk9gCx_zabfr2@z5P^*r&PIz7UtG zR8lN`r5@m)0C)<8vs@)9MFx~*{(=T1^(Q77v_DJWjf1zqwAAjo93)V0TBA>O#k1Et)416h1~y{UuQ~V`lOMLQDow; zX;$0CF&rU3ceAvxFckY%0Y)4&OcnftvnwCpEG{hUIk7r7ZRQsi2Fjo@hWvW7kQNpA zJboS1QO&Nste#ibzg^iJUsxD5Duk6-By#psbpm?Z|o%oh7qWzTk}hL6yEx{-L~L z*Ug-%Tf={3olPA>0m8)oH-VU-6--hxMT?#pZ2{w^f5}qLJVUK!DLFt1h3LdQlKq|f z|3Ha@^lm)FR#L|}*jfA)NXNNmJT=vP#~!-y{L zmURkdPP1cnQ7`Iw6aKB+<>`a7;|C|_Ci1NKb+HKl{<>&qD9z_>`-gd!F(_bAMn6i& KfC7V`t^Yskrw-Zx literal 0 HcmV?d00001 diff --git a/scripts/predict_batch_seq_msk_inp.py b/scripts/predict_batch_seq_msk_inp.py index f443d06..e0a0551 100644 --- a/scripts/predict_batch_seq_msk_inp.py +++ b/scripts/predict_batch_seq_msk_inp.py @@ -5,16 +5,17 @@ from argparse import ArgumentParser from operator import itemgetter from pathlib import Path -from typing import Iterable, List, Union, Tuple +from typing import Iterable, List, Union, Tuple, Any import numpy as np +import random import pytorch_lightning as pl import torch from rdkit import Chem from torch.utils.data import DataLoader from ProteinReDiff.data import InferenceDataset, collate_fn, ligand_to_data, protein_to_data -from ProteinReDiff.model import ProteinReDiffModel ## (NN) +from ProteinReDiff.model import ProteinReDiffModel from ProteinReDiff.mol import get_mol_positions, mol_from_file, update_mol_positions from ProteinReDiff.protein import ( RESIDUE_TYPES, @@ -25,14 +26,36 @@ proteins_to_pdb_file, ) from ProteinReDiff.tmalign import run_tmalign + +torch.multiprocessing.set_start_method('fork') RESIDUE_TYPES_MASK = RESIDUE_TYPES + [""] -def compute_residue_esm(protein: Protein) -> torch.Tensor: - esm_model, esm_alphabet = torch.hub.load( - "facebookresearch/esm:main", "esm2_t33_650M_UR50D" - ) - esm_model.cuda().eval() - esm_batch_converter = esm_alphabet.get_batch_converter() + + + +esm_model = None +esm_batch_converter = None + +def load_esm_model(accelerator): + global esm_model, esm_batch_converter + if esm_model is None or esm_batch_converter is None: + esm_model, esm_alphabet = torch.hub.load( + "facebookresearch/esm:main", "esm2_t33_650M_UR50D" + ) + + # esm_model.cuda().eval() + if accelerator == "gpu": + esm_model.cuda().eval() + else: + esm_model.eval() + esm_batch_converter = esm_alphabet.get_batch_converter() + + + +def compute_residue_esm(protein: Protein, accelerator: str) -> torch.Tensor: + + global esm_model, esm_batch_converter + load_esm_model(accelerator) data = [] for chain, _ in itertools.groupby(protein.chain_index): @@ -40,7 +63,11 @@ def compute_residue_esm(protein: Protein) -> torch.Tensor: [RESIDUE_TYPES_MASK[aa] for aa in protein.aatype[protein.chain_index == chain]] ) data.append(("", sequence)) - batch_tokens = esm_batch_converter(data)[2].cuda() + # batch_tokens = esm_batch_converter(data)[2].cuda() + if accelerator == "gpu": + batch_tokens = esm_batch_converter(data)[2].cuda() + else: + batch_tokens = esm_batch_converter(data)[2] with torch.inference_mode(): results = esm_model(batch_tokens, repr_layers=[esm_model.num_layers]) token_representations = results["representations"][esm_model.num_layers].cpu() @@ -66,6 +93,24 @@ def proteins_from_fasta(fasta_file: Union[str, Path]): return proteins, names +def proteins_from_fasta_with_mask(fasta_file: Union[str, Path], mask_percent: float = 0.0): + names = [] + proteins = [] + sequences = [] + with open(fasta_file, "r") as f: + for line in f: + if line.startswith(">"): + name = line.lstrip(">").rstrip("\n").replace(" ","_") + names.append(name) + elif not line in ['\n', '\r\n']: + sequence = line.rstrip("\n") + sequence = mask_sequence_by_percent(sequence, mask_percent) + protein = protein_from_sequence(sequence) + proteins.append(protein) + sequences.append(sequence) + + return proteins, names, sequences + def parse_ligands(ligand_input: Union[str, Path, list]): ligands = [] if isinstance(ligand_input, list): @@ -110,9 +155,17 @@ def update_seq( protein = dataclasses.replace(protein, aatype = aatype) return protein +def mask_sequence_by_percent(seq, percentage=0.2): + aa_to_replace = random.sample(range(len(seq)), int(len(seq)*percentage)) + + output_aa = [char if idx not in aa_to_replace else 'X' for idx, char in enumerate(seq)] + masked_seq = ''.join(output_aa) + + return masked_seq def main(args): - pl.seed_everything(args.seed, workers=True) + pl.seed_everything(np.random.randint(999999), workers=True) + # Check if the directory exists if os.path.exists(args.output_dir): # Remove the existing directory @@ -123,45 +176,45 @@ def main(args): model = ProteinReDiffModel.load_from_checkpoint( args.ckpt_path, num_steps=args.num_steps ) - ## (NN) model.training_mode = False - args.num_gpus = 1 model.mask_prob = args.mask_prob ## (NN) # Inputs - proteins, names = proteins_from_fasta(args.fasta) - + proteins, names, masked_sequences = proteins_from_fasta_with_mask(args.fasta, args.mask_prob) + + with open(args.output_dir / "masked_sequences.fasta", "w") as f: + for i, (name, seq) in enumerate(zip(names, masked_sequences)): + f.write(">{}_sample_{}\n".format(name,i%args.num_samples)) + f.write("{}\n".format(seq)) + if args.ligand_file is None: - ligand_input = ["*"]*args.num_samples*len(names) + ligand_input = ["*"]*len(names) ligands = parse_ligands(ligand_input) else: ligands = parse_ligands(args.ligand_file) - # total_num_atoms = len(protein.aatype) + ligand.GetNumAtoms() - # print(f"Total number of atoms: {total_num_atoms}") - # if total_num_atoms > 400: - # warnings.warn("Too many atoms (> 400). May take a long time for sample generation.") datas = [] - for protein, ligand in zip(proteins, ligands): + for name, protein, ligand in zip(names,proteins, ligands): data = { **ligand_to_data(ligand), - **protein_to_data(protein, residue_esm=compute_residue_esm(protein)), + **protein_to_data(protein, residue_esm=compute_residue_esm(protein, args.accelerator)), } datas.extend([data]*args.num_samples) - # Generate samples - trainer = pl.Trainer.from_argparse_args( - args, - accelerator="auto", - gpus = args.num_gpus, - default_root_dir=args.output_dir, - max_epochs=-1, - ) + + trainer = pl.Trainer( + accelerator=args.accelerator, + devices=args.num_gpus, + default_root_dir=args.output_dir, + max_epochs=-1, + strategy='ddp' + + ) results = trainer.predict( model, dataloaders=DataLoader( @@ -169,15 +222,14 @@ def main(args): batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_fn, - ), + ) ) - torch.save(results,"results.pt") - positions = [p[0] for p in results] ## (NN) - probabilities = [s[1] for s in results] ## (NN) + + + + probabilities = [s[1] for s in results] - # positions = torch.cat(positions, dim=0).detach().cpu().numpy() - # probabilities = torch.cat(probabilities, dim=0).detach().cpu().numpy() #(NN) names = [n for n in names for _ in range(args.num_samples)] @@ -187,47 +239,16 @@ def main(args): sequence = predict_seq(seq_prob.squeeze()) f.write("{}\n".format(sequence)) - # Save samples - # sample_proteins, sample_ligands = [], [] - # tmscores = [] - # for pos, seq_prob in zip(positions, probabilities): - # sample_protein, sample_ligand = update_pos(protein, ligand, pos) - # sample_protein = update_seq(sample_protein, seq_prob) - # if ref_protein is None: - # warnings.warn( - # "Using the first sample as a reference. The resulting structures may be mirror images." - # ) - # ref_protein = sample_protein - # tmscore, t, R = max( - # run_tmalign(sample_protein, ref_protein), - # run_tmalign(sample_protein, ref_protein, mirror=True), - # key=itemgetter(0), - # ) - # sample_proteins.append( - # dataclasses.replace( - # sample_protein, atom_pos=t + sample_protein.atom_pos @ R - # ) - # ) - # sample_ligands.append( - # update_mol_positions( - # sample_ligand, t + get_mol_positions(sample_ligand) @ R - # ) - # ) - # tmscores.append(tmscore) - # proteins_to_pdb_file(sample_proteins, args.output_dir / "sample_protein.pdb") - # with Chem.SDWriter(str(args.output_dir / "sample_ligand.sdf")) as w: - # for sample_ligand in sample_ligands: - # w.write(sample_ligand) - # with open(args.output_dir / "sample_tmscores.txt", "w") as f: - # for tmscore in tmscores: - # f.write(str(tmscore) + "\n") + if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--seed", type=int, default=1234) + # parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--accelerator", type=str, default="gpu") parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_gpus", type=int, default=1) parser.add_argument("--num_workers", type=int, default=torch.get_num_threads()) parser.add_argument("--num_steps", type=int, default=64) parser.add_argument("--mask_prob", type=float, default=0.3) diff --git a/scripts/predict_batch_strc_msk_inp.py b/scripts/predict_batch_strc_msk_inp.py index 3321530..f948320 100644 --- a/scripts/predict_batch_strc_msk_inp.py +++ b/scripts/predict_batch_strc_msk_inp.py @@ -51,25 +51,10 @@ def load_esm_model(accelerator): esm_model.eval() esm_batch_converter = esm_alphabet.get_batch_converter() -# esm_model, esm_alphabet = torch.hub.load( -# "facebookresearch/esm:main", "esm2_t33_650M_UR50D" -# ) -# esm_model.cuda().eval() -# esm_batch_converter = esm_alphabet.get_batch_converter() def compute_residue_esm(protein: Protein, accelerator: str) -> torch.Tensor: - # esm_model, esm_alphabet = torch.hub.load( - # "facebookresearch/esm:main", "esm2_t33_650M_UR50D" - # ) - # esm_model.cuda().eval() - # esm_batch_converter = esm_alphabet.get_batch_converter() + global esm_model, esm_batch_converter - # if esm_model is None or esm_batch_converter is None: - # esm_model, esm_alphabet = torch.hub.load( - # "facebookresearch/esm:main", "esm2_t33_650M_UR50D" - # ) - # esm_model.cuda().eval() - # esm_batch_converter = esm_alphabet.get_batch_converter() load_esm_model(accelerator) data = [] @@ -179,7 +164,7 @@ def mask_sequence_by_percent(seq, percentage=0.2): return masked_seq def main(args): - pl.seed_everything(np.random.randint(999999999), workers=True) + pl.seed_everything(np.random.randint(99999), workers=True) # Check if the directory exists if os.path.exists(args.output_dir): @@ -210,11 +195,6 @@ def main(args): else: ligands = parse_ligands(args.ligand_file) - # total_num_atoms = len(protein.aatype) + ligand.GetNumAtoms() - # print(f"Total number of atoms: {total_num_atoms}") - # if total_num_atoms > 400: - # warnings.warn("Too many atoms (> 400). May take a long time for sample generation.") - datas = [] for name, protein, ligand in zip(names,proteins, ligands): data = { @@ -245,24 +225,8 @@ def main(args): ) - - #torch.save(results,"results.pt") - positions = [p[0] for p in results] ## (NN) - probabilities = [s[1] for s in results] ## (NN) - - # positions = torch.cat([p[0] for p in results], dim=-1).detach().cpu().numpy() - # probabilities = torch.cat(probabilities, dim=0).detach().cpu().numpy() #(NN) - - - - # with open(args.output_dir / "sample_sequences.fasta", "w") as f: - # for i, (name, seq_prob) in enumerate(zip(names, probabilities)): - # f.write(">{}_sample_{}\n".format(name,i%args.num_samples)) - # sequence = predict_seq(seq_prob.squeeze()) - # f.write("{}\n".format(sequence)) - - # Save samples - # Repeat samples by num_samples + positions = [p[0] for p in results] + probabilities = [s[1] for s in results] proteins, ligands, names = [protein for protein in proteins for _ in range(args.num_samples)],\ [ligand for ligand in ligands for _ in range(args.num_samples)], \ [name for name in names for _ in range(args.num_samples)] diff --git a/train.py b/train.py index 9cae8aa..fbdef2e 100644 --- a/train.py +++ b/train.py @@ -17,8 +17,8 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger -from ProteinReDiff.data import PDBbindDataModule -from ProteinReDiff.model import DiffusionModel +from ProteinReDiff.data import PDBDataModule +from ProteinReDiff.model import ProteinReDiffModel @@ -29,14 +29,13 @@ def main(args): rmtree(args.save_dir) args.save_dir.mkdir(parents=True) - datamodule = PDBbindDataModule.from_argparse_args(args) - model = DiffusionModel(args) + datamodule = PDBDataModule.from_argparse_args(args) + model = ProteinReDiffModel(args) trainer = pl.Trainer.from_argparse_args( args, accelerator="auto", precision=16, strategy="ddp_find_unused_parameters_false", - #logger=WandbLogger(save_dir=args.save_dir, project="DiffusionProteinLigand"), callbacks=[ ModelCheckpoint( filename="{epoch:03d}-{val_loss:.2f}", @@ -53,10 +52,11 @@ def main(args): if __name__ == "__main__": parser = ArgumentParser() - parser = PDBbindDataModule.add_argparse_args(parser) - parser = DiffusionModel.add_argparse_args(parser) + parser = PDBDataModule.add_argparse_args(parser) + parser = ProteinReDiffModel.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--num_gpus", type = int, default = 1) parser.add_argument("--save_dir", type=Path, required=True) args = parser.parse_args() diff --git a/train_from_ckpt.py b/train_from_ckpt.py index 871f8c5..6cce808 100644 --- a/train_from_ckpt.py +++ b/train_from_ckpt.py @@ -15,8 +15,8 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger -from ProteinReDiff.data import PDBbindDataModule -from ProteinReDiff.model import DiffusionModel +from ProteinReDiff.data import PDBDataModule +from ProteinReDiff.model import ProteinReDiffModel @@ -24,15 +24,14 @@ def main(args): pl.seed_everything(args.seed, workers=True) args.save_dir.mkdir(parents=True) - datamodule = PDBbindDataModule.from_argparse_args(args) - model = DiffusionModel(args) + datamodule = PDBDataModule.from_argparse_args(args) + model = ProteinReDiffModel(args) trainer = pl.Trainer.from_argparse_args( args, accelerator="auto", precision=16, strategy="ddp_find_unused_parameters_false", resume_from_checkpoint=args.trained_ckpt, - #logger=WandbLogger(save_dir=args.save_dir, project="DiffusionProteinLigand"), callbacks=[ ModelCheckpoint( filename="{epoch:03d}-{val_loss:.2f}", @@ -49,8 +48,8 @@ def main(args): if __name__ == "__main__": parser = ArgumentParser() - parser = PDBbindDataModule.add_argparse_args(parser) - parser = DiffusionModel.add_argparse_args(parser) + parser = PDBDataModule.add_argparse_args(parser) + parser = ProteinReDiffModel.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--seed", type=int, default=1234) parser.add_argument("--save_dir", type=Path, required=True)