Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 593160972
  • Loading branch information
learned_optimization authors committed Dec 22, 2023
1 parent 35120a4 commit 597524c
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 80 deletions.
133 changes: 56 additions & 77 deletions learned_optimization/research/univ_nfn/learned_opt/learned_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from learned_optimization.learned_optimizers import base as lopt_base
from learned_optimization.learned_optimizers import common
from learned_optimization.optimizers import base as opt_base
from learned_optimization.research.univ_nfn.nfn import ff_layers as nf_layers
from learned_optimization.research.univ_nfn.nfn import universal_layers
from learned_optimization.research.univ_nfn.nfn import utils as nfu

Expand Down Expand Up @@ -89,42 +88,6 @@ class SimpleOptState(flax.struct.PyTreeNode):
state: Any


def flax_to_hk(input_dict):
"""Maps flax parameter structure to haiku parameter structure.
Example:
>>> input_dict = {
... 'params': {
... 'Dense_0': {'kernel': W0, 'bias': b0},
... 'Dense_1': {'kernel': W1, 'bias': b1}
... }
... }
>>> transform_dict(input_dict)
{'mlp/~/linear_0': {'w': W0, 'b': b0}, 'mlp/~/linear_1': {'w': W1, 'b': b1}}
"""
params = input_dict.get('params', {})
output_dict = {}
for i, (_, layer_data) in enumerate(params.items()):
# Constructing new key and sub-dictionary format
new_key = f'mlp/~/linear_{i}'
new_data = {'w': layer_data['kernel'], 'b': layer_data['bias']}
output_dict[new_key] = new_data
return output_dict


def hk_to_flax(input_dict):
"""Maps haiku parameter structure to flax parameter structure."""
output_dict = {'params': {}}

for key, layer_data in input_dict.items():
# Extracting the layer number from the key
layer_num = key.split('_')[-1] # Get the part after the last '_'
original_layer_name = f'Dense_{layer_num}'
original_data = {'kernel': layer_data['w'], 'bias': layer_data['b']}
output_dict['params'][original_layer_name] = original_data
return output_dict


def make_hk_perm_spec(mlp_params):
"""Produces perm spec for a haiku mlp."""
perm_spec = {}
Expand Down Expand Up @@ -153,17 +116,57 @@ def make_hk_cnn_perm_spec(mlp_params):
return perm_spec


def build_init_fn(scale, shape):
return lambda rng, _shape: scale * jax.random.normal(rng, shape)


class PosEmbConv(nn.Module):
"""Add learned position embeddings for spatial dims of conv input."""

@nn.compact
def __call__(self, inp_features):
features, tree_def = jtu.tree_flatten(inp_features)
out_features = []
for i, val in enumerate(features):
if len(val.shape) == 5: # conv2d filter: HxWxC1xC2xC
shape = (val.shape[0], val.shape[1], 1, 1, val.shape[-1])
scale = 0.17 # roughly 1 / sqrt(32), to match scale of kernel at init
pos_emb = self.param(f'pos_emb_{i}', build_init_fn(scale, shape), shape)
out_features.append(pos_emb + val)
else:
out_features.append(val)
out_features = jtu.tree_unflatten(tree_def, out_features)
return out_features


def make_hk_irnn_perm_spec(mlp_params):
"""Tested on RNNLM_lm1bbytes_Patch32_IRNN128_Embed64."""
# -1: vocab, 0: embed, 1: hidden
del mlp_params
perm_spec = {
'embed': {'embeddings': (-1, 0)},
'irnn/linear': {'b': (1,), 'w': (0, 1)},
'irnn/linear_1': {'b': (1,), 'w': (1, 1)},
'linear': {'b': (-1,), 'w': (1, -1)},
'~': {'initial_state_0': (-2, 1)},
}
return perm_spec


class MLPForOpt(nn.Module):
"""MLP for learned opt."""

hidden_channels: int
out_channels: int
num_layers: int
pos_emb: bool = False

def setup(self):
layers = []
for _ in range(self.num_layers - 1):
for i in range(self.num_layers - 1):
layers.append(nn.Dense(self.hidden_channels))
if i == 0 and self.pos_emb:
layers.append(PosEmbConv())
layers.append(jax.nn.relu)
layers.append(nn.Dense(self.out_channels))
self.mod = nn.Sequential(layers)
Expand All @@ -173,38 +176,6 @@ def __call__(self, inp_features):
return jtu.tree_map(self.mod, inp_features)


class NFNForOpt(nn.Module):
"""NFN for learned opt."""

in_channels: int
hidden_channels: int
out_channels: int
num_layers: int
pos_enc: bool = True
hnet: bool = False

