Skip to content

Commit

Permalink
feat: add entity embedding (#15)
Browse files Browse the repository at this point in the history
dataloader error
  • Loading branch information
eenzeenee committed Nov 23, 2022
1 parent 4d56d6f commit e3d2e4e
Show file tree
Hide file tree
Showing 17 changed files with 77 additions and 17 deletions.
Binary file modified code/__pycache__/load_data.cpython-38.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
15 changes: 10 additions & 5 deletions dataset/datacheck.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@
},
{
"cell_type": "code",
"execution_count": 51,
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -381,25 +381,30 @@
},
{
"cell_type": "code",
"execution_count": 52,
"execution_count": 56,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'호남이 기반인 바른미래당·<object>대안신당<\\\\object>·민주평화당이 우여곡절 끝에 합당해 민생당(가칭)으로 재탄생한다.'"
"'호남이 기반인 바른미래당·<obj>대안신당<\\\\obj>·<sub>민주평화당</sub>이 우여곡절 끝에 합당해 민생당(가칭)으로 재탄생한다.'"
]
},
"execution_count": 52,
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# object tag 11\n",
"sent"
"sent[:19+11] + '<sub>' + sent[19+11:24+11] + '</sub>' + sent[24+11:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
Binary file added template/__pycache__/dataloader.cpython-38.pyc
Binary file not shown.
Binary file added template/__pycache__/losses.cpython-38.pyc
Binary file not shown.
Binary file added template/__pycache__/metrics.cpython-38.pyc
Binary file not shown.
Binary file added template/__pycache__/models.cpython-38.pyc
Binary file not shown.
50 changes: 48 additions & 2 deletions template/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
import sys
import argparse
import ast

Expand Down Expand Up @@ -94,31 +95,75 @@ def add_entity_token(self, item: pd.Series):
types = item['types']

slide_size = 0
# print(item[0])
# print(item[1])
# sys.exit()
# print(item['subject_entity'])
# print(item['object_entity'])
# print()
for i, entity in enumerate([item['subject_entity'], item['object_entity']]):
print(i, entity)
special_token_pair = f'[{types[i]}]', f'[/{types[i]}]'
attached = special_token_pair[0] + entity + special_token_pair[1]
if special_token_pair not in self.special_tokens:
self.special_tokens += special_token_pair

sentence = sentence[:ids[i]+slide_size] + attached + sentence[ids[i]+len(entity)+slide_size:]
slide_size += len(f'[{types[i]}]' + f'[/{types[i]}]')
if ids[0] < ids[1]:
slide_size += len(f'[{types[i]}]' + f'[/{types[i]}]')
print('here')
return sentence

def tokenizing(self, df: pd.DataFrame) -> List[dict]:
data = []
#[added]
# tmp = ' '.join(self.special_tokens)
# print(tmp)
# print(self.tokenizer.encode(tmp))
# # print(self.tokenizer)
# sys.exit()

for idx, item in tqdm(df.iterrows(), desc='tokenizing', total=len(df)):
# concat_entity = '[SEP]'.join([item[column] for column in self.using_columns])
concat_entity = self.add_entity_token(item)

# print(concat_entity)
outputs = self.tokenizer(
concat_entity,
add_special_tokens=True,
padding='max_length',
truncation=True,
max_length=256
)
# print(outputs)
# sys.exit()
# outputs['entity_mask'] = [0]*len(outputs['input_ids'])
# print(outputs)
# [ORG] [/ORG] [PER] [/PER] [POH] [/POH] [LOC] [/LOC] [DAT] [/DAT] [NOH] [/NOH]
# [32000, 32001, 32002, 32003, 32004, 32005, 32006, 32007, 32008, 32009, 32010, 32011]

# [added] - entity mask 추가
outputs['entity_mask'] = []
entity_mask = 0
for i in range(len(outputs['input_ids'])):
idx = outputs['input_ids'][i]
if idx in list(range(32000,32011,2)):
# print(list(range(32000,32011,2)))
entity_mask = 1
outputs['entity_mask'].append(entity_mask)
elif idx in list(range(32001,32012,2)):
outputs['entity_mask'].append(entity_mask)
entity_mask = 0
else:
outputs['entity_mask'].append(entity_mask)

data.append(outputs)
if len(data) == 3:
# print(data[1]['input_ids'])
# print(data[1]['entity_mask'])
# print()
# print(data[2]['input_ids'])
# print(data[2]['entity_mask'])
sys.exit()
return data

def preprocessing(self, df: pd.DataFrame):
Expand Down Expand Up @@ -147,6 +192,7 @@ def preprocessing(self, df: pd.DataFrame):
'types': types,
'label': df['label'],
})
# print(preprocessed_df.iloc[2])

inputs = self.tokenizing(preprocessed_df)
targets = self.label_to_num(preprocessed_df['label'])
Expand Down
Binary file not shown.
3 changes: 3 additions & 0 deletions template/lightning_logs/version_0/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
lr: 1.0e-05
model_name: klue/roberta-large
pooling: true
Binary file not shown.
3 changes: 3 additions & 0 deletions template/lightning_logs/version_1/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
lr: 1.0e-05
model_name: klue/roberta-large
pooling: true
3 changes: 3 additions & 0 deletions template/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
from copy import deepcopy
import sys

import torch
import numpy as np
Expand Down Expand Up @@ -55,6 +56,8 @@ def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:

def training_step(self, batch, batch_idx):
x, y = batch
print(x)
sys.exit()
logits = self(x)
loss = self.criterion(logits, y)
self.log("train_loss", loss)
Expand Down
20 changes: 10 additions & 10 deletions template/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse

import torch
import wandb
# import wandb
import transformers
import torchmetrics
import pandas as pd
Expand Down Expand Up @@ -36,14 +36,14 @@
parser.add_argument('--predict_path', default='../dataset/test/test_data.csv')
args = parser.parse_args(args=[])

try:
wandb.login(key='4c0a01eaa2bd589d64c5297c5bc806182d126350')
except:
anony = "must"
print('If you want to use your W&B account, go to Add-ons -> Secrets and provide your W&B access token. Use the Label name as wandb_api. \nGet your W&B access token from here: https://wandb.ai/authorize')
# try:
# wandb.login(key='4c0a01eaa2bd589d64c5297c5bc806182d126350')
# except:
# anony = "must"
# print('If you want to use your W&B account, go to Add-ons -> Secrets and provide your W&B access token. Use the Label name as wandb_api. \nGet your W&B access token from here: https://wandb.ai/authorize')

wandb.init(project="level2", name= f"{args.model_name}-pooling-focal-typing_entity")
wandb_logger = WandbLogger('level2')
# wandb.init(project="level2", name= f"{args.model_name}-pooling-focal-typing_entity")
# wandb_logger = WandbLogger('level2')

dataloader = Dataloader(
args.tokenizer_name,
Expand All @@ -69,12 +69,12 @@

trainer = pl.Trainer(
accelerator='gpu',
devices=4,
devices=1,
max_epochs=args.max_epoch,
log_every_n_steps=1,
num_sanity_val_steps=0,
precision=16,
logger=wandb_logger
# logger=wandb_logger
)

# Train part
Expand Down

0 comments on commit e3d2e4e

Please sign in to comment.