From 1084f01ebd196849aa6f8d9e205308fa6253c07c Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 21 Nov 2022 18:21:59 -0500 Subject: [PATCH 1/4] Add a warm-start example (WIP) Signed-off-by: Fabrice Normandin Use more samples for training of Profet model Signed-off-by: Fabrice Normandin Warm-start example is a bit better Signed-off-by: Fabrice Normandin Add more comments in the notebook Signed-off-by: Fabrice Normandin Use Pytorch mnist/fashionmnist instead, bad Signed-off-by: Fabrice Normandin [dirty] Trying to find good settings Signed-off-by: Fabrice Normandin Update example a bit, seems to work much better! Signed-off-by: Fabrice Normandin Try to make the plot appear Signed-off-by: Fabrice Normandin Add the pickleddb files for the databases Signed-off-by: Fabrice Normandin Rename things, image still does show :( Signed-off-by: Fabrice Normandin Add a thumbnail (image shows up now!) Signed-off-by: Fabrice Normandin Removed the torch multiprocessing context setting Signed-off-by: Fabrice Normandin Make path relative so CI will work too Signed-off-by: Fabrice Normandin Add tutorial to index, add text to describe plot Signed-off-by: Fabrice Normandin Move databases to avoid any possible abs paths Signed-off-by: Fabrice Normandin Attempt to fix KnowledgeBase config merge issue Signed-off-by: Fabrice Normandin Update the pickle files for the example Signed-off-by: Fabrice Normandin Seed all the randomness, update pickles Signed-off-by: Fabrice Normandin Fix issue with portability of PickledDB objects Signed-off-by: Fabrice Normandin Update the pickleddb files for the example Signed-off-by: Fabrice Normandin --- docs/requirements.txt | 3 + docs/src/index.rst | 1 + examples/tutorials/README.rst | 1 - examples/tutorials/current_db.pkl | Bin 0 -> 126872 bytes examples/tutorials/plot_5_warm_starting.py | 380 +++++++++++++++++++++ examples/tutorials/previous_db.pkl | Bin 0 -> 14841 bytes setup.py | 3 + src/orion/core/io/database/pickleddb.py | 14 + src/orion/core/io/experiment_builder.py | 2 + tox.ini | 2 + 10 files changed, 405 insertions(+), 1 deletion(-) create mode 100644 examples/tutorials/current_db.pkl create mode 100644 examples/tutorials/plot_5_warm_starting.py create mode 100644 examples/tutorials/previous_db.pkl diff --git a/docs/requirements.txt b/docs/requirements.txt index 38bd1e244..f069d27c6 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -9,3 +9,6 @@ plotly matplotlib kaleido dask[complete] +torch +torchvision +tqdm diff --git a/docs/src/index.rst b/docs/src/index.rst index 93a0ede07..2d4b46a9a 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -38,6 +38,7 @@ tutorials/cluster tutorials/pytorch_a2c_ppo tutorials/speechbrain_tutorial + auto_tutorials/plot_5_warm_starting .. toctree:: :caption: Plugins diff --git a/examples/tutorials/README.rst b/examples/tutorials/README.rst index 63fc63ae9..ecb76636a 100644 --- a/examples/tutorials/README.rst +++ b/examples/tutorials/README.rst @@ -1,4 +1,3 @@ Examples -------- -bla bla bla diff --git a/examples/tutorials/current_db.pkl b/examples/tutorials/current_db.pkl new file mode 100644 index 0000000000000000000000000000000000000000..d1183d3596f09c96beddbf0574ed7a8f814ccc0d GIT binary patch literal 126872 zcmeEv2Ut_f);7KODo96)Adp_{qKJy3fW2YL4gqP>6csxv8VedLDwcEXVh0rzM6n}^ zdIY=JQ4vw>&*37)K)+~)xb$wM;1po0|;u0d` zVtrY035aiGoG+V^#9%TK5nm*pi^L!ajA%BK&(jZR_*Wl){z`tw+59xVCr^VLmzcy4 z=4r$-Vi3XK@o@=)f9oVCA_?MqI`NFe#L009Y`Cu-7oP;5l?a7->fyqAS#ijO6Gfb|fnaK835oGjIId-i?OvPC7pnK3*T05|1QA#vrjtiF~f9_)F=9v*TFF zqMPIRo=}NtV{}^Q;=*5#k@7 z0v`a_@OiE&IA4M<$u}Ob#S8x=kx0I5Bn9UkL1l9A2#!SO^Gq9yL+wZsfWYVJNPD@G z-&;@nxq-R4iH^=)4XU>R@fHov?V-kIn;t}6O1Yz@oRCtaO z_hZ8&VUO^4i0)J^!|NA zKp9WS_SHp0Q+Up0;2T5tp3A59*rCOl0zW<)acA6B*3_uRk>$W+E`9X>=bo zlznh0>pp7!2R%6SAO45AkNlaD2mRWjOqXixz*; zRm1s$I=D_js`L5W{=z>m3jd51{<%{4=O@9RP+xK~lc&c-Vp-f6M#4mN5tnN{jvwU+ z3)IrSAbbri(bs7BaKYCWRi-6g)ofS|7F6W<$6}Uvjfz8w?*)Z5t9hZwV1^& zHNldnhG@Y~BLQJV^Ld&njOb(pRvu4X{%(0R@pUqTr=5T#!g{s1Y*ZM0W1dEITw)?W zYM9`)ZX7d8u)3Xs@S{e+jlja*yVH6TPpWn9Piclwz=Pw8%i0i$DcTgptHY>qpBQ!wPr!A1i`(O%}d=Qt*kBxs)m4 zRA{}Hz!;XW0WbJ!uzm?nEGtgnlP1A(FHa4656_87VlkKqmBt{lSWFU|M8&fh3=W%# zXAp2ighl4_^w5=cDJzX4a@1Nvj{k(AH)D4HYD!KThMbv_)#5E7M1q!2h18k@nO5E)bonW-c> z>MbGX*6PrcT`8tb$=QJ+XO|>7dxYdv-0PrGJ%A^J99PZv(Q?QE@}3-(M-CX2`pPB; zQ@(A;VN%FAGL^$&lZaFzg@JJB2pi9)vN=QoolGDqNsdNK$jLMg*}MLiLsN3DW5~H7 zNzN@HIay)LnS;vfq?2Q&t}{#A@%5xV)=XWsR(+65j;mJM4LRhX<=ciFBAZHMkvMoV zhlR(nXk-RLB@)RL5}t_AadaB|2rBbf(`*SjQC%!{X~gerO3rr-IX@-I`7I=;YJ31m zqdG_iIZmpcnj&nal4GW>+a*UHIbh(lS2j7A@@+#7jl-cbs3bOn#v-#=6g-oJ;Hg9g zg-K#kh%_QBwkSi6R!hin*%p7OcFvKe% zTG3Q_yH(8jwL`5b;D3n@YjaICK`mfL|OA3pTQq zAxFC<!VM$HAg8n1>?QKZvDZ01NG>_{I$1sBkb{=2Sun{B;nww9Z8 zax83hk}t|5r?dK48@c3kR_~f2ha9wg+mOS+lc*$c&?szJ%_p-tBqo)Hqr(6q(Lnq{ zQ}VH<*Aj9l->erNI-Sv!oKg%qWs>Av7n1X?Lg#$doL(}>p{n<7*dUhXty$RW4impG zeG&==$X5i$cD0XCUNCXnzf9FZ)T^U%mN27yZANRZP+h#^`bFlY%m zT#d@P{G>Ha$#GUky){?CYpl2C1~&o=pBO}sw%e;CgB+4(ro0sbAn(&TdE|gWc}Lmg zV9K`*Ie3H)Q%wqmNh45U0?NU&DHIxxN&tU=$fh$v3HgV&CW9QqmXPzNR|YlcoRoP^ z5QdyVlH?2)lJj$0iN3K%zI1Y|EH#tXiLjMA&#@$Fbuo}j4oS<|Qw}+3`L-d4O@LJ! z5`{s+!_oke$|ACf92hAU2aizjG&YGtL8w#)9*@JZiF5{q zz$D}8M1(<-oP^34ImRs@r^n(uAvRaF!<&+mjUi`?BstrJ{4LpnLuW(L=` z${9IU1oeX-<&r~CUl}fYITtbHT$UuKR7g%WZAsbg(rg*z`04%ZBZnMoGsErAkAUhk(Pen0O9Cq|$LLItR~Sk?0I6hfSf< zXy6%0t`Nxh%`tBYIk(1s+gtgvdsA`(G34}@BqvBnPDS<*QcQ_p)JfS1wJ|g*5WC7! z9fEc~nz~+c$??&6F7_y-?Sx{=w+%Vqsj+E90-44}5E??E5r7dU8AoGtI5hY@A}Kd= zELuX2-<4qJD{5Ptk~0ZIPNF0^$wG48@ZhA%{=qWH@zVcBlrwT{j8uA^l1C00*w91< z(#XM-ZyR!mB-jW=XjBsH7Z7P2HjYhW;_*Neks$C8lqAQpCFES#yxixJVf&`!WMRnJ zC`nGXket#1b>C-Xsz@itR^Nz{D32U3gK}ND@U{a8bM)CvD_l%C`+UEZ7O9vq(6Y zuy9Y4~B5Tyi`O{~?GBq>_V{ZzFOL4vR@=GuZH0V>3u30-Hus?cFY&21Ct-ati)?2fO8-ax%^R?#}yU&(Ej=Soq zd^zL*d8KpZkpl)f!LrH0ly4hy*h~V8!XlDLu#`rmlL#Cdm5qbVQ8H{+lfh3=jvSko zki*R!b3cXKwJABhG2{eDk`pK-=Lv1)?Av$irITZ)rF&f5@uhle?cLSVcgiKlU5(vc z4moJ~wjqbe0CtEt4jiwh5a=vWNl_VOBAx)sD1=62ORh@G7&*2rA*WkLTJg3Sf{CPr zs4xaYPMjn;lZ52FPpcztFe;Qmj;o66ArZDx$+6SYTOn3HN#D-_gZRF($-$Iw8*}X5r%>%7%|15<|{JNpfNYZRL?;rg2SY42^W#D9^7d1v6Idr|Tm2%0U>$!_1qtYI0j%fL|A*Y$A3}lew+7fbP zo-%lWA?KAOIj@D}{Bl?;Tes#!F*v+k4mnOX`pZViBL@tIpBEWOBL`EyZOCcnDFYeg zbZ7}VT~kH|y^ohN2{qM5y)|>eYpl0s2{!_Z)IR9v25h-3edIV>s}1##M-IiP+gZ8f zPz>v`<&cAxZyR#ROd1{(H*6M}NawH#L^>?s;D|&9osAIiIJ)FEwG49HT0&00N{h-p z6%znpiRGIv7;?HvlG8&-j;8C>(KRe58RQUkvm@k?<7};dP&~4v|K@-}l8bC|Fy-5Z z9JW9v&7@P=ED{@z=fD~pjX=i1x($&`Byc#AD+Dsgac>DZfyd@X?u{1ARV2tkFyur? zk`pN+2e%3>J~aAmMO$ zA`wreAh0?>hOHbr=yDhmlTaDtc(jC^0qY*pbMJg;O3pG2Ihm5=tQ3*+?GC7M@iNH4 z>*xF}ha49R!?T^`kpl)y@rIDJ-yBT&HX;X(>aq}o1!tu|y~d<~1_vRsXe>}nQy6r- zq`*T4Ii4*c$6u#sPsY24P02ZqAtzswoRdOw)E|ajthxMMIytWG^tS($M-JXFK`e!q zK5|^}hVJWSj~sZrp?uqr!^Bf@f^BI!Jk8iF4u?zwB>@fxv2!Ub2AQek$m!S;a%7$| zc!DA4nIt(cgyg6gY7cOwpO!%mRfp*;haA^-23^eMjT|sg6CYEMHgYiKl^{oO${>6y zb0U=)&Q)@SpqZx(ngqEKtq^#%gdCZt3=DNpZ_QZn8tbi@!i~Vf-+_G{2B!#|2`TfO z4%S8+PRb*Ps+%T08zp_@xKeckxw1zNyxmZ~ZOCcnDFYegc(;TcnWqe>7;@;6E zK}ov2kpl+v<%kL~<=ciF*leS*AS@A_(q#&w-mnk~m539ZtY$()4}#==jtp}AT0+jt zQ3F2KRtSDKB*=M)A?L9qIZuS-d<>o8*+GAb403!8JvNHKmr9PincDl;^2h;$eXC?Y z)-dJUh8&o8;_#ptppr;r77+`AGkLnvYB}Z2m z_15$Sud&{mA>0To{JAe8uG~>DftBK|d6;N~i+?htljCc2uR<<4zDE3+a>zl;w+%T! z2#9YubUMURL5LtFq(M*uI+IDG!iilXTT&}6gB*NI$axiIC3c4;Mh+1}4q1{Ms*s%e zwZjg7fgCX(QQ8;(JSI4hC| zc4d%5Yza9sPZ=!0kdq-v&SC*Mh3D&2@Lx7BmqCt~`j&~JjxUuQPb0Mx;#F(u-WnLp zlC$uMDX#=Mf>Q?Z;Ve=#iyWck3PCeZ83+PrN(m!LY6&?qPZ=DIyoKnw2#x}k>jOtdWGnH8RU3rEX2uvta+m4+lHKGo-&X@4!I@d$UJ3m7emf{ zNpc|SZLz*`16?uN@T204_LkR$Vy zfujoLFw3n9Ak4ZMf6A;(Kw<+HK8#~K*q zi31f#a|kfy+lHKGo-&X@4y`5Rj1IcH`EPdtz!KitC=5AcB*_^kAg4$(ulM~mpJY7N zyw!~^YRV(W)u6qpTyk9XZ-`e;rILe|ZyR#RI5vw-fiP=yG6LEG0)jw@To#c6XWdC~ za7Xfdjtp|>Eg@&*$bwPF)H*jMXD)`EbV+g+h{$aea}2G;%QI+lCwp4h|YpfEzfH13Cf@0$~Y&BGBaGaWoo1auOi8*X$d(YGw1c$c&fPR$hnOnr&5xfyCQPf6`kQY zwG49X4K9n~L_GM5f{yBr1VH zg_J6P903uKgUC#a;2A`OJPmZX1i2f!!X?PxpbwW0Tt;x2!etGY1zfgp*~8@umlIs> zaCL&KBV2xP5#XZ1MTV;jTn%|3wBfZLTqbZC!etJZ4O~`m*}>%qmm6FzaCySz4Hq6R zU${tcQQ_(g7vz%&(uGTqYr+67Gq{Z5vV^N0Tn=!xhsznR4sdzFJ3d1Fd%oN6k0W%E%L{VUKv6-UG0bmBWDLif_GE;c)N%WzJBJ~rPY4V&D?(BdW zu%K{aBr;RzXox;k;a(q@K@U*|Uf@)rT1E8TSc)dzUGk{m&-VB%l z^$M5Mz)Zk`qF*C06X2#qe*-gxQFx&?=;R8I3xSysHwD2j!A#InMMuTOAF3#ygXlvQ zMxqhbrcmGvW&p3kbu&?I3Pqn_CZJwXOhb{G!VMzPhbk1%gBftD3|Y`pW%`?-Hsx4Q z7_3S3-IQ#C!kG7>+7t%(1~X`>LLULl1f5)w91+Z*_bT%kgt4H^^9^cK=7}VHI4C$w zDQbd3p|}VTh0_gjOBF>805fQ*!U%U@24g{)M>Loz+>aC0rp&Wf0I#B8JL1|D?K6qY z6w2-zGjeGNQqc1ir_hqN6aew1w_O*x)S7}kQ3&7PuY-!pER9ti;4i^{X^=RoF~_|y z<(5tY5)lbm?Ni}jJne9C3NJx2F2#=no;4=fswljaudo8h*ELj-CzUVE%Ni`G?GI;K zRh0287(P&1J}dMfs)-=kakEloVQ}H;umJ;MmM_f9mHfth^m+%SF1%Qy1`P?l{+Qxm z-fyS?BR*(+7tJT86ON`T zOqkh)HzrYo+Vw@X4-w>5Lo4U{@B~@axPClQLNzWPO{FHR1Wn1N51xEn3X+fz$wuI3 zAsOCQ!YknFL4HjE(_F(KRjyGm*BG0kjB6UqHRH+8a>ljbqS!U0^WwJ0+-ykU#dXBo zY_xQ3yoqKD<9dwabG;hAnU6R>9H4+f6`&|=NIoWjf=h%l6tqke+#1dw3yr%5bGrqp zi9k~Li`S|kp=q%&Afc(WFd$(pF(5T&?NSz`h9qE0fz)sg7zIgP+>;bdds-Ky#`IKJ zkQ!53VL?JIQIOE=S;~Uckc>+ykQ&b8q9AFAKvL-Kv?@qwnkWoNXlf}8NEk~DNR3&n zlm)3FDV9RizTB0DK zS&ozisUc~SQXn;CXF@^J6@jF1U9wd{Leu$RKtfXhVL-xIVnAxlHl!>_4atj?0;wTO z5(<)@0HowZuByOKR8Qk`HH8xm4xLPcBd~BT8Nre9EH;b4VA1GAA{;4aQ18gNtWxpPr8Od1__2{gFATt6;O;3hUC5LTx7MJn(8 z{#Mz)PS9iqSLh~+eQ7+A29i-|B#wq7K@J`U1BT=N8*s)Dms3OedKe2F)GXT(7vu{{6Kfxe0EV=8C;yDxCm9>1b3OhbT~D!j3+f zDhMLTB(MoAh@dX5xf4XqWnh|%abnRqK`Xes+)2>f1VM9~Z-1#p=rD-cG-=J#BoA7ff@-J|~1D;|O-z;n*>OgFv`n4jYb#6G%)p zi$;TV+tQkQLe$)RY;#c;5!2k3JCoeY(A+D6<~GNBEOO3UkmU007N)sk?~%lSP7BW` z77GtKC&-ZEfC)J%@H9x5OaO6?)UjJ3YHlT_xfo{;ZEh>Lh1^%r+-gB{o8t!-buj$$ z`c0Y-l`x)C@)NR8wx8)9A zls2Rk=K4kH2>yabH^;9latWSxZjbFrOrynq9q@Gk->w}a7b?}9@b zSF3Cb=i<_B@Rw3^4jwyk;>$!|+lbDY+_{T~2bIvzhoD?>xOsqW#jJXo#{xc@CY8*Ymcqzk3O?L*e{65NITN0v(H3LdX&SM+P!bD z9gqF08}OQYHkAJ@G{G=?)XS&@E zsw|ERNA}-kUUWZ|oftB}Z`ikk4!cW+r}bQJwP5IX)rV!fNf&zVxE5`&`%#^BD?f2Ww{^|GkrFP3bw&P&t>RgwFj^B=S>5|>=X?pUE-l6_UX6Fv& z<8o)ddh|6dDMkOr^ynk?eGX4Ah*S5nK7t>#u|96xMVr?*7c9u<@9ubDa{NC>e(rqW z?57oPw(U&c0DnK9O_5yneVgt^?iyD;dBf7$gN4z4gZJz_X^8Wk&AAZ`%(J z=67m0XH?aG!@=Ci>zCJUUiaiir>IoK?6bp^!I$ptKRfF$-{Xdfj-MBdomLor=lbrE zgO>l^GLZH0-uee78Shjs+^G(D_1*Kr;x(K9((W5ZKI1s-;Lht;qQ|Wp!P=fVy5Me< z%Z!B7vWH)K{nptNwPD`E%Y3?V-|Y0*_G7d1irXQXJu-XUSafPw=**A{*4wP}Pv0iK zs!iyap*cS~OW%3n$`j<(HV22BEbAJJ>q*}}`fT8Va+7UV{IRB{f0gXfU)Xo7-M#&# zgTi+W9XC8a&(*LnVUO+XE;HX55QiRCUwPblZSLC5Kd*l+Ed0#QI_J82-e9j4m;6_! zgc>&!nNmDkyQ7W3o0VaG#G z-*Frr_W0L;U47iDs?$b{`ORB?dC&3A^UfqO=3YZ$hXqs*Kh@WEU)=6U-h)%GW@V(z zn=tu7T8`_FtwFOUJFE-KQ0@0T$E@$7?zcMKJ(IqA(%+TlQ;ttqzjgnloQ!M7ZAbeo zPd_Des!HeZD zml!!u_^8ZNx`Ae24odG(2cyzGBMcJ(;K zn7^*T3PDVq^50!A>~&=SQPzU5zo#!<8nUmlBy;FW=h6G--ZWiX_N}fCKP?vVe1UsY zw}D!hWjpirgzXOVXrnqBs_Paw^ceBvMdvq8skXKcR&-iO$t<|@i@E+l=%DuSt6HPQ7)y!`#k$b7|`? z{D{zTr+S+wYrk?%^vZo&aId28g=K!31@Fw_tGzRyb+7om%HgiZeQrn9&h9Q3Ug;bi zzw|=Nt|wcg7;c=BVH$mIfKl9cYUh#fH?ss@f#Ntl9f@bR<`!cO=`=$9~kMzBsvfp&q>8p48Tz*9_*RR|? zzwm-{cL&!_r!-dHEGAy+vzoAu(zQc8GiXIX%|U$^Bb9rmCFeKZTK@Ty+26jtFCtTR z&A+m;f5O`!O2-^`C!2mhe$HLE@z;+JL$7nY(yCumhHug9q4|PyJU;!l-M1I!L9R!a z4|Jb>VA;}w-amF z`dx+p&XqJAPpV8o&iDUu=JcyMuWn?09rv7THvH7wyX!|^o;hg3X9J`3E!#?8oun65 zdac~p>!4NmzICIkQlGB*j8_T#sImH#fA2tsdGhQ!r|^>B*Y1BFm$!FT#q{Im!>dh) z?;7Jb;EVq3*mDDaYWip1y*R2MrcCo7b6VV^Ng0gKx>fTNX>p6j&7E!@_2f;b8)oC4 zR56c_J7URg7xN|9Xd|9AsG>yU!y+WNg9n6xx6OEG@s%*~eSL{vEyZHqO1|gP^>d!4 zE@uT>S>lpe`o&>3_50d8Z!c{>M`w#}zaOP*GuHgRp0_>h(ZlEY+eZBC`_KA~Tc#U~ z;;wo9&L_18=j}mfPt~aO4bdGt_E!IvKN;D1*!ptk4#PVKd-P=C6UGnkmp5i^$fLcx zkL168)nkTh(X1=Gx_;RDes0C&t!}Az?ir1;*b(pQ-u|Hr=i?&SIy^LZa8 z#_9d^O?@#uA-2%?(t)1(J9qXTUQ_#!zc<_~XH?i81EOwqQgH1Rm0my2oh66wFq}vk zG4rj%A-d*oT)=>07rI9Uzg_pcU8epMV$Xgww(r5ANQbzHqyg&Pai_bkPCnkhrthMJ zbc53d=IZnGM)v7e81-n;Gu)&wrft4EVi%UUfoVUwu-9;6`$ykHKNrpKTWm<%VE(0W zsJmrtNNExy_sPcBj0Z=nf@iz8S2+X8lX#P7 zW8A8@tQh(+bNS(u{?A;zsmG&JtTzunfBqh6{?QlllrpG}6;H#WQG@Oavj(A+GQoB1Kz52-&>iJF^!^{}~h{^)5D zT{f1i$l0B&9`MTR<;Wqgq6RXCyr{mT7k}>7mcpu>ZUJOp$KM{#bEnLSw98ZHxj+5T zLz9;pGVp8JP`cT}pPVD(IsVC=ZYR0+ZJ#iJx8%2m-==N5_~6M6N>>e12^2{WRKT_4 zX(U35cJLE+E@qresVI;q(RdV9>a)_kBp5F)TPt-51u|NE;<@vB}IZi%jfBft7Aus99tb` zawPb@g16KY5rMCpF7RiS&_gT}SSei0T&~1!y09jN>x58~!YyZEO$wLb1m8)~zO1k& zg{xaolfunQVND7bj-V!mJAA^L6t2lYO$xWcgf%H#?tq#U?iC4ZQn->JXq}=B6=6*Z zr{_?U!d(YpO$ukb%aaq66MbRWMaQuwipFVgnDcQWBa&gD!D#VKibs3}ViOoX(T&5; zw_zNHi;NR}MZ?xNjw9Bvm=9u15eQNQ#lji}Ezet-x*O|g9|nr1J~3S#q|o4(IUj*bqWj6_Colfq#x3sQQXLsP5xgvhuAu^B(GN3G~#Ce0Fo zG&2nq@-ukWP3wwa#KbV5l565hqTbJ^|Wls#Eg~^SCyMKxlWexEv zg~r0ILIV*<3dO3e3ld4Ii$N2RNLtRKkOx?hP)igfD@)BJQMe_E4OE%3yPw-i^o>38 z1#ML58k~5fAXF;k#)QmQL^>?%GRbfn5@A5n(pCk@%2F%XP7)-z z=_?x87?3cQ7?23+wl>Y`N!KrWjP-R2Ol~K4rWGw7SMk0_Ds>)jzB!c=u z(P@9oaI+$)uM`~_#Daucq99qD8C=`iERa?wr5?!G?4cA$jd2vwo@6WnN#OynRt3r0 z%y9cN$({r^=ZnJoVnD)JVnFiK``M>iAXU?rlA4 ze(Qqdr(elx0+OHpaZ%%BsBdy1|&3PJO(6;CEAm08ne?YJKP$Q-7D3T{H_E$Us2nN0m(w} zx~Za!%|Y-i@G$9Qhz?Jukq9g}=1nHUMmPM;Wd13_5jmIA6lJJCjzI%;;xTAoEKxL2 zhg|{TgC~`t-D>D+`Vwx}ZrRfm0lTOEjhz=4?lV}W(-$1fF5(A=kP^m2T zpE3<`bEPQCFs?7UgWU>FGO8$Rj#pe{Jyfs%{%*0D=8C;yJb}faG1-ux0ner}5CV>a zM_3#Jl`P1#iAVn2wih>7ilWRCbz(8iZMnOPD$1JU-xe0$**++sY8s}w4gM{a3~_Nd z1Ui94q_RkG=99<)h6M?&i8vw^k>cNSr6|fU?kL*aR&ZQVMOkw^)greQch!4ZFUB-i z?5T1{cnC?$V3TPS_@BiDQ3hxS39wp50BHt8YG;6!?a zevcv1E*l)J^9=Km!-CYHBUuE&Fe!jn7)JbYO%$MBhS|tU!`x;I7cvZeh_nA zBXO7K_B#)*J^C>qb!3*WYS&XL$Mfd;X*-?w|DIy(`hmQf_)5FL-{I2KjrOZ!))odO z3@RUelaaTMQQYZx@iLo;Rp(1b2dXi6b}t5RdYd_tcWtN!Y zup#3@}e9y5}N3sl5 zI@q+*&V}tfz4SmQr>FCibFE)piLX6Xq1(Bef8kuqEXJ7rCKHeS9zsYq zsX00DMUbcNz2AX;9|u;+_6}1E;uWF1*z- zQ~kZEgVhP5=g*tJU23lQdTr~iM|_^OdJ;=-yZ!!etLuC(M}JvFTCyzr{a=;CtiE+T zszI^2c8zW|S+mT|ch=?O$o+za)O8+)T{KRR4q5>sU<>_on;|Pa;u!O||Sf zdgL+2G~@088au+>XOCS*>C*n&qlkSrk6(|8{J6%LGWtkK6?a0#iJSq`41OfJZRqX^XS061Sku3+ zyXJepFLPYWYNvTQ()C{SNgjbO7JTe9fvIxmRJ$cpJ{@G<**QJn^~9}A zmn`$qSFh|pRln(>En;YIviib}H{2!oH5%iCQxCsfnSbWO$+z1-9d>h%KljRe;QAS- zbb~W~>W5Eq*xWtxpV~~5u(^gC&pkf0)%2N4@S@6m|I{db(`Q9RgRL${?+ajcDe&3Y zncROlr_00{;hW}^Q!O?pj0qUSoFD$|efuH*gzeviyZrXe`XS`thsS#E>OJacVfe^m z#F6^j(!==F2A{0mRk*Q}pXj<@DAqb_Jp-X%dN?aT%yimd^S!KHe}|mnnX6yg_D@Ou ztrGp4RnyD18hl=MV=-s+{VSUM{Rem6D61Zl{J8{gH}}?GshgY*sPLQ*jyusOrbEF> z#OadT%9=GZq7U7e8KnBfIn}@ai^u@IwD8)5U-Y-ba%fNc^iMlL)cn5tw}!P|Avv_F z^zi13Ni^T=@80)g@6`<}eRt|=_OByb50n=4wBt+-^t$|a@5R^0Y&x)My6H2=mHw{% z`#kriEOVQ?JF9k{tws02xi8X;>eJ>@n6c~Td^*i^8?_@QjX9<#d#i74$s2vcZ=XA_ z+}FGK;Lp)WPGN4NR(Pa9vxKiqw7|Kw{Kw>FWk?BX8#cKdlv zkATiYLps^7%w3yV=^ol&^KhW%S8l0g#RF@LRkzo9ubtha=6<`oD+lkHUBw(y#+*{n^o}e z{3~i%zSgU9<}rN8Ik!W4?I_vvZg{_L1Jj56Wn#2x!OV3U=NY{$k7-A(`CxGFC6%`2!Z=!y)oJyq10Q&;Nyis0kBtYHVI{a_N;5rI%j!tM)d%HY2n+A!!Ld z@0XqSuc~uv%uK>RIEF@i$ZA&_GxthO-jvzZKiM{~3LhLgu=04vfhnPrcBma1GACnx zb$6!~fj`Rc_Bi*?)3Q&m#_CU>HjOm-bGb=Yaq;D=YMa*opw1j<9dbFnqTJ=L>95pS z3zxclaCtHQ_@d+OUJa=Er!cH?(&tqPi7QRZpBUB;c%3wA>ZZk~FW>SrT3$Nc+3{iSC4$t8}6l1^)f5;TsyC?gO}n4Tl^SpV79q; z?7h>oR;-<5xP0))uyNcs)1Qu;m0syk@TAMslU44Q+}>y%yE5$axvn4VGpjmJ>8iKu zUBNH>$0dVLbj#Mb^ds$pg-Z0-sV0Zjk?+q|pRz9)wlImjbi&r&_m+C~Td^*_UU%oh zvpZkit*6#&?Y=vH$Bd7^C;G6pUfb4M`FUlYaQDBFvjMlmWY=nZzg640wU%$Eojt63 zVfx2kyZWV`Jo>4y!-4}lyS_-!8}WPI!t>F%ic>1DR0mu4T|Y6cFpn9z*CF7aq^%|S z-$#C#y1BqxhZxJrb#NQ)xP0EJ6P~t50?#bof7|EWnrH7<`0Tqow)FJ1^&L8GArze3 z^Vf|#e+`}xkTT?+-nH}fXR_@ERy`Z!eB|rxs_{c{1$Qe~Mt55@bInAH3*_#(-EW3& zip+}FOTWMDXnW(hcN#W?Z#W|N@|(nu>&~Bs@g1 z_G-r$leHf4?>_pQU)jq#cE`Ou&(EA)hLsmjY^&Usoo3m)KeI>j_z~L+v`*R($|hS> z`#YLZX=_q9JPP=saf&rLu!VcA2^N3^(=0j+134F4tT)>fUo|NdLgq`*O|v z$=lByD)#v6;jh4>U7r!0Y6>cItG|A_sz31RByM@v4_3JD_l!3u=l5O_l9PSm!R)6` zxUbc(4S5(D9Z1`{h9LF-5iCW6pl-pZmtZdEpp4Z(%+uwX$X2q++l5LY54LSta7SRQRd*Fn=R5mw2GokVJyObSy86mnu;<(lAIWJ ze9DyQBsfc_yrK;BcjdamomrlAMHv_zl@=qrL2-tuK#HQwR^Nz{D5}2o6=kIZ>L9eY zig2OmUsjYgBvBG<4`SkA{~4pqQt%l{gzUral&UMMDEostjG_#%t9v3TOn{pgb3`}M z7-bksbZ^ks(~vttWU1tK@+H-ZUK!MZqJ=j_+sO_2jfC5(!nIi|F-VF6*|si7o`xT0 zH37-f@E<}GkWfn$r1rXo#-gOL5*w&86=iiFDqp8N98(UYhCD+kNY)~d6h`W7Rgj>S zrE?`gf}1(Pk|1F$F(A3Cp2`rNDcXZP>|Y*KvEctw{<~sS4-d7 z1SEGgwkYEdW-GouYKekmr=@#b+~<@8=?QJ+?Av$il>@0E?+yx*jkqT%iiFy#AlYf@ zt?-ou32w&smF!6vOAJV^Dz1l`)sx<*)e$!s6)Fc(w~VynZ8OBmD&aH9Rs@p5nD4C% zlB=rvq$VJ_syq~{d(a2B9cqb!WUp?pSgbu$vL}74kH5Qk? zUIdcD0Pn2|lD)d&iY!Tx;AV7)WKY6aVnA{-T$9_Zo>bf8g7SK4=Ba*vFYdt>Au8-i zb|R1zhPQ8BkerO>mp1{)$!LQ3haK%n_NXNal7p)5m#}7mRJS9s_o~^#*er_m(i3Q7 zBOX5S!hmEi0*Ode5E8moL2^*ls}*+zbQ2S9UL;EPB#b2nBnRC-SDFVB%@&f*3d0jA z2C2uqjwM$-uqqS>0Z2_1WdtUN3L#O*5VMMn1E~UnC*tu`JcK$T(%4MapNcY(b16+x zhWg_eMH%YEV}8akmM9vi!!F?vp^7rB+pS1ZR`_jv?Vj7`He>pT*q0_#>5%@Jz~Rsc z6xhxew3t9;;~}300V0$U@qfrP#LbnWD8snEDEY16B%_M5=6J=0KfZ3t);l{C(_FDv zOr#+Q1;?b)aj=|3XHg)c6OM!fjTx1~V$&p!5Qv*AMNtM$eVe$usG_Vn{%w(iHg(yI z(O8v&*uNz+sT4K`PXHAioS2})QV*U%#6id#P^Umpt3O3J;^s32b35PU5`M4a6br@#1x3>AWbucgb-N(!=HK-adV|8$}o;4 zrnxP5K~Y6nb9~Ms>#i>*bWggB>C9rElSW}t5E5ulhYg{Ra0 zj4vo|t`tQX#zjP%+X~JkswiuY_gLg`eNCT)vDnTm_8yr80u7{c@Nj|_BODw&xj-F> zXOVFbbdSaOQ*tA2t`tQX#@WL(x8-gjswiuYA6V30x6jr=?Xcmx#C{;8L!==vT5%AK zjY1*NX)HVgGFMTcsZ1J^NRswsmZB)bxNT^2TfxDTRFpNxuPbU7YV@%ySty+oJHKMT zj!6cgFOERR5#fnNgTR7Jfx-${XW}^w5>Xn*vK3_*XAINymb-0|in8W-Yensz)mx=n zOvj97vA4#;6X_g0l}ZD*4Z%|oCb%0+4iSM&nN+wfZ8XbPlwn*fwCSzjj7ciW>YCw$ z6*(-U5TZPOVVW-X!T#eFWpk5fE?a!$`}JMxi;6>dXV2u)W**aAur2Y_gwO1)2M6p< zHD2_XIsTk`SlGRe_r`CX((T2;<(IO4rrlY|etBW6*+h>DWYFaKedm@lp5&E%x)^kO z>fPg^{*ptqRb7`CNXpB&#_8h1x&e8adw;T^J7oOeRdpk z#$tz-tJb9KV%qm9-7jUYOf**SXH;)du3|^t&ehk4#-T>?A_&) z`kInqmXAigde?Es&_`RXMsy51e=B&b>F>^_BNoj`Ua{5aRME?RmG9SOXb#EW6QElA z>{02K{-K1A`QMg&cTXrOx#!ql=gPYgI%)Sbv-TV$myL4$VKQ`It@_N&tHHZ>`)6Dax8hy*EcE(ZXztq9)_HirfSdg~*?wcrn|MUlSMzQyHjtW@Vc~z#t zj(ahi#+9lio|*neldnDPgZ_{;144Hb$ArunH<LLH(`?HiBB@yCZv zTfFmkQhCkZzO(^Go&27k-GV>H=UDi1hB%C`->tW9*qgGqW}7y;I940Ad;Mrl#hn8= z@1B?EF3JDowA1iv*PjVP-(Go{n=z?mFfQwIs8LM*`!n=ggRvX`nK+58Gkf#8;c0 zALuNsoU`rpAf7=#@^$2(O0|Ps&h`@tXOGTU({0%5XO7`NZ@cXbi7lg#<(pU)_go&f zbXv`(uhFG02~oGS)@RMY*`I%cTVLUKZ1}9P`R@4>D<=JToaM6Y#M*raJhoodr9SO( z|9+SAf%StsZyJ-E#UxJ0Egk)SiUG~dw#4;Une&LvtN!{jvfn1Nb-^#%A6@^=YuL8Z znjct zxXq9I6{FtYePB=On#EY!$De#6YO71s)G~`q<9$w+duIjIoObZly0tZIwED@R->2l5 z2E9w)zIDgV2c6b!E)AfhT<8#KR$eftcR1mz=^djXrXEkc$MjCPa>La;G&Crut0vuV z@{B(1`tC^czG6Pnr&E_hq%LzmRui1`2Ytyo)j2tBbLn!`v0wYoFYvS)v~_Ix`?)LT znA!5*6dX)7D+uei^-+)&K5EBTtk}&HwVjrQ6qGWJcc|VH>*tdQg0G($S+U zcK=AV&-xsEV}mW9yMOlKiB39ZxK!fAZU$HCE-tBmk2i@dLJrDHB3n@IY*6#k`XN!Vy?cN*jdp`NaH~RPUp?!v};%_n?HubPoU*`pW&&!`} za5DhuC;6g%zEv7 z*YnM%zsfhP88zt8%}*Jt2)PFL;@{}$yE%>jmT@9*)+54$q=}{1?>j!XyXN_7UV-(& zALAbA^fA&ZbNf+x-hwx7X>jaLhWj1dN=wy{DreV$j2n+X1Z>)OV}BRp1FP~J`uC;! z&h8SilF9f?z2K~Y?D{@GgH1bN-_dkd_^wMT22Y&|uP1H4f2sP`G?T)CgREyfOLyQn zJgcZVGMwOYsf+8}$bl8TkL*2jW1255%Pei=(*+Y4DL*>-s`grv>p45+=jk6WeT_T@ zk2kFzVIDq8^VcEmPsLBdo_?Woll3PKTwM$$-HfSA^EpOmZWHeQGx*Rx5X` z&aR89TOL&!`rNp?>fB!MZWj$0IgcKebsA0&Tk^s1#I8erNq=Q@rVPD%CFJHe-BWh` zgFY^$zHvU!PvN}2@L=o(60ygG(}&ak4k|oJDm)X>r^o0)q1}gNzt6J%ZQO~X_WfR` zg$tkbI~5aA>Y%fD*y%p?NK_A(kuzt)kV|!A`33An$Q;>9N^s%)VKEovS|d7@YfQ*A%m!1LmEvLD+9SQYq^X`0gL- zwr$>>jq3*w-FtCciv8|@TqGsG+k$zkFYJ833ODR`=RD4_UYIQq%moq2ZSTbx>VjbLgib;6ok`$`VhR;<|= zWb3W1dnw|(i6!$ydPU!RW4A3^`F4Mob!;osuB>zU5tCQO+^yMsbgSxB-*o{my5jAk z-dub)b-$MG_Y;i%>#uI4oa}&O27AB$d-(K|0|Om)Z@!!Lx%N(JMj(efXu4X(qDg-D zDc|q8p2g832>E^M%l>gB?)Ldbit63}mBF1EiJ`k*TEzBr9AuMob?g)7z_S~8VaVc} z4{P5OHr;zSR1o+@U`G;M6@~{Z8JF%}yI+I+gJ1)6>wb+3MpT-Mn$-&JO;& z$PTZ~*^7(aD`HOsbkJTnp>JSkP1*;q2;)n{jLHoi^C;zeHa_YWWS*EwkDhwscGvJ1 zlkF~6c7K;;##)q_a3k`u%8;IyV}~xJbjUfgHDQb4e8&uxU39DC)pb~9VjnqxA0pB8|!w3{Sx}-MO=M+@`Y6uE?7D_A+n)b}5n_Wsc-gqyl+Zr%

