Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
MNasert committed Sep 13, 2022
1 parent 0ff2bd2 commit 99fa6df
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
numpy
numba
transformers
thermostat

thermostat-datasets
imgkit
tqdm
Expand Down
35 changes: 31 additions & 4 deletions src/processing/process_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class TaskBase:
RequiredRamPerProcess: int # in GBytes f.e. 1 = 1024*1024*1024 bytes
DesiredProcesses: int
TaskIndex: int
Task: object # preferably a method
Task: object # preferably a method but anything callable works
done = False
associated_processes = []

Expand Down Expand Up @@ -52,12 +52,15 @@ def concat_task() -> TaskBase:


class ProcessHandler:
#constants
prime_delimiter = -2743

def __init__(self, loader: Verbalizer,
tasks: List[TaskBase],
samples: dict):

self.manager = mp.Manager() # deprecated?
self.root = loader
self.manager = mp.Manager()
self.tasks = self.order_tasks(tasks)
self.samples = samples
# we try reconstructing the dict after calculations
Expand All @@ -74,14 +77,14 @@ def __init__(self, loader: Verbalizer,
offset_indices_attributions = int(0)
# fill buffer with attributions & use prime decompositional as delimiter
for key in samples.keys():
attrs = [*samples[key]["attributions"], -2743]
attrs = [*samples[key]["attributions"], self.prime_delimiter]
self.sample_attributions_buf[:offset_indices_attributions + 1] = attrs[:]
offset_indices_attributions += len(samples[key]["attributions"]) + 1

# fill buffer with texts & use prime decompositional as delimiter
offset_indices_texts = int(0)
for key in samples.keys():
attrs = [*samples[key]["input_ids"], -2743]
attrs = [*samples[key]["input_ids"], self.prime_delimiter]
self.sample_texts_buf[:offset_indices_texts + 1] = attrs[:]
offset_indices_texts += len(samples[key]["attributions"]) + 1

Expand All @@ -107,6 +110,30 @@ def equalsplit_data(data: np.array, pieces: int) -> list: # https://stackoverfl
for i in range(0, len(data), pieces):
yield data[i:i + pieces]

def attr_instance_generator(self) -> np.ndarray:
item = None
index = 0
while index < len(self.sample_texts_buf):
ret = []
while self.sample_texts_buf[index] != self.prime_delimiter:
ret.append(self.sample_texts_buf[index])
index += 1
if self.sample_texts_buf[index] == self.prime_delimiter:
index += 1
yield np.array(ret, dtype=np.float32)

def text_instance_generator(self) -> List[int]:
item = None
index = 0
while index < len(self.sample_texts_buf):
ret = []
while self.sample_texts_buf[index] != self.prime_delimiter:
ret.append(self.sample_texts_buf[index])
index += 1
if self.sample_texts_buf[index] == self.prime_delimiter:
index += 1
yield ret

def get_args(self, task: TaskBase) -> dict:
if task.TaskName == "ConvSearch":
return {"sgn": self.root.sgn,
Expand Down
3 changes: 2 additions & 1 deletion tmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import datasets
import src.dataloader as dataloader
if __name__ == "__main__":
config_path = "configs/mean_dev.yml"
multiprocessing.freeze_support()
config_path = "configs/toy_dev.yml"

with open(config_path) as stream:
config = yaml.safe_load(stream)
Expand Down

0 comments on commit 99fa6df

Please sign in to comment.