Skip to content

Commit

Permalink
Introducte GaLoreWrappedParameter to decouple grad from param
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Oct 30, 2024
1 parent 0c87b70 commit 7e3b5ff
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
6 changes: 3 additions & 3 deletions bitsandbytes/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
import torch

from bitsandbytes.optim.optimizer import Optimizer2State
from bitsandbytes.optim.optimizer import GaLoreWrappedParameter, Optimizer2State

_galore_available = False
try:
Expand Down Expand Up @@ -220,15 +220,15 @@ def step(self, closure=None):
lor_update = torch.zeros_like(
grad, dtype=p.data.dtype, device=p.data.device, requires_grad=grad.requires_grad
)
p.grad = grad

if "state1" not in state:
self.init_state(group, p, gindex, pindex)

self.prefetch_state(p)

if "rank" in group:
self.update_step(group, p, gindex, pindex, return_updates=lor_update)
galore_p = GaLoreWrappedParameter(p=p, grad=grad)
self.update_step(group, galore_p, gindex, pindex, return_updates=lor_update)

# GaLore Projection Back
p.data.add_(state["projector"].project_back(lor_update))
Expand Down
32 changes: 22 additions & 10 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
# LICENSE file in the root directory of this source tree.
from collections import abc as container_abcs, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from itertools import chain
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import torch

Expand All @@ -18,6 +19,12 @@ def __init__(self, initial_data):
setattr(self, key, initial_data[key])


@dataclass
class GaLoreWrappedParameter:
p: torch.Tensor
grad: torch.Tensor


class GlobalOptimManager:
"""
A global optimizer manager for enabling custom optimizer configs.
Expand Down Expand Up @@ -497,17 +504,22 @@ def init_state(self, group, p, gindex, pindex):
def update_step(
self,
group: Dict[str, Any],
p: torch.Tensor,
p: Union[torch.Tensor, GaLoreWrappedParameter],
gindex: int,
pindex: int,
return_updates: Optional[torch.Tensor] = None,
):
# avoid update error from non-contiguous memory layout
p.data = p.data.contiguous()
p.grad = p.grad.contiguous()
if isinstance(p, GaLoreWrappedParameter):
# Unwrap for GaLore
param_to_optimize = p.p
else:
param_to_optimize = p

state = self.state[p]
grad = p.grad
state = self.state[param_to_optimize]

# avoid update error from non-contiguous memory layout
param_to_optimize.data = param_to_optimize.data.contiguous()
grad = p.grad.contiguous()

config = self.get_config(gindex, pindex, group)

Expand All @@ -528,7 +540,7 @@ def update_step(
F.optimizer_update_32bit(
self.optimizer_name,
grad,
p,
param_to_optimize,
state["state1"],
config["betas"][0],
config["eps"],
Expand All @@ -550,7 +562,7 @@ def update_step(
F.optimizer_update_8bit(
self.optimizer_name,
grad,
p,
param_to_optimize,
state["state1"],
state["state2"],
config["betas"][0],
Expand Down Expand Up @@ -578,7 +590,7 @@ def update_step(
F.optimizer_update_8bit_blockwise(
self.optimizer_name,
grad,
p,
param_to_optimize,
state["state1"],
state["state2"],
config["betas"][0],
Expand Down

0 comments on commit 7e3b5ff

Please sign in to comment.