~lT=s7+$=1>iz0~JK>4Y5pM`PZ`5 zn>Pf_PF~ofCVa`hid|cXedgZ}9dRk^*P=TkojVL)>DA@+mB$~V&KC!#=|1RBsz3Z~ z$?Za+qAX3iqD<|-QcmaO2;cWMRin9M76=iCFD$3NQE6Oym%Be<0nQJ?JFP?CRs8Lb2Dnyv| zauNI`_)m~z=KmzmrAA9MWRn>QF@p4$NJ3-`5}Oo`0}UCQ8XpNe6XfP-$Q&Gu^a@1y z{HYNMG=W zhD=LVo=b00o=Zsen1Db6Qm&tpc`k*r1b9j*%6`q&X~+zzAEd!G2<968k>}Dlm}|nb zlS>`fjEiEY`mBc#JFX4(X3hf9O?&Lk21{r3rk&ysta#ngad^7I`d6dr|c`p3| zi`2h7P6!O3i2qcB^R9hD$y588zCLq!E+?O-~3AIE) za6gRW#-NH^MZ@inJeNX{Tty%$OhDhdAW@9EooxaV#jq~B2}r0V3X-$6+E9s1KxVLrEC!_8!$Y1M zDg|K?AUpwuMyD{z41Vi^MAXNdiCT#PiKtgSy$MLDB?^*@g+Yf$%>xOy@Z}Es7ov76 z>Pdme=0@&~7R!W$Ai0Zrk|2Hq3qfG>m<&gbn2^~F0i6JiMI^H!7Zje=x*)k&7@qAc z2@>37ignMZ3xKi2fP~l2`MY@_eY<0~R|l^gNCVbAr03rGfayuX5R2df;He;xVKM16 z0;EFYaL6<`C4|ER5-x$W*RF&avI>aX+DUor>iI~X(OckaiGE#Aq zr$+EG{-(Z(+8j&C&IUK*#QJfxvtcYz9QEXVhdwHjfxZ5$-`5ezmqOV0Z79`TpI*WWzm^*2y8sq&e8u;mz1vJ6aPw!ZrWdFoiqIepb zsOs-(7M|D7>Ue|>WhmD_u@Q@#iKprMQXyj_1xLZtnV1uqi9iE6 zSgSuo0g#-Hl6b<+)$b+o#8_hR^wG^>HVaRPGZ&a+Ar1_vuzxPT6Jm2!J6wQ%Gw~Fs zI2D9Y1g$fRiLen0nN9fF2L)lb3g#-d)a2?4E2Af8wLx4jz zh}W1n&?pkgtv%k2j5Q1vO7>4csaEVvqj+L0F?f0#I*CuCDfxSJ&QQ9O+eR6BobR{y-4G#U*BtPoGtmydRBn_l-X;Q9aBdlR^t*69Dg zXfQNTXf_oM!kGt@3MnaRH0O-fDMf?Kp)#bwkd!GQHzhJu<}zjoA(S#?p2ep77!I6OueJBQZQJ)XkUwE1m@#+o&b(KO`h1wF#BpFXfZF z4w#$c{u?~ern1mOmp~%ZVKMUIfCR0C`8)=RPh(MF7s`iSeY5klXrn6i{M5Eke-I+Y zQ#?}4vqKy8>jVE4&)Qn`;AKNn738Ts=I*gFo87;`Q`D&z4oOTBfl8sm_MJ!Pu{mt8 zmE=OKm;yE+9Iz3EQWVVkL?rnDtpK4AAJ7qV_7R6cFj~7f^bsJwjWHv|n8aa@fX0u- z0PBFh!s?sh#n@m4{&Vo?`EQ*I!k;V1MM*dgJ?EgI3L=5RW|YX^Is{_jh$OOv z?ldYNDgnSU1$HxJ@cbrHxzM*Anzhs5SWRZi5^(}1A#nV#aKsS>@!q7{pT`~oLIKWH z;i8cYHZ%E7U}hX*__#hfe5Z z@U(5bT|?olMc}N%z`;Zi#BiD+))0^jaCQjcNCp?0hJ%SN{0^L7p$IUQ2XzzPqz`_wn4lr>3JjM`k7T}x{K65043QfboL=}Dq z4jO_0QJq0$(IA3Mr^E8$GbuC%Kw(p@ZOR z@gi%6c!Ge7u;U@@nhYbLCpf{LR>+-;X_A5=hw9yeLeVH$VFNbk2UL(2I1~!Bnt-Zr zB9R8AhEzU>Kqf-qKo*$k4Pc@qz=)N8446NU2?!P=V3wd@6bJ=0fWfp=Q49>K|ATiw z1O5T&4{9wLWEO`?VRCtNK8s2vAno;}SIn9QFlmj!U?h7nm}U_71=$Fg?I;)pH2q&- zu#Hy~1A_|rAUV-#BzSjGpn{%C2SJZW;WOa=f(Aou7GIWdBG}&m=70p4zogO&&LUvW zp(dx8f39Q;ToEvCC>R9<>|bC?%vNkQ9%)mrI2cs5hSdV?pumiW#bQGG z61dX9se;AiLR%~%SOd#^T~BNPGpR8cjL<9w(+rZdAOryuih@x;DCGsz17GZu+V_aEi!0=1R5K>wc$bBXt#^P zXmR4P;1aTm%8^x6Vp3P`UKl149I&@c1v2WIXp{>U#xJoLf;4Ov_*zlnrePAnFo(}& z(HR8jr%2_{XpQ$PD2x^-77Jb>m-!ljQC!EF-uZTFwdBfphm3kJ4FeU2Nx~~;Qu%NY zfAS;#$iNp|t6ZRbi>6Zp^V_DCP7AO7+T#(BgiNfgm zX&u!$>6#=AcVyIPX&9*33(_2o%B4bQ0xEQ;WRqESA{(p=CJhz|4|eVla08zr5~=Y?xMhpG zf~CmC#{WX{BI=(X^QTH8NkT>?OCv$WToAW~Zx1T;J%rwt&&4BnL^!o(oHV`DabYVb}1h*24 z039(gauTvuTDEvAC`B&$3<{>JDtYUs5ss2z%8*ePq`{zK?f;nD)q7PXr36)O#LTI@ zlvODSts*HYOP4H8@kp_7X`P@vD8QK^FNomW?_=CyQtUEjB zozaKBNu9A}j9QQdbj<9CO%ZXx_$@ znRCC4sCThbzG`%D;$!o5%YGV{)%IMLHow4lrD<%4-Hx7nhfWJRefiK0yPHRB_Fwok zBD@Q!Cet~ves7-h(!7^Fm&RoN+T))G z&)BlHM;TK)uX?d;n)wU zcZLmj+Fg{{E4TB{DMyQZdrwQ`-K{KMtG>W*eZ1?ZR-py&wQjqQcpE)WpfOv}@!XYa zeeaV$X9av(Gb_LT>qxUlzmHnRsj4 z{Gas>&t7)Z-S0BaF=nWF8!bokm{I+0+7bHv!=2k>W_PXo&ZD`ria2T3_l}P8HQBA? z;x*o2Ku4?Db|&TT)B1+FZg#V)%~aoip@NyYjJd3b{gzyZ*H-`JM0LKsiL)Z@nSG~q zwyV24y|5aVSz4|WzC76}*6s8;asObbo&Q3=>d3Y3=Vsp_SGC#sPxZm% z*=IdE79Gf5FtKLz)`Vm?THEU*DM=d)E$WlaSJ#idxAn@f?i!OiR`>5{`y+Lz(emqY zYiQeMJNNjx@<3kMrF(;U9zAytB8{<~n!oZ>M%v4zrGvL#kGh#Lk2{{b`t;KID=YHU zzHFNAGWxyIxlxYha||Qn=;NO4c0JVE>gmRaaeTw~{Y01O?X*2fRtbx}YFQsaM@F?)cQ*6~RdXM7{}AJd>xDS=%IvMf<8lnr&L-@& z?xr{GdaE$|b0d9L%>A_D`7*6lD(M!pSC%X%SZp}8UghbfE~8v_4frs~A|$?RaOB`BMPdr)|{HKE7^@}*zHWm^n^*8&32YGOwsIIwCjByz>C%f17SsGed z_c~Afc&+T?%-pEd+*TRa^D9>MiJ&Zck!CwgM+G`abt%E|6b zK`W=popN?}eLu{tXq85;Px-~UdpmXLRc7kBap0JefMXh$ymAf&yeXadCdcB?pE7; z)Sa)?ZXe&u{iRtVujrYEK~&nE1Nv669knF3X5qu{qYSr{m!+O&9DAgjbJnF!Ter5`^PAbkh*j~Udpt0k)cx}7 zp|-rjikiBfU)SxLzO`Sg1fQF|->d9Z9`kO)%hOjJR-SSWZa?F)_eFNc4kvo0Elgvr zc~xt7W7+H{XN^dd*AyMiN4?GHF30Ch^517Z=TPpr=_RiQ4E|QNFZJ-A5!AA+S=`49 z{8t^mnoWBbzJ2MI4`=VX)=u8LGS{Sbzy|LvJ)e#2ZxX_Zj^CG&n0w^Z!XH<8J-rUk ze8f0vTx-28X-P~hd49l{K>^G~fhb`x$w|ZOt z>;48^tjznpztMksrd5euH6znWDLHy*aj4Cxn0Y16y;2vArgyY?e>kh_N%b9HHT)6^ z6HZaj9NV+_@_hq`f*D;Z7_Zk-iKfoCYPzeB(s#D_xWKjV%CSpt?mz55ikej)(#hoJ zq_|!qwa$E-03Z*Kz$;rZ3t)`gVKkodc&X=eLZQ(>$2ywTTE*6D0+ z*`324GTPH$eZ24ab-7*Jk28y`d%B^n-1(O82i-fagG(=MHoV+(;`H!CGm9<K_*sC#k29v8ZP^12#baNzK?Sxol+vxbK{M4xuN@_?Py zFMP)uIrO&)+z53i>jTovNlxv`5WJy_5EDo%!6qa zTf)07-Op5>H1tcvf!VL0+TA)>qf&Ono|aV{_u`u72aeX+?u)KXO3^o2F>myqDaU<{ z$q6e93{LFoT{Xh{&HTbevxZNdHsAd9fPj7zHjP}bzhP(e$D|$3hjKS+tG+9}SutGs z`0}=Xk#s-(UXJlc=tXaaX=yB4w7cAP{S~c@HzDV|QvD}i7?B)#c-2zptRc_Y^ld6S zj}wg2#?6_UQQmrSH_q96{Ap$tV{bg{M&G$={qn%p@$-uK{X+(yo1K&VAcQ|5f9-VL zC9kbaK5ra)VR9m~@X0D$E7KFV7OGxqKYjIxeAent)%?=lJL~(F`8v!o7_mFrc19Fo zzxw_wJ-#d?3^`*q^hy|EbLCstdc6;e@^=RGw0QY$g(bsnNCLOlf+sz%pPigpS5}~T z)0p*;{37^XeNachMTjN~n`*r~bTdQ%e&a)YDByrzm)pQ?PsQ0V7H3O@hcJ<}WUvYmbH$KGFC(*fP zd8Kcw9$p!+4K4NnW(CYbrV_obhi|Kpt!tAE+I`U|b}?M+NhpB&t2eEF-4F+4mxp#x-`vLb^c{LsNO8~u4?WI~59m;wH9rr#9S z)Tx0X{=&K1H2epY^@Rop^EeTKp&{WZ3GLD8xPhDq;SRM8J2EghB2f58t;2c3kK%;8 zRR!uNMV>;qDG%UI#K$WdM0}3$as*%|+;Tb5Ot?XHm7d(7tMql8GxB^v#XJr@MLwE2{#p4G!w2S4Kv~9?TTi?6?BE)6JC<5 zXeL}W6lTKB2o=qQD<#5AxapIknQ(PDm)9^$yLAgfLdXLD_P0(3Q zSR60VgWm*(!e30_7xXc&SesnjVtV+N6>F1=`!ax;6l;@<`!W!HPl~k3#eEqFzbD1o z%{J05joQ zn2Bb>eHjS9C&;{k%ws;60#OYbiwA{)@bjZ{`4IJ?lAv-`G!yR2K=?gz=}?c&CqgzQ zk-&j0c^)02@iZuecdCa496a)Nl=6c?q0 zAu=p2Sz#Yp zuF|9aA884#ozO-4dk^O!Kg4!otI8A?71;)!sxl2cEwq$R28oVz4F;a6?ji%vW$>Hu zAE8-J69%5R`-9(Z0VZiKwX57Yy(di&?+ApAk)AT zGl2{PPg8mub`RwLZgvCD5+zOb5vI&z!eAKdw+%cozCiMuKwAp`4h432*zlsN97`oB z1J6c3I!vmmDZTAnDX$v%aCm?SD5j_wGg4esY^JBR{s=O%aZ#~61JB~0!yUVhnm$>? z6W0e0WBtQIERsasAQjlWHY+4EeWh7Z$W@Ay2;oDATuDfnkr+r6^$OHRPC-cZtWjUU z=2S6A7(X8rQlDQd31?s4ypSkuwWVAU%qVRxr%OQ+j}$}d)me24>KLaWq$m0x7O0us zQV0^(qz8p$jX=UZax^QXUY*sl7?LXqJ`7BjgoGK1fkbT2|5x6W#jD4^C3us>Rz{fQ zG3m;6jPni(sV@Qv>~t6$*x54(OvnO(5+fds4(XSWTgZmpFsX4C*&jcXdJ#2F#YsUT zYAg~+K@yJ?L+Y)goOd1>so+ZbR;8yI>#D5KN*eq^*yWcf$^}6o^+O{LrVDf{^O-{6nvFL+vKP`KHO+?Rn|UZ3XSM5uzeU z1OyU`2!#RQ3&dhVGBacuK~6J?EtD`wWFkY-Rq>A@nRifqA@xjx53^oMLc)y1K1gApJ z(8g)Le-5c{d!@sYE6KcXyF-X`JZ7I{E*>d{1SbUh803Ff3{r^_!#XK`+HJ)ky{w96 z7w(*kSxF?}*UfFB%t$dF@eW>|H%+5`xTH6Y^qUN`;l|#A zO?j}09{;F84Z0P_l~JYl1_3f^3@|%q@%FJ9_Sg*uHBy^uY-zTDFDetnv=hMWsxfc` zLCCnK3^mc+H?AlwgNuxE1I%Vw%%vH&)nacN$$fOwTreByf5`>iaN|ncGPua7JTMdQ zk1<^Sd?ziorIFk^H^~LNcm9`L&^_H~mi*!dM3h%d^Sy zU@4w*j4Bo2`Kwa+|Flu16i+!ul?w1g3ycX&KAXS-#ezsC2n&j#ZW26-pa_aWW`oev z*uL0cT^Gi>AdL2BhK0 zJ`5-f2R#wc$b13?5+E21E(t^(E)!BY85ANFGE*St$B+r9#hnPG;K(@-M@Fsk-_p_>j4F#I#dGYLNg57DJ;$7x6o!LJ?U2Yr03R8l zVG8)ZfwLfyNhdH_6h4Q{;L?b)r1k(AwaXY~TznpC1`WNzsIpj6G;bOXMm7H(IH<%< zgv2Ha5zNoI&@P4yK072H1rp476b4vX5hZGw(YKWhqsrc5MH~a?&$V+2qsn4Q$-HSq z7|r~55TOD)hr@)AV`_<2M`pqyKBYu#h8O14KCnnDQb1z5(VRt;E0Oktug2HffgG z7@|D0O!*Lh-z@XbApRvkO}P-8G}dekQJ%4;e26N-Q?;<}=3g6Y{t?8#<-I8vVv{DE z5)kE>aLR|Mia^9IPLqBn3LZ6>a6bMYGU5Cr3MUFa3x5N%|0Yd1CBVor;gk;s)&GCT zgi~Gk>Jgc6wh?~yKVZVCFZ_G@3JgT=yy3+PX!^fi0P;*Y<%9VH6Hbc;FqRTv{*p>B zU?N}!pkNdb=YN5bXTm8T%paI=IyZnB(HP9%lG_F25HRCWFbb&b4PfM$aLNbsOIpV( z?3D405j21ak^u9!OgLvFVCJA;nl#~*c*e*v;gk;s)yRo73KtYB26Pt$^?}Ra&{X>L$BdbMJ2#1C#d?E{6IN3Z71FUOlB~xIkE-g+77u-ZHvl@XRZ?}908Fg101}c2Bz*nBeA&^Kc0tI@V!GAEIBopFGG$@** zaAjo-2edfhTkr+B%&!OxdAr^p$f%#vFn;B~P#GXfLQibysKelMIM7uOIzMv3l#>nN z5|T_Ux5WwF0)52)%>Z?Yle2p@K}L0zc8NoUZLYAXAcY9F=p+i6!esNPMCfx1?JyxU z0$rG7u8bBZYzyefWik*L#Yex?CivKhwxJjiNiYZ*Wg`vaSI#(tO(qCm9k4>^G$NbF z6_y zsH>$w&nXf#=%TavVDAi_x}miIi9saEgwaBUYe6J(vC~l`$Blj-4%g*NBAJDZnk|h4 z6{_hJ9uu0G@;Ok(PKVTQaB8KKXk3WKvq+HaEbGb9vW03v3UbM-Q7{*xiyX%f%#Z|= zj*MC>4F(map@JB~gnUR&pfjMoCYJ(XF)|s7`@yS$Ok>K59ky&?TCf+nG(K=1sG zf6NW<8tOfbj~?_)@6u(Jq4a!n-%ndx__jXq z=X;+XW|Q-F`d90!ok$OuGiukPWMw07*3hnr9p{)@v|XO-yl~CQ+|(%_6CbS{HKxOc zD~sOr&waMQFSgRxGik@sEpIomJKNTMbm+Vq1_lR^i8ZrK``@W&3 zoHnkwmL|CAuwmBL)^|s1=@p)=ch>sM7}2I%=El_t1%Qc9)tTepxdV@Zjd$`-}eZF6_Uu zXTdhTsbu?>cU+4ToLTU&l6;pDFg3KI&3@MI)ww0rwjTcZSNfg#q!VsG zxp>I^Y_sb0(2Vkl-;6ZJ*sP6DNB?AXpBqNwi-?c5#raQqc7T;;eSTWOoVuBp zuW?nz=|=2Qjxg+zoV~{J=AMkW!nXO|RLf}=hxH?lTEA%3rWHNs(}YB$`3IQ6^=*f2 zXgfZ~ps&`eBbO{+Z0gYcN8QEku1B>tr|(^PuD_Lo7FSnq<89sKnf{$~Y$$W5D#cuV z>v;Mu&24wKhgr1)_2~)i1v!+5w_m+^kVJpHcEXlPW0Z9r@0>X`d1k4Zsp-4tN}3bD zx$HB@UA*1k?!ED=EpAMT+q$~I>|y6lvkTLn*15H=c^27a(wtPC%r^SojL_5|_M1tX zo|{Iv`&{6zGBG@)S@!t!8secG%jtpTR~%l(Cl_3f9dY=yAkTePMqonMu}RjnKFSWk zsTNO8&>}YjVHL%Yr=zkIX2@iRkN=a>=dx;c>n?k&izAFt&30 zIwr_i?N0ChLnF-vWyeM`ejc$pZLVqicKqp2tM?U;UDEkPWopSydsSBx?`rm*s;uSx z+x7}MzV(Nzi<8~@;!Ova&55bXs@gbVe~NDT-aZS`JI&vA)NII(GnG*#G|gP zqQYf^cxuHRionMdn9`A0&-dA~ zeddg3?g3wn%@(*HB-S4azx4XV#f390tLEhDZW){2&dB(4*8>LrFRd(QPBt<#oUKx< z=8~Ql>KJ5s-}r;F=I5k>p;yO0rsnc1)wf4Q@3)|wuKn^pS$p(G!qxknvZT)!l1>&T zjZllebUSs~Mjp{>+RMIiD`wc*S&%dci~5o5M=C|P3fOGp&t;}8Vcs=zsT>q~W@vcy zIu#p}LDQVtX_rozbJqS<=oX{rsa`QV4jL-+mu*me_~BH^)zcA%YhANV9&O+5Wqo&k zKO5x%kvaACqlbqdcgP;sAx3?Pc9h=0tU|u;_!w)a&FRBmAIs0QSeZEJMA3bZtCe$i zPcG_6zd5;g&BUF_BebXZj%ec@RH9qmefOm6Y9-I73_jxGR?=Z{YR&oe2iMp9>{L2b z!zFxg;@YuqR|Y@wVNf668$|IP`mMs|(%u8_dJ6QqPB><{_H$vc+iTg!s!h%NjHuG6 z-*u7t%#bx}p5D?G+afD#!eT=_`|7EG9OvP->f8D=LxPVjObNU3_JiJqi<%sZ*t(r3 zd%f${HS@?Y-{ZHsZY^@qN%GgwDWtwhRrS6-GRKz`J8S!geH+@mEm%KvKlh%B-~#Dq zu5!-cDG|H{`nxZiT;AAWbZOiJR!C^_)ZjIlt^hz5ybi)Ue7XnKiW74Hd%cG-i*z7inaK857dYd`rN=hA~HAfF2 ztZm>A_fhswIHg+aVhij-ImI{pDpizD?YE!T92^&w%FY6Vx(9zP=={up%gT-Nmn(b`8Yx0TDECLe78(w`DIFePY zbn3vs&cykhXKv`0bY!ei_RJG5Bb+BhZtc+7B|T%*q^`HiU(ePvK6u@4-4~r+B|mN? zrtJ3(-(FOg^i6pX>Efb5wKj7+?XPSK-?()4^5{`{6}PjV2+X6#&zscl?c^h_Zw6}x zn7@2A<@?@!Z;sFw1_yl3eeK}tzTy4clhuj$J+~2`4$b>)*Eyr@mOURlZQ`OkeST%q zFW}=li-}tg4mv$EFeJUEcGfE6?UT>#eUaurwR+8`yd9s^UsdY+KRvZne{H3H*@Sf? z{m%qMySTijf7+0LaK8&}^`^<;%ISL#-3Y3mcC}a2<}3wB<+~Gl@1y6hGld& zil3)@>T%t&%1yQZRCVwy^(e~zShV1t*}KFisVhFu3%A_du~t7ei_mYl&d;=syCZcI zd}Cf2U%z=Qu4AnEapD@oyKQsdsv5lxcvkZT-kskqL?)cRvQ0RZ|1;qvKz@1YHp)oCIFcOt@;Oa3;KL zSJ6!14A8Wc2s7cPhKgpw)!~Gk@G>YxGvUftFcWTyp=c&ty{EZMIB{TjdEu5qD{#{auS_9BP~4kV0Rw6H%l`8WRlq=6_(N$n z&rn=nA<po%ngva$x}J_l+yh zY+By;uBH((iP}_F7%wCu#Vw(IAlEN+Ixj3NkjqPPAq)Q)9-Ls2%!FbZ=2m#V3BJT1%(9e-)trq8pLyG zkQhWEaUpq%51!i`9(3U3H!~#5c5NO?c{0L>*%p$JFe5RLdTT~z{u@Xotatjr z1H{6GWdBo5<$JLxy&iSaqau*FkP}1WKr~lqn+NrvdxjyCN`t7eVD~X0cB0~ih zhYX!t`9vBCtnGPJI^>?gz69E(GsqI@W`7LHN?qxbHgc8X+YLS}dM^nHGZF*IwB7V) z{{~X=k1zhg%Z8*P*Nx|-7qDXdt?2=SFp%g7BuJv+aB0vMoJoahS~7_T2Q)H=OoQrC z3M{c^zTK=$HHDZy=SZ77e+z=A(kQ+c=l%OU&8-f)5Y0M?Mro z!i>Z~>aH?j?!SXnS7J!J_gt}+G-k}byg->A;c?>qVK#-so zTbOu4fffL8pvNXe|2d@Ysy9NVAaz$=@mdO!c%>wQ}onkADS8tzBl3-k5s|LQ<)U zNr_m4fg}PZ6e7SSgG*&YtuYOH3a}v|mP==`p?@#`*C6k{NX9^%byT-glXX)UU{R@zrVw#QAAz6vs`ZU2?wWph9!s&6MDt4q8% z|JSN*1)_>Ie51@zWL#55nP}lHZWIsu(r5;rEvxWWil=N(9f1No(HdM_Rk^u&wx}9h zNuF|ScogA@mgD{_W6ow+pDn8#SCXe3=N?6PqLsP-w+%g|c*-&KRN#I`3wm**`^|j6 zw=6?XDV}l+Jr&^jpP^?!AAuB4IfkAJ@chrvQx;D-hMo%W{8e@Qf7;MfYJJKv^i+T+ zTC7YZLvsx#lo~=S5*nEYMqoS!fegM)EJ$7e%SIXJaaqO6;wT12Ar*%{Fz)`Ck>WF? zILsl>-M>+=Lekw|`o6_PC9nl%a=2jYmeEiJg+PFopisCBzIsFglSTl+6be(x&>s{^ ztEfy_Tx8TOV2t*kIn2Mm`c4}n-XXo!ML;?wESofMjl@Tn{U zpTlQ>L;>|?WC$=*>5Xg75Pc#*MjbOIq9BIT46%m>cYmp1L(_0D(T3lFgN7y;R3e7} zN)4GxAal7KVVNcc%H5&Vo6Drp$Xr<~M@C)q--sY-s4i;}ZB#yvYQ|T!?K!Ado2&j*L2IOawu^a+)F5Aolr}3@$Vc2NPZR9XMzx zg3JLeolb#hIfo6HKy3&2Q>X{0k~l;|4F?ld z_#HTC2m2wBvqV)Q zn9wvFOf=zl;Gm%g28lx@64@*!RL&D9PzcTDgTFEfoJUAPqYat2m5jRRKH^9L2F{Bh%_YB1 z8WSxLgKCDjf`p-`R5+n&P?&ha?|?#s4@4#xVh=Qkt%F$$iwmiBki0{uf!#Ze21cK< zZZ!FI(|;?rAZh5SCJ|C-8Wbj`@H?Q;07UZ`dI~&+TCvDp%uD!{*k0^!#S;Wvkvc8l z34as%uFXBMY{xXshHwm)9lEm?>_PoayWklD<~a(cNz+RSFmh~zD$hyPBQM;42`e^?GY;|4dj8^ zwCSY;7&)ew^1+}Q`R|xs8VSEY(Obz__?7sr^tS|Xfi(iAFAAnf)62$S*hql+Tc($;2pBgMOp~UUjlsw>y_65;4@@s7Hh`Jb7|dTYy$nIXgrZ=YG`*An zBggboKA2zf^KY15&T9Y@Edl0lnO-I%V3wm`nl!zX03*lrQa%_|M`o~CY&wU_pz$aq zKAFO!a#%E|X@ibPG&qwo8oMo_+J%^K`+e!{US0|Fn+Wuy^@7=}qp9)0>1F zK+HJ-jrRW9?PB#&R6HiKKm-LpLug&iWl{hKd_@B*I}W5K3r|+EB)}FY9t$oZtEe1Z zMc=<{%xHJgOA){X$f)$A9LYcf1`|UH4HR@Lp9{*B@TP&j zVN5=cF7q~NapJH*6R{Z6LX5~t44(!JACQb_TI@+85OA)_uxgF(gGgtjf*^zub&)t&gm58duWtyf(;`9_$hSL#u>an?PO zl)FE_oIkM7l{t%ZN7b!bq{^u>j$PDlw0E0*#FN#NPsjW17@m={`9b^#lasXujBy6J z)VlQBwjExd>NGoLa(eeA$cClg_+49&;i7>Ury`3OkP(+YDx3J{YWJTc7Z4@PQ6p z=;cw>>io*vE7!I@n6rn??Xx}n;AEZsAJYcf9J;rrT_>ko>lckVl9&AKE7>m1Qu}N7 z`Lz{;d{aHf2J&>;zn*_$_GP`Dbyds}9`PZ**7iLkGCUt#@A}JEV>mK2x-sTFR|G^>l3&?@fYgn4L0kprJ{FUUllCXWKkOgN~#-|NMNd%QVu90bMf#0#tv#*_J%R z*2b4B6luM1FeK zC-k;S?ibTrcWb*iXj*-z`5$`qE`O?Jz{Rzy!z23dbHCMP_QhvaPM14`m-;#^oXl5h z$NFdT!>6oU*^|#4-K=u;8j6_v5u&Og}ndgG+SCr4E6>U3x9 z*z_LyCG(9su3LN~#p$D)SLO)8`PjXQALjW?82mMd^5fm@4QcEc?@o@_ypxVs1v$BfUA(i&a_~`Z z3Sn%=b8lbVE7ckAx2%KKE1k_Xj%{Wxc^(tR(_Gb8MQ~|y?&tj@m|pk}$b z-16(3D4*b$3lAJq;SLUbbjd5EkUV%(;%PsP&*jJCpIpj|8Poc~n8`CX&E9M?l`^qt z?Yr>z<8lrLqEikLYe zLrxA`^DNDexWmafM6jK$)2+)6zhll-uAac;r)&+Tzl6#}jf%OKF`M zS9&~E&FZyiY%3>sE@SMqw_}x39=Ki#xTa98!|LF0-Fb%wx^z4A*s;^Q(d|d1 zeS5d2`veOmhpyKh9wz&)pqI=&PVe7ptIcrl`C-LmZo7^iD>P>Y?&^7A{F}9Yb4~5C z;!09`moH7!YgI#FbUZ#PLoZ?dFz>`gHE)(rs2leE?c1>HtkXeNY*M9d*+$kA?T|bB zH0_*tCr0aiH+eeVu6^+HeNWH7J;L_U8aCDBZTQ}tJx-55@46ivK7Hc*2XC`Kt1sR# z;-l-WpFVXQjcZSh=cQ*@bu>#|?|nIarT>F*vz|mf5)`$i=tlLpK|AJU<+(bq^HHK- zN!m{<3cK@$#W{nw49L^7v3*18_g-E1ZkOkou9g2Qn8 zMUg@KoQ67ge&aHtt>N*zDc(Y=fhuCrlj^!<1%9Q_Bws_R_gNB-*X%f<<)#ipA`7! zo4dP*PsZpi^Ya|Svpm{<**E6Nn}=$3=V#jaf9qPH>R({Hz%Rfp_K@vxgNI(PADw(e z|7Kb6*2GuGbkNHPtGcioDM?-QtIr-@cYfQ^QS^cPJ`~OR7;DHpFf`rc;@)G?S{^Rf z^z5{Q&Y9e|`4CG}uRX)H>|n=<9RFonpF5T3wsi5jReQTn-()4qmj%<)>sOdAOZ%c9 zZ&GuL!TjoVdrQ~blfMkKy?d*uXD6!zJ0D&=S>kvrWQfU?_x(TIKE7?&jZ=Q*IcuZ$ z>(s^#{JL#S?nR+8J@#*}W^u8X;17dQAp7Bdc^9a2k?`&y4uWjz6 z2}9OijvjDiq21vb9tHPSt?O9w(rm`3P3mUDCTb^i2-Mtcb8+s|*0IzidpI=e!~7^;h1ece}jVY^>*T;zNgtQI`v6PVluNZQfpY`*7*! zx{2>fV=TLyGhVBN*Unw$RCKaW^|!+7A(s1oY)Ps+)=$^KDfxYMWXAKyQI$PXm35rX z2i!|LbFXVsNKV$=O>b8AnqL3WdXvHShmUs~8u!Gjss#VkH#w7T9X>tLq;>dzHseP%~^uRKQFWu9QW{pqTc zqm+g{C=7ejhQ7b_yql31qy7EW$s^P5R5@pw_@9sS?vyoTLi~V#`W35Gs+vW_qnnCV9>u9i6JP#roX*i9^Gbs+ZW^{zROS}Cz4$rMq_5=y^(j|QzsoYb`{mXW>%q&N_T>qzP``B2@l&QKi&Tpo?eWU!b-3WLig^5A^S zB7u<#Vl2gBF$hFDk;mgu*d!>SCPMBlheF{{KwzbD*g|tT+_D(aOt^tTn2AVZa6y&g zv7oOnoy`U(04^vnENCQ!=5*cQ!1;j+nWuF%v4v~NBf!lCyzi%e%gZ)vDxCW;L2 z;9yFJRC5-5012`bDj2sxFDN<*`cM<-JSH8Pi35Er=wuRy%VIIX5(WfwVGnO67moTw zD&(gNua1dV1R(mJn0T=omM#>-Dza!M zTuoXy6JFk~XeL}iS9o=JNv@)q$arcf%!Hc}Dw+vbN)%olUizeHCR`m3X2Q)f6wSoP zldoVV+%!DVOt^YabD3T?pQ9yi1vy-NGuc%NG4h(li&FSKDP%N>TQ3VU{pTY2KkXuk zYeFZwnQ&hQ@Uz3UFhg!jyy_w0On5$DqM2}C2EvtzXND!33HN0HGvPW$Av58<3}7Z) zTP5V`@LmQm6R!7=XeQj30nCJJ(1TnZ-pc@H!ga9`&4l|hfSGWuUPLqDz6}2J?F X@Net5l;w+3!Xw#{0mW{S?0)|r1 float: + """Tests the model, returning the average test loss.""" + model.eval() + test_loss = 0.0 + correct = 0 + + num_batches = len(test_loader.dataset) # type: ignore + + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + # sum up batch loss + test_loss += F.nll_loss(output, target, reduction="sum").item() + # get the index of the max log-probability + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= num_batches + + print( + f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{num_batches} " + f"({100.0 * correct / num_batches:.0f}%)\n" + ) + return test_loss + + +def main(**kwargs): + """Main loop. Trains and then tests a model after each epoch.""" + # Training settings + + # note: could also use simple-parsing to parse the Config from the command-line: + # import simple_parsing + # from simple_parsing import parse + # config = parse(Config) + args = Args(**kwargs) + print(f"Args: {args}") + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + device = args.device + + train_kwargs = {"batch_size": args.batch_size} + test_kwargs = {"batch_size": args.test_batch_size} + if device.type == "cuda": + # Note: When using Orion with parallel workers (which is the case by default?), + # `num_workers` should be set to 0, because otherwise we get an error about daemonic + # processes having children, etc. + cuda_kwargs = {"num_workers": 0, "pin_memory": True, "shuffle": True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + normalization_means[args.dataset], normalization_stds[args.dataset] + ), + ] + ) + + data_dir = args.data_dir + dataset_class = getattr(datasets, args.dataset) + train_dataset = dataset_class( + str(data_dir), train=True, download=True, transform=transform + ) + test_dataset = dataset_class(str(data_dir), train=False, transform=transform) + train_loader = DataLoader(train_dataset, **train_kwargs) + test_loader = DataLoader(test_dataset, **test_kwargs) + + model = Net(n_classes=dataset_num_classes[args.dataset]).to(device) + optimizer = optim.Adam(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + test_loss = None + for epoch in range(1, args.epochs + 1): + train_epoch(args, model, device, train_loader, optimizer, epoch) + test_loss = test_epoch(model, device, test_loader) + scheduler.step() + + if args.save_model: + run_working_dir = Path(os.environ.get("ORION_WORKING_DIR", ".")) + # use the trial working dir to save the model. + torch.save(model.state_dict(), str(run_working_dir / "model.pt")) + return [dict(name="loss", type="objective", value=test_loss)] + + +# %% +# Controls for this example: +# +previous_experiment_n_runs = 10 +previous_experiment_settings = { + "dataset": "CIFAR100", + "epochs": 3, +} + +current_experiment_n_runs = 10 +current_experiment_settings = { + "dataset": "CIFAR10", + "epochs": 3, +} + +# We're using multiple seeds for a more robust comparison of with/without warm-starting. +n_seeds = 3 + +# The number of initial random suggestions that the optimization algorithm should do. +n_initial_random_suggestions = 5 + +# NOTE: This gets run in the tutorials directory +# NOTE: This needs to be a relative path, otherwise the CI runs will fail. +# Specify the database where the previous experiments are stored. We use a local PickleDB here. +previous_experiment_storage = { + "type": "legacy", + "database": { + "type": "pickleddb", + "host": "previous_db.pkl", + }, +} +current_experiment_storage = { + "type": "legacy", + "database": { + "type": "pickleddb", + "host": "current_db.pkl", + }, +} + + +previous_experiment = build_experiment( + name="previous_experiment", + space={"lr": "loguniform(1e-5, 1.0)"}, + storage=previous_experiment_storage, + algorithms={"random": {"seed": 1}}, + max_trials=previous_experiment_n_runs, + executor=SingleExecutor(), +) + +# %% +# Populate the initial experiment with some trials: + +previous_experiment.workon(main, **previous_experiment_settings) + +# %% +# Run a new experiment, without warm-starting (a.k.a. "cold-start"): + + +cold_experiments = [ + build_experiment( + name=f"cold_experiment_{seed}", + space={"lr": "loguniform(1e-5, 1.0)"}, + storage=current_experiment_storage, + executor=SingleExecutor(), + algorithms={ + "tpe": {"seed": seed, "n_initial_points": n_initial_random_suggestions} + }, + # algorithms={"robo_gp": {"seed": seed, "n_initial_points": n_initial_points}}, + max_trials=current_experiment_n_runs, + ) + for seed in range(n_seeds) +] +for exp in cold_experiments: + exp.workon(main, **previous_experiment_settings) + +#%% +# New experiment with warm-starting: + +assert previous_experiment.storage +assert previous_experiment.max_trials +warm_experiments = [ + build_experiment( + name=f"warm_experiment_{seed}", + space={"lr": "loguniform(1e-5, 1.0)"}, + storage=current_experiment_storage, + executor=SingleExecutor(), + max_trials=current_experiment_n_runs, + # NOTE: This n_initial_points is changed slightly, since it also counts the trials from + # the previous experiment. This is just so the comparison is a bit fairer. + # Both algorithms do random search for the first few trials of the current task and then + # optimize. + algorithms={ + "tpe": { + "seed": seed, + "n_initial_points": previous_experiment_n_runs + + n_initial_random_suggestions, + } + }, + # Pass the knowledge base to `build_experiment`, either with a configuration dictionary: + knowledge_base={"KnowledgeBase": {"storage": previous_experiment_storage}}, + # Or by instianting a KnowledgeBase and passing it directly: + # knowledge_base=KnowledgeBase(storage=previous_experiment.storage), + ) + for seed in range(n_seeds) +] + +for exp in warm_experiments: + exp.workon(main, **current_experiment_settings) + +# %% +# +# Compare the results: +# +# Here we use the :func:`orion.plotting.base.regrets` function to plot the results. +# This shows the performance of each variant. +# in blue, we have the single line which shows the results we had on the previous experiment. +# In yellow, we have the results of the new experiment (same model on the new dataset). +# In red, we have the warm-started experiment, where we give the algorithm access to the old data. +# The previous data gets annotated with a different task id when viewed by the algorithm, making it +# possible for the algorithm to learn what is common and what is different about each task. + +from orion.plotting.base import regrets + +fig = regrets( + { + "previous experiment": [previous_experiment], + "without warm-start": cold_experiments, + "with warm-start": warm_experiments, + }, +) +fig.show() +fig.write_image("../../docs/src/_static/warm_start_thumbnail.png") +fig + +# sphinx_gallery_thumbnail_path = '_static/warm_start_thumbnail.png' diff --git a/examples/tutorials/previous_db.pkl b/examples/tutorials/previous_db.pkl new file mode 100644 index 0000000000000000000000000000000000000000..2b69b030ad8478b402850481753481049389e41f GIT binary patch literal 14841 zcmc&*2Uruy`zEx|i;AM4*hQtJhaiFnDryc?9o>^JYsEYx55oTa09 z$p#8VVZvw^o=^n4gb7{vkObmFV%Pa+&Ek)vvMM@)h#oN%BOwqf8$W ziY0P?nSL}B1*4xaLJ|7gNGgU!?dObQAh9@3DB^?jL}82s+)E4!%k(%3zdT`NB+OHu zm6`Hk0VIu-@WXf#a2G^J)?IyHAGIO4k}3}Xw>O37#K5AkC^%XomLt~fl``S*g*>V9 zWC-fRTJ2+O`vp}5G6N18t}+%yeP!KaL~v}FP%7qN&gC*I5iEd3VwV_51a8I&LxZts zKPM6j6=&E4^(lxC#YkY0JVY)tkAfxOrv!3AF_FuN0*EUL`LK+S(K#`#!1_=@+b`e3*d|{M}7>4r%4G(zP^1)=`+8CIEf&Qe=xHKD}^AYj=I_Oo-vE(^oD4rCT5DKtH%|ZVFK!M6hq++?u zm?w;iiG(GfA<1+>3y}4v(O57YqSE+$x*MHBBlAd9I-f>my3ySr3X>*~%S>>cR?(z7 zU`4kptRVU~lI4=z&skZjVkP?{R+jsgRTs93ldj+Hm=&{bCPj-oV}(rIPU(~tGI3>m zC#+!QzZENFm_?z|84Ru)g(+ac0y3Y$a3e7(EEYuNv$!PfSkdbWD^GvOV6snc`<#{i zDpn4D#7e1e*@wJ)ThopX`vNPjCZSH9uwvHD)GDJhRsgLn)4t3KuKc%Rg-?NLOd12C zky$XG%H&b`Q~~HP5CetDd>M{2V z^f((AaXX<0&{wpkGkSnQPQ#b!!Il43^zc~%7LyKxe>Bir7*r~i!Kct+3Pj^Gd2|Yk zsvSLsU7^S2N#e5iJq4fB!@$ww_9;E?3VNQ|TsSj)=GuA7=mcidyKn6&n{{|_ryP@d`?d>j-Gv=(z9Pd&)&60hD-LBe1RTE z;;`&a=mGRK-tLSZU~ui&m+8Tk|5o&n7%;?yXgrX(G3k5(ok9}86fTR)bE7e-d{|3* zOu9mkWl+_T%F(+&r{^w?o~lpjsZr2VT}_P4^vUU%9xEf_wNahX<7i~i>9fbukysnu z2|ZZ(Z$%Gy=ujyTOmTA~yV3bnH!6b%zSEg3I+M--J*tD=19z+}II;~>4iQv#Mu|yOPe#wX`VGk9W zw9rg?;q+gn3ge>Cw}2KEh$@6ivD6>+_VEKksd52Tm(!7CgDE;zsFeDD)Pr!yXjU~l z)~$M}gfA;ym2PDMDykN~aMx81aI%^DDkhs5aWEkcF_M`n&~XV+h;ulQlwx@MND3&- zN_h`Am{MzeYQZi|4zy1qM=cQo@1{!cQ&80qbB;L6P|Xl=ZI=rXGA0x%oM3{ZDfqJx z0wGrv#)naP61Hqx!E?*&b_>=9D z72<$Dsj?iTK8dMg$S@$(j!@OjStTI2;s>b!47J#Xf`L#$83tD7Gvriap2FC}9~sTo zRl=lEJo(BnVR{e_CQL)Z!GyEK!K9X}w1r8ferW}hN?pTX(o?l24bPFU4U<}+!Na5$ zb?`7@mKaQ!exxl-Dmh9UnACC?29th!Ytp1Dd}Ww0p#%pLCa&OM!dc>AQtLR{!laUg zw1G*jKVdK#C}Gl=M|^FV)Pe;bCbh_chY7R9V8ZkiZDCT$Gg`r<(t$9T43#ix%*wtp zOqg(hg9#HGa4_L4aWJWM4Q*jk$s}6Aq|#$Bn21W4G-iHZ8z%Jt9uJdx(2s`+v&3M+ z^aX8UQppus!KBhfFqn+mJDsMi{42wR4Ww}}VS{fROgKv%OzLsIwlJw=1Fc|E=@S@C z#!8qpj{d$jOzHtG9wzl57!MO>iNS=8x3z^yHGJ0!CY4Tr!DNEMBo!kBG;`ETlp_Y( z4L&}H4Y5)rnMikh7O9?E#-RaLGqk7mI5fb?_FNvp&@fe}@SmW8xFD_w2~8VSGv``` zubP?1!vFtg*NC?=TpxV6>N#6{xaxUce7Kk;7Va05W@H)&H}Jb=$;d2axDYN}TuO_L z5xzq5i_8Y$iqLSsTegeLSBAR)A1;=f;==8E5{qPkaI?{HzgyOdIq6~d#d+6USUkj(GL#KWcI+kMLLhpJq7Ve`AeQbN^Und8X8H>Pj zOQmw`C?om>l36S&Q!cZKmPW<*NtOE$xbns5vRgL6F9ytBgiqH6YeWqA3VCw{SfVGT$urWG7J4#*huq)a ze^MKGDc;yR5;-@+`>?T&&Z>mse4}L-clYYG-m#h#`Lgct7}wDr_9sdvWEfDq1mUC7 zQYtpJjm-Kjmw9;4GVig6u4GqN9v&#UXhv8!iSuw3)Tfa#RF^p=ife1fFke5E(dW)b1;@AI&UyfBbC8>AI^P z{<9A|l(7DB72Q47lFD1-aU^&D_Q8|yz23YoJ>wQ3>8-$f=G+lgS><}<-1K96P4;}) z2@QE@S^943&Y$On2pdXusDnNG_Dkt`VB??>`6oxutJyI!F~0EN+6|Ir1)Gj;UGckf zP;o=sva9Ru?k_pwUc?Hp?em5nSXj9?E+)YQdL)R7?76!BLjY`Zy6IASt50I+>r3M) z9!uXRyt})1SLDm%Et!J)O@}Ib8V_mFbERCp6sNx}c}xbkTWWLBttZ8GUgdoa#d{5d zp3Xfs!QM1|V!_lAOKbL~`dNM$vU%N`Ld$6KyXVW3PcOJ*J9L(j*O0$&Ha1M)#_xN* z!#3~K^nrcTPB#53I$wFO=H$%Ne-1di#=t)!@z0qbLPxJo*%e%TaJ#{RyLr)7HGbRO z))udrmT`l8+CC&?+2(oq0og|`>W=75>6;=LMB>US2Uo`#ePUiYmDHZeUNRwH9Ql#Gf@H6sTQA=yt&XP+D-XGic^FQ?$y5~=> zvfEf^!yg=eJuQEF_29IW^*YBMr<c1dft3u=kSwvAaV6kGD(AeceM$}%mTJ_>_i*0L&p@;ax3^L+k!A&TO<2%nZ7nx)teuVYKJETI(kP z%`ab1jTo?cs$-560*v zJU+91r2pEG^V@!}eUNZzw=|i=iP`z>BKAUqH zyNBO2>}l^)n`7N#oBQ%!QhiPS=Kf3URt#8Y>;OrP9$(R2oZN>NE9L8?660GJNV8)H zyBm}*`mw)1ygsy{_Lq^4eg~{7nTfj=GF> z!%%Qu;u9O=2K&~;V%=G9wrm;d2Mca(Sbz4K+r*lkl=2f>Bg{RnUD{+H_G)P8#FkS} z|9&?m-Oh7F|C@1HAzAD(HS1UeyMJG{E;9|Rvz$J>?nMAqrp*U$(3HQyY?i$;x!EF zzxKxdtY2Ra2u}?zNtidAn+-b~kmOI%JFPDHgA!dD4zi8A(B2=K3Kf2U`oLI-g%!Yx}SdjXJOVO)=@* z&dtlt7xdb!Q+$42BYUFNiG?N4z3+H4Yy!MF{1aQ;Et(&j)}OP?`e$O=+t!wZh;(*r z>^ryk84h#iN?+8M{5pMoYR68h&Z+$V03)W<9*3qmU(GO0} zId~-G_KY*$6@M2`>0e&tLpoI?b+3tAb^PuMJ5Op^@}VrdT&C#QIAUh&fFEaj>Go)d zi5=NCqPyKBn-x9MwXb$gM$5qD?Y$3NIAa6NUlX`z*6BA}xZN!0zYWy$8N1)}$hOIuV-g}9 z{%s2VbIaNa3AeeBU0FKEFW#Zfcy#dFhwpS!H{|pvC^Ks#JzUAtHCx_do!8> zGqN9M49(mg9&jRUNuEXhU{gw;3ew%`J-_}qVF`TuL&`4if?SSy?yH=e*{dA3GNQ@V z^n$?5Q&sNF#aE~2T;EV!&2HvR;Gg@|bjn=Y<*Uq_E)`cU8?YzL<7nj0sE7%UQ9+gt z$JaK$U8H;I)T#+V)cyP0;tL0wgMAxsXsZ|W3;M0`pK$8Cvr5v9bs#LCCiik)2Epp2%~ zdv9}J?Aq);@j>ICnOBOF9hVs5Vq?T>~G2;!ccH)#^VZEnkswBp)Ht0$EFm7r&}ZFi&HtkMk+d^U|6OSWVOCUT>; z^y@kM=?Cxehhv*P1u4{grwN;CUB)fi>px@Md`E{1`c=G9Pi11?Mq+8<%LcL2;LC<@ z2nj3RPvbSS2}eRp92W37{yKFnTtl1w4+nbnoIdD3LpjLg*aO%)zKX>N;7uZzq8p@? zD-bwQP)rQi@Pn>IFjD`JnRd3%OJ>BChJlO}Yz&j>iD7i>mvV!b=G6;elaCf_9I~~T z`Br$+xPS(D(l|M+@TAd7Q6~-ar|ospV1?F6(-fn!PMWNMCyldk3QroXfG3R;RSHiU zt$-(ub2$o68m-XiG);FXqtj#sJZYQ}P2g=fo_@8k@W_nF+o_U oR0N^BB0t?w_wf=ObkiKV%Iu3<<-krZDOiI7Tt>_NrCgW)0poKUF8}}l literal 0 HcmV?d00001 diff --git a/setup.py b/setup.py index 52d8b61dd..897b45727 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,9 @@ "sphinxcontrib.httpdomain", "sphinx-autoapi", "sphinx_gallery", + "torch", + "torchvision", + "tqdm", ], "dask": ["dask[complete]"], "track": ["track @ git+https://github.com/Delaunay/track@master#egg=track"], diff --git a/src/orion/core/io/database/pickleddb.py b/src/orion/core/io/database/pickleddb.py index 6650daf84..ca8efd0a1 100644 --- a/src/orion/core/io/database/pickleddb.py +++ b/src/orion/core/io/database/pickleddb.py @@ -101,6 +101,9 @@ def __init__(self, host="", timeout=60, *args, **kwargs): host = DEFAULT_HOST super().__init__(host) + # NOTE: Save the original value of `host`, so this object can be pickled and unpickled on + # different machines more easily if it's a relative path. + self.original_host = host self.host = os.path.abspath(host) self.timeout = timeout @@ -255,6 +258,17 @@ def get_defaults(cls): """ return {"host": DEFAULT_HOST} + def __getstate__(self): + """Return state to be pickled.""" + return self.__dict__.copy() + + def __setstate__(self, state: dict) -> None: + """Restore state from pickled object.""" + self.__dict__.update(state) + # NOTE: `original_host` might not be present when unpickling old databases. + self.original_host = state.setdefault("original_host", self.host) + self.host = os.path.abspath(self.original_host) + local_file_systems = ["ext2", "ext3", "ext4", "ntfs"] diff --git a/src/orion/core/io/experiment_builder.py b/src/orion/core/io/experiment_builder.py index e462baeaa..85f7ca331 100644 --- a/src/orion/core/io/experiment_builder.py +++ b/src/orion/core/io/experiment_builder.py @@ -800,6 +800,8 @@ def consolidate_config(self, name: str, version: int | None, config: dict): merge_algorithm_config(config, new_config) # TODO: Remove for v0.4 merge_producer_config(config, new_config) + if "knowledge_base" in new_config: + config["knowledge_base"] = new_config["knowledge_base"] config.setdefault("name", name) config.setdefault("version", version) diff --git a/tox.ini b/tox.ini index e378c8d84..1b77ed4d3 100644 --- a/tox.ini +++ b/tox.ini @@ -218,6 +218,8 @@ commands = description = Invoke sphinx to build documentation and API reference basepython = python3 skip_install = false +deps = + -rdocs/requirements.txt extras = docs pb2 From 5b9f2f8618f22e92dcd19b8fb16e238676705d41 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 20 Dec 2022 15:03:55 -0500 Subject: [PATCH 2/4] Ignore KB instantiation errors if mode != "x" Signed-off-by: Fabrice Normandin --- src/orion/core/io/experiment_builder.py | 46 ++++++++- .../core/io/test_experiment_builder.py | 95 ++++++++++++++++++- tox.ini | 2 +- 3 files changed, 132 insertions(+), 11 deletions(-) diff --git a/src/orion/core/io/experiment_builder.py b/src/orion/core/io/experiment_builder.py index 85f7ca331..43b0d32de 100644 --- a/src/orion/core/io/experiment_builder.py +++ b/src/orion/core/io/experiment_builder.py @@ -84,6 +84,8 @@ import typing from typing import Any, TypeVar +from typing_extensions import Literal + import orion.core from orion.algo.base import BaseAlgorithm, algo_factory from orion.algo.space import Space @@ -207,7 +209,34 @@ def _instantiate_space(config: Space | dict[str, Any]) -> Space: return SpaceBuilder().build(config) -def _instantiate_knowledge_base(kb_config: dict[str, Any]) -> KnowledgeBase: +@typing.overload +def _instantiate_knowledge_base( + kb_config: dict[str, Any], + ignore_instantiation_errors: Literal[True] = True, +) -> KnowledgeBase | None: + ... + + +@typing.overload +def _instantiate_knowledge_base( + kb_config: dict[str, Any], + ignore_instantiation_errors: Literal[False] = False, +) -> KnowledgeBase: + ... + + +@typing.overload +def _instantiate_knowledge_base( + kb_config: dict[str, Any], + ignore_instantiation_errors: bool, +) -> KnowledgeBase | None: + ... + + +def _instantiate_knowledge_base( + kb_config: dict[str, Any], + ignore_instantiation_errors: bool = True, +) -> KnowledgeBase | None: """Instantiate the Knowledge base from its configuration.""" if len(kb_config) != 1: raise ConfigurationError( @@ -232,10 +261,17 @@ def _instantiate_knowledge_base(kb_config: dict[str, Any]) -> KnowledgeBase: kb_kwargs = kb_config[kb_type_name] # Instantiate the storage that is required for the KB. storage_config = kb_kwargs["storage"] - if isinstance(storage_config, dict): - storage = setup_storage(storage_config) - kb_kwargs["storage"] = storage - return kb_type(**kb_kwargs) + try: + if isinstance(storage_config, dict): + storage = setup_storage(storage_config) + kb_kwargs["storage"] = storage + return kb_type(**kb_kwargs) + except (FileNotFoundError, PermissionError) as err: + if not ignore_instantiation_errors: + log.error("Unable to instantiate the KnowledgeBase.") + raise err + log.warning("KnowledgeBase could not be instantiated.") + return None def _instantiate_algo( diff --git a/tests/unittests/core/io/test_experiment_builder.py b/tests/unittests/core/io/test_experiment_builder.py index 5611eecec..ab6ba3c52 100644 --- a/tests/unittests/core/io/test_experiment_builder.py +++ b/tests/unittests/core/io/test_experiment_builder.py @@ -2,9 +2,11 @@ """Example usage and tests for :mod:`orion.core.io.experiment_builder`.""" from __future__ import annotations +import contextlib import copy import datetime import logging +import typing from pathlib import Path import pytest @@ -28,11 +30,15 @@ UnsupportedOperation, ) from orion.core.worker.algo_wrappers import AlgoWrapper +from orion.core.worker.experiment_config import ExperimentConfig from orion.core.worker.warm_start import KnowledgeBase from orion.storage.base import setup_storage from orion.storage.legacy import Legacy from orion.testing import OrionState +if typing.TYPE_CHECKING: + from _pytest.logging import LogCaptureFixture + def count_experiments(): """Count experiments in storage""" @@ -45,10 +51,8 @@ def space(): return {"x": "uniform(-50,50)"} -@pytest.fixture() -def python_api_config(): - """Create a configuration without the cli fluff.""" - new_config = dict( +def _python_api_config() -> ExperimentConfig: + return ExperimentConfig( name="supernaekei", version=1, space={"x": "uniform(0,10)"}, @@ -79,9 +83,14 @@ def python_api_config(): _id="fasdfasfa", something_to_be_ignored="asdfa", refers=dict(root_id="supernaekei", parent_id=None, adapter=[]), + knowledge_base=None, ) - return new_config + +@pytest.fixture() +def python_api_config(): + """Create a configuration without the cli fluff.""" + return _python_api_config() @pytest.fixture() @@ -1374,6 +1383,82 @@ def test_load_unavailable_algo(algo_unavailable_config, capsys): experiment_builder.build("supernaekei") +def _exp_config_with_knowledge_base_at(kb_pickle_path: str | Path) -> ExperimentConfig: + config = _python_api_config() + config["knowledge_base"] = { + "KnowledgeBase": { + "storage": { + "type": "legacy", + "database": {"type": "pickleddb", "host": str(kb_pickle_path)}, + }, + }, + } + return config + + +def test_load_uninstantiable_knowledge_base(caplog: LogCaptureFixture, tmp_path: Path): + """Check that an error is raised when trying to create an experiment where the KB has absolute + paths that aren't on this machine. + """ + + @contextlib.contextmanager + def _logs_warning_about_kb(): + # Now, if trying to open in read mode, but the path doesn't exist, then the exception + # should be caught and a warning should be printed. + caplog.clear() + with caplog.at_level(logging.WARNING): + yield + assert len(caplog.records) >= 1 + assert "KnowledgeBase could not be instantiated" in caplog.text + + @contextlib.contextmanager + def _setup(kb_host: str | Path): + exp_config = _exp_config_with_knowledge_base_at(kb_pickle_path=kb_host) + with OrionState(experiments=[exp_config]): + yield exp_config + + with _setup(kb_host="/I/do/not/exist.pkl") as exp_config_with_invalid_kb_host: + caplog.clear() + with pytest.raises(PermissionError, match="/I"), caplog.at_level(logging.ERROR): + experiment_builder.build( + name=exp_config_with_invalid_kb_host["name"], mode="x" + ) + assert len(caplog.records) >= 1 + + # Now, if trying to open in read mode, but the path doesn't exist, then the exception + # should be caught and a warning should be printed. + with _logs_warning_about_kb(): + experiment = experiment_builder.load( + exp_config_with_invalid_kb_host["name"], mode="r" + ) + assert experiment.knowledge_base is None + + # Now, use a path that could be written to, but doesn't exist. + host = tmp_path / "some_folder" / "db.pkl" + with _setup(kb_host=host) as exp_config_with_absent_kb_file: + assert not host.exists() + + # Try to load the experiment, but the KB points to pickledb host files that don't exist! + # NOTE: This shouldn't create the files, or any of the parent directories! + with _logs_warning_about_kb(): + experiment_builder.load(exp_config_with_absent_kb_file["name"], mode="r") + assert not host.exists() + assert not host.parent.exists() + + with _logs_warning_about_kb(): + experiment_builder.build(exp_config_with_absent_kb_file["name"], mode="r") + assert not host.exists() + assert not host.parent.exists() + + # Try to build the experiment to run it writing, but the KB points to pickledb host files + # that don't exist! + # NOTE: This shouldn't create the files, or any of the parent directories! + with pytest.raises(FileNotFoundError, match=str(host)): + experiment_builder.build(exp_config_with_absent_kb_file["name"], mode="x") + assert not host.exists() + assert not host.parent.exists() + + class TestInitExperimentReadWrite: """Create new Experiment instance that only supports read/write.""" diff --git a/tox.ini b/tox.ini index 1b77ed4d3..c5c0467d8 100644 --- a/tox.ini +++ b/tox.ini @@ -73,7 +73,7 @@ commands = [testenv:backward-compatibility] description = Run all versions of Orion to assert backward compatibility setenv = COVERAGE_FILE=.coverage.backward_compatibility -passenv = CI ORION_DB_TYPE +passenv = CI,ORION_DB_TYPE deps = coverage commands = From 155cd139c2981bf4362a231a0f13ae4f578ffa1c Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 21 Dec 2022 16:06:19 -0500 Subject: [PATCH 3/4] Small typing improvements to experiment_builder.py Signed-off-by: Fabrice Normandin --- src/orion/core/io/experiment_builder.py | 44 +++++++++++++++++-------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/src/orion/core/io/experiment_builder.py b/src/orion/core/io/experiment_builder.py index 43b0d32de..d025d79f1 100644 --- a/src/orion/core/io/experiment_builder.py +++ b/src/orion/core/io/experiment_builder.py @@ -105,12 +105,19 @@ RaceCondition, ) from orion.core.worker.experiment import Experiment, Mode -from orion.core.worker.experiment_config import ExperimentConfig +from orion.core.worker.experiment_config import ( + ExperimentConfig, + MetaData, + PartialExperimentConfig, + RefersConfig, +) from orion.core.worker.primary_algo import create_algo from orion.core.worker.warm_start import KnowledgeBase from orion.storage.base import setup_storage if typing.TYPE_CHECKING: + from typing_extensions import Unpack + from orion.core.evc.adapters import CompositeAdapter from orion.storage.base import BaseStorageProtocol log = logging.getLogger(__name__) @@ -121,7 +128,7 @@ ## -def clean_config(name: str, config: dict, branching: dict | None): +def clean_config(name: str, config: PartialExperimentConfig, branching: dict | None): """Clean configuration from hidden fields (ex: ``_id``) and update branching if necessary""" log.debug("Cleaning config") @@ -390,7 +397,7 @@ def _fetch_config_version( ### -def get_cmd_config(cmdargs) -> ExperimentConfig: +def get_cmd_config(cmdargs) -> dict: """Fetch configuration defined by commandline and local configuration file. Arguments of commandline have priority over options in configuration file. @@ -459,7 +466,7 @@ def build_from_args(cmdargs): return builder.build(**cmd_config) -def get_from_args(cmdargs, mode="r"): +def get_from_args(cmdargs: dict, mode: Literal["r", "w"] = "r"): """Build an experiment view based on commandline arguments .. seealso:: @@ -477,7 +484,7 @@ def get_from_args(cmdargs, mode="r"): name = cmd_config.get("name") version = cmd_config.get("version") - + assert isinstance(name, str) return builder.load(name, version, mode=mode) @@ -486,7 +493,7 @@ def build( version: int | None = None, branching: dict | None = None, storage: BaseStorageProtocol | dict | None = None, - **config, + **config: Unpack[PartialExperimentConfig], ): """Build an experiment. @@ -498,10 +505,17 @@ def build( if storage is None: storage = setup_storage() - return ExperimentBuilder(storage).build(name, version, branching, **config) + config["name"] = name + config["version"] = version + return ExperimentBuilder(storage).build(branching=branching, **config) -def load(name, version=None, mode="r", storage=None): +def load( + name: str, + version=None, + mode: Literal["r", "w"] = "r", + storage: BaseStorageProtocol | dict | None = None, +) -> Experiment: """Load an experiment. .. seealso:: @@ -549,7 +563,7 @@ def build( name: str, version: int | None = None, branching: dict | None = None, - **config, + **config: Unpack[PartialExperimentConfig], ) -> Experiment: """Build an experiment object @@ -671,7 +685,9 @@ def _get_conflicts(self, experiment: Experiment, branching: dict): return conflicts - def load(self, name: str, version: int | None = None, mode: Mode = "r"): + def load( + self, name: str, version: int | None = None, mode: Literal["r", "w"] = "r" + ): """Load experiment from database An experiment view provides all reading operations of standard experiment but prevents the @@ -813,7 +829,9 @@ def _attempt_branching(self, conflicts, experiment, version, branching): return branched_experiment - def consolidate_config(self, name: str, version: int | None, config: dict): + def consolidate_config( + self, name: str, version: int | None, config: PartialExperimentConfig + ): """Merge together given configuration with db configuration matching for experiment (``name``, ``version``) """ @@ -913,8 +931,8 @@ def create_experiment( max_trials: int | None = None, max_broken: int | None = None, working_dir: str | None = None, - metadata: dict | None = None, - refers: dict | None = None, + metadata: MetaData | None = None, + refers: RefersConfig | None = None, producer: dict | None = None, knowledge_base: KnowledgeBase | dict | None = None, user: str | None = None, From 687447b8982834c3dcad3040574944f981120a36 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 21 Dec 2022 16:41:43 -0500 Subject: [PATCH 4/4] Improve tests, leave xfailed test for future work Signed-off-by: Fabrice Normandin --- src/orion/core/io/experiment_builder.py | 4 +- .../core/io/test_experiment_builder.py | 83 +++++++++++-------- 2 files changed, 50 insertions(+), 37 deletions(-) diff --git a/src/orion/core/io/experiment_builder.py b/src/orion/core/io/experiment_builder.py index d025d79f1..2bc989c56 100644 --- a/src/orion/core/io/experiment_builder.py +++ b/src/orion/core/io/experiment_builder.py @@ -981,7 +981,9 @@ def _default(v: T | None, default: V) -> T | V: space = _instantiate_space(space) max_trials = _default(max_trials, orion.core.config.experiment.max_trials) if isinstance(knowledge_base, dict): - knowledge_base = _instantiate_knowledge_base(knowledge_base) + knowledge_base = _instantiate_knowledge_base( + knowledge_base, ignore_instantiation_errors=mode != "x" + ) instantiated_algorithm = _instantiate_algo( space=space, max_trials=max_trials, diff --git a/tests/unittests/core/io/test_experiment_builder.py b/tests/unittests/core/io/test_experiment_builder.py index ab6ba3c52..bbec1ad38 100644 --- a/tests/unittests/core/io/test_experiment_builder.py +++ b/tests/unittests/core/io/test_experiment_builder.py @@ -1396,42 +1396,54 @@ def _exp_config_with_knowledge_base_at(kb_pickle_path: str | Path) -> Experiment return config -def test_load_uninstantiable_knowledge_base(caplog: LogCaptureFixture, tmp_path: Path): - """Check that an error is raised when trying to create an experiment where the KB has absolute - paths that aren't on this machine. +@contextlib.contextmanager +def _logs_warning_about_kb(caplog: LogCaptureFixture): + # Now, if trying to open in read mode, but the path doesn't exist, then the exception + # should be caught and a warning should be printed. + caplog.clear() + with caplog.at_level(logging.WARNING): + yield + assert len(caplog.records) >= 1 + assert "KnowledgeBase could not be instantiated" in caplog.text + + +@contextlib.contextmanager +def _setup(kb_host: str | Path): + exp_config = _exp_config_with_knowledge_base_at(kb_pickle_path=kb_host) + with OrionState(experiments=[exp_config]): + yield exp_config + + +def test_load_uninstantiable_knowledge_base(caplog: LogCaptureFixture): + """Check that if an experiment is loaded in read mode and the knowledge base cannot be + instantiated, then a warning is printed. """ - @contextlib.contextmanager - def _logs_warning_about_kb(): - # Now, if trying to open in read mode, but the path doesn't exist, then the exception - # should be caught and a warning should be printed. - caplog.clear() - with caplog.at_level(logging.WARNING): - yield - assert len(caplog.records) >= 1 - assert "KnowledgeBase could not be instantiated" in caplog.text - - @contextlib.contextmanager - def _setup(kb_host: str | Path): - exp_config = _exp_config_with_knowledge_base_at(kb_pickle_path=kb_host) - with OrionState(experiments=[exp_config]): - yield exp_config - with _setup(kb_host="/I/do/not/exist.pkl") as exp_config_with_invalid_kb_host: - caplog.clear() - with pytest.raises(PermissionError, match="/I"), caplog.at_level(logging.ERROR): - experiment_builder.build( - name=exp_config_with_invalid_kb_host["name"], mode="x" + # experiment_builder.build uses ExperimentBuilder.create_experiment with mode="x", so this + # should try to load the KnowledgeBase and raise an error. + with pytest.raises(PermissionError, match="/I"): + experiment = experiment_builder.build( + name=exp_config_with_invalid_kb_host["name"], ) - assert len(caplog.records) >= 1 + assert experiment.knowledge_base is None # Now, if trying to open in read mode, but the path doesn't exist, then the exception # should be caught and a warning should be printed. - with _logs_warning_about_kb(): + with _logs_warning_about_kb(caplog): experiment = experiment_builder.load( exp_config_with_invalid_kb_host["name"], mode="r" ) - assert experiment.knowledge_base is None + assert experiment.knowledge_base is None + + +@pytest.mark.xfail( + reason="See https://github.com/Epistimio/orion/issues/1053", raises=AssertionError +) +def test_attempt_to_load_knowledge_base_doest_create_files( + caplog: LogCaptureFixture, tmp_path: Path +): + """TODO: https://github.com/Epistimio/orion/issues/1053""" # Now, use a path that could be written to, but doesn't exist. host = tmp_path / "some_folder" / "db.pkl" @@ -1440,21 +1452,20 @@ def _setup(kb_host: str | Path): # Try to load the experiment, but the KB points to pickledb host files that don't exist! # NOTE: This shouldn't create the files, or any of the parent directories! - with _logs_warning_about_kb(): + with _logs_warning_about_kb(caplog): experiment_builder.load(exp_config_with_absent_kb_file["name"], mode="r") assert not host.exists() assert not host.parent.exists() - with _logs_warning_about_kb(): - experiment_builder.build(exp_config_with_absent_kb_file["name"], mode="r") - assert not host.exists() - assert not host.parent.exists() - - # Try to build the experiment to run it writing, but the KB points to pickledb host files - # that don't exist! - # NOTE: This shouldn't create the files, or any of the parent directories! + # Try to build the experiment (execute mode), but the KB points to pickledb host files + # that don't exist. + # NOTE: A bit trickier. Should it create the KB files (and parent directories) so the KB is + # just created and empty? + # In my (@lebrice) view, the KB should always be treated as "read-only", so this should + # raise a FileNotFoundError. + # This should fail to read the pickle file (because it doesn't exist). with pytest.raises(FileNotFoundError, match=str(host)): - experiment_builder.build(exp_config_with_absent_kb_file["name"], mode="x") + experiment_builder.build(exp_config_with_absent_kb_file["name"]) assert not host.exists() assert not host.parent.exists()