From 072b69e780a176ccd4a0d47662e747eaca4767f9 Mon Sep 17 00:00:00 2001 From: Christian Liebhardt <544520+liebharc@users.noreply.github.com> Date: Sat, 17 Feb 2024 03:18:09 +0100 Subject: [PATCH] Make training of segnet, unet and classifiers easier by providing a single entry point to all training steps (#54) * Fixed 'TypeError: Cannot convert 4.999899999999999e-07 to EagerTensor of dtype int64' in training, fixes #39 https://stackoverflow.com/questions/76511182/tensorflow-custom-learning-rate-scheduler-gives-unexpected-eagertensor-type-erro * --format was deprecated in ruff and replaced wtih --output-format * Added a single entry point to train all models * Added convenience wrapper for oemer * Tried to figure out the definitions for the dense dataset and to document them in code There is likely an official definition somewhere but I just couldn't find it. So I looked at example and tried to reconstruct the mapping. Unknown basically means that I just couldn't see the symbol on the picture. * Decreased queue sizes as otherwise the training process crashed with an out of memory exception after it used up about 30GB of memory * Added model outputs to git ignore * Added checks for dataset folders * Using default training params * Added workarounds for removal of np.float * Using dataset definitions * Added type annotations * Added a train_all_rests even if the resulting model is right now not used in oemer * segnet and unet should now pick the correct model * Changed label definitions from what appears to be used in oemer right now * With this commit the resulting arch.json matches the one inside of oemer/ceckpoints/seg_net/arch.json * Avoid that the OMR processes finishes prematurely (#53) * Fixed typos in comments * IndexError while scanning for a dot should not abort the whole process * Bound check while getting the note label * Added check if label is in the note_type_map * Filter staffs instead of aborting with an exception * Bound check during symbol extraction * Marking notes as invalid instead of aborting with an exception * Bound check * Fixed type error * Fixed TypeError at start of unet or segnet training (#52) * Fixed 'TypeError: Cannot convert 4.999899999999999e-07 to EagerTensor of dtype int64' in training, fixes #39 https://stackoverflow.com/questions/76511182/tensorflow-custom-learning-rate-scheduler-gives-unexpected-eagertensor-type-erro * --format was deprecated in ruff and replaced wtih --output-format * HoughLinesP can return None if no lines are found * Fixed error which happens if no rest bboxes were found * Limited try/except block * Fixed typo * Use fixed versions for the linter dependencies to avoid that results are different for the same source code level on different test runs due to update of the dependencies * Fixed type errors which came up with the recent version of cv2 * Going back to the newest version of ruff and mypy as the type errors were introduced by cv2 * Fix install from github command in README --------- Co-authored-by: Yoyo --- .gitignore | 21 +++- main.py | 3 + oemer/build_label.py | 23 ++-- oemer/classifier.py | 52 ++++++++- oemer/constant.py | 30 +++--- oemer/constant_min.py | 17 ++- oemer/dense_dataset_definitions.py | 78 ++++++++++++++ oemer/models/unet.py | 46 +++++--- oemer/train.py | 166 +++++++++++++++++------------ train.py | 64 +++++++++++ 10 files changed, 376 insertions(+), 124 deletions(-) create mode 100644 main.py create mode 100644 oemer/dense_dataset_definitions.py create mode 100644 train.py diff --git a/.gitignore b/.gitignore index 2b08b63..1142069 100755 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,23 @@ checkpoints/ *.musicxml *.mp3 -*.swp \ No newline at end of file +*.swp + +# Model training datasets +/ds2_dense +/CvcMuscima-Distortions + +# Model training checkpoints and outputs +/seg_unet +/test_data +/train_data +/*.model +/*.h5 +/*.json + +/segnet_* +/unet_* +/rests_* +/all_rests_* +/sfn_* +/clef_* \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..aee9622 --- /dev/null +++ b/main.py @@ -0,0 +1,3 @@ +from oemer import ete + +ete.main() \ No newline at end of file diff --git a/oemer/build_label.py b/oemer/build_label.py index 1490266..6026696 100755 --- a/oemer/build_label.py +++ b/oemer/build_label.py @@ -1,3 +1,4 @@ +import sys import os import random from PIL import Image @@ -5,10 +6,11 @@ import cv2 import numpy as np -from .constant_min import CLASS_CHANNEL_MAP +from .constant_min import CLASS_CHANNEL_MAP, CHANNEL_NUM +from .dense_dataset_definitions import DENSE_DATASET_DEFINITIONS as DEF -HALF_WHOLE_NOTE = [39, 41, 42, 43, 45, 46, 47, 49] +HALF_WHOLE_NOTE = DEF.NOTEHEADS_HOLLOW + DEF.NOTEHEADS_WHOLE + [42] def fill_hole(gt, tar_color): @@ -75,12 +77,12 @@ def build_label(seg_path): color_set = set(np.unique(arr)) color_set.remove(0) # Remove background color from the candidates - total_chs = len(set(CLASS_CHANNEL_MAP.values())) + 2 # Plus 'background' and 'others' channel. + total_chs = CHANNEL_NUM output = np.zeros(arr.shape + (total_chs,)) output[..., 0] = np.where(arr==0, 1, 0) for color in color_set: - ch = CLASS_CHANNEL_MAP.get(color, -1) + ch = CLASS_CHANNEL_MAP.get(color, 0) if (ch != 0) and color in HALF_WHOLE_NOTE: note = fill_hole(arr, color) output[..., ch] += note @@ -101,12 +103,7 @@ def find_example(dataset_path: str, color: int, max_count=100, mark_value=200): if __name__ == "__main__": - seg_folder = '/media/kohara/ADATA HV620S/dataset/ds2_dense/segmentation' - files = os.listdir(seg_folder) - path = os.path.join(seg_folder, random.choice(files)) - #out = build_label(path) - - color = 45 - arr = find_example(color) # type: ignore - arr = np.where(arr==200, color, arr) - out = fill_hole(arr, color) + seg_folder = 'ds2_dense/segmentation' + color = int(sys.argv[1]) + with_background, without_background = find_example(seg_folder, color) + cv2.imwrite("example.png", with_background) diff --git a/oemer/classifier.py b/oemer/classifier.py index 8d8536f..d94a1ee 100755 --- a/oemer/classifier.py +++ b/oemer/classifier.py @@ -58,6 +58,7 @@ def _collect(color, out_path, samples=100): img = imaugs.resize(Image.fromarray(patch.astype(np.uint8)), width=tar_w, height=tar_h) seed = random.randint(0, 1000) + np.float = float # Monkey patch to workaround removal of np.float img = imaugs.perspective_transform(img, seed=seed, sigma=3) img = np.where(np.array(img)>0, 255, 0) Image.fromarray(img.astype(np.uint8)).save(out_path / f"{idx}.png") @@ -118,10 +119,12 @@ def train(folders): model.fit(train_x, train_y) return model, class_map +def build_class_map(folders): + return {idx: Path(ff).name for idx, ff in enumerate(folders)} def train_tf(folders): import tensorflow as tf - class_map = {idx: Path(ff).name for idx, ff in enumerate(folders)} + class_map = build_class_map(folders) train_x = [] train_y = [] samples = None @@ -234,6 +237,53 @@ def predict(region, model_name): pred = model.predict(np.array(region).reshape(1, -1)) return m_info['class_map'][pred[0]] +def train_rests_above8(filename = "rests_above8.model"): + folders = ["rest_8th", "rest_16th", "rest_32nd", "rest_64th"] + model, class_map = train_tf([f"train_data/{folder}" for folder in folders]) + test_tf(model, [f"test_data/{folder}" for folder in folders]) + output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map} + pickle.dump(output, open(filename, "wb")) + + +def train_rests(filename = "rests.model"): + folders = ["rest_whole", "rest_quarter", "rest_8th"] + model, class_map = train_tf([f"train_data/{folder}" for folder in folders]) + test_tf(model, [f"test_data/{folder}" for folder in folders]) + output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map} + pickle.dump(output, open(filename, "wb")) + + +def train_all_rests(filename = "all_rests.model"): + folders = ["rest_whole", "rest_quarter", "rest_8th", "rest_16th", "rest_32nd", "rest_64th"] + model, class_map = train_tf([f"train_data/{folder}" for folder in folders]) + test_tf(model, [f"test_data/{folder}" for folder in folders]) + output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map} + pickle.dump(output, open(filename, "wb")) + + +def train_sfn(filename = "sfn.model"): + folders = ["sharp", "flat", "natural"] + model, class_map = train_tf([f"train_data/{folder}" for folder in folders]) + test_tf(model, [f"test_data/{folder}" for folder in folders]) + output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map} + pickle.dump(output, open(filename, "wb")) + + +def train_clefs(filename = "clef.model"): + folders = ["gclef", "fclef"] + model, class_map = train_tf([f"train_data/{folder}" for folder in folders]) + test_tf(model, [f"test_data/{folder}" for folder in folders]) + output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map} + pickle.dump(output, open(filename, "wb")) + + +def train_noteheads(): + folders = ["notehead_solid", "notehead_hollow"] + model, class_map = train_tf([f"train_data/{folder}" for folder in folders]) + test_tf(model, [f"test_data/{folder}" for folder in folders]) + output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map} + pickle.dump(output, open(f"notehead.model", "wb")) + if __name__ == "__main__": samples = 400 diff --git a/oemer/constant.py b/oemer/constant.py index 3cc789a..3d7a320 100755 --- a/oemer/constant.py +++ b/oemer/constant.py @@ -1,21 +1,23 @@ from enum import Enum, auto +from oemer.dense_dataset_definitions import DENSE_DATASET_DEFINITIONS as DEF + CLASS_CHANNEL_LIST = [ - [165, 2], # staff, ledgerLine - [35, 37, 38], # noteheadBlack - [39, 41, 42], # noteheadHalf - [43, 45, 46, 47, 49], # noteheadWhole - [64, 58, 59, 60, 66, 63, 69, 68, 61, 62, 67, 65], # flags - [146, 51], # beam, augmentationDot - [3, 52], # barline, stem - [74, 70, 72, 76], # accidentalSharp, accidentalFlat, accidentalNatural, accidentalDoubleSharp - [80, 78, 79], # keySharp, keyFlat, keyNatural - [97, 100, 99, 98, 101, 102, 103, 104, 96, 163], # rests - [136, 156, 137, 155, 152, 151, 153, 154, 149, 155], # tuplets - [145, 147], # slur, tie - [10, 13, 12, 19, 11, 20], # clefs - [25, 24, 29, 22, 23, 28, 27, 34, 30, 21, 33, 26], # timeSigs + DEF.STAFF + DEF.LEDGERLINE, + DEF.NOTEHEADS_SOLID + [38], + DEF.NOTEHEADS_HOLLOW + [42], + DEF.NOTEHEADS_WHOLE + [46], + DEF.FLAG_DOWN + DEF.FLAG_UP + [59, 65], + DEF.BEAM + DEF.DOT, + DEF.BARLINE_BETWEEN + DEF.STEM, + DEF.ALL_ACCIDENTALS, + DEF.ALL_KEYS, + DEF.ALL_RESTS + [163], + DEF.TUPETS, + DEF.SLUR_AND_TIE, + DEF.ALL_CLEFS + DEF.NUMBERS, + DEF.TIME_SIGNATURE_SUBSET ] CLASS_CHANNEL_MAP = { diff --git a/oemer/constant_min.py b/oemer/constant_min.py index ea64b4d..eb761b4 100755 --- a/oemer/constant_min.py +++ b/oemer/constant_min.py @@ -1,13 +1,10 @@ +from oemer.dense_dataset_definitions import DENSE_DATASET_DEFINITIONS as DEF + + CLASS_CHANNEL_LIST = [ - [165, 2], # staff, ledgerLine - [35, 37, 38, 39, 41, 42, 43, 45, 46, 47, 49, 52], # notehead, stem - [ - 64, 58, 60, 66, 63, 69, 68, 61, 62, 67, 65, 59, 146, # flags, beam - 97, 100, 99, 98, 101, 102, 103, 104, 96, 163, # rests - 80, 78, 79, 74, 70, 72, 76, 3, # sharp, flat, natural, barline - 10, 13, 12, 19, 11, 20, 51, # clefs, augmentationDot, - 25, 24, 29, 22, 23, 28, 27, 34, 30, 21, 33, 26, # timeSigs - ] + DEF.STEM + DEF.ALL_RESTS_EXCEPT_LARGE + DEF.BARLINE_BETWEEN + DEF.BARLINE_END, + DEF.NOTEHEADS_ALL, + DEF.ALL_CLEFS + DEF.ALL_KEYS + DEF.ALL_ACCIDENTALS, ] CLASS_CHANNEL_MAP = { @@ -16,4 +13,4 @@ for color in colors } -CHANNEL_NUM = len(CLASS_CHANNEL_LIST) + 2 +CHANNEL_NUM = len(CLASS_CHANNEL_LIST) + 1 # Plus 'background' and 'others' channel. diff --git a/oemer/dense_dataset_definitions.py b/oemer/dense_dataset_definitions.py new file mode 100644 index 0000000..b0fa1a3 --- /dev/null +++ b/oemer/dense_dataset_definitions.py @@ -0,0 +1,78 @@ +class Symbols: + BACKGROUND = [0] + LEDGERLINE = [2] + BARLINE_BETWEEN = [3] + BARLINE_END = [4] + ALL_BARLINES = BARLINE_BETWEEN + BARLINE_END + REPEAT_DOTS = [7] + G_GLEF = [10] + C_CLEF = [11, 12] + F_CLEF = [13] + ALL_CLEFS = G_GLEF + C_CLEF + F_CLEF + NUMBERS = [19, 20] + TIME_SIGNATURE_SUBSET = [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33, 34] + TIME_SIGNATURE = TIME_SIGNATURE_SUBSET + [31, 32] # Oemer hasn't used these in the past + NOTEHEAD_FULL_ON_LINE = [35] + UNKNOWN = [36, 38, 40, 128, 143, 144, 148, 150, 157, 159, 160, 161, 162, 163, 164, 167, 170, 171] + NOTEHEAD_FULL_BETWEEN_LINES = [37] + NOTEHEAD_HOLLOW_ON_LINE = [39] + NOTEHEAD_HOLLOW_BETWEEN_LINE = [41] + WHOLE_NOTE_ON_LINE = [43] + WHOLE_NOTE_BETWEEN_LINE = [45] + DOUBLE_WHOLE_NOTE_ON_LINE = [47] + DOUBLE_WHOLE_NOTE_BETWEEN_LINE = [49] + NOTEHEADS_SOLID = NOTEHEAD_FULL_ON_LINE + NOTEHEAD_FULL_BETWEEN_LINES + NOTEHEADS_HOLLOW = NOTEHEAD_HOLLOW_ON_LINE + NOTEHEAD_HOLLOW_BETWEEN_LINE + NOTEHEADS_WHOLE = WHOLE_NOTE_ON_LINE + WHOLE_NOTE_BETWEEN_LINE + DOUBLE_WHOLE_NOTE_ON_LINE + DOUBLE_WHOLE_NOTE_BETWEEN_LINE + NOTEHEADS_ALL = NOTEHEAD_FULL_ON_LINE + NOTEHEAD_FULL_BETWEEN_LINES + NOTEHEAD_HOLLOW_ON_LINE + NOTEHEAD_HOLLOW_BETWEEN_LINE + WHOLE_NOTE_ON_LINE + WHOLE_NOTE_BETWEEN_LINE + DOUBLE_WHOLE_NOTE_ON_LINE + DOUBLE_WHOLE_NOTE_BETWEEN_LINE + DOT = [51] + STEM = [52] + TREMOLO = [53, 54, 55, 56] + FLAG_DOWN = [58, 60, 61, 62, 63] + FLAG_UP = [64, 66, 67, 68, 69] + FLAT = [70] + NATURAL = [72] + SHARP = [74] + DOUBLE_SHARP = [76] + ALL_ACCIDENTALS = FLAT + NATURAL + SHARP + DOUBLE_SHARP + KEY_FLAT = [78] + KEY_NATURAL = [79] + KEY_SHARP = [80] + ALL_KEYS = KEY_FLAT + KEY_NATURAL + KEY_SHARP + ACCENT_ABOVE = [81] + ACCENT_BELOW = [82] + STACCATO_ABOVE = [83] + STACCATO_BELOW = [84] + TENUTO_ABOVE = [85] + TENUTO_BELOW = [86] + STACCATISSIMO_ABOVE = [87] + STACCATISSIMO_BELOW = [88] + MARCATO_ABOVE = [89] + MARCATO_BELOW = [90] + FERMATA_ABOVE = [91] + FERMATA_BELOW = [92] + BREATH_MARK = [93] + REST_LARGE = [95] + REST_LONG = [96] + REST_BREVE = [97] + REST_FULL = [98] + REST_QUARTER = [99] + REST_EIGHTH = [100] + REST_SIXTEENTH = [101] + REST_THIRTY_SECOND = [102] + REST_SIXTY_FOURTH = [103] + REST_ONE_HUNDRED_TWENTY_EIGHTH = [104] + ALL_RESTS_EXCEPT_LARGE = REST_LONG + REST_BREVE + REST_FULL + REST_QUARTER + REST_EIGHTH + REST_SIXTEENTH + REST_THIRTY_SECOND + REST_SIXTY_FOURTH + REST_ONE_HUNDRED_TWENTY_EIGHTH + ALL_RESTS = ALL_RESTS_EXCEPT_LARGE + TRILL = [127] + GRUPPETO = [129] + MORDENT = [130] + DOWN_BOW = [131] + UP_BOW = [132] + SYMBOL = [133, 134, 135, 138, 139, 141, 142] + TUPETS = [136, 137, 149, 151, 152, 153, 154, 155, 156] + SLUR_AND_TIE = [145, 147] + BEAM = [146] + STAFF = [165] + +DENSE_DATASET_DEFINITIONS = Symbols() \ No newline at end of file diff --git a/oemer/models/unet.py b/oemer/models/unet.py index 66efd1d..ae6ee92 100755 --- a/oemer/models/unet.py +++ b/oemer/models/unet.py @@ -142,10 +142,19 @@ def my_conv_block(inp, kernels, kernel_size=(3, 3), strides=(1, 1)): return out +def my_conv_small_block(inp, kernels, kernel_size=(3, 3), strides=(1, 1)): + inp = L.Conv2D(kernels, kernel_size, strides=strides, padding='same', dtype=tf.float32)(inp) + out = L.Activation("relu")(L.LayerNormalization()(inp)) + out = L.Dropout(0.3)(out) + out = L.Add()([inp, out]) + out = L.Activation("relu")(L.LayerNormalization()(out)) + return out + + def my_trans_conv_block(inp, kernels, kernel_size=(3, 3), strides=(1, 1)): inp = L.Conv2DTranspose(kernels, kernel_size, strides=strides, padding='same', dtype=tf.float32)(inp) - out = L.Activation("relu")(L.LayerNormalization()(inp)) - out = L.Conv2D(kernels, kernel_size, padding='same', dtype=tf.float32)(out) + #out = L.Activation("relu")(L.LayerNormalization()(inp)) + out = L.Conv2D(kernels, kernel_size, padding='same', dtype=tf.float32)(inp) out = L.Activation("relu")(L.LayerNormalization()(out)) out = L.Dropout(0.3)(out) out = L.Add()([inp, out]) @@ -157,25 +166,25 @@ def u_net(win_size=288, out_class=3): inp = L.Input(shape=(win_size, win_size, 3)) tensor = L.SeparableConv2D(128, (3, 3), activation="relu", padding='same')(inp) - l1 = my_conv_block(tensor, 64, (3, 3), strides=(2, 2)) # 128 - l1 = my_conv_block(l1, 128, (3, 3)) - l1 = my_conv_block(l1, 128, (3, 3)) + l1 = my_conv_small_block(tensor, 64, (3, 3), strides=(2, 2)) + l1 = my_conv_small_block(l1, 64, (3, 3)) + l1 = my_conv_small_block(l1, 64, (3, 3)) - skip = my_conv_block(l1, 128, (3, 3), strides=(2, 2)) # 64 - l2 = my_conv_block(skip, 128, (3, 3)) - l2 = my_conv_block(l2, 128, (3, 3)) - l2 = my_conv_block(l2, 128, (3, 3)) - l2 = my_conv_block(l2, 128, (3, 3)) + skip = my_conv_small_block(l1, 128, (3, 3), strides=(2, 2)) + l2 = my_conv_small_block(skip, 128, (3, 3)) + l2 = my_conv_small_block(l2, 128, (3, 3)) + l2 = my_conv_small_block(l2, 128, (3, 3)) + l2 = my_conv_small_block(l2, 128, (3, 3)) l2 = L.Concatenate()([skip, l2]) - l3 = my_conv_block(l2, 256, (3, 3)) - l3 = my_conv_block(l3, 256, (3, 3)) - l3 = my_conv_block(l3, 256, (3, 3)) - l3 = my_conv_block(l3, 256, (3, 3)) - l3 = my_conv_block(l3, 256, (3, 3)) + l3 = my_conv_small_block(l2, 256, (3, 3)) + l3 = my_conv_small_block(l3, 256, (3, 3)) + l3 = my_conv_small_block(l3, 256, (3, 3)) + l3 = my_conv_small_block(l3, 256, (3, 3)) + l3 = my_conv_small_block(l3, 256, (3, 3)) l3 = L.Concatenate()([l2, l3]) - bot = my_conv_block(l3, 256, (3, 3), strides=(2, 2)) # 32 + bot = my_conv_small_block(l3, 256, (3, 3), strides=(2, 2)) # 32 st1 = L.SeparableConv2D(256, (3, 3), padding='same', dtype=tf.float32)(bot) st1 = L.Activation("relu")(L.LayerNormalization()(st1)) st2 = L.SeparableConv2D(256, (3, 3), dilation_rate=(2, 2), padding='same', dtype=tf.float32)(bot) @@ -189,20 +198,23 @@ def u_net(win_size=288, out_class=3): norm = L.Activation("relu")(L.LayerNormalization()(st)) bot = my_trans_conv_block(norm, 256, (3, 3), strides=(2, 2)) # 64 - tl3 = L.Conv2D(256, (3, 3), padding='same', dtype=tf.float32)(bot) + tl3 = L.Conv2D(128, (3, 3), padding='same', dtype=tf.float32)(bot) tl3 = L.Activation("relu")(L.LayerNormalization()(tl3)) tl3 = L.Concatenate()([tl3, l3]) + tl3 = my_conv_small_block(tl3, 128, (3, 3)) tl3 = my_trans_conv_block(tl3, 128, (3, 3)) # Head 1 tl2 = L.Conv2D(128, (3, 3), padding='same', dtype=tf.float32)(tl3) tl2 = L.Activation("relu")(L.LayerNormalization()(tl2)) tl2 = L.Concatenate()([tl2, l2]) + tl2 = my_conv_small_block(tl2, 128, (3, 3)) tl2 = my_trans_conv_block(tl2, 128, (3, 3), strides=(2, 2)) # 128 tl1 = L.Conv2D(128, (3, 3), padding='same', dtype=tf.float32)(tl2) tl1 = L.Activation("relu")(L.LayerNormalization()(tl1)) tl1 = L.Concatenate()([tl1, l1]) + tl1 = my_conv_small_block(tl1, 128, (3, 3)) tl1 = my_trans_conv_block(tl1, 128, (3, 3), strides=(2, 2)) # 256 out1 = L.Conv2D(out_class, (1, 1), activation='softmax', padding='same', dtype=tf.float32)(tl1) diff --git a/oemer/train.py b/oemer/train.py index dbe5b46..475acc2 100755 --- a/oemer/train.py +++ b/oemer/train.py @@ -11,11 +11,14 @@ from .build_label import build_label from .models.unet import semantic_segmentation, u_net -from .constant import CHANNEL_NUM +from .constant_min import CHANNEL_NUM def get_cvc_data_paths(dataset_path): + if not os.path.exists(dataset_path): + raise FileNotFoundError(f"{dataset_path} not found, download the dataset first.") + dirs = ["curvature", "ideal", "interrupted", "kanungo", "rotated", "staffline-thickness-variation-v1", "staffline-thickness-variation-v2", "staffline-y-variation-v1", "staffline-y-variation-v2", "thickness-ratio", "typeset-emulation", "whitespeckles"] @@ -37,6 +40,9 @@ def get_cvc_data_paths(dataset_path): def get_deep_score_data_paths(dataset_path): + if not os.path.exists(dataset_path): + raise FileNotFoundError(f"{dataset_path} not found, download the dataset first.") + imgs = os.listdir(os.path.join(dataset_path, "images")) paths = [] for img in imgs: @@ -122,9 +128,42 @@ def batch_transform(img, trans_func): result.append(np.array(tmp_img)) return np.dstack(result) + +class MultiprocessingDataLoader: + def __init__(self, num_worker: int): + self._queue: Queue = Queue(maxsize=20) + self._dist_queue: Queue = Queue(maxsize=30) + self._process_pool = [] + for _ in range(num_worker): + processor = Process(target=self._preprocess_image) + processor.daemon = True + self._process_pool.append(processor) + self._pdist = Process(target=self._distribute_process) + self._pdist.daemon = True + + def _start_processes(self): + if not self._pdist.is_alive(): + self._pdist.start() + for process in self._process_pool: + if not process.is_alive(): + process.start() -class DataLoader: + def _terminate_processes(self): + self._pdist.terminate() + for process in self._process_pool: + process.terminate() + + + def _distribute_process(self): + pass + + def _preprocess_image(self): + pass + + +class DataLoader(MultiprocessingDataLoader): def __init__(self, feature_files, win_size=256, num_samples=100, min_step_size=0.2, num_worker=4): + super().__init__(num_worker) self.feature_files = feature_files random.shuffle(self.feature_files) self.win_size = win_size @@ -138,16 +177,6 @@ def __init__(self, feature_files, win_size=256, num_samples=100, min_step_size=0 self.file_idx = 0 - self._queue = Queue(maxsize=200) - self._dist_queue = Queue(maxsize=300) - self._process_pool = [] - for _ in range(num_worker): - processor = Process(target=self._preprocess_image) - processor.daemon = True - self._process_pool.append(processor) - self._pdist = Process(target=self._distribute_process) - self._pdist.daemon = True - def _distribute_process(self): while True: paths = self.feature_files[self.file_idx] @@ -175,6 +204,7 @@ def _preprocess_image(self): # Random perspective transform seed = random.randint(0, 1000) + np.float = float # Monkey patch to workaround removal of np.float perspect_trans = lambda img: imaugs.perspective_transform(img, seed=seed, sigma=70) image = np.array(perspect_trans(image)) # RGB image staff_img = np.array(perspect_trans(staff_img)) # 1-bit mask @@ -187,11 +217,7 @@ def _preprocess_image(self): def __iter__(self): samples = 0 - if not self._pdist.is_alive(): - self._pdist.start() - for process in self._process_pool: - if not process.is_alive(): - process.start() + self._start_processes() while samples < self.num_samples: image, staff_img, symbol_img, ratio = self._queue.get() @@ -218,9 +244,7 @@ def __iter__(self): start_y = min(start_y + y_step, max_y) start_x = min(start_x + x_step, max_x) - self._pdist.terminate() - for process in self._process_pool: - process.terminate() + self._terminate_processes() def get_dataset(self, batch_size, output_types=None, output_shapes=None): def gen_wrapper(): @@ -240,8 +264,9 @@ def gen_wrapper(): .prefetch(tf.data.experimental.AUTOTUNE) -class DsDataLoader: +class DsDataLoader(MultiprocessingDataLoader): def __init__(self, feature_files, win_size=256, num_samples=100, step_size=0.5, num_worker=4): + super().__init__(num_worker) self.feature_files = feature_files random.shuffle(self.feature_files) self.win_size = win_size @@ -255,16 +280,6 @@ def __init__(self, feature_files, win_size=256, num_samples=100, step_size=0.5, self.file_idx = 0 - self._queue = Queue(maxsize=200) - self._dist_queue = Queue(maxsize=100) - self._process_pool = [] - for _ in range(num_worker): - processor = Process(target=self._preprocess_image) - processor.daemon = True - self._process_pool.append(processor) - self._pdist = Process(target=self._distribute_process) - self._pdist.daemon = True - def _distribute_process(self): while True: paths = self.feature_files[self.file_idx] @@ -293,6 +308,7 @@ def _preprocess_image(self): # Random perspective transform seed = random.randint(0, 1000) + np.float = float # Monkey patch to workaround removal of np.float perspect_trans = lambda img: imaugs.perspective_transform(img, seed=seed, sigma=70) image = np.array(batch_transform(image, perspect_trans)) # RGB image label = np.array(batch_transform(label, perspect_trans)) @@ -302,11 +318,7 @@ def _preprocess_image(self): def __iter__(self): samples = 0 - if not self._pdist.is_alive(): - self._pdist.start() - for process in self._process_pool: - if not process.is_alive(): - process.start() + self._start_processes() while samples < self.num_samples: image, label, ratio = self._queue.get() @@ -337,10 +349,7 @@ def __iter__(self): ll = label[index] yield feat, ll - self._pdist.terminate() - for process in self._process_pool: - process.terminate() - + self._terminate_processes() def get_dataset(self, batch_size, output_types=None, output_shapes=None): def gen_wrapper(): for data in self: @@ -410,7 +419,6 @@ def focal_tversky_loss(y_true, y_pred, fw=0.7, alpha=0.7, smooth=1., gamma=0.75) def train_model( dataset_path, - win_size=288, train_val_split=0.1, learning_rate=5e-4, epochs=15, @@ -418,33 +426,51 @@ def train_model( batch_size=8, val_steps=200, val_batch_size=8, - early_stop=8 + early_stop=8, + data_model="segnet" ): - # feat_files = get_cvc_data_paths(dataset_path) - feat_files = get_deep_score_data_paths(dataset_path) + if data_model == "segnet": + feat_files = get_deep_score_data_paths(dataset_path) + else: + feat_files = get_cvc_data_paths(dataset_path) random.shuffle(feat_files) split_idx = round(train_val_split * len(feat_files)) train_files = feat_files[split_idx:] val_files = feat_files[:split_idx] print(f"Loading dataset. Train/validation: {len(train_files)}/{len(val_files)}") - train_data = DsDataLoader( - train_files, - win_size=win_size, - num_samples=epochs*steps*batch_size - ) \ - .get_dataset(batch_size) - val_data = DsDataLoader( - val_files, - win_size=win_size, - num_samples=epochs*val_steps*val_batch_size - ) \ - .get_dataset(val_batch_size) + if data_model == "segnet": + win_size=288 + train_data = DsDataLoader( + train_files, + win_size=win_size, + num_samples=epochs*steps*batch_size + ) \ + .get_dataset(batch_size) + val_data = DsDataLoader( + val_files, + win_size=win_size, + num_samples=epochs*val_steps*val_batch_size + ) \ + .get_dataset(val_batch_size) + model = u_net(win_size=win_size, out_class=CHANNEL_NUM) + else: + win_size=256 + train_data = DataLoader( + train_files, + win_size=win_size, + num_samples=epochs*steps*batch_size + ) \ + .get_dataset(batch_size) + val_data = DataLoader( + val_files, + win_size=win_size, + num_samples=epochs*val_steps*val_batch_size + ) \ + .get_dataset(val_batch_size) + model = semantic_segmentation(win_size=256, out_class=3) print("Initializing model") - #model = naive_conv(win_size=win_size) - #model = u_net(win_size=win_size, out_class=CHANNEL_NUM) - model = semantic_segmentation(win_size=win_size, out_class=CHANNEL_NUM) optim = tf.keras.optimizers.Adam(learning_rate=WarmUpLearningRate(learning_rate)) #loss = tf.keras.losses.BinaryCrossentropy(label_smoothing=0.1) #loss = tf.keras.losses.CategoricalCrossentropy() @@ -457,15 +483,19 @@ def train_model( ] print("Start training") - model.fit( - train_data, - validation_data=val_data, - epochs=epochs, - steps_per_epoch=steps, - validation_steps=val_steps, - callbacks=callbacks - ) - return model + try: + model.fit( + train_data, + validation_data=val_data, + epochs=epochs, + steps_per_epoch=steps, + validation_steps=val_steps, + callbacks=callbacks + ) + return model + except Exception as e: + print(e) + return model def resize_image(image: Image.Image): diff --git a/train.py b/train.py new file mode 100644 index 0000000..2b90a51 --- /dev/null +++ b/train.py @@ -0,0 +1,64 @@ +import sys +import time +import os + +import tensorflow as tf + +from oemer import train +from oemer import classifier + + +def write_text_to_file(text, path): + with open(path, "w") as f: + f.write(text) + +if len(sys.argv) != 2: + print("Usage: python train.py ") + sys.exit(1) + +def get_model_base_name(model_name: str) -> str: + timestamp = str(round(time.time())) + return f"{model_name}_{timestamp}" + +model_type = sys.argv[1] + +def prepare_classifier_data(): + if not os.path.exists("train_data"): + classifier.collect_data(2000) + +if model_type == "segnet": + model = train.train_model("ds2_dense", data_model=model_type, steps=1500, epochs=15) + filename = get_model_base_name(model_type) + os.makedirs(filename) + write_text_to_file(model.to_json(), os.path.join(filename, "arch.json")) + model.save_weights(os.path.join(filename, "weights.h5")) +elif model_type == "unet": + model = train.train_model("CvcMuscima-Distortions", data_model=model_type, steps=1500, epochs=15) + filename = get_model_base_name(model_type) + os.makedirs(filename) + write_text_to_file(model.to_json(), os.path.join(filename, "arch.json")) + model.save_weights(os.path.join(filename, "weights.h5")) +elif model_type == "unet_from_checkpoint" or model_type == "segnet_from_checkpoint": + model = tf.keras.models.load_model("seg_unet", custom_objects={"WarmUpLearningRate": train.WarmUpLearningRate}) + filename = get_model_base_name(model_type.split("_")[0]) + os.makedirs(filename) + write_text_to_file(model.to_json(), os.path.join(filename, "arch.json")) + model.save_weights(os.path.join(filename, "weights.h5")) +elif model_type == "rests_above8": + prepare_classifier_data() + classifier.train_rests_above8(get_model_base_name(model_type)) +elif model_type == "rests": + prepare_classifier_data() + classifier.train_rests(get_model_base_name(model_type)) +elif model_type == "all_rests": + prepare_classifier_data() + classifier.train_all_rests(get_model_base_name(model_type)) +elif model_type == "sfn": + prepare_classifier_data() + classifier.train_sfn(get_model_base_name(model_type)) +elif model_type == "clef": + prepare_classifier_data() + classifier.train_clefs(get_model_base_name(model_type)) +else: + print("Unknown model: " + model_type) + sys.exit(1) \ No newline at end of file