-
Notifications
You must be signed in to change notification settings - Fork 0
/
ldt_train.py
148 lines (126 loc) · 4.56 KB
/
ldt_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
'''
Author: Emilio Morales ([email protected])
Jun 2023
'''
import argparse
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Disable tensorflow debugging logs
import tensorflow as tf
import time
from autoencoder import Autoencoder
from ae_trainer import AutoencoderKL
from dit import DiT
from ldt_trainer import LDT
from utils import *
from schedule import CosineSchedule
def train(args):
print('\n#########')
print('LDT Train')
print('#########\n')
file_pattern = args.file_pattern
model_dir = args.model_dir
ldt_name = args.ldt_name
max_ckpt_to_keep = args.max_ckpt_to_keep
interval = args.interval
restore_best = args.restore_best
total_batches = args.total_batches
# ldt config file
model_dir = os.path.join(model_dir, ldt_name)
os.makedirs(model_dir, exist_ok=True)
config_file = os.path.join(model_dir, f"{ldt_name}_config.json")
if os.path.exists(config_file):
with open(config_file, 'r') as file:
config = json.load(file)
print(f'{config_file} loaded')
else:
from ldt_config import config
with open(config_file, 'w') as file:
json.dump(config, file)
print(f'{config_file} saved')
print(config)
# ae config file
ae_dir = os.path.join(config['ae_dir'], config['ae_name'])
ae_config_file = os.path.join(ae_dir, f"{config['ae_name']}_config.json")
with open(ae_config_file, 'r') as file:
ae_config = json.load(file)
print(f'{ae_config_file} loaded')
# dataset
train_ds = create_train_ds(
file_pattern, config['batch_size'], config['img_size']
)
val_ds = create_test_ds(
file_pattern, config['fid_batch_size'], config['img_size'],
config['n_fid_images'], config['ds_val_seed'],
)
train_ds = iter(train_ds.repeat())
train_batch = next(train_ds)
# ae model
autoencoder = Autoencoder(
ae_config['encoder_dim'], ae_config['decoder_dim'],
cuant_dim=ae_config['cuant_dim']
)
autoencoder.trainable = False
autoencoder(train_batch) # init model
print(autoencoder.summary())
ae_kl = AutoencoderKL(
None, autoencoder, None, None, None, ae_config
)
# ae ckpt
ae_kl.restore_ae(ae_dir)
test_latent = ae_kl.ae.encoder(train_batch)[0]
# dit model
dit = DiT(
config['latent_size'], config['patch_size'], config['ldt_dim'],
heads=config['heads'], k=config['k'], mlp_dim=config['mlp_dim'],
depth=config['depth'], cuant_dim=config['cuant_dim']
)
test_noise = tf.ones([config['batch_size'], 1, 1, 1]) * -1e9
outputs = dit([test_latent, test_noise]) # init dit
print(dit.summary())
# schedule
diffusion_schedule = CosineSchedule(
start_log_snr=config['start_log_snr'],
end_log_snr=config['end_log_snr'],
)
# ldt optimizer
opt = tf.keras.optimizers.Adam(
learning_rate=config['learning_rate']
)
# trainer
ldt = LDT(
network=dit, ae_kl=ae_kl, opt=opt, diffusion_schedule=diffusion_schedule,
config=config
)
ldt.create_ckpt(
model_dir, max_ckpt_to_keep=max_ckpt_to_keep, restore_best=restore_best
)
ldt.plot_images(0, diffusion_steps=config['fid_diffusion_steps']) # init ldt
# train
start_batch = int((ldt.ckpt.n_images / ldt.batch_size) + 1)
n_images = int(ldt.ckpt.n_images)
start = time.time()
for n_batch in range(start_batch, total_batches):
batch = train_ds.get_next()
ldt.train_step(batch)
if n_batch % interval == 0:
print(f'\nTime for interval is {time.time()-start:.4f} sec')
n_images = n_batch * ldt.batch_size
ldt.save_ckpt(
n_images, config['n_fid_images'], config['fid_diffusion_steps'],
config['fid_batch_size'], val_ds
)
ldt.plot_images(n_images, diffusion_steps=config['fid_diffusion_steps'])
start = time.time()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--file_pattern', type=str)
parser.add_argument('--model_dir', type=str, default='ldt')
parser.add_argument('--ldt_name', type=str, default='model_1')
parser.add_argument('--max_ckpt_to_keep', type=int, default=2)
parser.add_argument('--interval', type=int, default=500)
parser.add_argument('--restore_best', type=bool, default=False)
parser.add_argument('--total_batches', type=int, default=100000000)
args = parser.parse_args()
train(args)
if __name__ == '__main__':
main()