Skip to content

Commit

Permalink
deploy: 306eb24
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffjennings committed Nov 30, 2023
1 parent 4ca5616 commit 00fc1d9
Show file tree
Hide file tree
Showing 18 changed files with 180 additions and 115 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 20 additions & 21 deletions _modules/mpol/crossval.html
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ <h1>Source code for mpol.crossval</h1><div class="highlight"><pre>
<span class="k">for</span> <span class="n">kk</span><span class="p">,</span> <span class="p">(</span><span class="n">train_set</span><span class="p">,</span> <span class="n">test_set</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">split_iterator</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_verbose</span><span class="p">:</span>
<span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Cross-validation: k-fold </span><span class="si">{}</span><span class="s2"> of </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">kk</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_kfolds</span><span class="p">)</span>
<span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Cross-validation: k-fold </span><span class="si">{}</span><span class="s2"> of </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">kk</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_kfolds</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="p">)</span>

<span class="c1"># if hasattr(self._device,&#39;type&#39;) and self._device.type == &#39;cuda&#39;: # TODO: confirm which objects need to be passed to gpu</span>
Expand Down Expand Up @@ -510,15 +510,14 @@ <h1>Source code for mpol.crossval</h1><div class="highlight"><pre>

<span class="c1"># store objects from the most recent kfold for diagnostics </span>
<span class="bp">self</span><span class="o">.</span><span class="n">_model</span> <span class="o">=</span> <span class="n">model</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_train_figure</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">train_figure</span>

<span class="c1"># collect objects from this kfold to store</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_store_cv_diagnostics</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_diagnostics</span><span class="p">[</span><span class="s2">&quot;loss_histories&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss_history</span><span class="p">)</span>
<span class="c1"># update regularizer strength values</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_regularizers</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">regularizers</span>
<span class="c1"># store the most recent train figure for diagnostics</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_train_figure</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">train_figure</span>

<span class="c1"># run testing</span>
<span class="n">all_scores</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">trainer</span><span class="o">.</span><span class="n">test</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model</span><span class="p">,</span> <span class="n">test_set</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_diagnostics</span><span class="p">[</span><span class="s2">&quot;models&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_diagnostics</span><span class="p">[</span><span class="s2">&quot;regularizers&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_regularizers</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_diagnostics</span><span class="p">[</span><span class="s2">&quot;loss_histories&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss_history</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_diagnostics</span><span class="p">[</span><span class="s2">&quot;train_figures&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_train_figure</span><span class="p">)</span>

<span class="c1"># average individual test scores to get the cross-val metric for chosen</span>
<span class="c1"># hyperparameters</span>
Expand All @@ -527,33 +526,33 @@ <h1>Source code for mpol.crossval</h1><div class="highlight"><pre>
<span class="s2">&quot;std&quot;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">all_scores</span><span class="p">),</span>
<span class="s2">&quot;all&quot;</span><span class="p">:</span> <span class="n">all_scores</span><span class="p">,</span>
<span class="p">}</span>

<span class="k">return</span> <span class="n">cv_score</span></div>

<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">model</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;SimpleNet class instance&quot;&quot;&quot;</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;For the most recent kfold, trained model (`SimpleNet` class instance)&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model</span>

<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">regularizers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Dict containing regularizers used and their strengths&quot;&quot;&quot;</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;For the most recent kfold, dict containing regularizers used and their strengths&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_regularizers</span>

<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">diagnostics</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Dict containing diagnostics of the cross-validation loop&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_diagnostics</span>

<span class="k">def</span> <span class="nf">train_figure</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;For the most recent kfold, (fig, axes) showing training progress&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_train_figure</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">split_figure</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;(fig, axes) of train/test splitting diagnostic figure&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_split_figure</span>

<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">train_figure</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;(fig, axes) of most recent training diagnostic figure&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_train_figure</span></div>
<span class="k">def</span> <span class="nf">diagnostics</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Dict containing diagnostics of the cross-validation loop across all kfolds: models, regularizers, loss values, training figures&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_diagnostics</span></div>


<div class="viewcode-block" id="RandomCellSplitGridded"><a class="viewcode-back" href="../../api.html#mpol.crossval.RandomCellSplitGridded">[docs]</a><span class="k">class</span> <span class="nc">RandomCellSplitGridded</span><span class="p">:</span>
Expand Down
Loading

0 comments on commit 00fc1d9

Please sign in to comment.