Skip to content

Commit

Permalink
Merge pull request #11 from husichao666/master
Browse files Browse the repository at this point in the history
使用try适配全部MindSpore版本
  • Loading branch information
suhaibo1 authored Aug 15, 2023
2 parents e185c41 + a658647 commit f165d99
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 31 deletions.
2 changes: 1 addition & 1 deletion set_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def get_version():
version = '1.0.0'
version = '1.0.1'
return version


Expand Down
11 changes: 5 additions & 6 deletions tk/delta/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023, All rights reserved.

from collections import OrderedDict
from tk.utils.version_utils import is_version_ge

import mindspore as ms
import mindspore.nn as nn
Expand All @@ -13,14 +12,14 @@
from mindspore.ops import functional as F
from tk.delta.delta_constants import VALID_TENSOR_DATATYPE

if is_version_ge(ms.__version__, '1.11.0'):
try:
from mindspore.nn.transformer.layers import _Linear, _args_type_validator_check, _valid_value_checks
from mindspore._checkparam import Validator
except:
from mindformers.modules.layers import Linear, _args_type_validator_check, _valid_value_checks
import mindspore._checkparam as Validator
_Linear = Linear
else:
from mindspore.nn.transformer.layers import _Linear, _args_type_validator_check, _valid_value_checks
from mindspore._checkparam import Validator


class AdapterLayer(nn.Cell):
"""
定义微调算法adapter layer层,初始化adapter layer层参数,包括矩阵参数、激活层等。
Expand Down
16 changes: 6 additions & 10 deletions tk/delta/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023, All rights reserved.

import math
from tk.utils.version_utils import is_version_ge

import mindspore as ms
import mindspore.nn as nn
from mindspore import ops
Expand All @@ -15,13 +13,14 @@
from mindspore.ops import functional as F
from mindspore.common.initializer import initializer, HeUniform
from tk.delta.delta_constants import VALID_TENSOR_DATATYPE
from tk.utils.version_control import get_dropout

if is_version_ge(ms.__version__, '1.11.0'):
import mindspore._checkparam as Validator
INC_LEFT = Validator.INC_LEFT
else:
try:
from mindspore._checkparam import Validator, Rel
INC_LEFT = Rel.INC_LEFT
except:
import mindspore._checkparam as Validator
INC_LEFT = Validator.INC_LEFT

class LoRADense(nn.Dense):
"""Define a dense layer with LoRA structure.
Expand Down Expand Up @@ -60,10 +59,7 @@ def __init__(
# Define and initialize params
self.lora_rank = lora_rank
self.lora_alpha = lora_alpha
if is_version_ge(ms.__version__, '1.11.0'):
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = nn.Dropout(keep_prob=1 - lora_dropout)
self.lora_dropout = get_dropout(lora_dropout)
self.tk_delta_lora_a = Parameter(
initializer(lora_a_init, [lora_rank, in_channels], param_init_type),
name='tk_delta_lora_A')
Expand Down
9 changes: 4 additions & 5 deletions tk/delta/low_rank_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023, All rights reserved.

import numbers
from tk.utils.version_utils import is_version_ge

import mindspore as ms
import mindspore.nn as nn
Expand All @@ -17,12 +16,12 @@
from mindspore.ops import functional as F
from tk.delta.delta_constants import VALID_TENSOR_DATATYPE

if is_version_ge(ms.__version__, '1.11.0'):
from mindformers.modules.layers import _args_type_validator_check, _valid_type_checks, _valid_value_checks
import mindspore._checkparam as Validator
else:
try:
from mindspore.nn.transformer.layers import _args_type_validator_check, _valid_type_checks, _valid_value_checks
from mindspore._checkparam import Validator
except:
from mindformers.modules.layers import _args_type_validator_check, _valid_type_checks, _valid_value_checks
import mindspore._checkparam as Validator


class LowRankLinear(nn.Cell):
Expand Down
15 changes: 6 additions & 9 deletions tk/delta/prefix_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import mindspore as ms
import mindspore.nn as nn

from tk.utils.version_utils import is_version_ge
from tk.utils.version_control import get_dropout

if is_version_ge(ms.__version__, '1.11.0'):
import mindspore._checkparam as Validator
INC_LEFT = Validator.INC_LEFT
else:
try:
from mindspore._checkparam import Validator, Rel
INC_LEFT = Rel.INC_LEFT
except:
import mindspore._checkparam as Validator
INC_LEFT = Validator.INC_LEFT


def check_multiple(param_dividend, value_dividend, param_divisor, value_divisor):
Expand Down Expand Up @@ -55,10 +55,7 @@ def __init__(self,
except ValueError as ex:
raise ValueError(f"Invalid param [prefix_token_num] when initializing"
f"PrefixLayer, error message:{str(ex)}") from ex
if is_version_ge(ms.__version__, '1.11.0'):
self.dropout = nn.Dropout(p=dropout_rate)
else:
self.dropout = nn.Dropout(keep_prob=1 - dropout_rate)
self.dropout = get_dropout(dropout_rate)
self.past_value_reparam = None
self.past_key_reparam = None
self.__define_network()
Expand Down
10 changes: 10 additions & 0 deletions tk/utils/version_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import mindspore as ms
from mindspore import nn
from .version_utils import is_version_ge

def get_dropout(dropout_prob):
if is_version_ge(ms.__version__, '1.11.0'):
dropout = nn.Dropout(p=dropout_prob)
else:
dropout = nn.Dropout(keep_prob=1 - dropout_prob)
return dropout

0 comments on commit f165d99

Please sign in to comment.