Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add crf and lstm-crf example; update the config file #84

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions research/huawei-gts/CRF/default_config.yaml
Original file line number Diff line number Diff line change
@@ -1,27 +1,37 @@
# --------------------------------------------------
# Builtin Configurations
data_path: "../conll2003"
# 训练数据存储的地方
data_path: '../conll2003'
# 输出ckpt文件的名称
ckpt_save_path: '../ckpt_lstm_crf'
# 设备类型: Ascend, CPU。默认Ascend
device_target: 'CPU'
device_id: 1
enable_profiling: False
# 在device_target位Ascend时,指定参与计算的NPU
device_id: 2
# 导出文件的前缀
export_prefix: 'model-crf'
export_suffix: ''

# --------------------------------------------------
# LSTM_CRF CONFIG
num_epochs: 20
batch_size: 20
embed_size: 300
num_hidden: 320
# 训练的轮数,默认20轮
num_epochs: 2
# 单次处理数据量的大小,默认16
batch_size: 16
# 组成最大长度
vocab_max_length: 113
# embedding_dim数量,默认384
embedding_dim: 128
# hidden_dim数量,默认384
hidden_dim: 128
num_layers: 2
bidirectional: True

# train.py
device_num: 1
data_CoNLL_path: "../data/conll2003"
learning_ratelearning_rate: True
# 学习率过小会导致loss不容易对其,学习率过大会导致预测准确率位0(可以加大epoch解决)
learning_rate: 0.00001

# export.py
model_format: "MINDIR"
model_path: "./"
ckpt_path: "./lstm-crf.ckpt"

---
device_target: ['Ascend', 'CPU']
file_format: ['AIR', 'MINDIR']
file_format: ['AIR', 'MINDIR']
10 changes: 4 additions & 6 deletions research/huawei-gts/CRF/scripts/export_ascend.sh
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
#!/bin/bash
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash export_ascend.sh DEVICE_ID TRAIN_DATA_FOLDER CKPT_FILE"
echo "for example: bash export_ascend.sh 0 ../conll2003 lstm_crf.ckpt"
echo "bash export_ascend.sh DEVICE_ID TRAIN_DATA_FOLDER EXPORT_SUFFIX"
echo "for example: bash export_ascend.sh 0 ../conll2003 my_suffix"
echo "=============================================================================================================="

DEVICE_ID=$1
TRAIN_DATA_FOLDER=$2
CKPT_FILE=$3
EXPORT_SUFFIX=$3

BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"


python "${BASE_PATH}/../src/export.py" \
--config_path=$CONFIG_FILE \
--device_target="Ascend" \
--model_format="ckpt" \
--device_id=${DEVICE_ID}\
--data_path=${TRAIN_DATA_FOLDER}\
--ckpt_path=${CKPT_FILE}
--export_suffix=${EXPORT_SUFFIX}
# --ckpt_path=${CKPT_FILE} > log_export.txt 2>&1 &
Empty file.
26 changes: 26 additions & 0 deletions research/huawei-gts/CRF/src/example/crf_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import sys
from pathlib import Path

sys.path.append(f'{Path.cwd()}')
print(f'======CurrentPath: {Path.cwd()}')
import mindspore as ms

# 设置mindspore的执行目标,可以使Ascend、CPU、GPU,mode建议位图模式。注意,ms需要放到import的首行,避免context设置不生效
ms.set_context(device_target="CPU", mode=ms.GRAPH_MODE)
from model.lstm_crf_model import CRF
import mindspore.numpy as mnp

if __name__ == '__main__':
# 需要使用的实体索引,可以根据需要使用BIO或者BIOES作为标注模式
tag_to_idx = {"B": 0, "I": 1, "O": 2}

# 初始化模型,这里需要传入待使用实体映射的个数
model = CRF(len(tag_to_idx))

# 定义decode需要的emissions
seq_length = 3
batch_size = 2
emissions = mnp.randn(seq_length, batch_size, len(tag_to_idx))
score, history = model(emissions)
best_tags_list = CRF.post_decode(score, history, mnp.full(batch_size, 1))
print(best_tags_list)
22 changes: 22 additions & 0 deletions research/huawei-gts/CRF/src/example/crf_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import sys
from pathlib import Path

