From 7b61d6cc969e7b459781eacbb817af74536dc17f Mon Sep 17 00:00:00 2001 From: junyoung park Date: Tue, 7 Nov 2023 15:49:06 +0900 Subject: [PATCH] [Feat] Implement MatNet --- rl4co/models/zoo/matnet/__init__.py | 0 rl4co/models/zoo/matnet/decoder.py | 52 +++++ rl4co/models/zoo/matnet/encoder.py | 309 ++++++++++++++++++++++++++++ rl4co/models/zoo/matnet/model.py | 39 ++++ rl4co/models/zoo/matnet/policy.py | 61 ++++++ 5 files changed, 461 insertions(+) create mode 100644 rl4co/models/zoo/matnet/__init__.py create mode 100644 rl4co/models/zoo/matnet/decoder.py create mode 100644 rl4co/models/zoo/matnet/encoder.py create mode 100644 rl4co/models/zoo/matnet/model.py create mode 100644 rl4co/models/zoo/matnet/policy.py diff --git a/rl4co/models/zoo/matnet/__init__.py b/rl4co/models/zoo/matnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rl4co/models/zoo/matnet/decoder.py b/rl4co/models/zoo/matnet/decoder.py new file mode 100644 index 00000000..e703bf5c --- /dev/null +++ b/rl4co/models/zoo/matnet/decoder.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import Tuple, Union + +import torch +import torch.nn as nn +from einops import rearrange +from rl4co.models.zoo.common.autoregressive.decoder import AutoregressiveDecoder +from rl4co.utils.ops import batchify, get_num_starts, select_start_nodes, unbatchify +from tensordict import TensorDict +from torch import Tensor + + +@dataclass +class PrecomputedCache: + node_embeddings: Tensor + graph_context: Union[Tensor, float] + glimpse_key: Tensor + glimpse_val: Tensor + logit_key: Tensor + + +class MatNetDecoder(AutoregressiveDecoder): + def _precompute_cache( + self, embeddings: Tuple[Tensor, Tensor], num_starts: int = 0, td: TensorDict = None + ): + col_emb, row_emb = embeddings + ( + glimpse_key_fixed, + glimpse_val_fixed, + logit_key, + ) = self.project_node_embeddings( + col_emb + ).chunk(3, dim=-1) + + # Optionally disable the graph context from the initial embedding as done in POMO + if self.use_graph_context: + graph_context = unbatchify( + batchify(self.project_fixed_context(col_emb.mean(1)), num_starts), + num_starts, + ) + else: + graph_context = 0 + + # Organize in a dataclass for easy access + return PrecomputedCache( + node_embeddings=row_emb, + graph_context=graph_context, + glimpse_key=glimpse_key_fixed, + glimpse_val=glimpse_val_fixed, + # logit_key=col_emb, + logit_key=logit_key, + ) \ No newline at end of file diff --git a/rl4co/models/zoo/matnet/encoder.py b/rl4co/models/zoo/matnet/encoder.py new file mode 100644 index 00000000..273baa31 --- /dev/null +++ b/rl4co/models/zoo/matnet/encoder.py @@ -0,0 +1,309 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from rl4co.models.nn.ops import Normalization +from tensordict import TensorDict + + +class MatNetCrossMHA(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + bias: bool = True, + mixer_hidden_dim: int = 16, + mix1_init: float = (1 / 2) ** (1 / 2), + mix2_init: float = (1 / 16) ** (1 / 2), + ): + super().__init__() + self.embedding_dim = embedding_dim + self.num_heads = num_heads + assert ( + self.embedding_dim % num_heads == 0 + ), "embedding_dim must be divisible by num_heads" + self.head_dim = self.embedding_dim // num_heads + + self.Wq = nn.Linear(embedding_dim, embedding_dim, bias=bias) + self.Wkv = nn.Linear(embedding_dim, 2 * embedding_dim, bias=bias) + + # Score mixer + # Taken from the official MatNet implementation + # https://github.com/yd-kwon/MatNet/blob/main/ATSP/ATSP_MatNet/ATSPModel_LIB.py#L72 + mix_W1 = torch.torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample( + (num_heads, 2, mixer_hidden_dim) + ) + mix_b1 = torch.torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample( + (num_heads, mixer_hidden_dim) + ) + self.mix_W1 = nn.Parameter(mix_W1) + self.mix_b1 = nn.Parameter(mix_b1) + + mix_W2 = torch.torch.distributions.Uniform(low=-mix2_init, high=mix2_init).sample( + (num_heads, mixer_hidden_dim, 1) + ) + mix_b2 = torch.torch.distributions.Uniform(low=-mix2_init, high=mix2_init).sample( + (num_heads, 1) + ) + self.mix_W2 = nn.Parameter(mix_W2) + self.mix_b2 = nn.Parameter(mix_b2) + + self.out_proj = nn.Linear(embedding_dim, embedding_dim, bias=bias) + + def forward(self, q_input, kv_input, dmat): + """ + + Args: + q_input (Tensor): [b, m, d] + kv_input (Tensor): [b, n, d] + dmat (Tensor): [b, m, n] + + Returns: + Tensor: [b, m, d] + """ + + b, m, n = dmat.shape + + q = rearrange( + self.Wq(q_input), "b m (h d) -> b h m d", h=self.num_heads + ) # [b, h, m, d] + k, v = rearrange( + self.Wkv(kv_input), "b n (two h d) -> two b h n d", two=2, h=self.num_heads + ).unbind( + dim=0 + ) # [b, h, n, d] + + scale = math.sqrt(q.size(-1)) # scale factor + attn_scores = torch.matmul(q, k.transpose(2, 3)) / scale # [b, h, m, n] + mix_attn_scores = torch.stack( + [attn_scores, dmat[:, None, :, :].expand(b, self.num_heads, m, n)], dim=-1 + ) # [b, h, m, n, 2] + + mix_attn_scores = ( + ( + torch.matmul( + F.relu( + torch.matmul(mix_attn_scores.transpose(1, 2), self.mix_W1) + + self.mix_b1[None, None, :, None, :] + ), + self.mix_W2, + ) + + self.mix_b2[None, None, :, None, :] + ) + .transpose(1, 2) + .squeeze(-1) + ) # [b, h, m, n] + + attn_probs = F.softmax(mix_attn_scores, dim=-1) + out = torch.matmul(attn_probs, v) + return self.out_proj(rearrange(out, "b h s d -> b s (h d)")) + + +class MatNetMHA(nn.Module): + def __init__(self, embedding_dim: int, num_heads: int, bias: bool = True): + super().__init__() + self.row_encoding_block = MatNetCrossMHA(embedding_dim, num_heads, bias) + self.col_encoding_block = MatNetCrossMHA(embedding_dim, num_heads, bias) + + def forward(self, row_emb, col_emb, dmat): + """ + Args: + row_emb (Tensor): [b, m, d] + col_emb (Tensor): [b, n, d] + dmat (Tensor): [b, m, n] + + Returns: + Updated row_emb (Tensor): [b, m, d] + Updated col_emb (Tensor): [b, n, d] + """ + + updated_row_emb = self.row_encoding_block(row_emb, col_emb, dmat) + updated_col_emb = self.col_encoding_block( + col_emb, row_emb, dmat.transpose(-2, -1) + ) + return updated_row_emb, updated_col_emb + + +class MatNetMHALayer(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + bias: bool = True, + feed_forward_hidden: int = 512, + normalization: Optional[str] = "instance", + ): + super().__init__() + self.MHA = MatNetMHA(embedding_dim, num_heads, bias) + + self.F_a = nn.ModuleDict( + { + "norm1": Normalization(embedding_dim, normalization), + "ffn": nn.Sequential( + nn.Linear(embedding_dim, feed_forward_hidden), + nn.ReLU(), + nn.Linear(feed_forward_hidden, embedding_dim), + ), + "norm2": Normalization(embedding_dim, normalization), + } + ) + + self.F_b = nn.ModuleDict( + { + "norm1": Normalization(embedding_dim, normalization), + "ffn": nn.Sequential( + nn.Linear(embedding_dim, feed_forward_hidden), + nn.ReLU(), + nn.Linear(feed_forward_hidden, embedding_dim), + ), + "norm2": Normalization(embedding_dim, normalization), + } + ) + + def forward(self, row_emb, col_emb, dmat): + """ + Args: + row_emb (Tensor): [b, m, d] + col_emb (Tensor): [b, n, d] + dmat (Tensor): [b, m, n] + + Returns: + Updated row_emb (Tensor): [b, m, d] + Updated col_emb (Tensor): [b, n, d] + """ + + row_emb_out, col_emb_out = self.MHA(row_emb, col_emb, dmat) + + row_emb_out = self.F_a["norm1"](row_emb + row_emb_out) + row_emb_out = self.F_a["norm2"](row_emb_out + self.F_a["ffn"](row_emb_out)) + + col_emb_out = self.F_b["norm1"](col_emb + col_emb_out) + col_emb_out = self.F_b["norm2"](col_emb_out + self.F_b["ffn"](col_emb_out)) + return row_emb_out, col_emb_out + + +class MatNetMHANetwork(nn.Module): + def __init__( + self, + embedding_dim: int = 128, + num_heads: int = 8, + num_layers: int = 3, + normalization: str = "batch", + feed_forward_hidden: int = 512, + ): + super().__init__() + self.layers = nn.ModuleList( + [ + MatNetMHALayer( + num_heads=num_heads, + embedding_dim=embedding_dim, + feed_forward_hidden=feed_forward_hidden, + normalization=normalization, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, row_emb, col_emb, dmat): + """ + Args: + row_emb (Tensor): [b, m, d] + col_emb (Tensor): [b, n, d] + dmat (Tensor): [b, m, n] + + Returns: + Updated row_emb (Tensor): [b, m, d] + Updated col_emb (Tensor): [b, n, d] + """ + + for layer in self.layers: + row_emb, col_emb = layer(row_emb, col_emb, dmat) + return row_emb, col_emb + + +class MatNetATSPInitEmbedding(nn.Module): + """ + Preparing the initial row and column embeddings for ATSP. + + Reference: + https://github.com/yd-kwon/MatNet/blob/782698b60979effe2e7b61283cca155b7cdb727f/ATSP/ATSP_MatNet/ATSPModel.py#L51 + + + """ + + def __init__(self, embedding_dim: int, mode: str = "RandomOneHot") -> None: + super().__init__() + + self.embedding_dim = embedding_dim + assert mode in { + "RandomOneHot", + "Random", + }, "mode must be one of ['RandomOneHot', 'Random']" + self.mode = mode + + self.dmat_proj = nn.Linear(1, 2 * embedding_dim, bias=False) + self.row_proj = nn.Linear(embedding_dim * 4, embedding_dim, bias=False) + self.col_proj = nn.Linear(embedding_dim * 4, embedding_dim, bias=False) + + def forward(self, td: TensorDict): + dmat = td["cost_matrix"] # [b, n, n] + b, n, _ = dmat.shape + + row_emb = torch.zeros(b, n, self.embedding_dim, device=dmat.device) + + if self.mode == "RandomOneHot": + # MatNet uses one-hot encoding for column embeddings + # https://github.com/yd-kwon/MatNet/blob/782698b60979effe2e7b61283cca155b7cdb727f/ATSP/ATSP_MatNet/ATSPModel.py#L60 + + col_emb = torch.zeros(b, n, self.embedding_dim, device=dmat.device) + rand = torch.rand(b, n) + rand_idx = rand.argsort(dim=1) + b_idx = torch.arange(b)[:, None].expand(b, n) + n_idx = torch.arange(n)[None, :].expand(b, n) + col_emb[b_idx, n_idx, rand_idx] = 1.0 + + elif self.mode == "Random": + col_emb = torch.rand(b, n, self.embedding_dim, device=dmat.device) + else: + raise NotImplementedError + + return row_emb, col_emb, dmat + + +class MatNetEncoder(nn.Module): + def __init__( + self, + embedding_dim: int = 256, + num_heads: int = 16, + num_layers: int = 5, + normalization: str = "instance", + feed_forward_hidden: int = 512, + init_embedding: nn.Module = None, + init_embedding_kwargs: dict = None, + ): + super().__init__() + + if init_embedding is None: + init_embedding = MatNetATSPInitEmbedding( + embedding_dim, **init_embedding_kwargs + ) + + self.init_embedding = init_embedding + self.net = MatNetMHANetwork( + embedding_dim=embedding_dim, + num_heads=num_heads, + num_layers=num_layers, + normalization=normalization, + feed_forward_hidden=feed_forward_hidden, + ) + + def forward(self, td): + row_emb, col_emb, dmat = self.init_embedding(td) + row_emb, col_emb = self.net(row_emb, col_emb, dmat) + + embedding = (row_emb, col_emb) + init_embedding = None + return embedding, init_embedding # match output signature for the AR policy class diff --git a/rl4co/models/zoo/matnet/model.py b/rl4co/models/zoo/matnet/model.py new file mode 100644 index 00000000..1af9cace --- /dev/null +++ b/rl4co/models/zoo/matnet/model.py @@ -0,0 +1,39 @@ +from typing import Any, Union +from rl4co.models.zoo.matnet.policy import MatNetPolicy + +import torch.nn as nn + +from rl4co.models.zoo.pomo.model import POMO +from rl4co.envs.common.base import RL4COEnvBase + + +class MatNet(POMO): + def __init__( + self, + env: RL4COEnvBase, + policy: Union[nn.Module, MatNetPolicy] = None, + optimizer_kwargs: dict = {"lr": 4 * 1e-4, "weight_decay": 1e-6}, + lr_scheduler: str = "MultiStepLR", + lr_scheduler_kwargs: dict = {"milestones": [2001, 2101], "gamma": 0.1}, + use_dihedral_8: bool = False, + num_starts: int = None, + train_data_size: int = 10_000, + batch_size: int = 200, + policy_params: dict = {}, + model_params: dict = {}, + ): + if policy is None: + policy = MatNetPolicy(env_name=env.name, **policy_params) + + super(MatNet, self).__init__( + env=env, + policy=policy, + optimizer_kwargs=optimizer_kwargs, + lr_scheduler=lr_scheduler, + lr_scheduler_kwargs=lr_scheduler_kwargs, + use_dihedral_8=use_dihedral_8, + num_starts=num_starts, + train_data_size=train_data_size, + batch_size=batch_size, + **model_params, + ) diff --git a/rl4co/models/zoo/matnet/policy.py b/rl4co/models/zoo/matnet/policy.py new file mode 100644 index 00000000..8b4e1761 --- /dev/null +++ b/rl4co/models/zoo/matnet/policy.py @@ -0,0 +1,61 @@ +from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy +from rl4co.models.zoo.matnet.encoder import MatNetEncoder +from rl4co.models.zoo.matnet.decoder import MatNetDecoder +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +class MatNetPolicy(AutoregressivePolicy): + """MatNet Policy from Kwon et al., 2021. + Reference: https://arxiv.org/abs/2106.11113 + + Warning: + This implementation is under development and subject to change. + + Args: + env_name: Name of the environment used to initialize embeddings + embedding_dim: Dimension of the node embeddings + num_encoder_layers: Number of layers in the encoder + num_heads: Number of heads in the attention layers + normalization: Normalization type in the attention layers + **kwargs: keyword arguments passed to the `AutoregressivePolicy` + + Default paarameters are adopted from the original implementation. + """ + + def __init__( + self, + env_name: str, + embedding_dim: int = 256, + num_encoder_layers: int = 5, + num_heads: int = 16, + normalization: str = "instance", + init_embedding_kwargs: dict = {"mode": "RandomOneHot"}, + use_graph_context: bool = False, + **kwargs, + ): + if env_name not in ["atsp"]: + log.error(f"env_name {env_name} is not originally implemented in MatNet") + + super(MatNetPolicy, self).__init__( + env_name=env_name, + encoder=MatNetEncoder( + embedding_dim=embedding_dim, + num_heads=num_heads, + num_layers=num_encoder_layers, + normalization=normalization, + init_embedding_kwargs=init_embedding_kwargs, + ), + decoder=MatNetDecoder( + env_name=env_name, + embedding_dim=embedding_dim, + num_heads=num_heads, + use_graph_context=use_graph_context, + ), + embedding_dim=embedding_dim, + num_encoder_layers=num_encoder_layers, + num_heads=num_heads, + normalization=normalization, + **kwargs, + )