diff --git a/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py b/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py index d3ffeab..12649df 100644 --- a/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py +++ b/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py @@ -35,7 +35,6 @@ from learned_optimization.learned_optimizers import base as lopt_base from learned_optimization.learned_optimizers import common from learned_optimization.optimizers import base as opt_base -from learned_optimization.research.univ_nfn.nfn import ff_layers as nf_layers from learned_optimization.research.univ_nfn.nfn import universal_layers from learned_optimization.research.univ_nfn.nfn import utils as nfu @@ -89,42 +88,6 @@ class SimpleOptState(flax.struct.PyTreeNode): state: Any -def flax_to_hk(input_dict): - """Maps flax parameter structure to haiku parameter structure. - - Example: - >>> input_dict = { - ... 'params': { - ... 'Dense_0': {'kernel': W0, 'bias': b0}, - ... 'Dense_1': {'kernel': W1, 'bias': b1} - ... } - ... } - >>> transform_dict(input_dict) - {'mlp/~/linear_0': {'w': W0, 'b': b0}, 'mlp/~/linear_1': {'w': W1, 'b': b1}} - """ - params = input_dict.get('params', {}) - output_dict = {} - for i, (_, layer_data) in enumerate(params.items()): - # Constructing new key and sub-dictionary format - new_key = f'mlp/~/linear_{i}' - new_data = {'w': layer_data['kernel'], 'b': layer_data['bias']} - output_dict[new_key] = new_data - return output_dict - - -def hk_to_flax(input_dict): - """Maps haiku parameter structure to flax parameter structure.""" - output_dict = {'params': {}} - - for key, layer_data in input_dict.items(): - # Extracting the layer number from the key - layer_num = key.split('_')[-1] # Get the part after the last '_' - original_layer_name = f'Dense_{layer_num}' - original_data = {'kernel': layer_data['w'], 'bias': layer_data['b']} - output_dict['params'][original_layer_name] = original_data - return output_dict - - def make_hk_perm_spec(mlp_params): """Produces perm spec for a haiku mlp.""" perm_spec = {} @@ -153,17 +116,57 @@ def make_hk_cnn_perm_spec(mlp_params): return perm_spec +def build_init_fn(scale, shape): + return lambda rng, _shape: scale * jax.random.normal(rng, shape) + + +class PosEmbConv(nn.Module): + """Add learned position embeddings for spatial dims of conv input.""" + + @nn.compact + def __call__(self, inp_features): + features, tree_def = jtu.tree_flatten(inp_features) + out_features = [] + for i, val in enumerate(features): + if len(val.shape) == 5: # conv2d filter: HxWxC1xC2xC + shape = (val.shape[0], val.shape[1], 1, 1, val.shape[-1]) + scale = 0.17 # roughly 1 / sqrt(32), to match scale of kernel at init + pos_emb = self.param(f'pos_emb_{i}', build_init_fn(scale, shape), shape) + out_features.append(pos_emb + val) + else: + out_features.append(val) + out_features = jtu.tree_unflatten(tree_def, out_features) + return out_features + + +def make_hk_irnn_perm_spec(mlp_params): + """Tested on RNNLM_lm1bbytes_Patch32_IRNN128_Embed64.""" + # -1: vocab, 0: embed, 1: hidden + del mlp_params + perm_spec = { + 'embed': {'embeddings': (-1, 0)}, + 'irnn/linear': {'b': (1,), 'w': (0, 1)}, + 'irnn/linear_1': {'b': (1,), 'w': (1, 1)}, + 'linear': {'b': (-1,), 'w': (1, -1)}, + '~': {'initial_state_0': (-2, 1)}, + } + return perm_spec + + class MLPForOpt(nn.Module): """MLP for learned opt.""" hidden_channels: int out_channels: int num_layers: int + pos_emb: bool = False def setup(self): layers = [] - for _ in range(self.num_layers - 1): + for i in range(self.num_layers - 1): layers.append(nn.Dense(self.hidden_channels)) + if i == 0 and self.pos_emb: + layers.append(PosEmbConv()) layers.append(jax.nn.relu) layers.append(nn.Dense(self.out_channels)) self.mod = nn.Sequential(layers) @@ -173,38 +176,6 @@ def __call__(self, inp_features): return jtu.tree_map(self.mod, inp_features) -class NFNForOpt(nn.Module): - """NFN for learned opt.""" - - in_channels: int - hidden_channels: int - out_channels: int - num_layers: int - pos_enc: bool = True - hnet: bool = False - - def setup(self): - assert not (self.hnet and self.pos_enc), 'Only one of these can be on.' - in_channels, hidden_channels = self.in_channels, self.hidden_channels - layer_cls = lambda *args, **kwargs: nf_layers.NFLinearMlp( - *args, **kwargs, pe_enabled=self.pos_enc - ) - if self.hnet: - layer_cls = nf_layers.NFLinearMlpHNet - layers = [layer_cls(hidden_channels, in_channels), nf_layers.nf_relu] - for _ in range(self.num_layers - 2): - layers.append(layer_cls(hidden_channels, hidden_channels)) - layers.append(nf_layers.nf_relu) - layers.append(layer_cls(self.out_channels, hidden_channels)) - self.mod = nn.Sequential(layers) - - def __call__(self, inp_features): - # add batch dimension for nf layers - inp_features = nfu.tree_expand_dims(inp_features, 0) - out = flax_to_hk(self.mod(hk_to_flax(inp_features))[0]) - return nfu.tree_squeeze(out, 0) - - class UnivNFNForOpt(nn.Module): """Univeral NFN for learned opt.""" @@ -214,6 +185,7 @@ class UnivNFNForOpt(nn.Module): num_layers: int perm_spec: Any ptwise_init: bool = False + pos_emb: bool = False def setup(self): in_channels, hidden_channels = self.in_channels, self.hidden_channels @@ -224,10 +196,10 @@ def make_layer(out_chan, in_chan): else: return universal_layers.NFLinear(out_chan, in_chan, w_init='lecun') - layers = [ - make_layer(hidden_channels, in_channels), - universal_layers.nf_relu, - ] + layers = [make_layer(hidden_channels, in_channels)] + if self.pos_emb: + layers.append(PosEmbConv()) + layers.append(universal_layers.nf_relu) for _ in range(self.num_layers - 1): layers.append(make_layer(hidden_channels, hidden_channels)) layers.append(universal_layers.nf_relu) @@ -434,10 +406,14 @@ def norm_second_moment(p): class ResidualOptNFN(ResidualOpt): """NFN learning a residual on base optimizer.""" - def __init__(self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False): + def __init__( + self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False, pos_emb=False + ): example_params = task.init(jax.random.PRNGKey(0)) if 'conv2_d' in example_params: perm_spec = make_hk_cnn_perm_spec(example_params) + elif 'irnn/linear' in example_params: + perm_spec = make_hk_irnn_perm_spec(example_params) else: perm_spec = make_hk_perm_spec(example_params) network = UnivNFNForOpt( @@ -447,6 +423,7 @@ def __init__(self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False): num_layers=4, perm_spec=perm_spec, ptwise_init=ptwise_init, + pos_emb=pos_emb, ) super().__init__( network, example_params, step_mult=step_mult, out_mult=out_mult @@ -456,9 +433,11 @@ def __init__(self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False): @gin.configurable class ResidualOptMLP(ResidualOpt): - def __init__(self, task, step_mult=0.1, out_mult=1e-4): + def __init__(self, task, step_mult=0.1, out_mult=1e-4, pos_emb=False): example_params = task.init(jax.random.PRNGKey(0)) - network = MLPForOpt(hidden_channels=32, out_channels=1, num_layers=4) + network = MLPForOpt( + hidden_channels=32, out_channels=1, num_layers=4, pos_emb=pos_emb + ) super().__init__( network, example_params, step_mult=step_mult, out_mult=out_mult ) diff --git a/learned_optimization/research/univ_nfn/nfn/siren.py b/learned_optimization/research/univ_nfn/nfn/siren.py index 6e93e57..cdb8d01 100644 --- a/learned_optimization/research/univ_nfn/nfn/siren.py +++ b/learned_optimization/research/univ_nfn/nfn/siren.py @@ -19,7 +19,6 @@ from typing import Any, Callable, Optional, Tuple from flax import linen as nn -from flax import nn as fnn import jax from jax import lax import jax.numpy as jnp @@ -113,7 +112,7 @@ class ModulatedLayer(nn.Module): features: int = 32 is_first: bool = False synthesis_act: Callable = jnp.sin - modulator_act: Callable = fnn.relu + modulator_act: Callable = nn.relu precision: Any = None dtype: Any = jnp.float32 w0_first_layer: float = 30.0 @@ -196,7 +195,7 @@ class ModulatedSiren(nn.Module): output_dim: int = 3 num_layers: int = 5 synthesis_act: Callable = jnp.sin - modulator_act: Callable = fnn.relu + modulator_act: Callable = nn.relu final_activation: Callable = lambda x: x w0_first_layer: float = 30.0 dtype: Any = jnp.float32