diff --git a/oemer/bbox.py b/oemer/bbox.py index 60baa3b..9ba88d3 100755 --- a/oemer/bbox.py +++ b/oemer/bbox.py @@ -2,6 +2,7 @@ from typing import Union, Any, List, Tuple, Dict import cv2 +from cv2.typing import RotatedRect import numpy as np from numpy import ndarray from sklearn.cluster import AgglomerativeClustering @@ -118,11 +119,12 @@ def find_lines(data: ndarray, min_len: int = 10, max_gap: int = 20) -> List[BBox lines = cv2.HoughLinesP(data.astype(np.uint8), 1, np.pi/180, 50, None, min_len, max_gap) new_line = [] - for line in lines: - line = line[0] - top_x, bt_x = (line[0], line[2]) if line[0] < line[2] else (line[2], line[0]) - top_y, bt_y = (line[1], line[3]) if line[1] < line[3] else (line[3], line[1]) - new_line.append((top_x, top_y, bt_x, bt_y)) + if lines is not None: + for line in lines: + line = line[0] + top_x, bt_x = (line[0], line[2]) if line[0] < line[2] else (line[2], line[0]) + top_y, bt_y = (line[1], line[3]) if line[1] < line[3] else (line[3], line[1]) + new_line.append((top_x, top_y, bt_x, bt_y)) return new_line @@ -159,7 +161,7 @@ def draw_bounding_boxes( return img -def get_rotated_bbox(data: ndarray) -> List[Tuple[Tuple[float, float], Tuple[float, float], float]]: +def get_rotated_bbox(data: ndarray) -> List[RotatedRect]: contours, _ = cv2.findContours(data.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) bboxes = [] for cnt in contours: diff --git a/oemer/inference.py b/oemer/inference.py index 74db2c0..082815c 100755 --- a/oemer/inference.py +++ b/oemer/inference.py @@ -58,8 +58,8 @@ def inference( image_pil = Image.open(img_path) if "GIF" != image_pil.format: # Tricky workaround to avoid random mistery transpose when loading with 'Image'. - image_pil = cv2.imread(img_path) - image_pil = Image.fromarray(image_pil) + image_cv = cv2.imread(img_path) + image_pil = Image.fromarray(image_cv) image_pil = image_pil.convert("RGB") image = np.array(resize_image(image_pil)) diff --git a/oemer/notehead_extraction.py b/oemer/notehead_extraction.py index 13b8174..21eeb28 100755 --- a/oemer/notehead_extraction.py +++ b/oemer/notehead_extraction.py @@ -306,7 +306,7 @@ def fill_hole(region: ndarray) -> ndarray: def gen_notes(bboxes: List[ndarray], symbols: ndarray) -> List[NoteHead]: notes = [] for bbox in bboxes: - # Instanitiate notehead. + # Instantiate notehead. nn = NoteHead() nn.bbox = typing.cast(BBox, bbox) @@ -343,7 +343,7 @@ def assign_group_track(st: Staff) -> None: # The value could also be negative. The zero index starts from the position # same as D4, assert the staffline is in treble clef. The value increases # as the pitch goes up. - # Build centers of each postion first. + # Build centers of each position first. step = st_master.unit_size / 2 pos_cen = [l.y_center for l in st_master.lines[::-1]] tmp_inter = [] @@ -354,7 +354,7 @@ def assign_group_track(st: Staff) -> None: pos_cen.insert(idx*2+1, interp) pos_cen = [pos_cen[0]+step] + pos_cen + [pos_cen[-1]-step] - # Estimate position by the closeset center. + # Estimate position by the closest center. pos_idx = np.argmin(np.abs(np.array(pos_cen)-cen_y)) if 0 < pos_idx < len(pos_cen)-1: nn.staff_line_pos = int(pos_idx) diff --git a/oemer/rhythm_extraction.py b/oemer/rhythm_extraction.py index 26bf8f2..73d4e59 100755 --- a/oemer/rhythm_extraction.py +++ b/oemer/rhythm_extraction.py @@ -3,6 +3,7 @@ import math import cv2 +from cv2.typing import RotatedRect import scipy.ndimage import numpy as np from numpy import ndarray @@ -37,7 +38,12 @@ def scan_dot( # Find the right most bound for scan the dot. # Should have width less than unit_size, and can't # touch the nearby note. - cur_scan_line = note_id_map[int(start_y):int(bbox[3]), int(right_bound)] + try: + cur_scan_line = note_id_map[int(start_y):int(bbox[3]), int(right_bound)] + except IndexError as e: + print(e) + break + ids = set(np.unique(cur_scan_line)) if -1 in ids: ids.remove(-1) @@ -134,7 +140,7 @@ def parse_beams( min_area_ratio: float = 0.07, min_tp_ratio: float = 0.4, min_width_ratio: float = 0.2 -) -> Tuple[ndarray, List[Tuple[Tuple[float, float], Tuple[float, float], float]], ndarray]: +) -> Tuple[ndarray, List[RotatedRect], ndarray]: # Fetch parameters symbols = layers.get_layer('symbols_pred') staff_pred = layers.get_layer('staff_pred') @@ -156,14 +162,14 @@ def parse_beams( ratio_map = np.copy(poly_map) null_color = (255, 255, 255) - valid_box = [] + valid_box: List[RotatedRect] = [] valid_idxs = [] idx_map = np.zeros_like(poly_map) - 1 - for idx, rbox in enumerate(rboxes): # type: ignore + for box_idx, rbox in enumerate(rboxes): # type: ignore # Used to find indexes of contour areas later. Must be check before # any 'continue' statement. - idx %= 255 # type: ignore - if idx == 0: + box_idx %= 255 # type: ignore + if box_idx == 0: idx_map = np.zeros_like(poly_map) - 1 # Get the contour of the rotated box @@ -186,8 +192,8 @@ def parse_beams( continue # Tricky way to get the index of the contour area - cv2.fillPoly(idx_map, [cnt], color=(idx, 0, 0)) - yi, xi = np.where(idx_map[..., 0] == idx) + cv2.fillPoly(idx_map, [cnt], color=(box_idx, 0, 0)) + yi, xi = np.where(idx_map[..., 0] == box_idx) pts = beams[yi, xi] meta_idx = np.where(pts>0)[0] ryi = yi[meta_idx] @@ -429,7 +435,9 @@ def get_label(nbox, stem_up): end_x=min(set_box[2], cen_x+half_scan_width), end_y=end_y, threshold=threshold - ) + ) + if count >= len(note_type_map): + return note_type_map[len(note_type_map) - 1] return note_type_map[count] if len(nts) == 2: @@ -501,7 +509,7 @@ def parse_rhythm(beam_map: ndarray, map_info: Dict[int, Dict[str, Any]], agree_t rev_map_info[gid] = {'reg': reg, 'bbox': box} # Define beam count to note type mapping - note_type_map = { + note_type_map: Dict[int, NoteType] = { 0: NoteType.QUARTER, 1: NoteType.EIGHTH, 2: NoteType.SIXTEENTH, @@ -580,7 +588,7 @@ def parse_rhythm(beam_map: ndarray, map_info: Dict[int, Dict[str, Any]], agree_t end_y = gbox[3] # Calculate how many beams/flags are there. - count = scan_beam_flag( # type: ignore + beam_flag_count = scan_beam_flag( # type: ignore bin_beam_map, max(reg_box[0], cen_x-half_scan_width), start_y, @@ -590,12 +598,15 @@ def parse_rhythm(beam_map: ndarray, map_info: Dict[int, Dict[str, Any]], agree_t ) #cv2.rectangle(beam_img, (gbox[0], gbox[1]), (gbox[2], gbox[3]), (255, 0, 255), 1) - cv2.putText(beam_img, str(count), (int(cen_x), int(gbox[3])+2), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 1) + cv2.putText(beam_img, str(beam_flag_count), (int(cen_x), int(gbox[3])+2), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 1) # Assign note label for nid in group.note_ids: if notes[nid].label is None: - notes[nid].label = note_type_map[count] # type: ignore + if beam_flag_count in note_type_map: + notes[nid].label = note_type_map[beam_flag_count] + else: + notes[nid].invalid = True return beam_img @@ -605,7 +616,7 @@ def extract( dot_max_area_ratio: float = 0.2, beam_min_area_ratio: float = 0.07, agree_th: float = 0.15 -) -> Tuple[ndarray, List[Tuple[Tuple[float, float], Tuple[float, float], float]]]: +) -> Tuple[ndarray, List[RotatedRect]]: logger.debug("Parsing dot") parse_dot(max_area_ratio=dot_max_area_ratio, min_area_ratio=dot_min_area_ratio) diff --git a/oemer/staffline_extraction.py b/oemer/staffline_extraction.py index 70bd92d..fc4b24f 100755 --- a/oemer/staffline_extraction.py +++ b/oemer/staffline_extraction.py @@ -331,7 +331,7 @@ def extract( # Start process zones, *_ = init_zones(staff_pred, splits=splits) - all_staffs = [] + all_staffs: List[List[Staff]] = [] for rr in zones: print(rr[0], rr[-1], end=' ') rr = np.array(rr, dtype=np.int64) @@ -339,43 +339,45 @@ def extract( if staffs is not None: all_staffs.append(staffs) print(len(staffs)) - all_staffs = align_staffs(all_staffs) # type: ignore + aligned_staffs: np.ndarray = align_staffs(all_staffs) # Use barline information to infer the number of tracks for each group. - num_track = further_infer_track_nums(all_staffs, min_degree=barline_min_degree) # type: ignore + num_track = further_infer_track_nums(aligned_staffs, min_degree=barline_min_degree) logger.debug(f"Tracks: {num_track}") - for col_sts in all_staffs: + for col_sts in aligned_staffs: for idx, st in enumerate(col_sts): st.track = idx % num_track st.group = idx // num_track # Validate staffs across zones. # Should have same number of staffs - if not all([len(staff) == len(all_staffs[0]) for staff in all_staffs]): + if not all([len(staff) == len(aligned_staffs[0]) for staff in aligned_staffs]): raise Exception - assert all([len(staff) == len(all_staffs[0]) for staff in all_staffs]) + assert all([len(staff) == len(aligned_staffs[0]) for staff in aligned_staffs]) norm = lambda data: np.abs(np.array(data) / np.mean(data) - 1) - for staffs in all_staffs.T: # type: ignore + valid_staffs: list[list[Staff]] = [] + for staffs in aligned_staffs.T: # Should all have 5 lines line_num = [len(staff.lines) for staff in staffs] if len(set(line_num)) != 1: - raise E.StafflineCountInconsistent( - f"Some of the stafflines contains less or more than 5 lines: {line_num}") + print(f"Some of the stafflines contains less or more than 5 lines: {line_num}") + continue # Check Staffs that are approximately at the same row. centers = np.array([staff.y_center for staff in staffs]) if not np.all(norm(centers) < horizontal_diff_th): - raise E.StafflineNotAligned( - f"Centers of staff parts at the same row not aligned (Th: {horizontal_diff_th}): {norm(centers)}") + print(f"Centers of staff parts at the same row not aligned (Th: {horizontal_diff_th}): {norm(centers)}") + continue # Unit sizes should roughly all the same unit_size = np.array([staff.unit_size for staff in staffs]) - if not np.all(norm(unit_size) < unit_size_diff_th): - raise E.StafflineUnitSizeInconsistent( - f"Unit sizes not consistent (th: {unit_size_diff_th}): {norm(unit_size)}") + if not np.all(norm(unit_size) < unit_size_diff_th): + print(f"Unit sizes not consistent (th: {unit_size_diff_th}): {norm(unit_size)}") + continue + valid_staffs.append(staffs) - return np.array(all_staffs), zones + return np.array(valid_staffs).T, zones def extract_part(pred: ndarray, x_offset: int, line_threshold: float = 0.8) -> List[Staff]: diff --git a/oemer/symbol_extraction.py b/oemer/symbol_extraction.py index 369dc84..9d0b2f9 100755 --- a/oemer/symbol_extraction.py +++ b/oemer/symbol_extraction.py @@ -270,7 +270,7 @@ def parse_clefs_keys( for box in bboxes: w = box[2] - box[0] h = box[3] - box[1] - region = clefs_keys[box[1]:box[3], box[0]:box[2]] + region: ndarray = clefs_keys[box[1]:box[3], box[0]:box[2]] usize = get_unit_size(*get_center(box)) area_size_ratio = w * h / usize**2 area_tp_ratio = region[region>0].size / (w * h) @@ -317,6 +317,8 @@ def parse_rests(line_box: ndarray, unit_size: float) -> Tuple[List[BBox], List[s bboxes = get_bbox(rests) bboxes = filter_out_of_range_bbox(bboxes) + if len(bboxes) == 0: + return [], [] bboxes = merge_nearby_bbox(bboxes, unit_size*1.2) bboxes = rm_merge_overlap_bbox(bboxes) bboxes = filter_out_small_area(bboxes, area_size_func=lambda usize: usize**2 * 0.7) @@ -373,6 +375,9 @@ def get_nearby_note_id(box: BBox, note_id_map: ndarray) -> Union[int, None]: unit_size = int(round(get_unit_size(cen_x, cen_y))) nid = None for x in range(box[2], box[2]+unit_size): + is_in_range = (0 <= cen_y < note_id_map.shape[0]) and (0 <= x < note_id_map.shape[1]) + if not is_in_range: + continue if note_id_map[cen_y, x] != -1: nid = note_id_map[cen_y, x] break @@ -401,11 +406,14 @@ def gen_sfns(bboxes: List[BBox], labels: List[str]) -> List[Sfn]: if ss.note_id is not None: note = notes[ss.note_id] if ss.track != note.track: - raise E.SfnNoteTrackMismatch(f"Track of sfn and note not mismatch: {ss}\n{note}") - if ss.group != note.group: - raise E.SfnNoteGroupMismatch(f"Group of sfn and note not mismatch: {ss}\n{note}") - notes[ss.note_id].sfn = ss.label - ss.is_key = False + print(f"Track of sfn and note mismatch: {ss}\n{note}") + notes[ss.note_id].invalid = True + elif ss.group != note.group: + print(f"Group of sfn and note mismatch: {ss}\n{note}") + notes[ss.note_id].invalid = True + else: + notes[ss.note_id].sfn = ss.label + ss.is_key = False sfns.append(ss) return sfns @@ -432,7 +440,7 @@ def gen_rests(bboxes: List[BBox], labels: List[str]) -> List[Rest]: rr.group = st1.group unit_size = int(round(get_unit_size(*get_center(box)))) - dot_range = range(box[2]+1, box[2]+unit_size) + dot_range = range(box[2]+1, min(box[2]+unit_size, symbols.shape[1] - 1)) dot_region = symbols[box[1]:box[3], dot_range] if 0 < np.sum(dot_region) < unit_size**2 / 7: rr.has_dot = True