Airtest源碼分析--圖像識別整體流程

上期回顧:Airtest-API精講之Template


以下基於
python3.8;airtest1.2.2;pocoui1.0.83

之前講了圖像識別的基礎——Template類:======Template類
這次我們看下Airtest圖像識別的整體流程。

我們以touch()接口爲例,AirtestIDE中touch怎麼用可以看:AirtestIDE基本功能(一)

進入查看touch源碼

# 源碼路徑 your_python_path/site-packages/airtest/core/api.py
def touch(v, times=1, **kwargs):
    """
    Perform the touch action on the device screen

    :param v: target to touch, either a ``Template`` instance or absolute coordinates (x, y)
    :param times: how many touches to be performed
    :param kwargs: platform specific `kwargs`, please refer to corresponding docs
    :return: finial position to be clicked, e.g. (100, 100)
    """
    if isinstance(v, Template):
        pos = loop_find(v, timeout=ST.FIND_TIMEOUT)
    else:
        try_log_screen()
        pos = v
    for _ in range(times):
        G.DEVICE.touch(pos, **kwargs)
        time.sleep(0.05)
    delay_after_operation()
    return pos

touch是兼容傳入圖片或座標的,我們只看圖片的邏輯。

pos = loop_find(v, timeout=ST.FIND_TIMEOUT)

可以看到是通過loop_find去循環找圖,超時時間ST.FIND_TIMEOUT默認是20S,這裏找到圖片的話會返回座標,後面的代碼會去點擊這個座標,就完成了touch操作。

繼續進入loop_find源碼:

# 源碼路徑 your_python_path/site-packages/airtest/core/cv.py
def loop_find(query, timeout=ST.FIND_TIMEOUT, threshold=None, interval=0.5, intervalfunc=None):
    G.LOGGING.info("Try finding: %s", query)
    start_time = time.time()
    while True:
        screen = G.DEVICE.snapshot(filename=None, quality=ST.SNAPSHOT_QUALITY)

        if screen is None:
            G.LOGGING.warning("Screen is None, may be locked")
        else:
            if threshold:
                query.threshold = threshold
            match_pos = query.match_in(screen)
            if match_pos:
                try_log_screen(screen)
                return match_pos

        if intervalfunc is not None:
            intervalfunc()

        # 超時則raise,未超時則進行下次循環:
        if (time.time() - start_time) > timeout:
            try_log_screen(screen)
            raise TargetNotFoundError('Picture %s not found in screen' % query)
        else:
            time.sleep(interval)

loop_find整體邏輯就是循環去屏幕截圖上找圖,找到返回其座標,超時未找到報錯。第1個參數query就是我們前面傳入的Template類實例(我們截的圖)

其中關鍵是match_pos = query.match_in(screen),前一步給手機截圖賦值給screen,然後在截圖中查找給定圖片,用的方法是Template類中的match_in方法。

繼續看match_in源碼:

# 源碼路徑 your_python_path/site-packages/airtest/core/cv.py
def match_in(self, screen):
    match_result = self._cv_match(screen)
    G.LOGGING.debug("match result: %s", match_result)
    if not match_result:
        return None
    focus_pos = TargetPos().getXY(match_result, self.target_pos)
    return focus_pos

其中核心代碼是match_result = self._cv_match(screen)圖像匹配

如果找到後面代碼會返回9宮點中我們要求的座標:

focus_pos = TargetPos().getXY(match_result, self.target_pos)

還得記得9宮點嗎?就是Template實例化時我們指定的target_pos,忘了可以看這篇=========中的第X點

 

繼續看_cv_match源碼:

# 源碼路徑 your_python_path/site-packages/airtest/core/cv.py
    def _cv_match(self, screen):
        # in case image file not exist in current directory:
        ori_image = self._imread()
        image = self._resize_image(ori_image, screen, ST.RESIZE_METHOD)
        ret = None
        for method in ST.CVSTRATEGY:
            # get function definition and execute:
            func = MATCHING_METHODS.get(method, None)
            if func is None:
                raise InvalidMatchingMethodError("Undefined method in CVSTRATEGY: '%s', try 'kaze'/'brisk'/'akaze'/'orb'/'surf'/'sift'/'brief' instead." % method)
            else:
                if method in ["mstpl", "gmstpl"]:
                    ret = self._try_match(func, ori_image, screen, threshold=self.threshold, rgb=self.rgb, record_pos=self.record_pos,resolution=self.resolution, scale_max=self.scale_max, scale_step=self.scale_step)
                else:
                    ret = self._try_match(func, image, screen, threshold=self.threshold, rgb=self.rgb)
            if ret:
                break
        return ret

其中ori_image = self._imread()讀取圖像

image = self._resize_image(ori_image, screen, ST.RESIZE_METHOD)

根據分辨率,將輸入的截圖適配成 等待模板匹配的截圖

之後會循環各種算法去匹配圖片,默認算法爲ST.CVSTRATEGY = ["mstpl", "tpl", "surf", "brisk"]

循環中用到的匹配方法爲_try_match

繼續看_try_match源碼:

# 源碼路徑 your_python_path/site-packages/airtest/core/cv.py
    def _try_match(func, *args, **kwargs):
        G.LOGGING.debug("try match with %s" % func.__name__)
        try:
            ret = func(*args, **kwargs).find_best_result()
        except aircv.NoModuleError as err:
            G.LOGGING.warning("'surf'/'sift'/'brief' is in opencv-contrib module. You can use 'tpl'/'kaze'/'brisk'/'akaze'/'orb' in CVSTRATEGY, or reinstall opencv with the contrib module.")
            return None
        except aircv.BaseError as err:
            G.LOGGING.debug(repr(err))
            return None
        else:
            return ret

其核心代碼爲ret = func(*args, **kwargs).find_best_result()

不同的算法對應不同的find_best_result()方法,目前一共有4種,我們以TemplateMatching類中的爲例看一下

# 源碼路徑 your_python_path/site-packages/airtest/aircv/template_matching.py
def find_best_result(self):
    """基於kaze進行圖像識別,只篩選出最優區域."""
    """函數功能:找到最優結果."""
    # 第一步:校驗圖像輸入
    check_source_larger_than_search(self.im_source, self.im_search)
    # 第二步:計算模板匹配的結果矩陣res
    res = self._get_template_result_matrix()
    # 第三步:依次獲取匹配結果
    min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
    h, w = self.im_search.shape[:2]
    # 求取可信度:
    confidence = self._get_confidence_from_matrix(max_loc, max_val, w, h)
    # 求取識別位置: 目標中心 + 目標區域:
    middle_point, rectangle = self._get_target_rectangle(max_loc, w, h)
    best_match = generate_result(middle_point, rectangle, confidence)
    LOGGING.debug("[%s] threshold=%s, result=%s" % (self.METHOD_NAME, self.threshold, best_match))

    return best_match if confidence >= self.threshold else None

到這裏就是基於cv2庫去找圖了,步驟註釋寫的很清楚了。對opencv感興趣的同學,可以到這裏學一學http://www.woshicver.com/

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章