-
Notifications
You must be signed in to change notification settings - Fork 0
/
shell_train.py
36 lines (26 loc) · 1.15 KB
/
shell_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--domain", "-d", default="sketch", help="Target")
parser.add_argument("--gpu", "-g", default=0, type=int, help="Gpu ID")
parser.add_argument("--times", "-t", default=1, type=int, help="Repeat times")
parser.add_argument("--root", default=None, type=str)
args = parser.parse_args()
###############################################################################
source = ["photo", "cartoon", "art_painting", "sketch"]
target = args.domain
source.remove(target)
input_dir = '/path/to/data'
output_dir = '/path/to/outputs'
config = "PACS/ResNet18"
domain_name = target
path = os.path.join(output_dir, config.replace("/", "_"), domain_name)
##############################################################################
for i in range(args.times):
os.system(f'CUDA_VISIBLE_DEVICES={args.gpu} '
f'python train_DCG.py '
f'--source {source[0]} {source[1]} {source[2]} '
f'--target {target} '
f'--input_dir {input_dir} '
f'--output_dir {output_dir} '
f'--config {config}',)