sys.path.append(f'{Path.cwd()}')
print(f'======CurrentPath: {Path.cwd()}')

import mindspore as ms

# 设置mindspore的执行目标,可以使Ascend、CPU、GPU,mode建议位图模式。注意,ms需要放到import的首行,避免context设置不生效
ms.set_context(device_target="CPU", mode=ms.GRAPH_MODE)

from model.lstm_crf_model import CRF

if __name__ == '__main__':
# 需要使用的实体索引,可以根据需要使用BIO或者BIOES作为标注模式
tag_to_idx = {"B": 0, "I": 1, "O": 2}

# 初始化模型,这里需要传入待使用实体映射的个数
model = CRF(len(tag_to_idx))

# 初始化完成,可以打印模型进行查看
print(model)
27 changes: 27 additions & 0 deletions research/huawei-gts/CRF/src/example/crf_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import sys
from pathlib import Path

sys.path.append(f'{Path.cwd()}')
print(f'======CurrentPath: {Path.cwd()}')
import mindspore as ms

# 设置mindspore的执行目标,可以使Ascend、CPU、GPU,mode建议位图模式。注意,ms需要放到import的首行,避免context设置不生效
ms.set_context(device_target="CPU", mode=ms.GRAPH_MODE)
from model.lstm_crf_model import CRF

import mindspore.numpy as mnp

if __name__ == '__main__':
# 需要使用的实体索引,可以根据需要使用BIO或者BIOES作为标注模式
tag_to_idx = {"B": 0, "I": 1, "O": 2}

# 初始化模型,这里需要传入待使用实体映射的个数
model = CRF(len(tag_to_idx))

# 定义计算loss相关的输入emissions与tag
seq_length = 3
batch_size = 2
emissions = mnp.randn(seq_length, batch_size, len(tag_to_idx))
tags = mnp.array([[0, 1], [1, 2], [2, 1]])
loss = model(emissions, tags)
print(loss)
69 changes: 69 additions & 0 deletions research/huawei-gts/CRF/src/example/lstm_crf_ckpt_create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import sys
from pathlib import Path

import numpy as np
from tqdm import tqdm

sys.path.append(f'{Path.cwd()}')
print(f'======CurrentPath: {Path.cwd()}')
import mindspore as ms
import mindspore.dataset as ds
import mindspore.nn as nn

# 设置mindspore的执行目标,可以使Ascend、CPU、GPU,mode建议位图模式。注意,ms需要放到import的首行,避免context设置不生效
ms.set_context(device_target="CPU", mode=ms.GRAPH_MODE)
from model.lstm_crf_model import BiLSTM_CRF
from utils.dataset import read_data, GetDatasetGenerator, get_dict, COLUMN_NAME
from utils.config import config

if __name__ == '__main__':
# Step1: 定义初始化参数
embedding_dim = config.embedding_dim
hidden_dim = config.hidden_dim
Max_Len = config.vocab_max_length
batch_size = config.batch_size

# BIOES标注模式: 一般一共分为四大类:PER(人名),LOC(位置[地名]),ORG(组织)以及MISC(杂项),而且B表示开始,I表示中间,O表示不是实体。
Entity = ['PER', 'LOC', 'ORG', 'MISC']
labels_text_mp = {k: v for k, v in enumerate(Entity)}
LABEL_MAP = {'O': 0} # 非实体
for i, e in enumerate(Entity):
LABEL_MAP[f'B-{e}'] = 2 * (i + 1) - 1 # 实体首字
LABEL_MAP[f'I-{e}'] = 2 * (i + 1) # 实体非首字

# Step2: 读取数据集
train_dataset = read_data('../../conll2003/train.txt')
char_number, id_indexs = get_dict(train_dataset[0])

train_dataset_generator = GetDatasetGenerator(train_dataset, id_indexs)
train_dataset_ds = ds.GeneratorDataset(train_dataset_generator, COLUMN_NAME, shuffle=False)
train_dataset_batch = train_dataset_ds.batch(batch_size, drop_remainder=True)

