From e9244b5141125873631a45494b6f3b6a5d773a70 Mon Sep 17 00:00:00 2001 From: ShadowMov Date: Sun, 11 Mar 2018 16:47:17 +0800 Subject: [PATCH] Fix Training sort (#407) --- .travis.yml | 2 +- vj4/handler/training.py | 5 +++-- vj4/test/test_misc.py | 16 ++++++++++++++++ vj4/util/misc.py | 11 +++++++++++ 4 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 vj4/test/test_misc.py diff --git a/.travis.yml b/.travis.yml index 863e6367..170d942a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,7 +9,7 @@ python: - 3.6 install: - pip install -r requirements.txt -- node scripts/fix_abroad_shrinkwrap.js +- npm config set registry https://registry.npm.taobao.org - npm install script: - npm run build:production diff --git a/vj4/handler/training.py b/vj4/handler/training.py index 9d9df70d..779b4c5a 100644 --- a/vj4/handler/training.py +++ b/vj4/handler/training.py @@ -13,6 +13,7 @@ from vj4.handler import base from vj4.util import json from vj4.util import pagination +from vj4.util import misc def _parse_dag_json(dag): @@ -29,8 +30,8 @@ def _parse_dag_json(dag): raise error.ValidationError('dag') new_node = {'_id': int(node['_id']), 'title': str(node.get('title', '')), - 'require_nids': list(set(map(int, node['require_nids']))), - 'pids': list(set(map(document.convert_doc_id, node['pids'])))} + 'require_nids': misc.dedupe(map(int, node['require_nids'])), + 'pids': misc.dedupe(map(document.convert_doc_id, node['pids']))} new_dag.append(new_node) except ValueError: raise error.ValidationError('dag') from None diff --git a/vj4/test/test_misc.py b/vj4/test/test_misc.py new file mode 100644 index 00000000..54532e15 --- /dev/null +++ b/vj4/test/test_misc.py @@ -0,0 +1,16 @@ +import unittest + +from vj4.util import misc + + +class Test(unittest.TestCase): + def test_dedupe(self): + self.assertListEqual(misc.dedupe([2,1,1,3,2,3]),[2,1,3]) + self.assertListEqual(misc.dedupe([]),[]) + self.assertListEqual(misc.dedupe(map(int,['2','1','1','3','2','3'])),[2,1,3]) + self.assertListEqual(misc.dedupe(['b','a','b','c','b']),['b','a','c']) + self.assertListEqual(misc.dedupe([0]),[0]) + + +if __name__ == '__main__': + unittest.main() diff --git a/vj4/util/misc.py b/vj4/util/misc.py index ef8d670a..fe4d8a59 100644 --- a/vj4/util/misc.py +++ b/vj4/util/misc.py @@ -94,3 +94,14 @@ def format_seconds(seconds): def base64_encode(str): encoded = base64.b64encode(str.encode()) return encoded.decode() + + +def dedupe(list): + result = [] + result_set = set() + for i in list: + if i in result_set: + continue + result.append(i) + result_set.add(i) + return result