Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add deepcopy and copy for Param4bit #1060

Merged
29 changes: 29 additions & 0 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Any, Dict, Optional, TypeVar, Union, overload
import warnings

Expand Down Expand Up @@ -213,6 +214,34 @@ def __new__(
self.data = data
self.module = module
return self

def __getstate__(self):
state = self.__dict__
state["data"] = self.data
state["requires_grad"] = self.requires_grad
return state

def __setstate__(self, state):
self.requires_grad = state["requires_grad"]
self.blocksize = state["blocksize"]
self.compress_statistics = state["compress_statistics"]
self.quant_type = state["quant_type"]
self.quant_state = state["quant_state"]
self.data = state["data"]

def __deepcopy__(self,memo):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
new_instance.quant_state = copy.deepcopy(state["quant_state"])
new_instance.data = copy.deepcopy(state["data"])
return new_instance

def __copy__(self):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
return new_instance
Copy link
Contributor

@akx akx Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is having to do this dance common in Torch world? 🤔

I'm a little worried that someone adding a new field in __init__ will inevitably miss adding them here...

Copy link
Contributor Author

@SunMarc SunMarc Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is having to do this dance common in Torch world? 🤔

I don't think but I wasn't able to find a better solution. I based my solution over this specific code from torch.

I'm a little worried that someone adding a new field in init will inevitably miss adding them here...

Yeah, that's true :/ . I tried modify __setstate__ so that we udpate state.__dict__ using self.__dict__ but some attributes were not copied properly.


@classmethod
def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from scipy.stats import norm
import torch

import bitsandbytes as bnb
from bitsandbytes import functional as F
import bitsandbytes as bnb
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
from tests.helpers import (
BOOLEAN_TUPLES,
TRUE_FALSE,
Expand Down
16 changes: 16 additions & 0 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
from tempfile import TemporaryDirectory

Expand Down Expand Up @@ -146,3 +147,18 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases
ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}"
assert size_ratio < target_compression, ratio_error_msg

def test_copy_param():
tensor = torch.tensor([1.,2.,3.,4.])
param = bnb.nn.Params4bit(data = tensor, requires_grad=False).cuda(0)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved

shallow_copy_param = copy.copy(param)
assert param.quant_state is shallow_copy_param.quant_state
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()

def test_deepcopy_param():
tensor = torch.tensor([1.,2.,3.,4.])
param = bnb.nn.Params4bit(data = tensor, requires_grad=False).cuda(0)
copy_param = copy.deepcopy(param)
assert param.quant_state is not copy_param.quant_state
assert param.data.data_ptr() != copy_param.data.data_ptr()
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion tests/test_linear8bitlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import pytest
import torch

import bitsandbytes as bnb
from bitsandbytes import functional as F
import bitsandbytes as bnb
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import TRUE_FALSE, id_formatter
Expand Down