# Step3: 初始化模型与优化器
model = BiLSTM_CRF(vocab_size=len(id_indexs), embedding_dim=embedding_dim, hidden_dim=hidden_dim,
num_tags=len(Entity) * 2 + 1)
optimizer = nn.Adam(model.trainable_params(), learning_rate=config.learning_rate)
grad_fn = ms.value_and_grad(model, None, optimizer.parameters)


def train_step(token_ids, seq_len, labels):
loss, grads = grad_fn(token_ids, seq_len, labels)
optimizer(grads)
return loss


# Step5: 训练
tloss = []
for epoch in range(config.num_epochs):
model.set_train()
with tqdm(total=train_dataset_batch.get_dataset_size()) as t:
for batch, (token_ids, seq_len, labels) in enumerate(train_dataset_batch.create_tuple_iterator()):
loss = train_step(token_ids, seq_len, labels)
tloss.append(loss.asnumpy())
t.set_postfix(loss=np.array(tloss).mean())
t.update(1)

# Step6: 导出CKPT
file_name = 'lstm_crf.ckpt'
ms.save_checkpoint(model, ckpt_file_name=file_name)
print(f'======Create CKPT SUCCEEDED, file: {file_name}')
62 changes: 62 additions & 0 deletions research/huawei-gts/CRF/src/example/lstm_crf_ckpt_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import sys
from pathlib import Path

from tqdm import tqdm

sys.path.append(f'{Path.cwd()}')
print(f'======CurrentPath: {Path.cwd()}')
import mindspore as ms
import mindspore.dataset as ds
from mindspore.train.serialization import load_checkpoint, load_param_into_net

# 设置mindspore的执行目标,可以使Ascend、CPU、GPU,mode建议位图模式。注意,ms需要放到import的首行,避免context设置不生效
ms.set_context(device_target="CPU", mode=ms.GRAPH_MODE)
from model.lstm_crf_model import BiLSTM_CRF, CRF
from utils.dataset import read_data, GetDatasetGenerator, get_dict, COLUMN_NAME, get_entity
from utils.metrics import get_metric
from utils.config import config

if __name__ == '__main__':
# Step1: 定义初始化参数
batch_size = config.batch_size

# BIOES标注模式: 一般一共分为四大类:PER(人名),LOC(位置[地名]),ORG(组织)以及MISC(杂项),而且B表示开始,I表示中间,O表示不是实体。
Entity = ['PER', 'LOC', 'ORG', 'MISC']

# Step2: 加载ckpt,传入文件路径与名称
file_name = 'lstm_crf.ckpt'
param_dict = load_checkpoint(file_name)

# Step3: 获取模型初始化参数
embedding_shape = param_dict.get('embedding.embedding_table').shape

# Step4: 初始化模型
model = BiLSTM_CRF(vocab_size=embedding_shape[0], embedding_dim=embedding_shape[1], hidden_dim=embedding_shape[1],
num_tags=len(Entity) * 2 + 1)

# Step5: 将ckpt导入model
load_param_into_net(model, param_dict)
print(model)

# Step6: 读取数据集
train_dataset = read_data('../../conll2003/train.txt')
test_dataset = read_data('../../conll2003/test.txt')
char_number, id_indexs = get_dict(train_dataset[0])

test_dataset_generator = GetDatasetGenerator(test_dataset, id_indexs)
test_dataset_ds = ds.GeneratorDataset(test_dataset_generator, COLUMN_NAME, shuffle=False)
test_dataset_batch = test_dataset_ds.batch(batch_size, drop_remainder=True)

# Step7: 进行预测
decodes = []
model.set_train(False)
with tqdm(total=test_dataset_batch.get_dataset_size()) as t:
for batch, (token_ids, seq_len, labels) in enumerate(test_dataset_batch.create_tuple_iterator()):
score, history = model(token_ids, seq_len)
best_tag = CRF.post_decode(score, history, seq_len)
decode = [[y for y in x] for x in best_tag]
decodes.extend(list(decode))
t.update(1)

pred = [get_entity(x) for x in decodes]
get_metric(pred, test_dataset_generator)
22 changes: 22 additions & 0 deletions research/huawei-gts/CRF/src/example/lstm_crf_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import sys
from pathlib import Path

sys.path.append(f'{Path.cwd()}')
print(f'======CurrentPath: {Path.cwd()}')
import mindspore as ms

