diff --git a/arknights_mower/__init__.py b/arknights_mower/__init__.py index 79d67e4e5..d5fe1fda8 100644 --- a/arknights_mower/__init__.py +++ b/arknights_mower/__init__.py @@ -15,4 +15,4 @@ __cli__ = not (__pyinstall__ and not sys.argv[1:]) __system__ = platform.system().lower() -__version__ = '2.0.1' +__version__ = '2.0.2' diff --git a/arknights_mower/data/ocr.py b/arknights_mower/data/ocr.py index 0f6e50720..e4ff450ae 100644 --- a/arknights_mower/data/ocr.py +++ b/arknights_mower/data/ocr.py @@ -30,116 +30,4 @@ 'Castle3': 'Castle-3', 'Lancet2': 'Lancet-2', 'THRMEX': 'THRM-EX', - '佛影': '傀影', - '使影': '傀影', - '倪影': '傀影', - '影': '傀影', - '愧影': '傀影', - '国克洛丝': '克罗丝', - '家克洛丝': '克罗丝', - '卡湿利安': '卡涅利安', - '6卡缇': '卡缇', - 'G卡缇': '卡缇', - '了G卡缇': '卡缇', - '叫米': '古米', - '叶米': '古米', - '叶': '吽', - '哗': '吽', - '耳': '吽', - '6夜刀': '夜刀', - '面夜刀': '夜刀', - '身': '宴', - '唉峨': '嵯峨', - '峡峨': '嵯峨', - '峨': '嵯峨', - '槎峨': '嵯峨', - '爱哦': '嵯峨', - '送哦': '嵯峨', - '送峨': '嵯峨', - '可巡林者': '巡林者', - '巡林': '巡林者', - '幽灵': '幽灵鲨', - '幽灵滥': '幽灵鲨', - '幽灵盗': '幽灵鲨', - '惊': '惊蛰', - '早器': '早露', - '6杜林': '杜林', - 'I6杜林': '杜林', - '了G杜林': '杜林', - '桑甚': '桑葚', - '桑衰': '桑葚', - '来': '梓兰', - '粹兰': '梓兰', - '辣刺': '棘刺', - '森': '森蚺', - '森蚪': '森蚺', - '森蛾': '森蚺', - '森螃': '森蚺', - '深': '深靛', - '灰': '灰烬', - 'm炎熔': '炎熔', - 'n炎熔': '炎熔', - '了元炎熔': '炎熔', - '光恪': '炎熔', - '光烙': '炎熔', - '光焙': '炎熔', - '光熔': '炎熔', - '听炎熔': '炎熔', - '我恪': '炎熔', - '我烙': '炎熔', - '我焙': '炎熔', - '我熔': '炎熔', - '斤炎熔': '炎熔', - '炎恪': '炎熔', - '炎烙': '炎熔', - '炎焙': '炎熔', - '而炎熔': '炎熔', - '石': '燧石', - '狮喝': '狮蝎', - '狮蜴': '狮蝎', - '攻兰莎': '玫兰莎', - '放兰莎': '玫兰莎', - '救兰莎': '玫兰莎', - '欢兰莎': '玫兰莎', - '致兰莎': '玫兰莎', - '琴迎': '琴柳', - '自金': '白金', - '白面': '白面鸮', - '白面男': '白面鸮', - '白面鸭': '白面鸮', - '白面鹑': '白面鸮', - '白面鹗': '白面鸮', - '乐': '砾', - '优乐': '砾', - '既': '砾', - '研': '砾', - '舔': '砾', - '米格各': '米格鲁', - '罗比塔': '罗比菈塔', - '罗比拉塔': '罗比菈塔', - '罗比益塔': '罗比菈塔', - '罗比藏塔': '罗比菈塔', - '5美蓉': '芙蓉', - 'G芙蓉': '芙蓉', - '可美蓉': '芙蓉', - '听美蓉': '芙蓉', - '市美蓉': '芙蓉', - '美蓉': '芙蓉', - '.芬': '芬', - '劳': '芬', - '节草': '苇草', - '草': '苇草', - '荣草': '苇草', - '蒂草': '苇草', - '蛇居箱': '蛇屠箱', - '野': '野鬃', - '野景': '野鬃', - '野秦': '野鬃', - '野紧': '野鬃', - '星': '陨星', - '碗星': '陨星', - '限星': '陨星', - '番草': '香草', - '看草': '香草', - '沾': '黑', } diff --git a/arknights_mower/data/recruit.py b/arknights_mower/data/recruit.py index 8c02a4bba..ce2093d92 100644 --- a/arknights_mower/data/recruit.py +++ b/arknights_mower/data/recruit.py @@ -1,6 +1,6 @@ # recruit database -# TODO: check/update from gamedata +# TODO check/update from gamedata recruit_database = [ ('Lancet-2', 1, ['医疗干员', '远程位', '治疗', '支援机械']), diff --git a/arknights_mower/fonts/SourceHanSansSC-Bold.otf b/arknights_mower/fonts/SourceHanSansSC-Bold.otf index 027675fb9..11af8768c 100644 Binary files a/arknights_mower/fonts/SourceHanSansSC-Bold.otf and b/arknights_mower/fonts/SourceHanSansSC-Bold.otf differ diff --git a/arknights_mower/resources/agent_on_shift.png b/arknights_mower/resources/agent_on_shift.png new file mode 100644 index 000000000..d9f3d7c62 Binary files /dev/null and b/arknights_mower/resources/agent_on_shift.png differ diff --git a/arknights_mower/solvers/base_construct.py b/arknights_mower/solvers/base_construct.py index a46ba7d4e..4068c60c6 100644 --- a/arknights_mower/solvers/base_construct.py +++ b/arknights_mower/solvers/base_construct.py @@ -2,7 +2,7 @@ import numpy as np -from ..utils import detector, segment +from ..utils import detector, segment, character_recognize from ..utils import typealias as tp from ..utils.device import Device from ..utils.log import logger @@ -363,29 +363,43 @@ def drone(self, room: str): def choose_agent(self, agent: list[str]) -> None: logger.info(f'安排干员:{agent}') - agent = set(agent) - - # 滑动到最左边 h, w = self.recog.h, self.recog.w - for _ in range(9): - self.swipe((w//2, h//2), (w//2, 0), interval=0) - self.swipe((w//2, h//2), (w//2, 0), interval=3, rebuild=False) - checked = set() # 已经识别过的干员 - pre = set() # 上次识别出的干员 - error_count = 0 - while True: + # 在 agent 中 'Free' 表示任意空闲干员 + free_num = agent.count('Free') + agent = set(agent) - set(['Free']) + + # 安排指定干员 + if len(agent): + + # 滑动到最左边 + for _ in range(9): + self.swipe((w//2, h//2), (w//2, 0), interval=0) + self.swipe((w//2, h//2), (w//2, 0), interval=3, rebuild=False) + checked = set() # 已经识别过的干员 + pre = set() # 上次识别出的干员 + error_count, restart = 0, False while len(agent): try: # 识别干员 - ret = segment.agent(self.recog.img) # 返回的顺序是从左往右从上往下 + ret = character_recognize.agent(self.recog.img) # 返回的顺序是从左往右从上往下 except RecognizeError as e: - logger.warning(e) error_count += 1 - if error_count >= 3: + if error_count < 3: + logger.debug(e) + self.sleep(3) + elif not restart: + # 重新滑动到最左边并重置变量 + logger.warning(e) + for _ in range(9): + self.swipe((w//2, h//2), (w//2, 0), interval=0) + self.swipe((w//2, h//2), (w//2, 0), interval=3, rebuild=False) + checked = set() + pre = set() + error_count, restart = 0, True + else: raise e - self.sleep(3) continue # 提取识别出来的干员的名字 @@ -394,7 +408,7 @@ def choose_agent(self, agent: list[str]) -> None: error_count += 1 if error_count >= 3: logger.warning(f'未找到干员:{list(agent)}') - return + break else: pre = agent_name @@ -405,18 +419,61 @@ def choose_agent(self, agent: list[str]) -> None: for name in agent_name & agent: for y in ret: if y[0] == name: - self.tap((y[1][0]), rebuild=False) + self.tap((y[1][0]), interval=0, rebuild=False) break agent.remove(name) # 如果已经完成选择则退出 if len(agent) == 0: - return + break st = ret[-2][1][2] # 起点 ed = ret[0][1][1] # 终点 self.swipe_noinertia(st, (ed[0]-st[0], 0)) + # 安排空闲干员 + if free_num: + + # 滑动到最左边 + for _ in range(9): + self.swipe((w//2, h//2), (w//2, 0), interval=0) + self.swipe((w//2, h//2), (w//2, 0), interval=3, rebuild=False) + + while free_num: + try: + # 识别空闲干员 + ret, st, ed = segment.free_agent(self.recog.img) # 返回的顺序是从左往右从上往下 + except RecognizeError as e: + error_count += 1 + if error_count < 3: + logger.debug(e) + self.sleep(3) + elif not restart: + # 重新滑动到最左边并重置变量 + logger.warning(e) + h, w = self.recog.h, self.recog.w + for _ in range(9): + self.swipe((w//2, h//2), (w//2, 0), interval=0) + self.swipe((w//2, h//2), (w//2, 0), + interval=3, rebuild=False) + checked = set() + pre = set() + error_count, restart = 0, True + else: + raise e + continue + + while free_num and len(ret): + self.tap(ret[0], interval=0, rebuild=False) + free_num -= 1 + ret = ret[1:] + + # 如果已经完成选择则退出 + if free_num == 0: + break + + self.swipe_noinertia(st, (ed[0]-st[0], 0)) + def agent_arrange(self, plan: tp.BasePlan) -> None: """ 基建排班 """ logger.info('基建:排班') diff --git a/arknights_mower/utils/character_recognize.py b/arknights_mower/utils/character_recognize.py new file mode 100644 index 000000000..e3b8bcdc2 --- /dev/null +++ b/arknights_mower/utils/character_recognize.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import cv2 +import traceback +import numpy as np +from copy import deepcopy +from matplotlib import pyplot as plt +from PIL import Image, ImageDraw, ImageFont + +from . import segment +from ..ocr import ocrhandle +from .image import img2bytes +from .log import logger, save_screenshot +from .recognize import RecognizeError +from .. import __rootdir__ +from ..data.agent import agent_list + + +def poly_center(poly): + return (np.average([x[0] for x in poly]), np.average([x[1] for x in poly])) + + +def in_poly(poly, p): + return poly[0, 0] <= p[0] <= poly[2, 0] and poly[0, 1] <= p[1] <= poly[2, 1] + + +char_map = {} +agent_sorted = sorted(deepcopy(agent_list), key=len) +origin = origin_kp = origin_des = None + +FLANN_INDEX_KDTREE = 0 +GOOD_DISTANCE_LIMIT = 0.7 +SIFT = cv2.SIFT_create() + + +def agent_sift_init(): + global origin, origin_kp, origin_des + if origin is None: + logger.debug('agent_sift_init') + + height = width = 2000 + lnum = 25 + cell = height // lnum + + img = np.zeros((height, width, 3), dtype=np.uint8) + img = Image.fromarray(img) + + font = ImageFont.truetype( + f'{__rootdir__}/fonts/SourceHanSansSC-Bold.otf', size=30, encoding='utf-8') + chars = sorted(list(set(''.join([x for x in agent_list])))) + assert len(chars) <= (lnum - 2) * (lnum - 2) + + for idx, char in enumerate(chars): + x, y = idx % (lnum - 2) + 1, idx // (lnum - 2) + 1 + char_map[(x, y)] = char + ImageDraw.Draw(img).text((x * cell, y * cell), + char, (255, 255, 255), font=font) + + origin = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2GRAY) + origin_kp, origin_des = SIFT.detectAndCompute(origin, None) + + +def sift_recog(query, resolution, draw=False): + """ + 使用 SIFT 提取特征点识别干员名称 + """ + agent_sift_init() + + query = cv2.cvtColor(np.array(query), cv2.COLOR_RGB2GRAY) + + # the height & width of query image + height, width = query.shape + + multi = 2 * (resolution / 1080) + query = cv2.resize(query, (int(width * multi), int(height * multi))) + query_kp, query_des = SIFT.detectAndCompute(query, None) + + # build FlannBasedMatcher + index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5) + search_params = dict(checks=50) + flann = cv2.FlannBasedMatcher(index_params, search_params) + matches = flann.knnMatch(query_des, origin_des, k=2) + + # store all the good matches as per Lowe's ratio test + good = [] + for x, y in matches: + if x.distance < GOOD_DISTANCE_LIMIT * y.distance: + good.append(x) + + if draw: + result = cv2.drawMatches( + query, query_kp, origin, origin_kp, good, None) + plt.imshow(result, 'gray') + plt.show() + + count = {} + + for x in good: + x, y = origin_kp[x.trainIdx].pt + c = char_map[(int(x) // 80, int(y) // 80)] + count[c] = count.get(c, 0) + 1 + + best = None + best_score = 0 + for x in agent_sorted: + score = 0 + for c in set(x): + score += count.get(c, -1) + if score > best_score: + best = x + best_score = score + + logger.debug(f'segment.sift_recog: {count}, {best}') + + return best + + +def agent(img, draw=False): + """ + 识别干员总览界面的干员名称 + """ + try: + height, width, _ = img.shape + resolution = height + left, right = 0, width + + # 异形屏适配 + while np.max(img[:, right-1]) < 100: + right -= 1 + while np.max(img[:, left]) < 100: + left += 1 + + # 去除左侧干员详情 + x0 = left + 1 + while not (img[height-1, x0-1, 0] > img[height-1, x0, 0] + 10 and abs(int(img[height-1, x0, 0]) - int(img[height-1, x0+1, 0])) < 5): + x0 += 1 + + # ocr 初步识别干员名称 + ocr = ocrhandle.predict(img[:, x0:right]) + + # 获取分割结果 + ret = segment.agent(img, draw) + + # 确定位置后开始精确识别 + ret_succ = [] + ret_fail = [] + ret_agent = [] + for poly in ret: + found_ocr, fx = None, 0 + for x in ocr: + cx, cy = poly_center(x[2]) + if in_poly(poly, (cx+x0, cy)) and cx > fx: + fx = cx + found_ocr = x + + if found_ocr is not None: + x = found_ocr + if x[1] in agent_list and x[1] not in ['砾', '陈']: # ocr 经常会把这两个搞错 + ret_agent.append(x[1]) + ret_succ.append(poly) + continue + __img = img[poly[0, 1]:poly[2, 1], poly[0, 0]:poly[2, 0]] + res = sift_recog(__img, resolution, draw) + if res is not None: + logger.debug(f'干员名称识别修正:{x[1]} -> {res}') + ret_agent.append(res) + ret_succ.append(poly) + continue + logger.warning( + f'干员名称识别异常:{x[1]} 为不存在的数据,请报告至 https://github.com/Konano/arknights-mower/issues') + save_screenshot( + img2bytes(__img), subdir=f'agent/{height}x{width}') + else: + __img = img[poly[0, 1]:poly[2, 1], poly[0, 0]:poly[2, 0]] + if 80 <= np.min(__img): + continue + res = sift_recog(__img, resolution, draw) + if res is not None: + ret_agent.append(res) + ret_succ.append(poly) + continue + logger.warning(f'干员名称识别异常:区域 {poly.tolist()}') + save_screenshot( + img2bytes(__img), subdir=f'agent/{height}x{width}') + ret_fail.append(poly) + + if len(ret_fail): + save_screenshot( + img2bytes(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)), subdir=f'agentlist/{height}x{width}') + if draw: + __img = img.copy() + cv2.polylines(__img, ret_fail, True, + (255, 0, 0), 3, cv2.LINE_AA) + plt.imshow(__img) + plt.show() + + logger.debug(f'character_recognize.agent: {ret_agent}') + logger.debug(f'character_recognize.agent: {[x.tolist() for x in ret]}') + return list(zip(ret_agent, ret_succ)) + + except Exception as e: + logger.debug(traceback.format_exc()) + raise RecognizeError(e) diff --git a/arknights_mower/utils/config.py b/arknights_mower/utils/config.py index 3bad0776f..40bf9b716 100644 --- a/arknights_mower/utils/config.py +++ b/arknights_mower/utils/config.py @@ -117,10 +117,10 @@ def init_config() -> None: COMPATIBILITY_MODE = __get('device/compatibility_mode', False) global ADB_TOUCH_DEVICE - ADB_TOUCH_DEVICE = __get('adb_touch_device', None) + ADB_TOUCH_DEVICE = __get('device/adb_touch_device', None) global ADB_MNT_PORT - ADB_MNT_PORT = __get('adb_mnt_port', 20937) + ADB_MNT_PORT = __get('device/adb_mnt_port', 20937) global APPNAME APPNAME = __get('app/package_name', 'com.hypergryph.arknights') + \ diff --git a/arknights_mower/utils/detector.py b/arknights_mower/utils/detector.py index 4dd870934..4b2bff796 100644 --- a/arknights_mower/utils/detector.py +++ b/arknights_mower/utils/detector.py @@ -1,7 +1,11 @@ +import cv2 import numpy as np from . import typealias as tp from .log import logger +from .matcher import Matcher +from .image import loadimg +from .. import __rootdir__ def confirm(img: tp.Image) -> tp.Coordinate: @@ -158,3 +162,14 @@ def visit_next(img: tp.Image) -> tp.Coordinate: point = (right - 10, (up + down) // 2) logger.debug(f'detector.visit_next: {point}') return point + + +on_shift = loadimg(f'{__rootdir__}/resources/agent_on_shift.png', True) + + +def is_on_shift(img: tp.Image) -> bool: + """ + 检测干员是否正在工作中 + """ + matcher = Matcher(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)) + return matcher.match(on_shift, judge=False) is not None diff --git a/arknights_mower/utils/device/client.py b/arknights_mower/utils/device/client.py index d9b36b7f9..d264be100 100644 --- a/arknights_mower/utils/device/client.py +++ b/arknights_mower/utils/device/client.py @@ -68,6 +68,7 @@ def __available_devices(self) -> list[str]: def __exec(self, cmd: str, adb_bin: str = None) -> None: """ exec command with adb_bin """ + logger.debug(f'client.__exec: {cmd}') if adb_bin is None: adb_bin = self.adb_bin subprocess.run([adb_bin, cmd], check=True) @@ -83,6 +84,7 @@ def __run(self, cmd: str, restart: bool = True) -> Optional[bytes]: error_limit -= 1 self.__exec('kill-server') self.__exec('start-server') + time.sleep(10) continue return @@ -98,6 +100,7 @@ def __check_adb(self, adb_bin: str) -> bool: return True self.__exec('kill-server', adb_bin) self.__exec('start-server', adb_bin) + time.sleep(10) if self.check_server_alive(False): return True except (FileNotFoundError, subprocess.CalledProcessError): @@ -124,6 +127,7 @@ def run(self, cmd: str) -> Optional(bytes): error_limit -= 1 self.__exec('kill-server') self.__exec('start-server') + time.sleep(10) self.__init_device() continue raise e diff --git a/arknights_mower/utils/device/minitouch.py b/arknights_mower/utils/device/minitouch.py index c6cb4cb08..d428f08de 100644 --- a/arknights_mower/utils/device/minitouch.py +++ b/arknights_mower/utils/device/minitouch.py @@ -68,6 +68,7 @@ class MiniTouch(object): def __init__(self, client: Client, touch_device: str = config.ADB_TOUCH_DEVICE) -> None: self.client = client self.touch_device = touch_device + self.process = None self.start() def start(self) -> None: diff --git a/arknights_mower/utils/device/session.py b/arknights_mower/utils/device/session.py index 11d3bb255..be06cc05f 100644 --- a/arknights_mower/utils/device/session.py +++ b/arknights_mower/utils/device/session.py @@ -103,4 +103,5 @@ def devices_list(self) -> list[tuple[str, str]]: """ returns list of devices that the adb server knows """ resp = self.request('host:devices').response().decode(errors='ignore') devices = [tuple(line.split('\t')) for line in resp.splitlines()] + logger.debug(devices) return devices diff --git a/arknights_mower/utils/device/socket.py b/arknights_mower/utils/device/socket.py index 18a818f17..5f300effd 100644 --- a/arknights_mower/utils/device/socket.py +++ b/arknights_mower/utils/device/socket.py @@ -11,10 +11,12 @@ class Socket(object): def __init__(self, server: tuple[str, int], timeout: int) -> None: logger.debug(f'server: {server}, timeout: {timeout}') try: + self.sock = None self.sock = socket.create_connection(server, timeout=timeout) self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - except ConnectionRefusedError: - logger.error(f'ConnectionRefusedError: {self.server}') + except ConnectionRefusedError as e: + logger.error(f'ConnectionRefusedError: {server}') + raise e def __enter__(self) -> Socket: return self diff --git a/arknights_mower/utils/image.py b/arknights_mower/utils/image.py index 4b65fb153..393e9f82a 100644 --- a/arknights_mower/utils/image.py +++ b/arknights_mower/utils/image.py @@ -14,6 +14,11 @@ def bytes2img(data: bytes, gray: bool = False) -> Union[tp.Image, tp.GrayImage]: return cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) +def img2bytes(img) -> bytes: + """ bytes -> image """ + return cv2.imencode('.png', img)[1] + + def loadimg(filename: str, gray: bool = False) -> Union[tp.Image, tp.GrayImage]: """ load image from file """ logger.debug(filename) diff --git a/arknights_mower/utils/segment.py b/arknights_mower/utils/segment.py index e746e36dd..7ecb756ba 100644 --- a/arknights_mower/utils/segment.py +++ b/arknights_mower/utils/segment.py @@ -2,18 +2,15 @@ import cv2 import traceback -import imagehash import numpy as np from matplotlib import pyplot as plt -from PIL import Image, ImageDraw, ImageFont from . import typealias as tp -from .image import rgb2gray, thres0 -from .log import logger +from . import detector from .recognize import RecognizeError from .. import __rootdir__ from ..data.agent import agent_list -from ..data.ocr import ocr_error +from .log import logger from ..ocr import ocrhandle @@ -80,7 +77,7 @@ def ptp(j: int) -> int: while ptp(left) < 50: left += 1 - split_x = [left + (right - left) // 5 * i for i in range(0, 6)] + split_x = [left + (right - left) // 5 * i for i in range(0, 6)] split_y = [up_1, (up_1 + down) // 2, down] ret = [] @@ -338,61 +335,30 @@ def worker(img: tp.Image, draw: bool = False) -> tuple[list[tp.Rectangle], tp.Re raise RecognizeError(e) -agent_ahash = None - - -def agent_ahash_init(): - global agent_ahash - if agent_ahash is None: - logger.debug('agent_ahash_init') - agent_ahash = {} - font = ImageFont.truetype( - f'{__rootdir__}/fonts/SourceHanSansSC-Bold.otf', size=30, encoding='utf-8') - for text in agent_list: - dt = np.zeros((500, 500, 3), dtype=int) - img = Image.fromarray(np.uint8(dt)) - ImageDraw.Draw(img).text((0, 0), text, (255, 255, 255), font=font) - img = np.array(img) - - x0 = 0 - while (img[:, x0] == 0).all(): - x0 += 1 - x1 = img.shape[1] - while (img[:, x1-1] == 0).all(): - x1 -= 1 - y0 = 0 - while (img[y0, x0:x1] == 0).all(): - y0 += 1 - y1 = img.shape[0] - while (img[y1-1, x0:x1] == 0).all(): - y1 -= 1 - - agent_ahash[text] = str(imagehash.average_hash( - Image.fromarray(img[y0:y1, x0:x1]), 16)) - - -def agent(im, draw=False): +def agent(img, draw=False): """ 干员总览的图像分割算法 """ try: - h, w, _ = im.shape - - # gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) - # gray = 255 - gray + height, width, _ = img.shape + resolution = height + left, right = 0, width - l, r = 0, w - while np.max(im[:, r-1]) < 100: - r -= 1 - while np.max(im[:, l]) < 100: - l += 1 + # 异形屏适配 + while np.max(img[:, right-1]) < 100: + right -= 1 + while np.max(img[:, left]) < 100: + left += 1 - x0 = l + 1 - while not (im[h-1, x0-1, 0] > im[h-1, x0, 0] + 10 and abs(int(im[h-1, x0, 0]) - int(im[h-1, x0+1, 0])) < 5): + # 去除左侧干员详情 + x0 = left + 1 + while not (img[height-1, x0-1, 0] > img[height-1, x0, 0] + 10 and abs(int(img[height-1, x0, 0]) - int(img[height-1, x0+1, 0])) < 5): x0 += 1 - ocr = ocrhandle.predict(im[:, x0:r]) + # ocr 初步识别干员名称 + ocr = ocrhandle.predict(img[:, x0:right]) + # 保留上下两行皆被成功识别出来的干员名称的识别结果 segs = [(min(x[2][0][1], x[2][1][1]), max(x[2][2][1], x[2][3][1])) for x in ocr if x[1] in agent_list] while True: @@ -411,13 +377,16 @@ def agent(im, draw=False): else: break segs = sorted(segs) + + # 计算纵向的四个高度,[y0, y1] 是第一行干员名称的纵向坐标范围,[y2, y3] 是第二行干员名称的纵向坐标范围 for x in segs: - if x[1] < h // 2: + if x[1] < height // 2: y0, y1 = x y2, y3 = x card_gap = y1 - y0 logger.debug([y0, y1, y2, y3]) + # 预计算:横向坐标范围集合 x_set = set() for x in ocr: if x[1] in agent_list and (y0 <= x[2][0][1] <= y1 or y2 <= x[2][0][1] <= y3): @@ -425,200 +394,102 @@ def agent(im, draw=False): x_set.add(x[2][2][0]) x_set = sorted(x_set) logger.debug(x_set) + + # 排除掉一些重叠的范围,获得最终的横向坐标范围 + x_gap = 40 * (resolution / 1080) x_set = [x_set[0]] + \ - [y for x, y in zip(x_set[:-1], x_set[1:]) if y - x > 80] + [y for x, y in zip(x_set[:-1], x_set[1:]) if y - x > x_gap * 2] gap = [y - x for x, y in zip(x_set[:-1], x_set[1:])] - gap = [x for x in gap if x - np.min(gap) < 40] + gap = [x for x in gap if x - np.min(gap) < x_gap] gap = int(np.average(gap)) for x, y in zip(x_set[:-1], x_set[1:]): - if y - x > 40: + if y - x > x_gap: gap_num = round((y - x) / gap) for i in range(1, gap_num): x_set.append(int(x + (y - x) / gap_num * i)) while np.min(x_set) > 0: x_set.append(np.min(x_set) - gap) - while np.max(x_set) < r - x0: + while np.max(x_set) < right - x0: x_set.append(np.max(x_set) + gap) x_set = sorted(x_set) logger.debug(x_set) + # 获得所有的干员名称对应位置 ret = [] for x1, x2 in zip(x_set[:-1], x_set[1:]): - if 0 <= x1+card_gap and x0+x2+5 <= r: + if 0 <= x1+card_gap and x0+x2+5 <= right: ret += [get_poly(x0+x1+card_gap, x0+x2+5, y0, y1), get_poly(x0+x1+card_gap, x0+x2+5, y2, y3)] - def poly_center(poly): - return (np.average([x[0] for x in poly]), np.average([x[1] for x in poly])) - - def in_poly(poly, p): - return poly[0, 0] <= p[0] <= poly[2, 0] and poly[0, 1] <= p[1] <= poly[2, 1] - - # if draw: - # cv2.polylines(im, ret, True, (255, 0, 0), 3, cv2.LINE_AA) - # plt.imshow(im) - # plt.show() - - def flood(img, dt): - h, w = img.shape - while True: - pre_count = (dt > 0).sum() - for x in range(1, w): - dt[:, x][(dt[:, x-1] > 0) & (img[:, x] > 0)] = 1 - for y in range(h-2, -1, -1): - dt[y][(dt[y+1] > 0) & (img[y] > 0)] = 1 - for x in range(w-2, -1, -1): - dt[:, x][(dt[:, x+1] > 0) & (img[:, x] > 0)] = 1 - for y in range(1, h): - dt[y][(dt[y-1] > 0) & (img[y] > 0)] = 1 - if pre_count == (dt > 0).sum(): - break + # draw for debug + if draw: + __img = img.copy() + cv2.polylines(__img, ret, True, (255, 0, 0), 3, cv2.LINE_AA) + plt.imshow(__img) + plt.show() - def ahash_recog(origin_img, scope): - agent_ahash_init() - origin_img = origin_img[scope[0, 1]:scope[2, 1], scope[0, 0]:scope[2, 0]] - h, w = origin_img.shape[:2] - thresh = 70 - while True: - try: - img = rgb2gray(thres0(origin_img, thresh)) - dt = np.zeros((h, w), dtype=np.uint8) - for y in range(h): - if img[y, w-1] != 0: - dt[y, w-1] = 1 - flood(img, dt) - for y in range(h-1, h//2, -1): - count = 0 - for x in range(w-1, -1, -1): - if dt[y, x] != 0: - count += 1 - else: - break - if not (dt[y, :] > 0).all(): - for x in range(w): - if dt[y, x] != 0: - count += 1 - else: - break - if (dt[y] > 0).sum() != count: - logger.debug(f'{y}, {count}') - raise FloodCheckFailed - - for y in range(h): - if img[y, 0] != 0: - dt[y, 0] = 1 - for x in range(w): - if img[h-1, x] != 0: - dt[h-1, x] = 1 - if img[0, x] != 0: - dt[0, x] = 1 - flood(img, dt) - img[dt > 0] = 0 - if (img > 0).sum() == 0: - raise FloodCheckFailed - - x0, x1, y0, y1 = 0, w, 0, h - while True: - while (img[y0:y1, x0] == 0).all(): - x0 += 1 - while (img[y0:y1, x1-1] == 0).all(): - x1 -= 1 - while (img[y0, x0:x1] == 0).all(): - y0 += 1 - while (img[y1-1, x0:x1] == 0).all(): - y1 -= 1 - for x in range(x0, x1-10+1): - if (img[y0:y1, x:x+10] == 0).all(): - x0 = x - break - if (img[y0:y1, x0] == 0).all(): - continue - for y in range(y0, y1-10+1): - if (img[y:y+10, x0:x1] == 0).all(): - y0 = y - break - if (img[y0, x0:x1] == 0).all(): - continue - break + logger.debug(f'segment.agent: {[x.tolist() for x in ret]}') + return ret + + except Exception as e: + logger.debug(traceback.format_exc()) + raise RecognizeError(e) - dt = np.zeros((y1-y0, x1-x0, 3), dtype=np.uint8) - dt[:, :, 0] = img[y0:y1, x0:x1] - dt[:, :, 1] = img[y0:y1, x0:x1] - dt[:, :, 2] = img[y0:y1, x0:x1] - ahash = str(imagehash.average_hash(Image.fromarray(dt), 16)) - p = [(bin(int(ahash, 16) ^ int(agent_ahash[x], 16)).count('1'), x) for x in agent_ahash.keys()] - p = sorted(p) - logger.debug(p[:10]) - if p[1][0] - p[0][0] < 10: - raise FloodCheckFailed - logger.debug(p[0][1]) - return p[0][1] - - except FloodCheckFailed: - thresh += 5 - logger.debug(f'add thresh to {thresh}') - if thresh > 100: - break - continue - return None - ret_succ = [] - ret_fail = [] - ret_agent = [] +def free_agent(img, draw=False): + """ + 识别未在工作中的干员 + """ + try: + height, width, _ = img.shape + resolution = height + left, right = 0, width + + # 异形屏适配 + while np.max(img[:, right-1]) < 100: + right -= 1 + while np.max(img[:, left]) < 100: + left += 1 + + # 去除左侧干员详情 + x0 = left + 1 + while not (img[height-1, x0-1, 0] > img[height-1, x0, 0] + 10 and abs(int(img[height-1, x0, 0]) - int(img[height-1, x0+1, 0])) < 5): + x0 += 1 + + # 获取分割结果 + ret = agent(img, draw) + st = ret[-2][2] # 起点 + ed = ret[0][1] # 终点 + + # 去除空白的干员框,同时收集 y 坐标 + y_set = set() for poly in ret: - found_ocr, fx = None, 0 - for x in ocr: - cx, cy = poly_center(x[2]) - if in_poly(poly, (cx+x0, cy)) and cx > fx: - fx = cx - found_ocr = x - - if found_ocr is not None: - x = found_ocr - if x[1] in agent_list: - ret_agent.append(x[1]) - ret_succ.append(poly) - continue - res = ocrhandle.predict( - thres0(im[poly[0, 1]-20:poly[2, 1]+20, poly[0, 0]-20:poly[2, 0]+20], 70)) - if len(res) > 0 and res[0][1] in agent_list: - x = res[0] - ret_agent.append(x[1]) - ret_succ.append(poly) - continue - res = ahash_recog(im, poly) - if res is not None: - logger.warning(f'干员名称识别异常:{x[1]} 应为 {res}') - ocr_error[x[1]] = res - ret_agent.append(res) - ret_succ.append(poly) - continue - logger.warning( - f'干员名称识别异常:{x[1]} 为不存在的数据,请报告至 https://github.com/Konano/arknights-mower/issues') - else: - res = ocrhandle.predict( - thres0(im[poly[0, 1]-20:poly[2, 1]+20, poly[0, 0]-20:poly[2, 0]+20], 70)) - if len(res) > 0 and res[0][1] in agent_list: - res = res[0][1] - ret_agent.append(res) - ret_succ.append(poly) - continue - res = ahash_recog(im, poly) - if res is not None: - ret_agent.append(res) - ret_succ.append(poly) - continue - logger.warning(f'干员名称识别异常:区域 {poly}') - ret_fail.append(poly) - - if draw and len(ret_fail): - cv2.polylines(im, ret_fail, True, (255, 0, 0), 3, cv2.LINE_AA) - plt.imshow(im) + __img = img[poly[0, 1]:poly[2, 1], poly[0, 0]:poly[2, 0]] + y_set.add(poly[0, 1]) + y_set.add(poly[2, 1]) + if 80 <= np.min(__img): + ret.remove(poly) + + y1, y2, y4, y5 = sorted(list(y_set)) + y0 = height - y5 + y3 = y0 - y2 + y5 + + ret_free = [] + for poly in ret: + poly[poly == y1] = y0 + poly[poly == y4] = y3 + __img = img[poly[0, 1]:poly[2, 1], poly[0, 0]:poly[2, 0]] + if not detector.is_on_shift(__img): + ret_free.append(poly) + + if draw: + __img = img.copy() + cv2.polylines(__img, ret_free, True, (255, 0, 0), 3, cv2.LINE_AA) + plt.imshow(__img) plt.show() - logger.debug(f'segment.agent: {ret_agent}') - logger.debug(f'segment.agent: {[x.tolist() for x in ret]}') - return list(zip(ret_agent, ret_succ)) + logger.debug(f'segment.free_agent: {[x.tolist() for x in ret_free]}') + return ret_free, st, ed except Exception as e: logger.debug(traceback.format_exc())