diff --git a/.gitignore b/.gitignore index c896353cb..9105d89e7 100644 --- a/.gitignore +++ b/.gitignore @@ -113,3 +113,4 @@ examples/app_yolov8_classification/xcore_flash_binary.out examples/app_yolov8_classification/yolov8n-cls.onnx examples/app_yolov8_classification/yolov8n-cls.pt examples/app_yolov8_classification/yolov8n-cls_saved_model/ +examples/app_denoising/data diff --git a/examples/app_denoising/README.rst b/examples/app_denoising/README.rst new file mode 100644 index 000000000..ea3c3da8a --- /dev/null +++ b/examples/app_denoising/README.rst @@ -0,0 +1,39 @@ +====================== +De-noising model +====================== + +Installation +============ + +1. **Install Dependencies**: + +.. code-block:: shell + + pip install -r requirements.txt + +2. **Download Dataset**: + +.. code-block:: shell + + bash download.sh + +This script will download a sample of the `MS-SNSD` dataset for noise samples, and a sample of the `DNS-Challenge` dataset for clean speech. Samples are saved in the `data/` directory. + +3. **Convert Dataset**: + +.. code-block:: shell + + python dataset.py + +Convert the downloaded datasets into training samples, and save them as `TFRecords` in `data/records/`. + +Training +======== + +1. **Initiate Training**: + +.. code-block:: shell + + python train.py + +This script will train a de-noising and de-reverberation model on the prepared data, and save it as `model.h5`. diff --git a/examples/app_denoising/dataset.py b/examples/app_denoising/dataset.py new file mode 100644 index 000000000..a2ad3af6d --- /dev/null +++ b/examples/app_denoising/dataset.py @@ -0,0 +1,223 @@ +import os +import tensorflow as tf +import random +from tqdm import tqdm +import numpy as np +from scipy.io import wavfile +from scipy.signal import stft, resample, fftconvolve +import noisereduce as nr +from glob import glob + +SAMPLES = None +WINDOW_SIZE = 512 +HOP_SIZE = 128 +FFT_SIZE = 512 +FL = 512 * 8 * 40 +FADE_IN_LENGTH = 512 * 50 +SR = 16000 +NUM_BINS = 64 +POWER_FACTOR = .3 +tf.keras.utils.set_random_seed(42) + + +def unique_log_bins(low, high, nbins): + if low < 1: + bins = np.geomspace(1, high, nbins-1, dtype=int) + bins = np.concatenate(([0], bins)) + else: + bins = np.geomspace(low, high, nbins, dtype=int) + while len(np.unique(bins)) != nbins: + unique_vals, counts = np.unique(bins, return_counts=True) + duplicates = np.argwhere(counts > 1) + arg_first_unique = duplicates[-1][0] + 1 + first_unique = unique_vals[arg_first_unique] + total_duplicates = np.sum(unique_vals < first_unique) + next_bins = np.geomspace( + first_unique, high, nbins - total_duplicates, dtype=int) + bins = np.concatenate((unique_vals[:arg_first_unique], next_bins)) + return bins + + +def log_filterbank(Fs, nfft, n_filters=24, f_low=0, f_high=None, window_function=np.hanning): + f_high = f_high or Fs/2.0 + assert (f_high <= Fs/2.0), "Log filterbank higher frequency cannot exceed Fs/2!" + bin_low = np.floor(f_low*(nfft)/Fs) + bin_high = np.floor(f_high*(nfft)/Fs) + Hz_bins = unique_log_bins(bin_low, bin_high, n_filters) + fbank = np.zeros([n_filters, nfft // 2 + 1]) + for n in range(n_filters-1): + dist = int(Hz_bins[n+1] - Hz_bins[n]) + wind = window_function(2*dist + 1) + fbank[n, Hz_bins[n]:Hz_bins[n+1]] = wind[dist:-1] + fbank[n+1, Hz_bins[n]:Hz_bins[n+1]] = wind[:dist] + fbank[0, :Hz_bins[0]] = 1.0 + fbank[-1, Hz_bins[-1]:] = 1.0 + return fbank + + +F_BANK = log_filterbank(SR, WINDOW_SIZE, NUM_BINS) +_, NOISE_AUDIO = wavfile.read("data/noise.wav") + + +def infinite(gen, *args, **kwargs): + while True: + yield from gen(*args, **kwargs) + + +def nsf(signal, noise, snr): + signal_power = np.mean(signal ** 2) + noise_power = np.mean(noise ** 2) + return np.sqrt((signal_power / noise_power) * 10 ** (-snr / 10.0)) + + +def apply_reverb(signal, rir): + ratio = np.random.uniform(0, 1) + r_signal = fftconvolve(signal, rir, mode="full")[:len(signal)] + return ratio * r_signal + (1.-ratio) * signal + + +def process_wave(signal, return_orig=False): + _, _, s = stft(signal, fs=SR, nperseg=WINDOW_SIZE, + noverlap=WINDOW_SIZE - HOP_SIZE, nfft=FFT_SIZE) + mag = np.abs(s.T) @ F_BANK.T + mag = mag[..., None].astype(np.float32)**POWER_FACTOR + if return_orig: + return mag, s.T + return mag + + +def pad(sig): + tot_pad = FL - len(sig) + left_pad = np.random.randint(0, tot_pad + 1) + right_pad = tot_pad - left_pad + return np.concatenate([np.zeros(left_pad), sig, np.zeros(right_pad)]) + + +def get_input(signal, noise, rirs, return_phase=False): + snr = np.random.uniform(0, 30) + noise_factor = nsf(signal, noise, snr) + if len(signal) > FADE_IN_LENGTH: + signal[:FADE_IN_LENGTH] *= np.arange(0, 1, 1/FADE_IN_LENGTH) + signal, noise = pad(signal), pad(noise) + r_signal = apply_reverb(signal, rirs[..., 0]) + r_noise = apply_reverb(noise, rirs[..., 1]) + noisy_signal = r_signal + r_noise * noise_factor + if return_phase: + ins, orig = process_wave(noisy_signal, True) + outs, perf = process_wave(signal, True) + return ins, outs, orig, perf + else: + ins, outs = process_wave(noisy_signal), process_wave(signal) + return ins, outs + + +def signal_gen(folder, chop=True, is_clean=False): + paths = glob(f"{folder}/**/*.wav", recursive=True) + random.shuffle(paths) + for path in paths: + fs, s = wavfile.read(path) + if is_clean: + s = nr.reduce_noise(s, sr=fs, y_noise=NOISE_AUDIO, + stationary=True, n_fft=512, + time_mask_smooth_ms=32, + freq_mask_smooth_hz=188, + n_std_thresh_stationary=.8) + s = resample(s, int(len(s) / fs * SR)) + if chop: + yield from (s[i:i+FL] for i in range(0, len(s), FL)) + elif s.shape[0] <= FL and len(s.shape) == 2: + yield from (s[:, i-2:i] for i in range(2, s.shape[1], 2)) + + +def chop_silence(wav): + aw = np.abs(wav[..., 0]) + maw = np.max(aw) + index = np.where(aw / maw > 0.7)[0][0] + return wav[index:] / maw + + +def data_gen(sig_fol, noise_fol, rir_fol, phase=False): + voices = signal_gen(sig_fol, is_clean=True) + noises = infinite(signal_gen, noise_fol) + rirs = map(chop_silence, infinite(signal_gen, rir_fol, False)) + combined = zip(voices, noises, rirs) + yield from (get_input(s, n, r, phase) for s, n, r in combined) + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(value).numpy()])) + + +def serialize_example(signal, noise): + feature = { + 'signal': _bytes_feature(signal), + 'noise': _bytes_feature(noise), + } + example_proto = tf.train.Example( + features=tf.train.Features(feature=feature)) + return example_proto.SerializeToString() + + +def write_tfrecords(folder_path, data_generator, samples_per_file=256): + file_count = samples_written = 0 + tfrecord_writer = None + for signal, noise in tqdm(data_generator): + if not samples_written: + file_name = f"{folder_path}/data_{file_count}.tfrecord" + tfrecord_writer = tf.io.TFRecordWriter(file_name) + tf_example = serialize_example(signal, noise) + tfrecord_writer.write(tf_example) + samples_written += 1 + if samples_written == samples_per_file: + tfrecord_writer.close() + file_count += 1 + samples_written = 0 + if samples_written: + tfrecord_writer.close() + + +def ds_from_paths(paths, batch_size): + ds = tf.data.TFRecordDataset(filenames=paths) + ds = ds.map(_parse_function).shuffle(buffer_size=100) + return ds.batch(batch_size).prefetch(tf.data.AUTOTUNE) + + +def read_tfrecords(folder_path, bs=16): + fp = glob(os.path.join(folder_path, '*.tfrecord')) + random.shuffle(fp) + nt = len(fp) // 10 + return ds_from_paths(fp[nt:], bs), ds_from_paths(fp[:nt], bs) + + +def load_dataset(batch_size): + return tf.data.Dataset.from_generator( + lambda: data_gen("data/datasets_fullband/", + "data/MS-SNSD/", "data/rirs_noises/"), + output_signature=( + tf.TensorSpec(shape=(SAMPLES, NUM_BINS, 1), dtype=np.float32), + tf.TensorSpec(shape=(SAMPLES, NUM_BINS, 1), dtype=np.float32), + ) + ).batch(batch_size) + + +def _parse_function(proto): + keys_to_features = { + 'signal': tf.io.FixedLenFeature([], tf.string), + 'noise': tf.io.FixedLenFeature([], tf.string) + } + parsed_features = tf.io.parse_single_example(proto, keys_to_features) + parsed_features['signal'] = tf.io.parse_tensor( + parsed_features['signal'], out_type=tf.float32) + parsed_features['noise'] = tf.io.parse_tensor( + parsed_features['noise'], out_type=tf.float32) + + return parsed_features['signal'], parsed_features['noise'] + + +if __name__ == "__main__": + gen = data_gen("data/datasets_fullband/", + "data/MS-SNSD/", "data/rirs_noises/") + write_tfrecords(f"data/records_{NUM_BINS}", gen) + # for a, b in gen: + # print(a.shape, b.shape) + # break diff --git a/examples/app_denoising/download.sh b/examples/app_denoising/download.sh new file mode 100755 index 000000000..e5272a689 --- /dev/null +++ b/examples/app_denoising/download.sh @@ -0,0 +1,37 @@ +#!/usr/bin/bash + +### DOWNLOAD VOICE SAMPLES ### +BLOB_NAMES=( + clean_fullband/datasets_fullband.clean_fullband.french_speech_000_NA_NA.tar.bz2 +) +AZURE_URL="https://dns4public.blob.core.windows.net/dns4archive/datasets_fullband" + +DATA_DIR="data" +OUTPUT_PATH="$DATA_DIR/datasets_fullband" + +mkdir -p $OUTPUT_PATH/{clean_fullband,noise_fullband} + +for BLOB in ${BLOB_NAMES[@]} +do + URL="$AZURE_URL/$BLOB" + echo "Download: $BLOB" + curl "$URL" | tar -C "$OUTPUT_PATH" -f - -x -j +done + +## DOWNLOAD NOISE SAMPLES ### +repo_subdir="noise_train" + +git -C "$DATA_DIR" init +git -C "$DATA_DIR" config core.sparseCheckout true +echo "noise_train/*" > "$DATA_DIR/.git/info/sparse-checkout" +git -C "$DATA_DIR" remote add -f origin https://github.com/microsoft/MS-SNSD.git +git -C "$DATA_DIR" pull origin master +rm -rf "$DATA_DIR/.git" + +### DOWNLOAD RIRS ### +openslr_url="https://www.openslr.org/resources/28/rirs_noises.zip" +openslr_dir="$DATA_DIR/rirs_noises" +mkdir -p "$openslr_dir" +wget -O "$openslr_dir/rirs_noises.zip" "$openslr_url" +unzip "$openslr_dir/rirs_noises.zip" -d "$openslr_dir" +rm "$openslr_dir/rirs_noises.zip" diff --git a/examples/app_denoising/model.py b/examples/app_denoising/model.py new file mode 100644 index 000000000..8508e3e67 --- /dev/null +++ b/examples/app_denoising/model.py @@ -0,0 +1,84 @@ +import tensorflow as tf +import numpy as np +from tensorflow.keras.layers import (Input, GRU, ReLU, Reshape, + BatchNormalization, Conv2D, + Conv2DTranspose, Concatenate) + +TIMESTEPS = None +SAMPLES = 1281 + + +def bn_relu(x): + return ReLU()(BatchNormalization()(x)) + + +def simple_sigmoid(x): + return tf.clip_by_value(ReLU()(x + .5), 0, 1) + + +def simple_tanh(x): + return tf.clip_by_value(x, -1, 1) + + +def gru_block(x, num_samples, enc_f, state_input): + num_chans = 32 + x = Reshape([num_samples, num_chans * enc_f])(x) + if state_input is not None: + x, state = GRU(num_chans * enc_f, return_state=True, + return_sequences=True, unroll=True)(x, state_input) + # activation=simple_tanh, + # recurrent_activation=simple_sigmoid)(x, state_input) + else: + x, state = GRU(num_chans * enc_f, return_sequences=True)(x), None + # activation=simple_tanh, + # recurrent_activation=simple_sigmoid)(x), None + x = Reshape([num_samples, enc_f, 32])(x) + return x, state + + +def get_trunet(num_freqs=64, num_samples=SAMPLES, inference=False): + channels = [24, 32, 48, 48, 64, 64] + strides = [2, 1, 2, 1, 2, 2] + k_sizes = [5, 3, 5, 3, 5, 3] + zipped = list(zip(k_sizes, strides, channels)) + inp = Input(shape=(num_samples, num_freqs, 1)) + state_input = Input(shape=(64,)) if inference else None + x = BatchNormalization()(inp) + x = Conv2D(channels[0], kernel_size=[1, k_sizes[0]], + strides=[1, strides[0]], padding="same", use_bias=False)(x) + x = bn_relu(x) + xs = [x] + for k, s, c in zipped[1:]: + x = Conv2D(c, kernel_size=[1, k], strides=[ + 1, s], padding="same", use_bias=False)(x) + x = bn_relu(x) + xs.append(x) + x = Conv2D(32, kernel_size=[1, 2], strides=[1, 2], + padding="same", use_bias=False)(x) + x = bn_relu(x) + x, new_state = gru_block(x, num_samples, 2, state_input) + x = Conv2DTranspose(32, kernel_size=[1, 2], strides=[1, 2], + padding="same", use_bias=False)(x) + x = bn_relu(x) + for (k, s, c), skip in list(zip(zipped, xs))[:1:-1]: + cs = (c * 2) // 3 + x = Concatenate()([x, skip]) + x = Conv2D(cs, kernel_size=[1, 1], use_bias=False)(x) + x = BatchNormalization()(x) + x = Conv2DTranspose(cs, kernel_size=[1, k], strides=[1, s], + padding="same", use_bias=False)(x) + x = bn_relu(x) + x = Concatenate()([x, xs[0]]) + x = Conv2DTranspose(1, kernel_size=[1, k_sizes[0]], + strides=[1, strides[0]], padding="same")(x) + out = tf.keras.activations.sigmoid(x) + inputs = [inp, state_input] if inference else inp + outputs = [out, new_state] if inference else out * inp + model = tf.keras.models.Model(inputs=inputs, outputs=outputs) + return model + + +if __name__ == "__main__": + model = get_trunet(64, 1, inference=True) + print(np.sum([np.prod(i.shape) for i in model.trainable_weights])) + tf.keras.utils.plot_model(model, show_shapes=True) diff --git a/examples/app_denoising/samples/input_1.wav b/examples/app_denoising/samples/input_1.wav new file mode 100644 index 000000000..142a463d1 Binary files /dev/null and b/examples/app_denoising/samples/input_1.wav differ diff --git a/examples/app_denoising/samples/input_2.wav b/examples/app_denoising/samples/input_2.wav new file mode 100644 index 000000000..f2f432b7d Binary files /dev/null and b/examples/app_denoising/samples/input_2.wav differ diff --git a/examples/app_denoising/samples/input_3.wav b/examples/app_denoising/samples/input_3.wav new file mode 100644 index 000000000..7626524ac Binary files /dev/null and b/examples/app_denoising/samples/input_3.wav differ diff --git a/examples/app_denoising/samples/input_4.wav b/examples/app_denoising/samples/input_4.wav new file mode 100644 index 000000000..1bbd97c49 Binary files /dev/null and b/examples/app_denoising/samples/input_4.wav differ diff --git a/examples/app_denoising/samples/input_5.wav b/examples/app_denoising/samples/input_5.wav new file mode 100644 index 000000000..210123b96 Binary files /dev/null and b/examples/app_denoising/samples/input_5.wav differ diff --git a/examples/app_denoising/samples/input_6.wav b/examples/app_denoising/samples/input_6.wav new file mode 100644 index 000000000..b363eb38f Binary files /dev/null and b/examples/app_denoising/samples/input_6.wav differ diff --git a/examples/app_denoising/tflite_converter.py b/examples/app_denoising/tflite_converter.py new file mode 100644 index 000000000..0bddba18b --- /dev/null +++ b/examples/app_denoising/tflite_converter.py @@ -0,0 +1,62 @@ +import tensorflow as tf +from model import get_trunet +from xmos_ai_tools import xformer +from dataset import data_gen +from tqdm import tqdm +import numpy as np + + +def get_rep_dataset(model): + d = data_gen("data/datasets_fullband/", + "data/MS-SNSD/", "data/rirs_noises/") + train_sample, _ = d.__next__() + outputs = np.zeros(train_sample.shape) + states = np.zeros([len(outputs), 64], dtype=np.float32) + for i in tqdm(range(len(outputs))): + out, state = model([train_sample[i:i+1][None], states[i:i+1]]) + outputs[i:i+1] = out + if i != len(outputs) - 1: + states[i+1:i+2] = state + for t, s in zip(train_sample[:, None, None], states[:, None]): + yield [t, s] + + +def save_tflite(model, ws_path, quant="16x8"): + model.load_weights(ws_path) + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.representative_dataset = lambda: get_rep_dataset(model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + if quant == "16x8": + print("Using experimental 16x8 quantization...") + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8] + converter.inference_input_type = tf.int16 + converter.inference_output_type = tf.int16 + elif quant == "8x8": + print("Using 8x8 quantization...") + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + else: + print("Using float32...") + return converter.convert() + + +def save_xformed(model_in, model_out): + hyper_params = {"xcore-thread-count": 5} + xformer.convert(model_in, model_out, hyper_params) + + +if __name__ == "__main__": + USE_XINTERPRETER = False + QUANT_TYPE = "float32" + MODEL_PATH = "models/model_64f_114k.h5" + OUTPUT_XC_MODEL = f"models/model_xc_{QUANT_TYPE}.tflite" + OUTPUT_TFLITE_MODEL = f"models/model_{QUANT_TYPE}.tflite" + model = get_trunet(64, 1, True) + tflite_model = save_tflite(model, MODEL_PATH, quant=QUANT_TYPE) + with open(OUTPUT_TFLITE_MODEL, "wb") as f: + f.write(tflite_model) + if USE_XINTERPRETER: + save_xformed(OUTPUT_TFLITE_MODEL, OUTPUT_XC_MODEL) diff --git a/examples/app_denoising/tflite_tester.py b/examples/app_denoising/tflite_tester.py new file mode 100644 index 000000000..327318455 --- /dev/null +++ b/examples/app_denoising/tflite_tester.py @@ -0,0 +1,89 @@ +import tensorflow as tf +import numpy as np +from tqdm import tqdm +from scipy.signal import istft +from scipy.io import wavfile +from dataset import F_BANK, process_wave, read_tfrecords +from xmos_ai_tools.xinterpreters import TFLMHostInterpreter +from train import weighted_mse + + +def reconstruct_signal(magnitudes, phases, fs=16000): + complex_spectrogram = magnitudes * np.exp(1j * phases) + _, reconstructed_signal = istft( + complex_spectrogram.T, + fs=16000, + nperseg=512, + noverlap=512 - 128, + nfft=512 + ) + return reconstructed_signal + + +def get_xinterpreter(model_path): + with open(model_path, "rb") as fd: + model = fd.read() + ie = TFLMHostInterpreter() + ie.set_model(model_content=model, secondary_memory=False) + return ie + + +def get_tflite_interpreter(model_path): + ie = tf.lite.Interpreter(model_path=model_path) + ie.allocate_tensors() + return ie + + +def get_preds(ie, x): + in_dets = ie.get_input_details() + out_dets = ie.get_output_details() + in_scale, in_zp = in_dets[0]['quantization'] + out_scale, out_zp = out_dets[1]['quantization'] + if in_scale == 0. and out_scale == 0.: + in_scale = out_scale = 1. + x = (x / in_scale + in_zp).astype(in_dets[0]["dtype"]) + outputs = np.zeros(x.shape, dtype=out_dets[1]["dtype"]) + state = np.zeros([1, 64], dtype=out_dets[0]["dtype"]) + for i in range(len(outputs)): + ie.set_tensor(in_dets[1]["index"], state) + ie.set_tensor(in_dets[0]["index"], x[i:i+1][None]) + ie.invoke() + state = ie.get_tensor(out_dets[0]['index']) + outputs[i:i+1] = ie.get_tensor(out_dets[1]['index']) + outputs = ((outputs.astype(np.float32) - out_zp) * out_scale) + return outputs[..., 0] + + +def write_wav(signal, path): + signal = (signal / np.max(np.abs(signal)) * (2**14)).astype(np.int16) + wavfile.write(path, 16000, signal) + + +def evaluate_model(ie): + _, test = read_tfrecords("data/records_64/", bs=1) + losses = [] + for x, y in tqdm(test): + preds = get_preds(ie, x.numpy()[0])[None, ..., None] + losses.append(weighted_mse(y, preds*x)) + return np.mean(losses) + + +if __name__ == "__main__": + QUANT_TYPE = "float32" + MODEL_NAME = f"models/model_{QUANT_TYPE}.tflite" + USE_XINTERPRETER = False + if USE_XINTERPRETER: + ie = get_xinterpreter(MODEL_NAME) + else: + ie = get_tflite_interpreter(MODEL_NAME) + print(evaluate_model(ie)) + for num in range(1, 7): + input_path = f"samples/input_{num}.wav" + output_path = f"samples/output_{num}_{QUANT_TYPE}_relu.wav" + _, wav = wavfile.read(input_path) + x, orig = process_wave(wav, True) + outs = get_preds(ie, x) + mask = ((outs @ F_BANK)**(1/.3)).clip(0.1, 1) + mags = (np.abs(orig) * mask) + signal = reconstruct_signal(mags, np.angle(orig)) + write_wav(signal, output_path) diff --git a/examples/app_denoising/train.py b/examples/app_denoising/train.py new file mode 100644 index 000000000..24a3e5e42 --- /dev/null +++ b/examples/app_denoising/train.py @@ -0,0 +1,57 @@ +from dataset import read_tfrecords +import tensorflow as tf +from model import get_trunet +from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint +# import os +# os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + +EPOCHS = 100 +BATCH_SIZE = 4 +MODEL_NAME = "models/model_64f_114k.h5" + + +def weighted_mse(y_true, y_pred): + squared_difference = tf.square(y_true - y_pred) + weights = tf.where(y_pred < y_true, 4.0, 1.0) + weighted_squared_difference = weights * squared_difference + wmse = tf.reduce_mean(weighted_squared_difference) + return wmse + + +early_stopping = EarlyStopping( + monitor='val_loss', + patience=5, + restore_best_weights=True, +) + +reduce_lr = ReduceLROnPlateau( + monitor='val_loss', + factor=0.5, + patience=1, + min_lr=2e-5, + min_delta=0.01, +) + + +model_checkpoint = ModelCheckpoint( + filepath=MODEL_NAME, + save_best_only=True, + monitor='val_loss', + save_weights_only=False, + save_freq='epoch' +) + + +def train(): + train_ds, test_ds = read_tfrecords("data/records_64", bs=BATCH_SIZE) + model = get_trunet(64) + optimizer = tf.keras.optimizers.Adam(4e-4) + model.compile(optimizer=optimizer, loss=weighted_mse) + model.fit(train_ds, + epochs=EPOCHS, validation_data=test_ds, + callbacks=[reduce_lr, early_stopping, model_checkpoint]) + model.save(MODEL_NAME) + + +if __name__ == "__main__": + train()