Skip to content

Commit

Permalink
Merge pull request #49 from crux-bphc/ml-training-validation
Browse files Browse the repository at this point in the history
Freeze selected layers and create model training notebook
  • Loading branch information
majimearun authored Aug 31, 2023
2 parents 5ebc354 + 6111873 commit a486c7c
Show file tree
Hide file tree
Showing 3 changed files with 943 additions and 4 deletions.
15 changes: 13 additions & 2 deletions clusterer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ def __init__(self, embedding_size: int=256, use_default: bool=False):
use_default (bool) : allows user to use unaltered model, defaults to False
"""
self.embedding_size: int = embedding_size
self.model = InceptionResnetV1(pretrained='vggface2').eval()
self.model: InceptionResnetV1 = InceptionResnetV1(pretrained='vggface2').eval()
if not use_default:
self.n_features = self.model.last_linear.in_features
self.model.last_linear = nn.Linear(self.n_features, self.embedding_size)
self.n_bn_features = self.model.last_bn.num_features
self.model.last_bn = nn.BatchNorm1d(embedding_size)
self.model.classify = False
self.model.eval()

def forward(self, x: torch.Tensor):
"""
Expand Down Expand Up @@ -102,7 +103,7 @@ def load_model(cls, filepath: str, embedding_size: int=256):
facenet.model.classify = False
return facenet

def train(self, train_data: TripletDataset, batch_size : int, n_epochs : int, learning_rate : float):
def train(self, train_data: TripletDataset, batch_size : int, n_epochs : int, learning_rate : float, frozen: int = 250):
"""
Trains a model on given data
Expand All @@ -111,6 +112,7 @@ def train(self, train_data: TripletDataset, batch_size : int, n_epochs : int, le
batch_size (int)
n_epochs (int)
learning_rate (float)
frozen (int) : Number of layers upto which parameters are to be froze. defaults to 250
Returns:
None
Expand All @@ -122,6 +124,15 @@ def train(self, train_data: TripletDataset, batch_size : int, n_epochs : int, le
optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)

self.model.train()

counter = 0
for param in self.model.parameters():
if counter < frozen:
param.requires_grad=False
elif counter >= frozen:
param.requires_grad=True
counter+=1

start = time.time()

print(f"Training starts: {time.ctime(start)}\n")
Expand Down
Loading

0 comments on commit a486c7c

Please sign in to comment.