diff --git a/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py b/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py index bbeff1f..bec0f64 100644 --- a/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py +++ b/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py @@ -261,6 +261,34 @@ def __call__(self, inp_features): return self.mod(inp_features, self.perm_spec.unfreeze()) +class HybridMLPNFN(nn.Module): + """MLP + NFN Lopt.""" + + in_channels: int + hidden_channels: int + out_channels: int + num_layers: int + perm_spec: Any + ptwise_init: bool = False + + def setup(self): + out_channels, hidden_channels = self.out_channels, self.hidden_channels + + self.mlp = MLPForOpt(hidden_channels, hidden_channels, self.num_layers - 1) + + def make_layer(out_chan, in_chan): + if self.ptwise_init: + return universal_layers.PointwiseInitNFLinear(out_chan, in_chan) + else: + return universal_layers.NFLinear(out_chan, in_chan, w_init='lecun') + + self.final = make_layer(out_channels, hidden_channels) + + def __call__(self, inp_features): + features = universal_layers.nf_relu(self.mlp(inp_features)) + return self.final(features, self.perm_spec.unfreeze()) + + class SGDControl(lopt_base.LearnedOptimizer): """SGD where per-parameter learning rates are controlled by a network.""" @@ -457,7 +485,13 @@ class ResidualOptNFN(ResidualOpt): """NFN learning a residual on base optimizer.""" def __init__( - self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False, pos_emb=False + self, + task, + step_mult=0.1, + out_mult=1e-4, + ptwise_init=False, + pos_emb=False, + hybrid=False, ): example_params = task.init(jax.random.PRNGKey(0)) if 'conv2_d' in example_params: @@ -468,15 +502,25 @@ def __init__( perm_spec = make_hk_transformer_perm_spec(example_params) else: perm_spec = make_hk_perm_spec(example_params) - network = UnivNFNForOpt( - in_channels=19, - hidden_channels=32, - out_channels=1, - num_layers=4, - perm_spec=perm_spec, - ptwise_init=ptwise_init, - pos_emb=pos_emb, - ) + if hybrid: + assert not pos_emb + network = HybridMLPNFN( + in_channels=19, + hidden_channels=32, + out_channels=1, + num_layers=4, + perm_spec=perm_spec, + ) + else: + network = UnivNFNForOpt( + in_channels=19, + hidden_channels=32, + out_channels=1, + num_layers=4, + perm_spec=perm_spec, + ptwise_init=ptwise_init, + pos_emb=pos_emb, + ) super().__init__( network, example_params, step_mult=step_mult, out_mult=out_mult )