Skip to content

Commit

Permalink
10 version
Browse files Browse the repository at this point in the history
  • Loading branch information
yana-xuyan committed Apr 4, 2020
1 parent 1c4a32e commit 753a66b
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 74 deletions.
76 changes: 40 additions & 36 deletions build/lib/caireCovid/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@


class QaModule():
def __init__(self, model_name):
def __init__(self, model_name, model_path, spiece_model, bert_config, bert_vocab):
# init QA models
self.model_name = model_name
self.predict_fn = self.getPredictors()
self.model_path = model_path
self.spiece_model = spiece_model
self.bert_config = bert_config
self.bert_vocab = bert_vocab
self.getPredictors()

def readIR(self, data):
synthetic = []
Expand All @@ -38,6 +42,7 @@ def readIR(self, data):
"doi": doi,
"title": title,
}

data_sample["qas"].append(qas_item)
synthetic.append(data_sample)

Expand All @@ -57,6 +62,7 @@ def getPredictors(self):
self.bio_predict_fn = self.getPredictor("biobert")

def getPredictor(self, model_name):
modelpath = self.getModelPath(model_name)
if model_name == 'mrqa':
d = {
"uncased": False,
Expand All @@ -66,24 +72,24 @@ def getPredictor(self, model_name):
"train_batch_size": 1,
"predict_batch_size": 1,
"shuffle_buffer": 2048,
"spiece_model_file": "./mrqa/model/spiece.model",
"spiece_model_file": self.spiece_model,
"max_seq_length": 512,
"doc_stride": 128,
"max_query_length": 64,
"n_best_size": 5,
"max_answer_length": 64,
}
self.mrqaFLAGS = namedtuple("FLAGS", d.keys())(*d.values())
return tf.contrib.predictor.from_saved_model("/kaggle/input/pretrained-qa-models/mrqa/1564469515")
return tf.contrib.predictor.from_saved_model(modelpath)
elif model_name == 'biobert':
d = {
"version_2_with_negative": False,
"null_score_diff_threshold": 0.0,
"verbose_logging": False,
"init_checkpoint": None,
"do_lower_case": False,
"bert_config_file": "./biobert/model/bert_config.json",
"vocab_file": "./biobert/model/vocab.txt",
"bert_config_file": self.bert_config,
"vocab_file": self.bert_vocab,
"train_batch_size": 1,
"predict_batch_size": 1,
"max_seq_length": 384,
Expand All @@ -93,9 +99,13 @@ def getPredictor(self, model_name):
"max_answer_length": 30,
}
self.bioFLAGS = namedtuple("FLAGS", d.keys())(*d.values())
return tf.contrib.predictor.from_saved_model("/kaggle/input/pretrained-qa-models/biobert/1585470591")
return tf.contrib.predictor.from_saved_model(modelpath)
else:
raise ValueError("invalid model name")

def getModelPath(self, model_name):
index = self.model_name.index(model_name)
return self.model_path[index]

