Skip to content

Commit

Permalink
fix issue I8CJMD (#606)
Browse files Browse the repository at this point in the history
Co-authored-by: panshaowu <[email protected]>
  • Loading branch information
panshaowu and panshaowu authored Nov 8, 2023
1 parent 2e08ac8 commit 8ea4da4
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/cn/tutorials/yaml_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
| mode | MindSpore运行模式(静态图/动态图) | 0 | 0 / 1 | 0: 表示在GRAPH_MODE模式中运行; 1: PYNATIVE_MODE模式 |
| distribute | 是否开启并行训练 | True | True / False | \ |
| device_id | 指定单卡训练时的卡id | 7 | 机器可用的卡的id | 该参数仅在distribute=False(单卡训练)和环境变量DEVICE_ID未设置时生效。单卡训练时,如该参数和环境变量DEVICE_ID均未设置,则默认使用0卡。 |
| amp_level | 混合精度模式 | O0 | O0/O1/O2/O3 | 'O0' - 不变化。<br> 'O1' - 将白名单内的Cell和运算转为float16精度,其余部分保持float32精度。<br> 'O2' - 将黑名单内的Cell和运算保持float32精度,其余部分转为float16精度。<br> 'O3' - 将网络全部转为float16精度。|
| amp_level | 混合精度模式 | O0 | O0/O1/O2/O3 | 'O0' - 不变化。<br> 'O1' - 将白名单内的Cell和运算转为float16精度,其余部分保持float32精度。<br> 'O2' - 将黑名单内的Cell和运算保持float32精度,其余部分转为float16精度。<br> 'O3' - 将网络全部转为float16精度。<br> 注意:GPU平台上的模型推理或评估暂不支持'O3'模式,如设置为'O3'模式,程序会自动将其转为'O2'模式。|
| seed | 随机种子 | 42 | Integer | \ |
| ckpt_save_policy | 模型权重保存策略 | top_k | "top_k" 或 "latest_k" | "top_k"表示保存前k个评估指标分数最高的checkpoint;"latest_k"表示保存最新的k个checkpoint。 `k`的数值通过`ckpt_max_keep`参数定义 |
| ckpt_max_keep | 最多保存的checkpoint数量 | 5 | Integer | \ |
Expand Down
2 changes: 1 addition & 1 deletion docs/en/tutorials/yaml_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ This document takes `configs/rec/crnn/crnn_icdar15.yaml` as an example to descri
| mode | Mindspore running mode (static graph/dynamic graph) | 0 | 0 / 1 | 0: means running in GRAPH_MODE mode; 1: PYNATIVE_MODE mode |
| distribute | Whether to enable parallel training | True | True / False | \ |
| device_id | Specify the device id while standalone training | 7 | The ids of all devices in the server | Only valid when distribute=False (standalone training) and environment variable 'DEVICE_ID' is NOT set. While standalone training, if both this arg and environment variable 'DEVICE_ID' are NOT set, use device 0 by default. |
| amp_level | Mixed precision mode | O0 | O0/O1/O2/O3 | 'O0' - no change. <br> 'O1' - convert the cells and operations in the whitelist to float16 precision, and keep the rest in float32 precision. <br> 'O2' - Keep the cells and operations in the blacklist with float32 precision, and convert the rest to float16 precision. <br> 'O3' - Convert all networks to float16 precision. |
| amp_level | Mixed precision mode | O0 | O0/O1/O2/O3 | 'O0' - no change. <br> 'O1' - convert the cells and operations in the whitelist to float16 precision, and keep the rest in float32 precision. <br> 'O2' - Keep the cells and operations in the blacklist with float32 precision, and convert the rest to float16 precision. <br> 'O3' - Convert all networks to float16 precision. <br> Notice: Model prediction or evaluation does not support 'O3' on GPU platform. If amp_level is set to 'O3' for model prediction and evaluation on GPU platform, the program will switch it to 'O2' automatically.|
| seed | Random seed | 42 | Integer | \ |
| ckpt_save_policy | The policy for saving model weights | top_k | "top_k" or "latest_k" | "top_k" means to keep the top k checkpoints according to the metric score; "latest_k" means to keep the last k checkpoints. The value of `k` is set via `ckpt_max_keep` |
| ckpt_max_keep | The maximum number of checkpoints to keep during training | 5 | Integer | \ |
Expand Down
6 changes: 6 additions & 0 deletions tools/benchmarking/multi_dataset_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ def main(cfg):
# model
cfg.model.backbone.pretrained = False
amp_level = cfg.system.get("amp_level_infer", "O0")
if ms.get_context("device_target") == "GPU" and amp_level == "O3":
logger.warning(
"Model inference does not support amp_level O3 on GPU currently. "
"The program has switched to amp_level O2 automatically."
)
amp_level = "O2"
network = build_model(cfg.model, ckpt_load_path=cfg.eval.ckpt_load_path, amp_level=amp_level)
num_params = sum([param.size for param in network.get_parameters()])
num_trainable_params = sum([param.size for param in network.trainable_params()])
Expand Down
6 changes: 6 additions & 0 deletions tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def main(cfg):
# model
cfg.model.backbone.pretrained = False
amp_level = cfg.system.get("amp_level_infer", "O0")
if ms.get_context("device_target") == "GPU" and amp_level == "O3":
logger.warning(
"Model evaluation does not support amp_level O3 on GPU currently. "
"The program has switched to amp_level O2 automatically."
)
amp_level = "O2"
network = build_model(cfg.model, ckpt_load_path=cfg.eval.ckpt_load_path, amp_level=amp_level)
num_params = sum([param.size for param in network.get_parameters()])
num_trainable_params = sum([param.size for param in network.trainable_params()])
Expand Down
13 changes: 9 additions & 4 deletions tools/infer/text/predict_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Text detection inference
Example:
$ python tools/infer/text/predict_det.py --image_dir {path_to_img} --rec_algorithm DB++
$ python tools/infer/text/predict_det.py --image_dir {path_to_img} --det_algorithm DB++
"""

import json
Expand Down Expand Up @@ -54,9 +54,14 @@ def __init__(self, args):
f"Supported detection algorithms are {list(algo_to_model_name.keys())}"
)
model_name = algo_to_model_name[args.det_algorithm]
self.model = build_model(
model_name, pretrained=pretrained, ckpt_load_path=ckpt_load_path, amp_level=args.det_amp_level
)
amp_level = args.det_amp_level
if ms.get_context("device_target") == "GPU" and amp_level == "O3":
logger.warning(
"Detection model prediction does not support amp_level O3 on GPU currently. "
"The program has switched to amp_level O2 automatically."
)
amp_level = "O2"
self.model = build_model(model_name, pretrained=pretrained, ckpt_load_path=ckpt_load_path, amp_level=amp_level)
self.model.set_train(False)
logger.info(
"Init detection model: {} --> {}. Model weights loaded from {}".format(
Expand Down
8 changes: 6 additions & 2 deletions tools/infer/text/predict_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,18 @@ def __init__(self, args):
)
model_name = algo_to_model_name[args.rec_algorithm]

# amp_level = 'O2' if args.rec_algorithm.startswith('SVTR') else args.rec_amp_level
amp_level = args.rec_amp_level
if args.rec_algorithm.startswith("SVTR") and amp_level != "O2":
logger.warning(
"SVTR recognition model is optimized for amp_level O2. ampl_level for rec model is changed to O2"
)
amp_level = "O2"

if ms.get_context("device_target") == "GPU" and amp_level == "O3":
logger.warning(
"Recognition model prediction does not support amp_level O3 on GPU currently. "
"The program has switched to amp_level O2 automatically."
)
amp_level = "O2"
self.model = build_model(model_name, pretrained=pretrained, ckpt_load_path=ckpt_load_path, amp_level=amp_level)
self.model.set_train(False)

Expand Down
6 changes: 6 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def main(cfg):

# create model
amp_level = cfg.system.get("amp_level", "O0")
if ms.get_context("device_target") == "GPU" and cfg.system.val_while_train and amp_level == "O3":
logger.warning(
"Model evaluation does not support amp_level O3 on GPU currently. "
"The program has switched to amp_level O2 automatically."
)
amp_level = "O2"
network = build_model(cfg.model, ckpt_load_path=cfg.model.pop("pretrained", None), amp_level=amp_level)
num_params = sum([param.size for param in network.get_parameters()])
num_trainable_params = sum([param.size for param in network.trainable_params()])
Expand Down

0 comments on commit 8ea4da4

Please sign in to comment.