def setup(self):
assert not (self.hnet and self.pos_enc), 'Only one of these can be on.'
in_channels, hidden_channels = self.in_channels, self.hidden_channels
layer_cls = lambda *args, **kwargs: nf_layers.NFLinearMlp(
*args, **kwargs, pe_enabled=self.pos_enc
)
if self.hnet:
layer_cls = nf_layers.NFLinearMlpHNet
layers = [layer_cls(hidden_channels, in_channels), nf_layers.nf_relu]
for _ in range(self.num_layers - 2):
layers.append(layer_cls(hidden_channels, hidden_channels))
layers.append(nf_layers.nf_relu)
layers.append(layer_cls(self.out_channels, hidden_channels))
self.mod = nn.Sequential(layers)

def __call__(self, inp_features):
# add batch dimension for nf layers
inp_features = nfu.tree_expand_dims(inp_features, 0)
out = flax_to_hk(self.mod(hk_to_flax(inp_features))[0])
return nfu.tree_squeeze(out, 0)


class UnivNFNForOpt(nn.Module):
"""Univeral NFN for learned opt."""

Expand All @@ -214,6 +185,7 @@ class UnivNFNForOpt(nn.Module):
num_layers: int
perm_spec: Any
ptwise_init: bool = False
pos_emb: bool = False

def setup(self):
in_channels, hidden_channels = self.in_channels, self.hidden_channels
Expand All @@ -224,10 +196,10 @@ def make_layer(out_chan, in_chan):
else:
return universal_layers.NFLinear(out_chan, in_chan, w_init='lecun')

layers = [
make_layer(hidden_channels, in_channels),
universal_layers.nf_relu,
]
layers = [make_layer(hidden_channels, in_channels)]
if self.pos_emb:
layers.append(PosEmbConv())
layers.append(universal_layers.nf_relu)
for _ in range(self.num_layers - 1):
layers.append(make_layer(hidden_channels, hidden_channels))
layers.append(universal_layers.nf_relu)
Expand Down Expand Up @@ -434,10 +406,14 @@ def norm_second_moment(p):
class ResidualOptNFN(ResidualOpt):
"""NFN learning a residual on base optimizer."""

def __init__(self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False):
def __init__(
self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False, pos_emb=False
):
example_params = task.init(jax.random.PRNGKey(0))
if 'conv2_d' in example_params:
perm_spec = make_hk_cnn_perm_spec(example_params)
elif 'irnn/linear' in example_params:
perm_spec = make_hk_irnn_perm_spec(example_params)
else:
perm_spec = make_hk_perm_spec(example_params)
network = UnivNFNForOpt(
Expand All @@ -447,6 +423,7 @@ def __init__(self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False):
num_layers=4,
perm_spec=perm_spec,
ptwise_init=ptwise_init,
pos_emb=pos_emb,
)
super().__init__(
network, example_params, step_mult=step_mult, out_mult=out_mult
Expand All @@ -456,9 +433,11 @@ def __init__(self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False):
@gin.configurable
class ResidualOptMLP(ResidualOpt):

def __init__(self, task, step_mult=0.1, out_mult=1e-4):
def __init__(self, task, step_mult=0.1, out_mult=1e-4, pos_emb=False):
example_params = task.init(jax.random.PRNGKey(0))
network = MLPForOpt(hidden_channels=32, out_channels=1, num_layers=4)
network = MLPForOpt(
hidden_channels=32, out_channels=1, num_layers=4, pos_emb=pos_emb
)
super().__init__(
network, example_params, step_mult=step_mult, out_mult=out_mult
)
Expand Down
5 changes: 2 additions & 3 deletions learned_optimization/research/univ_nfn/nfn/siren.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Any, Callable, Optional, Tuple

from flax import linen as nn
from flax import nn as fnn
import jax
from jax import lax
import jax.numpy as jnp
Expand Down Expand Up @@ -113,7 +112,7 @@ class ModulatedLayer(nn.Module):
features: int = 32
is_first: bool = False
synthesis_act: Callable = jnp.sin
modulator_act: Callable = fnn.relu
modulator_act: Callable = nn.relu
precision: Any = None
dtype: Any = jnp.float32
w0_first_layer: float = 30.0
Expand Down Expand Up @@ -196,7 +195,7 @@ class ModulatedSiren(nn.Module):
output_dim: int = 3
num_layers: int = 5
synthesis_act: Callable = jnp.sin
modulator_act: Callable = fnn.relu
modulator_act: Callable = nn.relu
final_activation: Callable = lambda x: x
w0_first_layer: float = 30.0
dtype: Any = jnp.float32
Expand Down

0 comments on commit 597524c

Please sign in to comment.