def getAnswers(self, data):
"""
Expand Down Expand Up @@ -163,27 +173,30 @@ def getAnswers(self, data):
if "biobert" in self.model_name:
raw_bio = self.biobertPredictor([qa])
# get sentence from BioBERT
raw = raw_bio[qa["qas"][0]["id"]]
# question answering one by one
answer_start = context.find(raw, 0)
answer_end = answer_start + len(raw)
answer_span = []
for idx, span in enumerate(spans):
if not (answer_end <= span[0] or answer_start >= span[1]):
answer_span.append(idx)

y1, y2 = answer_span[0], answer_span[-1]
if not y1 == y2:
# context tokens in index y1 and y2 should be merged together
# print("Merge knowledge sentence")
answer_sent_bio = " ".join(sents[y1:y2+1])
raw = raw_bio[qa["qas"][0]["id"]]
if raw == "empty" or "":
answer_sent_bio = ""
else:
answer_sent_bio = sents[y1]

# if raw not in answer_sent_bio:
# print("RAW", raw)
# print("BIO", answer_sent_bio)
# assert raw in answer_sent_bio
# question answering one by one
answer_start = context.find(raw, 0)
answer_end = answer_start + len(raw)
answer_span = []
for idx, span in enumerate(spans):
if not (answer_end <= span[0] or answer_start >= span[1]):
answer_span.append(idx)

y1, y2 = answer_span[0], answer_span[-1]
if not y1 == y2:
# context tokens in index y1 and y2 should be merged together
# print("Merge knowledge sentence")
answer_sent_bio = " ".join(sents[y1:y2+1])
else:
answer_sent_bio = sents[y1]

# if raw not in answer_sent_bio:
# print("RAW", raw)
# print("BIO", answer_sent_bio)
assert raw in answer_sent_bio
else:
answer_sent_bio = ""

Expand All @@ -198,15 +211,6 @@ def getAnswers(self, data):
answer_sent= " ".join([answer_sent_mrqa, answer_sent_bio])

answers[-1]["data"]["answer"].append(answer_sent)


# print("context:", context)
# print("-"*80)
# print("query:", question)
# print("-"*80)
# print("answer:", answer_sent)
# input()
# break
return answers

def convert_idx(self, text, tokens):
Expand Down
2 changes: 1 addition & 1 deletion caireCovid.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: caireCovid
Version: 0.0.7
Version: 0.1.0
Summary: system for covid-19.
Home-page: https://github.com/yana-xuyan/caire-covid
Author: yana
Expand Down
76 changes: 40 additions & 36 deletions caireCovid/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@


class QaModule():
def __init__(self, model_name):
def __init__(self, model_name, model_path, spiece_model, bert_config, bert_vocab):
# init QA models
self.model_name = model_name
self.predict_fn = self.getPredictors()
self.model_path = model_path
self.spiece_model = spiece_model
self.bert_config = bert_config
self.bert_vocab = bert_vocab
self.getPredictors()

def readIR(self, data):
synthetic = []
Expand All @@ -38,6 +42,7 @@ def readIR(self, data):
"doi": doi,
"title": title,
}

data_sample["qas"].append(qas_item)
synthetic.append(data_sample)

Expand All @@ -57,6 +62,7 @@ def getPredictors(self):
self.bio_predict_fn = self.getPredictor("biobert")

def getPredictor(self, model_name):
modelpath = self.getModelPath(model_name)
if model_name == 'mrqa':
d = {
"uncased": False,
Expand All @@ -66,24 +72,24 @@ def getPredictor(self, model_name):
"train_batch_size": 1,
"predict_batch_size": 1,
"shuffle_buffer": 2048,
"spiece_model_file": "./mrqa/model/spiece.model",
"spiece_model_file": self.spiece_model,
"max_seq_length": 512,
"doc_stride": 128,
"max_query_length": 64,
"n_best_size": 5,
"max_answer_length": 64,
}
self.mrqaFLAGS = namedtuple("FLAGS", d.keys())(*d.values())
return tf.contrib.predictor.from_saved_model("/kaggle/input/pretrained-qa-models/mrqa/1564469515")
return tf.contrib.predictor.from_saved_model(modelpath)
elif model_name == 'biobert':
d = {
"version_2_with_negative": False,
"null_score_diff_threshold": 0.0,
"verbose_logging": False,
"init_checkpoint": None,
"do_lower_case": False,
"bert_config_file": "./biobert/model/bert_config.json",
"vocab_file": "./biobert/model/vocab.txt",
"bert_config_file": self.bert_config,
"vocab_file": self.bert_vocab,
"train_batch_size": 1,
"predict_batch_size": 1,
"max_seq_length": 384,
Expand All @@ -93,9 +99,13 @@ def getPredictor(self, model_name):
"max_answer_length": 30,
}
self.bioFLAGS = namedtuple("FLAGS", d.keys())(*d.values())
return tf.contrib.predictor.from_saved_model("/kaggle/input/pretrained-qa-models/biobert/1585470591")
return tf.contrib.predictor.from_saved_model(modelpath)
else:
raise ValueError("invalid model name")

def getModelPath(self, model_name):
index = self.model_name.index(model_name)
return self.model_path[index]

def getAnswers(self, data):
"""
Expand Down Expand Up @@ -163,27 +173,30 @@ def getAnswers(self, data):
if "biobert" in self.model_name:
raw_bio = self.biobertPredictor([qa])
# get sentence from BioBERT
raw = raw_bio[qa["qas"][0]["id"]]
# question answering one by one
answer_start = context.find(raw, 0)
answer_end = answer_start + len(raw)
answer_span = []
for idx, span in enumerate(spans):
if not (answer_end <= span[0] or answer_start >= span[1]):
answer_span.append(idx)

y1, y2 = answer_span[0], answer_span[-1]
if not y1 == y2:
# context tokens in index y1 and y2 should be merged together
# print("Merge knowledge sentence")
answer_sent_bio = " ".join(sents[y1:y2+1])
raw = raw_bio[qa["qas"][0]["id"]]
if raw == "empty" or "":
answer_sent_bio = ""
else:
answer_sent_bio = sents[y1]

# if raw not in answer_sent_bio:
# print("RAW", raw)
# print("BIO", answer_sent_bio)
# assert raw in answer_sent_bio
# question answering one by one
answer_start = context.find(raw, 0)
answer_end = answer_start + len(raw)
answer_span = []
for idx, span in enumerate(spans):
if not (answer_end <= span[0] or answer_start >= span[1]):
answer_span.append(idx)

y1, y2 = answer_span[0], answer_span[-1]
if not y1 == y2:
# context tokens in index y1 and y2 should be merged together
# print("Merge knowledge sentence")
answer_sent_bio = " ".join(sents[y1:y2+1])
else:
answer_sent_bio = sents[y1]

# if raw not in answer_sent_bio:
# print("RAW", raw)
# print("BIO", answer_sent_bio)
assert raw in answer_sent_bio
else:
answer_sent_bio = ""

Expand All @@ -198,15 +211,6 @@ def getAnswers(self, data):
answer_sent= " ".join([answer_sent_mrqa, answer_sent_bio])

answers[-1]["data"]["answer"].append(answer_sent)


# print("context:", context)
# print("-"*80)
# print("query:", question)
# print("-"*80)
# print("answer:", answer_sent)
# input()
# break
return answers

def convert_idx(self, text, tokens):
Expand Down
Binary file removed dist/caireCovid-0.0.7.tar.gz
Binary file not shown.
Binary file not shown.
Binary file added dist/caireCovid-0.1.0.tar.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="caireCovid",
version="0.0.7",
version="0.1.0",
author="yana",
author_email="[email protected]",
description="system for covid-19.",
Expand Down

0 comments on commit 753a66b

Please sign in to comment.