Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support using timm optimizers for alternative to adamw default #979

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0,
):
Expand Down Expand Up @@ -131,9 +132,10 @@ def __init__(
cast_dtype=cast_dtype,
)

self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
else:
self.logit_bias = None
self.pad_id = pad_id
Expand Down
8 changes: 8 additions & 0 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ def load_checkpoint(
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)

# correct if logit_scale differs in being scaler vs 1d param
if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim:
state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape)

# correct if logit_bias differs in being scaler vs 1d param
if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim:
state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape)

# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
if 'logit_bias' not in state_dict and model.logit_bias is not None:
state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
Expand Down
34 changes: 30 additions & 4 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def __init__(
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
Expand All @@ -249,9 +250,10 @@ def __init__(
self.text_pool_type = text.pool_type
self.register_buffer('attn_mask', text.attn_mask, persistent=False)

self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
else:
self.logit_bias = None

Expand All @@ -264,6 +266,15 @@ def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable

@torch.jit.ignore
def no_weight_decay(self):
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
no_wd = {'positional_embedding'}
if hasattr(self.visual, 'no_weight_decay'):
for n in self.visual.no_weight_decay():
no_wd.add('visual.' + n)
return no_wd

def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
Expand Down Expand Up @@ -328,6 +339,7 @@ def __init__(
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
Expand All @@ -337,9 +349,11 @@ def __init__(
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.context_length = self.text.context_length
self.vocab_size = self.text.vocab_size
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)

lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
else:
self.logit_bias = None

Expand All @@ -355,6 +369,18 @@ def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)

@torch.jit.ignore
def no_weight_decay(self):
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
no_wd = set()
if hasattr(self.visual, 'no_weight_decay'):
for n in self.visual.no_weight_decay():
no_wd.add('visual.' + n)
if hasattr(self.text, 'no_weight_decay'):
for n in self.visual.no_weight_decay():
no_wd.add('text.' + n)
return no_wd

def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
Expand Down
14 changes: 14 additions & 0 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,12 @@ def init_parameters(self):
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable

@torch.jit.ignore
def no_weight_decay(self):
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
no_wd = {'positional_embedding', 'class_embedding'}
return no_wd

def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.pool_type == 'avg':
pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
Expand Down Expand Up @@ -759,6 +765,14 @@ def init_parameters(self):
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable

@torch.jit.ignore
def no_weight_decay(self):
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
no_wd = {'positional_embedding'}
if self.cls_emb is not None:
no_wd.add('cls_emb')
return no_wd

def build_causal_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
Expand Down
69 changes: 52 additions & 17 deletions src/open_clip_train/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import glob
import logging
import os
Expand Down Expand Up @@ -309,22 +310,56 @@ def main(args):
if args.train_data or args.dataset_type == "synthetic":
assert not args.trace, 'Cannot train with traced model'

exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
include = lambda n, p: not exclude(n, p)

named_parameters = list(model.named_parameters())
gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]

optimizer = optim.AdamW(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.eps,
)
opt = getattr(args, 'opt', 'adamw').lower()
if opt.startswith('timm/'):
from timm.optim import create_optimizer_v2
timm_opt = opt.split('timm/')[-1]
opt_kwargs = {}
assert (args.beta1 is None) == (args.beta2 is None), \
'When using timm optimizer, BOTH beta1 and beta2 must be specified (or not specified).'
if args.beta1 is not None:
opt_kwargs['betas'] = (args.beta1, args.beta2)
if args.momentum is not None:
opt_kwargs['momentum'] = args.momentum
optimizer = create_optimizer_v2(
model,
timm_opt,
lr=args.lr,
weight_decay=args.wd,
eps=args.eps,
**opt_kwargs,
)
else:
# If some params are not passed, we use the default values based on model name.
exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
include = lambda n, p: not exclude(n, p)

named_parameters = list(model.named_parameters())
gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]

if opt == 'adamw':
optimizer = optim.AdamW(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.eps,
)
else:
assert False, f'Unknown optimizer {opt}'

if is_master(args):
if is_master(args):
defaults = copy.deepcopy(optimizer.defaults)
defaults['weight_decay'] = args.wd
defaults = ', '.join([f'{k}: {v}' for k, v in defaults.items()])
logging.info(
f'Created {type(optimizer).__name__} ({args.opt}) optimizer: {defaults}'
)

if args.horovod:
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
Expand Down Expand Up @@ -425,7 +460,7 @@ def main(args):

if args.grad_checkpointing and args.distributed:
logging.info('Disabling DDP dynamo optimizer when grad checkpointing enabled.')
# As of now (~PyTorch 2.4/2.5), compile + checkpointing but DDP optimizer must be disabled
# As of now (~PyTorch 2.4/2.5), compile + grad checkpointing work, but DDP optimizer must be disabled
torch._dynamo.config.optimize_ddp = False

model = torch.compile(original_model)
Expand Down
16 changes: 11 additions & 5 deletions src/open_clip_train/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,14 @@ def parse_args(args):
parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
parser.add_argument("--momentum", type=float, default=None, help="Momentum (for timm optimizers).")
parser.add_argument(
"--warmup", type=int, default=10000, help="Number of steps to warmup for."
)
parser.add_argument(
"--opt", type=str, default='adamw',
help="Which optimizer to use. Choices are ['adamw', or any timm optimizer 'timm/{opt_name}']."
)
parser.add_argument(
"--use-bn-sync",
default=False,
Expand Down Expand Up @@ -467,10 +472,11 @@ def parse_args(args):

args = parser.parse_args(args)

# If some params are not passed, we use the default values based on model name.
default_params = get_default_params(args.model)
for name, val in default_params.items():
if getattr(args, name) is None:
setattr(args, name, val)
if 'timm' not in args.opt:
# set default opt params based on model name (only if timm optimizer not used)
default_params = get_default_params(args.model)
for name, val in default_params.items():
if getattr(args, name) is None:
setattr(args, name, val)

return args
Loading