From 1a3b3634756f8158cb6239f30695b17c7ea63621 Mon Sep 17 00:00:00 2001 From: Ruihao Zeng <52482911+AndrewBarker0621@users.noreply.github.com> Date: Thu, 14 Dec 2023 17:27:49 +1100 Subject: [PATCH] Add files via upload --- zed_utils/__pycache__/click.cpython-310.pyc | Bin 0 -> 1850 bytes .../__pycache__/color_init.cpython-310.pyc | Bin 0 -> 1074 bytes .../__pycache__/grid_map.cpython-310.pyc | Bin 0 -> 9225 bytes zed_utils/__pycache__/init.cpython-310.pyc | Bin 0 -> 1198 bytes zed_utils/__pycache__/tools.cpython-310.pyc | Bin 0 -> 4595 bytes zed_utils/click.py | 69 +++ zed_utils/color_init.py | 31 ++ zed_utils/grid_map.py | 526 ++++++++++++++++++ zed_utils/init.py | 56 ++ zed_utils/segment_anything/__init__.py | 15 + .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 407 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 418 bytes .../automatic_mask_generator.cpython-310.pyc | Bin 0 -> 11408 bytes .../automatic_mask_generator.cpython-39.pyc | Bin 0 -> 11340 bytes .../__pycache__/build_sam.cpython-310.pyc | Bin 0 -> 2152 bytes .../__pycache__/build_sam.cpython-39.pyc | Bin 0 -> 2218 bytes .../__pycache__/predictor.cpython-310.pyc | Bin 0 -> 9963 bytes .../__pycache__/predictor.cpython-39.pyc | Bin 0 -> 9946 bytes .../automatic_mask_generator.py | 372 +++++++++++++ zed_utils/segment_anything/build_sam.py | 107 ++++ .../segment_anything/modeling/__init__.py | 11 + .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 394 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 405 bytes .../__pycache__/common.cpython-310.pyc | Bin 0 -> 1749 bytes .../__pycache__/common.cpython-39.pyc | Bin 0 -> 1747 bytes .../__pycache__/image_encoder.cpython-310.pyc | Bin 0 -> 12630 bytes .../__pycache__/image_encoder.cpython-39.pyc | Bin 0 -> 12448 bytes .../__pycache__/mask_decoder.cpython-310.pyc | Bin 0 -> 5470 bytes .../__pycache__/mask_decoder.cpython-39.pyc | Bin 0 -> 5419 bytes .../prompt_encoder.cpython-310.pyc | Bin 0 -> 7678 bytes .../__pycache__/prompt_encoder.cpython-39.pyc | Bin 0 -> 7668 bytes .../modeling/__pycache__/sam.cpython-310.pyc | Bin 0 -> 6659 bytes .../modeling/__pycache__/sam.cpython-39.pyc | Bin 0 -> 6650 bytes .../__pycache__/transformer.cpython-310.pyc | Bin 0 -> 6603 bytes .../__pycache__/transformer.cpython-39.pyc | Bin 0 -> 6547 bytes zed_utils/segment_anything/modeling/common.py | 43 ++ .../modeling/image_encoder.py | 395 +++++++++++++ .../segment_anything/modeling/mask_decoder.py | 176 ++++++ .../modeling/prompt_encoder.py | 214 +++++++ zed_utils/segment_anything/modeling/sam.py | 174 ++++++ .../segment_anything/modeling/transformer.py | 240 ++++++++ zed_utils/segment_anything/predictor.py | 269 +++++++++ zed_utils/segment_anything/utils/__init__.py | 5 + .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 154 bytes .../utils/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 165 bytes .../utils/__pycache__/amg.cpython-310.pyc | Bin 0 -> 12107 bytes .../utils/__pycache__/amg.cpython-39.pyc | Bin 0 -> 12115 bytes .../__pycache__/transforms.cpython-310.pyc | Bin 0 -> 3937 bytes .../__pycache__/transforms.cpython-39.pyc | Bin 0 -> 3971 bytes zed_utils/segment_anything/utils/amg.py | 346 ++++++++++++ zed_utils/segment_anything/utils/onnx.py | 144 +++++ .../segment_anything/utils/transforms.py | 102 ++++ zed_utils/tools.py | 183 ++++++ 53 files changed, 3478 insertions(+) create mode 100644 zed_utils/__pycache__/click.cpython-310.pyc create mode 100644 zed_utils/__pycache__/color_init.cpython-310.pyc create mode 100644 zed_utils/__pycache__/grid_map.cpython-310.pyc create mode 100644 zed_utils/__pycache__/init.cpython-310.pyc create mode 100644 zed_utils/__pycache__/tools.cpython-310.pyc create mode 100644 zed_utils/click.py create mode 100644 zed_utils/color_init.py create mode 100644 zed_utils/grid_map.py create mode 100644 zed_utils/init.py create mode 100644 zed_utils/segment_anything/__init__.py create mode 100644 zed_utils/segment_anything/__pycache__/__init__.cpython-310.pyc create mode 100644 zed_utils/segment_anything/__pycache__/__init__.cpython-39.pyc create mode 100644 zed_utils/segment_anything/__pycache__/automatic_mask_generator.cpython-310.pyc create mode 100644 zed_utils/segment_anything/__pycache__/automatic_mask_generator.cpython-39.pyc create mode 100644 zed_utils/segment_anything/__pycache__/build_sam.cpython-310.pyc create mode 100644 zed_utils/segment_anything/__pycache__/build_sam.cpython-39.pyc create mode 100644 zed_utils/segment_anything/__pycache__/predictor.cpython-310.pyc create mode 100644 zed_utils/segment_anything/__pycache__/predictor.cpython-39.pyc create mode 100644 zed_utils/segment_anything/automatic_mask_generator.py create mode 100644 zed_utils/segment_anything/build_sam.py create mode 100644 zed_utils/segment_anything/modeling/__init__.py create mode 100644 zed_utils/segment_anything/modeling/__pycache__/__init__.cpython-310.pyc create mode 100644 zed_utils/segment_anything/modeling/__pycache__/__init__.cpython-39.pyc create mode 100644 zed_utils/segment_anything/modeling/__pycache__/common.cpython-310.pyc create mode 100644 zed_utils/segment_anything/modeling/__pycache__/common.cpython-39.pyc create mode 100644 zed_utils/segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc create mode 100644 zed_utils/segment_anything/modeling/__pycache__/image_encoder.cpython-39.pyc create mode 100644 zed_utils/segment_anything/modeling/__pycache__/mask_decoder.cpython-310.pyc create mode 100644 zed_utils/segment_anything/modeling/__pycache__/mask_decoder.cpython-39.pyc create mode 100644 zed_utils/segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc create mode 100644 zed_utils/segment_anything/modeling/__pycache__/prompt_encoder.cpython-39.pyc create mode 100644 zed_utils/segment_anything/modeling/__pycache__/sam.cpython-310.pyc create mode 100644 zed_utils/segment_anything/modeling/__pycache__/sam.cpython-39.pyc create mode 100644 zed_utils/segment_anything/modeling/__pycache__/transformer.cpython-310.pyc create mode 100644 zed_utils/segment_anything/modeling/__pycache__/transformer.cpython-39.pyc create mode 100644 zed_utils/segment_anything/modeling/common.py create mode 100644 zed_utils/segment_anything/modeling/image_encoder.py create mode 100644 zed_utils/segment_anything/modeling/mask_decoder.py create mode 100644 zed_utils/segment_anything/modeling/prompt_encoder.py create mode 100644 zed_utils/segment_anything/modeling/sam.py create mode 100644 zed_utils/segment_anything/modeling/transformer.py create mode 100644 zed_utils/segment_anything/predictor.py create mode 100644 zed_utils/segment_anything/utils/__init__.py create mode 100644 zed_utils/segment_anything/utils/__pycache__/__init__.cpython-310.pyc create mode 100644 zed_utils/segment_anything/utils/__pycache__/__init__.cpython-39.pyc create mode 100644 zed_utils/segment_anything/utils/__pycache__/amg.cpython-310.pyc create mode 100644 zed_utils/segment_anything/utils/__pycache__/amg.cpython-39.pyc create mode 100644 zed_utils/segment_anything/utils/__pycache__/transforms.cpython-310.pyc create mode 100644 zed_utils/segment_anything/utils/__pycache__/transforms.cpython-39.pyc create mode 100644 zed_utils/segment_anything/utils/amg.py create mode 100644 zed_utils/segment_anything/utils/onnx.py create mode 100644 zed_utils/segment_anything/utils/transforms.py create mode 100644 zed_utils/tools.py diff --git a/zed_utils/__pycache__/click.cpython-310.pyc b/zed_utils/__pycache__/click.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..041b61710d2ee2465fea9eda45a8b78dd0039a4c GIT binary patch literal 1850 zcma(SO>f&al%ym_@sc#_+MvaD;9(ddz_Hh%r(hV`CEa!y+@i@gz-S>TB5f;?D21Y$ zD3E~x?a$buy$sl2;&s@ehy8&aiak=c;|?1z0(^XYdhhXlE8pDgB3OU?`JDff zvbr3&Jb+#H0WidHiZVRHxW&ncSei3(zeS8N=M`ekOEU5pK0<-}7G%&cHgI(jxIBPe zehPr25oQSVU`Ae{k;5G3!s#-Pb>Q^a8ffyhcaSFWS!_fX%iv}oz%F+HG`K+_m?dNZ!17$zc7 zVQ8ew4uB5s;4Z$}vTpSkpJNm5#WIO8tk%C@0bzLNDAb@+d`eCohVP=6_z*omn7J!d zGN<-uUd!3IjW-7BGk4aRtF%lR!+uIPMR{Ie;|6qXyPV^n5rAz1Qn?$gZLIUC`HBqNAPgAOaIZXtOS{}y7iQtT$)Ib?l*%LUWjOu=OsrN}Fql_yq%l#!S&U48U5h-qROLs+; z6?I=QD;w<7hpW11KBna~D!AnfjcsLE_b3R>FPA-fKj_h=`h`pXbzV%lGPS-> zkKyHAaHS5wnnir9^fC!6VSUO~CB;fit(v(OnP2okNm^whFwYwWt2vF-zb3Xvb4hpi z?)B*2y?`dITx(@gQjsg0c~M{Mx^_llsQkAE1g`dEBqrRpqTaf8#w1&NHkUzHt^u#c zmm1gF9jDPmJ7BhU!0b^MG7y>!gqJILJ=$uM5!z(HxA*V0ZTc3h8|61_W7{GLP3Sj) zSbhS)-Q2-D_zwOYyV!?Ke)cf|jJ)&7)$O*|{~;Tj7GEqg9?YDvEq@=71}kWHXuLZO zdQBSAIB&=^Vj3Bm+j;Fm%XR?s-Z*9=AK4w?VbB48UwolIeDvd^;c@u=*FPK|KOP=F z{%I(0fz~#F>k*reJZWqQ{5EnPe*uLnZy3I90NOJ_EVWmdhLbl9{UZa6u`5B@%JQn@ zVH~BY8L6Ma*V3?b@ixT7eRp-M_5Xb~V=7EX0j(T{x*LXB&Z^Yl8)0}}MQQs*nqk=5 z{_*nJ_YIKdUtn@?U-glXeQ)EgrJ6U28s>7pOf@M}c^lZYXQzWEqOeRW{#eB})BcyO QJztpuw2C(fD8@I*-+E`xB>(^b literal 0 HcmV?d00001 diff --git a/zed_utils/__pycache__/color_init.cpython-310.pyc b/zed_utils/__pycache__/color_init.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64e41bdad024ed5f9939f9ce3fe6a022dc9aa37c GIT binary patch literal 1074 zcmZvcKabNe7{+7gZ<{n_V1NNw$wEM^=niyjNGAkaBvj=ZUmG<|A}3KPayPyJ3ldCs zx53PZK;j!%3?K%S1+lPj631)DsqPSO%w!g`nZpKb z$lOEk7#-;e=*EM08J(cbNV$Evmd`vkV*a5HUuC=>=oG=5OKJZ=7U=B<=vq;2&$_~M zkzWO&R#$j0a60p1?cuCA{7AK(Aiuv#q{_ZeG80M(VgdQ1yUeWpdGC8lMj6{c0t3TCtud($~^4~>a~oOe)2@t)KPd@dDJyc?xb z>flRitJ1^@TQ(<0^chF_3ffgkq-7sD)JPw9V6)2fvYe%jq0k!E;cz=}PrP`PWjO-`y%JYcNvY-8cz5TgnZk4EK~hye5?2 zjfUG%BKK6~E!hhP-G21-DBe(oankGel08*>JL)7uc}?h4`FA6k^g2;IYw1l}ipg|> zxaF$a+Av9m+d*e22a$}Gw-NQa8;NXSP0C#lX#!^(Z~DNH=TP_27p`nzAy>kIkkP#> zH=}qf8IG>Ja{C4rab-Ue0WB!L(v>|CY=@)Eqdisa$S6#rpfg?@_Cy7<#OFbtHXW;C z&i~#=rDDyjp;lf%bB8uO`_Xw(`39c&B1mc@u_DePT@vTT z1*COxQG5#NDe<&;25Cbyu|c0!g`2~6A%&(0jMM@(6Km5>jZM^S8tYc#2=f8eyZL%yt=qE}FGc^L zkZ;e8^98j+zk!yxkDnHVvu;Dv>6y!^gSJZQf>u)xv?dIEski!pi8D_2Mv)e87IYF) z<&MDxFSM3ac`)7%h+*SkH(%J3q_>s3AIV{?Z1h$|-Wy_URqG6gQuGGI*;FpBg>f|M zqliNyo(g({^`SboHtzM&tqXn|#|dPtU6sR7%!;bSveMfQyOApE`6UEAQYBpwy-uR) zI*+f2qd|hMt)g<`sJ|{-Sfl)1B3Fo9C320(7eU(oS~s}Y6Uj!qh+ID!bd!zaOzj6) zYA*=_nkrra@l4NLG0!ul{29};JhNo(*QacAIUBNIK8V>#=dqMNjZ6X-IItk4W?Y7` zuqg)_g&{dhhhweo7q5)MWaDzOo3tEdVVfN)D=*rGpvB7H4!0uFlX1(EU&4gSk;7qf z`~yy}_xcfkLoe+2_QRw%9K3+uES(js+c); z#NXz+^&)KU9cr=Ba$bP<9&2&Y^4kKO_*hE;EngJ=2`xpmydp{`w3N}ZDk>+mRMGNn zQ9Ge!2`w9{8;#|tAF-`dZdXRpfNh}i_M(1& zc#q`>Bl@sWZ2FWfBXPvC;*AGegWe6xWAeDdQw+CP;Y?re02_%QpCfi2@F;ce5b|s#MrdCbY_A>JaC{?Tqr3AYR)}JY~Ys_AdYbW5wCw@+&c7G+iDt<=KYE^;C12fOp^o+RTWNd%8? z@DhvqTzCB;`ov`IOOTVARWmQb1j8s>&PV%avu!;#rb8Q={U}fRlsEScWJr^1o)_l2 z{UU8pR))q$*-qWFd;}y*!k+f=a&^$`TEbdxZp$?KVmM~X%XrgtZ{#iHC zRWzU48}VJ`to6rI5^KqCfV7=83HBt4V@ZZd*l(A$rU=&iVOP1mfrxhHTf7t*#mAWz zwiX%z_$*lfncCO@7 z>Cj{SQUVv@ln-sw*txK?JOwW>dvs=+Fm&OfO?4~?W%)Mps&XqF$1(hZ>)3czyZ+|w zo3GvuZoc}}SHAhiTgva;O>WQ*s_Kn5-+1$OaQ)WptFPU8o6lZj2SyPAO-q6$UkAxF zkIT@cUF9j@p?83vnAOK z&?0&ba}|np@*2Fsf9Jm7{e|)G|EBTTHw^wMx8{GzMx6B3*&B5qdh2YC2e~!;G@Ik3 zRc9lzKJ)hsw)Bi>F`>17o%J|rmc^CicX;p`eGQ+7wdFoE>>o2=?E+i#@u;osy(DM6t?j+`)Tc(i zAFV$IgO|2L+3gKxYVvy|Y-ASeX2WvK6{}&LG0)o#YuUWWROc&B14?umCSRwuXu^&+bfY^4cM*8eBJ1xQLL{CN!o_J)LTp zojPcHk^9<6sV}8u|5%c>)aY^!p5r_)?pxFT-{QG#lzE)yLgIp~JcIQ(MT|?P{gaf| zVzZR;`!ZgBY#hD@;~q}?Kb)@jVb+gy&WZ1{<=>viy)86T%EL$E3*aNz0jX;mmL#1hq~K_#6=i=Net+5R^3@0h1EjDNl2p za!z8lay(p5;P?{9{TngbC6J3|6IOV|Y+CuJX<4TAh@KTNnrH2w&UAckjt(s{`_++| zL98F3g&axVMu2icm53Wy53GsFrdY%Dr(m>6b!!M;1^k9lUdN_TN|8}OC!orm6^2J1 zj^70q1zQThd6)%kFxI3qL=gMhiR#U|8 zE#WIEDuX@5uKnxIKkj@jIVH*zWMgnrS!Dzw>D*bnF~n?Hmvd0lQg$?oRe2j)DGs;c zRgxpm5Jtmz<%DqbQO7Jwux9_&u-``rY7i%K+)2hjgUwJz&2XFyKR2t6MDuqjiqZrm zZLYx)?Mfh0(TsKxWf75RwHnjlO|(5q_Hr19S)VX3^S`8!g(|Nj`Z0nNOOY0PJKU9@ zrw-qwHZL5FBH&&^{i=Y~Y=omoc?jhLjG}XIAU}nsR(lR z$w67Ymo^w0wQ{_@UMluL((=JXE= zC;Fft*hi1BUisY_A&cMAfLyLo?H54uko)$n>#8W>!f^av$Lp9CIef^tuaJ4&@YvJG@AUNy_HLUaYB`9C#lbS=vWPMU`_N$qq<-Q9WPXagFc)ED{|J!&{ z7ji=171UE$Z@wN1^{t5->ft@WE5%hA z+P$Oz>eH&pG?JOf)v*70AIKnA$A0HRuYw*2k+Q`S3Cm7p2UmYqc-(bmTurX9<|dLnjdD0nq(hK%S(>72R*c8F%@H!Z_n&xQOh$krVhzN; zfm=>|9LtB;k>GhBTfc(;TFJPPrdfzN(B{$SGavSm$?@T3ia6jDHF&uM^NZ6=+2}L& zyqYUpz@nPJ<~_i7^B&AI+y=v)onsbzTY9pA0# zy+qjQjJL;h`SA%zCs?bOc@ zf0^&pPuW{XTr72F65E7qc=|$mQW$LVZe;D`M|?SP$RmKUfH?srpTyx2xVFZKZ!C-; zPvG_wN8oZBBVJe-K@o?oCyu~HI7X1SI$t&Efp?!c0@vjj0bhXEh->o+$vfaFzM?XX zC+QMLoVxH;Bu6(*)dyG1 zYFl>q8vlO&@-*|~=}b2<5PD_SlP77}zx>gE(eGC;%LgzPgWmd3(j7WNY*NxhHEqb6)c?~Usz`U}K_8b|b7&~fP&38M zl$G@HNM)_j)0a=7vhC>$^W!4tQ*4^1`yZ^Cp_;>i`v#Oy2}MmrQhl)nW+PfX|F8;Sx4yrrN;6Qr6$%s!rFSSZ;nxZ!7Nr|^{u+ao8@$Kk}F%fHgW?aK3aZ!r;3F< zWe4(;+<2s?gf&jdBYMhD9|=;j`3!Sa_T1d_D)P_KZehMJFwe{7D|Wb5%KSV1HEZtl zypo%FHMep#*Jc%**+m`vlb{> z2*xe|8iP2HH$M7s$wda1;8_x`WRS#QIkr2xnv4gdp(x%69Nm^xy#IZKigtmG5>RI; zhVMN}xo6*u``-1aXM;Yu8atO@F`mkW7bT2?HW?1aU|kQct_Iik*mj_Kn}du@nHCa3#R?^&B&cNHdNjI zrE@dHB4;(we1XpD1837(fX#bpG1LrdvtqX1hk(fi)OSkAx!(_$aSnL|!WiHqC2N7>@_W!==ERu+K{Fn*;q0{5polMfs zrPB@MeFt0eaG_z=Df()s_h1&LLNPEnD_(`En5soBs!;>!bFNX7T8c$Y<%C*lot~hF zshVgtX;W?L$GYfXf-?=w^d9rJvrrVj0?;B~E=sKI@fm3su=yTe!7Ak67w`6P<$fo0dAt{TT>4>6TGIZL z5eCwcF3P&}q>r*8YcfFDly%ua*^*7!Lb)PW7?U$kd*blipNQ`pGrL}8x=E3CIfKd zpcM0O4Yh)azI$(Tn2nRocp_D@zxi2`@9J#2`S97JRL?hGCo-Dpbd+!EEF0zTP3NUA zRlFamENwzM_dyzFL)@`!ORa&qbQ{!5x5Kw_7nQ<1G76mR*!qSwP%5xesr`mWp9BgjF&jv zOJmnkZ7q&WC>SkL+~6K*-hg_K&4oZ8OaZ1Lq$o_7hk5KQtimFN$L~PXK?mO!)Ftzd z-p@L$(!8^u>R~7D4AQ-1(uqgYVch9ws-uTVr?0Y!uExE`()Xt^44`U{SpI(J1qO1p zqqELV(!pV;qj;XkPCBVZUXtC3d%ORaC~bfS4^%dr$V+qZE~}NR4YMl$+qOG*-=}ST z&|S)_M9|- zf-#)d2?^+L!s~6KeVc7wkwx()x{j(-y#Z3z^He5L<-F%EaCNiwy+rBcAlli<4nBq$ zn=big+f;9Z?-7ukeso)BJADC0iUp8oVQK7`%vD-%1Q5%9DGGq?8zBr0oH`T>M_WgB zA++_7faht6KAqn*<3&14L34tm3+x6AGc&Xvht z+Dp3T3AvTJLxfZdPp8e6Cg7EOSh;?0Pd~~=nJSx)o__xHS+w=|*@w>_Z>Bx}D7{@9WwUUZZx;Tpt+u{0Uj4Xe6)SC{@Cs+cDEz|RFxm$CzPQjQF889} zP=Dn9K)P7f-8I!KtkI;o`gRde1pc%6z?d|!bFTSh$6$*vLDk6FP*cAq?Ee!hUhHM)g1g}3Bf=6qPE4-5Qh;Vo)K?K!-P$D;K*X1!(bsG7E~snaj>7&4;hB z+k6AWH+?9?7cKmJz?JQ+{~HQ272(mPOpW5qnR@;ShEfyGSxqjHuObo%jTixVf^xu_ zOU~0!9mCR1fSC=F&rl_@%-?oiKfJhP?ioz}Y2ga+LkT#dnbx79< zmX^{sB9Zw&q`jgJ>5r>5v{Sf;$Q&^zWB?#2ACQTGL!ZZe06gwPi9WaR7&@G#VUU5d z-!qbQUA0LK?-L=jQ1?JeE8fXVE1!+K4YmnI(Lim1gzJlPOH z>*Skw%1|QX@>9m)>-YF6gLuYI8C>?S;M1>K80Da5zd|glypj-@?;UR(#$Ik+sTI~K?8s-XR$y&#K|#r5YrD#vmJdZvBq@rWz8nen z7y;4n3)D~Oa3_4jRfBp;4U}k|iqJ6i#f1i{SKeV&p2Dk;97Jd_o$^w2()2}Ww%qX7B z^)Q_bIRE9`w{#G3rt=Dpeu=@Qn<1&VCl$qIeslB7lwTv9nQ$7)5twnFJzSA-^ST#y sPqwXAX-fYA(&O1E`Gho(lhP=3;MIHjv;BtO@~`>6zvjP#^4&)3KgobDy8r+H literal 0 HcmV?d00001 diff --git a/zed_utils/click.py b/zed_utils/click.py new file mode 100644 index 0000000..d3144c4 --- /dev/null +++ b/zed_utils/click.py @@ -0,0 +1,69 @@ +####################################################################### +# Click +import pyzed.sl as sl +import cv2 +import numpy as np + +class Click: + def __init__(self, grid_map): + self.grid_map = grid_map + self.background_color = 'white' + self.current_color = 'purple' + + def click_grid_cell(self, x, y): + """ + Finds which grid cell the point belongs to in a grid defined by a NumPy array. + + Parameters: + coordinate_map (numpy.ndarray): A NumPy array of shape (num_grids_y, num_grids_x, 4), + where each element contains the top-left and bottom-right + coordinates of a grid cell. + point (tuple): The (x, y) coordinates of the point. + + Returns: + tuple: The (row, column) of the grid cell that the point belongs to, or (-1, -1) if the point is not in any grid cell. + """ + flag = 0 + for row in range(self.grid_map.coordinate_map.shape[0]): + for col in range(self.grid_map.coordinate_map.shape[1]): + top_left_x, top_left_y, bottom_right_x, bottom_right_y = self.grid_map.coordinate_map[row, col] + if top_left_x <= x <= bottom_right_x and top_left_y <= y <= bottom_right_y: + flag = 1 + if flag == 1: + break + if flag == 1: + break + return row, col + + def mouse_callback(self, event, x, y, flags, param): + if event == cv2.EVENT_LBUTTONDOWN: + row, col = self.click_grid_cell(x, y) + self.grid_map.add_color(col, row, self.current_color) + + # def click_draw(self): + # # grid_image = cv2.cvtColor(grid_image, cv2.COLOR_BGR2RGB) + # for y in range(self.grid_map.num_grids_y): + # for x in range(self.grid_map.num_grids_x): + # # draw the color + # if self.grid_map.color_map[y,x] == 0: + # grid_color = (255,255,255) # white + # elif self.grid_map.color_map[y,x] == 1: + # grid_color = (50,127,205) # brown + # elif self.grid_map.color_map[y,x] == 2: + # grid_color = (128,0,128) # purple + # elif self.grid_map.color_map[y,x] == 3: + # grid_color = (50,205,50) # green + # elif self.grid_map.color_map[y,x] == 4: + # grid_color = (0,191,255) # yellow + # elif self.grid_map.color_map[y,x] == 5: + # grid_color = (0,0,0) # black + # elif self.grid_map.color_map[y,x] == 10: + # grid_color = (32,0,128) # red + # else: + # grid_color = (255,255,255) # unknow->blue + # # print(self.coordinate_map[y,x]) + # # BGR format + # cv2.rectangle(self.grid_map.grid_image, + # (int(self.grid_map.coordinate_map[y,x,0]),int(self.grid_map.coordinate_map[y,x,1])), + # (int(self.grid_map.coordinate_map[y,x,2]), int(self.grid_map.coordinate_map[y,x,3])), + # grid_color, -1) \ No newline at end of file diff --git a/zed_utils/color_init.py b/zed_utils/color_init.py new file mode 100644 index 0000000..18998e8 --- /dev/null +++ b/zed_utils/color_init.py @@ -0,0 +1,31 @@ +####################################################################### +# Color setting +import numpy as np + + +def color_init(): + lower_white = np.array([0, 0, 254]) + upper_white = np.array([180, 1, 255]) + lower_purple = np.array([120,25,120]) + upper_purple = np.array([170,255,255]) + lower_green = np.array([40, 40, 75]) + upper_green = np.array([80,255, 255]) + lower_yellow = np.array([25, 0, 254]) + upper_yellow = np.array([40, 255, 255]) + lower_brown = np.array([150, 0, 0]) + upper_brown = np.array([180, 145, 245]) + lower_red_1 = np.array([0, 25, 200]) + upper_red_1 = np.array([40, 255, 255]) + + color_range = (np.concatenate((lower_white, upper_white, + lower_purple, upper_purple, + lower_green, upper_green, + lower_yellow, upper_yellow, + lower_brown, upper_brown, + lower_red_1, upper_red_1))).reshape((-1,3)) + + color_dict = [ 'white', 'purple', 'green', 'yellow', 'black', 'red', 'unknow'] + + color_persentage = np.array([0.55, 0.5, 0.5, 0.5, 0.5, 0.5]) + + return color_range, color_dict, color_persentage \ No newline at end of file diff --git a/zed_utils/grid_map.py b/zed_utils/grid_map.py new file mode 100644 index 0000000..ae61d59 --- /dev/null +++ b/zed_utils/grid_map.py @@ -0,0 +1,526 @@ +####################################################################### +# Grid Map +import math +import cv2 +import numpy as np +import random +import os +from sklearn.cluster import DBSCAN +import open3d as o3d + + +def create_cuboid(bottom_corners, height): + vertices = np.vstack([bottom_corners, bottom_corners + np.array([0, 0, height])]) + + faces = np.array([ + [0, 2, 1], [0, 3, 2], # 底面 + [4, 5, 6], [4, 6, 7], # 顶面 + [0, 4, 7], [0, 7, 3], # 前侧面 + [1, 2, 6], [1, 6, 5], # 后侧面 + [2, 3, 7], [2, 7, 6], # 右侧面 + [0, 1, 5], [0, 5, 4] # 左侧面 + ]) + + mesh = o3d.geometry.TriangleMesh() + mesh.vertices = o3d.utility.Vector3dVector(vertices) + mesh.triangles = o3d.utility.Vector3iVector(faces) + + return mesh + + +def distanc_cal(x, y, point_cloud, depth): + err_depth, depth_value = depth.get_value(x, y) + print(depth_value) + return (depth_value) + + +class GridMap: + ''' + generate the color map with depth + ''' + def __init__(self, num_grids_x, num_grids_y, baseplate, color_range, color_dict, color_persentage, bg_width, bg_length): + self.num_grids_x = num_grids_x + self.num_grids_y = num_grids_y + self.color_map = np.zeros((self.num_grids_y,self.num_grids_x), dtype = int) + self.coordinate_map = np.zeros((self.num_grids_y,self.num_grids_x,4), dtype = int) + self.baseplate = baseplate + self.depth_info = np.zeros((self.num_grids_y, self.num_grids_x), dtype = int) + self.building_clusters = 0 + self.road_clusters = 0 + self.grid_image = np.zeros((bg_length, bg_width, 3), dtype = "uint8") + # 0->white, 1->brown, 2->purple, 3->green, 4->yellow, 5->black, 6->red, 7->unknow + self.color_range = color_range + self.color_dict = color_dict + self.color_persentage = color_persentage + + def file_initialization(self,root): + if not os.path.exists(root): + os.makedirs(root) + return root + '/path.txt' + + # 0->white, 1->brown, 2->purple, 3->green, 4->yellow, 5->black, 10->red, -1->unknow + def add_color(self, grid_x, grid_y, color): + if color == 'white': + self.color_map[grid_y,grid_x] = 0 + elif color == 'brown': + self.color_map[grid_y,grid_x] = 1 + elif color == 'purple': + self.color_map[grid_y,grid_x] = 2 + elif color == 'green': + self.color_map[grid_y,grid_x] = 3 + elif color == 'yellow': + self.color_map[grid_y,grid_x] = 4 + elif color == 'black': + self.color_map[grid_y,grid_x] = 5 + elif color == 'red': + self.color_map[grid_y,grid_x] = 10 + elif color == 'unknow': + self.color_map[grid_y,grid_x] = -1 + + def init_coordinate(self, grid_width, grid_length): + grid_tl_x = self.baseplate[0] + grid_tl_y = self.baseplate[1] + + for y in range(self.num_grids_x): + for x in range(self.num_grids_y): + self.add_coordinate(x, y, grid_tl_x, grid_tl_y, grid_width[y], grid_length[x]) + grid_tl_y += grid_length[x] + grid_tl_x += grid_width[y] + # print(grid_tl_y) + grid_tl_y = self.baseplate[1] + + # tl_x, tl_y, br_x, br_y + def add_coordinate(self, grid_x, grid_y, tl_x, tl_y, width, length): + self.coordinate_map[grid_y,grid_x,0] = tl_x + self.coordinate_map[grid_y,grid_x,1] = tl_y + self.coordinate_map[grid_y,grid_x,2] = tl_x + width + self.coordinate_map[grid_y,grid_x,3] = tl_y + length + + # cv2 version + def color_detector(self, hsv, blur, tl_x, tl_y, br_x, br_y): + total = (br_x - tl_x) * (br_y - tl_y) + unknow_flag = -1 + for index in range(len(self.color_persentage)): + # mask1 = cv2.inRange(image_hsv, lower_red1, upper_red1) + # mask2 = cv2.inRange(image_hsv, lower_red2, upper_red2) + # mask = cv2.bitwise_or(mask1, mask2) + mask = cv2.inRange(hsv, self.color_range[index * 2], self.color_range[index * 2 + 1]) + res = cv2.bitwise_and(blur, blur, mask = mask) + # print(res[tl_y:br_y, tl_x:br_x]) + # color_num = np.count_nonzero(res[tl_y:br_y + 1, tl_x:br_x + 1]) + if (np.count_nonzero(res[tl_y:br_y, tl_x:br_x]) / 3) >= (self.color_persentage[index] * total): + unknow_flag = index + break + return self.color_dict[unknow_flag] + + # read new frame + def reload_image(self, image): + blurred_image = cv2.GaussianBlur(image, (1,1), cv2.BORDER_DEFAULT) + hsv_image = cv2.cvtColor(blurred_image, cv2.COLOR_BGR2HSV) + for y in range(self.num_grids_y): + for x in range(self.num_grids_x): + color = self.color_detector(hsv_image, blurred_image, + self.coordinate_map[y,x,0], + self.coordinate_map[y,x,1], self.coordinate_map[y,x,2], + self.coordinate_map[y,x,3]) + self.add_color(x, y, color) + + # generate margin with a width equals 1 + def margin_detector(self, grid_x, grid_y): + grid_up = grid_right = grid_bottom = grid_left = -1 + # top-left grid (2 by 2 grids for calibration) + if self.color_map[grid_y, grid_y] == 2: + pass + else: + if grid_x - 1 < 0 and grid_y - 1 < 0: + # self.color_map[grid_y, grid_x] = 1 + grid_bottom = self.color_map[grid_y + 1, grid_x] + grid_right = self.color_map[grid_y, grid_x + 1] + + # top-right grid (2 by 2 grids for calibration) + elif grid_x + 2 > self.num_grids_x and grid_y - 1 < 0: + # self.color_map[grid_y, grid_x] = 1 + grid_bottom = self.color_map[grid_y + 1, grid_x] + grid_left = self.color_map[grid_y, grid_x - 1] + + # bottom-right grid (2 by 2 grids for calibration) + elif grid_x + 2 > self.num_grids_x and grid_y + 2 > self.num_grids_y: + # self.color_map[grid_y, grid_x] = 1 + grid_up = self.color_map[grid_y - 1, grid_x] + grid_left = self.color_map[grid_y, grid_x - 1] + + # bottom-left grid (2 by 2 grids for calibration) + elif grid_x - 1 < 0 and grid_y + 2 > self.num_grids_y: + # self.color_map[grid_y,grid_x] = 1 + grid_up = self.color_map[grid_y - 1, grid_x] + grid_right = self.color_map[grid_y, grid_x + 1] + + # if the grid in on the edge of the board + # left edge + elif grid_y - 1 < 0: + grid_bottom = self.color_map[grid_y + 1, grid_x] + grid_left = self.color_map[grid_y, grid_x - 1] + grid_right = self.color_map[grid_y, grid_x + 1] + # up edge + elif grid_x - 1 < 0: + grid_up = self.color_map[grid_y - 1, grid_x] + grid_right = self.color_map[grid_y, grid_x + 1] + grid_bottom = self.color_map[grid_y + 1, grid_x] + # right edge + elif grid_x + 1 >= self.num_grids_x: + grid_up = self.color_map[grid_y - 1, grid_x] + grid_left = self.color_map[grid_y, grid_x - 1] + grid_bottom = self.color_map[grid_y + 1, grid_x] + # bottom edge + elif grid_y + 1 >= self.num_grids_y: + grid_up = self.color_map[grid_y - 1, grid_x] + grid_right = self.color_map[grid_y, grid_x + 1] + grid_left = self.color_map[grid_y, grid_x - 1] + else: + grid_up = self.color_map[grid_y - 1, grid_x] + grid_bottom = self.color_map[grid_y + 1, grid_x] + grid_left = self.color_map[grid_y, grid_x - 1] + grid_right = self.color_map[grid_y, grid_x + 1] + # print(grid_up, grid_right, grid_bottom, grid_left) + return grid_up, grid_right, grid_bottom, grid_left + + def soft_margin(self, grid_x, grid_y): + up, right, bottom, left = self.margin_detector(int(grid_x), int(grid_y)) + if up == -1 and right == -1 and bottom == -1 and left == -1: + pass + elif up == right and up != (-1 or 0) and right != (-1 or 0): + # if self.color_map[grid_y - 1, grid_x] != 2: + self.color_map[grid_y, grid_x] = self.color_map[grid_y - 1, grid_x] + elif right == bottom and right != (-1 or 0) and bottom != (-1 or 0): + # if self.color_map[grid_y, grid_x + 1] != 2: + self.color_map[grid_y, grid_x] = self.color_map[grid_y, grid_x + 1] + elif bottom == left and bottom != (-1 or 0) and left != (-1 or 0): + # if self.color_map[grid_y + 1, grid_x] != 2: + self.color_map[grid_y, grid_x] = self.color_map[grid_y + 1, grid_x] + elif left == up and left != (-1 or 0) and up != (-1 or 0): + # if self.color_map[grid_y, grid_x - 1] != 2: + self.color_map[grid_y, grid_x] = self.color_map[grid_y, grid_x - 1] + # if islated by white block + elif (up and right and bottom == 0) or (up and right and left == 0) or \ + (up and bottom and left == 0) or (right and bottom and left == 0): + self.color_map[grid_y, grid_x] = 0 + else: + pass + + # 0->white, 1->brown, 2->purple, 3->green, 4->yellow, 5->black, 10->red, -1->unknow + def building_clustering(self): + backup = np.empty((self.num_grids_y, self.num_grids_x, 3), dtype = int) + for x in range(self.num_grids_x): + for y in range(self.num_grids_y): + backup[y, x, 0:2] = y, x + backup[...,2] = self.color_map + # delete background color, road color, and calibration color + no_background = backup[backup[...,2] > 2] + # DBSCAN clustering + dbscan = DBSCAN(eps = 1, min_samples = 2) + building_labels = dbscan.fit_predict(no_background) + self.building_clusters = max(building_labels) + 1 + for cluster_num in range(self.building_clusters): + # print('cluster_num', cluster_num) + origin_building_cluster = np.empty((1,3), dtype = int) + for index in range(len(building_labels)): + if building_labels[index] == cluster_num: + origin_building_cluster = np.append(origin_building_cluster, + np.array(no_background[index]).reshape(1,3), axis = 0) + # delete the first row generated by np.empty + building_cluster = np.delete(origin_building_cluster, 0, 0) + # print('cluster', cluster) + y_min = min(building_cluster[:,0]) + y_max = max(building_cluster[:,0]) + x_min = min(building_cluster[:,1]) + x_max = max(building_cluster[:,1]) + color_tag = building_cluster[0,2] + for x in range(x_min, x_max + 1): + for y in range(y_min, y_max + 1): + # print('x_coor, y_coor', x, y) + # print('color_tag', color_tag) + self.color_map[y,x] = color_tag + print('all constructures are auto-constructed (road and background excluded)') + +# # 0->white, 1->brown, 2->purple, 3->green, 4->yellow, 5->black, 10->red, -1->unknow +# def road_vertex_detector(self, grid_x, grid_y): +# # check top, right, bottom, and right +# if (self.color_map[grid_y - 1, grid_x] > 2) or (self.color_map[grid_y, grid_x + 1] > 2) or \ +# (self.color_map[grid_y + 1, grid_x] > 2) or (self.color_map[grid_y, grid_x - 1] > 2): +# return True + +# def lane_number_detector(self, r_v): +# l_n = 1 +# for index in range(1, r_v.shape[0]): +# if (r_v[index, 0] == r_v[0, 0] and abs(r_v[index, 1] - r_v[0, 1]) <= 3) or \ +# (r_v[index, 1] == r_v[index, 1] and abs(r_v[index, 0] - r_v[0, 0]) <= 3): +# l_n += 1 +# return int(l_n / 2) + +# def inflection_detector(self, grid_x, grid_y): +# turn_number = 0 +# # check top, right, bottom, and right +# if self.color_map[grid_y - 1, grid_x] == 2: +# turn_number += 1 +# if self.color_map[grid_y, grid_x + 1] == 2: +# turn_number += 1 +# if self.color_map[grid_y + 1, grid_x] == 2: +# turn_number += 1 +# if self.color_map[grid_y, grid_x - 1] == 2: +# turn_number += 1 +# if turn_number == 2 or turn_number == 4: +# return True + +# def lane_render(self, default_width=1): +# # default_width = 1 means +# backup = np.empty((self.num_grids_y, self.num_grids_x, 3), dtype = int) +# for x in range(self.num_grids_x): +# for y in range(self.num_grids_y): +# backup[y, x, 0:2] = y, x +# backup[...,2] = self.color_map +# no_background = backup[backup[...,2] == 2] +# for x in range(self.num_grids_x-default_width): +# least_check_time = default_width - 1 +# start_grid = end_grid = 0 +# while end_grid != num_grids_y - 1: + + + + +# def inflection_expension(self, grid_x, grid_y): + +# # 0->white, 1->brown, 2->purple, 3->green, 4->yellow, 5->black, 10->red, -1->unknow +# # only clustering the road +# def road_clustering(self): +# backup = np.empty((self.num_grids_y, self.num_grids_x, 3), dtype = int) +# for x in range(self.num_grids_x): +# for y in range(self.num_grids_y): +# backup[y, x, 0:2] = y, x +# backup[...,2] = self.color_map +# # delete background color, road color, and calibration color +# no_background = backup[backup[...,2] == 2] +# # DBSCAN clustering +# dbscan = DBSCAN(eps = 1, min_samples = 2) +# road_labels = dbscan.fit_predict(no_background) +# self.road_clusters = max(road_labels) + 1 +# for cluster_num in range(self.road_clusters): +# # print('cluster_num', cluster_num) +# origin_road_cluster = np.empty((1,3), dtype = int) +# for index in range(len(road_labels)): +# if road_labels[index] == cluster_num: +# origin_road_cluster = np.append(origin_road_cluster, +# np.array(no_background[index]).reshape(1,3), +# axis = 0) +# # each road cluster +# # delete the first row generated by np.empty +# road_cluster = np.delete(origin_road_cluster, 0, 0) + +# # find road vertex and inflection +# origin_road_vertex = np.empty((1,2), dtype = int) +# origin_road_inflection = np.empty((1,2), dtype = int) +# for index in range(road_cluster.shape[0]): +# if self.road_vertex_detector(road_cluster[index,1], road_cluster[index,0]) == True: +# origin_road_vertex = np.append(origin_road_vertex, +# np.array((road_cluster[index,0], +# road_cluster[index,1])).reshape(1,2), axis = 0) +# elif self.inflection_detector(road_cluster[index,1], road_cluster[index,0]) == True: +# origin_road_inflection = np.append(origin_road_inflection, +# np.array((road_cluster[index,0], +# road_cluster[index,1])).reshape(1,2), axis = 0) +# else: +# pass +# # delete the first row generated by np.empty +# road_vertex = np.delete(origin_road_vertex, 0, 0) +# road_inflection = np.delete(origin_road_inflection, 0, 0) +# # print('road vertx', road_vertex) +# # print('road inflection', road_inflection) +# # road_point = (np.append(road_vertex, road_inflection)).reshape((-1,2)) + +# # lane number +# # one-way lane-based +# ow_lane_number = self.lane_number_detector(road_vertex) +# # two-way lane-based +# total_lane_number = 2 * self.lane_number_detector(road_vertex) + +# # find the path from vertex to inflection +# # create a path dict +# # path_number = (road_vertex.shape[0]/total_lane_number) * ((road_vertex.shape[0]/total_lane_number) - 1) * ow_lane_number +# path_number = int((road_vertex.shape[0]/2) * ((road_vertex.shape[0]/total_lane_number) - 1)) + +# # path format - start -> inflection -> end + +# with open(self.path,'w+') as f: +# for path_index in range(path_number): +# # only support 1-lane right now +# # add origin-point of road +# temp_path = np.full((1,2),-1) +# # type flag +# # left-right format -> 0 +# # up-bottom format -> 1 +# type_flag = 0 +# for lane_index in range(ow_lane_number): +# # left-right format +# index = path_index * 2 + lane_index +# if road_vertex[index, 1] == road_vertex[index + ow_lane_number, 1]: +# temp_path[0,:] = [max(road_vertex[index, 0], +# road_vertex[index + ow_lane_number, 0]), road_vertex[index, 1]] +# # up-bottom format +# elif road_vertex[index, 0] == road_vertex[index + ow_lane_number, 0]: +# temp_path[0,:] = [road_vertex[index, 0], +# max(road_vertex[index, 1],road_vertex[index + ow_lane_number, 1])] +# type_flag = 1 +# else: +# pass +# # extend path from origin-point +# # no inflection +# if len(road_inflection) == 0: +# for connection in road_vertex: +# match = temp_path[-1,:] == connection +# if np.any(match): +# if connection[0] == temp_path[0,0] and connection[1] == temp_path[0,1]: +# pass +# elif type_flag == 0 and connection[1] == temp_path[0,1] and first_match == 0: +# pass +# elif type_flag == 1 and connection[0] == temp_path[0,0] and first_match == 0: +# pass +# else: +# temp_path = (np.append(temp_path, connection)).reshape(-1,2) +# # has inflections +# else: +# break_flag = 0 +# temp_inflection = road_inflection +# while break_flag == 0: +# break_flag = 1 +# for index in range(temp_inflection.shape[0]): +# match = temp_path[-1,:] == temp_inflection[index,:] +# if np.any(match): +# if temp_inflection[index][0] == temp_path[0,0] and temp_inflection[index][1] == temp_path[0,1]: +# pass +# elif type_flag == 0 and temp_inflection[index][1] == temp_path[0,1] and first_match == 0: +# pass +# elif type_flag == 1 and temp_inflection[index][0] == temp_path[0,0] and first_match == 0: +# pass +# else: +# temp_path = (np.append(temp_path, temp_inflection[index,:])).reshape(-1,2) +# temp_inflection = np.delete(temp_inflection, index, 0) +# break_flag = 0 +# break +# for connection in road_vertex: +# # print('last one', temp_path[-1,:]) +# match = temp_path[-1,:] == connection +# if np.any(match): +# if connection[0] == temp_path[0,0] and connection[1] == temp_path[0,1]: +# pass +# elif type_flag == 0 and connection[1] == temp_path[0,1] and first_match == 0: +# pass +# elif type_flag == 1 and connection[0] == temp_path[0,0] and first_match == 0: +# pass +# else: +# temp_path = (np.append(temp_path, connection)).reshape(-1,2) +# break +# for index in range(len(temp_path)): +# print('%d %d' % (temp_path[index,0], temp_path[index,1]), file = f) +# print('%d %d' % (-1, -1), file = f) +# print('path:', temp_path) + + # draw the urban structure including buildings and road links, which are extinguished by color and capacity + # capacity is specified by the building type and size/depth + + def urban(self): + # transfer BGR to RGB + self.grid_image = cv2.cvtColor(self.grid_image, cv2.COLOR_BGR2RGB) + # cv2.rectangle(grid_image, (0,0), (bg_width,bg_length), (220,220,220), -1) + # cv2.rectangle(grid_image, (0,0), (bg_width,bg_length), (0,0,0), -1) + cv2.rectangle(self.grid_image, (self.baseplate[0], self.baseplate[1]), + (self.baseplate[0] + self.baseplate[2], self.baseplate[1] + self.baseplate[3]), + (255,255,255), -1) + + for y in range(self.num_grids_y): + for x in range(self.num_grids_x): + # adopte soft margin detector + self.soft_margin(x, y) + + self.building_clustering() + + def urban_pure(self): + # transfer BGR to RGB + self.grid_image = cv2.cvtColor(self.grid_image, cv2.COLOR_BGR2RGB) + # cv2.rectangle(grid_image, (0,0), (bg_width,bg_length), (220,220,220), -1) + # cv2.rectangle(grid_image, (0,0), (bg_width,bg_length), (0,0,0), -1) + cv2.rectangle(self.grid_image, (self.baseplate[0], self.baseplate[1]), + (self.baseplate[0] + self.baseplate[2], self.baseplate[1] + self.baseplate[3]), + (255,255,255), -1) + + def draw(self): + # grid_image = cv2.cvtColor(grid_image, cv2.COLOR_BGR2RGB) + for y in range(self.num_grids_y): + for x in range(self.num_grids_x): + # draw the color + if self.color_map[y,x] == 0: + grid_color = (255,255,255) # white + elif self.color_map[y,x] == 1: + grid_color = (50,127,205) # brown + elif self.color_map[y,x] == 2: + grid_color = (128,0,128) # purple + elif self.color_map[y,x] == 3: + grid_color = (50,205,50) # green + elif self.color_map[y,x] == 4: + grid_color = (0,191,255) # yellow + elif self.color_map[y,x] == 5: + grid_color = (0,0,0) # black + elif self.color_map[y,x] == 10: + grid_color = (32,0,128) # red + else: + grid_color = (255,255,255) # unknow->blue + # print(self.coordinate_map[y,x]) + # BGR format + cv2.rectangle(self.grid_image, + (int(self.coordinate_map[y,x,0]),int(self.coordinate_map[y,x,1])), + (int(self.coordinate_map[y,x,2]), int(self.coordinate_map[y,x,3])), + grid_color, -1) + # draw the scratch of rectangles + # cv2.rectangle(grid_image, (int(self.coordinate_map[y,x,0]),int(self.coordinate_map[y,x,1])), + # (int(self.coordinate_map[y,x,2]), int(self.coordinate_map[y,x,3])), grid_color, 1) + + def color_accumulate(self, previous_color_map): + for y in range(self.num_grids_y): + for x in range(self.num_grids_x): + # draw the color + # 0->white, 1->brown, 2->purple, 3->green, 4->yellow, 5->black, 6->red, 7->unknow + if previous_color_map[y,x] != 0 and previous_color_map[y,x] != 2: + if self.color_map[y,x] == 0: + self.color_map[y,x] = previous_color_map[y,x] + + def add_mesh(self, point_cloud, depth): + mesh_list = [] + + for x in range(self.num_grids_x): + for y in range(self.num_grids_y): + if self.color_map[y,x] == 2: + height = 1 + color = [0.5, 0, 0.5] + elif self.color_map[y,x] == 3: + height = distanc_cal(x+2, y+2, point_cloud, depth) + color = [0, 1, 0] + elif self.color_map[y,x] == 4: + height = distanc_cal(x+2, y+2, point_cloud, depth) + color = [1, 1, 0] + elif self.color_map[y,x] == 1: + height = distanc_cal(x+2, y+2, point_cloud, depth) + color = [0.6, 0.3, 0] + elif self.color_map[y,x] == 10: + height = distanc_cal(x+2, y+2, point_cloud, depth) + color = [1,0,0] + else: + height = 1 + color = [0.9, 0.9, 0.9] + tl = [self.coordinate_map[y,x,0],self.coordinate_map[y,x,1],0] + tr = [self.coordinate_map[y,x,0],self.coordinate_map[y,x,3],0] + br = [self.coordinate_map[y,x,2],self.coordinate_map[y,x,3],0] + bl = [self.coordinate_map[y,x,2],self.coordinate_map[y,x,1],0] + bottom_corners = np.array([tl,tr,br,bl]) + + cuboid = create_cuboid(bottom_corners, height) + cuboid.paint_uniform_color(color) + mesh_list.append(cuboid) + return mesh_list \ No newline at end of file diff --git a/zed_utils/init.py b/zed_utils/init.py new file mode 100644 index 0000000..333ea66 --- /dev/null +++ b/zed_utils/init.py @@ -0,0 +1,56 @@ +####################################################################### +# Init +import pyzed.sl as sl +import math + +def zed_init(args): + # Create a Camera object + zed = sl.Camera() + + # Create a InitParameters object and set configuration parameters + init_params = sl.InitParameters() + if args.resolution == "HD720": + init_params.camera_resolution = sl.RESOLUTION.HD720 + frame_width = 1280 + frame_length = 720 + elif args.resolution == "HD1080": + init_params.camera_resolution = sl.RESOLUTION.HD1080 + frame_width = 1920 + frame_length = 1080 + elif args.resolution == "HD2K": + init_params.camera_resolution = sl.RESOLUTION.HD2K + frame_width = 2560 + frame_length = 1440 + init_params.camera_fps = args.fps # Set fps based on input + + init_params.depth_mode = sl.DEPTH_MODE.ULTRA # Use ULTRA depth mode + init_params.coordinate_units = sl.UNIT.MILLIMETER # Use meter units (for depth measurements) + # Point cloud setting + mirror_ref = sl.Transform() + mirror_ref.set_translation(sl.Translation(2.75,4.0,0)) + + # Open the camera + status = zed.open(init_params) + if status != sl.ERROR_CODE.SUCCESS: #Ensure the camera has opened succesfully + print('Camera Open : '+repr(status)+'. Exit program.') + exit() + + # Create and set RuntimeParameters after opening the camera + image = sl.Mat() + depth = sl.Mat() + point_cloud = sl.Mat() + runtime_parameters = sl.RuntimeParameters() + + # zed.set_camera_settings(sl.VIDEO_SETTINGS.EXPOSURE, -1) + zed.set_camera_settings(sl.VIDEO_SETTINGS.EXPOSURE, 30) + # zed.set_camera_settings(sl.VIDEO_SETTINGS.BRIGHTNESS, 50) + # zed.set_camera_settings(sl.VIDEO_SETTINGS.CONTRAST, 30) + # zed.set_camera_settings(sl.VIDEO_SETTINGS.HUE, 50) + # zed.set_camera_settings(sl.VIDEO_SETTINGS.SATURATION, 30) + zed.set_camera_settings(sl.VIDEO_SETTINGS.SHARPNESS, 60) + + # calibration_params = zed.get_camera_information().camera_configuration.calibration_parameters + # print(calibration_params.left_cam.disto[0]) + + return zed, image, depth, point_cloud, runtime_parameters, frame_width, frame_length + diff --git a/zed_utils/segment_anything/__init__.py b/zed_utils/segment_anything/__init__.py new file mode 100644 index 0000000..34383d8 --- /dev/null +++ b/zed_utils/segment_anything/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .build_sam import ( + build_sam, + build_sam_vit_h, + build_sam_vit_l, + build_sam_vit_b, + sam_model_registry, +) +from .predictor import SamPredictor +from .automatic_mask_generator import SamAutomaticMaskGenerator diff --git a/zed_utils/segment_anything/__pycache__/__init__.cpython-310.pyc b/zed_utils/segment_anything/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c5018bf16cfc98cfc248d5a7e82a81221ebf3a8 GIT binary patch literal 407 zcmZ8dO-lnY5KZ>0yV9a~Qt%J-pnKGVhy}q@Md>Y-v+)|5qx(r03y z!2v^fw@>jk+V8|SX*9w2b=x~ot+3LoZTsN1);J+UyO_<}pgK@Ca$$F|*S2dp*B1BX YSL@wS-{LkKqwdo8g!GJyYCJljKLr?Z4*&oF literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/__pycache__/__init__.cpython-39.pyc b/zed_utils/segment_anything/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1a0b27513dc435a7569b94ee59c70fea7da2172 GIT binary patch literal 418 zcmZ8c!Ab)$5KVTs?XI*4o)r9n9&D~2L@X3MRVuyZ64EAYLo-Q9QcFKc{G+{k@)tZg z5h~Jwyu3G$dCADp=z?GscdzP$67reF|0+;C;fePMiYRIc5Ke${3I#8qv=_Pnv`u5eC(SK5l8TN6__ zj%hbsN>~J~%sTmi`UdrMpFGH9)^i!R&)VrA(R>*u&7(mE(z(5 fX=7Z=gfI@57n9wws~>Q+z15HT`IvuqI~dX*4nK8@ literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/__pycache__/automatic_mask_generator.cpython-310.pyc b/zed_utils/segment_anything/__pycache__/automatic_mask_generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..627ad369ff90698744b46f523fd0e7e85fe99b05 GIT binary patch literal 11408 zcmcIqON<=Hd7ggEOi#~aA6!0`)T1R^-W__mq!h!kWtuWA$&w6=4j z+mB&&56K~BjKC{7@<9lkLoP{x-9rw#BnWcPAwhr~b4Y-oZ%Kd%3XH&x<0N(@E8qWD z&rHvf6zoH0x2vnG>+x6B|M>pDI&L=25`OB1Z@61uD3$(&9*RE=4`0HS{3#N*s4ShwqVr@Q5#5jO26=PN-gT(Xxq zUk#SS6?-LIwO7Lv_K9%KUJFm!C&N?rsqnOYIy_^a3D@oQ@T`3{JZGN^AF>|`AGRM3 zAF&?^AGIG1E!$E`FO*!vt-V=tYX_?Rn5$ebwaxcIbLm=J&FVW&+THcsP824Yu@m3v zIKp$<#S#FT;%4LWJu0TJ0UQF3Q=XG8^e zMz?pfb2~#na63*AbOLAOiOwMQqqNf(zMEw0T_@-c11I%5N$TwQfuD{#NjDZ=w$zQg zamS0gvFq{pC;DDQLwlXBhzB&9mt-gXgmRt8bA&AS+`gBcn9>BoJA)C(?e1o4!VBYD z7%3!TgxANaBw3xRfLLmgtqh}i=RA+Bkr&bXBx~kRJc*=ByHWf-iia=ZNQ+5tueM%7i!&%$_e}S!YxW*-&$$o1S+?t5jHy zcOP*deN#VB?B!XF6_i@;V{hvAYF_p@S`kH`K>wfgh=wPqJ$hT|neM0D4b)w8H+h71 z;r+DTy=~6BpGKRL-YIv>TX8R-$8F4r`ZL^%?j@$^v}=OLpUT$3l2?am96G7reHrWi zg53F7yl-J~-=`IRul@S({ngJlfAq|IJ^uasGw-1kzaLLa{`DC_oc`XQ{PU0gy8n;= z@QkS8=^6ybj~p?Q`&|yfT|u@ax3e<5 z%i~@z@luEnv1yo0>ou9tA}JIJ*qxZV@P3sGz^@{31|9IdA7_u`ZF1_|_;!>)g1k=8 zA?eNP{34gFwa|}vmN~1lg2?LeFdYoj4oy6(2rnIqsJlc9m;Y99*Ks|IEBSjQW7SoL zrE!DOeJyEx6)N&T`JTe%ia3*^#@HCw543TqqeKm^sR(P_;AcHmU5)FvT>VxlxrFxR zaqACDZsm`s<<)&vJWa}uRPO{Tblh}}eO-JRWlB(X=3`#yZm4xQx^s zSCH1nRiq8jQR+0uHKeVviF9dPN4h+2B3&7`kgiHa>iiq5glXl!*AhSKTaI=8>dO|h z#FjPHHi^ZOZ}nnfrMsTxMX4`5%MYEtw>|41=TI*}C#2e}xM%TZSOe(KL7G^!liUHf z+@fq`_51+pG?{f1#J8=RLx-1gI5kn~Oe#)TKH9jCj$6GtF>R`06l4?|2Td3si4$rZ3R$v!;IBhHq^ zN@Hu7c)S?2ESMcz4rtQ(T3B-C$xNKAO=x@z6kWsMd6&^J+`)oOmnos1cRVX`LYOGG zK#XNML5$5oQ;TGin$SMT4Pf%pICTQ+o*w=Ay3#fmM4I|JfBva$>!lvkv}J7+Bxl14 zhe=8scZko1Z<$(b!2{;bnu8~gNniKC1Vs2fbc$hwnT6O=C-UN95{wpx z+8J>(X3K5sMZdp`b(5VCZ;oi$dQJj9vUY|kZ%+sVYQ6CF*Q^j0PmC7GTU2A#W-o}H z6!cxqnUMrRs<=3CEkClpb7||zH@48PeB&(;$5d@3Q$FxqX4EMU7YDw-y8JcAUl^Mf zbJpvWEto5Zk&D&n?mAH)5(&HkeSnP8?wJJl^F~IX{a`z6=R{|$ct8FaYVecUW#8^m?^j?ilYmm za|Z$wEDV|q2N2&Rf$$QQQWxE$atcirEY_5-AwpL!ZOLt)YXc~rARx{p=63oH_$WR$ zW!O7t#RXx$kBl&8f#kH(mC$XM7f8LiiwIFD$D=Z>^VHf3m zW+wUne=)Hj>1_{s-}QSvOa$z4f1(!5{+zDa`nXv^a+yAaX>TWu^SloxZ1n`uy%nI`2vNx{UE(1c_Okf zFcSi6N|a!ZA3x~%^Y_{^38}2j01nk=G}w+@M+j%sE*Q0|7QK{aerhrp3M^Woxxr&a zM0R>=MYbTR0zwL<#2ShfuQd4)tT@98%nO^aFl8PVljQ#7dyDGN_j=oU{!TZ50>9 zjRWXt&@Ew>1T1hoEq2p9u`t&ZD&4vtiQCqz-pwI^8Mtf!r}0*CP6fzoZ6>M6Aqdt! z%ts{zC9yE5^P_F+S&ReiwNY>uHntW>*?=bqWJd}Vjscv=jVL&e@L#)o<@q}h z>R`k`YQ}|gjE3Ro7)T-@EI8;ki^^~Ynu9n-SScna3yh%x$lJk^&nvyL_E>N;f$wf+@q~s|i?UhVVyr3tZ zpps8f@@sWO)XsZD{0KAYV26W5d=9VPEeQmwO22sF;%*#z7Xkc*cl+XVSP)<{H@OJ- zeKGO+gws1tG=eUT`WKz4gDic#++B&q_F$A5osJ**X{SRgpL`if&33NjRsV}EiD z9{IR*SQ;xw%0wNP_7qovpLZH%D!fQMr||Fg;N9(&57Z-#>r`A7{^6#px%!*rjUq7d zTnXNw>Q*k7BDCGBxa2{89^T^pYE)z!S-)SJXlXSy=xH6j94OC~-ncT+ZCl@4x8ihk)o7t zzBJYl^jLEK5>EF_gy-(0)QV&#$oTf5!eP;*?>Vb|Unfr5>7ucczML$FV$L}xcP+`K z+afjm;7XD8X3>=2NbSVMYexxu1hN5HF)UINX5x9r=$^{p3YyTv`s6N=$Z?Ki$+#ny212j%_PZlt(rQ(qJj^FI^6p z+}y{dt_m?WYv2ec^nft$St4m@x-oN>Tq?$9`FiCG#mgyA*hx{>T_mK=C@RpROV4!vwrh znJ5SxLjC?wy{TT82C%J&6L|lwnyK5DG7TV#?V?OB<4tB>rd=l7b>+{fJL-F_boa9l zHciuMZC?&zc-hGnWTstYI)+H>W>I@2j8+~v0r3n*WP(8K8OncYS}-S5kgfZ3Dt``1 zR>of7sG>%}W8&~6t4QpWEgfT`%$)9>m$a){qexh$ljvqTDF?f`Ae&jEC=TK!8q}Ef z)NW-5HYjM7kbZ4F?hrK~C~Wf|A1Qu%qKrmG(ZRHmITAQZxg1};V@e~Zjm z+2e!3w{YY!QB&A-U<|sn1MPKnqD}NvKPXMgDeU-IKTv@nR#0A@567W*A=P#GXsZjaP9}AcD=|pvGDsij6o|R+&1Om zA=oC#bmAuQ3UyMO6@$Cz1chE@6%L!>Y;WK5a8fv;bCqL^Cccj5e}gL_Z=ePKG8N;a zN_iE$)KD7QvbxM=DtJv-Dj%2?L)DeL?FVa`<3h&twV0?Gkzx;nZFkE^r$2IMi2G;_ zlh0sl-Y&haPs+qi;2+{8<1G~@9aZc*?|u(lLmX4!QJi$_DL~LQ;($~;)QK0UHvZJN zG~$*C&O!FTufU@9wbG=KR=~l{6db!}j2k%*D-)cM9M*77V!jQ;3l3fyH^G@p)bDqf zZeHebfO9cQ17|eojXAVQoXcFwWs=T&xD=-_;L>HLaaFc{yL6yZ?s6$TfqQiwv*(g3 zN?M0&zgy}ld1wdev-*wz@xNCZ&2<< zB==}_T63~E3F$I6OZr54yHqeqQq`;CDk@}3=Us)_?wG2tVAm32`Q_5xFMJGx6s6~W z_DDRBKA7Fxr$vYIpQmAKIOyrHvChN zq|H`rKu|8U`CL=xS(v&tx(lcFNFCRwIaUSAUSn)x zEV#RPUwub~YA`TX^Sh)rnr;mSX-|q4J@4yK0?otLxW$4B1>ll{OLQ&ot>C_jgWjgw zfZ&z(PH^iV&@UTnLs(=!BnO_sxjZ@Nu9pQJn5YoJes!YlXn18u1^X@UMz zL`U)$1&U6~)-b*0aphpz=b1vEEm8_@{fkmHEL{UsbFon}kfdg1Hyt5(Ohq^FCCxIO zyg|HxQl#&fHDxHU6U2Q#P1t%;{0T=WNn6XxR5cN-Ux-iN!^J;Y^gZMmf=M+*BHdok z1vv$VO31Cw8Wb%Zb)*BsvY6Ezd>bI8j!4UNM9&gd13*hrnoThuK32ikEG+t2wHG)k z8qrDqz=Jv7q|sTHS1I=db=2ay&Cc{>(2K)g#KPhv)zR=tj5te=r;xOlq_ETU#Do$u zW>VNIY?ya5dB1J1&L#zcm$=rh;mdOPK>Plm zm24l=o8KnV42m-|<~HY_Wpv48On;GGosR3-(FrLX;qIy)~_WF^6dwp!WRhPhpnuAIe+Cc|g zsDswWCJlUQ0ojI}tKtXvumX8GD!>CE0L_#npXN3uonTT8JX44a&o_4gD4Jt{TF{8q zXxH!l-OR&$mA}Wr@eF;C^{bu`PB0CU0y_#3GAEQ#pnkXdf%ims89W+kyb5 zAHVBAQC0Pq6;1sRbv{tb>i?=0^%rVY{kdwW|D)E_|3=9Fztp<=Gqs`qXQQe9RBfsM zp)RREQJ2jhSKP;ywae;5dZ7=K`!Ii@wyD%a&lU7MEB!oE@DWmZsms32S+3>)zAjWb2LDlS1uecAul;^11%i%GJUg0Rd|00BJ(j1 zc$rRAwwM{F7i_e)oo3?Q35!7|9w$Viy&;W>Isv$@}<@!$3VOitcwsi$3y{k4>{^lbFaifb4l!ga-~U(7Yq=Ec z0GZvcuCA`fUseC(`~PaY(Wol;X_tS`+5Cp0{2M*Y|8zV&hb#U|B#z>!zGA6-*DMWp z&DR6NGJ=9tQ0cwy7lV>j3d&YFs92StYE^@pRl_^ODfsoEVKq2k^qawwwZ!?7zZ|Ss zE5WL@8mw7s!Me2`oU%>@r>)b$8S6}N);b$(SR27P>s)Z&Iv+e_Jrq1_Jsdn@JrX=> zJsOynsVXljPT8rvtvHnf&H93)-cnlCPe8MBv!!LVEj#IMyG|zz;;g(C?R9M7+O0yS zUk^u_@q*V)GUF95PO|cAgT#wM+s_KOhXddJL`4@ZHPdg|0d95NUKg%|29anjXXRJz zc=rW6v9kvC*y)KV=m_7Xu)QviB*o_>Q z$6xEaAr0+zx*{6TXl|UH@?y$$Lf00u+;#eHwl<*&gm(rbklWqP)`c5HcQH~x#0a;K zRf)42RROWoB3l`Tv(9-QSv@bJ_i@(9pLh~+mv&?Rw}OY~aK%4G;xb=>ryg;2*Ki9P zid%4t`Jc=cP+oG&JCzNE-&UQfQ~zMeX*kWd3s%joA2yB@dXHL7x9TqSG^*(=eV{qZ z&dS@yfod)D`xP!*b=KZCtkt}Q9C6JpIP28TU2#r*P;gE=XHaLI>&)Jjn7-NbigOlq zH=J{*d&)iSt}yK@&iM}v=OO1|9{G$@1y4SjZGeTY50fac6R-O!*6BsLQ<3<@#5#QP z9FmV)Z~gJ#{rAP6J^OKwfB*69$0)_`=aZ73J}ZbbKmLn<{n_95|M{Pu6%{<)1fzSQ zEk?XQt);9GfFJy9Np3)A0EUV^$IX`K3Yalb1zg?nqG2c57H+(qoj%554=j&*z1U5_ z@Wfa_JgL`U)(9mV$6!!m!@2jXTmTLezCGxGQ@tpAByW>b=S26y7(DNGdN#>hR^u1B zG^__+$n(sZlSMkKMZ;t;Ogc32tSH=MD8lX%EnWU2rZaFok1PHOl2milp_0}q-PPl> zzYC>yp#D(ha#cK$phj9wYX^F&bky)H*Hp!GX`P?TiRS2Bzv&ovmH7K;Ur3vOrZ~kv zpOlw&HIa~dB4s&-(n}jodDjqcqfGS+@0(gM8ezH_P;}~SIzwQX6%K1(>8Bif7N7`*fb~lAvRf7 z&0ZwTWZN~}F!6+IdV$?{+tUtm9(7}MLdwXDdM0m%Ie^9+B(X_5$sKUZP0EI5&-0-G z<7qd3bkDpqv`KN9cIcQ6RIJy9mXbuY&D)^BaeMaAPfT!C0}pu^Uc1B5G+e z_q=2qBz@bvefcGG%kJ)OMIp$Z*1InH!vIr?pPm+(as|wbq|u+|5vR*yCXqReU0w`Y z7R-)yJ#^n}ElfG{cq&fjMd)Y~6y3z&d6(fZ*usKLmMNy5w_G!}0~i%&ju_Ll{Ro?b zrY6ZIHKBcw8^GixQDXb%eLZ@!b){{d6KU$_(xs=`=F2^%Y16zgCpj0)U>GN~ad&Av zcaQeKOBQI7d&|^f3m!0c)*L)}O!}4!CLqG^qf-n+%q+l`+Myc_V}G)A2*$lMwxygdO-p84YIZ<+yYlo&0Lx2VRM=ERKkG=F8EW3o`}xgi&}ou=gM^{=%T~ zZ~*a*V+b!%DRt3(Drc_Ag2kHfHALvz6Pt3|XW9UY$M=adiMj2*4L*vFO&RtMT5&;` zPmmEtdAx_{Y172Yfwf(`yFEvjX>u)d+m6krE?x4ZO`Q`o))gj^zp#sPK2wwY|G$`6 zkmQ~Vz3+Iv9wq|zIGLygvp=J2x;~EE7p_aYc5mD3Zp&WS!~iVJvx{lIjLij0^*jNc zMx)8a+RWSVKIOxitZKa2>)B43~|w;!apBu|7U24+HFO^FiB z@$(10bm@LuCLxu1(T59kF&wl*#}>jKwdRc4b(3C7Ge0pI34(+vidZW1^isw@oMZi30>-V#Hc?Z}6Y2Hg^7 zNx%YE(PYQU6$^7spwi8gNNk&LxOavCX5g{`+`PNIRwGl2l=Rkpd=Os zbzayupT{`RUKi%vg$tVtq+EdS24qJH6pjJhz6+bg4Hwwaq{sGL9-ZgZ4MtQUi6VbN z6`AI@(Ac*{p9qN~e@OBd1DMQH0BXU#C5NdzKKx)c` zGmM7eXBbE#AS^%*f}HmL?_fzTB+#*7qOA!0b&QF4Gv|6hV}g6_mM$KL7hx?QU$3k@ zS*rc@H^beKgmZqUGPNzff|l&ZA+%r+gS6HJIVj=@Bw6*_wm)=l2oVXZHn*_itMpdy z-hE2cC|AktQY6F$YJ8EB7A2P``4T0Ylw78yP01BXo}}a{B(0Uqh+V%YzDOm%P04T6 z5m7tq4G|yAq=OyyWAQb-`cM%FEh)daePufe+$#Y7!o7Fp1uO`#nG;_D{Js)KiMtg= zyYZEF`%3Kg3AA_Ya0JB~_OI9ze_1+wx$hE*_F$BiI~_0dl1_)VA^s*3R3ya9o13c`?XP~!L!730-tXO9^X#kKs(a8PSMfeGrsQV zj`22ms0dEHpuj`aoZ?j_MBAOBLtdqT2YFJBx@;rs_m#1pl#((%ZJ?I}^#$d7>DZ9Z zsj;CXmApRk2WnEq6M`)C=0H8H9T5dp`FvRCr&|7mcYUI&K6}fp8}qH5<_D#5A=M8X zY2ip!l{Ij7_*?$X1d@9y67gK6VG96@)`%vq!a?&52);`g(leT-_fJy&OmP1+R zjFY>YfoB5~1xyFkN98ct(|NFiZo_}tdThB7qhedY(~zD6&4IV<_y6EKFHLugdzd8@ ziJjPVf>!*DZ60`gu0O@(yd1~ZBG;rc6gC5<)XT(gB0WCm2;>N#db@CJTZy7P+CeY} zONoE!hQJKxelK-37k4uUwy*;isPn!hlE$YSF=xr8V!W2GSH56e3oDLV@BkCW9temX zpQuUTn#7BzMGsTiU@Uv$>P-DM8Fikm&vkg>c02^iSXhq5+22pX{BnSar!M%+L6igK zF~!eW8-~0OU?B#IIbBi{Onl_Yl<%jESB7~lS3jY7BFC}uW;5QUaO~tq#(;v1w{XSl zNcu1V->Z*R1QVftf2`dxZb_rqQpFnHf2d_z`-w~kuwq*&ldE`>Rj<-66BfJnS9ty} zT+04eA8el{lWJe}Blz9%HDo3oWCjL^t;W3Wh#9~9-^o7Bg;6R-Y#}Gf0>-gVcr9SdiANJ})5RB^tCm z>8aJsmO7J_VGt2BV97S-(EhQr4jn-S37!}?5m3qsf?_V24*yUzXl~0?T1+BZ;>^Ht zPmmSiDY&82DvGD6#x<%@A$jUJI4H6z-JuYGHaJbuJ!db|kv+C_S$RV0yJ#E}yi(3r z^{Q$B%NT0?mqw|oo>g&ab=^?cfnu8KDz0U^|4*?}*FF!yOgdk{EQBP>xEQniBQmMF z!v}uv;4oyYC9v{99dzjj`divq9~+5rpo|L%EPHAkXuuSYp}aUor1}u?Y{acgsdi+H z%W2`Tl$P1@KLr$ms5sD#A^yhEfIA9~PAKENGOnbR!z!MEI(BNnN|mH8pXgX8tt8FE zCHahFM&%t~B08oyLezX`*(ttfjM4URC9Tpp$gLhJbO+8T-TZ{)kxxj(%apr83CZh+ zYSx&Yc_1LzTFVM-;{-{qM9mDC>v+uceQ%I8z`pDp@(CJXxSa@+bD>ieEs3(Cu%Mj_>WlyQkMhX%Dh84%H9>7>ruAdA5m!*T<0!S`pz^RUQh0 zZQ{%zZW1@ClghLh+(qXs^eQWIa1Cd6yRM59z!9Ck9Ah-`8k+wDuJ|Gn1>B>l)z8X> zBDiT;t?PAd8F!T7I;$F=7K>G&=Ka=#wa+mjYqtt40Furo6E`JRco9%|HuI% zPNemWzk+RgS9!}A7l^aKOT<^@cQl-hl)wYb2|aKSanBsT;%sC`1-`BjHzfL@L3}~A z@u$6`6X%R^a^r#j&Y`PSi#9k^efPaDK+Ett`f$Qrx#dooQpL~X&WcY94QQE7w zDEA_g`?Nl-IoYF_^cowj@>qRWnKM|-K`l)@hYFe6`B3GpJ*M`H*fm!B```E+R+*Qh zldP3^0b?-Zx6X(^r2N-u&=bS>|! z;J%8J;)YX)94b3&)cSo@tWycsTH7tmw_j|%YtSE2gf#*|nM|&Ls5&Ft!!j4r;=!cXvva*RNl!SnXQZ}Rx(OiKrM<*HDHPgp zk`W@%RCF7U%XEAPnFENC&SKV(@x+cF^}Qr!<4Qp+oWsN|9UmT`YAjgg5W{?oi+{4{ zkC11CCN&eud21t=?gS7kCOO2yh&~4xxUkk_>qMWDRZ6}{9W{Aw(}O`7Icw2tVqgf0qMx3L^ z(@0uNQUlQRL`un27P2DS?d|-2&sv>M3R*zIUaNx77v{3Qy#?n~@XGMA4X8_!{;a|V zoCrK_zK3{S#g4_af=a2ZFGHNqLA=-5lxv^X%OvgqioYxtGyupZD}dF{%6gsir`2Ws zoLWaIa%xfAr&IBhzfNf{D$|xBzC*;7DO}B4|X&5jav05mGyXi3BJ#Pzpdw1{1-k#}fkzh6Jk!)l&f^tzf(=D~!^S!tISC z757?NbxIDw4lReC8We>AAW;LY-yGNR*#?9hBCm;$@R0`ca#X+yfCHK-sQ{W=O6s9X zHSkQKH9X(>2vDPu0{VhRtVXN0|IySleS^Qo!ch->j`i!FHsoJn2p4CIv)BZ=$v0oU zMLvIic!3C(8M^moXQPPSPe;DtK)yMO#iobiWdq;t&VOaGa5izdd)LH4C1P5THXKxk z-MIsLY0|O%idEo~5e|IgTOF(9#Xcg#F$#3Y3vLPqDL+I_))B3y1ljCuo}tojQ^HpJ z9pthKrg&%Qx`4`T&ux0mvdiF~Ejx)V3qE=y<`wT#@?#`c1@q7SLZ1q+Q^Ilfl00X3 zER2w6rIg;FDrAVQqQ_11z_(OA$DHBUQSW)oJElIB`m$>12Dx$wL9aqySJkR^T3^E(_$H6YRKt2zu*7zgJtu9UTmu&9>##;5lSAV#H8d4lsrZWXv%ZHML9Cq zH)Y}D$hOW>d_6=&s?*7;9r~CSi-hd+IVkB89KTgba z#7o69R8ISb@(2fywj&BuWKd#J!YGbB6G0k-ag88NhfqL5AI%{=wjDWwmUK=b@`5uE zdI&FQGVvCYvH@&<^51`I8m)DHKrnZ&%lyylPEj?lu0CRLEjU5HsDB9%3QwH>x~8G^ zuLOE#ypS|?C|CM^ zkE-wr6a3{PAMhfbuC(o-FIS6W5u{@!%P&)Ig%Y-Kgy^M1OLmN{mtYxUX=0SiW{sdO f&*r3M&B|Bh?q4JNVpizj(`D@;tqv}IY5jiz&D=O| literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/__pycache__/build_sam.cpython-310.pyc b/zed_utils/segment_anything/__pycache__/build_sam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74374041f360d8ad128fef7b3dcc035363fe898d GIT binary patch literal 2152 zcmb7FOOG2x5T2fw=V|X|6CO!`2nB=>$r6x&1Bwpo z(fAZNTKNMUHo4`N3le9L_z4{PfP^?%5vTBwNF=E0vAqi-2WB+YRo&IqRn=ABlws&; z@aShgq!VA$en#Q&V}Y;%Uo`?kHA*rq!cosiu17l2s6oxl$j!*it;ovl$j)0)D|aFX z*4osf_7@~_nHzZ^wWtG<51dOq-~t9K%x3}hzcQoHC9o3E5LVihY-!!j3{KTHyF|7| z2^VRS%?M0%E!lZLPX_G0qF++RKTfx$e<`W1Utmg>;Rn3TM`FG#jjbe?tJ}NfKzaB2Ej+ zZn8>x{UPgLAC+k#?h)z5y>Xh+xJvSSgdf90HU2CNf4;RoEOWM=6qK{w^$X>0k(FS4 zee0b|>lGX1tPpWh?1^Do4Az%M&W`qEXDP)uQV|c^VD2Fjoiwivqy31VfT<^@kGrr6 ziJEC9WaAkSIFuG93ae>`{A?9x(J;pxa!fc^niznJf*%L&-wtJXFh7apMIcXfp*nPf zI^3Y4ZqQUWXu%B{;6eZfz$YDz;g!uC^Uq1C-!Y#_Dmx1e=vPIcWGR{O~sU$W- zCY99dV-gk*`hbOlRagB0gp!Fx^@&Z5NvpQ&)^Tm(2u+Y0DzLVxDRg06Gp||LSW~^$ z>lQRst2U@pThyHf!mh2_shu4icH28pzto$WqD6h-&;WEoXtB=J1kQ*0UYS~uoz=Sp z=#jgGu1&3pTf6)29gS$UNwqQIsqlb1dP~0@PCQ}k`?Yt6O#I6L117w;45AgzIjxF! z34vWiyMlY&{h;>mz)bBz*BpalfDYr&L5QjtW>6a1VEW6~`+t0Y|2#Cf!A8~g3@f4= zNUP}&_n>o}G#|uO`Zd%kkfQCh0ba8R?=bwn$4PW zFw6mopl6%iGblBUp9Yd@P+%O31uULHi=m27*#dL`)P%@@;jf~O6&=Hna;PIP4i>>7 zIH{7pIDoLIb4b9wQ9CWi3QJ;@!)R?O4^v86fl6It@H##Zu5JuI`sVpBAFO^lXMVb} z0gK=L_R1%>U%Rw1xcx117dHmrkfzMS=2_Hb4wlk=}%q z=PW5Ax8W;6rOtcZkRxnztOJfcC@oR){!r>gA>E63r<-uQ+{Sfm8IJ8KO?D&6q*;zw zAx%Uz*+CjtA`uK?>x*u{6}TM1J!?pgB0+qNRu;}K5!pL*ac%6@gWBQeV2d;WUM1O) z-oXhCJSm_;jN%CTZ^Bny1!A^LViIQ#mknsP{~_tb{8euQjKT~2W8GeWD~=1<{M`;f z<45PTLjrhD{^jW=QD=3@X1BxNfWxG_GcNj~EVD}GFmiJEA7)UTOxQn$snx+-&5xuigy{bblwKdamLQ$kF!GbbL>Drse6X$d4}id N{``mWSP$Ee{swIMDF*-m literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/__pycache__/build_sam.cpython-39.pyc b/zed_utils/segment_anything/__pycache__/build_sam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0da714e558261f9e7bb001a35a36bc591ec69a6 GIT binary patch literal 2218 zcmbVNJC7Vi5bmDW&SM|EEHIK6WG^p?1X+i$10?$dz7v*5hSBVH-|qP4vGnx% zd}uZ@ge(66g0nMcPLPNQK;j1w5Ojcqh(;C}4+%SDzl2z*iktJg}pH})tn5gB=o$N}u zyD7pn-^xV7#Q=3J3i1tT??KC-0wUq6Ok?_xFyk&Q2s39Um}IR-1YA;{geBz?kgieU zzOsYlX(TTIxo4_nURxLNq|aHvqC~Y8LyV8*=wL-UNrDj%S+0&RRt)2;AI3o_Wg0{& z<2QM(+`)(sw#QkN%6n9~L4Oj(EXc#;9u-IMP$Hgz?yn2IQI_ytm@>h4dgrs9G|oVI zZ{yudy)=`&pJm&5@9f!L&W8z4We}#jaulV*-oD1O<6YI+mx3FS3`T9xog*Psn^%i! z55+MU`JcWli2)8nqXkAG_!7gdgg7MSGLE*8pliiZB z|H@1ao?Ne<+^C-1geTX*pe85)3pMD(^FT@)xf4kKlZE}4Vat;&JO$%jLp2!R3MaAj z&C2a{7uQ{XfDT#|vMkZAJAYI^&+|q6zo4Qip^9uJ4oqAUQLky_5bV|~t z*Gc}(Yp@6FRov_BH4E32$hKS5RZCx?iOWzy%-|ZU&4vgKJFxC|)UnI0EIM z5e$Npa^=Yb1oS(H1l;Slqij+^QBe8NU)`5S5o0{XMO}UPCYA@EtPL-Jed3D;%OB5~ zpRcaLCn8TjmtZZd9j>A(# zH^Joz4^!W%`CKxkiGH^w5OoE{fxwJZrp&}(q_i|u&P9CFb$DH2;XI~{z+_bhzY)gD z$i_TX27;pMAPsUEN)C4$$Zk_0s*4JJR*+jog7E4uEsR|Ny6+x(xHe|%LF|YgY*9L( zv!YwdJviYp{1W)Hc^g`O4G3u%6sx@tu?V&8e@NPIMjJ;hsM;gLT7YrYFXTUrCd~h2 zow7k!lb)t7txoDv_?<4kj(7`>RL<5U9mp(;^U9UJoxoo;h8V5gTe+`tWaaV-^$_!w zv;Tt-CqWj7Uup$z98?M*U4~IE#cp-*p=(!oUtzcX4y>HS{2Xe|QK_vpwdZu#HK4iP L{A!P^7J2wLqf|4e literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/__pycache__/predictor.cpython-310.pyc b/zed_utils/segment_anything/__pycache__/predictor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b79555483416b2e923119c9b98ebfafe2898db8 GIT binary patch literal 9963 zcmeHN?T_2WnHOI~QCh96U9S^6aTE4+SG~2IxJfPunxghLj#H%D)b_!NE(e0E8ELhq zNOefry9+K3cXrw@bw3^MOHp9=+vQ*4$Gb1*Kj45ijhp*DGZaN?-6U7^1vpd!N91tk znP;APUVhJv!sX>g1t0y=PuRa)s#N}q9*Q3W58uZX{SXDKFwL*@H2JIdbo}bRF*bXq zM(+&2I<|V&xYn!T*<@AU9@l$yt@3t-S*-T3!fFrnUV~}3DjoYV=vB5mT58<##vMJi ze>{mtVc_{`_4ah)^T!(Az%Y+jS^EmmiZht;0V>&K0g3R~j!y(QMaX+;L zkEbHo?(2mmp?kVWd!lcnh${~&$J$9nt0Y=darM|pv^{eVJmJYoD#tZ?CpAMq{O^X7 zef}Zhk#p9U-r%Y!~- zjYgyK&?Rd~>#^{H2$LSCOM*u{c4bT0kr{D+AZ+yhSWD}!+xNXFa@~(BpS*r~HyrcJ zkd5H?FaIFCANZliqRWu)<%kbya;_H~#JiB^=uQsA5~!`>iim7u$JYiU)olQtFq#HSbcNgHOLDsE_sTezZsN0Go{?^UAa4`FYK{wsKsUmFj! zf7KE_sq9tuw7Z6QK9P3Js?7SxAj>?0_U!4$(4`aY$cUj&iGdpUnn(JPnV82lW|MWF z7{@j&yGq}W^kO+MLB)jr)e{pM*}yydBTzokp^a5&{*w34a4PwVZOxQ0dYTqSa=A%8 zp+EQFs#va&q<@fWPKa~@G(u;G8cVxE>-(d4R}x;&>eihxk9R}%MT#ip>}&*+ZooVt zyn{>^RXrKU%zVgLt~cSGs#qc7&r(6dq)VeRM*hm19c=0~ohk zi*4E8R{dp2s8gTghB9n~eps~Kl*Nis6g+U7hdV^VhR z`GR}wz$q*Uc>wB^Sp=6B2}iJlRN`yIus-LNqLlpj^NyOjfl0rzjUVB22ddXF`S^cQU3ZmN`W? zP-16?Lo5jQvry*|K$MPRJ40<<#(bAjtGmW9;oKcd<d~b*QQEKdj`)Pwh1P}=X zmA1!IKZYKoBAmulv#cxaWp9QA8pWcDFXM<_u$j$L?J7;h!2#c&)H&6)n8c~m>H zkLpJ-Ze~n1@eU$CexnO;6~Nbs_*qQHuX&;$X)(}pLS+p#a1xJuZJcM~|3Kmqxi{!|FSK>xsj z%_wK_Mc${VNkstSV8rb(;IQ^cG+WwurmjkvrTSD0${)s4?ajWMkq&s z-*|uXgAUzD1OjkQjn${9!GDkonDG$!YJ1L#Qu5iO;ZmtE&&jEm$4*#qzrb>Kr0 zPLXCP1}Wiqso11MYsc%~EjBN54r}`&hF{4TR*UC3)I}}cAC#W-lJ9@;{s*UgB*dGW zc5z_2PAW-?LVrrGR^>)YS+Y1KmuAz64?;n(=*$>C0NCV_fkP-ady%As$-B8*{ z1V=F!(u;SU8wf5G^$>Kjki!Kpkt3Z&ZnEjzc=OJiS6+RsD78hG_N8jWO%Y4&z-UsX3XsWV*z77{KsmFH5wDVZyTk=SV6F zlUvWl3()Y>sppSk*lwS~W*;K~7$XT~C0bZXq=*3iQbar0RFS-Or3ivBxHK0-?jdFZ zkYYz5FA*#WYZGy`94+3w?)gZBi5*hlnLH5v(6F=$rKoaxMb+!bnFg0=8Kd3MrwmIj zrlUg90$))KyhSbtHUqdT4%AHNN&QW!6{VE^bPlD0UN#QA&lV|cW>*B2ptsHztqxSK z#6U2`_JpAEC68)bB7-ifIh>>d;#tx~NTH~gWoE8IbttzZikUClawz}suyv4maA~0H zQHhVan#~i`nXZ5oCqGLZez}n)Y*a#6oV}xFTpD~fd$}hwvt-&Qfle&A8?}Zr#hqZ^ z0?QkRC2}y2=}yFdI;Fjv+?|kdDjOlgozLTwtB{$3;4`>Kflb2SZk{4-bi3US^pd=~ zA~Tm*w_|S8=kSeEh{>reC<|vD68Kxq2z{ZYz*Tt?(mBD3lnLBng@%I2oG6U_L z5ijFDZQh(BiyZT}gg}NVv(To>LKQ?4o7CegRJ=g#s{?G^#3~gvfaGR5DgGq07p(WL|ON~*p0ECf)tHAOhc#Xb` z!s^&ahi$E)>j1=Uz~MD*L$d*o*KpY=FKQ;<&A-i0*6P#>ka?{HnKRO9RZ6ska0&Sn z;@zTQe}Mwnu7GfXYIUH?HugF70q|y^Gn9li_Y5FFgAn8+<3pV=u&l-I6ZWm-drFEW zMyxXPQ59_(M@#hPxCYb+oQtb|VjMLSE4rFAWuL}f6VJ_LDQQR)SoKq2G?Mh3DM#Ufcf*1Jlprq~P0SB1RGBR7g>1GmFcAUQ`O8TL?c3a{K=c zNw4jHc~I*7h`?lqFQLTPxhY`Zd4xh7T*L>NlgYMt9Tqx47)y-j45k9DvHqpKk9lAN zieG?Pr|}X9WM>h7L%_L1Bscaq1pN63xag&_jk-k)d~N?PiGmjxD&k;;M#0XX4G4dk z!3!yDo(lYtU^vTU{*)+KybhGI-7!*2hL1?Rg=cXUMW>p!)p1IGl8W6)RWsq~`Do&a zhy%mXX{cLh^DM*O@>5<&^oS!BLbc)zD(G+}ZOtfT2boPmelCvk=~z>|L`|AQ9G+xU zdAc$QqZplNyHue?(IskxgVUDmo#%j^yotL1z$K5htxL9M>kXhv%TTDY@x-c`z?0e2@7C(3-UK2&eC}!ZQsL!3wn_63 z-Q!S0!Asf%2FxqMOTa+_m{A+srRsnbHQ4?$5C1cFa#>od&{X-Le4YV5Wv@>i@a1Hb zZN02a{Zy`(OUz-&T!M)9cL?+$!Nt}*pkGdh0@1uRmLMK^exVmaBc1VYCD64;ntVW=iudsJfQt14!-6Lbhi&LUC;L4)gb`iA7K*&8629aWHCSaK#th$ z7zK4EUHw>73Rm8kPtTGDFIdz6pm{_lU|FOQ=4Y#hr5!pYU3f~jQyM{9ArlVTTPKa! zqd(DT+c+5|>M|1)H|eg-lzf4+>`FvBW$TA$OGB#@(Nm-4`nQ2jj85I?P`=_KwLuB} zLy14vpfqJ!Wu$?|;fnz390$s@;7?Jn?)B}CrDREB>{+z?;qGR?_tTn0{_gJV=KfP> zMCUCQ%}9z>w>0%d3U|IVH{_YK^+3u^hb0}0GWpxrDWylLi#jT|JLl=o7Xh*j{Q08c z(%)L9KHWE6mv%|i8$mLWe}qW&V3O9dWA@a**|zuwy`$_L^`Scy9!o8C_$K6u6eXwf z+(sDG$fUhW@dK(PZz66{r44U~9fJ5kttHJeUC#c#(dWMJx`m~sBVa*i!-6tVI&+DB zhQh8u$#nhqy3u-~pEaK#*~Pv7L^m6sS{I>ycy7q&=BN6L%0)@F{H={QHs08#{;6Hn ze}}`QPxPAp+f}>q8|z)Ys@j^L)2{iO_PhG>soaR6eGzOWFZxvWRcq6RcrS>vKFZxwuK!5y{$L( HRpWmF__!ky literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/__pycache__/predictor.cpython-39.pyc b/zed_utils/segment_anything/__pycache__/predictor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e67419529d2a76ed30d04657351b7134bf345b70 GIT binary patch literal 9946 zcmeHNTW=i4mG0YIIi!XXWlN3{H*zD#OATYi_IjOluo7cg32eqIX(4;oNuxPkLk`u` zJ*n!ENTLUejb!8{S}cMfFL40!E&en40sWY~L_Z)v;`oy9RCV`s59!#uNV32pGnkt0 z>Z((xPMynls@h##Y-;!zm;crK&&!(j8@efeOx*kne&RbQJk8U6t*6U>qi5jX@Xe9c zvvhi9`jt_&R~^-QHQZZX#jlSVy@sy6s(Dqf_E7U`4~$;Z({E{=`eV@3wmN!Z-f~AB zBdLFP91X+3^^?l&$=GL)bv(!8<(o_l57}#Bu){=jYv{2)^(=l`>5Z__he`f=1~Wa~ zGu}0Rqz!b>^sI-ip5-Yv>xcaCGYd>4E4RJed2$cwb7sGr@w&-I%nRkd2MeQ<1DLjyf1(P#*&my|5ucV zpY$gwodu_H|$3t--7HL!W}uI&|`j5UCUs)GEjcxxOFXw>cX#5b)M*HW&>1Ll#8#$Q5_(v6>sS z$=(^>V*x6{(Kr&_W=`QM-)}kdyRDRCHpp5lVTx`CxS`7=7#9z1hzW$ zq#`C`#*@0^41;0hI7vnJ=T%58S(MXsSkQ;4(P%i@aY(n5M#SAfVA7*xfiuA(N4CVu zSc3TjUPtfu^rYc9ecu(rao*G3e`RAg9I*|^hO_${-wN*se&~8)1M=Mn!ia5$;a#!O z?QRIRLo;>U;2_$Ca5qx%guFXG;0qvq7C%8_XbVlfsxRq>=ZhNB%_n2teDzw#Oe&JQ zq@|`qY`YO}pnIP*B>PlwLDSyEPy7r;YN_Uc;S?IRXB@THhZ=yOQu~TDNMCShO2@pQfBr*!Eg5?gpOAxqFc6r>ZA| z$xMn2NDkh$U)dkCek5Q)Ia_3ls?RX6Z`|3mC&M6mk*0-JTvz;!hx^R8@45a228(9@ z{H-h4M%me*F?lqe)^#gUefw-Fcadsq_heZT3xqSuf2M6wLq_9@F$=m ziHZ8de<&)JJ(`G!WJr9JdV`>_6TRErdi*gyh!s_j-5d7VV{%rU;s+-s>@0AK5c~oa zXHfJO-0|2yP)3$i-~{`-Ny8OR!~zj=ehF=o8tiy9iCEI?he0&lnS>LOoFUtk?X%8e z*B!G?g)bBFC#j%el7(R~o}^~nTQ1zF-h72x%HljVZ{#-rMcfK9NNuTUHb1nimVVZ3 z>23WP_{g(zyG|BJZHN+OI8lGR>!?tUwZEHtW87E-lPA#g-{&m$~ zMvDgZIch5Vme*En9#b4y@bJuONgHjtO{{6I!~bVCnj4f1TC~JDPs*m0xw}4Ru6JM; z7KI!EbxI9`NyCIc*hZ@?%6ys&D3DWYYOhIM?B14on>hqxEAvXy-AEnH86^ks?EcsH zZ|vWJx$h49yY{XtU||T9g%Ok|I03$jF5nae14jhj*5Fgf-lI;+6oomv=nYEjY%_=j zNFa-XE+IkbFV=3X+jlyt)MjHpCY6=O{g^@y|2&GV!z**v2?j_MnFG1DdE<_nl@Bm! z3wUpvKZTx;KZ5(X-C0abp9TC&s7cI`yRW>)BGF%>qD>DL^GVCdU6IpGn##guRTGXK z!J*HDNa_P1lnFvQf~CmpkVPF!`HQa-*;OjOOhwr#aB>CT#7~ev($2O%HjLIshWQkn zLi+=&Qa7slir%7ded&Z_DEO$`*E_YOH4d>FPCpE}2WHfr?`|_+B<6OwpEN0A0JK10 zNqaQ$Bj__K!bwCmiwYGwez-G?gz_Ra;wt@+_aguyy~rLsFDzLTOA)*XMR?D6+nUy< z^=V@YOZUu)&VK?lfPZlXegg;_C?JjV__vOYsU87E$5hr(1CQ~@oYrKWfs(w2s;kE} zpnH4~?`26?Ij(1Qgdq%##*Xn60cBMR`!sh;dLNdf(qn*`V-E9 z5&Z)jCZv4Fr#YjdLKR_&jS;uPfWh#u+FScqYzbD#_GOO@i>LlDXPdl|gvtOdE?kmo zi~x_YzxL+(TOGQPxCAhs8mmuHkw1|OSoIE2*Vc>`rTLk33iRwUmtFf=jEh7Gc?M7v zbznO@oFE}l3{q0z>>@;`+k~w%eyD+Q&$M%Sczf5-qGGl1ZckQbgO>Op&~Gr3ivBxI7a> z?jbS(NU`b4jM0CcYUPH#11KNCJzKdG%T$`DY%?mQT15Zli)Hfqu337 z%Dd!Z+A1Q=@g*X_FLF7s86aSOpk_Kt|8GjIC?)?VGb&XZewH}=Y$Hn)se~{;dt1%8G_PlSfcSkXITNyuM@@RsCb@=7pVA0Du}~gOAW5b^ja!drgset zm0q?<@(MDhBUYroCuiXPotY@X7#7@tkrmH+R6(2Obb+26)qpF3h4E_~o6}ZY6+etyvQP7_h5J^# z5H}@Gthh^n=($n7@fX9185a9?DR{cDsckpv?@C+74vVdJS7zPUUEB&NTnJEea{K=cTCeSYcGzmaPH;0tpip9L^%T(YEQTRc&LfD_38h=Zwinvi z(w8XE9!xk|WBp5;C9}{6oIeM|P9iE0NY6F?j*)YRNN()!82Qg*mElA1>$vCFP;@Fu zTOGsXr?S}2R5fFkTo7ZI3kH-&C#z1O&66HBqfdDt@g0@DP(g<>$x=okJrr#b zHgs^vPe-KuS!&YS!J$h^l_ksLP(P61@(rR!NI6-Oz4I)wvp5fT<^0NUm%mCq zm(vLqX3|z;&T9C#Xdps&QbcpOloJ&FC4TaVT)PZRYSi_nVVPBhIzMdIs=5U%DpBYk zn~fD9(w1@f)M@xrA!_6d@?4e9va%0d<1j>_Oxi&P+AE?=phCi#!p6?3I$uQ%_6e=S zU+2Cq&3_fBDxba2QpHm?1=aarPDYv7v&!O6B!#))Y!4|;AZR0qupcsEZ0Q35=5#0? z&01q6qM_>-dLcs6YXfWf_LIFy8$)V`=!NGxv*>rY`PmacIQtzI(8d%4Uo!@}TZfmg zr~CHm@PO?Pv0H*v56x6onV*;-nZ)^9L90n+-_w=el{fU0bEL>~M)f;17xD#`s#KGY zE%U5?XqQytly3cbji7s(chELGX~Z7AWuxQm4t0z{X9)Zj6=kMmE{UUXO^{A$#^Jfr z(CR?-#Avzxt)Y)cr*(9uuQ|wYP(uGunzBTWgX@Y%8fYFq4ZO~99!(2=ih6agZ*{6l zmL$erm9|No-Sq!{Qj-aQvzuMqKXpcQPE(~BNwMOTB>Ns-D9sHCX|^7)Gdh0mR4L`h zfeCVdlo#kdjsVF9-s5OG^n%R9r|XvE&^C&C!bukLHHyRt#z`$boli_0iSw_}GfLl4 zA38hS^^&SOtmEt%ZWMJh~cWYK0Oe}(SJnef|GS%|RT{TP1x>kkasJJduH2+T+9qX|RLPuiLsOUH=(Vx^ p?2B#=ZH0VqMv^bgQ7W)JHn0D)T9a?mWCgv*2G)(1(KcG None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=True, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/zed_utils/segment_anything/build_sam.py b/zed_utils/segment_anything/build_sam.py new file mode 100644 index 0000000..37cd245 --- /dev/null +++ b/zed_utils/segment_anything/build_sam.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam diff --git a/zed_utils/segment_anything/modeling/__init__.py b/zed_utils/segment_anything/modeling/__init__.py new file mode 100644 index 0000000..38e9062 --- /dev/null +++ b/zed_utils/segment_anything/modeling/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer diff --git a/zed_utils/segment_anything/modeling/__pycache__/__init__.cpython-310.pyc b/zed_utils/segment_anything/modeling/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aae6f0410a999bf21e3e1408ed49cfe6bbd23f3 GIT binary patch literal 394 zcmYk2OHRWu5QgofNnZ(%V`Kro076A73l>G7R3K$zS=giz9=o!GqAWQMSIU+ZS760B zMZj2|`RCK<_rvD-f?z~v_wtPr@)4T<5n;2!?L(48QcZwXlm;E?2;wRpb*v|lRLQ6( zI)!O99d)8-Ft6sQ(~=a^PdtMZbcjm-HHBf&Pma~AgP z&j^CkZD$`vzjea+XX`+@+hQ@qUVsaC9r7kva&_=@abOvsCzl6PHwoxyKl=X8B>^tm zKUbF`c_%KzGN!K;+gVUd80l2U*0wX+3h7ySwP9Y>piIk!>DygnY6kcs9R%l%XVuJW%vajK0Zq`8rBK@2D?vbq5uE@ literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/modeling/__pycache__/__init__.cpython-39.pyc b/zed_utils/segment_anything/modeling/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c47a9c76816c0790e157fc6ccdc58453799c125d GIT binary patch literal 405 zcmYk1O-{ow5QXieY5J3ZI7S!L7eEN4vS1NJ)d-|)EQ3uO)ni9?P?aUe;Y!)E;tH%7 zrwAC!GjBeP-q<3~X9OcUU&}X2$VX`YM}*B4ULTSql4=69qBQ78M-W%>pkqCPq)G-o z(ql-gbkK>Oz_gm89&gAZ{lpz)LHl?k0L$quh^@LY&sHk;P?w12h4A~88VI)ey|d7? zKO+dv%7c9rUFn4JFV=x_cZ*pcdjT%Iwa@Ed$<@)*#erpjo?ISH-6)`wX7K%;O9Gs? zf39wf~3c~TYFX^ z(e{*}p87u|$K1H^FTnrcD+iPVXCx?y_dMCnR(SN@+Rxr+`}y}?(Q4HRH1^r|;#Z%L z-*B+lI1Ii5G2ej_M9_@%s$ZJ39;4W2nV0h(r{pmap5W(1aLGpgnb!-z;)?(*!Hp#p z^nkR(Yw$+8ZEE@9$yla1T5PqOf!T+SBt0rfk4bOD`&3Zr3kKrJ04(7K>?Ztv^pe0_ zAj0!jFA|Zgjp}d;(0wjyqJGYMjhBRzUK4h2*l2He|6x|7&umNPPo)q;brg$XZo@nq zyN;Knr46(!Ct7t=jKZv8AH@SP+n`EvMlR?jp`>IN-jq(+C7Y6%yAhjGH}a?683Ah` z$N{;c-F9ewGZ{;5qc|R_VHwAkDP@D*p^}NVJR7Rdo!-^~Y|`eI8=3VrZmr>^h%;YN zTZ`i~ON@!*EArP)XHev_lPIC(ap#-jSY<^bOy}V1y^fJbxm0DGsFQL47w+VRkg&~B zCoS^4P+Q{@jW8^MmyUB^v-9qQv;`dj7{Cx4}FF)z_(p7|Wjy6Qe(bc>u+6@d0{4gLWZzgt)jP ztq(vdTIn);b$Yt85P&(b4z4$l#YDo(K0*fy)Rc~(z6mD0SL_8qJ7wztu;7=Ac9-Al zRi=AWGU99GlHMaP0Q8X$wUGSCQ1`tJ@V5dfFoKLJ!Pz1wgQjVeg zTqBWC8F(%eWxa86d=n@0T$fb@K7tkc6I{k`lg84CrDg$57agScV-Vw#ny$n5DdqI^ z?qYmq8ASC1bXwP#ym}24!mQr~W&J)>k+OP7d^r6cD)g^6QQhY3ZBD!o6DS6&S48g9 z%|PS09aVRVT&is;RiP1-Cu1l@E+_RpSm)Bk{bRxX{ON7u%JaE87<20WiGAG*ptkp+dW(E zQAjjatz+Rb{PKli?7b{h)pe z8T*|!^F!k1OC3QvMwgpyE2D2c8m zvEn@B=xlJyqx)Wjv*TTi~u%l<*PEnQ0&yIQHC*-O>1aB-23{UXTv zm5V@E6EJ;x|7_5jC_u8C$SQWi&iMu7tPe9{anaV$=nS09hvsw7H-X}SH*NY;>&NE3~Q7%p6tw`1npy zDvf9Eb^2vdlxAam1cZq@NH#&nq9(tC^Z3nK^fqP}zZ8!D2r=691zBDXT9fb*l6@Q5 zl${fCBnpe4I}FHsI{-S#zX@6-5by>G-XgYq!|vkB#e)(KGf?ZevHnDUymWJ`Yh}Ad zku_PNF|y4x#rI8eS}&&Q^GTN18dxBQ^6|G@&s;Dx6_7g@ZF^%M02fvz^!Hu#R2y3Y z&_i77I_7~3za`WHT@nZ$^bkv;B$WROA-LcP7ym3Rc;-~*6k<1IKMCT4U>*IfAnZq2`Sa3MvZXE|`V`w> ze4zChyU-P(5Ick{bY@&IE)TEcWS>J-N8n?$fKSLw-WH1E6UXfw9DQ^Mv1D$2kH->T z!}%#czB?bES=Dg(fSfjA6+f^>jLHn|A#>pXyGd0ik`f2+5hMTbKC#<6y{&`$s34`K zI46890+O_o`bCK^t_@>qU_{l?7<-VfLU;%5e4+S1=A7Ro-X#y}vVKT?#%Dj_i()ZY GUj7e1GITTm literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc b/zed_utils/segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9458d07584177e0169f225f28c1d7ae8fc42dc17 GIT binary patch literal 12630 zcmeHN%X1vZd7tTd?Ck7f@uI+oC~73q7FLD`k(6bLiENP+NlBq#T7*nzE!AqUGmBkd zcXrV;3y{d*63S67yyTh-HiuLx$5bw%=W`~`DzRjQKU65H`3mMn|;ecdzr0wg(3 z#i>*>Tk~yCPj^p0zQ^zTx;s--wt|a2_Z@HbB}Ms1dKkS*cz6|ecw1EzPw~{2Qjz~^ zn^l-9(^{J`MOE5*#Sl5HWwvvbT-&Nx?R+JVJk8TvcDqn1w2PIZs@zmO!!tinJkw_z zQxCLC2{}2>LXMT@Od}`n*~qceoEfj+D;u+3(Vz3Ce2dbcbN*bNom9kadL?iA1HCfu z%?wJ8l=0?<<;ue#txIJ&F1Y7>naQg|R%9j^((g-cia=AJAN{l(oeAb~kEv0=JbItG#Z^PxRIM zUH<@+u9R6)xaHpW`AUbkFM5gf)~(xDTbBod) z%4w2)Z9~Uxcortu@@%hApTVS>m=a7N?rGdJxMy+C`8jVM|Hpg_QLVqw zn(g&!*xdCKyX!`^hD=+{pjvCVL73$I_L}ciy=FTxJii+?l6=r>R~x?T;bptkt#TI& zk>oqQD6L{`ZEjcBnr@ib0al>eqP0l!ZY@fagC(B=h(qVWepY{{-BTaxc%yp^)QF4x*~r|_Mb>`a(@=+XxDZdprFb@;*;e^$@pPbi`e<8q<{1y| zy}4M~DD0QwxxT6@TbI!8bYyPKpdAam2rhi6N3*f@Koz+ba%F!0p&|48)blxyfh%m+ z02A0wdR^h`;pM>-WoAq6EH;CveA#&;==LJyf|DF#h4W*jhis9RzfFvCM9GY4vT`du z@D6u6b*DM<1>&o~Z-wJu8L(MaS$;%hd(B|o0SEd)2$3JJCwMZe^BOTIeGif_UTna# zSTYHX|L^ZQHArVq01wt5o_ zH54{yF$m7T)$w{QAMKJFbX(0`>Sf>vX&mp#Kr+YcU8zOQ?a!?=BEyGAfR_#gywmDM zz5^+5(x=_la5^SdpBhU=|8Qe89e9bJY=2H;G&xZNpByY@+wG2z@sMIU{2kIcSbg-B zX3IbFn_ZW?Z9nq)!~jN=RK{UzlE7fOgTA;!QEip zPmHT1#EI#4yMEv$_8lK`1c3*1Ft2rj?Ta2aMc)Izwuv6DZgH|F!@1(e=y5nb z$LVomjvlAH@rmM3A}>*!i6Q#+k(%VI)mqC9LyUeJ?>@P{+~~CZWjFA+zq9;WXD5JS z@xtXhSKe9<{dMTQsOkpyqXy=2xeZI-qJ+@bRbQI^^WFP=5pB`Jh7}O|E2herseG;RWcS1b<~X0NWuo4z zF?>6_$n;vcui*}FfyBx_OlV)-gYnz|*y^c#nvC&=dR!5C3kdc|{b2sP4Gj+rUFe_`z1ikJ;)1)7kFv$-$@P%d^DO0hB|vn%LDUmI>My>0-6I z)pJ|&4PT;fMm!B8et1yHe&acqJ zRLV;=ztu|lNsi|Cs1gl{(-IIThcqTjVcKxjWeS`zSaYoMj~HX0q13N2-cOmqe#K1o z3CppMnZJKV>EMZ)}?Vgw3!Y+q3LPY_9aLXTB%CTf-;E zE|M3z!aoBULhg!2>+9| zZpJVAQ$A`Dphl@#w6Y)|^_({kNL>OjpWc`OR0jp5e#~0{q@ER!dJbH5oL@r6cJHi0 zd4yQoHULbjYJ?C>0!N{F7I^Rb!k3j1AGcCr^tJSpT)z7 zGMUgL6Ds5!6?Q;s1(58bG%YUiM_30Lfw%@q5FjDJT!2;p$y{v3c~9MBd?B_24d65n z@W~#UD5b|)c|h7&jdG}?9aGSP>1hu<5k3yr+}oxlMd>01^)aOaebCwjAOti2;vB zfBz59{yX5=|5ZdqNV*w0u=8#!be>-#Y#0pKHS>!IqkcYe5`GeSEdt;;ixCAGh^LnT z5u>sLK)2Ef2wqF?335+pmk>6Ay_(yq^$4;8o)I$5@av&ojDX_-ww!RSBsw#svJiS= zs2xZ#AR00?@H2#`o+Bp(LYQcP_A6zZKZ9bzf+Plc2^W(5l?*B;Iyuua|02&(!J8n7 zhBgGyD;K26{Ld>aVS)(ZNIOLMWEXcxXwp7ykJuch&Z*WXcK(EFv)z&aYy(swZ1otf zAh<^NHEQfnKw=e`VH{hin>v3oRyUyGPU1cOwUY{Cd)l5ZG$Js^1NJW4Q}+zigiAg^ zB~UsLhV47H)u5DY2{56mySFL)T8ej#%7Ji+` z4I;!K{BQFg-q@a?vTD}p|?OBz$8OdqW^pn zlUTqc)&wRg0h6#|>1St_gb}Zyv9XI>4Z6Y>$SC&6v%HqqaU0$oBCI-4hYr+X0CgBZ z9R^T`2~45;-~@pxik?v~2~;uVm4GTNV2}KU4NL+GQ~`VeR8bJ9q6p5I=ik96cVBs{ z*NU3YQJldU2%v-)KodY0nVJ~*%%dLj0onT18>7(S2yGx?0?Bc9Q8jwjaEE^dG6EAQ z(gn7rsMcOC&UxxQMYFUBKn%G6-N%$zOCf+6>H8wLPqlJ@6a)*D4U69aOwfC3c;QWi zsUN6+tfFK-viA#~i8r~E$}xngPbhJ&&qUm;IQHH;^nQw9hJYG*_Y`<)pf!J|P-Eyu zlQB3HuK6#=YyMZ8=b`ieKWZD-!OpmT`1dmCpo%cV%I;NZlqs0$w0mLXtRZs5JKKnT zcuupP#!U`M8%3zdpN7>w&q;wx50mtHximDQ!V)8r0VD5AkIuZ+3=pnMG;m3x<4YSd z2=E-WK~5WA21)Yil2;K$&yqVZ5-(ASvqVI6;RvZtEK(X|)AnD5dISgY^UYGs>PV)XG~C1<>g|Jr|@UPVKRYGWJFhUZeF;rF5L1QTC_g*QiX|BkEJyqpZ(( zL)KhDxE&)&qY3#H!9fR)a7`@z0ok{M&+wwGid8C4^LLQNs~|}Z$d3%WT)I~&%j^Yf z2>$W+=pi4GCkdG63zkS05D%+*ey!7{JTD7-aub5`=BSuz40Lc3O5O+ustdhFKzRDJ0V$Y~Wb`jhH zBrrSVjSLMK{9Tk@y|KDdo|*s#{QGz#QGfsdzopqHoA|f5Ly{0-_Khqwp?^GS@F@sU z5(dBY%QNSMJr&4MfGcVAs2kZue4NKE`1o&-00)ioG5p(DT~J_@K*6gj%-Qfc2eSnn z7(PiR5qR)x5pg_toqm)Tn{~6#_K2&^WALTvrw}Camm&D2Nwd>fJ0&cL}chuDUA@V8lBKaR3C+Gyvb+&!pLRgcW z3<)1Y2_w>&D`~pG_o4HoDI2$*_n^2%APGK<^v$ErCNg9b-@m%#yuRezU2?9yH#CGP zNELcDb?{|Y>aqw!{kGVR8`*P8`*wNINw9SQ`b#h23J!U3Cj812*I z>s|D-i&4<9q2C*RbG;Ec6ph(wdO%fcK6n}iAe%wa%b{wQ5k#TxeK5*_fbSksfZZhL ziFzZ^&^9O7$X}%#yVgMjdA--^g%C}WO38+8N{tBTwTdCox%}Hn*z4Vd-R&m1bc~aN z`d!rjDejQsxTV4jn*n4$0q5v6JEs;^PVf1BA{66ISYxL~$|<{SBna&ZrvD2hVsdeG zsBI1ZCJ-Kkz*FH6A>ad*D_HcD0_+DIx37ZnAOv7rpgh66U{$a!0!5_Ag>n}=;$T}1 zY%9wfLY?MO(}cjhg1R{f%(oG0!XD)hpq^t>&>T_9sp}`f@_mKc1&in1KhIeFuG}#h zae*YScX)HXN!v%E2;JBGP%JL@+py^1{gmC0V)AsJzl7C)JI0f3r||&_3S{I z8r(&>0vIGD=glfv>M8ZKdQP2zJm@N?w~4tW)*vxAo?jabL6)jC{kv8TJJhhTAtXt&@d z;Rn$sw2(jSLkmt%w=l=8i(Hc|3M>o;Lgr!z+u0g&xLmTyV$|};wtKpskf6+79qQtU zpX@B&g3&5}o&a3G2S8Sf!1*Jgw?X&}nq_o44t*ybtK@%SC@|9sC+SE`Y|Ft3!0{uU z#+&H0RkxCQC_nxdq~RtN(*W(Hw1o6nfadUdw}W7z&#ADmFzHSYb~|YcQ*U*Gb@&N? zM6Y8admxn}S`zrt!?tH-I-Mru(C591wNsw(g;h%)3M#x_GZKr z!&3oIvoK~jz9Y&4OZAO-1`?}5`i`J=!vH+$4OnI}tREt01a=Z`7)4)7eLO?PSI zQOL)&H0tX+LxAtRogH1cJv;|-op!oO=0c~=J8kE3mpq8e@3*&`(e}mnrF=A^s2g5h zUS3CJsJC_=9>a3ot@+p=+{6-t(`t=ns<4dhrM2Y?FMs2u7hk+|!M}9rd9VI*{nCpU zFZy49;rWYeFSsvUc-g)9;wATD?WN~s6=Ka+xQx)RJo|$EV+i+Iz{)Mxx%%?gpBLj9 z7ikm@Mr3eM*BLu(u{Ba86?a(iiQ^gK5Q>C-!g+EQ0XMg*4H)9B>W&m&QZ{sDY~2`X zIOF1h43EoygIW-A_{><|YmOYW$XX=p`WtE#(Wmqn2SG}Atiep#Wfnff5eA$u7}!>! zCTYtsn-F1Y7cpQO^$v8E9IE4F`tVCz5-L*8;Nm8=vnkpU4IKjDU$keCtY@e^w5F@AdMfOHJy1_ z)#g<8nXYpxRl_}Y(T*)$VFWS}F_tmoG-B)snlOO`IkwYww z`yOqx%%}8QE^5C_rKYBD6b!IL;<_a^Ixk(7K8ZF&Q)y1G`Ci zBAP%hlvLOol%g?>fWm>gnT n9$|A=+~B6q&YWaVQNE5D(J>=BX2iy^A{!AsJ!gNs@XUVz9@b}V literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/modeling/__pycache__/image_encoder.cpython-39.pyc b/zed_utils/segment_anything/modeling/__pycache__/image_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0067ee46690edbb9a4ef41a6999365ccb9f955ff GIT binary patch literal 12448 zcmeHNTW=)Cb?)2r^jt^|xhw9)iqf{$5=Ws(i7U&}CYJZEmAzu*T}o?=rHq@?J)9wD zdWO|KupU_RHbcHGF+my%yzbtZCe$qovY+frn^Qf-!4=N?P8^tpT+`FG zXCCU6S(Id53nf-kGKZ3!n@34LDJgl%_Pn>?S^RGywNTg2DRjGq$GTf|XWlm|r-rqX z=cu2i(B7SWoN?#e5^A2Na;{C~^N$U8!99iYC(^Rv|EPyiirSohoOPdY7tv-BqdnuE z#c0naqdnuQca`$FBbv?ivL0K%-VK`_-)Y5}jb69q#m2^iu6KlKRmxgixa~af#Ck`x z*WB28^Y&X;Tbm=Blnn6O1Xvq|!8c`nG{P<64+P%gwum zddV}#ep#d~FNZq$bIS1FLYajp1yWV{-h)v;zJ>l2p`70%%wVTCF6A}Ck zNT@ti4%GoBAW|cxs)c%_MOqyQqN3Oc&BJVH9p+peZM5^$lBh8- z2GmCe&+`KXsX|zcG7r`3%6r%Q^WjWnL}rwYa#21iL^IKBG#{1rRPlS!oUggYcwcqR z%{> zZOE}{)zR3%NuE!-!e*AnSZ&Y2_n{AalhndwPeX>cm+GyKL-Tis_@7>jXLh6Mx}HzN zwZo1*G5*PpQW>$AHandbEkoVj5It|%c3jtX?3U*U-`Q+=Hc1k7w%hZ>gJpZi1KD$0 zy{Uy7iJHCS`zvpD++NE=zoh1zR`Y;H>Dw%elOq|*G>t|Oeb7kXmq1IAQT%3=3dYq6Nv^^UM*m9a50b4~Sgf_Af)TFpR;DQi3K zOt!Uq0f~leH0pWugH2NJFlR3?m`@LD##FweeAXXKGamuh2|J?Jh%>Zmurrd|Z)2unQ%+8t z5sts*#hI%l#IfmgyPogH`MVzE2m%jiU|#F^duy&RdF;Md+rbEpz-!fcSc2@yXs*O* z%8sVz3}vV0=u1>KIZ@)5Q5LJsIK$)mP>pldYOUo20lt0?&py7h+UT^sRmXRQx4(M5 zv+qN(xWVe(D{rp)ozR1R*a=owR#pRV3raAoI{t&Ofk|C$!|=B#!OFYpNhGk+eIS<5 zF|Be?0a3o1S4~Y*P5d-f`>B;RRYNn>#ZOHmuNJ9n>Nl&-{)<_hMI9r3@>9b^t2C!- z|7K*(e>gZh1wmHQh5b;A)%(K6%WE19A-TvL6Ux1YJGc!JDTgqt1GNu>x($HUQ^iwc zptsdC3YTdxs>ME;Y4ss$?Os3&nBzm3YYD{27^5cPzsx$bGIJnv_7T0ioDt9APy7mz z7m0iwB+h~n+fYzoRbX*Nd=sRiZ?<-LRV4WaSA_WPJHp*YCLm0qsCoKnbT4WLXQt;M zsZZdlTJ@nYs?|7OtzxyHk&!M|tGhj?C7+08dPi=ENMqym8@Jz%b=c-O!zKn|1q&fy z)Rowjj4o&%d6C5m7jM8gc7(V}nVE#YYF?|A@R31 zIgj!J;Bxw(Ls`}XB&He!m*Iy17(Fg9U&xp zBQ9bH2X{B17y_(O8{i~Sr29<(2dJiQM}%a_k11;iCk#T+bg^!BS9H40mNbpB12Kc1 z6I{JsE{La*9a9xkT4G8;T%gJ}NUZ>pUX-R~qLmKnAY%YFL$UzCM<5oE6fiLxSy9ea zcQmma<$V*-GY4p?Ju*?th|>B1tdSaK(MCV5pa;{{AL*!pwxBuJAQ^I+c;yfR9*2(v zZ4%@-LQxXoos8}NAAJ3HfUp0nc!_XwGqhpFomODKuuRC$AF^xe^$-I6Y+S@%5cwnw zU+krjLIlLq%K(63*#=x&@Aw3srEdhuCoD_&nV?+FY1MiJNCCbGcc$p{*eJ%J?+`go zIZhILWu&rTJu%V_q!GHz?1g8^1%VL8I)M9nIWL|?HMRwDhJ1h*qx%}};CDeH71nfvqMmvZ$6c`ai8%xDRIJOs8Pvl%N?E*NNN!iL=$J z+o=H{5C9Q3iM&pP7(~25zWC=6x)@=KSylU|(lQ(D8k$X92_7*z*~N4msyMcx0#rf1%J}$bqk zxPJ2IGDM(qP+$d`ZBP1I=Iy4B;99JMOJW0W+LkW+^VA0^DX|I?=aMC_BI2DU_o4Ks z?Bxj(9b2R_$VQ73tmL1e^MH)JatdxayeK%4>Z}ImTr)p4H0=|^u+@Wwp~CqbTbcR@ zY%Bi?J;D|$?;3rLEFs%IsDy!fNZvCn9fX?f9z3r>2L3Z`VyNVG4KjU*{h@}ZP!!wn z4U*d001Cu~fBmlB&-Sf;4t|ERZQ{9gn8ntIHqc5+KhG-V4?+re({sufq-9O)tBE@K z?IJur>yb+7C|6PrXXM+cPBJ1IQ!=7-%y>f9SwV0dUy}Iwxpn5Cqfc>7E&UPMtD{fw zpscdqEzgN}P$V3XI1404hFZd{D^z6m{7r=FM3pkRkbFbHHjh~%UPSb&>UyedDrnfIPQI=vV3Mjl@%_Vji8G7*^W!b8; zLH{|;I9bF$#U0onFyk5Qq{!b#BLl9fFx>xIx*(`F0xL8H43MOeqgSUF@oo+`^X{J@ z0j?S6UHGVxx~RYifr39(n6FVj3o`{Q7raI$5SZ?rkhmQDO+PA&%(^+y`oziR8En}^ z*)xO?iH9kEq`_ZSc|#|c{AUIxY6C5zhjfF_9GHih0qVQtGXpPz7m0h}FW0ek1f*a& zZ>x#PLu^uFKJpo#SI}oer#HOjRwJ}2rn293fulA(@G+Er zI)gmQkvf<0yvyJbK_=vgA$T1rKxdNeSiKqR=v(+yg{3-Q>mbIv)$8;Eh~tP2J5*y# zAg@(2jIib3eyrW<#@fAZoK3!QS~%~a{XgOk7C@B3tX5L#cUE1{o>VWWf}SJZ_?!5X zeE%*|u?7ThBT*yelwJg53F)x_{1}OdTozK>I{rJR{S2$q99o(Xo&&VaLU{fNp(E@l{tnc0 zWHQYXrINaF4(vZrs9&&q&iTWX-S5f$jWL%-@_k1%x0d}`oxST<$qk?1wXS`8=R7^^Rl%$7$S z&$}Zl`IF`H{v~^Cejz!dIc5LjxDJp)L$V=vc8JNOg(S%lA;@|e>XI@;X%Xa}A&@Z1 zwIZTQ4~URsKav{3Lc}t~t{xLMUEf8$0vIHqT^(hX+`02AB_u{7ggk@*0|hB%E8X@D&N zJ+d4CMn6QTlN9?vhj#?Icm#kHF}R8Jfrh70`7jrTkG2RRWDg-@0~HX5B!R=tcrV`k zCHO`3?tX^r0FqE0szr{b9NVyt3HdU~fYC@NpCPP3{#f=fH$WhpCF|W(vR?PXp(MF?1XtjLl5N?hx%9vlwFYm>E9Vcbqgha$(ZoF53{q!g<@--W@Tz zT({|J)brS0d9tVwq|}5R8{(L=Y%kr09V>s99Nf4MTviJq0AmrgVN?tXW_&u1y(jrr zDF$piFw=@4A&E_FpTQTvnIMC{H#TTXZ9Q>R{{EW~3Ia}umpS*P(DW!H`E%3Lx{D52N86>GOz>tJbv zinWubs1omfCk0bBDK2bzVTw#Iko29$iPEDq7p<{hd7=< z>zQ*3P~QAW&t#dbk@iR)Oq4sV;61>65I@7AtwI)MslU934+uhv+l z2CLXe+FX6{)vv$u^2--r^e$d}!L7eqzxeXnn)kJrURc|F$$9C;SDm$&FFI?rS6+}! zc>6b4MTl3PV8O;Or2T0?&t1p9`s&wS;P07~aa8ulgmT!{o;WhGJ60u;e^~daLl%6L zL_$M$s_Z2|(%otU270%;FXfrE4_zsrH^w?n*?K6{lM10hJ#f4|HKO;MW5*@Z9?9zd zj4Fo5lpMw&*y&7km}g&k!L!%x8BbUK&@?l+NPTaR(PbMsed%+eVm$draIFt$62BS!theKm0pE8q2k47CN zdC1cjw@rM8QRB2A#VB9g#V`7BVsZd#v~cTi3!lZB$Wt53lzN%SCP=yk^qEV2LKtxu zixx0rm-4Xl`A-i{F_STmeNM~}W^M$kpF#$977$g&)(#OpkxpFGa zuGD)zvpkNITi&{~ou+XONdQbY@h6ap2fU|c|)?E)xwKd*uX-D3j zUH8mdv1+ymC8x+GMR5fV(uJxN7ml3yC#dQxC;bH^PKf!vo{=*RNl{ z&lxwHzJ-Us@N@Bxir+$Flto07l zJdOeVjcMFZqj4M#JGOSNMM*V>Nn7^od0InJZqB zzo8?oE)!Ol3nuL?M=-&K{U=8_vLY%NafORsRaU8gkr=Ls8k#Ta8?E138%#7r^FgcY ziI%Kw`C`t%wt?pioR@VRszH5f8~mcR(GUw_@qyEAK5^CQ6}~txmaqbQI#zcMqFL7U z>tS-|s_c)2R0lTBa$w>39Y>#!r=#9zI!vPxFnZ%@I+>=4_Ctv7UP$88t!O;$4P+>K zLQc|wUO0LmL_)~8C!&#_OH~*r8<3(@W}sdgPz~&;ngm(|NUGXZGM%b;y-)W*kBe4C ztD_N%$=?A|YuDOi`{4E?E3_5BF&543S3tK3I{7`iqWZu~=V*^zcGcQt z`wa8#^c;9JJ7<+$2P5;sgY;|RU%6%ZLGj#BQbheU2}s~cuo>NxaS+A>kQ7A}IAa^UVk^G-qGr)2OtmrmiS_X6 z+F(4AYcLKax7V(Yx8vbB6v^7nA75Wfg2xm zFAos}`Y^65KLp)`cy}$%+mGBwJ1f_LU3kdVUN4HHwAWj`JRIYulPl;PE?w^|=t?r3 zNTt0)Jj16t2YR`yBO&P8YljA`Yjktz_>PSAoY8mU%)BSMZrChd{G0Ob6sL*8p`Lqn z9N)V|ZgnzFQ2%d4JV1ZKtzlS zFANfvZoGS^s+~j*H%_1Bq}%?4=~}PXABIWN>#48MiN8u90HE_H&ja|ET!tTSai6uB z!xx#)9CiWC{oJYg%;V0N?jw$#|7d6NEWzuxafiKQAE>>2WFqy2WD=@G_HsX2##u%< zI&-3ZI>hxU+8jPLbm;w6v_t^Fvs)QMtg$o1`H!7<`OKYFvx-_sP1N7zGcT^7?+8x) z%%zy}ABJiHASX>=IZZY zLRWHZk6a-U==wC!GmF(zwA29dCc275aywMw4ZI}ev7BYj@}K{O<29JipMq!c?45<{ zAadxoi&gy!RvbP{=#YII+s<%;)Y?L}w#B!QLvA@+_9pUq!9V7+N>t)rg+PluIHuGoo zStGOdeBi@(?A!1!O{{LB@Tq)A{Fwg@8~&noG>S?C9M_jEU>wiSW!1d~xN`UTUNfs@ zepb&K`wpmCd#$Xx&t~mSYj)0Lt98m&-#b5>&(7_&LF-&Lf6>b3a%$Q9F;&WFnRWLA zoCCPUnHRG53I6eZyyK#^Z?`Ra7k0tHA~;ycTHxS(wwRqaKCim7cH>0;yBge-GI*}I zx#t3Og3(yXK*%(VhRJHKbjS>lmmWOAQ|RI%?zFS0D@N6|-49ds42HD3Eu+mrs-C5n z+B8lOPbDQR^8E?&2gEo^U(P8xkztB3pGq~9;XMg8CRddQ@voc3CX1!mtDSq1+*Z$F zRCUZI;bbC_jC_+`DoGml@2GE~-&uO15pUXQyrH77Yw>1MCXtu5S zAn7)b5(g^Hx^u>4@|9*_K2HOrIF5qdriuDET90xTbqRcRy~8o}GG5j92%NMz^<9ij z(2^B^OSMaEo-NrPTVgFXZ!fWcU9c&Nx7Z3=i#@}9+v7VI%l3D|FLc_JAI0HF_Ildy z^^lgKKtR9I>rrZ(&nS~BId8mn?JaHNHnf{(Gm7G|$tu(j(9>=nx3ojfR*`|L?-L-d zb+gDw`f@l#R&so>7YN1VC*%SsnM{5G;91;tj6ZQdzZo5ioBs^a=6!#R8nX%#kPN35NN;9Gpfpc zN;Ze|eIwDWB-$K}BT>}&Y98FIpGbqq$y2k@k{bZW#dk*KEQ;?}FIvSA_-EopocqCkh7IjqWp{76AOl6bx{QY{pX* z|CEUB^SAK{;4U{**gC2Pqc`rgD3+n>*b3jFEX`ug?v$?I#FEuq>Y;Uw-71TecJi(p|Xr*_9|L(4IrhIl^OTG z1j+H4O&?0O5K1mscPzC`cZFmY82c=y&nM{3*j?PrRz*FFSvSMX46>$@EWafb1T_f- zk}8Cg46%zji)OEDN6F#T99mZkwL;=R@n5s|jmOayBbif|VQls<#5C1=^XwZa6UeFT z)aj^25a_ViA@_CK?`H%^6V;mp-XahXFt-0T_0BT(3z+x=no%bhd)0Z!VaR0ayfAe# zWL|B3)@~bXf>Dv1$+gPM)2BFAq?@V_pxyNT1eu#6$>;8e9F%rZO^(K~Db3VPnj;rr jNR~J|x^C@V&fVY@N~IEVNS@nZ4YtS|eBQQxy;T1Xyx5uv literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/modeling/__pycache__/mask_decoder.cpython-39.pyc b/zed_utils/segment_anything/modeling/__pycache__/mask_decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa1cd9444a40cbbf2b00782907834e5b68dcf6fb GIT binary patch literal 5419 zcma)ATaz0{6`r1(Mx&KhUT zuP10WT@8=^;@jb0*EH>KRH;5DDnG)T4nRUJWFxIh|N2N9vo2#)Hb(l`=$fi*j;yiW zwa2w?ZR~U%v{|7&s*l~S%e1$&uogNGwb0>g*w{C^OK7QwE?V5Or5Un2TBmV@*|g0L zlX@J3y4%yZmqke&j5>xiZ$@b*t-I67h)eVCg9$%kUovU5D3)?uoa- zm*(|teGSWdWUFPF-R8V~@qd!D@yH67!sbJ>+xkQ^weB(&-;#}6LAvug?p0INgS+xS%%Ks5FQrOe&g0T11vXq%C+h6>*PjO^=PYhPQ#2luZ8ul4<+e z0Xu}gpK7`0>$whU&%YAdQWtwB3y{_%ELx%o#b^$u7nR9F3 zM9WI(!2b2%&$_0$UisV>JdAo->XFG)uOHp#u@}T1wxG{Z2)&6&#*@tB<1HSBQQS{o z@TMvLy* zV?OfG%;OYW?< zq)lP5V_6j`p0_)I9m6OM##>Q;io*+osR0$P<8ytE&hUGhq{4%lpuLh@5#o8%N!klW z=S{G14&^M>RTK@SNpIb}ecCQ3CLq1NdsEDjJ*adnq#wIpb&8_(>;_`CiV&!O&W@a! zO^_=dQMf+B>D4;AZ0vA>1LCLA9+!5Si74bp4*C`zjIu_|H1s0g^cx_3_~eI|W-NnW zDG$`(TWsI_oqo@Rcj)3Wm;_}F1Y85UVk&+8CCQ@CnP{W=bM4EwHwMX=Z@@VO-`#jU z*^Ng@5T+Y4#l#jMgujOR;{+!9W@ApPQn)<%@8m;*2L_82$fBFK5cGYDzvvpf*6o#MZs+vb`!v-8@#k8(Stpk-T!Cs@I6&dC!e(!_ts5&tXlpEja*ElZF% z*6+ehi74atv5mcLBq0_4i#d6$Jf{*z=6i5P5DsZnU@K3JqByx5WdkKG9Rx*~J*`hg z@>lKYnVBeURa7p6>m)%^OI0S_Nxe4dzuI@@F=m5ctxJ4dMn%cj)F-i(|Bd;UiG$utT#ZCE&UiJ(zL2$ zb*sy+y*x=Mw(O^_AQI3X99MA%NXb9|CR;&7si;ybtxy^qG}g2UZB! z2cT&j)Tl&GeQ3-qHO9(~ZJmMvcx|*e2p;yZ9x{M=l3>8?Yq@zD?yq|FQ@u`J8`eKC zXZ4voYs{8%?Z5>;!pfn6(9^`|ChkGj$E1(?-!>3RT8maQ1pFK}E^FX8?qAMp2TRb( zKGYAIc|CXYM!s}tLYj8a%4>&g*6wSw3kuja2-tr3!fYkKaL|UV3;D{LmM<65@|9z$ z1b&(JTXv7E1@3ay7xVUs6`+1R=c;ySv^A=SMybOpbhwzepu-dSYW{=@h_$_q+b6*C zx_5_j@7Z!&pY>4i#);rw$g?0CrR&9s1dITPJ$j6%JgBSK)y}H4lpjc=7i8jTG)a4x zNBu!2zD1S#G)|FoIiVZ=(F9-vd62Nq1;Hmg$N&X0E=D}K&vC{njunain`*Fvyp*4% zc|YR2;#stcj+!KxOgNCsb5ya?Ea>fs=TYul`brP!Hp{$Sq5-@7H&W5-)-a2Yf@E8g zfes|wmHHsEse4ofk zhZNsM+XQcV4MeM7WmnmXafvzX5^EVNpq_C>W%?F-jVW>D5 zbKjS)?*q2sHi7bz?-LF#dW3>g=38&vd{-LS4QUsMM^Lg>$V2=9C21E)O_~&m1-ZC* zg$QXan*6m-_5u@_CGUD#Co5-put z;l}S1%g>4Yf`}SM{oLh~^MVCWO5D*fsW+s@nBv=J`W+6tdN8 zc+;C886i_vglgDKpBeOhWT?{~?Py}1_6V329EU8%-cV`A?z!Nx#81#`=jfTi%5)Pe zb_ejk&T($I2oUZ9%%avJ*_+Z7d^A;ymbLOU0UuDBm+T`IJEWy7b7~uu_x*{KRv!o1 z3VQ|%cG+HomO9VU>`mBxlji#w5juh59U|`%@rWodeveA$xh(+$`WBZ`xOZyWm%3hJ z1ajCf+HK`Kot6p`#Tjx6WE97O&Y|dmNIQEl0Voq>?IL(kc#$^lisK|!moIUL`pC(Z g$ci?LEtbypB8c4}pHlnmuq6PMRei}=`L)*g54od*761SM literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc b/zed_utils/segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..578bb608020f523032906c7704bfee438d9c3a30 GIT binary patch literal 7678 zcmb7J%X1t@8K3ES?CeUrl4V(b28L(6pvX>m8{@=Dj6+o-LTm^ugsIVXue9>+%%*47 zwzgIURSuP0e1HpA>6qfel`H=M9H}``MHQ59T$4b;@9UnKot0%MMm4=X{rLLp@BMv> zo6UxShduKH{+rJj#^30p`q=n*0dKN{h8x@*7#&mnvJS&98?d3(u}qq?2KLbDI77GN z4r`s-(Cc^@x4AQ@4;!6^Y1}lp%WLl#ye6#Og$Gt=5hEV2W2BysG`V@pXg40=G-I=E zW-N-@c4l3V_A>ALqqHAK;UIHvk4FQM*|+yb;*p7!Sh(xiXK{;4-^QFH2wM*`z`eKlu7N#-bF)+Du4wEo^cZjO zg?H@EB44C$-c)T%H1m0nFY#r}xA+M)qj%YSg`dRe5?`GgJ;m2Bx(rw+fC=EyOPB$J z6Bxr6daGp*nD;z>nxDb06?NiCV04zB!}sd{0cGwiKs?Vc0K%zv4cq9ffhSL8&DUi- z9Hlp+Zp?+u>SDMhc)lKwWabO-ArtzLJH3BxZ&{|W5)=r)G zAfM-q2PVE;PGdF2EsNXlc&Rrvc54sJmodNZruBoy)IFq?3+OL$>^-C%P3&XbJ!`B& zw$XE_*Q1#h)-9#=-DU1fZO+c&Ed13T+UW`XjdPs`2CdNFXe%qKv%u5g-b3rK2DmG8 z=bz+tlK=1pW*7}We>BAtf28Q2_`O*Aine|l`?s#Y_H^PCVSO$XXEqw;&g*h}JQPuy zTq%cqzvP>Lu^*-FEB@^rQNVs$&k_@Ta8^VKm{Zw%9Hae>v}0fpjN zEyYhbbWTpICq`oky`Q^9P;E6r1MfDwn+wqFxpFZu>H~y_G3UDEzp6wl8_sty{vI{&`;7X?982tat)ngo5EQtufVc}J3@jx+8(({D*IgET4w=uv%Q)* z$#?`y<`wYNZR8?mF* zH4n`L#?e9oS$GD^N>P79`G`cR@hE2v3{0UCS_h#lTA+MI?836Gmr%mdXwZkUVqplI z89}oWk&;fU4H|s16OIb2RW%u)Z^gdm2vx0X8q4m^#%+L&F1RyM@WY;jwpfx2@r z64x-IGz=U&H_x%^vdq1yh$6ry%-@%+SHK+DA;}&nUGI?XBSM?H@Y>!TcGW0)4k9Pa*LNG=fV}17><-Jv7YMJpIjb^fRp%g(^~ot^ z3*d}N=L_N5_=hINU!PYz(KcuqV^AxXi@W_U7?N4=sQE$|=1=gn7c)l*Oy*34jA1yG zT*&9JHM8SLfY{wIRUFPv!n@oHB@ak|f^Hm3o@iD6EG>8MiT?IZs(mfqH)o7~L{_ZrW;H1^_mn>XVQtT*+XOP>fBvfL~Hv2*z3CmI0-w(N!3bV#HD-Pfe|2 zgj-;)J^6WQKW~au8~d#=?ZPC_;#5B#r|`V~cKA}U)YqQByN3?2w*JB>Gq@Ct7XkmhC zFvsLwHLl=jDWb~q$S)s*--5zV+p>MajUczo$=PG@(W~ZM1jQZXCKMn4jE$dyF zfkL^HCZ&dl_H0kv3x5lel){yp5Z>?gghbrtr>IsWD$39zv+nDZ7kcAw#S{{ZLWx)e z5vGnxN>C~+ApJt_P!1lO+G$5!BxKc4WF6toEHG*NNFgKm7vvh+ktu7`kHS z5H{NFQ@o@wnaR2j#9_s8UIlgqw4N>YRp2T;f8sVhTW%}&{m*DCr{IiUDpe-+%?U%d zQAn-(Pl&=^yuVqE!?11#?DIGMegdwfLJs>8V#lpLWff-WU~Ua4D0rNJ>c}eDs`R>& z$Dqextq4ImwL0waknzPTszv{mB3umKJQ6C(XQ6kfs;OW=$0%RHh;nM}mi#XDzkw!m z1nN9{I$^j;qhxsHx2Ylhl{cvQA~odJa;+u%5=7@!_^sij#|!I9vA%*yAvNx>-0HTOAjf@in5w$8SJOK`ma^f z{=~BD;W2 z9>pTb&(L^=U9&%AjD6g&*uNV#`xkTAN1n?*WVO}@%g?X{8*?`180KuJ@zL7qTAev7 z81*pfsU3}vmRFcn9cwu&A4xi?i+AoJb52k49-5-6%FT|&PjXwJ?u(i#Dru;jT0K`) z;%eMO-L!^kueVFpQ?xxhubwtgt6UJ+r>d~UT$8s@g>a7YRBNd9t~xM%9xlyYBWyUY~jN0-$Unl~(C}67)ovQh1IO zTgWBMbuBc-Lw!od={KRsZ&s*;Co;xH(xeFGjwzQoqYxGuF3Ogvb%4-yYWCovR#A7M zGCHo{Y|5_z#x`C&|9n3C+$`DE#S+{g*bc`@(hnn~`lxH%S5ax(k>{~RNno}l#Wso# zh%L9q1|$eB9$!M_|v z85{K#wUo`vEJzIUusA^}>}f2~7Hg)FI zDAWMYIl__Vk#Tv6AWqS!IEd07kZjjPn? zl2_1VwH``~=~tdZ1~ncg`)Jf8?dnc;_=y040>{XoFqB(OXGwOs>t2BVD@E`FeD`-6CnQM@K$UpwhAo-en`biwk-AT5nGd$EPkxFEZ+5bW2 zOb0oDQZpv_n3@O1p`r3x=mYXv-Lf9CSqf{<8s`nT;VsxmsK*9Q-$pN{`j5YU<-;PIrI+Q*(zQ;2{oXj$&Ss5e+aTht!Zs zILZLBqkqDy6v?wMo0s9wCtofwJgP`J*{jtNNZXqLu-RUcn{>eUsd*huRto|ici~sx zpdsaiN7TDOjgqz7)FaW-4xXez5=Qbz)XbD$*-a%dBWn=qqp00o#!+dxZI)X-APTbjbW zYjo<*u$r;mF*7TQI!S6=n*SVEzuB7w5`tDT4-ApZM3!JIjo#hp?zQ{ z=jgli>xy%ryC%+K{2Kbs^A6U<1$?jnKjtsdXIdA6T>mO7s2z(S@V4rkH+ck zs22;cvg2hnIogpT5Q9;+J{X0>U=h<1ZyFqK(*5RE5kp!>L`ISP~e!5|8HdtsDh z^{|%?9)u*g%u_O*sHo?aL{mY%iPy%P{2oPWJT{KZW5YClYRrtln%OArnT66x-Je1p zSB=LezC1x~mAPc9X?14oS00=1p#RWIYe)5&cT6K!P;Us#J*F8=%(H}d#aM^-qUKVq zPdzP+TTN^GZQ;%wVO=HJo1*f>NzdtTtm{5DXoUVoSsGF9#Xen8ePSP1KzD6v{qv$m z8Xn%V45RMnkLG#ek68kV-;b5gto75_zx(DlUrv1Dt1l$WX0tw@`KH>PjAWE1*XK>X zKNp*SZ4jlM>;AnxSqK z=Ikr=Vdo(4KitED1<9p#&LsZ~I%GT}AKfE2SUl89Ms^-=av8-gOzFe+p_#(81Ptrg z`kD2y12r@i8DPg!G$pIuCUd9LbgazUd^Ix%pPGsS#HYqz?r!bHBe@kuLdl0)@5B$I zVH}EN3);CA#i`thLGL|Z<7WDXFs)1g$f+D;0ew+-*QO&vV9UL|uW!13^h7?`8>rjqh9K!+@ zjxfnw6-K*Kv#_oO{2n%#wNH*zYqZWea4QJQWV)-lHj*!8FIvAVKc0YzFs{sdJB}W_ zE))f^tj!|@xSxO*U~1Nk;(#03=sT6nN#wAv4@c+HCnO)MT2%Z4icig~76iRvm?S}< zE>QC&Dqccy4sUV|g>kWNdi>Y2JoB>IFk9AD(=pf0I+q^G)}OuF^y1m_0_#^nWl1wh zQsnC2#+wi~X2ucTee>8nvII)#q!nJqu(|Sohl57KhCR+D1ea47iMC3ZkJc)md0zOs zjX9NYJRS~Uz8Dz7i$<`uM5birJVBFB_QG-D$I4CwY%Y1+d_!e_o5rfQw|Ng_V|Cpx z^r*bchl6y_XZrIkk?Vi$(Y3j+?&Niju6s*g4?FW&TG?&&c`#4wZEX@IP!>6rFjf+r zx;QULs##-KrU9Qh7|WYzNyyrb+VtwO;L5`Qx;)L9%o~O~aDSOS4uzH|Ih^7l|KZ|g ziT)2NGlaxT#-Ta0)m{q8&dhyF0MrnC>f89Th5N)hbkOI_tfCJA*wQ^K#-WQocV^C< zT??~3jE3hzi#2oSEeP2A-cbdkjQ#4dbqHWN^AOIf_pKX7QF8&M&|llHe-&z%w_EqE zynTiCwbW;q`s{Ka^s+X+z@7n7nmj|BUS-9O*Ze2LFlPc8IUD|0S{|15_SBy#( zPT)oSt>Lc?7jc=trY|gO*3xXm7I!CpBonPntcxH>4=L7MnI7O{u1w|;&>ED%KDG}@ z^l%r13HNP?-kw3Vg!5xyIsMzWa{)Vogy`nNFGhK%MpeqAAnNM$d@gEibS`N18g`_< zN(ISxdm-;1pR367I*u03Ez~nl%dy+;#fV$#`P^^h{9~?t zh^YxtHLR-H0@J^Yzmuk1oA7a&?w{so{Zm|~M!URRjoOVA%hBSyz@D?hh zF%wf=#4i~)wC0L&*MMW7)@yJwMT^ZX*JpOoBJ8_B{5Z8U{b=rP@5ofghMh3&!GkY= zt{+cQ#A| zQ8vu!l{5CENA*|@XcQhoD-=#xmcKzoYa{HOtlmknBg(kQ>JdncBj*^hL~uZ;5H_Uy z(aijJ1cF24TnHB4tOBXq6Wl;Nu@9@5?aka-6#>NMkb-c+vwj3h7zv%P>SpyV!cm%tcE`aGv;cQ zdlxb2_oV{P^HW@9Bn(HivTXP|S%xk7J27F0ai{=VfJilbnUiAtKsJZ8W8QdXR;c57 zk;|1a%LZ`K0>N~AWSex{YqA-=odS~M2w3q>7FzCOa4Yhw5KLk+~9=La_yP5lVs<#j24kFQRECXj3EcqZGl^*?yRkG?Sefg zeNrVyPCvx#30 z?WP0hW|_`sjrH}~dWBelp3fR;ulqx+V$vPo(Ea{aWmR8sZrvNm{G#UBEl}C+*hr;0y-$L^3{wK+ zNU^2H95#j&57n!9={F(0U))(DrWg|=xr$%&zNtPZEJCYDZE=H{+DAaCGqaCavyLkm zx--Wi+->zQps|Y=&mZ5&zPQM8^;HSZ5bTDNBpHMeQhnTOJYoRaan*~M!a9|$D!GeG z2f)ak2?dqhR_W|ffz-ZFqq|hFqOjh~^Zs9=U4sc|NgKLB+R^^&be_)ME-r`kwaENl z2De9V6HQL`IpX7s9;%bFP;WwR0I?e<7uID;4=IjAMGt}3D6+2uxvk?X#jQLzg_wSf z`jJP!ob)+QD2}?*uPp>UJXp&538;msAHpL?_&)%yN2WUC|G}#puYI`5JkySiF+6!U zBuhWQ5B|+C&aCl(nZ~Xxv!Oajz~ThQ}B!L&Z(Ruh5c^UKn-*ENofR zFBMxo#k$foz;lOS$W9IzsYHbzmbP9e3vLyOb+E zF@*z81NGqBoquzh`~tOCW;T9-TaG-;uHf4Pr#*}$G*q2ckE%d8z)QzK#>k0o-jPUC z00Iy&^-Wp%h6A3`5B*(Sy5%>)FgV#8^y$`x&bJ9ha&GGi0y{c+;t0cO6#>6RZq=!= zTr|`q7pDFdMONwK&N2PU>qwm@qhvBtJJjkXY1q4;RX}t!NhRj8z$9D&?r$%V2#keB^6UcREDJgaJ6giW-`M_QN7>F1v(*dkomSwkmxxVuD> zkkuFeAvx_&$j{I1WQvcOd1M?LoY%rSkk{(c>LZ(lXF>Ubft>e5AFel++kldJJTjnD z`PGw{oi=CSvghnRs2o=so3F3v&y+&z9fOteGde-JFDtjw95Nbd>chp2?7!K z5I*lx6UW+*srC{Tta9I_8flo0gn=f&E3?P{xK zI)CL*)oEoeV$AqJeMBU96$`b@gGs>6E1Q~mFD2-J>+w+>@%^v5M+3+)Ful?y=FU8; Y-q7geCYe;e69PDwt%kk!AI6J+2VK8tTmS$7 literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/modeling/__pycache__/sam.cpython-310.pyc b/zed_utils/segment_anything/modeling/__pycache__/sam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..975e55738327300d22ba6d76e88531aa52c32716 GIT binary patch literal 6659 zcmbVRNpBp-74B_%7S3>yC0Uegx#BG1AQWZGagZpoK`rqDNQ@)N7BaBY=$Wb+Hr3NT zs_x;2;Ykwgi(urQ18L;qTaGy=|Dta}fB=H8$sx97=X+JNav0f26J1kXU0wCwSMPoA zRWYtqN(L_L{BQZC7YySc^e}!ocv!`qNEF=QW@yyt*9wicSu;(l+o9FAYmTZrVXp1g z+;+Z}Zx?EXcCl7OdyczdsXbGhQT2RSZdYm*RWF3I?YY{VX}oRlA}>8OcuANW^AGIW z0yl3P)tM)l*I28XnH5DSOU-W7NP{@?!>W_nSE8NFxgIpq%()RHX_mXw?S$fqiEq$( z?%lTE62FcbF&FZK;0~H*Zu-f+>q52QvGSgb+nu!Er9pFzUzPqt@-I|`Q8T$wv)%^^ zHCs3$$F0lIzUGQN>ds}O1x-IQQM#X*O^fG6ffq!P7extL%?M}F;4}T-GOskt9~(`R z&+@s4m0DR;cv;Nit9icg(0O3i=EVH|!hyk$;frIK0q+(KjPRUz4n1|-{Kj$J+I0DG zenNeBf}a#<=g*6i;(30G-t)z=dYRrf3;Z-c^U$fC;xDM)y4B$?^0R1N{Mc}e+G*hb zQfA-w+pPg8*=!s5dFi?o$$A{}%)a&Z)vO#4&R#Q?Z9lErt=H)K;`ghqTYowG=!=Dq z`jx-`ZWS;8_V1++KfQQ!we{&=sdi&k5B}#Ht64>lBlLo0vmLcK-r%{c=H`0wvQj75 z7NOS`ev}n;JxO`yN|APDwAOG3GUJ!(1>837_JPfv0~`J6ZBD(-O~2(XFw6UiWf&#aziwZ-$vRSSSZTuI z`i5wvj5JAD6Z$bfVoGgPI18dqH$`tU%4DT99OH^?b=x9JlS{)Urav2?!)f~IHFw_oCi@WL0>;*ei+7E3~S6pvl%o3 zNTbvk=mwu$85?nO?byGAnpq5Cy5{ z$$7kx3sf9KkenLanO zqUSY2KS@0AbK}cjU09FX;sWeiimeOR;zA-?WY3--?WF6#?Lr$I9a1In z+bf+Nc?xo!#GRB;l#aV5{hc*mHh0S)q?Lgt4WjzNMRmHb;Z6vFp0QzUP#7`Ss+JNo zbGg_I8sc+PzK9Nz4T~R8?PV0N;wA?+3f69k=F(qemAaoc)&&QKaMEnP-3?Q$Ea))q zrf6vpy9};bRECW6Pdua$MDe}Jt%rr4dEIzmf>)f!)}H-|v2XS419Q)58GE_Z1}8ZW ztn0={AAD%-xqEq@>$#hT{AthSxkUq|yJ$46y+VW-vsd`U*l>D*o~YItJR+oUKu1q zF(hI}8}>$1#hXI12H1=~ARkqh$9RBFX}I%Uq&r~*efK45CtkS1h$r+6n#r|Z%S`x< zb>A#uh#`; z79&x70(6Uy;j&)5 zWZmxwc6EDg`yy8D2#7Q=Qv+$#n$mBRs?WeqOCI`l5k5n5{a&rdgc8%ss7!7s^%PL` zcw0PU9=z5Io*dHuXNN~(&p@MZGv7md#j*;kX9z`|`chGzGql5L<@>>ztQTwcD%tzE z+=q#sE!E>Vgn!>%7Z_EJhGs~DyCG!{D!pUsOW5?pI`aS_tsbTb+9(M*&l1&Zfrdeizwe zxfrLkMM&VNYC*;4!cByOm|cCG&S*Xl5yzYdE@RI7hnJR z3cIEP4eKC}1D;@0wV=xVrbdFJCEa==?st`uVHFe!t4w}`(MXqxe2tPU+L}??nw&<_ zLdNz{aStiVJ~n0?+uZ~657vE)GPo?KRs)q121Z zE>kE*)bG6M)95tSX=C>#%IDzsTfXEgm%|u&K=LLkgT9i|M@(V!pj|7Ehl|0Jrr5PC zrzo;EL)hqyFsa&^t3FBOGRcrr?o!QV^Gb4Ww4cEs-IN`rmNyA5GkppQcseS*{Us9NS?VvhFhDRG^uJu(js8DkYdxQ z()`Tc4LWM0kl8Jf>fJ)I8zoq8u`6n2G-6Mtc9*D{T@ zR0g^#x#Vj!IiaF$Mp=(CoP6IpUdOv+6~zn7CDXBvnFaG0ejU>?OZ4~D$raFMslU=! zPTn%-X^-->ltUkj+MTEFw@!X{X3XaFlov7Vidvcmov#fps*^ih!=2nhL3;)|pV+WA za56z9&WEsn;Pzu@4{x}Goq~<1(g^os`VjV~aCf46jQ))IAsg69IP&gyU8&9z=4Rb90Rvmz-KuuoM-`h`8 zJSg$%wh9sblyt-;IJBUhm@@R~WK!q3{Umb4=KZAOhT^Mq)xf zMB7%dDX}%|Y|JDz^@!cMiS&gj*fA z%-hH8vRO8d6K=B#w^77BL&r!FHY}c(l{^oN+YPB+_B_gC^czk1Idx)|&3QQH!M0sz z1~`41Q!*y0ltt9!3>EZm2uawhlb@`BFoL*|?nq)qbuOcWQ$C8lmh`#Ip*?dbGnZ$m z%VjEP;W8Vhq|8B*R#~NgCe;vOh{+lCQ^i@TDchklCNi7k28xnFy1RMTbiM?Ue`DF! zQ_Hcwc5~KOu4{d14F9YM9i)9if)FHx>M9nla57byeZ4a9JYcL67X6;I{|ma`GL!%S literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/modeling/__pycache__/sam.cpython-39.pyc b/zed_utils/segment_anything/modeling/__pycache__/sam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d253f3ebcde38fc9161c08e1e4e85df580cbd78 GIT binary patch literal 6650 zcmbVRNsk-H74D5q4u_o0lC7~ExwC2rj5xMz2Z$gW+G5LqG_WPvLI^vJCacKlQL~$> zYR(4ENfPXfVC1gDjFF3h9CFM#`4@c%5CjMy_!#66FLAzC#dSE5jr5?4)z#Hi?|t>& z_g>Wm<#N%$Wu1GUU3keb{!S0$mxG5(xTQeB3}!|~jef1jXqz?Dq`DnhZM){Ex)bHv zZq04yYx#DeR%p-EX3(BvZd7c~)@D^bAC=nWT3OW#(Oi4JHg6j57;J_W9~!L4&CP`e zc5RWFw~fl|Q_O3uSIo?cW0b{aH*Tb15(iPm$?VJVZsuGI8)@d;2xXe(?sPj5e`?|z zbe?~=9klqb<3_@`_%OVKrrDc;ymyVO7Ce^U6G^+1_PaD_uJNnTe~^Dgg&Q@K88zzz zpir~9!*k5K2<>Yw&!g^KG+NN~a}%Zex!JUMfzR+FpXJW7p-Yx$g(vPOM$=TfmTM(m z=5uWJiOEW={Lpz|*5>)b!QvtHU&5$!O>5b}*U$Bz7dB7m7B=_9X7g-8eRGndoj=cC zV2kv2tcG^B^u%G$u@euS+9}nK+D<;P*z@cKv@L&PI7aP7Abl#cZwKwxK>cj44W#^Z zLvXo~L@cv!y>lfig@mZzOhh|KD|YKOy1xAFrPi%Koq6=-;>Z2U-+prmFaPqdm5)9< zfAdo7vp-Yq#w9)YA8%Y*uee!R55@KBWOE(068_-1q^4(jrLtlt+~JYm=0Tjz=(Y()IG@eoI~@V zz$|7n`_N|2p^g6ZHmBa^rr$CbsO1A`8Aj3TU$-yc^g4nwSfuol`X+Cr9w{QdCe-tS z*i%}g!n`o+08u#i+%)PGU z%u5qbgXyIkd~B%xV3;30v^cz7=ZuAMOQLU(dO;K=+a6ZgDLrI~?#dk3uj&(NmN8{l)iI#(nmJ_uc^CF<_zLlqzEa zdJ1rj1{*;>nr~!Y-bw@_S!qTI3>>oGgvmkhs0)b+keEzA3ikk(r@oTVf)MU2qX3-% zM_Dec_2|gA(e(chES`gDKm87uAZveHoItr=naOgp+uuVcHo3FvH1>at~mTboWkjXVT#9ExByq+ZYCp2Swdl)D4MSNJ7eyZz)#hIS`|~A#QO$O{otZH-B)o-LZxSH8k-bG%=L<; zgwI^Yx55Vh!W2J32jLBiA5-loC|<=)PK~IzSEMQR7Zp1z*MqdN!5NT+>t+k>Zj@rZ zL6b>0MT?e*h-Yxkp)w?+f8Zg7JSqT~%z9Ypnb(X5CMd^wZ0*~h8V6R-J~a28ma(5p zZP1nTz`AC9{D+UMeRn_4ay@s;5FhqjmRmMZy30n>+AqWo%kLLHH8!1Iq343K3M_wU zsoD(HSn;uq5oY#_=;QYC2YFDON&OEBhemIvSL|7A_Hhn+@>T zXN1FsawvvCG-<>7XsTq33$Fp9qYub^l^uF4M5i>`{XRaPu!O#66YCRYT=s}U^bDHh zwO*?p4;mX@FqDT>qmL+^5xd@Ym~JRhnhEhQ33iX~$|?rbXwV~Mj=lOMpYxqY)K&O_ zQd4NG&_NuM@U(F*q6yC`Vb1D^q3T`$$Pq5U0$!>}fG~Y+=lafFqC7=@~KuYwe^+jMALOL~nBEnXP{LPo)9)}z&DdlV{2uG{<90;mdHt6M1FT{!| z2}YD(=a`v~LS}(XfqW|khNAjM5M&D+BfcA=N6;{s>iaPbBq7-OohttaU<2G3@ z*6dZX_i?$8vb(o}AQr*D?{096Dn>(l$dJ1cB^E04^VFBH>4|k_Ap%@IY;|Ft@NAX5 zta@-6ox{QdiMt{Y5fm7spIa)$ok#|Cit_z*vkfV^CLJDe3cKTMsN7io5CWRe0}P^M z|KE;Nhj)X{)I#(jfkVpa1(>kDtAs!7NXdMKNQ{;)?jki6Q@P^wSXb!_#*x>w!x&^# zhX6;yGbjNOM{`1d{3DuqSPxB31KLqSC>I|$4TA10QdxgYy=gq|45G?eIWP^vVS3J{ z#_E0-iDkM=EC`FzOmcN2^gB!@)@ji*VM1!90rs!#E;^8FY&|iP^ z?wjY|_+{0*ssfGIL5c@F!KP|KmHSPN1V@wIy5#q}%E+(^iiA~?Kf-9F%S75nxfbo( zD5I4xqG%y;`*>*|Darx%Yz#Z#L-TjmeTNdcET>ijl_Ca4Tx@VDD^_MzU&+iJ@l$-R zlRC=hvf@Pw#fbWyw*ox>6SuMVG9`8J`t3ll>cuEQ79ih3WzbVl4v7hD7Pf1p@nA7{ z(geGn53c97kYo}x<5R${3Xfrt8 z=LLp1H)BO6?hcYLa8Tlp7gQB8p$(%#B6JGlLl2O3t9qf zBU0=pRsNsZdtpcIATqnfQ@xFt>BbVP&G&e%ghp)El$<?AVnuZypg$5>>|qDgkQ(B%%WNRwwQC! zQos}bzAomeeZh3Ub#q12u~h4~&UbEpZ+6_~DvqA@H4Lm4M{gYHtZi^noqXadZg~p@ zZ5-(QV$<5h$pw{7i(wPN?8nYN-Y^H-1B|$)8~bMeTYJ1ee5nxlaJ``~ z-{Nf~#qf7^5}>%U+mUb%9g@JuI+_kFt-s->{A?^-9eaAA;6|c1o@ZScCR+NfIv`Qe znw+V>_n+i>@Zyyn6({<6>4;Eps6l%&W$)7&rA~JHndFGl`ipg#DshZxvg-EV?(k#jGRjE5g=Ki9%7Hh1MxOk$su5~GU8k-MglHkE@X9G5w1GM zz?rW3b~cOjoQqBp1^A)kiwsoRQUk+v)V^JL?Z|&Tbka^t{wsck;t_61CmzO0^2#~# z&~IG#419R+BzfZ$PvMB5g6`TGDJTsta>P9Dp))E+G)~d^f_WMa)-yMiGty2OyJ+m? z@y5n}+&X3K z_X|95F71Amu~!8hRcQ1o`O(>oxPWGTazhe}%P2A%LsVQ19%-4G8=1LXbI`n@%59CG zM_(uq&!YBw-1;EKUa*(U68=t@dvglDQS?1cr%f?7FTS4@eIIMwji_GoeadO{8%_T? zb+DGr`#1o?7G7rwIHa1#M42s9A!bN`St{s16N0c;XF^#4K?PAH-4#TR>ZnGCs(hUI zEfFx0Lwn{>mM%_Hmp7@Ph0APMl(GszT4m+_(N%*-5hnMjpDNB!O=(Z3Ph>rkf@hI5 z&*LEKD@&8QI+C*fy>J?IZ@V@b*TOe*?>iI}Oy}#;%YweIEGYp6*-}=}UKhv0#-m;o z7bl@e`n5tNy;he**6XLhATcx;>J*Q#$=sU?WHrs0{2p~B`Ln`By$o;BU{ak!krLJ^ Jix#nD{}*K?EIj}K literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/modeling/__pycache__/transformer.cpython-310.pyc b/zed_utils/segment_anything/modeling/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccdf4fbc29e565d9627b7e5c5790bad735834354 GIT binary patch literal 6603 zcmbVQNs}B$6|T%$yLy?$V>u)*6t;nBkY?}_1PEi0uv?TRv($Q9(^*Qd32ZFRG}iPHJj`-87oDhj_QK z)#S>#6(?CL6;BeB>|4V@FILvAgF*a|p-Yn~@5c3auJ+RQUYo|0pQTA5+eIceT1*%% z9-F(CU>z=aYzq^~Lb8#Z*g>0np5|-QJUfj~j1JTDtG7H+iz}jz)IjoLKd!`;4nJ>n zJh3E}Kek%cPYlaw)iC=NwQ_4e{UAEHC8H$kq_Q8&ha7WIwulG$j;hD~?O2Fzayt~= zzN#d{ez+4yB2#X^H_&x{)Xuwiqr96Y>U@;vaYETpr29!0^mB)d75z>0vq#Rz6>JZ-d*B>1X56`u*L0gB%m<{c4n2(#Kc=@V2@AHcqi3*kKbSmM zf_!e1K*a4-#4^ZtqCC)x6j_j`L6ijDWH8KwZa=ym2g3}OnyI)mjAb{@g8iK|i-SR$ zb;&YOFPIrM=w?AS91MEhSZw&Sfmh`1?DDJ>;O+Pp!Fo5zo0kK8f^Ezs?F7B(AeI>> zfN90$>6V7hUPa$(s)D+=ccLVTdjYE9x0qLe~3z`kgO4??hzpQCKA?+DjPk-3U zyHqTUMeE)OZVa=$qzvoo-tCHbZVsAbsl&v>%I`WoYZxV#WK?_&S{L?MuuYq+q-gG; z!FrNx+(^Z+7sHg)ztjc|Lwu--Jx?o=dLKKIJXY37)JZf*&<<8lWP9B~*y+kF5B0|h z2l1y&EiLSYt>&__@!4akyfEx0-8>8>LEWNe+7`+$rhC1c<;p2Oxw7G!x3Q<9!5~hA zS}Qj3mFecwAfqa?vO3@C!bBk)Hwp7}FD4mkmsOSPMF4kBH!*(yqvM~JA`25=DPTG$*qeRGffAb9>K`#YnY~H+bV>63y!{PEUN)GZJ z?3T?wHdc=^rTw-sIFL_b@heE#cOm>I9aj8&R^txirJuvB&z(oS>M-xovUzyn6rycR zjjWdRn!}lIE3c>+#A<1f!jE*j1t@+*xx7(o7(); z+i5?RVKM>e!ZY-I1Wl&ECxY5D&}1jN3lxM3gq_J;g9z|Q-RIx6JduS@e&&2;zOpb2 z0`~*V7EY6ubp{&d0zDaRFKwf(u;_WpTm-sgrfG*#6nIYu%oBm2{Iv@B8pyOAXIU(Y zsm#||viH}GTCmP+l&rwZHQ-4Hxu&JuNjOv1UVM-#Pa6OLryP_C=uE}#D7L5A-KGP4 zBVSQ&r>CQ(a)EaJA)xvx()$dF=SXiueuMJgBJph!-+@qkJ5#K!*p6cRa-FK1RBdbc zlb5JsgDO0YXz4JQFHoHpFH(aGyY5F)ypBRfHI213w#F>eV!k;CkJ^`>=doqxF`s$- z@ce1$nCNNN391t~Rv0Qj3;|jQA;{Ol@Xj#m6)jQ&w)I|o{e5M^F{Dj%aS8y}fF_At zNrD#nBWkNpV6z?fdI$pZReY1W(8MKiq-CV+2N0fNxtv*_an8TuCjU3L_?O(~Uw97x zoV)xp?lu0!FO^$;d-V&gqmAJwo*C(ssGgEe!8GmU!t7Wj!fAtQPD!T+(piZ;qMH-q zSrhe->n$RjzF0z9Mp{Ag<0|Ncz_pEXTB>wu~?c>yo(Voiqjwzn3;-z>yP ziUX32aFRX}B$J8B?AdQ2=JYY=AdvjNrjNM=m;^k7>N!RI$QDe~K)2MU2_@0T`$Q1^ zQcJazuPtb(gK8RKLlF3yQSSu!mN{8LN`n0^m|Q0v$_abB9IU_83@FN8(i?@C5^=k1 z5Xo$gwVIP<7lZYen?YNqSvKcY^WFs&)x1=I(-D&aHKkFnGy{;S+2+>#D8P$~x6Jwx zQm?;y<{Qci2?@`fUkSWGVTTmwPr%mZbE6j2X=a0fEV4L{ zovy$_OPqKz4Rrl~N#)mQID;()>)UDCLm;L$ZGm^*qz^^5(7dv6o7pPlD!hKHxg^){ zUm(oW$f*i>l;(RHljO5hX%oMDNuAdYbS4?pJzkcyuat>J7eiht{$4#XvIrcYk(y8H z9olpYFn=6BF1E2oHVt$d-iYQUNEyWh|6GmL7ze@g9@WeSM~QO~Kkt!m9X@*&%bjt% zd6-PiFjyyWK zjy1$-*2lNxMXabLB(9bjV1DW33{JPNJzfjjKL#h_16T;_NeLe#vR5ZOF^?-XD{AjRe8F{0RvP z=Q^m$pCYF$+AL?qRrxdOsGUrUc{HVv(?`lkm1l6<(kF?N5!FLP<>iRl1Z8+sPCj4# zN{6pD)SP(8&FIPQLCm9S%dBv*hHD&(som!!*`p-NTpy%LxqD%AyWPTmag+8{^)GpeYR|%@#$uycEx|8=U#m8~m7oAq zTour&lykkC#F13aoscXD>OA?sTl{<2VmjBf7NJxd?F{Vv3t5YzfbTfiXAV2Zmazwy z!FCTW+DZezBNtxedLuZRtN9i6rHd2AwiQd{t5kkrCSIhLr7(-Z zeQ;%1Tr~b2l^MOHh85M=;rY^_p##!*P>>cy2UhZ47l{cX5FG3+1Ep_kUc z!dBVgh0@X%y3Wr$aucIYTuy*L&p+XxP(XmdXVf7Dm#LFbK|#wz9k}fgd>#A4I_5f< z>D{DD6k7CUt1h9}uB|UQ^hF28@FRzeCQM-+TgwIloRhlyI3P3+e}) z+D0b&ASfPHM=st>x!NS>Vvo7S9@8H5kbF|R(Zff0a9M6*=QQi`msEe9#0?VflOPY$ zz9wHrPFej3H`P{lN9M}bXIe^?H8G7sgwLbfx!^A-&kVS9xHys?hkd0!?MCsXargzF$Wn9UAi9&3& zgfv=aY~8b^*fL{h{Bxo4R*cw@rnDY;9~)bOX7N#@<;NBNRY`7{E11tCTYA#}$Zpj> zHf*C+$5<=+%$>dTy=ecAijr(ARlQg}Hu1%crl%bl_w#MthmC_6n}Nd~=e zJC0;NJ+9vFw>$PmU*42{qnnvmAeg0zP=q?5ac-jRWlpI3(Bfw*pbaAD*PjeGuKHLsz) zg?7%+9ePsin9_Xc9tmOWUdrqI%$3$d>aBr(7KY{#ePvPFu!a*oi@8^W@pUE0XI2Sh z+)ibzf_yv51D>PIf;DB3{XsIS$Q1kY7lq7LCfEs>^aRqPy?x>j1^u?^y+jRBIJY)Os!pOoVu8dWgnlC1>GZ>`j>__o7bon;p9sA; z7YZBACGB9@W2OBt>?EB$3>87!e9sgXx>^kPMkmX)TP(SD;F@<4Qc=GjCsHpLVSHs0 zdkiYNGOg>Ctqx2S!f}%@Pj_OHA-k-u+#oO7K=A9u1QNraG;R9Y&f@OYsbibXT68eS zN*K1gQI>__C&s6@*0gO8{Ud=Adf768Y;JV*7_!Zw%bXE{ z)jU=0?X@C(Axii;gB|&oL}@L!y`2uaGT4lRY&+`5rC1AstHFoILmKsM+9vXBPkfzXv<#gIfEW(KF&0<#EyPeGXN=su7UT_6*U#~MU{Olm$`-10!? z*8SA+OdV!!6a@JD7%e;}EB6^N^aVT^winxIBP?2;Fc$$anQ79Xlo{UP9C|D{lvk^O zw7yE)ahAoh7|Lv(C3}C>TXWW#zNLtkFbFCaB?OQ+PI21aiT5+@vjG5l+C`nf&(dO7 ziybZQH(g*J`HJ?ox|~6^2L$X50N+nhzt58R8uh!NzE0(DlK2*hZ$oHvGt;82#kLlE zYL(PYQacQZ>P1qlk-}$`O9#1no^&iO&;t*4-HVib4V8>E#`1z#7R#0`Y^yp$kk9MB zFMP3tyJ{Y+oJ5duKg|ZgcLK)>LtPC+U>DK~%JneZ9Yo#YiDD49--&O$t1UQ&a%e1W z0qTrviYS*NXi-0)XN@tGw&QLWiD5R2Z%`8&xFX&}fRqIgzGn;bU#2+$L7%4b&!zRE zUj2+)x6$y}HKmxphI&FVrC^c?I=N-flg!vBayh1%M2{P-N?aZPYV^-U(|_d1imZZY z)|qHF06;Z$9v|r#7&iORn;@9;xo(_I#Mv@^GeT{JQ7uWKxD95dQvNu>vKIcKr4-KK@zB1QTG_AmibpfK!Uvv zSld=QP-CWcIaqzE8Bj)jkpqBO5%IXJAE|7Hq0GOsi^1y4&7iH)ESqttS?hv!G6$95 zZ^T+aL+PzongNK@G-#O%1^7_$m1#Rd;MGf~exckUA>66sD}fT|>VVStF<6>nBi_DK>yCX5t68{|!15N*50{Im^oWcx)Rq&uL5-dGq z3*35x7K&`a9I|km=`7S)c>P9mQ7z-YK#pfn(iQS3=5~xV>N!$6#N}SnE9^k05DZe^LI7oC?#Ka(0A`T>GA>+1F(VcO7{eB^xCD1bSZ|!Gin^!+ zwuq_K%>|U}fG*VRgXd0Tv@`1F%o2biMyS3+f+-qd#N;B3qxv=6*z zN60V|Nq{nuD~F#r*gh6dT$p=+kC(#7M`a&fHKJF{5Usu=Ev_)m0KSS#_Pn8zkN& z@e>knL7+J=+cbZc%*QG8`_z*1Xp8svC(|hBsNbVOMtvBzE6$qZWNA&)CekUA=d5lX zJbUVbPBQ6&x{SWm4@taAf>2)lhy-P8&Y9}RC~2Ew=Crh_UZ;laYb<8Tl5$TEcaba| zQy^WI5+E(=~*gp6jNB@XU;7${WRQpJt`q8=WMMly-MRvLTr6`2T+K?;~Wm1ar+HB-`v1Nc=Ndi@@D*X=6+G zIdcgFbV(eXy)xO<&(IH~`om&GjCeGx$wiJzh<*)sM)<&+Qfp+(_6n-+@yr^{=5zm;yXmWu!&sO!1s30F;WjTZ3xOAKu zC`Vb)fzci!6c9f4k;%D>gN zBk>GsdF5Vp=t@A`LzA8jEwn)nJgf~pe3?r1amhmndPNAbQ#~d(WzTy21uh(aTL`RX zL;Z~OH%Qzh@h%B+Dt15hGD_O+ML3GKvt5;Ihj*^p%hIj97k#MSqox}WO~3HU?!I!g{`2pI4K>VrFH z6aNZT?dJP^q=I5sNfb!C@XubFFyG@;%G*Ltcid5F6g(b|UjNl%Q~MfWN=92Wq None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/zed_utils/segment_anything/modeling/image_encoder.py b/zed_utils/segment_anything/modeling/image_encoder.py new file mode 100644 index 0000000..66351d9 --- /dev/null +++ b/zed_utils/segment_anything/modeling/image_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/zed_utils/segment_anything/modeling/mask_decoder.py b/zed_utils/segment_anything/modeling/mask_decoder.py new file mode 100644 index 0000000..5d2fdb0 --- /dev/null +++ b/zed_utils/segment_anything/modeling/mask_decoder.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/zed_utils/segment_anything/modeling/prompt_encoder.py b/zed_utils/segment_anything/modeling/prompt_encoder.py new file mode 100644 index 0000000..c3143f4 --- /dev/null +++ b/zed_utils/segment_anything/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/zed_utils/segment_anything/modeling/sam.py b/zed_utils/segment_anything/modeling/sam.py new file mode 100644 index 0000000..8074cff --- /dev/null +++ b/zed_utils/segment_anything/modeling/sam.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/zed_utils/segment_anything/modeling/transformer.py b/zed_utils/segment_anything/modeling/transformer.py new file mode 100644 index 0000000..28fafea --- /dev/null +++ b/zed_utils/segment_anything/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/zed_utils/segment_anything/predictor.py b/zed_utils/segment_anything/predictor.py new file mode 100644 index 0000000..8a6e6d8 --- /dev/null +++ b/zed_utils/segment_anything/predictor.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from segment_anything.modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/zed_utils/segment_anything/utils/__init__.py b/zed_utils/segment_anything/utils/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/zed_utils/segment_anything/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/zed_utils/segment_anything/utils/__pycache__/__init__.cpython-310.pyc b/zed_utils/segment_anything/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f71759f20553d6c22454ce1dc4a541cc42766510 GIT binary patch literal 154 zcmd1j<>g`k0w&$Sl%qiUF^Gc<7=auIATDMB5-AM944RC7D;bJF!U*D*t$s#+ZmND_ zUP@7FxxPz&d0tL_VoI@ou%oYjacX*QYF%~K!Sw<0Hu5*#{d8T literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/utils/__pycache__/__init__.cpython-39.pyc b/zed_utils/segment_anything/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a890ac59298c8c652c8bd1adf7983eeebdf3c85f GIT binary patch literal 165 zcmYe~<>g`k0w&$Sl%qiUF^Gc<7=auIATDMB5-AM944RC7D;bJF!U*D*mwrZmZmND_ zUP@7FxxPz&d0tL_VoI@ou%oYjUVcexQht7RvA&+3esOAgZfaghd}3Z@Nk(Q~x_)U% jW=^qwe0*kJW=VX!UP0w84x8Nkl+v73JCGfpftUdRT2v^` literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/utils/__pycache__/amg.cpython-310.pyc b/zed_utils/segment_anything/utils/__pycache__/amg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a13ba4f704337823325363c492d6fd8b4e62e198 GIT binary patch literal 12107 zcmaJ{TWlQHd7j(O-j@_blhnoX_--mPWhqYL#Ij>qwsU7vTe+>RE|z;{mrL%2Ix`e6 zodq1zhN%K_oAiRVNm{B%(WtkF7QH;QuT6pWsX$R+-r7DC0@|k(c`y*E-}leXl1s@U zF=x-5^Pk&){`3Fe|DWlX%LM~}*3-YFetga_{);aM|B85d2~VIb!%&7YJ4V%%U#n`# zuU)nAYjvD%rkXKHV|Uzcwwjf^)5&%7)qJ;5Ep&_3Vz*Q+b<5Rqw^FT`#^(%`QSQ5j za(AriL_Vu7s+@1FOsc$J_KzGhRAKmA^ruvbHO+!5tIE5X>QOZzlu4<#;nM2!S z>Nx6-9jZHmniJ|IYL2T1rPouKrwR>%1%nz<0yYZ>dG7m z^T??u)l;Z{P|f=l)STx06zM)D{Z_Cp;mk23`DHy?ol#Gt)u}_=OL*g|1$7p+52kWYQ5na(%O$>%x7h83D-ZNJ_)La)u%X9^&DpNX@3SBs_@-^WH^t{sOLfT zsJbx5#Sy2$u@}^fpm=OJTDp6mRWG6Rab>lGZ?Ln2e=c5L z!V`QCi7&zoK{g@8bIMT}Ov6#GpHW%gRhF`wuF9!AO0!bR5=RwO5hb~##CK6rQe~9n zm4$_8P{M7oR7Fjow5TTi(n?u+bJP(vg|Z5SYFa&j+yur=gl|&WSpM-i_mz6E`bs^l zZ|@X5{9W{dP-E%q+V`4$jkmsT_gY>TWPw-jDKA|1+uGB9rylx{IyCj#J#VQVHkR>@ zHr;x-AbowI3!ZfSUKpS{Xg9oOyA%4-9Xo3Ddkyr_tB37=FKH3G+Q-nn#fAxnrGGXu zYa5zhzsqY-3!8cui?iI+s9UKEzb6ow2w1jU}HIUFH76lQ4}S&h6Z6U zZ^usHcbc)gx>47yAU3b*3f3pdhtHi|?sxsOV3qba&c4#$=ym#a6`XzJ;#bZFev6E) z)q7jvGWO@}df4s+XY1Y8!rE4xtJT`Qc37*;kYntzWLjp$bnjZ0duQ6-K9VkfA*GD% zRUd-hhFmwu@AQxP44zAPg40Ms1N=5aYsC(om5lJalJYyUB4fpk%nMM=;Ip-;%V;s5 z)sy(qN0>}8Im%?3$pcJgk<8nAj`;_foMJL!gg$}SxLB(Ne#rUNY9tCuNYX_)cPjSw z^npcR7#@2y5Y)*lfpU9UJr{8YY!;Dhd~oUry}sq;Nv03-rg3U{c4Kml<<%?zd?rm%H@XgoS6Lo2dHEK!EzVQUa(ws9Vv5THO!(IH@c4@SW7<}x&q^DW-jU$5_sgipjIn&^i z;1clecx3HmprTwn*_%R3D|t0Tl_A=HKKGGwub)Ntn#Wsz5=knmAGNSJ7xrs%1U?Vy zfLebFtC}&bx1Bo&EPqM`F{de91lNRNbi#%uM~EbP^{!v5#f4g}3sBr)zErDSU9WeB zHQ8EC^&7RCrc#S-;73iwpdVuLD3iySA#_`8Wpv+mf6E?B0>3HF)gO zKTeGBQnvB~WI_tIiIhjihW)Vzj+9wrWtq^4a9nq+J?o}1V5u;ev&XJz8A3NH#g*G0M2cA7ZULb%0WJ19fBE`mfr+X(E2QQoe4!Z@v$Cm-xyMFMY!>jjIcSh=%I_9v*Xsl&pa2;XB~*mGS_i<3uHM z0Git|S1suYCX5b+aB~qJy4~@6coh&rbBrQ~9@M_#i1Qa(@Ygrj^otD+LheBW8q;Rs z?)Jk83IXro2Vmrzyd)8nLAelXFjMug}FVH44f=0-pf6Rnh zVE}*?p+EMI^J8gnq6dj?nS(-9*egVN`YwAVJms6lP17_Ycdx=xT2@q0*&kReySfm$ zI0L!uHzm08YOm2(elWBNVLt)X^aQjY1VGfUfBDrjjLgsgBvBa;r@?P9YO}N@eXMOY z`i;I90vrcP07ed&u3|4l1Q&1L#h?Bf-XVMrdt zAeyx;F+rpgWmIyf&8O^~=|FN%LfcMTE+p9cxpQ|edmKn=h8lQ-ld{o#Oq7Y`M__{M zg{VTA4iTXMi^!UWS|uVM03Hk=$O8Zel{rQSd1~N7IPD0c8@aQ_dlt+Ia-X6w0}w1i zn6J)8_8$EL6JPafUMu;z}Yl z%Vb8sqvFXh&DQ$OCW205u`^DSKvSm5GXymNk7?FebyhIt&|DEI80PInA`+UpUqs_j zK2C7Y?d_k)K^|^=keH&dzSKsD3SrGge;CGvM9`ekV#!JtyvtAv2tp>5(n&(yPCEHQ z|Lfk8A8z=5qCpM_6ztwoSa0`4k6^?hGZ2N8Zm4gdgBa%o1xNFU2w^UPOqbiOWeKQu z`Ws#nKpiTcHu$5rkia^9hk2?E1g3i7xlhW^$InA9cUbc-llPER%c%Bs0Hp@|9K`N= zkH3ed?HJH#W1aD8I^fBQYqHuOpcC=r3q_pvyE#}Z>u$l_&*w7wI%wiTuLdm#q@~7> zJCFM4xbqCqQboqJal#;$*%QndnLvxiO&gfW*n?9D8-f3k-GW`JakJO@;85WC*L(?cuhkpa z6BvazpE+~x%BG|X3E1EfS=6!%7m-bMK(2$}Hk%<%((d@s`3SVE>R%y=t9f=~Hxd2n zL^4>d-vgh)sA_IlusOykt%ZHzC7KjHrDknh9q!9SosarPZ&5(;#Z?D`Fq!7=NKhp-Tr``3Kksjt-l zdfWXmW=(>U2-%iipkrJ3w;+`;J2YMydZjQ-xH|=UvIUWu@K`e~NAQZn66T20-X)5H zc^aavJc2`cm6-(y9u(*`MA>PewWx3c=VZscZbdotRM@q4?CU^W7$wOYL$AVBSww4r zgxDW8srbGj-bU_Fc?k%w7$JI#t@SYH@}LrxqYCC--o=)4?#Nemfw!^2R?8Voa3#@1 zx`Jeue7jn^`WVk!lm2&pTZ>}bmw~$66Y+e>&cj>7xx*W=iLY@lBc4F8;Ju2VWFi>S z1LK}goj-lLJ%9T2c`s?kOE#*^b~vOw0K~?&h+9LMeuG=Y80{_j?!5W z1BkjZz=o!oPNqv|L|mc&U}gcy_LIMGioYNZo=#2w-l+TUGdDK)^8E(?&vfwlOl)`C zz2xZcYXX9S`APkoBoig@Tgb%)S>TE{GNp*ow=zZVfwzbfa1T-q4|_RsFkclx-&} zpgK9+zma{a7SsM}o6j70*?K0nHu2!s#$duX2*(dO8QQ^QLE_aNy$|dD2_(T2ND5hk zvnjh^=ByF`^0Zm9;NV$vDDzM=37`zK`Ntp~i*!?V$=XYJg4d8l=35ROyn&O3*Kp+M z^f+*>5Pajs&maSZVc=2-W)+UpLs4cgBQ79AZBQw17`FiQoO(vB1TW;cV5Pc^X z1@v}x2T7dAox|02-{1D3jQ68bKO(^&k-)6!5BM?;5A|h~{VyKTuq7CyiiP!?Qm%gu z756N41pHn`?fsV8vOaQ|58k^>{bO{aZy~|NIbae_8%n7XCNCc zVkTY0pP>v^oMBR2RLk#1gY!@S=%0_aj^<3(1Gi=1!WuZ@)6T&?p>I0}U2e-S{ABsX zCShaU$eCl*4B9`+FbC_BgJI!$_qWX9PFG z#0o;RdcC;N#2N_!>xgHp)#*3xV3rAs{+zkLU@~GB4#GJcgmcvlh-;h4&zAm65{++s zTpDsuQ1e|7B%lP6?%<+_!8rZbY&Jj=7sZK45HS=k{PQslUIGo-Dn>R&hKuU}l@JnM z!?t|UcnAK0a#jKPz!>8q03>hP5n3QJp$C|`Y2nm>?ZUhq%q>FAO1=S6_||~|he1KY zd;F3F!Hj*g8M;wXLV%XO2$!iq<0e6l;;aF~>jVwC^ejp{1fUahd)*9OUUSOkHbNNU zy?j)ljZxWI1VK=5A%v0Lr8i1ftRzOWynS6v-s^rBb|vtFE`gFlkhtILB|!-7w}{ro z9e_6lA^_96cnk~P>+&6l_#)De_!0bS1owvG_47L%+T3}7aVyunHG2kdAf-D1ehxQLrc0?8%6b0BUi6FZeoY|gkmU&OWQwezuC4|r2L z3=OQquXrv#IoeHnpaA}GHogfY^9H|e(VymuMfZ+OAbQcQZhaGG4PVX7m*m_etod6~ z6UA1`a9cHGq$b-{Ep7DmD!faH7E}fObIJ@!hyaOCg5jtEkw=UwsFugrnCys^jPv&# zN6g&x{a9Y)(ml}RvL!1AYc`G0&x`~k3OMt~;nYiL;vYHQB=o~5=El&$*zQnijuhtw zyd)>Si^~KA`4Fy<$PlAAL~>M)2y*i5^EzSA#%ag7RL%v&favC>Dlp+1qc;-c*v_nn z&8MH6cj96a(UX9l5LJs!{k5&HF0Fu_@QwT5ZlnJhuVag!K*4D7KSB%fhXGGTl^lnbwg5^NH+`NE_G&pP*xpeeG#Rh}l9u!IKtQv~Wp1fm+;nu)c7hJ_m4; zjA0MQ$fHGG#;~LOeaE0}%fYIpvnrxR0oE};9E0_U36JVasE1FpiuiGqTQy!b24zs7 z1Xi;cWm`@{2~P*U&IH$Y8B%`M*Lc=?3arqKHGa{u-3FT%LFm zP4q6E+q`ssb75g&^YfeEJR_PMp!DT zxqn{be7je}Cy}k?P&80=a`awn^!r)`)zVTw3=vr5tz}Tnhq#!QY@7q>Mq7o;)rmBh zwkZOQEr-J& zKtL|hKinHc^taZayOS>h=mKa)$diu&^2||-$Q6F_Euha=y0~&fE7zQIZ=}D-GREAK zuLs{|?j0uAnJ~N};STXT5?D@Hn}9P$e>By#C@ERc0j3H2#31A@B%e6E2yPhxy~`Y8 zfpBxcF3+sk)@Pl&j%nR>hmVC{$12>JDSKHjfAXQ@70Y2g%hS5)S*wt~^EU^2t5`Td eHG`5CZ*Z7Nod+h!Z17M?Ck8$ z&EGlWz37a9Y@)0Y4lUP0jaZigb>RW21tAbTL4g-uKtehXNL(c32Y5l^p;B#rzyCKg zJ8L_Yn9=#pclm#}|M$Otr#dl_Gw`#X_>%vfbB6KnER6rsC_IZNJYgAzZ}?`@D4X)O z%9gzCvW>UZbXuu$$|Q{4bX)0iTJlaa)5?~!tz0?R%9r!4Lb=eIC{MJC<)UeP*6>rl z`>x@;J63tpe>$+%rhMB-bH&EHh}KDe3at+e#>@q?L3YJDVxaX9TBrTPXq^j2^?|uI@6Vv+0e==PN5)zX zN&TEZkNO1}Z#vK`J96Fe7yP5BKZ;dwgku=tBmQyJ977HN6RgLq4+7TnA41DVf?{yU zKf%?YeHx_+z2D)!+2TJ8_(%Ll0e^h32jYc$f9$&LKklDI%Y%2dl+bd@|0r6#u@28(GB;8dVzgLS&Zel;z-b$QHjarmsKO3|IRgF3- z$-W!~t?#VQ(|X$-Ra^D2+_0m<#ZoqLuT-18AWU+Nu+a{qYP%LB z4tpn7yPKrjzVuE~Kwxb(abK67iQ|);+zL8GVaZOMFleqM?)pYm)x*TRa??-?*xwId zIJ4Sm1!sUK6>OZj)Y)h^J5@hC^V)N-oN0HWV7b#-56>(vo(Y3GQCz9Ex1v>?)R|t? zXohF1t@>hjE6G$UjdmldRAz~}A`-(an3h?*<672j*WNxfM8o0`5{t6e0}y@#bgl~M zQ6u_?Kju?-F5(GKA&Cs2-;As^J95@ig8pMe^v70gthupy0kZZrLrnm@lva~?sVOFh zm`pP{%w&cMS6Z^w9P=Mxa-0c~qLZV@CHYDv3?itbsD3MIrjF<*v3YnB})$} zyMmS&{g~$?yJ9QW4Wjvstox=3AtCF!X^pMx-X-}#Q!nXZj1bEB=7EKb0BRpPIdz<# z{0LFzp{O2Wa)QadbJmAwcnt0Ln{xzkfTPG=qYV$|+%s+vTN=dd)XVtUJg(G29)Yua0Nhp zch^{(z{tgQ2a*Hius=5Y{KNYu;2gs3V%-N|ruNJw4n0W%p4=09e@gR7- zA0F{mtJF1KuRRdUP%P0_H}Dd3u{14&DYU00$w+i0*=ktP5?Y<);3M%uTt`)iVY}At zB`J~aB7u^W&w&#M^0b*`1GG@h2TAsIRKK7U6uv`AU_+~gd-;WUC%;_86aEvDk^cJu zxutK5pf#_U;Ym;`^dXEga?jI|7286Yr~?OOB0Cl_4f^#VuH7MYY}a9%_|`QF>~UDl zMPG~3YuUSC?R#Kz81?%@C_VEWxotRJ-UsP|7gAAyt0I34xI(mYDZQj9mcSae-)Q0G zR5u%;I)gfOmI;aYj6f@$PqNdeIFzHM+GkktX(szr zx=*Ulqk)YfY2GlW%y~Qo)00OIvoJbLOnrc7hU_fo#wme~p`1IVnHqQ!+XcAG#nxU5 zQp&}Xz9FzwotH&sh$bjyK9u;?$1uElk_nl7NUtBZtt1n5D#8aZ0XlS%kXQ@mm_;yz zC>^35(XO_FN+rovDlHhuP38-g%H>|QIcP~&Dt@O{sVK^=#D*QIXeg)$nLNzo5hfWX z)Er|v{{o9dw7SGZkiVahwNxGnZHAEMLdJAb#q_j0Wl!0&_Vnb`gtY%Bn(%w_*-|P= zEq69yon>mO)2$$%WMIhFCFj77Z)zqZV-0^?81W#7^#L*wN!LV5rm&%8?7ky))`7ZI zXtb+MuhFX3gGGqw zPhguOD>dsuqAj5&aTG9gi_EBKZG-GY$ew3gY6X^PB%P8&z4j2lL0H+`+T2Rgjj-12 zgh7&TcIpji6j-KyV#7SeD74?z=b601gmaaP!x1WCDVK}Ga~hW0K~+@*h+8dBN?CSL zVZ|7=kWi8b8X-@)aN)f=dX@XfFX*RRQfJ0Je2_iQ{u7UVA& zywLW2kcJPvhZF}q&o_7#DV%zw&Y;d_9S$$Gbc+LtfV;)v)R&mB4oLqBUgg0WHdgQ6 z(bRBpzz7z(vHB*!Lz1fj?ZWLS)?gbsal`HaBb=2C`8b|%8%bpBVk5q}?x+)y zxo`Gpb7sd3Pr;6NXtofNo82*Ytpx+l)>wHKE$a~Ha|T>7C?Xn6Z!cUm-#6a}?Mu{T zU?6?jFpbM)F@Of5{(!I~F1S7lvhV~FSUwAV%njbsD}eE~W3F2=4%8LUY6(K=Hrmag zjiT5bA4O+WIrOPCGKDg-7wYihH@oVY8mHiVj=BAi#u+SD@$bj%pds}3)6WmQTMYzv z)ob;_2xiKa0M2i>2fvXV*lY?%hLDsV&tUPM{ZYP71`afvoot7RqvCCDuvoQ>?(4Pe+g(4MqYeEWvI$JFej9agyW9!g9_+zx9HW6?Fw z8GR>q_G}2~RNa7Rbx2eo)IaN|<5Zj)j9^Od^g703uAhlgdoH31Y2U1yab_)pCkt_G z#~I(AGuQ&l%bqa~w$`~W!0Hu29%L{+RzjK3haDATJdO{>vDAOmjkd$_`?)x`my5IX zQuYdXCT4irPw(Z(|^eieiv ztmjuh|H5g8X{ZHsh{gjW1T}_mmbYY#?pCc+>v$0i;ZR3ogtydJaTY>@OSkUer@nwP z&d(u%CvL6NLwu(+nWVziDh%m#H&Cq}4CwR<0z^1tL_#`kh@*}QIFRU(=b5WB8R?(I z4K@Mo@0&58kdA~?+-kRrmQr4KA!Q}zYGNY(<~LSWpg>khvj}t`wgQ!JcY2Yy9py}| zrz#D9GlXo7HYBIkIPF2(B01bbbc@1O!K~?Si5^-9oN%e(7R(utZpL;%zQ;_*nlW8- z{zujwCp~X2n6vn+D}V-4Yz%O(;HBk0GAziSpjZggf&C{&Mr_SMtkVAO!&K;lp8GH& z{L~S8#En5K(DdcS*xsWTV8U>Kjlbv6lP{&W&s+o% z;5@vlw+z*Rn>^Ll9Uk5-G`7T@^J=){{{5a_O?}kRqEhh zfYqIW^8^Xb&?d<8YV?eKC47Wr|BnC(J))-tgImTq)Ea?r{RyQ zrR^yz{vt9E3UHI=jCE%XE04@IA%;O7A!8DThME6yFb>J&D39LW{##+|!TE<;Geyy81c1 zM2H|V)sD`6LS7#`5AuA6Ejvu!MN*zXbD%C|aeHmP?kfhO+5;>m65Q zx7RU9NP;l(1=B^i)V=N8u^hkZy{ABzQ^G7I9eZ|I%anzt~ zvnOU`Y{GyvZrHG(j6FDnP#pM=lo^x-*?PZ?%1$eg*pN2*C1Cs5fg2B1bIsa=2=Unq z$_>=uZhv}kg8?m3_0C`N>Z;-M%)1>RD)fRYfrP=k)f$cln&Qo~r_a5;De0m%LC7#m zN)F>$WQP(V(?s~2-4G~gGy^CCL|WF>HbGo<6|cnAY*3!mlT|uxpc~jKX9g9Uquf^F z>bkp6YHGRHX!;dttZ?z)!^k@9D>C93MDv2Feu>jMdfH68(yVTQ2$IPV=M2$~Q<9o{ z8Ingd5+P5c04o6qr&T`(;Q=Xp0?r@efRM4yVH_Sj0LbJ$n!JeYiB6u1ZJ^rPKDCc( zS+3wZ)XWAGk*$uUm7jd-lGI%K6 z!W<-V9e)eRgL0yZOVN*o`og^_Of6f?HxoW?s_uw!F9$HRf;-#5f3 z$=p?6fH|0t5!A&oKZ`X_^o#LCT*SI3c5(b%JMzU{Sm!uwtM2qCxs!NuxC6aPmM;I* z<(J5ZEA-g2+e+l#K0Q?Beu;c1Q~<{g{D=|_nLyzlNBo9x(R%^WN=+m(0t2Ea&!0Ng zD4jZW-qYQ9Jx7V!hz5{jR@*&SHl**@ctpHSiHyjkm?3q~1C=p1W*&vLDQePM5{8GN zPl}`lk5&_>7BreAyR5?~aJI_nws_@Xk{*!A-7KqPF97{#PHZ7D%U3%Hd zzRx9YK6>n1od*;>#~_!|iX01%Q;^byq8CoW7>h;2lo7UCL=T8+S0>m{zk#A&E}a)~ zoc@&=ahx9iu}l0h;qm|9($sHq#sdpmz2CzAZMZN*wd_`-twqGX#Bk7Fno_?)FcB8N zid>SDoeFxBoJbq7e{rqXs^Bz{Zj^JdS5%}G7tI^6z7`ebslnNf6k9nz9Ivu@cC4^< zHnBGG;B8|v!6^vmn_SFqF%e9Qhj{otX#G2YXe&FNqj_gfn>j0EGh#jiv0kv|&3V*$ zF#M)1^?QKE;QLV|aUO-R(8AvVG5#VdV)HGBj$_{$#_qtKbfz4*Y>4df>ZqTBY%y@H z1kDVG?3p;VmlD^KaXd<o}W0WY|QNnb3T^>1tO-yqJ@$r6aBAH6&TwTU_o1!FB)z zxWCwM5#V7uoxHB=_5P z-TKfif9l>XtKa8ze}Dv=*q?f|33Ze0_t{1=F$fB3t@OVD8*OM-Z|J)mbzre^Log5+ zo5d};yf?I#;f994JTl}i*&v+@*&v*xz6-4a_ZjDQ7S51;86< zTfwGjX;u+O=~n5izJoS>Tz|;iA2AtW4@`3oOmnWB0&rzhzqZsL6X+mSB!vMw#W=qM z0B!VuM?CnV3yl7K54BiLj8~Eur$w9I5Zwq1+^e%mKg6!+>;xI^u?d|`%kv66ug@9p zz#H+Mb&Fw>*mw?hDsT7^)}YZtw=;Fa0%rmHg2N27H)7qs{wyWUJi;eAi796MEVM9m z^H2*%Zk(66qov+~pO&MplmJV9&VU|w!XL=hYF_#w05TTGy;E2H7`L8 zrDPh$%9|J}RDX&yxJ4^Uo7WUD2lO5fiU&q|3q2$fI-J)LgcujY>%11mjrrX09`L#=xKHVu5gmE-l{)EtffuuYE2U&qW zDsoii!bV4}!;h8l1Z{4;V8%d%c*O0&bd-w7BQ_S6Ck`M{p9XG7ImMqd5v@3LKZNJ8 zaXDuroq-z6z&xKp9Ca2tfiV!roJH*nVtk=MX@5?k3$^`ez0 zo`xHd*8xR|CkjwGzoNdl^@Zg%U=zuEfuOb{rE1`i8$)HcrE>}Q+A`DoO9(szti|^t_#Yd^QS-Ja6Kk!!tpHq0Kz|(59{{qQiHnTkaSRY# ze*1Ss{$3mgY#?F9dxEU~%2&Vg)nQn>st}RZ!6mp+T6!~>A^*1tCBr~^OC)n3Z!Rt_Zoahn#%U4kh$?)pb_?B9m<- zrAbZbv_1EEC&@P26@02$UyVd4m8VAKO0CmTepoInccKWvP2Qr0*tcfNnJW zXtg{!%nkeGVW_P}gD$ydurkaS1f3QAIg9FXEq56*hxILzFCvtYUYy?&01u1|X5P;^QxJetHm+{1jHUh&eW1)f$Jb->$%FprgGg1sh6ElfzUY-t zM1Tv#0<|RK$>3^Xs|#VSzfYjgpctdqpC)9PqdbvM3iLOLK8vq1d4tKDOjek@%;XA_ zX(o(JOK@01gA&G(u!?vtV#llR66AYKo?=4zE#gWxG=?%ko@j{71^vOrGhAc%5|WIW zvGiv}o?ZM2*Pd9sc5HW=oZYut!6{^0pSM_9?~U@VPKrEETXyM7p^}Lx z8@h;wNg5?t+!x>qxDdl`F+9ZFdoUWNqDumnpEWE#!mIuTB8qgU8GljaeWCRIgp0{v>mi->zEFWNysKl`7dlEs z-q(Y>?}V`m4@8>MIF!X>@F*iI6M9f^74`}l8n-2RKkOET|nTb$OzFg&J`qNbof!B2Vp+ShKFG+WeiFZV%;4G4qat;5NeUDLaMeoTsBv9rRdQk znSb835d!mv=Jf92RTTMCHnk!IFD%6nfil1zLeEtq?U}_hOHwd>b#Hf2WMUUeNO7?H zNpX;;MaO6XOkN* z_@+H)D?Q_v74<0_J-kN!?_x5wp5hAojI-<2h!B`s`}T=*WNYWCt3B;c*)f~i$My;3 znlt&iu@?FNr{kgMYPe>_%m)eVt_JZ?grAKcj`zYO4@XI^KOo;`v97F5=Ou#%r$kz{ za+|_rC0Od}HC?ve`gr-sSGtiQ)%YX83dYV*pp?8L!0@5t9S8TP58-4k}iri?q_@~7@iYa6cM&WA)@?0HjXw&&6woZwCV z_!1`+y|7BT@IA6W*_3 zAa2?6I~d9uTuE6sPl*tl#fG`8> z=YLvBve*#kuIUx>#)r!*`*8~C25k2 zSbqAT<&-Wlmd!X#`gznXWG`*Z_HbrT8aJ@9Qy-A{eL*N?|ls<`_l_S5U*6{ z&;-GGFiQa@-x!w3>$tk4D$(-KqnNFW(lhp!w}|UDk#oM;T85{}%!DTuNWnP-QX6xS zdgC%DRUmbNmzGmc68Zsw&N5)0N*7u2f5V}LLtX~{gpZP_Hg~7))aw!0JClbC`|K6jaHN3~z=~p| zF|X=nKY1*2RL#W(b)HpkV!Gq1HLgPX3_to%5zjyF%_e`lhd$ED;&JH!3=oO$7Z@n0D7d2n4SE%RA<4iFllhRf{db-b=!|F^xd zvZ}w$%-PmDx(oVwVAQQN3QT6V^ITG`lkb9bUNd`-6Vo&;N`h*i{nR{bM@yG}R?hkz nbUWzftDBRuOZl+4MHzfpnf1ri1}kb~13!S9Hdw=Mur2356zT54 literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/utils/__pycache__/transforms.cpython-39.pyc b/zed_utils/segment_anything/utils/__pycache__/transforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69863951e360989914756d564119636cdf28f376 GIT binary patch literal 3971 zcmcInOK&5`5uPVWQKVLqm0hpBUbkb%4uK76I}YrF5j&RS2VfgW{ir`PrH^in?yx}i&bep<<@-D+0r*35g5 zu4MIY-L%7WHEVPmuJftGE4=#9;ngFr+vM(Dr&T+_+Rj%kx2*4u@}5qLJWg9)S&>2| z6HzvF5e<_xO0u{wFsjCd81{RSk5wv6q%U0(#;$)pY?quwk7VBS0QCH# zHr?B}75VqnojMU*6t;5VXR-zzmI{0gloc=_FO1#pzUgcu>}&D9IBYMK@?8W{2+N zg$q!%=fkk2&;TV@FyYLRC=Ns1=16a%F?Al}SbHvaZ`vnZW9saA$Nr(G{l|e0b!F-v zxl`}RJEkb}C!ZS|l2`nCJQO_*z|I(dKLHwQ5Dx|WYh-b037|?v_ zVh1NATGeukoM$Kne5EIpip&aSXdcG6rI!!O&;Tsy95R4-;!8qCu0fQmGfXMm4G>l| zOkX@4rUy3oUw}f%{*TJ!S=uTbkV^n1tyy$EJ5$?ZFGJr0e z`ZHZXI$#k_{_+egid|;ok-beYP|)tu0T@rNrmVzD9}Gnq?jk1ild8p!Tw6X5(G81~ zM+5uV?}oY9j|QcA?4QT9VR5qPyn@M!{>uq#Zs4vzt@&&27Jif4R{?A%0^rcljro5Y zGs~z_n~>ect$qw)W9Bhp21x*sGWCuDBFpR?nBd9^F484}BX>C=2P$hj zNt)y$mY>{%!v!?3Y{qHQ&!b);b0Jl^66ZRBDpvET*&viC%araXEzX`V!@c#&gw0D{~1pE6IS}~IBhTdA6V_L7n~qI zlg62Gg7d_df>{1%hqKxz+giA28GuzD~I5Zn_QpCa+)g&C3e6)^k+0=5Z#X zsH{g3_->R^yAeg-jN)|GQ;i~C^rA=-Q^_WYP1;{I=eE$hPAWpfBsBRtwJ5cem3yY5 zmu=D*7kZo24W&WzJ^Q&}1NGFQOpB^PNwGqqrzl_%eU){8t9pCwc5SP<<_7<6Y#3Sh zTr;w@h602>FU;ob|JO{wxAR<5)|Gc4x~~}BgT&NTi%g-~r=O#zg=-nmr{}CBLP3P$ h-=34UwiE=5J(sojtr34oBNQE@=GFkox(8tV{{W#g^*aCn literal 0 HcmV?d00001 diff --git a/zed_utils/segment_anything/utils/amg.py b/zed_utils/segment_anything/utils/amg.py new file mode 100644 index 0000000..be06407 --- /dev/null +++ b/zed_utils/segment_anything/utils/amg.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/zed_utils/segment_anything/utils/onnx.py b/zed_utils/segment_anything/utils/onnx.py new file mode 100644 index 0000000..3196bdf --- /dev/null +++ b/zed_utils/segment_anything/utils/onnx.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + return_single_mask: bool, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.return_single_mask = return_single_mask + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + def select_masks( + self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Determine if we should return the multiclick mask or not from the number of points. + # The reweighting is used to avoid control flow. + score_reweight = torch.tensor( + [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] + ).to(iou_preds.device) + score = iou_preds + (num_points - 2.5) * score_reweight + best_idx = torch.argmax(score, dim=1) + masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) + iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) + + return masks, iou_preds + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.return_single_mask: + masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/zed_utils/segment_anything/utils/transforms.py b/zed_utils/segment_anything/utils/transforms.py new file mode 100644 index 0000000..c08ba1e --- /dev/null +++ b/zed_utils/segment_anything/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/zed_utils/tools.py b/zed_utils/tools.py new file mode 100644 index 0000000..bd53aeb --- /dev/null +++ b/zed_utils/tools.py @@ -0,0 +1,183 @@ +####################################################################### +# Useful Tools +import cv2 +import math +from segment_anything import sam_model_registry, SamAutomaticMaskGenerator +import cv2 +import numpy as np +import random +import open3d as o3d + + + +def draw_roi(rect, image, frame_width, frame_length): + + x = (frame_width - rect) // 2 + y = (frame_length - rect) // 2 + + # Draw the rectangle + cv2.rectangle(image, (x, y), (x + rect, y + rect), (0, 0, 255), 2) + # cv2.rectangle(image, (x, y), (x + rect_width, y + rect_length), (0, 0, 255), 2) + + return image + + +def distanc_cal(x, y, point_cloud, depth): + err_pcd, point_cloud_value = point_cloud.get_value(x, y) + err_depth, depth_value = depth.get_value(x, y) + # if math.isfinite(point_cloud_value[2]): + # distance = math.sqrt(point_cloud_value[0] * point_cloud_value[0] + # + point_cloud_value[1] * point_cloud_value[1] + # + point_cloud_value[2] * point_cloud_value[2]) + # print(f'Point cloud distance to Camera at {{{x};{y}}}: {distance}') + # print(f'Depth distance to Camera at {{{x};{y}}}: {depth_value}') + # else : + # print(f'The distance can not be computed at {{{x};{y}}}') + # return distance + return (depth_value - 50) + +def blend_images(front_image, back_image, alpha): + """ + Blends two images with a given alpha for the front image. + :param front_image: The image to be displayed in front. + :param back_image: The image to be displayed in the background. + :param alpha: The alpha value for the front image (0 to 1). + :return: Blended image. + """ + # Resize images to match if they are different sizes + if front_image.shape != back_image.shape: + back_image = cv2.resize(back_image, (front_image.shape[1], front_image.shape[0])) + + # Blend the images + blended = cv2.addWeighted(front_image, alpha, back_image, 1 - alpha, 0) + + return blended + + +def vertex_bbox(frame_width, frame_length, side_length): + + x = (frame_width - side_length) // 2 + y = (frame_length - side_length) // 2 + + return x, y, (x + side_length), (y + side_length) + + +def segmentation(checkpoint, image): + + model_type = "vit_b" + device = "cuda" + + sam = sam_model_registry[model_type](checkpoint=checkpoint) + sam.to(device=device) + + mask_generator = SamAutomaticMaskGenerator(sam) + + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = cv2.GaussianBlur(image, (3,3), 0) + masks = mask_generator.generate(image) + + return masks + + +def detect_roi(masks, iou_threshold, frame_width, frame_length, side_length): + + bg_tl_x, bg_tl_y, bg_br_x, bg_br_y = vertex_bbox(frame_width, frame_length, side_length) + + white_image = np.zeros((frame_length, frame_width, 3), dtype="uint8") + baseplate = np.array([-1, -1, -1, -1], dtype=int) + + for num in range(len(masks)): + x,y,w,h = masks[num]['bbox'] + try: + if x < bg_tl_x or x > bg_br_x or y < bg_tl_y or y > bg_br_y: + continue + elif (w * h) > ((bg_br_x - bg_tl_x) * (bg_br_y - bg_tl_y)): + continue + elif (w * h) < (iou_threshold * ((bg_br_x - bg_tl_x) * (bg_br_y - bg_tl_y))): + continue + else: + # feathering for 1 pixel + if baseplate[-1] == -1: + baseplate[0:4] = x, y, w, h + else: + baseplate = np.append(baseplate, [x, y, w, h], axis = 0) + rect_color = (0,random.randint(0, 255),random.randint(0, 255)) + cv2.rectangle(white_image, (x, y), ((x + w), (y + h)), rect_color, 1, 1) + except: + print('Baseplate is not found') + continue + + return white_image, baseplate + + +def grid_eastablish(baseplate, stud_num_x, stud_num_y): + + baseplate_width = baseplate[2] + baseplate_length = baseplate[3] + stud_width_original = baseplate_width / stud_num_x + stud_length_original = baseplate_length / stud_num_y + stud_width = int(math.floor(stud_width_original)) + stud_length = int(math.floor(stud_length_original)) + + if stud_width == stud_length: + stud_size = stud_width + else: + stud_size = int(math.floor((stud_width + stud_length)/2)) + print('Stud size is aligned') + + return stud_size + + +def baseplate_aligned(baseplate, stud_num_x, stud_num_y, distortion_coefficient): + + # Calculate the index and position of the middle stud + middle_x = stud_num_x // 2 + middle_y = stud_num_y // 2 + + # Initialize arrays for stud sizes + grid_width = np.zeros(stud_num_x, dtype=int) + grid_length = np.zeros(stud_num_y, dtype=int) + + # Simple sum + # if (baseplate[2] % stud_num_x != 0) or (baseplate[3] % stud_num_y != 0): + # grid_width[0,:] = stud_size + # grid_length[0,:] = stud_size + # remainder_x = baseplate[2] % stud_num_x + # remainder_y = baseplate[3] % stud_num_y + # if remainder_x != 0: + # grid_width[0, stud_num_x - remainder_x : stud_num_x] += 1 + # if remainder_y != 0: + # grid_length[0, stud_num_y - remainder_y : stud_num_y] += 1 + + # Calculate the size of studs with distortion, except for the middle one + for i in range(stud_num_x): + distance_to_center = abs(i - middle_x) + distorted_size = (baseplate[2] // stud_num_x) * (1 + distortion_coefficient * distance_to_center) + grid_width[i] = int(distorted_size) + + for j in range(stud_num_y): + distance_to_center = abs(j - middle_y) + distorted_size = (baseplate[3] // stud_num_y) * (1 + distortion_coefficient * distance_to_center) + grid_length[j] = int(distorted_size) + + # Adjust the size of the middle stud to ensure the canvas is fully filled + grid_width[middle_x] = baseplate[2] - sum(grid_width) + grid_width[middle_x] + grid_length[middle_y] = baseplate[3] - sum(grid_length) + grid_length[middle_y] + + return grid_width, grid_length + + +def draw_grid(image, baseplate, grid_width, grid_length): + + # Draw the studs + x_start = baseplate[0] + for width in grid_width: + y_start = baseplate[1] + for height in grid_length: + cv2.rectangle(image, (x_start, y_start), (x_start + width, y_start + height), (255, 0, 0), 1) + y_start += height + x_start += width + + return image + +