Skip to content

Commit

Permalink
Merge pull request bfsujason#5 from ProMeText/add_device_conf
Browse files Browse the repository at this point in the history
Add device selection
  • Loading branch information
Jean-Baptiste-Camps authored Apr 19, 2024
2 parents 5e90057 + 03d5a8a commit 4b1c071
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 41 deletions.
22 changes: 0 additions & 22 deletions aquilign/align/bertalign/Bertalign.py

This file was deleted.

19 changes: 11 additions & 8 deletions aquilign/align/bertalign/aligner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import numpy as np

from aquilign.align.bertalign.Bertalign import model
import aquilign.align.bertalign.corelib as core
import aquilign.align.bertalign.utils as utils
import torch.nn as nn
import torch

class Bertalign:
def __init__(self,
model,
src,
tgt,
max_align=3,
Expand All @@ -17,14 +17,17 @@ def __init__(self,
margin=True,
len_penalty=True,
is_split=False,
):
device="cpu"):

self.max_align = max_align
self.top_k = top_k
self.win = win
self.skip = skip
self.margin = margin
self.len_penalty = len_penalty
self.device = device
self.model = model




Expand All @@ -38,11 +41,11 @@ def __init__(self,
assert len(src_sents) != 0, "Problemo"

print("Embedding source and target text using {} ...".format(model.model_name))
src_vecs, src_lens = model.transform(src_sents, max_align - 1)
tgt_vecs, tgt_lens = model.transform(tgt_sents, max_align - 1)
src_vecs, src_lens = self.model.transform(src_sents, max_align - 1)
tgt_vecs, tgt_lens = self.model.transform(tgt_sents, max_align - 1)

self.search_simple_vecs = model.simple_vectorization(src_sents)
self.tgt_simple_vecs = model.simple_vectorization(tgt_sents)
self.search_simple_vecs = self.model.simple_vectorization(src_sents)
self.tgt_simple_vecs = self.model.simple_vectorization(tgt_sents)

char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,])

