Skip to content

Commit

Permalink
Avoid that the OMR processes finishes prematurely (#53)
Browse files Browse the repository at this point in the history
* Fixed typos in comments

* IndexError while scanning for a dot should not abort the whole process

* Bound check while getting the note label

* Added check if label is in the note_type_map

* Filter staffs instead of aborting with an exception

* Bound check during symbol extraction

* Marking notes as invalid instead of aborting with an exception

* Bound check

* Fixed type error

* Fixed TypeError at start of unet or segnet training (#52)

* Fixed 'TypeError: Cannot convert 4.999899999999999e-07 to EagerTensor of dtype int64' in training, fixes #39

https://stackoverflow.com/questions/76511182/tensorflow-custom-learning-rate-scheduler-gives-unexpected-eagertensor-type-erro

* --format was deprecated in ruff and replaced wtih --output-format

* HoughLinesP can return None if no lines are found

* Fixed error which happens if no rest bboxes were found

* Limited try/except block

* Fixed typo

* Use fixed versions for the linter dependencies to avoid that results are different for the same source code level on different test runs due to update of the dependencies

* Fixed type errors which came up with the recent version of cv2

* Going back to the newest version of ruff and mypy as the type errors were introduced by cv2
  • Loading branch information
liebharc authored Jan 29, 2024
1 parent 57f49d7 commit 49cef64
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 47 deletions.
14 changes: 8 additions & 6 deletions oemer/bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Union, Any, List, Tuple, Dict

import cv2
from cv2.typing import RotatedRect
import numpy as np
from numpy import ndarray
from sklearn.cluster import AgglomerativeClustering
Expand Down Expand Up @@ -118,11 +119,12 @@ def find_lines(data: ndarray, min_len: int = 10, max_gap: int = 20) -> List[BBox

lines = cv2.HoughLinesP(data.astype(np.uint8), 1, np.pi/180, 50, None, min_len, max_gap)
new_line = []
for line in lines:
line = line[0]
top_x, bt_x = (line[0], line[2]) if line[0] < line[2] else (line[2], line[0])
top_y, bt_y = (line[1], line[3]) if line[1] < line[3] else (line[3], line[1])
new_line.append((top_x, top_y, bt_x, bt_y))
if lines is not None:
for line in lines:
line = line[0]
top_x, bt_x = (line[0], line[2]) if line[0] < line[2] else (line[2], line[0])
top_y, bt_y = (line[1], line[3]) if line[1] < line[3] else (line[3], line[1])
new_line.append((top_x, top_y, bt_x, bt_y))
return new_line


Expand Down Expand Up @@ -159,7 +161,7 @@ def draw_bounding_boxes(
return img


def get_rotated_bbox(data: ndarray) -> List[Tuple[Tuple[float, float], Tuple[float, float], float]]:
def get_rotated_bbox(data: ndarray) -> List[RotatedRect]:
contours, _ = cv2.findContours(data.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
bboxes = []
for cnt in contours:
Expand Down
4 changes: 2 additions & 2 deletions oemer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def inference(
image_pil = Image.open(img_path)
if "GIF" != image_pil.format:
# Tricky workaround to avoid random mistery transpose when loading with 'Image'.
image_pil = cv2.imread(img_path)
image_pil = Image.fromarray(image_pil)
image_cv = cv2.imread(img_path)
image_pil = Image.fromarray(image_cv)

image_pil = image_pil.convert("RGB")
image = np.array(resize_image(image_pil))
Expand Down
6 changes: 3 additions & 3 deletions oemer/notehead_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def fill_hole(region: ndarray) -> ndarray:
def gen_notes(bboxes: List[ndarray], symbols: ndarray) -> List[NoteHead]:
notes = []
for bbox in bboxes:
# Instanitiate notehead.
# Instantiate notehead.
nn = NoteHead()
nn.bbox = typing.cast(BBox, bbox)

Expand Down Expand Up @@ -343,7 +343,7 @@ def assign_group_track(st: Staff) -> None:
# The value could also be negative. The zero index starts from the position
# same as D4, assert the staffline is in treble clef. The value increases
# as the pitch goes up.
# Build centers of each postion first.
# Build centers of each position first.
step = st_master.unit_size / 2
pos_cen = [l.y_center for l in st_master.lines[::-1]]
tmp_inter = []
Expand All @@ -354,7 +354,7 @@ def assign_group_track(st: Staff) -> None:
pos_cen.insert(idx*2+1, interp)
pos_cen = [pos_cen[0]+step] + pos_cen + [pos_cen[-1]-step]

# Estimate position by the closeset center.
# Estimate position by the closest center.
pos_idx = np.argmin(np.abs(np.array(pos_cen)-cen_y))
if 0 < pos_idx < len(pos_cen)-1:
nn.staff_line_pos = int(pos_idx)
Expand Down
39 changes: 25 additions & 14 deletions oemer/rhythm_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math

import cv2
from cv2.typing import RotatedRect
import scipy.ndimage
import numpy as np
from numpy import ndarray
Expand Down Expand Up @@ -37,7 +38,12 @@ def scan_dot(
# Find the right most bound for scan the dot.
# Should have width less than unit_size, and can't
# touch the nearby note.
cur_scan_line = note_id_map[int(start_y):int(bbox[3]), int(right_bound)]
try:
cur_scan_line = note_id_map[int(start_y):int(bbox[3]), int(right_bound)]
except IndexError as e:
print(e)
break

ids = set(np.unique(cur_scan_line))
if -1 in ids:
ids.remove(-1)
Expand Down Expand Up @@ -134,7 +140,7 @@ def parse_beams(
min_area_ratio: float = 0.07,
min_tp_ratio: float = 0.4,
min_width_ratio: float = 0.2
) -> Tuple[ndarray, List[Tuple[Tuple[float, float], Tuple[float, float], float]], ndarray]:
) -> Tuple[ndarray, List[RotatedRect], ndarray]:
# Fetch parameters
symbols = layers.get_layer('symbols_pred')
staff_pred = layers.get_layer('staff_pred')
Expand All @@ -156,14 +162,14 @@ def parse_beams(
ratio_map = np.copy(poly_map)

null_color = (255, 255, 255)
valid_box = []
valid_box: List[RotatedRect] = []
valid_idxs = []
idx_map = np.zeros_like(poly_map) - 1
for idx, rbox in enumerate(rboxes): # type: ignore
for box_idx, rbox in enumerate(rboxes): # type: ignore
# Used to find indexes of contour areas later. Must be check before
# any 'continue' statement.
idx %= 255 # type: ignore
if idx == 0:
box_idx %= 255 # type: ignore
if box_idx == 0:
idx_map = np.zeros_like(poly_map) - 1

# Get the contour of the rotated box
Expand All @@ -186,8 +192,8 @@ def parse_beams(
continue

# Tricky way to get the index of the contour area
cv2.fillPoly(idx_map, [cnt], color=(idx, 0, 0))
yi, xi = np.where(idx_map[..., 0] == idx)
cv2.fillPoly(idx_map, [cnt], color=(box_idx, 0, 0))
yi, xi = np.where(idx_map[..., 0] == box_idx)
pts = beams[yi, xi]
meta_idx = np.where(pts>0)[0]
ryi = yi[meta_idx]
Expand Down Expand Up @@ -429,7 +435,9 @@ def get_label(nbox, stem_up):
end_x=min(set_box[2], cen_x+half_scan_width),
end_y=end_y,
threshold=threshold
)
)
if count >= len(note_type_map):
return note_type_map[len(note_type_map) - 1]
return note_type_map[count]

if len(nts) == 2:
Expand Down Expand Up @@ -501,7 +509,7 @@ def parse_rhythm(beam_map: ndarray, map_info: Dict[int, Dict[str, Any]], agree_t
rev_map_info[gid] = {'reg': reg, 'bbox': box}

# Define beam count to note type mapping
note_type_map = {
note_type_map: Dict[int, NoteType] = {
0: NoteType.QUARTER,
1: NoteType.EIGHTH,
2: NoteType.SIXTEENTH,
Expand Down Expand Up @@ -580,7 +588,7 @@ def parse_rhythm(beam_map: ndarray, map_info: Dict[int, Dict[str, Any]], agree_t
end_y = gbox[3]

# Calculate how many beams/flags are there.
count = scan_beam_flag( # type: ignore
beam_flag_count = scan_beam_flag( # type: ignore
bin_beam_map,
max(reg_box[0], cen_x-half_scan_width),
start_y,
Expand All @@ -590,12 +598,15 @@ def parse_rhythm(beam_map: ndarray, map_info: Dict[int, Dict[str, Any]], agree_t
)

#cv2.rectangle(beam_img, (gbox[0], gbox[1]), (gbox[2], gbox[3]), (255, 0, 255), 1)
cv2.putText(beam_img, str(count), (int(cen_x), int(gbox[3])+2), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 1)
cv2.putText(beam_img, str(beam_flag_count), (int(cen_x), int(gbox[3])+2), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 1)

# Assign note label
for nid in group.note_ids:
if notes[nid].label is None:
notes[nid].label = note_type_map[count] # type: ignore
if beam_flag_count in note_type_map:
notes[nid].label = note_type_map[beam_flag_count]
else:
notes[nid].invalid = True

return beam_img

Expand All @@ -605,7 +616,7 @@ def extract(
dot_max_area_ratio: float = 0.2,
beam_min_area_ratio: float = 0.07,
agree_th: float = 0.15
) -> Tuple[ndarray, List[Tuple[Tuple[float, float], Tuple[float, float], float]]]:
) -> Tuple[ndarray, List[RotatedRect]]:

logger.debug("Parsing dot")
parse_dot(max_area_ratio=dot_max_area_ratio, min_area_ratio=dot_min_area_ratio)
Expand Down
32 changes: 17 additions & 15 deletions oemer/staffline_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,51 +331,53 @@ def extract(

# Start process
zones, *_ = init_zones(staff_pred, splits=splits)
all_staffs = []
all_staffs: List[List[Staff]] = []
for rr in zones:
print(rr[0], rr[-1], end=' ')
rr = np.array(rr, dtype=np.int64)
staffs = extract_part(staff_pred[:, rr], x_offset=rr[0], line_threshold=line_threshold)
if staffs is not None:
all_staffs.append(staffs)
print(len(staffs))
all_staffs = align_staffs(all_staffs) # type: ignore
aligned_staffs: np.ndarray = align_staffs(all_staffs)

# Use barline information to infer the number of tracks for each group.
num_track = further_infer_track_nums(all_staffs, min_degree=barline_min_degree) # type: ignore
num_track = further_infer_track_nums(aligned_staffs, min_degree=barline_min_degree)
logger.debug(f"Tracks: {num_track}")
for col_sts in all_staffs:
for col_sts in aligned_staffs:
for idx, st in enumerate(col_sts):
st.track = idx % num_track
st.group = idx // num_track

# Validate staffs across zones.
# Should have same number of staffs
if not all([len(staff) == len(all_staffs[0]) for staff in all_staffs]):
if not all([len(staff) == len(aligned_staffs[0]) for staff in aligned_staffs]):
raise Exception
assert all([len(staff) == len(all_staffs[0]) for staff in all_staffs])
assert all([len(staff) == len(aligned_staffs[0]) for staff in aligned_staffs])

norm = lambda data: np.abs(np.array(data) / np.mean(data) - 1)
for staffs in all_staffs.T: # type: ignore
valid_staffs: list[list[Staff]] = []
for staffs in aligned_staffs.T:
# Should all have 5 lines
line_num = [len(staff.lines) for staff in staffs]
if len(set(line_num)) != 1:
raise E.StafflineCountInconsistent(
f"Some of the stafflines contains less or more than 5 lines: {line_num}")
print(f"Some of the stafflines contains less or more than 5 lines: {line_num}")
continue

# Check Staffs that are approximately at the same row.
centers = np.array([staff.y_center for staff in staffs])
if not np.all(norm(centers) < horizontal_diff_th):
raise E.StafflineNotAligned(
f"Centers of staff parts at the same row not aligned (Th: {horizontal_diff_th}): {norm(centers)}")
print(f"Centers of staff parts at the same row not aligned (Th: {horizontal_diff_th}): {norm(centers)}")
continue

# Unit sizes should roughly all the same
unit_size = np.array([staff.unit_size for staff in staffs])
if not np.all(norm(unit_size) < unit_size_diff_th):
raise E.StafflineUnitSizeInconsistent(
f"Unit sizes not consistent (th: {unit_size_diff_th}): {norm(unit_size)}")
if not np.all(norm(unit_size) < unit_size_diff_th):
print(f"Unit sizes not consistent (th: {unit_size_diff_th}): {norm(unit_size)}")
continue
valid_staffs.append(staffs)

return np.array(all_staffs), zones
return np.array(valid_staffs).T, zones


def extract_part(pred: ndarray, x_offset: int, line_threshold: float = 0.8) -> List[Staff]:
Expand Down
22 changes: 15 additions & 7 deletions oemer/symbol_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def parse_clefs_keys(
for box in bboxes:
w = box[2] - box[0]
h = box[3] - box[1]
region = clefs_keys[box[1]:box[3], box[0]:box[2]]
region: ndarray = clefs_keys[box[1]:box[3], box[0]:box[2]]
usize = get_unit_size(*get_center(box))
area_size_ratio = w * h / usize**2
area_tp_ratio = region[region>0].size / (w * h)
Expand Down Expand Up @@ -317,6 +317,8 @@ def parse_rests(line_box: ndarray, unit_size: float) -> Tuple[List[BBox], List[s

bboxes = get_bbox(rests)
bboxes = filter_out_of_range_bbox(bboxes)
if len(bboxes) == 0:
return [], []
bboxes = merge_nearby_bbox(bboxes, unit_size*1.2)
bboxes = rm_merge_overlap_bbox(bboxes)
bboxes = filter_out_small_area(bboxes, area_size_func=lambda usize: usize**2 * 0.7)
Expand Down Expand Up @@ -373,6 +375,9 @@ def get_nearby_note_id(box: BBox, note_id_map: ndarray) -> Union[int, None]:
unit_size = int(round(get_unit_size(cen_x, cen_y)))
nid = None
for x in range(box[2], box[2]+unit_size):
is_in_range = (0 <= cen_y < note_id_map.shape[0]) and (0 <= x < note_id_map.shape[1])
if not is_in_range:
continue
if note_id_map[cen_y, x] != -1:
nid = note_id_map[cen_y, x]
break
Expand Down Expand Up @@ -401,11 +406,14 @@ def gen_sfns(bboxes: List[BBox], labels: List[str]) -> List[Sfn]:
if ss.note_id is not None:
note = notes[ss.note_id]
if ss.track != note.track:
raise E.SfnNoteTrackMismatch(f"Track of sfn and note not mismatch: {ss}\n{note}")
if ss.group != note.group:
raise E.SfnNoteGroupMismatch(f"Group of sfn and note not mismatch: {ss}\n{note}")
notes[ss.note_id].sfn = ss.label
ss.is_key = False
print(f"Track of sfn and note mismatch: {ss}\n{note}")
notes[ss.note_id].invalid = True
elif ss.group != note.group:
print(f"Group of sfn and note mismatch: {ss}\n{note}")
notes[ss.note_id].invalid = True
else:
notes[ss.note_id].sfn = ss.label
ss.is_key = False

sfns.append(ss)
return sfns
Expand All @@ -432,7 +440,7 @@ def gen_rests(bboxes: List[BBox], labels: List[str]) -> List[Rest]:
rr.group = st1.group

unit_size = int(round(get_unit_size(*get_center(box))))
dot_range = range(box[2]+1, box[2]+unit_size)
dot_range = range(box[2]+1, min(box[2]+unit_size, symbols.shape[1] - 1))
dot_region = symbols[box[1]:box[3], dot_range]
if 0 < np.sum(dot_region) < unit_size**2 / 7:
rr.has_dot = True
Expand Down

0 comments on commit 49cef64

Please sign in to comment.