# 设置mindspore的执行目标,可以使Ascend、CPU、GPU,mode建议位图模式。注意,ms需要放到import的首行,避免context设置不生效
ms.set_context(device_target="CPU", mode=ms.GRAPH_MODE)
from model.lstm_crf_model import BiLSTM_CRF
from utils.config import config

if __name__ == '__main__':
# 需要使用的实体索引,可以根据需要使用BIO或者BIOES作为标注模式
tag_to_idx = {"B": 0, "I": 1, "O": 2}

len_id_index = 1024

# 初始化模型
model = BiLSTM_CRF(vocab_size=len_id_index, embedding_dim=config.embedding_dim, hidden_dim=config.hidden_dim,
num_tags=len(tag_to_idx))
print(model)
71 changes: 71 additions & 0 deletions research/huawei-gts/CRF/src/example/lstm_crf_mindir_create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import sys
from pathlib import Path

import numpy as np
from tqdm import tqdm

sys.path.append(f'{Path.cwd()}')
print(f'======CurrentPath: {Path.cwd()}')
import mindspore as ms
import mindspore.ops as ops
import mindspore.dataset as ds
import mindspore.nn as nn

# 设置mindspore的执行目标,可以使Ascend、CPU、GPU,mode建议位图模式。注意,ms需要放到import的首行,避免context设置不生效
ms.set_context(device_target="CPU", mode=ms.GRAPH_MODE)
from model.lstm_crf_model import BiLSTM_CRF
from utils.dataset import read_data, GetDatasetGenerator, get_dict, COLUMN_NAME
from utils.config import config

if __name__ == '__main__':
# Step1: 定义初始化参数
embedding_dim = config.embedding_dim
hidden_dim = config.hidden_dim
Max_Len = config.vocab_max_length
batch_size = config.batch_size

# BIOES标注模式: 一般一共分为四大类:PER(人名),LOC(位置[地名]),ORG(组织)以及MISC(杂项),而且B表示开始,I表示中间,O表示不是实体。
Entity = ['PER', 'LOC', 'ORG', 'MISC']
labels_text_mp = {k: v for k, v in enumerate(Entity)}
LABEL_MAP = {'O': 0} # 非实体
for i, e in enumerate(Entity):
LABEL_MAP[f'B-{e}'] = 2 * (i + 1) - 1 # 实体首字
LABEL_MAP[f'I-{e}'] = 2 * (i + 1) # 实体非首字

# Step2: 读取数据集
train_dataset = read_data('../../conll2003/train.txt')
char_number, id_indexs = get_dict(train_dataset[0])

train_dataset_generator = GetDatasetGenerator(train_dataset, id_indexs)
train_dataset_ds = ds.GeneratorDataset(train_dataset_generator, COLUMN_NAME, shuffle=False)
train_dataset_batch = train_dataset_ds.batch(batch_size, drop_remainder=True)

# Step3: 初始化模型与优化器
model = BiLSTM_CRF(vocab_size=len(id_indexs), embedding_dim=embedding_dim, hidden_dim=hidden_dim,
num_tags=len(Entity) * 2 + 1)
optimizer = nn.Adam(model.trainable_params(), learning_rate=config.learning_rate)
grad_fn = ms.value_and_grad(model, None, optimizer.parameters)


def train_step(token_ids, seq_len, labels):
loss, grads = grad_fn(token_ids, seq_len, labels)
optimizer(grads)
return loss


# Step5: 训练
tloss = []
for epoch in range(config.num_epochs):
model.set_train()
with tqdm(total=train_dataset_batch.get_dataset_size()) as t:
for batch, (token_ids, seq_len, labels) in enumerate(train_dataset_batch.create_tuple_iterator()):
loss = train_step(token_ids, seq_len, labels)
tloss.append(loss.asnumpy())
t.set_postfix(loss=np.array(tloss).mean())
t.update(1)

# Step6: 导出MindIR
file_name = 'lstm_crf.mindir'
ms.export(model, ops.ones((batch_size, Max_Len), ms.int64), ops.ones(batch_size, ms.int64), file_name=file_name,
file_format='MINDIR')
print(f'======Create MINDIR SUCCEEDED, file: {file_name}')
Loading