Skip to content

Commit

Permalink
hotfix: fix syntax error python 3.6 (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlsneto authored May 28, 2020
1 parent ae35b0c commit 44688f9
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions cereja/datatools/split_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def split_data(self, test_max_size: int = None, source_vocab_size: int = None, t
data = list(self._get_vocab_data(source_vocab_size=source_vocab_size,
target_vocab_size=target_vocab_size))
else:
data = zip(self._x, self._y)
data = list(zip(self._x, self._y))

if shuffle:
random.shuffle(data)
Expand All @@ -160,7 +160,7 @@ def split_data(self, test_max_size: int = None, source_vocab_size: int = None, t
continue
train.append([x, y])
if take_paralel_data is False:
return *get_cols(train), *get_cols(test)
return (*get_cols(train), *get_cols(test))
return train, test

def split_data_and_save(self, save_on_dir: str, test_max_size: int = None, source_vocab_size: int = None,
Expand All @@ -175,11 +175,3 @@ def split_data_and_save(self, save_on_dir: str, test_max_size: int = None, sourc
File(save_on, x).save(**kwargs)
save_on = os.path.join(save_on_dir, f'{prefix}_{self.target_language}.{ext.strip(".")}')
File(save_on, y).save(**kwargs)


if __name__ == '__main__':
en = File.read('C:/Users/handtalk/Downloads/train_original/en.align')
asl = File.read("C:/Users/handtalk/Downloads/train_original/asl.align")
a = Corpus(zip(en.lines, asl.lines), 'en', 'asl')
a.split_data_and_save(save_on_dir="C:/Users/handtalk/Downloads/train_original/test_remove", source_vocab_size=1000,
exist_ok=True)

0 comments on commit 44688f9

Please sign in to comment.