Skip to content

Commit

Permalink
Added MLP; FIX: Respect bias argument
Browse files Browse the repository at this point in the history
  • Loading branch information
kunaldahiya committed Sep 9, 2021
1 parent 81dedfb commit eae0f47
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 8 deletions.
6 changes: 6 additions & 0 deletions deepxml/libs/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,12 @@ def _construct(self):
'--validate',
action='store_true',
help='Validate or just train')
self.parser.add_argument(
'--bias',
action='store',
default=True,
type=bool,
help='Use bias term or not!')
self.parser.add_argument(
'--shuffle',
action='store',
Expand Down
77 changes: 77 additions & 0 deletions deepxml/models/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
import torch.nn as nn


__author__ = 'KD'


class MLP(nn.Module):
"""
A multi-layer perceptron with flexibility for non-liearity
* no non-linearity after last layer
* support for 2D or 3D inputs
Parameters:
-----------
input_size: int
input size of embeddings
hidden_size: int or list of ints or str (comma separated)
e.g., 512: a single hidden layer with 512 neurons
"512": a single hidden layer with 512 neurons
"512,300": 512 -> nnl -> 300
[512, 300]: 512 -> nnl -> 300
dimensionality of layers in MLP
nnl: str, optional, default='relu'
which non-linearity to use
device: str, default="cuda:0"
keep on this device
"""
def __init__(self, input_size, hidden_size, nnl='relu', device="cuda:0"):
super(MLP, self).__init__()
hidden_size = self.parse_hidden_size(hidden_size)
assert len(hidden_size) >= 1, "Should contain atleast 1 hidden layer"
hidden_size = [input_size] + hidden_size
self.device = torch.device(device)
layers = []
for i, (i_s, o_s) in enumerate(zip(hidden_size[:-1], hidden_size[1:])):
layers.append(nn.Linear(i_s, o_s, bias=True))
if i < len(hidden_size) - 2:
layers.append(self._get_nnl(nnl))
self.transform = torch.nn.Sequential(*layers)

def parse_hidden_size(self, hidden_size):
if isinstance(hidden_size, int):
return [hidden_size]
elif isinstance(hidden_size, str):
_hidden_size = []
for item in hidden_size.split(","):
_hidden_size.append(int(item))
return _hidden_size
elif isinstance(hidden_size, list):
return hidden_size
else:
raise NotImplementedError("hidden_size must be a int, str or list")

def _get_nnl(self, nnl):
if nnl == 'sigmoid':
return torch.nn.Sigmoid()
elif nnl == 'relu':
return torch.nn.ReLU()
elif nnl == 'gelu':
return torch.nn.GELU()
elif nnl == 'tanh':
return torch.nn.Tanh()
else:
raise NotImplementedError(f"{nnl} not implemented!")

def forward(self, x):
return self.transform(x)

def to(self):
"""Transfer to device
"""
super().to(self.device)

@property
def sparse(self):
return False
11 changes: 6 additions & 5 deletions deepxml/models/network.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import torch.nn as nn
import numpy as np
import math
import os
import models.transform_layer as transform_layer
Expand Down Expand Up @@ -154,6 +153,7 @@ def __init__(self, params):
trans_config_coarse = transform_config_dict['transform_coarse']
self.representation_dims = int(
transform_config_dict['representation_dims'])
self._bias = params.bias
super(DeepXMLf, self).__init__(trans_config_coarse)
if params.freeze_intermediate:
print("Freezing intermediate model parameters!")
Expand Down Expand Up @@ -226,7 +226,7 @@ def forward(self, batch_data, bypass_coarse=False):

def _construct_classifier(self):
if self.num_clf_partitions > 1: # Run the distributed version
_bias = [True for _ in range(self.num_clf_partitions)]
_bias = [self._bias for _ in range(self.num_clf_partitions)]
_clf_devices = ["cuda:{}".format(
idx) for idx in range(self.num_clf_partitions)]
return linear_layer.ParallelLinear(
Expand All @@ -239,7 +239,7 @@ def _construct_classifier(self):
return linear_layer.Linear(
input_size=self.representation_dims,
output_size=self.num_labels, # last one is padding index
bias=True
bias=self._bias
)

def get_token_embeddings(self):
Expand Down Expand Up @@ -297,6 +297,7 @@ def __init__(self, params):
trans_config_coarse = transform_config_dict['transform_coarse']
self.representation_dims = int(
transform_config_dict['representation_dims'])
self._bias = params.bias
super(DeepXMLs, self).__init__(trans_config_coarse)
if params.freeze_intermediate:
print("Freezing intermediate model parameters!")
Expand Down Expand Up @@ -383,7 +384,7 @@ def _construct_classifier(self):
# last one is padding index for each partition
_num_labels = self.num_labels + offset
_padding_idx = [None for _ in range(self.num_clf_partitions)]
_bias = [True for _ in range(self.num_clf_partitions)]
_bias = [self._bias for _ in range(self.num_clf_partitions)]
_clf_devices = ["cuda:{}".format(
idx) for idx in range(self.num_clf_partitions)]
return linear_layer.ParallelSparseLinear(
Expand All @@ -399,7 +400,7 @@ def _construct_classifier(self):
input_size=self.representation_dims,
output_size=self.num_labels + offset,
padding_idx=self.label_padding_index,
bias=True)
bias=self._bias)

def to(self):
"""Send layers to respective devices
Expand Down
6 changes: 3 additions & 3 deletions deepxml/models/transform_layer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import sys
import re
import torch.nn as nn
import models.residual_layer as residual_layer
import models.astec as astec
import json
from collections import OrderedDict
import models.mlp as mlp


class _Identity(nn.Module):
Expand Down Expand Up @@ -38,7 +37,8 @@ def initialize(self, *args, **kwargs):
'residual': residual_layer.Residual,
'identity': Identity,
'_identity': _Identity,
'astec': astec.Astec
'astec': astec.Astec,
'mlp': mlp.MLP
}


Expand Down

0 comments on commit eae0f47

Please sign in to comment.