diff --git a/aquilign/align/bertalign/Bertalign.py b/aquilign/align/bertalign/Bertalign.py deleted file mode 100644 index 4cdbe11..0000000 --- a/aquilign/align/bertalign/Bertalign.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Bertalign initialization -""" - -__author__ = "Jason (bfsujason@163.com)" -__version__ = "1.1.0" - -from aquilign.align.bertalign.encoder import Encoder -# from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline -# See other cross-lingual embedding models at -# https://www.sbert.net/docs/pretrained_models.html -print("Chargement 1") - -models = {0: "distiluse-base-multilingual-cased-v2", 1: "LaBSE", 2: "Sonar"} -# as_list = "" -# for key, value in models.items(): -# as_list += f"{int(key)}\t{value}\n" -# print(as_list) -# nb = input(f'Choose a model:') -model = Encoder(models[int(1)]) - -from aquilign.align.bertalign.aligner import Bertalign diff --git a/aquilign/align/bertalign/aligner.py b/aquilign/align/bertalign/aligner.py index 0f629c3..cbc7f83 100644 --- a/aquilign/align/bertalign/aligner.py +++ b/aquilign/align/bertalign/aligner.py @@ -1,6 +1,5 @@ import numpy as np -from aquilign.align.bertalign.Bertalign import model import aquilign.align.bertalign.corelib as core import aquilign.align.bertalign.utils as utils import torch.nn as nn @@ -8,6 +7,7 @@ class Bertalign: def __init__(self, + model, src, tgt, max_align=3, @@ -17,7 +17,7 @@ def __init__(self, margin=True, len_penalty=True, is_split=False, - ): + device="cpu"): self.max_align = max_align self.top_k = top_k @@ -25,6 +25,9 @@ def __init__(self, self.skip = skip self.margin = margin self.len_penalty = len_penalty + self.device = device + self.model = model + @@ -38,11 +41,11 @@ def __init__(self, assert len(src_sents) != 0, "Problemo" print("Embedding source and target text using {} ...".format(model.model_name)) - src_vecs, src_lens = model.transform(src_sents, max_align - 1) - tgt_vecs, tgt_lens = model.transform(tgt_sents, max_align - 1) + src_vecs, src_lens = self.model.transform(src_sents, max_align - 1) + tgt_vecs, tgt_lens = self.model.transform(tgt_sents, max_align - 1) - self.search_simple_vecs = model.simple_vectorization(src_sents) - self.tgt_simple_vecs = model.simple_vectorization(tgt_sents) + self.search_simple_vecs = self.model.simple_vectorization(src_sents) + self.tgt_simple_vecs = self.model.simple_vectorization(tgt_sents) char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,]) @@ -57,7 +60,7 @@ def __init__(self, self.tgt_vecs = tgt_vecs def compute_distance(self): - if torch.cuda.is_available(): # GPU version + if torch.cuda.is_available() and self.device == 'cuda:0': # GPU version cos = nn.CosineSimilarity(dim=1, eps=1e-6) output = cos(torch.from_numpy(self.search_simple_vecs), torch.from_numpy(self.tgt_simple_vecs)) return output @@ -65,7 +68,7 @@ def compute_distance(self): def align_sents(self, first_alignment_only=False): print("Performing first-step alignment ...") - D, I = core.find_top_k_sents(self.src_vecs[0,:], self.tgt_vecs[0,:], k=self.top_k) + D, I = core.find_top_k_sents(self.src_vecs[0,:], self.tgt_vecs[0,:], k=self.top_k, device=self.device) first_alignment_types = core.get_alignment_types(2) # 0-1, 1-0, 1-1 first_w, first_path = core.find_first_search_path(self.src_num, self.tgt_num) first_pointers = core.first_pass_align(self.src_num, self.tgt_num, first_w, first_path, first_alignment_types, D, I) diff --git a/aquilign/align/bertalign/corelib.py b/aquilign/align/bertalign/corelib.py index 85f480f..d900ee0 100644 --- a/aquilign/align/bertalign/corelib.py +++ b/aquilign/align/bertalign/corelib.py @@ -377,7 +377,7 @@ def get_alignment_types(max_alignment_size): alignment_types.append([x, y]) return np.array(alignment_types) -def find_top_k_sents(src_vecs, tgt_vecs, k=3): +def find_top_k_sents(src_vecs, tgt_vecs, k=3, device='cpu'): """ Find the top_k similar vecs in tgt_vecs for each vec in src_vecs. Args: @@ -389,7 +389,7 @@ def find_top_k_sents(src_vecs, tgt_vecs, k=3): I: numpy array. Target index matrix of shape (num_src_sents, k). """ embedding_size = src_vecs.shape[1] - if torch.cuda.is_available() and platform == 'linux': # GPU version + if torch.cuda.is_available() and platform == 'linux' and device != "cpu": # GPU version res = faiss.StandardGpuResources() index = faiss.IndexFlatIP(embedding_size) gpu_index = faiss.index_cpu_to_gpu(res, 0, index) diff --git a/aquilign/align/bertalign/encoder.py b/aquilign/align/bertalign/encoder.py index 40d0fb3..d3e25d4 100644 --- a/aquilign/align/bertalign/encoder.py +++ b/aquilign/align/bertalign/encoder.py @@ -6,8 +6,8 @@ class Encoder: - def __init__(self, model_name): - device = "cuda:0" if torch.cuda.is_available() else "cpu" + def __init__(self, model_name, device): + self.device = "cuda:0" if torch.cuda.is_available() and device != "cpu" else "cpu" if model_name == "LaBSE": self.model = SentenceTransformer(model_name_or_path=model_name, device=device) self.model_name = model_name @@ -21,7 +21,7 @@ def simple_vectorization(self, sents): This function produces a simple vectorisation of a sentence, without taking into account its lenght as transform does """ - sent_vecs = self.model.encode(sents) + sent_vecs = self.model.encode(sents, device=self.device) return sent_vecs def transform(self, sents, num_overlaps): @@ -30,7 +30,7 @@ def transform(self, sents, num_overlaps): overlaps.append(line) if self.model_name == "LaBSE": - sent_vecs = self.model.encode(overlaps) + sent_vecs = self.model.encode(overlaps, device=self.device) else: sents_vecs = self.t2vec_model.predict() embedding_dim = sent_vecs.size // (len(sents) * num_overlaps) diff --git a/main.py b/main.py index d5cca91..78580fe 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,9 @@ import aquilign.align.graph_merge as graph_merge import aquilign.align.bertalign.utils as utils import aquilign.align.bertalign.syntactic_tokenization as syntactic_tokenization -from aquilign.align.bertalign.Bertalign import Bertalign +#from aquilign.align.bertalign.Bertalign import Bertalign +from aquilign.align.bertalign.encoder import Encoder +from aquilign.align.bertalign.aligner import Bertalign import pandas as pd import argparse import glob @@ -60,16 +62,21 @@ class Aligner: La classe Aligner initialise le moteur d'alignement, fondé sur Bertalign """ - def __init__(self, corpus_size:None, + def __init__(self, + model, + corpus_size:None, max_align=3, out_dir="out", use_punctuation=True, input_dir="in", main_wit=None, - prefix=None): + prefix=None, + device="cpu"): + self.model = model self.alignment_dict = dict() self.text_dict = dict() self.files_path = glob.glob(f"{input_dir}/*.txt") + self.device = device print(input_dir) if main_wit is not None: self.main_file_index = [index for index, path in enumerate(self.files_path) if main_wit in path][0] @@ -133,7 +140,14 @@ def parallel_align(self): else: margin = False len_penality = True - aligner = Bertalign(first_tokenized_text, second_tokenized_text, max_align= self.max_align, win=5, skip=-.2, margin=margin, len_penalty=len_penality) + aligner = Bertalign(self.model, + first_tokenized_text, + second_tokenized_text, + max_align= self.max_align, + win=5, skip=-.2, + margin=margin, + len_penalty=len_penality, + device=self.device) aligner.align_sents() # We append the result to the alignment dictionnary @@ -215,15 +229,25 @@ def run_alignments(): help="Pivot witness.") parser.add_argument("-p", "--prefix", default=None, help="Prefix for produced files.") + parser.add_argument("-d", "--device", default='cpu', + help="Device to be used.") args = parser.parse_args() out_dir = args.out_dir input_dir = args.input_dir main_wit = args.main_wit prefix = args.prefix + device = args.device use_punctuation = args.use_punctuation + + # Initialize model + models = {0: "distiluse-base-multilingual-cased-v2", 1: "LaBSE", 2: "Sonar"} + model = Encoder(models[int(1)], device=device) + + + print(f"Punctuation for tokenization: {use_punctuation}") - MyAligner = Aligner(corpus_size=None, max_align=3, out_dir=out_dir, use_punctuation=use_punctuation, input_dir=input_dir, main_wit=main_wit, prefix=prefix) + MyAligner = Aligner(model, corpus_size=None, max_align=3, out_dir=out_dir, use_punctuation=use_punctuation, input_dir=input_dir, main_wit=main_wit, prefix=prefix, device=device) MyAligner.parallel_align() utils.write_json(f"result_dir/{out_dir}/alignment_dict.json", MyAligner.alignment_dict) align_dict = utils.read_json(f"result_dir/{out_dir}/alignment_dict.json")