diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e1718031..2daa2af8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: hooks: - id: pytest-on-commit name: Running single sample test - entry: pytest -rfpsxEX --disable-warnings --verbose -k sample1 + entry: python3 -m pytest -rfpsxEX --disable-warnings --verbose -k sample1 language: system pass_filenames: false always_run: true @@ -51,7 +51,7 @@ repos: hooks: - id: pytest-on-push name: Running all tests before push... - entry: pytest -rfpsxEX --disable-warnings --verbose --durations=3 + entry: python3 -m pytest -rfpsxEX --disable-warnings --verbose --durations=3 language: system pass_filenames: false always_run: true diff --git a/README.md b/README.md index 4444a851..cd9acb4d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,10 @@ # OMR Checker -Read OMRs fast and accurately using a scanner 🖨 or your phone 🤳. +Read OMR sheets fast and accurately using a scanner 🖨 or your phone 🤳. + +## What is OMR? + +OMR stands for Optical Mark Recognition, used to detect and interpret human-marked data on documents. OMR refers to the process of reading and evaluating OMR sheets, commonly used in exams, surveys, and other forms. #### **Quick Links** - [Installation](#getting-started) diff --git a/main.py b/main.py index 2b0b1bd3..f8fc09cb 100644 --- a/main.py +++ b/main.py @@ -12,7 +12,7 @@ from pathlib import Path from src.entry import entry_point -from src.logger import logger +from src.utils.logger import logger def parse_args(): @@ -88,12 +88,18 @@ def entry_point_for_args(args): # Disable tracebacks sys.tracebacklimit = 0 # TODO: set log levels - for root in args["input_paths"]: - entry_point( - Path(root), - args, - ) + try: + entry_point( + Path(root), + args, + ) + except Exception: + if args["debug"] is True: + logger.critical( + f"OMRChecker crashed. add --debug and run again to see error details" + ) + raise if __name__ == "__main__": diff --git a/samples/answer-key/using-csv/adrian_omr.png b/samples/answer-key/using-csv/adrian_omr.png new file mode 100644 index 00000000..d8db0994 Binary files /dev/null and b/samples/answer-key/using-csv/adrian_omr.png differ diff --git a/samples/sample2/answer_key.csv b/samples/answer-key/using-csv/answer_key.csv similarity index 58% rename from samples/sample2/answer_key.csv rename to samples/answer-key/using-csv/answer_key.csv index 0708b190..566201d1 100644 --- a/samples/sample2/answer_key.csv +++ b/samples/answer-key/using-csv/answer_key.csv @@ -1,5 +1,5 @@ -q1,B +q1,C q2,E q3,A -q4,C +q4,B q5,B \ No newline at end of file diff --git a/samples/sample2/evaluation.json b/samples/answer-key/using-csv/evaluation.json similarity index 90% rename from samples/sample2/evaluation.json rename to samples/answer-key/using-csv/evaluation.json index 1dec9254..14a3db25 100644 --- a/samples/sample2/evaluation.json +++ b/samples/answer-key/using-csv/evaluation.json @@ -4,7 +4,7 @@ "answer_key_csv_path": "answer_key.csv", "should_explain_scoring": true }, - "marking_scheme": { + "marking_schemes": { "DEFAULT": { "correct": "1", "incorrect": "0", diff --git a/samples/answer-key/using-csv/template.json b/samples/answer-key/using-csv/template.json new file mode 100644 index 00000000..41ec9ffa --- /dev/null +++ b/samples/answer-key/using-csv/template.json @@ -0,0 +1,35 @@ +{ + "pageDimensions": [ + 300, + 400 + ], + "bubbleDimensions": [ + 25, + 25 + ], + "preProcessors": [ + { + "name": "CropPage", + "options": { + "morphKernel": [ + 10, + 10 + ] + } + } + ], + "fieldBlocks": { + "MCQ_Block_1": { + "fieldType": "QTYPE_MCQ5", + "origin": [ + 65, + 60 + ], + "fieldLabels": [ + "q1..5" + ], + "labelsGap": 52, + "bubblesGap": 41 + } + } +} diff --git a/samples/answer-key/weighted-answers/evaluation.json b/samples/answer-key/weighted-answers/evaluation.json new file mode 100644 index 00000000..c0daefcf --- /dev/null +++ b/samples/answer-key/weighted-answers/evaluation.json @@ -0,0 +1,35 @@ +{ + "source_type": "custom", + "options": { + "questions_in_order": [ + "q1..5" + ], + "answers_in_order": [ + "C", + "E", + [ + "A", + "C" + ], + [ + [ + "B", + 2 + ], + [ + "C", + "3/2" + ] + ], + "C" + ], + "should_explain_scoring": true + }, + "marking_schemes": { + "DEFAULT": { + "correct": "3", + "incorrect": "-1", + "unmarked": "0" + } + } +} diff --git a/samples/answer-key/weighted-answers/images/adrian_omr.png b/samples/answer-key/weighted-answers/images/adrian_omr.png new file mode 100644 index 00000000..69a53823 Binary files /dev/null and b/samples/answer-key/weighted-answers/images/adrian_omr.png differ diff --git a/samples/answer-key/weighted-answers/images/adrian_omr_2.png b/samples/answer-key/weighted-answers/images/adrian_omr_2.png new file mode 100644 index 00000000..d8db0994 Binary files /dev/null and b/samples/answer-key/weighted-answers/images/adrian_omr_2.png differ diff --git a/samples/answer-key/weighted-answers/template.json b/samples/answer-key/weighted-answers/template.json new file mode 100644 index 00000000..41ec9ffa --- /dev/null +++ b/samples/answer-key/weighted-answers/template.json @@ -0,0 +1,35 @@ +{ + "pageDimensions": [ + 300, + 400 + ], + "bubbleDimensions": [ + 25, + 25 + ], + "preProcessors": [ + { + "name": "CropPage", + "options": { + "morphKernel": [ + 10, + 10 + ] + } + } + ], + "fieldBlocks": { + "MCQ_Block_1": { + "fieldType": "QTYPE_MCQ5", + "origin": [ + 65, + 60 + ], + "fieldLabels": [ + "q1..5" + ], + "labelsGap": 52, + "bubblesGap": 41 + } + } +} diff --git a/samples/community/UPSC-mock/evaluation.json b/samples/community/UPSC-mock/evaluation.json index 33a6d8e8..42fbec8b 100644 --- a/samples/community/UPSC-mock/evaluation.json +++ b/samples/community/UPSC-mock/evaluation.json @@ -8,7 +8,7 @@ ], "should_explain_scoring": true }, - "marking_scheme": { + "marking_schemes": { "DEFAULT": { "correct": "2", "incorrect": "-2/3", diff --git a/samples/community/UmarFarootAPS/answer_key.csv b/samples/community/UmarFarootAPS/answer_key.csv index 39a9b6dc..b40e8959 100644 --- a/samples/community/UmarFarootAPS/answer_key.csv +++ b/samples/community/UmarFarootAPS/answer_key.csv @@ -2,8 +2,8 @@ q1,C q2,C q3,"D,E" q4,"A,AB" -q5,"['A', ['1', '-2/3', '0']]" -q6,"['A', ['2']]" +q5,"[['A', '1'], ['B', '2']]" +q6,"['A', 'B']" q7,C q8,D q9,B diff --git a/samples/community/UmarFarootAPS/evaluation.json b/samples/community/UmarFarootAPS/evaluation.json index 1dec9254..14a3db25 100644 --- a/samples/community/UmarFarootAPS/evaluation.json +++ b/samples/community/UmarFarootAPS/evaluation.json @@ -4,7 +4,7 @@ "answer_key_csv_path": "answer_key.csv", "should_explain_scoring": true }, - "marking_scheme": { + "marking_schemes": { "DEFAULT": { "correct": "1", "incorrect": "0", diff --git a/samples/sample4/evaluation.json b/samples/sample4/evaluation.json index a2d68f24..ec9b5071 100644 --- a/samples/sample4/evaluation.json +++ b/samples/sample4/evaluation.json @@ -11,7 +11,11 @@ "B", "D", "C", - "BC", + [ + "B", + "C", + "BC" + ], "A", "C", "D", @@ -19,7 +23,7 @@ ], "should_explain_scoring": true }, - "marking_scheme": { + "marking_schemes": { "DEFAULT": { "correct": "3", "incorrect": "-1", diff --git a/samples/sample5/evaluation.json b/samples/sample5/evaluation.json index 8abc70a3..3332fe6a 100644 --- a/samples/sample5/evaluation.json +++ b/samples/sample5/evaluation.json @@ -33,7 +33,7 @@ ], "should_explain_scoring": true }, - "marking_scheme": { + "marking_schemes": { "DEFAULT": { "correct": "1", "incorrect": "0", diff --git a/src/__init__.py b/src/__init__.py index 7fc528b4..1fb98d93 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,5 +1,5 @@ # https://docs.python.org/3/tutorial/modules.html#:~:text=The%20__init__.py,on%20the%20module%20search%20path. -from src.logger import logger +from src.utils.logger import logger # It takes a few seconds for the imports logger.info(f"Loading OMRChecker modules...") diff --git a/src/algorithm/__init__.py b/src/algorithm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/algorithm/core.py b/src/algorithm/core.py new file mode 100644 index 00000000..1f875086 --- /dev/null +++ b/src/algorithm/core.py @@ -0,0 +1,944 @@ +""" + + OMRChecker + + Author: Udayraj Deshmukh + Github: https://github.com/Udayraj123 + +""" + +import math +import os +import random +import re +from collections import defaultdict +from copy import deepcopy +from typing import Any + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +from matplotlib import colormaps + +import src.utils.constants as constants +from src.algorithm.detection import BubbleMeanValue, FieldStdMeanValue +from src.utils.image import ImageUtils +from src.utils.logger import logger + + +class ImageInstanceOps: + """Class to hold fine-tuned utilities for a group of images. One instance for each processing directory.""" + + save_img_list: Any = defaultdict(list) + + def __init__(self, tuning_config): + super().__init__() + self.tuning_config = tuning_config + self.save_image_level = tuning_config.outputs.save_image_level + + def apply_preprocessors(self, file_path, in_omr, template): + tuning_config = self.tuning_config + # resize to conform to common preprocessor input requirements + in_omr = ImageUtils.resize_util( + in_omr, + tuning_config.dimensions.processing_width, + tuning_config.dimensions.processing_height, + ) + # Copy template for this instance op + template = deepcopy(template) + # run pre_processors in sequence + for pre_processor in template.pre_processors: + result = pre_processor.apply_filter(in_omr, template, file_path) + out_omr, next_template = ( + result if type(result) is tuple else (result, template) + ) + # resize the image if its shape is changed by the filter + if out_omr.shape[:2] != in_omr.shape[:2]: + out_omr = ImageUtils.resize_util( + out_omr, + tuning_config.dimensions.processing_width, + tuning_config.dimensions.processing_height, + ) + in_omr = out_omr + template = next_template + return in_omr, template + + def read_omr_response(self, template, image, name, save_dir=None): + config = self.tuning_config + + img = image.copy() + # origDim = img.shape[:2] + img = ImageUtils.resize_util( + img, template.page_dimensions[0], template.page_dimensions[1] + ) + if img.max() > img.min(): + img = ImageUtils.normalize_util(img) + # Processing copies + transp_layer = img.copy() + final_marked = img.copy() + + # Move them to data class if needed + # Overlay Transparencies + alpha = 0.65 + omr_response = {} + multi_marked, multi_roll = 0, 0 + + # TODO Make this part useful for visualizing status checks + # blackVals=[0] + # whiteVals=[255] + + # if config.outputs.show_image_level >= 5: + # all_c_box_vals = {"int": [], "mcq": []} + # # TODO: simplify this logic + # q_nums = {"int": [], "mcq": []} + + # Get mean bubbleValues n other stats + ( + global_bubble_means_and_refs, + field_number_to_field_bubble_means, + global_field_bubble_means_stds, + ) = ( + [], + [], + [], + ) + for field_block in template.field_blocks: + field_bubble_means_stds = [] + box_w, box_h = field_block.bubble_dimensions + for field in field_block.fields: + field_bubbles = field.field_bubbles + field_bubble_means = [] + for unit_bubble in field_bubbles: + # TODO: move this responsibility into the plugin(not pre-processor) (of shifting every point) + # shifted + x, y = ( + unit_bubble.x + field_block.shift_x, + unit_bubble.y + field_block.shift_y, + ) + rect = [y, y + box_h, x, x + box_w] + mean_value = cv2.mean(img[rect[0] : rect[1], rect[2] : rect[3]])[0] + field_bubble_means.append( + BubbleMeanValue(mean_value, unit_bubble) + # TODO: cross/check mark detection support (#167) + # detectCross(img, rect) ? 0 : 255 + ) + + # TODO: move std calculation inside the class + field_std = round( + np.std([item.mean_value for item in field_bubble_means]), 2 + ) + field_bubble_means_stds.append( + FieldStdMeanValue(field_std, field_block) + ) + + field_number_to_field_bubble_means.append(field_bubble_means) + global_bubble_means_and_refs.extend(field_bubble_means) + global_field_bubble_means_stds.extend(field_bubble_means_stds) + + ( + PAGE_TYPE_FOR_THRESHOLD, + GLOBAL_PAGE_THRESHOLD_WHITE, + GLOBAL_PAGE_THRESHOLD_BLACK, + MIN_JUMP, + JUMP_DELTA, + GLOBAL_THRESHOLD_MARGIN, + MIN_JUMP_SURPLUS_FOR_GLOBAL_FALLBACK, + CONFIDENT_JUMP_SURPLUS_FOR_DISPARITY, + ) = map( + config.threshold_params.get, + [ + "PAGE_TYPE_FOR_THRESHOLD", + "GLOBAL_PAGE_THRESHOLD_WHITE", + "GLOBAL_PAGE_THRESHOLD_BLACK", + "MIN_JUMP", + "JUMP_DELTA", + "GLOBAL_THRESHOLD_MARGIN", + "MIN_JUMP_SURPLUS_FOR_GLOBAL_FALLBACK", + "CONFIDENT_JUMP_SURPLUS_FOR_DISPARITY", + ], + ) + global_default_threshold = ( + GLOBAL_PAGE_THRESHOLD_WHITE + if PAGE_TYPE_FOR_THRESHOLD == "white" + else GLOBAL_PAGE_THRESHOLD_BLACK + ) + # TODO: see if this is needed, then take from config.json + MIN_JUMP_STD = 15 + JUMP_DELTA_STD = 5 + global_default_std_threshold = 10 + global_std_thresh, _, _ = self.get_global_threshold( + global_field_bubble_means_stds, + global_default_std_threshold, + MIN_JUMP=MIN_JUMP_STD, + JUMP_DELTA=JUMP_DELTA_STD, + plot_title="Q-wise Std-dev Plot", + plot_show=config.outputs.show_image_level >= 5, + sort_in_plot=True, + ) + # plt.show() + # hist = getPlotImg() + # InteractionUtils.show("StdHist", hist, 0, 1,config=config) + + # Note: Plotting takes Significant times here --> Change Plotting args + # to support show_image_level + global_threshold_for_template, j_low, j_high = self.get_global_threshold( + global_bubble_means_and_refs, # , looseness=4 + global_default_threshold, + plot_title="Mean Intensity Barplot", + MIN_JUMP=MIN_JUMP, + JUMP_DELTA=JUMP_DELTA, + plot_show=config.outputs.show_image_level >= 5, + sort_in_plot=True, + looseness=4, + ) + global_max_jump = j_high - j_low + + logger.info( + f"Thresholding:\t global_threshold_for_template: {round(global_threshold_for_template, 2)} \tglobal_std_THR: {round(global_std_thresh, 2)}\t{'(Looks like a Xeroxed OMR)' if (global_threshold_for_template == 255) else ''}" + ) + # plt.show() + # hist = getPlotImg() + # InteractionUtils.show("StdHist", hist, 0, 1,config=config) + + # if(config.outputs.show_image_level>=1): + # hist = getPlotImg() + # InteractionUtils.show("Hist", hist, 0, 1,config=config) + # appendSaveImg(4,hist) + # appendSaveImg(5,hist) + # appendSaveImg(2,hist) + + per_omr_threshold_avg, absolute_field_number = 0, 0 + global_field_confidence_metrics = [] + for field_block in template.field_blocks: + block_field_number = 1 + key = field_block.name[:3] + box_w, box_h = field_block.bubble_dimensions + + for field in field_block.fields: + field_bubbles = field.field_bubbles + # All Black or All White case + no_outliers = ( + # TODO: rename mean_value in parent class to suit better + global_field_bubble_means_stds[absolute_field_number].mean_value + < global_std_thresh + ) + # print(absolute_field_number, field.field_label, + # global_field_bubble_means_stds[absolute_field_number].mean_value, "no_outliers:", no_outliers) + + field_bubble_means = field_number_to_field_bubble_means[ + absolute_field_number + ] + + ( + local_threshold_for_field, + local_max_jump, + ) = self.get_local_threshold( + field_bubble_means, + global_threshold_for_template, + no_outliers, + plot_title=f"Mean Intensity Barplot for {key}.{field.field_label}.block{block_field_number}", + # plot_show=field.field_label in ["q72", "q52", "roll5"], # Temp + # plot_show=field.field_label in ["q70", "q69"], # Temp + plot_show=config.outputs.show_image_level >= 6, + ) + # TODO: move get_local_threshold into FieldDetection + field.local_threshold = local_threshold_for_field + # print(field.field_label,key,block_field_number, "THR: ", + # round(local_threshold_for_field,2)) + per_omr_threshold_avg += local_threshold_for_field + + # TODO: @staticmethod + def apply_field_detection( + field, + field_bubble_means, + local_threshold_for_field, + global_threshold_for_template, + ): + # TODO: see if deepclone is really needed given parent's instance + # field_bubble_means = [ + # deepcopy(bubble) for bubble in field_bubble_means + # ] + + bubbles_in_doubt = { + "by_disparity": [], + "by_jump": [], + "global_higher": [], + "global_lower": [], + } + + for bubble in field_bubble_means: + global_bubble_is_marked = ( + global_threshold_for_template > bubble.mean_value + ) + local_bubble_is_marked = ( + local_threshold_for_field > bubble.mean_value + ) + bubble.is_marked = local_bubble_is_marked + # 1. Disparity in global/local threshold output + if global_bubble_is_marked != local_bubble_is_marked: + bubbles_in_doubt["by_disparity"].append(bubble) + + # 5. High confidence if the gap is very large compared to MIN_JUMP + is_global_jump_confident = ( + global_max_jump + > MIN_JUMP + CONFIDENT_JUMP_SURPLUS_FOR_DISPARITY + ) + is_local_jump_confident = ( + local_max_jump > MIN_JUMP + CONFIDENT_JUMP_SURPLUS_FOR_DISPARITY + ) + + # TODO: FieldDetection.bubbles = field_bubble_means + thresholds_string = f"global={round(global_threshold_for_template,2)} local={round(local_threshold_for_field,2)} global_margin={GLOBAL_THRESHOLD_MARGIN}" + jumps_string = f"global_max_jump={round(global_max_jump,2)} local_max_jump={round(local_max_jump,2)} MIN_JUMP={MIN_JUMP} SURPLUS={CONFIDENT_JUMP_SURPLUS_FOR_DISPARITY}" + if len(bubbles_in_doubt["by_disparity"]) > 0: + logger.warning( + f"found disparity in field: {field.field_label}", + list(map(str, bubbles_in_doubt["by_disparity"])), + thresholds_string, + ) + # 5.2 if the gap is very large compared to MIN_JUMP, but still there is disparity + if is_global_jump_confident: + logger.warning( + f"is_global_jump_confident but still has disparity", + jumps_string, + ) + elif is_local_jump_confident: + logger.warning( + f"is_local_jump_confident but still has disparity", + jumps_string, + ) + else: + logger.info( + f"party_matched for field: {field.field_label}", + thresholds_string, + ) + # 5.1 High confidence if the gap is very large compared to MIN_JUMP + if is_local_jump_confident: + # Higher weightage for confidence + logger.info( + f"is_local_jump_confident = increased confidence", + jumps_string, + ) + # No output disparity, but - + # 2.1 global threshold is "too close" to lower bubbles + bubbles_in_doubt["global_lower"] = [ + bubble + for bubble in field_bubble_means + if GLOBAL_THRESHOLD_MARGIN + > max( + GLOBAL_THRESHOLD_MARGIN, + global_threshold_for_template - bubble.mean_value, + ) + ] + + if len(bubbles_in_doubt["global_lower"]) > 0: + logger.warning( + 'bubbles_in_doubt["global_lower"]', + list(map(str, bubbles_in_doubt["global_lower"])), + ) + # 2.2 global threshold is "too close" to higher bubbles + bubbles_in_doubt["global_higher"] = [ + bubble + for bubble in field_bubble_means + if GLOBAL_THRESHOLD_MARGIN + > max( + GLOBAL_THRESHOLD_MARGIN, + bubble.mean_value - global_threshold_for_template, + ) + ] + + if len(bubbles_in_doubt["global_higher"]) > 0: + logger.warning( + 'bubbles_in_doubt["global_higher"]', + list(map(str, bubbles_in_doubt["global_higher"])), + ) + + # 3. local jump outliers are close to the configured min_jump but below it. + # Note: This factor indicates presence of cases like partially filled bubbles, + # mis-aligned box boundaries or some form of unintentional marking over the bubble + if len(field_bubble_means) > 1: + # TODO: move util + def get_jumps_in_bubble_means(field_bubble_means): + # get sorted array + sorted_field_bubble_means = sorted( + field_bubble_means, + ) + # get jumps + jumps_in_bubble_means = [] + previous_bubble = sorted_field_bubble_means[0] + previous_mean = previous_bubble.mean_value + for i in range(1, len(sorted_field_bubble_means)): + bubble = sorted_field_bubble_means[i] + current_mean = bubble.mean_value + jumps_in_bubble_means.append( + [ + round(current_mean - previous_mean, 2), + previous_bubble, + ] + ) + previous_bubble = bubble + previous_mean = current_mean + return jumps_in_bubble_means + + jumps_in_bubble_means = get_jumps_in_bubble_means( + field_bubble_means + ) + bubbles_in_doubt["by_jump"] = [ + bubble + for jump, bubble in jumps_in_bubble_means + if MIN_JUMP_SURPLUS_FOR_GLOBAL_FALLBACK + > max( + MIN_JUMP_SURPLUS_FOR_GLOBAL_FALLBACK, + MIN_JUMP - jump, + ) + ] + + if len(bubbles_in_doubt["by_jump"]) > 0: + logger.warning( + 'bubbles_in_doubt["by_jump"]', + list(map(str, bubbles_in_doubt["by_jump"])), + ) + logger.warning( + list(map(str, jumps_in_bubble_means)), + ) + + # TODO: aggregate the bubble metrics into the Field objects + # collect_bubbles_in_doubt(bubbles_in_doubt["by_disparity"], bubbles_in_doubt["global_higher"], bubbles_in_doubt["global_lower"], bubbles_in_doubt["by_jump"]) + confidence_metrics = { + "bubbles_in_doubt": bubbles_in_doubt, + "is_global_jump_confident": is_global_jump_confident, + "is_local_jump_confident": is_local_jump_confident, + "local_max_jump": local_max_jump, + "field_label": field.field_label, + } + return field_bubble_means, confidence_metrics + + field_bubble_means, confidence_metrics = apply_field_detection( + field, + field_bubble_means, + local_threshold_for_field, + global_threshold_for_template, + ) + global_field_confidence_metrics.append(confidence_metrics) + for bubble_detection in field_bubble_means: + bubble = bubble_detection.item_reference + x, y, field_value = ( + bubble.x + field_block.shift_x, + bubble.y + field_block.shift_y, + bubble.field_value, + ) + if bubble_detection.is_marked: + # Draw the shifted box + cv2.rectangle( + final_marked, + (int(x + box_w / 12), int(y + box_h / 12)), + ( + int(x + box_w - box_w / 12), + int(y + box_h - box_h / 12), + ), + constants.CLR_DARK_GRAY, + 3, + ) + + cv2.putText( + final_marked, + str(field_value), + (x, y), + cv2.FONT_HERSHEY_SIMPLEX, + constants.TEXT_SIZE, + (20, 20, 10), + int(1 + 3.5 * constants.TEXT_SIZE), + ) + else: + cv2.rectangle( + final_marked, + (int(x + box_w / 10), int(y + box_h / 10)), + ( + int(x + box_w - box_w / 10), + int(y + box_h - box_h / 10), + ), + constants.CLR_GRAY, + -1, + ) + + detected_bubbles = [ + bubble_detection + for bubble_detection in field_bubble_means + if bubble_detection.is_marked + ] + for bubble_detection in detected_bubbles: + bubble = bubble_detection.item_reference + field_label, field_value = ( + bubble.field_label, + bubble.field_value, + ) + multi_marked_local = field_label in omr_response + # Apply concatenation + omr_response[field_label] = ( + (omr_response[field_label] + field_value) + if multi_marked_local + else field_value + ) + # TODO: generalize this into rolls -> identifier + # Only send rolls multi-marked in the directory () + # multi_roll = multi_marked_local and "Roll" in str(q) + + multi_marked = multi_marked or multi_marked_local + + # Empty value logic + if len(detected_bubbles) == 0: + field_label = field.field_label + omr_response[field_label] = field_block.empty_val + + # TODO: fix after all_c_box_vals is refactored + # if config.outputs.show_image_level >= 5: + # if key in all_c_box_vals: + # q_nums[key].append(f"{key[:2]}_c{str(block_field_number)}") + # all_c_box_vals[key].append(field_number_to_field_bubble_means[absolute_field_number]) + + block_field_number += 1 + absolute_field_number += 1 + # /for field_block + + # TODO: aggregate with weightages + # overall_confidence = self.get_confidence_metrics(fields_confidence) + # underconfident_fields = filter(lambda x: x.confidence < 0.8, fields_confidence) + + # TODO: Make the plot for underconfident_fields + # logger.info(name, overall_confidence, underconfident_fields) + + per_omr_threshold_avg /= absolute_field_number + per_omr_threshold_avg = round(per_omr_threshold_avg, 2) + # Translucent + cv2.addWeighted(final_marked, alpha, transp_layer, 1 - alpha, 0, final_marked) + + # TODO: refactor all_c_box_vals + # Box types + # if config.outputs.show_image_level >= 5: + # # plt.draw() + # f, axes = plt.subplots(len(all_c_box_vals), sharey=True) + # f.canvas.manager.set_window_title( + # f"Bubble Intensity by question type for {name}" + # ) + # ctr = 0 + # type_name = { + # "int": "Integer", + # "mcq": "MCQ", + # "med": "MED", + # "rol": "Roll", + # } + # for k, boxvals in all_c_box_vals.items(): + # axes[ctr].title.set_text(type_name[k] + " Type") + # axes[ctr].boxplot(boxvals) + # # thrline=axes[ctr].axhline(per_omr_threshold_avg,color='red',ls='--') + # # thrline.set_label("Average THR") + # axes[ctr].set_ylabel("Intensity") + # axes[ctr].set_xticklabels(q_nums[k]) + # # axes[ctr].legend() + # ctr += 1 + # # imshow will do the waiting + # plt.tight_layout(pad=0.5) + # plt.show() + + if config.outputs.save_detections and save_dir is not None: + if multi_roll: + save_dir = save_dir.joinpath("_MULTI_") + image_path = str(save_dir.joinpath(name)) + ImageUtils.save_img(image_path, final_marked) + + self.append_save_img(2, final_marked) + + if save_dir is not None: + for i in range(config.outputs.save_image_level): + self.save_image_stacks(i + 1, name, save_dir) + + return ( + omr_response, + final_marked, + multi_marked, + multi_roll, + field_number_to_field_bubble_means, + global_threshold_for_template, + global_field_confidence_metrics, + ) + + # def get_confidence_metrics(self): + # config = self.tuning_config + # overall_confidence, fields_confidence = 0.0, [] + # PAGE_TYPE_FOR_THRESHOLD = map( + # config.threshold_params.get, ["PAGE_TYPE_FOR_THRESHOLD"] + # ) + # # Note: currently building with assumptions + # if PAGE_TYPE_FOR_THRESHOLD == "black": + # logger.warning(f"Confidence metric not implemented for black pages yet") + # return 0.0, [] + # # global_threshold_for_template + # # field + # return overall_confidence, fields_confidence + + @staticmethod + def draw_template_layout(img, template, shifted=True, draw_qvals=False, border=-1): + img = ImageUtils.resize_util( + img, template.page_dimensions[0], template.page_dimensions[1] + ) + final_align = img.copy() + for field_block in template.field_blocks: + field_block_name, s, d, bubble_dimensions, shift_x, shift_y = map( + lambda attr: getattr(field_block, attr), + [ + "name", + "origin", + "dimensions", + "bubble_dimensions", + "shift_x", + "shift_y", + ], + ) + box_w, box_h = bubble_dimensions + + if shifted: + cv2.rectangle( + final_align, + (s[0] + shift_x, s[1] + shift_y), + (s[0] + shift_x + d[0], s[1] + shift_y + d[1]), + constants.CLR_BLACK, + 3, + ) + else: + cv2.rectangle( + final_align, + (s[0], s[1]), + (s[0] + d[0], s[1] + d[1]), + constants.CLR_BLACK, + 3, + ) + for field in field_block.fields: + field_bubbles = field.field_bubbles + for unit_bubble in field_bubbles: + x, y = ( + (unit_bubble.x + shift_x, unit_bubble.y + shift_y) + if shifted + else (unit_bubble.x, unit_bubble.y) + ) + cv2.rectangle( + final_align, + (int(x + box_w / 10), int(y + box_h / 10)), + (int(x + box_w - box_w / 10), int(y + box_h - box_h / 10)), + constants.CLR_GRAY, + border, + ) + + if draw_qvals: + rect = [y, y + box_h, x, x + box_w] + cv2.putText( + final_align, + f"{int(cv2.mean(img[rect[0] : rect[1], rect[2] : rect[3]])[0])}", + (rect[2] + 2, rect[0] + (box_h * 2) // 3), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + constants.CLR_BLACK, + 2, + ) + + if shifted: + text_in_px = cv2.getTextSize( + field_block_name, cv2.FONT_HERSHEY_SIMPLEX, constants.TEXT_SIZE, 4 + ) + cv2.putText( + final_align, + field_block_name, + ( + int(s[0] + shift_x + d[0] - text_in_px[0][0]), + int(s[1] + shift_y - text_in_px[0][1]), + ), + cv2.FONT_HERSHEY_SIMPLEX, + constants.TEXT_SIZE, + constants.CLR_BLACK, + 4, + ) + + return final_align + + def get_global_threshold( + self, + bubble_means_and_refs, + global_default_threshold, + MIN_JUMP, + JUMP_DELTA, + plot_title=None, + plot_show=True, + sort_in_plot=True, + looseness=1, + ): + """ + Note: Cannot assume qStrip has only-gray or only-white bg + (in which case there is only one jump). + So there will be either 1 or 2 jumps. + 1 Jump : + ...... + |||||| + |||||| <-- risky THR + |||||| <-- safe THR + ....|||||| + |||||||||| + + 2 Jumps : + ...... + |||||| <-- wrong THR + ....|||||| + |||||||||| <-- safe THR + ..|||||||||| + |||||||||||| + + The abstract "First LARGE GAP" is perfect for this. + Current code is considering ONLY TOP 2 jumps(>= MIN_GAP) to be big, + gives the smaller one (looseness factor) + + """ + # Sort the Q bubbleValues + sorted_bubble_means_and_refs = sorted( + bubble_means_and_refs, + ) + sorted_bubble_means = [item.mean_value for item in sorted_bubble_means_and_refs] + + # Find the FIRST LARGE GAP and set it as threshold: + ls = (looseness + 1) // 2 + l = len(sorted_bubble_means) - ls + max1, thr1 = MIN_JUMP, global_default_threshold + for i in range(ls, l): + jump = sorted_bubble_means[i + ls] - sorted_bubble_means[i - ls] + if jump > max1: + max1 = jump + thr1 = sorted_bubble_means[i - ls] + jump / 2 + + # NOTE: thr2 is deprecated, thus is JUMP_DELTA + # TODO: make use of outliers using percentile logic and report the benchmarks + # Make use of the fact that the JUMP_DELTA(Vertical gap ofc) between + # values at detected jumps would be atleast 20 + max2, thr2 = MIN_JUMP, global_default_threshold + # Requires atleast 1 gray box to be present (Roll field will ensure this) + for i in range(ls, l): + jump = sorted_bubble_means[i + ls] - sorted_bubble_means[i - ls] + new_thr = sorted_bubble_means[i - ls] + jump / 2 + if jump > max2 and abs(thr1 - new_thr) > JUMP_DELTA: + max2 = jump + thr2 = new_thr + # global_threshold_for_template = min(thr1,thr2) + global_threshold_for_template, j_low, j_high = ( + thr1, + thr1 - max1 // 2, + thr1 + max1 // 2, + ) + + # # For normal images + # thresholdRead = 116 + # if(thr1 > thr2 and thr2 > thresholdRead): + # print("Note: taking safer thr line.") + # global_threshold_for_template, j_low, j_high = thr2, thr2 - max2//2, thr2 + max2//2 + + if plot_title: + _, ax = plt.subplots() + # TODO: move into individual utils + plot_means_and_refs = ( + sorted_bubble_means_and_refs if sort_in_plot else bubble_means_and_refs + ) + plot_values = [x.mean_value for x in plot_means_and_refs] + original_bin_names = [ + x.item_reference.plot_bin_name for x in plot_means_and_refs + ] + plot_labels = [x.item_reference_name for x in plot_means_and_refs] + + # TODO: move into individual utils + sorted_unique_bin_names, unique_label_indices = np.unique( + original_bin_names, return_inverse=True + ) + + plot_color_sampler = colormaps["Spectral"].resampled( + len(sorted_unique_bin_names) + ) + + shuffled_color_indices = random.sample( + list(unique_label_indices), len(unique_label_indices) + ) + # logger.info(list(zip(original_bin_names, shuffled_color_indices))) + plot_colors = plot_color_sampler( + [shuffled_color_indices[i] for i in unique_label_indices] + ) + # plot_colors = plot_color_sampler(unique_label_indices) + bar_container = ax.bar( + range(len(plot_means_and_refs)), + plot_values, + color=plot_colors, + label=plot_labels, + ) + + # TODO: move into individual utils + low = min(plot_values) + high = max(plot_values) + margin_factor = 0.1 + plt.ylim( + [ + math.ceil(low - margin_factor * (high - low)), + math.ceil(high + margin_factor * (high - low)), + ] + ) + + # Show field labels + ax.bar_label(bar_container, labels=plot_labels) + handles, labels = ax.get_legend_handles_labels() + # Naturally sorted unique legend labels https://stackoverflow.com/a/27512450/6242649 + ax.legend( + *zip( + *sorted( + [ + (h, l) + for i, (h, l) in enumerate(zip(handles, labels)) + if l not in labels[:i] + ], + key=lambda s: [ + int(t) if t.isdigit() else t.lower() + for t in re.split("(\\d+)", s[1]) + ], + ) + ) + ) + ax.set_title(plot_title) + ax.axhline( + global_threshold_for_template, color="green", ls="--", linewidth=5 + ).set_label("Global Threshold") + ax.axhline(thr2, color="red", ls=":", linewidth=3).set_label("THR2 Line") + # ax.axhline(j_low,color='red',ls='-.', linewidth=3) + # ax.axhline(j_high,color='red',ls='-.', linewidth=3).set_label("Boundary Line") + # ax.set_ylabel("Mean Intensity") + ax.set_ylabel("Values") + ax.set_xlabel("Position") + + if plot_show: + plt.title(plot_title) + plt.show() + + return global_threshold_for_template, j_low, j_high + + def get_local_threshold( + self, + bubble_means_and_refs, + global_threshold_for_template, + no_outliers, + plot_title=None, + plot_show=True, + ): + """ + TODO: Update this documentation too- + //No more - Assumption : Colwise background color is uniformly gray or white, + but not alternating. In this case there is atmost one jump. + + 0 Jump : + <-- safe THR? + ....... + ...||||||| + |||||||||| <-- safe THR? + // How to decide given range is above or below gray? + -> global bubble_means_list shall absolutely help here. Just run same function + on total bubble_means_list instead of colwise _// + How to decide it is this case of 0 jumps + + 1 Jump : + ...... + |||||| + |||||| <-- risky THR + |||||| <-- safe THR + ....|||||| + |||||||||| + + """ + config = self.tuning_config + # Sort the Q bubbleValues + sorted_bubble_means_and_refs = sorted( + bubble_means_and_refs, + ) + sorted_bubble_means = [item.mean_value for item in sorted_bubble_means_and_refs] + # Small no of pts cases: + # base case: 1 or 2 pts + if len(sorted_bubble_means) < 3: + max1, thr1 = config.threshold_params.MIN_JUMP, ( + global_threshold_for_template + if np.max(sorted_bubble_means) - np.min(sorted_bubble_means) + < config.threshold_params.MIN_GAP + else np.mean(sorted_bubble_means) + ) + else: + l = len(sorted_bubble_means) - 1 + max1, thr1 = config.threshold_params.MIN_JUMP, 255 + for i in range(1, l): + jump = sorted_bubble_means[i + 1] - sorted_bubble_means[i - 1] + if jump > max1: + max1 = jump + thr1 = sorted_bubble_means[i - 1] + jump / 2 + # print(field_label,sorted_bubble_means,max1) + + confident_jump = ( + config.threshold_params.MIN_JUMP + + config.threshold_params.MIN_JUMP_SURPLUS_FOR_GLOBAL_FALLBACK + ) + + # TODO: seek improvement here because of the empty cases failing here(boundary walls) + # Can see erosion make a lot of sense here? + # If not confident, then only take help of global_threshold_for_template + if max1 < confident_jump: + # Threshold hack: local can never be 255 + if no_outliers or thr1 == 255: + # All Black or All White case + thr1 = global_threshold_for_template + else: + # TODO: Low confidence parameters here + pass + + # TODO: Make a common plot util to show local and global thresholds + if plot_show and plot_title is not None: + # TODO: add plot labels via the util + _, ax = plt.subplots() + ax.bar(range(len(sorted_bubble_means)), sorted_bubble_means) + thrline = ax.axhline(thr1, color="green", ls=("-."), linewidth=3) + thrline.set_label("Local Threshold") + thrline = ax.axhline( + global_threshold_for_template, color="red", ls=":", linewidth=5 + ) + thrline.set_label("Global Threshold") + ax.set_title(plot_title) + ax.set_ylabel("Bubble Mean Intensity") + ax.set_xlabel("Bubble Number(sorted)") + ax.legend() + # TODO append QStrip to this plot- + # appendSaveImg(6,getPlotImg()) + if plot_show: + plt.show() + return thr1, max1 + + def append_save_img(self, key, img): + if self.save_image_level >= int(key): + self.save_img_list[key].append(img.copy()) + + def save_image_stacks(self, key, filename, save_dir): + config = self.tuning_config + if self.save_image_level >= int(key) and self.save_img_list[key] != []: + name = os.path.splitext(filename)[0] + result = np.hstack( + tuple( + [ + ImageUtils.resize_util_h(img, config.dimensions.display_height) + for img in self.save_img_list[key] + ] + ) + ) + result = ImageUtils.resize_util( + result, + min( + len(self.save_img_list[key]) * config.dimensions.display_width // 3, + int(config.dimensions.display_width * 2.5), + ), + ) + ImageUtils.save_img(f"{save_dir}stack/{name}_{str(key)}_stack.jpg", result) + + def reset_all_save_img(self): + for i in range(self.save_image_level): + self.save_img_list[i + 1] = [] diff --git a/src/algorithm/detection.py b/src/algorithm/detection.py new file mode 100644 index 00000000..fb0ed061 --- /dev/null +++ b/src/algorithm/detection.py @@ -0,0 +1,78 @@ +""" + + OMRChecker + + Author: Udayraj Deshmukh + Github: https://github.com/Udayraj123 + +""" + +import functools + +from src.utils.parsing import default_dump + + +@functools.total_ordering +class MeanValueItem: + def __init__(self, mean_value, item_reference): + self.mean_value = mean_value + self.item_reference = item_reference + self.item_reference_name = item_reference.name + + def __str__(self): + return f"{self.item_reference} : {round(self.mean_value, 2)}" + + def _is_valid_operand(self, other): + return hasattr(other, "mean_value") and hasattr(other, "item_reference") + + def __eq__(self, other): + if not self._is_valid_operand(other): + return NotImplementedError + return self.mean_value == other.mean_value + + def __lt__(self, other): + if not self._is_valid_operand(other): + return NotImplementedError + return self.mean_value < other.mean_value + + +# TODO: merge with FieldBubbleDetection +class BubbleMeanValue(MeanValueItem): + def __init__(self, mean_value, item_reference): + super().__init__(mean_value, item_reference) + self.is_marked = None + + def to_json(self): + # TODO: mini util for this loop + return { + key: default_dump(getattr(self, key)) + for key in [ + "is_marked", + "item_reference_name", + "mean_value", + ] + } + + +# TODO: see if this one can be merged in above +class FieldStdMeanValue(MeanValueItem): + def __init__(self, mean_value, item_reference): + super().__init__(mean_value, item_reference) + + def to_json(self): + return { + key: default_dump(getattr(self, key)) + for key in [ + "item_reference_name", + "mean_value", + ] + } + + +# TODO: use this or merge it +class FieldDetection: + def __init__(self, field, confidence): + self.field = field + self.confidence = confidence + # TODO: use local_threshold from here + # self.local_threshold = None diff --git a/src/evaluation.py b/src/algorithm/evaluation.py similarity index 72% rename from src/evaluation.py rename to src/algorithm/evaluation.py index 31104823..54a5ba1e 100644 --- a/src/evaluation.py +++ b/src/algorithm/evaluation.py @@ -1,3 +1,11 @@ +""" + + OMRChecker + + Author: Udayraj Deshmukh + Github: https://github.com/Udayraj123 + +""" import ast import os import re @@ -7,12 +15,12 @@ import pandas as pd from rich.table import Table -from src.logger import console, logger from src.schemas.constants import ( BONUS_SECTION_PREFIX, DEFAULT_SECTION_KEY, MARKING_VERDICT_TYPES, ) +from src.utils.logger import console, logger from src.utils.parsing import ( get_concatenated_response, open_evaluation_with_validation, @@ -22,125 +30,127 @@ class AnswerMatcher: - def __init__(self, answer_item, marking_scheme): - self.answer_type = self.get_answer_type(answer_item) - self.parsed_answer = self.parse_answer_item(answer_item) - self.set_defaults_from_scheme(marking_scheme) - self.marking_scheme = marking_scheme - - def get_answer_type(self, answer_item): - item_type = type(answer_item) - if item_type == str: + def __init__(self, answer_item, section_marking_scheme): + self.section_marking_scheme = section_marking_scheme + self.answer_item = answer_item + self.answer_type = self.validate_and_get_answer_type(answer_item) + self.set_defaults_from_scheme(section_marking_scheme) + + @staticmethod + def is_a_marking_score(answer_element): + # Note: strict type checking is already done at schema validation level, + # Here we focus on overall struct type + return type(answer_element) == str or type(answer_element) == int + + @staticmethod + def is_standard_answer(answer_element): + return type(answer_element) == str and len(answer_element) >= 1 + + def validate_and_get_answer_type(self, answer_item): + if self.is_standard_answer(answer_item): return "standard" - elif item_type == list: + elif type(answer_item) == list: if ( + # Array of answer elements: ['A', 'B', 'AB'] len(answer_item) >= 2 - and type(answer_item[0]) == str - and type(answer_item[1]) == str + and all( + self.is_standard_answer(answers_or_score) + for answers_or_score in answer_item + ) ): return "multiple-correct" elif ( - len(answer_item) == 2 - and type(answer_item[0]) == str - and type(answer_item[1]) == list + # Array of two-tuples: [['A', 1], ['B', 1], ['C', 3], ['AB', 2]] + len(answer_item) >= 1 + and all( + type(answer_and_score) == list and len(answer_and_score) == 2 + for answer_and_score in answer_item + ) + and all( + self.is_standard_answer(allowed_answer) + and self.is_a_marking_score(answer_score) + for allowed_answer, answer_score in answer_item + ) ): return "multiple-correct-weighted" - elif ( - len(answer_item) == 2 - and type(answer_item[0]) in list - and type(answer_item[1]) == list - ): - return "answer-weights" - else: - logger.critical( - f"Unable to determine answer type for answer item: {answer_item}" - ) - raise Exception("Unable to determine answer type") - def parse_answer_item(self, answer_item): - return answer_item + logger.critical( + f"Unable to determine answer type for answer item: {answer_item}" + ) + raise Exception("Unable to determine answer type") - def set_defaults_from_scheme(self, marking_scheme): + def set_defaults_from_scheme(self, section_marking_scheme): answer_type = self.answer_type - self.empty_val = marking_scheme.empty_val - parsed_answer = self.parsed_answer - self.marking = deepcopy(marking_scheme.marking) + self.empty_val = section_marking_scheme.empty_val + answer_item = self.answer_item + self.marking = deepcopy(section_marking_scheme.marking) # TODO: reuse part of parse_scheme_marking here - if answer_type == "standard": # no local overrides pass elif answer_type == "multiple-correct": - # no local overrides - for allowed_answer in parsed_answer: + # override marking scheme scores for each allowed answer + for allowed_answer in answer_item: self.marking[f"correct-{allowed_answer}"] = self.marking["correct"] elif answer_type == "multiple-correct-weighted": - custom_marking = list(map(parse_float_or_fraction, parsed_answer[1])) - verdict_types_length = min(len(MARKING_VERDICT_TYPES), len(custom_marking)) - # override the given marking - for i in range(verdict_types_length): - verdict_type = MARKING_VERDICT_TYPES[i] - self.marking[verdict_type] = custom_marking[i] - - if type(parsed_answer[0] == str): - allowed_answer = parsed_answer[0] - self.marking[f"correct-{allowed_answer}"] = self.marking["correct"] - else: - for allowed_answer in parsed_answer[0]: - self.marking[f"correct-{allowed_answer}"] = self.marking["correct"] + # Note: No override using marking scheme as answer scores are provided in answer_item + for allowed_answer, answer_score in answer_item: + self.marking[f"correct-{allowed_answer}"] = parse_float_or_fraction( + answer_score + ) def get_marking_scheme(self): - return self.marking_scheme + return self.section_marking_scheme def get_section_explanation(self): answer_type = self.answer_type - if answer_type == "standard" or answer_type == "multiple-correct": - return self.marking_scheme.section_key + if answer_type in ["standard", "multiple-correct"]: + return self.section_marking_scheme.section_key elif answer_type == "multiple-correct-weighted": return f"Custom: {self.marking}" def get_verdict_marking(self, marked_answer): answer_type = self.answer_type - + question_verdict = "incorrect" if answer_type == "standard": question_verdict = self.get_standard_verdict(marked_answer) - return question_verdict, self.marking[question_verdict] elif answer_type == "multiple-correct": question_verdict = self.get_multiple_correct_verdict(marked_answer) - return question_verdict, self.marking[question_verdict] elif answer_type == "multiple-correct-weighted": - question_verdict = self.get_multi_weighted_verdict(marked_answer) - return question_verdict, self.marking[question_verdict] + question_verdict = self.get_multiple_correct_weighted_verdict(marked_answer) + return question_verdict, self.marking[question_verdict] def get_standard_verdict(self, marked_answer): - parsed_answer = self.parsed_answer + allowed_answer = self.answer_item if marked_answer == self.empty_val: return "unmarked" - elif marked_answer == parsed_answer: + elif marked_answer == allowed_answer: return "correct" else: return "incorrect" - def get_multi_weighted_verdict(self, marked_answer): - return self.get_multiple_correct_verdict(marked_answer) - def get_multiple_correct_verdict(self, marked_answer): - parsed_answer = self.parsed_answer + allowed_answers = self.answer_item if marked_answer == self.empty_val: return "unmarked" - elif marked_answer in parsed_answer: + elif marked_answer in allowed_answers: return f"correct-{marked_answer}" else: return "incorrect" - def __str__(self): - answer_type, parsed_answer = self.answer_type, self.parsed_answer + def get_multiple_correct_weighted_verdict(self, marked_answer): + allowed_answers = [ + allowed_answer for allowed_answer, _answer_score in self.answer_item + ] + if marked_answer == self.empty_val: + return "unmarked" + elif marked_answer in allowed_answers: + return f"correct-{marked_answer}" + else: + return "incorrect" - if answer_type == "multiple-correct": - return str(parsed_answer) - elif answer_type == "multiple-correct-weighted": - return f"{parsed_answer[0]}" - # TODO: case of multi-lines in multi-weights - return parsed_answer + def __str__(self): + return f"{self.answer_item}" class SectionMarkingScheme: @@ -156,6 +166,9 @@ def __init__(self, section_key, section_scheme, empty_val): self.questions = parse_fields(section_key, section_scheme["questions"]) self.marking = self.parse_scheme_marking(section_scheme["marking"]) + def __str__(self): + return self.section_key + def parse_scheme_marking(self, marking): parsed_marking = {} for verdict_type in MARKING_VERDICT_TYPES: @@ -186,15 +199,13 @@ class EvaluationConfig: def __init__(self, curr_dir, evaluation_path, template, tuning_config): self.path = evaluation_path evaluation_json = open_evaluation_with_validation(evaluation_path) - options, marking_scheme, source_type = map( - evaluation_json.get, ["options", "marking_scheme", "source_type"] + options, marking_schemes, source_type = map( + evaluation_json.get, ["options", "marking_schemes", "source_type"] ) self.should_explain_scoring = options.get("should_explain_scoring", False) self.has_non_default_section = False self.exclude_files = [] - marking_scheme = marking_scheme - if source_type == "csv": csv_path = curr_dir.joinpath(options["answer_key_csv_path"]) if not os.path.exists(csv_path): @@ -236,12 +247,8 @@ def __init__(self, curr_dir, evaluation_path, template, tuning_config): raise Exception( f"Could not read answer key from image {image_path}" ) - ( - response_dict, - _final_marked, - _multi_marked, - _multi_roll, - ) = template.image_instance_ops.read_omr_response( + + (response_dict, *_) = template.image_instance_ops.read_omr_response( template, image=in_omr, name=image_path, @@ -291,13 +298,13 @@ def __init__(self, curr_dir, evaluation_path, template, tuning_config): self.validate_questions(answers_in_order) - self.marking_scheme, self.question_to_scheme = {}, {} - for section_key, section_scheme in marking_scheme.items(): + self.section_marking_schemes, self.question_to_scheme = {}, {} + for section_key, section_scheme in marking_schemes.items(): section_marking_scheme = SectionMarkingScheme( section_key, section_scheme, template.global_empty_val ) if section_key != DEFAULT_SECTION_KEY: - self.marking_scheme[section_key] = section_marking_scheme + self.section_marking_schemes[section_key] = section_marking_scheme for q in section_marking_scheme.questions: # TODO: check the answer key for custom scheme here? self.question_to_scheme[q] = section_marking_scheme @@ -305,7 +312,7 @@ def __init__(self, curr_dir, evaluation_path, template, tuning_config): else: self.default_marking_scheme = section_marking_scheme - self.validate_marking_scheme() + self.validate_marking_schemes() self.question_to_answer_matcher = self.parse_answers_and_map_questions( answers_in_order @@ -350,7 +357,8 @@ def match_answer_for_question(self, current_score, question, marked_answer): question, current_score, ) - return delta + expected_answer_string = str(answer_matcher) + return delta, question_verdict, expected_answer_string def conditionally_print_explanation(self): if self.should_explain_scoring: @@ -367,10 +375,13 @@ def parse_answer_column(answer_column): # Remove all whitespaces answer_column = answer_column.replace(" ", "") if answer_column[0] == "[": + # multiple-correct-weighted or multiple-correct parsed_answer = ast.literal_eval(answer_column) elif "," in answer_column: + # multiple-correct parsed_answer = answer_column.split(",") else: + # single-correct parsed_answer = answer_column return parsed_answer @@ -392,12 +403,13 @@ def validate_answers(self, answers_in_order, tuning_config): multi_marked_answer = True break if answer_type == "multiple-correct-weighted": - if len(answer_item[0]) > 1: - multi_marked_answer = True + for single_answer, _answer_score in answer_item: + if len(single_answer) > 1: + multi_marked_answer = True if multi_marked_answer: raise Exception( - f"Answer key contains multiple correct answer(s), but filter_out_multimarked_files is True. Scoring will get skipped." + f"Provided answer key contains multiple correct answer(s), but config.filter_out_multimarked_files is True. Scoring will get skipped." ) def validate_questions(self, answers_in_order): @@ -413,10 +425,10 @@ def validate_questions(self, answers_in_order): f"Unequal lengths for questions_in_order and answers_in_order ({len_questions_in_order} != {len_answers_in_order})" ) - def validate_marking_scheme(self): - marking_scheme = self.marking_scheme + def validate_marking_schemes(self): + section_marking_schemes = self.section_marking_schemes section_questions = set() - for section_key, section_scheme in marking_scheme.items(): + for section_key, section_scheme in section_marking_schemes.items(): if section_key == DEFAULT_SECTION_KEY: continue current_set = set(section_scheme.questions) @@ -438,9 +450,15 @@ def parse_answers_and_map_questions(self, answers_in_order): question_to_answer_matcher = {} for question, answer_item in zip(self.questions_in_order, answers_in_order): section_marking_scheme = self.get_marking_scheme_for_question(question) - question_to_answer_matcher[question] = AnswerMatcher( - answer_item, section_marking_scheme - ) + answer_matcher = AnswerMatcher(answer_item, section_marking_scheme) + question_to_answer_matcher[question] = answer_matcher + if ( + answer_matcher.answer_type == "multiple-correct-weighted" + and section_marking_scheme.section_key != DEFAULT_SECTION_KEY + ): + logger.warning( + f"The custom scheme '{section_marking_scheme}' will not apply to question '{question}' as it will use the given answer weights f{answer_item}" + ) return question_to_answer_matcher # Then unfolding lower abstraction levels @@ -488,9 +506,11 @@ def conditionally_add_explanation( str.title(question_verdict), str(round(delta, 2)), str(round(next_score, 2)), - answer_matcher.get_section_explanation() - if self.has_non_default_section - else None, + ( + answer_matcher.get_section_explanation() + if self.has_non_default_section + else None + ), ] if item is not None ] @@ -500,13 +520,25 @@ def conditionally_add_explanation( def evaluate_concatenated_response(concatenated_response, evaluation_config): evaluation_config.prepare_and_validate_omr_response(concatenated_response) current_score = 0.0 + question_meta = {} for question in evaluation_config.questions_in_order: marked_answer = concatenated_response[question] - delta = evaluation_config.match_answer_for_question( + ( + delta, + question_verdict, + expected_answer_string, + ) = evaluation_config.match_answer_for_question( current_score, question, marked_answer ) current_score += delta + question_meta[question] = { + "question_verdict": question_verdict, + "marked_answer": marked_answer, + "delta": delta, + "current_score": current_score, + "expected_answer_string": expected_answer_string, + } evaluation_config.conditionally_print_explanation() - - return current_score + evaluation_meta = {"final_score": current_score, "question_meta": question_meta} + return current_score, evaluation_meta diff --git a/src/template.py b/src/algorithm/template.py similarity index 77% rename from src/template.py rename to src/algorithm/template.py index f726ed76..5929d3da 100644 --- a/src/template.py +++ b/src/algorithm/template.py @@ -6,12 +6,14 @@ Github: https://github.com/Udayraj123 """ -from src.constants import FIELD_TYPES -from src.core import ImageInstanceOps -from src.logger import logger + +from src.algorithm.core import ImageInstanceOps from src.processors.manager import PROCESSOR_MANAGER +from src.utils.constants import FIELD_TYPES +from src.utils.logger import logger from src.utils.parsing import ( custom_sort_output_columns, + default_dump, open_template_with_defaults, parse_fields, ) @@ -208,13 +210,46 @@ def validate_parsed_labels(self, field_labels, block_instance): def __str__(self): return str(self.path) + # Make the class serializable + def to_json(self): + return { + key: default_dump(getattr(self, key)) + for key in [ + "page_dimensions", + "field_blocks", + # Not needed as local props are overridden - + # "bubble_dimensions", + # 'options', + # "global_empty_val", + ] + } + class FieldBlock: def __init__(self, block_name, field_block_object): self.name = block_name self.shift_x, self.shift_y = 0, 0 + # TODO: Move plot_bin_name into child class + self.plot_bin_name = block_name self.setup_field_block(field_block_object) + # Make the class serializable + def to_json(self): + return { + key: default_dump(getattr(self, key)) + for key in [ + "bubble_dimensions", + "dimensions", + "empty_val", + "fields", + "name", + "origin", + # "plot_bin_name", + # "shift_x", + # "shift_y", + ] + } + def setup_field_block(self, field_block_object): # case mapping ( @@ -292,22 +327,60 @@ def generate_bubble_grid( labels_gap, ): _h, _v = (1, 0) if (direction == "vertical") else (0, 1) - self.traverse_bubbles = [] + self.fields = [] # Generate the bubble grid lead_point = [float(self.origin[0]), float(self.origin[1])] for field_label in self.parsed_field_labels: bubble_point = lead_point.copy() field_bubbles = [] - for bubble_value in bubble_values: + for bubble_index, bubble_value in enumerate(bubble_values): field_bubbles.append( - Bubble(bubble_point.copy(), field_label, field_type, bubble_value) + FieldBubble( + bubble_point.copy(), + # TODO: move field_label into field_label_ref + field_label, + field_type, + bubble_value, + bubble_index, + ) ) bubble_point[_h] += bubbles_gap - self.traverse_bubbles.append(field_bubbles) + self.fields.append(Field(field_label, field_type, field_bubbles, direction)) lead_point[_v] += labels_gap -class Bubble: +class Field: + """ + Container for a Field on the OMR i.e. a group of FieldBubbles with a collective field_label + + """ + + def __init__(self, field_label, field_type, field_bubbles, direction): + self.field_label = field_label + self.field_type = field_type + self.field_bubbles = field_bubbles + self.direction = direction + # TODO: move local_threshold into child detection class + self.local_threshold = None + + def __str__(self): + return self.field_label + + # Make the class serializable + def to_json(self): + return { + key: default_dump(getattr(self, key)) + for key in [ + "field_label", + "field_type", + "direction", + "field_bubbles", + "local_threshold", + ] + } + + +class FieldBubble: """ Container for a Point Box on the OMR @@ -316,12 +389,32 @@ class Bubble: It can also correspond to a single digit of integer type Q (eg q5d1) """ - def __init__(self, pt, field_label, field_type, field_value): + def __init__(self, pt, field_label, field_type, field_value, bubble_index): + self.name = f"{field_label}_{field_value}" + self.plot_bin_name = field_label self.x = round(pt[0]) self.y = round(pt[1]) self.field_label = field_label self.field_type = field_type self.field_value = field_value + self.bubble_index = bubble_index def __str__(self): - return str([self.x, self.y]) + return self.name # f"{self.field_label}: [{self.x}, {self.y}]" + + # Make the class serializable + def to_json(self): + return { + key: default_dump(getattr(self, key)) + for key in [ + "field_label", + "field_value", + # for item_reference_name + "name", + "x", + "y", + # "plot_bin_name", + # "field_type", + # "bubble_index", + ] + } diff --git a/src/core.py b/src/core.py deleted file mode 100644 index 94221bd2..00000000 --- a/src/core.py +++ /dev/null @@ -1,586 +0,0 @@ -import os -from collections import defaultdict -from copy import deepcopy -from typing import Any - -import cv2 -import matplotlib.pyplot as plt -import numpy as np - -import src.constants as constants -from src.logger import logger -from src.utils.image import ImageUtils - - -class ImageInstanceOps: - """Class to hold fine-tuned utilities for a group of images. One instance for each processing directory.""" - - save_img_list: Any = defaultdict(list) - - def __init__(self, tuning_config): - super().__init__() - self.tuning_config = tuning_config - self.save_image_level = tuning_config.outputs.save_image_level - - def apply_preprocessors(self, file_path, in_omr, template): - tuning_config = self.tuning_config - # resize to conform to common preprocessor input requirements - in_omr = ImageUtils.resize_util( - in_omr, - tuning_config.dimensions.processing_width, - tuning_config.dimensions.processing_height, - ) - # Copy template for this instance op - template = deepcopy(template) - # run pre_processors in sequence - for pre_processor in template.pre_processors: - result = pre_processor.apply_filter(in_omr, template, file_path) - out_omr, next_template = ( - result if type(result) is tuple else (result, template) - ) - # resize the image if its shape is changed by the filter - if out_omr.shape[:2] != in_omr.shape[:2]: - out_omr = ImageUtils.resize_util( - out_omr, - tuning_config.dimensions.processing_width, - tuning_config.dimensions.processing_height, - ) - in_omr = out_omr - template = next_template - return in_omr, template - - def read_omr_response(self, template, image, name, save_dir=None): - config = self.tuning_config - - img = image.copy() - # origDim = img.shape[:2] - img = ImageUtils.resize_util( - img, template.page_dimensions[0], template.page_dimensions[1] - ) - if img.max() > img.min(): - img = ImageUtils.normalize_util(img) - # Processing copies - transp_layer = img.copy() - final_marked = img.copy() - - # Move them to data class if needed - # Overlay Transparencies - alpha = 0.65 - omr_response = {} - multi_marked, multi_roll = 0, 0 - - # TODO Make this part useful for visualizing status checks - # blackVals=[0] - # whiteVals=[255] - - if config.outputs.show_image_level >= 5: - all_c_box_vals = {"int": [], "mcq": []} - # TODO: simplify this logic - q_nums = {"int": [], "mcq": []} - - # Get mean bubbleValues n other stats - all_q_vals, all_q_strip_arrs, all_q_std_vals = [], [], [] - total_q_strip_no = 0 - for field_block in template.field_blocks: - q_std_vals = [] - box_w, box_h = field_block.bubble_dimensions - for field_block_bubbles in field_block.traverse_bubbles: - q_strip_vals = [] - for pt in field_block_bubbles: - # TODO: move this responsibility into the plugin (of shifting every point) - # shifted - x, y = (pt.x + field_block.shift_x, pt.y + field_block.shift_y) - rect = [y, y + box_h, x, x + box_w] - q_strip_vals.append( - cv2.mean(img[rect[0] : rect[1], rect[2] : rect[3]])[0] - # detectCross(img, rect) ? 100 : 0 - ) - q_std_vals.append(round(np.std(q_strip_vals), 2)) - all_q_strip_arrs.append(q_strip_vals) - # _, _, _ = get_global_threshold(q_strip_vals, "QStrip Plot", - # plot_show=False, sort_in_plot=True) - # hist = getPlotImg() - # InteractionUtils.show("QStrip "+field_block_bubbles[0].field_label, hist, 0, 1,config=config) - all_q_vals.extend(q_strip_vals) - # print(total_q_strip_no, field_block_bubbles[0].field_label, q_std_vals[len(q_std_vals)-1]) - total_q_strip_no += 1 - all_q_std_vals.extend(q_std_vals) - - global_std_thresh, _, _ = self.get_global_threshold( - all_q_std_vals - ) # , "Q-wise Std-dev Plot", plot_show=True, sort_in_plot=True) - # plt.show() - # hist = getPlotImg() - # InteractionUtils.show("StdHist", hist, 0, 1,config=config) - - # Note: Plotting takes Significant times here --> Change Plotting args - # to support show_image_level - # , "Mean Intensity Histogram",plot_show=True, sort_in_plot=True) - global_thr, _, _ = self.get_global_threshold(all_q_vals, looseness=4) - - logger.info( - f"Thresholding:\tglobal_thr: {round(global_thr, 2)} \tglobal_std_THR: {round(global_std_thresh, 2)}\t{'(Looks like a Xeroxed OMR)' if (global_thr == 255) else ''}" - ) - # plt.show() - # hist = getPlotImg() - # InteractionUtils.show("StdHist", hist, 0, 1,config=config) - - # if(config.outputs.show_image_level>=1): - # hist = getPlotImg() - # InteractionUtils.show("Hist", hist, 0, 1,config=config) - # appendSaveImg(4,hist) - # appendSaveImg(5,hist) - # appendSaveImg(2,hist) - - per_omr_threshold_avg, total_q_strip_no, total_q_box_no = 0, 0, 0 - for field_block in template.field_blocks: - block_q_strip_no = 1 - key = field_block.name[:3] - box_w, box_h = field_block.bubble_dimensions - for field_block_bubbles in field_block.traverse_bubbles: - # All Black or All White case - no_outliers = all_q_std_vals[total_q_strip_no] < global_std_thresh - # print(total_q_strip_no, field_block_bubbles[0].field_label, - # all_q_std_vals[total_q_strip_no], "no_outliers:", no_outliers) - per_q_strip_threshold = self.get_local_threshold( - all_q_strip_arrs[total_q_strip_no], - global_thr, - no_outliers, - f"Mean Intensity Histogram for {key}.{field_block_bubbles[0].field_label}.{block_q_strip_no}", - config.outputs.show_image_level >= 6, - ) - # print(field_block_bubbles[0].field_label,key,block_q_strip_no, "THR: ", - # round(per_q_strip_threshold,2)) - per_omr_threshold_avg += per_q_strip_threshold - - # TODO: get rid of total_q_box_no - detected_bubbles = [] - for bubble in field_block_bubbles: - bubble_is_marked = ( - per_q_strip_threshold > all_q_vals[total_q_box_no] - ) - total_q_box_no += 1 - - x, y, field_value = ( - bubble.x + field_block.shift_x, - bubble.y + field_block.shift_y, - bubble.field_value, - ) - if bubble_is_marked: - detected_bubbles.append(bubble) - # Draw the shifted box - cv2.rectangle( - final_marked, - (int(x + box_w / 12), int(y + box_h / 12)), - ( - int(x + box_w - box_w / 12), - int(y + box_h - box_h / 12), - ), - constants.CLR_DARK_GRAY, - 3, - ) - - cv2.putText( - final_marked, - str(field_value), - (x, y), - cv2.FONT_HERSHEY_SIMPLEX, - constants.TEXT_SIZE, - (20, 20, 10), - int(1 + 3.5 * constants.TEXT_SIZE), - ) - else: - cv2.rectangle( - final_marked, - (int(x + box_w / 10), int(y + box_h / 10)), - ( - int(x + box_w - box_w / 10), - int(y + box_h - box_h / 10), - ), - constants.CLR_GRAY, - -1, - ) - - for bubble in detected_bubbles: - field_label, field_value = ( - bubble.field_label, - bubble.field_value, - ) - # Only send rolls multi-marked in the directory - multi_marked_local = field_label in omr_response - omr_response[field_label] = ( - (omr_response[field_label] + field_value) - if multi_marked_local - else field_value - ) - # TODO: generalize this into identifier - # multi_roll = multi_marked_local and "Roll" in str(q) - multi_marked = multi_marked or multi_marked_local - - if len(detected_bubbles) == 0: - field_label = field_block_bubbles[0].field_label - omr_response[field_label] = field_block.empty_val - - if config.outputs.show_image_level >= 5: - if key in all_c_box_vals: - q_nums[key].append(f"{key[:2]}_c{str(block_q_strip_no)}") - all_c_box_vals[key].append(all_q_strip_arrs[total_q_strip_no]) - - block_q_strip_no += 1 - total_q_strip_no += 1 - # /for field_block - - per_omr_threshold_avg /= total_q_strip_no - per_omr_threshold_avg = round(per_omr_threshold_avg, 2) - # Translucent - cv2.addWeighted(final_marked, alpha, transp_layer, 1 - alpha, 0, final_marked) - # Box types - if config.outputs.show_image_level >= 6: - # plt.draw() - f, axes = plt.subplots(len(all_c_box_vals), sharey=True) - f.canvas.manager.set_window_title( - f"Bubble Intensity by question type for {name}" - ) - ctr = 0 - # TODO: generalize - type_name = { - "int": "Integer", - "mcq": "MCQ", - "med": "MED", - "rol": "Roll", - } - for k, boxvals in all_c_box_vals.items(): - axes[ctr].title.set_text(type_name[k] + " Type") - axes[ctr].boxplot(boxvals) - # thrline=axes[ctr].axhline(per_omr_threshold_avg,color='red',ls='--') - # thrline.set_label("Average THR") - axes[ctr].set_ylabel("Intensity") - axes[ctr].set_xticklabels(q_nums[k]) - # axes[ctr].legend() - ctr += 1 - # imshow will do the waiting - plt.tight_layout(pad=0.5) - plt.show() - - if config.outputs.save_detections and save_dir is not None: - if multi_roll: - save_dir = save_dir.joinpath("_MULTI_") - image_path = str(save_dir.joinpath(name)) - ImageUtils.save_img(image_path, final_marked) - - self.append_save_img(2, final_marked) - - if save_dir is not None: - for i in range(config.outputs.save_image_level): - self.save_image_stacks(i + 1, name, save_dir) - - return omr_response, final_marked, multi_marked, multi_roll - - @staticmethod - def draw_template_layout(img, template, shifted=True, draw_qvals=False, border=-1): - img = ImageUtils.resize_util( - img, template.page_dimensions[0], template.page_dimensions[1] - ) - final_align = img.copy() - for field_block in template.field_blocks: - field_block_name, s, d, bubble_dimensions, shift_x, shift_y = map( - lambda attr: getattr(field_block, attr), - [ - "name", - "origin", - "dimensions", - "bubble_dimensions", - "shift_x", - "shift_y", - ], - ) - box_w, box_h = bubble_dimensions - - if shifted: - cv2.rectangle( - final_align, - (s[0] + shift_x, s[1] + shift_y), - (s[0] + shift_x + d[0], s[1] + shift_y + d[1]), - constants.CLR_BLACK, - 3, - ) - else: - cv2.rectangle( - final_align, - (s[0], s[1]), - (s[0] + d[0], s[1] + d[1]), - constants.CLR_BLACK, - 3, - ) - for field_block_bubbles in field_block.traverse_bubbles: - for pt in field_block_bubbles: - x, y = (pt.x + shift_x, pt.y + shift_y) if shifted else (pt.x, pt.y) - cv2.rectangle( - final_align, - (int(x + box_w / 10), int(y + box_h / 10)), - (int(x + box_w - box_w / 10), int(y + box_h - box_h / 10)), - constants.CLR_GRAY, - border, - ) - if draw_qvals: - rect = [y, y + box_h, x, x + box_w] - cv2.putText( - final_align, - f"{int(cv2.mean(img[rect[0] : rect[1], rect[2] : rect[3]])[0])}", - (rect[2] + 2, rect[0] + (box_h * 2) // 3), - cv2.FONT_HERSHEY_SIMPLEX, - 0.6, - constants.CLR_BLACK, - 2, - ) - if shifted: - text_in_px = cv2.getTextSize( - field_block_name, cv2.FONT_HERSHEY_SIMPLEX, constants.TEXT_SIZE, 4 - ) - cv2.putText( - final_align, - field_block_name, - ( - int(s[0] + shift_x + d[0] - text_in_px[0][0]), - int(s[1] + shift_y - text_in_px[0][1]), - ), - cv2.FONT_HERSHEY_SIMPLEX, - constants.TEXT_SIZE, - constants.CLR_BLACK, - 4, - ) - - return final_align - - def get_global_threshold( - self, - q_vals_orig, - plot_title=None, - plot_show=True, - sort_in_plot=True, - looseness=1, - ): - """ - Note: Cannot assume qStrip has only-gray or only-white bg - (in which case there is only one jump). - So there will be either 1 or 2 jumps. - 1 Jump : - ...... - |||||| - |||||| <-- risky THR - |||||| <-- safe THR - ....|||||| - |||||||||| - - 2 Jumps : - ...... - |||||| <-- wrong THR - ....|||||| - |||||||||| <-- safe THR - ..|||||||||| - |||||||||||| - - The abstract "First LARGE GAP" is perfect for this. - Current code is considering ONLY TOP 2 jumps(>= MIN_GAP) to be big, - gives the smaller one - - """ - config = self.tuning_config - PAGE_TYPE_FOR_THRESHOLD, MIN_JUMP, JUMP_DELTA = map( - config.threshold_params.get, - [ - "PAGE_TYPE_FOR_THRESHOLD", - "MIN_JUMP", - "JUMP_DELTA", - ], - ) - - global_default_threshold = ( - constants.GLOBAL_PAGE_THRESHOLD_WHITE - if PAGE_TYPE_FOR_THRESHOLD == "white" - else constants.GLOBAL_PAGE_THRESHOLD_BLACK - ) - - # Sort the Q bubbleValues - # TODO: Change var name of q_vals - q_vals = sorted(q_vals_orig) - # Find the FIRST LARGE GAP and set it as threshold: - ls = (looseness + 1) // 2 - l = len(q_vals) - ls - max1, thr1 = MIN_JUMP, global_default_threshold - for i in range(ls, l): - jump = q_vals[i + ls] - q_vals[i - ls] - if jump > max1: - max1 = jump - thr1 = q_vals[i - ls] + jump / 2 - - # NOTE: thr2 is deprecated, thus is JUMP_DELTA - # Make use of the fact that the JUMP_DELTA(Vertical gap ofc) between - # values at detected jumps would be atleast 20 - max2, thr2 = MIN_JUMP, global_default_threshold - # Requires atleast 1 gray box to be present (Roll field will ensure this) - for i in range(ls, l): - jump = q_vals[i + ls] - q_vals[i - ls] - new_thr = q_vals[i - ls] + jump / 2 - if jump > max2 and abs(thr1 - new_thr) > JUMP_DELTA: - max2 = jump - thr2 = new_thr - # global_thr = min(thr1,thr2) - global_thr, j_low, j_high = thr1, thr1 - max1 // 2, thr1 + max1 // 2 - - # # For normal images - # thresholdRead = 116 - # if(thr1 > thr2 and thr2 > thresholdRead): - # print("Note: taking safer thr line.") - # global_thr, j_low, j_high = thr2, thr2 - max2//2, thr2 + max2//2 - - if plot_title: - _, ax = plt.subplots() - ax.bar(range(len(q_vals_orig)), q_vals if sort_in_plot else q_vals_orig) - ax.set_title(plot_title) - thrline = ax.axhline(global_thr, color="green", ls="--", linewidth=5) - thrline.set_label("Global Threshold") - thrline = ax.axhline(thr2, color="red", ls=":", linewidth=3) - thrline.set_label("THR2 Line") - # thrline=ax.axhline(j_low,color='red',ls='-.', linewidth=3) - # thrline=ax.axhline(j_high,color='red',ls='-.', linewidth=3) - # thrline.set_label("Boundary Line") - # ax.set_ylabel("Mean Intensity") - ax.set_ylabel("Values") - ax.set_xlabel("Position") - ax.legend() - if plot_show: - plt.title(plot_title) - plt.show() - - return global_thr, j_low, j_high - - def get_local_threshold( - self, q_vals, global_thr, no_outliers, plot_title=None, plot_show=True - ): - """ - TODO: Update this documentation too- - //No more - Assumption : Colwise background color is uniformly gray or white, - but not alternating. In this case there is atmost one jump. - - 0 Jump : - <-- safe THR? - ....... - ...||||||| - |||||||||| <-- safe THR? - // How to decide given range is above or below gray? - -> global q_vals shall absolutely help here. Just run same function - on total q_vals instead of colwise _// - How to decide it is this case of 0 jumps - - 1 Jump : - ...... - |||||| - |||||| <-- risky THR - |||||| <-- safe THR - ....|||||| - |||||||||| - - """ - config = self.tuning_config - # Sort the Q bubbleValues - q_vals = sorted(q_vals) - - # Small no of pts cases: - # base case: 1 or 2 pts - if len(q_vals) < 3: - thr1 = ( - global_thr - if np.max(q_vals) - np.min(q_vals) < config.threshold_params.MIN_GAP - else np.mean(q_vals) - ) - else: - # qmin, qmax, qmean, qstd = round(np.min(q_vals),2), round(np.max(q_vals),2), - # round(np.mean(q_vals),2), round(np.std(q_vals),2) - # GVals = [round(abs(q-qmean),2) for q in q_vals] - # gmean, gstd = round(np.mean(GVals),2), round(np.std(GVals),2) - # # DISCRETION: Pretty critical factor in reading response - # # Doesn't work well for small number of values. - # DISCRETION = 2.7 # 2.59 was closest hit, 3.0 is too far - # L2MaxGap = round(max([abs(g-gmean) for g in GVals]),2) - # if(L2MaxGap > DISCRETION*gstd): - # no_outliers = False - - # # ^Stackoverflow method - # print(field_label, no_outliers,"qstd",round(np.std(q_vals),2), "gstd", gstd, - # "Gaps in gvals",sorted([round(abs(g-gmean),2) for g in GVals],reverse=True), - # '\t',round(DISCRETION*gstd,2), L2MaxGap) - - # else: - # Find the LARGEST GAP and set it as threshold: //(FIRST LARGE GAP) - l = len(q_vals) - 1 - max1, thr1 = config.threshold_params.MIN_JUMP, 255 - for i in range(1, l): - jump = q_vals[i + 1] - q_vals[i - 1] - if jump > max1: - max1 = jump - thr1 = q_vals[i - 1] + jump / 2 - # print(field_label,q_vals,max1) - - confident_jump = ( - config.threshold_params.MIN_JUMP - + config.threshold_params.CONFIDENT_SURPLUS - ) - # If not confident, then only take help of global_thr - if max1 < confident_jump: - if no_outliers or thr1 == 255: - # All Black or All White case - thr1 = global_thr - else: - # TODO: Low confidence parameters here - pass - - # if(thr1 == 255): - # print("Warning: threshold is unexpectedly 255! (Outlier Delta issue?)",plot_title) - - # Make a common plot function to show local and global thresholds - if plot_show and plot_title is not None: - _, ax = plt.subplots() - ax.bar(range(len(q_vals)), q_vals) - thrline = ax.axhline(thr1, color="green", ls=("-."), linewidth=3) - thrline.set_label("Local Threshold") - thrline = ax.axhline(global_thr, color="red", ls=":", linewidth=5) - thrline.set_label("Global Threshold") - ax.set_title(plot_title) - ax.set_ylabel("Bubble Mean Intensity") - ax.set_xlabel("Bubble Number(sorted)") - ax.legend() - # TODO append QStrip to this plot- - # appendSaveImg(6,getPlotImg()) - if plot_show: - plt.show() - return thr1 - - def append_save_img(self, key, img): - if self.save_image_level >= int(key): - self.save_img_list[key].append(img.copy()) - - def save_image_stacks(self, key, filename, save_dir): - config = self.tuning_config - if self.save_image_level >= int(key) and self.save_img_list[key] != []: - name = os.path.splitext(filename)[0] - result = np.hstack( - tuple( - [ - ImageUtils.resize_util_h(img, config.dimensions.display_height) - for img in self.save_img_list[key] - ] - ) - ) - result = ImageUtils.resize_util( - result, - min( - len(self.save_img_list[key]) * config.dimensions.display_width // 3, - int(config.dimensions.display_width * 2.5), - ), - ) - ImageUtils.save_img(f"{save_dir}stack/{name}_{str(key)}_stack.jpg", result) - - def reset_all_save_img(self): - for i in range(self.save_image_level): - self.save_img_list[i + 1] = [] diff --git a/src/defaults/config.py b/src/defaults/config.py index 540ec9c3..f102a9c1 100644 --- a/src/defaults/config.py +++ b/src/defaults/config.py @@ -12,14 +12,19 @@ "GAMMA_LOW": 0.7, "MIN_GAP": 30, "MIN_JUMP": 25, - "CONFIDENT_SURPLUS": 5, + "CONFIDENT_JUMP_SURPLUS_FOR_DISPARITY": 25, + "MIN_JUMP_SURPLUS_FOR_GLOBAL_FALLBACK": 5, + "GLOBAL_THRESHOLD_MARGIN": 10, "JUMP_DELTA": 30, "PAGE_TYPE_FOR_THRESHOLD": "white", + "GLOBAL_PAGE_THRESHOLD_WHITE": 200, + "GLOBAL_PAGE_THRESHOLD_BLACK": 100, }, "outputs": { "show_image_level": 0, "save_image_level": 0, "save_detections": True, + "save_image_metrics": False, "filter_out_multimarked_files": False, }, }, diff --git a/src/entry.py b/src/entry.py index ff6500d1..8f6214eb 100644 --- a/src/entry.py +++ b/src/entry.py @@ -6,6 +6,8 @@ Github: https://github.com/Udayraj123 """ + +import json import os from csv import QUOTE_NONNUMERIC from pathlib import Path @@ -15,14 +17,14 @@ import pandas as pd from rich.table import Table -from src import constants +from src.algorithm.evaluation import EvaluationConfig, evaluate_concatenated_response +from src.algorithm.template import Template from src.defaults import CONFIG_DEFAULTS -from src.evaluation import EvaluationConfig, evaluate_concatenated_response -from src.logger import console, logger -from src.template import Template +from src.utils import constants from src.utils.file import Paths, setup_dirs_for_paths, setup_outputs_for_template from src.utils.image import ImageUtils from src.utils.interaction import InteractionUtils, Stats +from src.utils.logger import console, logger from src.utils.parsing import get_concatenated_response, open_config_with_defaults # Load processors @@ -36,6 +38,46 @@ def entry_point(input_dir, args): return process_dir(input_dir, curr_dir, args) +def export_omr_metrics( + outputs_namespace, + file_name, + image, + final_marked, + template, + field_number_to_field_bubble_means, + global_threshold_for_template, + global_field_confidence_metrics, + evaluation_meta, +): + global_bubble_means_and_refs = [] + for field_bubble_means in field_number_to_field_bubble_means: + global_bubble_means_and_refs.extend(field_bubble_means) + # sorted_global_bubble_means_and_refs = sorted(global_bubble_means_and_refs) + + image_metrics_path = outputs_namespace.paths.image_metrics_dir.joinpath( + f"{os.path.splitext(file_name)[0]}.js" + ) + with open( + image_metrics_path, + "w", + ) as f: + json_string = json.dumps( + { + "global_threshold_for_template": global_threshold_for_template, + "template": template, + "evaluation_meta": ( + evaluation_meta if evaluation_meta is not None else {} + ), + "global_bubble_means_and_refs": global_bubble_means_and_refs, + "global_field_confidence_metrics": global_field_confidence_metrics, + }, + default=lambda x: x.to_json(), + indent=4, + ) + f.write(f"export default {json_string}") + logger.info(f"Exported image metrics to: {image_metrics_path}") + + def print_config_summary( curr_dir, omr_files, @@ -51,15 +93,12 @@ def print_config_summary( table.add_column("Value", style="magenta") table.add_row("Directory Path", f"{curr_dir}") table.add_row("Count of Images", f"{len(omr_files)}") + table.add_row("Debug Mode ", "ON" if args["debug"] else "OFF") table.add_row("Set Layout Mode ", "ON" if args["setLayout"] else "OFF") table.add_row( "Markers Detection", "ON" if "CropOnMarkers" in template.pre_processors else "OFF", ) - table.add_row( - "Auto Alignment", - "ON" if "AutoAlignTemplate" in template.pre_processors else "OFF", - ) table.add_row("Detected Template Path", f"{template}") if local_config_path: table.add_row("Detected Local Config", f"{local_config_path}") @@ -70,6 +109,17 @@ def print_config_summary( "Detected pre-processors", f"{[pp.__class__.__name__ for pp in template.pre_processors]}", ) + + alignment_preprocessors = list( + filter( + lambda p: p in template.pre_processors, + ["AutoAlignTemplate", "FeatureBasedAlignment"], + ) + ) + table.add_row( + "Auto Alignment", + (alignment_preprocessors if len(alignment_preprocessors) else "OFF"), + ) console.print(table, justify="center") @@ -262,6 +312,9 @@ def process_files( final_marked, multi_marked, _, + field_number_to_field_bubble_means, + global_threshold_for_template, + global_field_confidence_metrics, ) = template.image_instance_ops.read_omr_response( template, image=in_omr, name=file_id, save_dir=save_dir ) @@ -276,15 +329,30 @@ def process_files( ): logger.info(f"Read Response: \n{omr_response}") - score = 0 + score, evaluation_meta = 0, None if evaluation_config is not None: - score = evaluate_concatenated_response(omr_response, evaluation_config) + score, evaluation_meta = evaluate_concatenated_response( + omr_response, evaluation_config + ) logger.info( f"(/{files_counter}) Graded with score: {round(score, 2)}\t for file: '{file_id}'" ) else: logger.info(f"(/{files_counter}) Processed file: '{file_id}'") + if tuning_config.outputs.save_image_metrics: + export_omr_metrics( + outputs_namespace, + file_name, + in_omr, + final_marked, + template, + field_number_to_field_bubble_means, + global_threshold_for_template, + global_field_confidence_metrics, + evaluation_meta, + ) + if tuning_config.outputs.show_image_level >= 2: InteractionUtils.show( f"Final Marked Bubbles : '{file_id}'", diff --git a/src/processors/CropOnMarkers.py b/src/processors/CropOnMarkers.py index 10c5bbe4..6a486b78 100644 --- a/src/processors/CropOnMarkers.py +++ b/src/processors/CropOnMarkers.py @@ -3,12 +3,12 @@ import cv2 import numpy as np -from src.logger import logger from src.processors.interfaces.ImageTemplatePreprocessor import ( ImageTemplatePreprocessor, ) from src.utils.image import ImageUtils from src.utils.interaction import InteractionUtils +from src.utils.logger import logger # Internal Processor for separation of code @@ -608,6 +608,7 @@ def find_largest_patch_area_corners(self, area_start, area, patch_type): elif patch_type == "line": # Rotated rectangle can correct slight rotations better rotated_rect = cv2.minAreaRect(bounding_cnt) + # TODO: less confidence if angle = rotated_rect[2] is too skew rotated_rect_points = cv2.boxPoints(rotated_rect) patch_corners = np.intp(rotated_rect_points) patch_corners = ImageUtils.order_points(patch_corners) diff --git a/src/processors/CropPage.py b/src/processors/CropPage.py index 690c9954..322df043 100644 --- a/src/processors/CropPage.py +++ b/src/processors/CropPage.py @@ -4,12 +4,12 @@ import cv2 import numpy as np -from src.logger import logger from src.processors.interfaces.ImageTemplatePreprocessor import ( ImageTemplatePreprocessor, ) from src.utils.image import ImageUtils from src.utils.interaction import InteractionUtils +from src.utils.logger import logger MIN_PAGE_AREA = 80000 diff --git a/src/processors/manager.py b/src/processors/manager.py index 9406b198..3ba72b13 100644 --- a/src/processors/manager.py +++ b/src/processors/manager.py @@ -5,7 +5,7 @@ import inspect import pkgutil -from src.logger import logger +from src.utils.logger import logger class Processor: diff --git a/src/schemas/config_schema.py b/src/schemas/config_schema.py index 2f58f072..2ede8ef3 100644 --- a/src/schemas/config_schema.py +++ b/src/schemas/config_schema.py @@ -21,14 +21,39 @@ "additionalProperties": False, "properties": { "GAMMA_LOW": {"type": "number", "minimum": 0, "maximum": 1}, + # TODO: rename these variables for better usability "MIN_GAP": {"type": "integer", "minimum": 10, "maximum": 100}, "MIN_JUMP": {"type": "integer", "minimum": 10, "maximum": 100}, - "CONFIDENT_SURPLUS": {"type": "integer", "minimum": 0, "maximum": 20}, + "MIN_JUMP_SURPLUS_FOR_GLOBAL_FALLBACK": { + "type": "integer", + "minimum": 0, + "maximum": 20, + }, + "GLOBAL_THRESHOLD_MARGIN": { + "type": "integer", + "minimum": 0, + "maximum": 20, + }, + "CONFIDENT_JUMP_SURPLUS_FOR_DISPARITY": { + "type": "integer", + "minimum": 0, + "maximum": 100, + }, "JUMP_DELTA": {"type": "integer", "minimum": 10, "maximum": 100}, "PAGE_TYPE_FOR_THRESHOLD": { "enum": ["white", "black"], "type": "string", }, + "GLOBAL_PAGE_THRESHOLD_WHITE": { + "type": "integer", + "minimum": 0, + "maximum": 255, + }, + "GLOBAL_PAGE_THRESHOLD_BLACK": { + "type": "integer", + "minimum": 0, + "maximum": 255, + }, }, }, "outputs": { @@ -38,6 +63,7 @@ "show_image_level": {"type": "integer", "minimum": 0, "maximum": 6}, "save_image_level": {"type": "integer", "minimum": 0, "maximum": 6}, "save_detections": {"type": "boolean"}, + "save_image_metrics": {"type": "boolean"}, # This option moves multimarked files into a separate folder for manual checking, skipping evaluation "filter_out_multimarked_files": {"type": "boolean"}, }, diff --git a/src/schemas/evaluation_schema.py b/src/schemas/evaluation_schema.py index ec2635e9..8de1af76 100644 --- a/src/schemas/evaluation_schema.py +++ b/src/schemas/evaluation_schema.py @@ -18,6 +18,7 @@ "required": ["correct", "incorrect", "unmarked"], "type": "object", "properties": { + # TODO: can support streak marking if we allow array of marking_scores here "correct": marking_score, "incorrect": marking_score, "unmarked": marking_score, @@ -31,12 +32,12 @@ "description": "OMRChecker evaluation schema i.e. the marking scheme", "type": "object", "additionalProperties": True, - "required": ["source_type", "options", "marking_scheme"], + "required": ["source_type", "options", "marking_schemes"], "properties": { "additionalProperties": False, "source_type": {"type": "string", "enum": ["csv", "custom"]}, "options": {"type": "object"}, - "marking_scheme": { + "marking_schemes": { "type": "object", "required": [DEFAULT_SECTION_KEY], "patternProperties": { @@ -102,57 +103,38 @@ "type": "array", "items": { "oneOf": [ - # "standard": single correct, multimarked single-correct - # Example: "q1" --> 'AB' + # "standard": single correct, multi-marked single-correct + # Example: "q1" --> '67' {"type": "string"}, - # "multiple-correct": multiple correct answers (for ambiguos/bonus questions) + # "multiple-correct": multiple-correct (for ambiguous/bonus questions) # Example: "q1" --> [ 'A', 'B' ] { "type": "array", "items": {"type": "string"}, "minItems": 2, }, - # "multiple-correct-weighted": array of answer-wise weights - # Example: "q1" --> [['A', 1], ['B', 2], ['C', 3]] + # "multiple-correct-weighted": array of answer-wise weights (marking scheme not applicable) + # Example 1: "q1" --> [['A', 1], ['B', 2], ['C', 3]] or + # Example 2: "q2" --> [['A', 1], ['B', 1], ['AB', 2]] { "type": "array", - "items": False, - "maxItems": 2, - "minItems": 2, - "prefixItems": [ - {"type": "string"}, - { - "type": "array", - "items": marking_score, - "minItems": 1, - "maxItems": 3, - }, - ], + "items": { + "type": "array", + "items": False, + "minItems": 2, + "maxItems": 2, + "prefixItems": [ + {"type": "string"}, + marking_score, + ], + }, }, + # Multiple-correct with custom marking scheme + # ["A", ["1", "2", "3"]], + # [["A", "B", "AB"], ["1", "2", "3"]] ], }, }, - { - # TODO: answer_weight format - "type": "array", # two column array for weights - "items": False, - "maxItems": 2, - "minItems": 2, - "prefixItems": [ - { - "type": "array", - "items": {"type": "string"}, - "minItems": 2, - "maxItems": 2, - }, - { - "type": "array", - "items": marking_score, - "minItems": 1, - "maxItems": 3, - }, - ], - }, ] }, "questions_in_order": ARRAY_OF_STRINGS, diff --git a/src/schemas/template_schema.py b/src/schemas/template_schema.py index b889ccf4..31d86e87 100644 --- a/src/schemas/template_schema.py +++ b/src/schemas/template_schema.py @@ -1,5 +1,5 @@ -from src.constants import FIELD_TYPES from src.schemas.constants import ARRAY_OF_STRINGS, FIELD_STRING_TYPE +from src.utils.constants import FIELD_TYPES positive_number = {"type": "number", "minimum": 0} positive_integer = {"type": "integer", "minimum": 0} diff --git a/src/tests/__snapshots__/test_all_samples.ambr b/src/tests/__snapshots__/test_all_samples.ambr index 8778ddd8..acd13959 100644 --- a/src/tests/__snapshots__/test_all_samples.ambr +++ b/src/tests/__snapshots__/test_all_samples.ambr @@ -1,4 +1,39 @@ # serializer version: 1 +# name: test_run_answer_key_using_csv + dict({ + 'Manual/ErrorFiles.csv': ''' + "file_id","input_path","output_path","score","q1","q2","q3","q4","q5" + + ''', + 'Manual/MultiMarkedFiles.csv': ''' + "file_id","input_path","output_path","score","q1","q2","q3","q4","q5" + + ''', + 'Results/Results_05AM.csv': ''' + "file_id","input_path","output_path","score","q1","q2","q3","q4","q5" + "adrian_omr.png","samples/answer-key/using-csv/adrian_omr.png","outputs/answer-key/using-csv/CheckedOMRs/adrian_omr.png","5.0","C","E","A","B","B" + + ''', + }) +# --- +# name: test_run_answer_key_weighted_answers + dict({ + 'images/Manual/ErrorFiles.csv': ''' + "file_id","input_path","output_path","score","q1","q2","q3","q4","q5" + + ''', + 'images/Manual/MultiMarkedFiles.csv': ''' + "file_id","input_path","output_path","score","q1","q2","q3","q4","q5" + + ''', + 'images/Results/Results_05AM.csv': ''' + "file_id","input_path","output_path","score","q1","q2","q3","q4","q5" + "adrian_omr.png","samples/answer-key/weighted-answers/images/adrian_omr.png","outputs/answer-key/weighted-answers/images/CheckedOMRs/adrian_omr.png","5.5","B","E","A","C","B" + "adrian_omr_2.png","samples/answer-key/weighted-answers/images/adrian_omr_2.png","outputs/answer-key/weighted-answers/images/CheckedOMRs/adrian_omr_2.png","10.0","C","E","A","B","B" + + ''', + }) +# --- # name: test_run_community_Antibodyy dict({ 'Manual/ErrorFiles.csv': ''' @@ -97,7 +132,7 @@ 'scans/Results/Results_05AM.csv': ''' "file_id","input_path","output_path","score","Roll_no","q1","q2","q3","q4","q5","q6","q7","q8","q9","q10","q11","q12","q13","q14","q15","q16","q17","q18","q19","q20","q21","q22","q23","q24","q25","q26","q27","q28","q29","q30","q31","q32","q33","q34","q35","q36","q37","q38","q39","q40","q41","q42","q43","q44","q45","q46","q47","q48","q49","q50","q51","q52","q53","q54","q55","q56","q57","q58","q59","q60","q61","q62","q63","q64","q65","q66","q67","q68","q69","q70","q71","q72","q73","q74","q75","q76","q77","q78","q79","q80","q81","q82","q83","q84","q85","q86","q87","q88","q89","q90","q91","q92","q93","q94","q95","q96","q97","q98","q99","q100","q101","q102","q103","q104","q105","q106","q107","q108","q109","q110","q111","q112","q113","q114","q115","q116","q117","q118","q119","q120","q121","q122","q123","q124","q125","q126","q127","q128","q129","q130","q131","q132","q133","q134","q135","q136","q137","q138","q139","q140","q141","q142","q143","q144","q145","q146","q147","q148","q149","q150","q151","q152","q153","q154","q155","q156","q157","q158","q159","q160","q161","q162","q163","q164","q165","q166","q167","q168","q169","q170","q171","q172","q173","q174","q175","q176","q177","q178","q179","q180","q181","q182","q183","q184","q185","q186","q187","q188","q189","q190","q191","q192","q193","q194","q195","q196","q197","q198","q199","q200" "scan-type-1.jpg","samples/community/UmarFarootAPS/scans/scan-type-1.jpg","outputs/community/UmarFarootAPS/scans/CheckedOMRs/scan-type-1.jpg","49.0","2468","A","C","B","C","A","D","B","C","B","D","C","A","C","D","B","C","A","B","C","A","C","B","D","C","A","B","D","C","A","C","B","D","B","A","C","D","B","C","A","C","D","A","C","D","A","B","D","C","A","C","D","B","C","A","C","D","B","C","D","A","B","C","B","C","D","B","D","A","C","B","D","A","B","C","B","A","C","D","B","A","C","B","C","B","A","D","B","A","C","D","B","D","B","C","B","D","A","C","B","C","B","C","D","B","C","A","B","C","A","D","C","B","D","B","A","B","C","D","D","C","B","A","B","C","D","C","B","A","B","C","D","C","B","A","B","C","D","C","B","A","B","C","B","A","C","B","A","C","A","B","C","B","C","B","A","C","A","C","B","B","C","B","A","C","A","B","A","B","A","B","C","D","B","C","A","C","D","C","A","C","B","A","C","A","B","C","B","D","A","B","C","D","C","B","B","C","A","B","C","B" - "scan-type-2.jpg","samples/community/UmarFarootAPS/scans/scan-type-2.jpg","outputs/community/UmarFarootAPS/scans/CheckedOMRs/scan-type-2.jpg","18.333333333333336","0234","A","B","C","D","C","B","A","B","C","D","C","B","A","B","C","D","C","B","A","B","C","D","C","B","A","B","C","D","C","B","A","B","C","D","C","B","A","B","C","D","C","B","A","B","C","D","C","B","A","B","A","D","","","AD","","","","A","D","","","","","","","D","A","","D","","A","","D","","","","A","","","C","","","D","","","A","","","","D","","C","","A","","C","","D","B","B","","","A","","D","","","","D","","","","","A","D","","","B","","","D","","","A","","","D","","","B","","","D","","","","A","D","","","A","","B","","D","","","","C","C","D","D","A","","D","","A","D","","","D","","B","D","","","D","","D","B","","","","D","","A","","","","D","","B","","","","","","D","","","A","","","A","","D","","","D" + "scan-type-2.jpg","samples/community/UmarFarootAPS/scans/scan-type-2.jpg","outputs/community/UmarFarootAPS/scans/CheckedOMRs/scan-type-2.jpg","20.0","0234","A","B","C","D","C","B","A","B","C","D","C","B","A","B","C","D","C","B","A","B","C","D","C","B","A","B","C","D","C","B","A","B","C","D","C","B","A","B","C","D","C","B","A","B","C","D","C","B","A","B","A","D","","","AD","","","","A","D","","","","","","","D","A","","D","","A","","D","","","","A","","","C","","","D","","","A","","","","D","","C","","A","","C","","D","B","B","","","A","","D","","","","D","","","","","A","D","","","B","","","D","","","A","","","D","","","B","","","D","","","","A","D","","","A","","B","","D","","","","C","C","D","D","A","","D","","A","D","","","D","","B","D","","","D","","D","B","","","","D","","A","","","","D","","B","","","","","","D","","","A","","","A","","D","","","D" ''', }) @@ -148,8 +183,8 @@ ''', 'AdrianSample/Results/Results_05AM.csv': ''' "file_id","input_path","output_path","score","q1","q2","q3","q4","q5" - "adrian_omr.png","samples/sample2/AdrianSample/adrian_omr.png","outputs/sample2/AdrianSample/CheckedOMRs/adrian_omr.png","5.0","B","E","A","C","B" - "adrian_omr_2.png","samples/sample2/AdrianSample/adrian_omr_2.png","outputs/sample2/AdrianSample/CheckedOMRs/adrian_omr_2.png","3.0","C","E","A","B","B" + "adrian_omr.png","samples/sample2/AdrianSample/adrian_omr.png","outputs/sample2/AdrianSample/CheckedOMRs/adrian_omr.png","0","B","E","A","C","B" + "adrian_omr_2.png","samples/sample2/AdrianSample/adrian_omr_2.png","outputs/sample2/AdrianSample/CheckedOMRs/adrian_omr_2.png","0","C","E","A","B","B" ''', }) diff --git a/src/tests/test_all_samples.py b/src/tests/test_all_samples.py index 00a25ff6..9efe94ff 100644 --- a/src/tests/test_all_samples.py +++ b/src/tests/test_all_samples.py @@ -42,6 +42,16 @@ def extract_sample_outputs(output_dir): return sample_outputs +def test_run_answer_key_using_csv(mocker, snapshot): + sample_outputs = run_sample(mocker, "answer-key/using-csv") + assert snapshot == sample_outputs + + +def test_run_answer_key_weighted_answers(mocker, snapshot): + sample_outputs = run_sample(mocker, "answer-key/weighted-answers") + assert snapshot == sample_outputs + + def test_run_sample1(mocker, snapshot): sample_outputs = run_sample(mocker, "sample1") assert snapshot == sample_outputs diff --git a/src/constants.py b/src/utils/constants.py similarity index 84% rename from src/constants.py rename to src/utils/constants.py index 1d329e1c..dfcad28a 100644 --- a/src/constants.py +++ b/src/utils/constants.py @@ -1,11 +1,3 @@ -""" - - OMRChecker - - Author: Udayraj Deshmukh - Github: https://github.com/Udayraj123 - -""" from dotmap import DotMap # Filenames @@ -48,7 +40,3 @@ CLR_WHITE = (255, 255, 255) CLR_GRAY = (130, 130, 130) CLR_DARK_GRAY = (100, 100, 100) - -# TODO: move to config.json -GLOBAL_PAGE_THRESHOLD_WHITE = 200 -GLOBAL_PAGE_THRESHOLD_BLACK = 100 diff --git a/src/utils/file.py b/src/utils/file.py index 5513381f..0b4c28a2 100644 --- a/src/utils/file.py +++ b/src/utils/file.py @@ -6,7 +6,7 @@ import pandas as pd -from src.logger import logger +from src.utils.logger import logger def load_json(path, **rest): @@ -23,6 +23,7 @@ class Paths: def __init__(self, output_dir): self.output_dir = output_dir self.save_marked_dir = output_dir.joinpath("CheckedOMRs") + self.image_metrics_dir = output_dir.joinpath("ImageMetrics") self.results_dir = output_dir.joinpath("Results") self.manual_dir = output_dir.joinpath("Manual") self.errors_dir = self.manual_dir.joinpath("ErrorFiles") @@ -39,12 +40,13 @@ def setup_dirs_for_paths(paths): os.mkdir(save_output_dir.joinpath("_MULTI_")) os.mkdir(save_output_dir.joinpath("_MULTI_", "stack")) - for save_output_dir in [paths.manual_dir, paths.results_dir]: - if not os.path.exists(save_output_dir): - logger.info(f"Created : {save_output_dir}") - os.makedirs(save_output_dir) - - for save_output_dir in [paths.multi_marked_dir, paths.errors_dir]: + for save_output_dir in [ + paths.manual_dir, + paths.results_dir, + paths.image_metrics_dir, + paths.multi_marked_dir, + paths.errors_dir, + ]: if not os.path.exists(save_output_dir): logger.info(f"Created : {save_output_dir}") os.makedirs(save_output_dir) diff --git a/src/utils/image.py b/src/utils/image.py index 9fa2445c..ff990c2a 100644 --- a/src/utils/image.py +++ b/src/utils/image.py @@ -1,18 +1,9 @@ -""" - - OMRChecker - - Author: Udayraj Deshmukh - Github: https://github.com/Udayraj123 - -""" - import cv2 import matplotlib.pyplot as plt import numpy as np -from src.constants import CLR_WHITE -from src.logger import logger +from src.utils.constants import CLR_WHITE +from src.utils.logger import logger plt.rcParams["figure.figsize"] = (10.0, 8.0) CLAHE_HELPER = cv2.createCLAHE(clipLimit=5.0, tileGridSize=(8, 8)) diff --git a/src/utils/interaction.py b/src/utils/interaction.py index c88e15bd..33eae329 100644 --- a/src/utils/interaction.py +++ b/src/utils/interaction.py @@ -3,8 +3,8 @@ import cv2 from screeninfo import get_monitors -from src.logger import logger from src.utils.image import ImageUtils +from src.utils.logger import logger monitor_window = get_monitors()[0] diff --git a/src/logger.py b/src/utils/logger.py similarity index 100% rename from src/logger.py rename to src/utils/logger.py diff --git a/src/utils/parsing.py b/src/utils/parsing.py index cb6901c6..f4eabcaf 100644 --- a/src/utils/parsing.py +++ b/src/utils/parsing.py @@ -2,12 +2,13 @@ from copy import deepcopy from fractions import Fraction +import numpy as np from deepmerge import Merger from dotmap import DotMap -from src.constants import FIELD_LABEL_NUMBER_REGEX from src.defaults import CONFIG_DEFAULTS, TEMPLATE_DEFAULTS from src.schemas.constants import FIELD_STRING_REGEX_GROUPS +from src.utils.constants import FIELD_LABEL_NUMBER_REGEX from src.utils.file import load_json from src.utils.validations import ( validate_config_json, @@ -111,3 +112,17 @@ def parse_float_or_fraction(result): else: result = float(result) return result + + +def default_dump(obj): + return ( + bool(obj) + if isinstance(obj, np.bool_) + else ( + obj.to_json() + if hasattr(obj, "to_json") + else obj.__dict__ + if hasattr(obj, "__dict__") + else obj + ) + ) diff --git a/src/utils/validations.py b/src/utils/validations.py index fc833216..4426b2c0 100644 --- a/src/utils/validations.py +++ b/src/utils/validations.py @@ -1,19 +1,11 @@ -""" - - OMRChecker - - Author: Udayraj Deshmukh - Github: https://github.com/Udayraj123 - -""" import re import jsonschema from jsonschema import validate from rich.table import Table -from src.logger import console, logger from src.schemas import SCHEMA_JSONS, SCHEMA_VALIDATORS +from src.utils.logger import console, logger def validate_evaluation_json(json_data, evaluation_path):