Expand All @@ -57,15 +60,15 @@ def __init__(self,
self.tgt_vecs = tgt_vecs

def compute_distance(self):
if torch.cuda.is_available(): # GPU version
if torch.cuda.is_available() and self.device == 'cuda:0': # GPU version
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
output = cos(torch.from_numpy(self.search_simple_vecs), torch.from_numpy(self.tgt_simple_vecs))
return output

def align_sents(self, first_alignment_only=False):

print("Performing first-step alignment ...")
D, I = core.find_top_k_sents(self.src_vecs[0,:], self.tgt_vecs[0,:], k=self.top_k)
D, I = core.find_top_k_sents(self.src_vecs[0,:], self.tgt_vecs[0,:], k=self.top_k, device=self.device)
first_alignment_types = core.get_alignment_types(2) # 0-1, 1-0, 1-1
first_w, first_path = core.find_first_search_path(self.src_num, self.tgt_num)
first_pointers = core.first_pass_align(self.src_num, self.tgt_num, first_w, first_path, first_alignment_types, D, I)
Expand Down
4 changes: 2 additions & 2 deletions aquilign/align/bertalign/corelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def get_alignment_types(max_alignment_size):
alignment_types.append([x, y])
return np.array(alignment_types)

def find_top_k_sents(src_vecs, tgt_vecs, k=3):
def find_top_k_sents(src_vecs, tgt_vecs, k=3, device='cpu'):
"""
Find the top_k similar vecs in tgt_vecs for each vec in src_vecs.
Args:
Expand All @@ -389,7 +389,7 @@ def find_top_k_sents(src_vecs, tgt_vecs, k=3):
I: numpy array. Target index matrix of shape (num_src_sents, k).
"""
embedding_size = src_vecs.shape[1]
if torch.cuda.is_available() and platform == 'linux': # GPU version
if torch.cuda.is_available() and platform == 'linux' and device != "cpu": # GPU version
res = faiss.StandardGpuResources()
index = faiss.IndexFlatIP(embedding_size)
gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
Expand Down
8 changes: 4 additions & 4 deletions aquilign/align/bertalign/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


class Encoder:
def __init__(self, model_name):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
def __init__(self, model_name, device):
self.device = "cuda:0" if torch.cuda.is_available() and device != "cpu" else "cpu"
if model_name == "LaBSE":
self.model = SentenceTransformer(model_name_or_path=model_name, device=device)
self.model_name = model_name
Expand All @@ -21,7 +21,7 @@ def simple_vectorization(self, sents):
This function produces a simple vectorisation of a sentence, without
taking into account its lenght as transform does
"""
sent_vecs = self.model.encode(sents)
sent_vecs = self.model.encode(sents, device=self.device)
return sent_vecs

def transform(self, sents, num_overlaps):
Expand All @@ -30,7 +30,7 @@ def transform(self, sents, num_overlaps):
overlaps.append(line)

if self.model_name == "LaBSE":
sent_vecs = self.model.encode(overlaps)
sent_vecs = self.model.encode(overlaps, device=self.device)
else:
sents_vecs = self.t2vec_model.predict()
embedding_dim = sent_vecs.size // (len(sents) * num_overlaps)
Expand Down
34 changes: 29 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import aquilign.align.graph_merge as graph_merge
import aquilign.align.bertalign.utils as utils
import aquilign.align.bertalign.syntactic_tokenization as syntactic_tokenization
from aquilign.align.bertalign.Bertalign import Bertalign
#from aquilign.align.bertalign.Bertalign import Bertalign
from aquilign.align.bertalign.encoder import Encoder
from aquilign.align.bertalign.aligner import Bertalign
import pandas as pd
import argparse
import glob
Expand Down Expand Up @@ -60,16 +62,21 @@ class Aligner:
La classe Aligner initialise le moteur d'alignement, fondé sur Bertalign
"""

def __init__(self, corpus_size:None,
def __init__(self,
model,
corpus_size:None,
max_align=3,
out_dir="out",
use_punctuation=True,
input_dir="in",
main_wit=None,
prefix=None):
prefix=None,
device="cpu"):
self.model = model
self.alignment_dict = dict()
self.text_dict = dict()
self.files_path = glob.glob(f"{input_dir}/*.txt")
self.device = device
print(input_dir)
if main_wit is not None:
self.main_file_index = [index for index, path in enumerate(self.files_path) if main_wit in path][0]
Expand Down Expand Up @@ -133,7 +140,14 @@ def parallel_align(self):
else:
margin = False
len_penality = True
aligner = Bertalign(first_tokenized_text, second_tokenized_text, max_align= self.max_align, win=5, skip=-.2, margin=margin, len_penalty=len_penality)
aligner = Bertalign(self.model,
first_tokenized_text,
second_tokenized_text,
max_align= self.max_align,
win=5, skip=-.2,
margin=margin,
len_penalty=len_penality,
device=self.device)
aligner.align_sents()

# We append the result to the alignment dictionnary
Expand Down Expand Up @@ -215,15 +229,25 @@ def run_alignments():
help="Pivot witness.")
parser.add_argument("-p", "--prefix", default=None,
help="Prefix for produced files.")
parser.add_argument("-d", "--device", default='cpu',
help="Device to be used.")

args = parser.parse_args()
out_dir = args.out_dir
input_dir = args.input_dir
main_wit = args.main_wit
prefix = args.prefix
device = args.device
use_punctuation = args.use_punctuation

# Initialize model
models = {0: "distiluse-base-multilingual-cased-v2", 1: "LaBSE", 2: "Sonar"}
model = Encoder(models[int(1)], device=device)



print(f"Punctuation for tokenization: {use_punctuation}")
MyAligner = Aligner(corpus_size=None, max_align=3, out_dir=out_dir, use_punctuation=use_punctuation, input_dir=input_dir, main_wit=main_wit, prefix=prefix)
MyAligner = Aligner(model, corpus_size=None, max_align=3, out_dir=out_dir, use_punctuation=use_punctuation, input_dir=input_dir, main_wit=main_wit, prefix=prefix, device=device)
MyAligner.parallel_align()
utils.write_json(f"result_dir/{out_dir}/alignment_dict.json", MyAligner.alignment_dict)
align_dict = utils.read_json(f"result_dir/{out_dir}/alignment_dict.json")
Expand Down

0 comments on commit 4b1c071

Please sign in to comment.