Skip to content

Commit

Permalink
black -l 120
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthijsBurgh committed Jun 11, 2024
1 parent 8ee5326 commit b4d446c
Show file tree
Hide file tree
Showing 10 changed files with 391 additions and 359 deletions.
18 changes: 13 additions & 5 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from .models.inception_resnet_v1 import InceptionResnetV1
from .models.mtcnn import MTCNN, PNet, RNet, ONet, prewhiten, fixed_image_standardization
from .models.mtcnn import (
MTCNN,
PNet,
RNet,
ONet,
prewhiten,
fixed_image_standardization,
)
from .models.utils.detect_face import extract_face
from .models.utils import training

import warnings

warnings.filterwarnings(
action="ignore",
message="This overload of nonzero is deprecated:\n\tnonzero()",
category=UserWarning
)
action="ignore",
message="This overload of nonzero is deprecated:\n\tnonzero()",
category=UserWarning,
)
69 changes: 38 additions & 31 deletions models/inception_resnet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@ class BasicConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
super().__init__()
self.conv = nn.Conv2d(
in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, bias=False
) # verify bias false
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False,
) # verify bias false
self.bn = nn.BatchNorm2d(
out_planes,
eps=0.001, # value found in tensorflow
momentum=0.1, # default pytorch value
affine=True
eps=0.001,
momentum=0.1,
affine=True, # value found in tensorflow # default pytorch value
)
self.relu = nn.ReLU(inplace=False)

Expand All @@ -44,13 +47,13 @@ def __init__(self, scale=1.0):

self.branch1 = nn.Sequential(
BasicConv2d(256, 32, kernel_size=1, stride=1),
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1),
)

self.branch2 = nn.Sequential(
BasicConv2d(256, 32, kernel_size=1, stride=1),
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1),
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1),
)

self.conv2d = nn.Conv2d(96, 256, kernel_size=1, stride=1)
Expand Down Expand Up @@ -78,8 +81,8 @@ def __init__(self, scale=1.0):

self.branch1 = nn.Sequential(
BasicConv2d(896, 128, kernel_size=1, stride=1),
BasicConv2d(128, 128, kernel_size=(1,7), stride=1, padding=(0,3)),
BasicConv2d(128, 128, kernel_size=(7,1), stride=1, padding=(3,0))
BasicConv2d(128, 128, kernel_size=(1, 7), stride=1, padding=(0, 3)),
BasicConv2d(128, 128, kernel_size=(7, 1), stride=1, padding=(3, 0)),
)

self.conv2d = nn.Conv2d(256, 896, kernel_size=1, stride=1)
Expand Down Expand Up @@ -107,8 +110,8 @@ def __init__(self, scale=1.0, noReLU=False):

self.branch1 = nn.Sequential(
BasicConv2d(1792, 192, kernel_size=1, stride=1),
BasicConv2d(192, 192, kernel_size=(1,3), stride=1, padding=(0,1)),
BasicConv2d(192, 192, kernel_size=(3,1), stride=1, padding=(1,0))
BasicConv2d(192, 192, kernel_size=(1, 3), stride=1, padding=(0, 1)),
BasicConv2d(192, 192, kernel_size=(3, 1), stride=1, padding=(1, 0)),
)

self.conv2d = nn.Conv2d(384, 1792, kernel_size=1, stride=1)
Expand Down Expand Up @@ -136,7 +139,7 @@ def __init__(self):
self.branch1 = nn.Sequential(
BasicConv2d(256, 192, kernel_size=1, stride=1),
BasicConv2d(192, 192, kernel_size=3, stride=1, padding=1),
BasicConv2d(192, 256, kernel_size=3, stride=2)
BasicConv2d(192, 256, kernel_size=3, stride=2),
)

self.branch2 = nn.MaxPool2d(3, stride=2)
Expand All @@ -156,18 +159,18 @@ def __init__(self):

self.branch0 = nn.Sequential(
BasicConv2d(896, 256, kernel_size=1, stride=1),
BasicConv2d(256, 384, kernel_size=3, stride=2)
BasicConv2d(256, 384, kernel_size=3, stride=2),
)

self.branch1 = nn.Sequential(
BasicConv2d(896, 256, kernel_size=1, stride=1),
BasicConv2d(256, 256, kernel_size=3, stride=2)
BasicConv2d(256, 256, kernel_size=3, stride=2),
)

self.branch2 = nn.Sequential(
BasicConv2d(896, 256, kernel_size=1, stride=1),
BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
BasicConv2d(256, 256, kernel_size=3, stride=2)
BasicConv2d(256, 256, kernel_size=3, stride=2),
)

self.branch3 = nn.MaxPool2d(3, stride=2)
Expand Down Expand Up @@ -199,22 +202,29 @@ class InceptionResnetV1(nn.Module):
initialized. (default: {None})
dropout_prob {float} -- Dropout probability. (default: {0.6})
"""
def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None):

def __init__(
self,
pretrained=None,
classify=False,
num_classes=None,
dropout_prob=0.6,
device=None,
):
super().__init__()

# Set simple attributes
self.pretrained = pretrained
self.classify = classify
self.num_classes = num_classes

if pretrained == 'vggface2':
if pretrained == "vggface2":
tmp_classes = 8631
elif pretrained == 'casia-webface':
elif pretrained == "casia-webface":
tmp_classes = 10575
elif pretrained is None and self.classify and self.num_classes is None:
raise Exception('If "pretrained" is not specified and "classify" is True, "num_classes" must be specified')


# Define layers
self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
Expand Down Expand Up @@ -264,7 +274,7 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr
if self.classify and self.num_classes is not None:
self.logits = nn.Linear(512, self.num_classes)

self.device = torch.device('cpu')
self.device = torch.device("cpu")
if device is not None:
self.device = device
self.to(device)
Expand Down Expand Up @@ -312,14 +322,14 @@ def load_weights(mdl, name):
Raises:
ValueError: If 'pretrained' not equal to 'vggface2' or 'casia-webface'.
"""
if name == 'vggface2':
path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180402-114759-vggface2.pt'
elif name == 'casia-webface':
path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180408-102900-casia-webface.pt'
if name == "vggface2":
path = "https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180402-114759-vggface2.pt"
elif name == "casia-webface":
path = "https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180408-102900-casia-webface.pt"
else:
raise ValueError('Pretrained models only exist for "vggface2" and "casia-webface"')

model_dir = os.path.join(get_torch_home(), 'checkpoints')
model_dir = os.path.join(get_torch_home(), "checkpoints")
os.makedirs(model_dir, exist_ok=True)

cached_file = os.path.join(model_dir, os.path.basename(path))
Expand All @@ -332,9 +342,6 @@ def load_weights(mdl, name):

def get_torch_home():
torch_home = os.path.expanduser(
os.getenv(
'TORCH_HOME',
os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')
)
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
)
return torch_home
Loading

0 comments on commit b4d446c

Please sign in to comment.