兩幅圖像線性拼接

1. 收藏鏈接:https://blog.csdn.net/wd1603926823/article/details/49582461

一、2幅圖像拼接

流程1>>>提取局部特徵,計算特徵點對應,本文使用rootSIFT計算特徵點的最近鄰,然後帥選優質最近鄰(最近鄰距離<0.2*次近鄰距離,這裏的0.2是自定義的)

流程2>>>檢驗單應矩陣是否符合要求,參考論文[1] 中“內點對總數>5.9+0.22*總點對數”

流程3>>>按重疊區範圍線性加權融合圖像

import cv2, math
from pyflann import *
import matplotlib.pyplot as plt


class Stitching2Img:
    def __init__(self):
        pass

    def __rootSIFT_extract(self, img_path, crop=None):
        '''
        rootsift= sqrt( sift / sum(sift) )
        :param img_path:  the abs_path of image
        :param crop: a list with 4 members denoting [low, up, left, right]
        :return: np.sqrt(sift_vector/max(sift_vector)), if needed do l2-moralization after extraction
        '''
        sift = cv2.xfeatures2d.SIFT_create()
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            print(img_path, "is not readable")
            raise Exception(img_path + " is not readable")
        if crop is not None:
            img = img[crop[0]:crop[1], crop[2]:crop[3]]
        kp, des = sift.detectAndCompute(img, mask=None)

        if (des is not None) and (len(des) > 0):
            des_filted = []
            kp_filted = []
            des_sum = np.sum(des, axis=1)
            des_sum_nonzero_index = np.argwhere(des_sum).reshape((-1,))
            for des_i in des_sum_nonzero_index:
                des_filted.append(des[des_i] / des_sum[des_i])
                kp_filted.append(kp[des_i])
            des_filted = np.sqrt(np.array(des_filted))
        else:
            des_filted = des
            kp_filted = kp
        return kp_filted, des_filted

    def __loadImages(self, imgs_path):
        self.__params = np.random.random(size=(len(imgs_path), 4))
        self.__imgspath = []
        self.__imgspath.extend(imgs_path)
        self.__imgs = []
        for img_path in imgs_path:
            self.__imgs.append(self.__rootSIFT_extract(img_path))
        self.__get_adjacent()
        return

    def __get_adjacent(self):
        num_imgs = len(self.__imgs)
        adjacent = [[[] for i in range(num_imgs)] for i in range(num_imgs)]
        for img_i in range(num_imgs):
            for img_j in range(img_i + 1, num_imgs):
                M, matched_pts = self.__get_M_ptpair(img_i, img_j)
                adjacent[img_i][img_j].append((M, matched_pts))
                adjacent[img_j][img_i].append((np.linalg.inv(M), [matched_pts[1], matched_pts[0]]))
        self.__adjacent = adjacent
        return

    def __get_M_ptpair(self,img_i,img_j,threshold=0.2):
        '''
        X_j=M*X_i
        :param img_i: the index of images
        :param img_j: the index of images
        :param threshold: the threshold of RANSAC
        :return: M, matched_pts
        '''
        myflann = FLANN()
        descriptor_img0 = self.__imgs[img_i][1]
        descriptor_img1 = self.__imgs[img_j][1]
        result, dists = myflann.nn(descriptor_img0, descriptor_img1, 2, algorithm='linear')
                                    # (np.array(words, dtype='float32'),np.array(des, dtype='float32'),1,
                                    #  algorithm="kdtree", trees=8, checks=512)
        match_filter = dists[:, 0] / dists[:, 1] < threshold
        matched_pts = [[], []]
        for i, item in enumerate(match_filter):
            if not item:
                continue
            matched_pts[1].append(self.__imgs[1][0][i].pt)
            matched_pts[0].append(self.__imgs[0][0][result[i][0]].pt)
        if len(matched_pts[0]) <= 8:
            return None, None
        M, mask = cv2.findHomography(np.float32(matched_pts[0]), np.float32(matched_pts[1]), cv2.RANSAC, 1.0)
        matchesMask = mask.ravel()
        if not sum(matchesMask)>5.9+0.22*len(matchesMask):#is it OK here?
            return None,None
        else:
            return M, matched_pts

    def stitching2images(self, left_img, right_img):
        '''
        :param left_img: the absolute path of left_img
        :param right_img: the absolute path of right_img
        :return: None
        '''
        self.__loadImages([left_img, right_img])
        M = self.__adjacent[1][0][0][0]
        im0 = cv2.imread(self.__imgspath[0], 1)
        im1 = cv2.imread(self.__imgspath[1], 1)
        h, w = im1.shape[0], im0.shape[1]
        im3 = cv2.warpPerspective(im1, M, (math.floor(2.5 * w), h))  # warpAffine
        for h_i in range(h):
            cur_l = -1
            cur_r = -1
            for w_i in range(w):
                if sum(im3[h_i][w_i]) / 3 > 0.001 and sum(im0[h_i][w_i]) / 3 > 0.001:
                    if cur_l < 0 and cur_r < 0:
                        cur_l = w_i
                    elif cur_l > 0:
                        cur_r = w_i
                if sum(im3[h_i][w_i]) / 3 > 0.001 and sum(im0[h_i][w_i]) / 3 < 0.001:
                    if cur_l > 0 and cur_r > 0:
                        break
            ler2l = cur_r - cur_l
            for w_i in range(w):
                if w_i < cur_l:
                    im3[h_i][w_i] = im0[h_i][w_i]
                elif w_i >= cur_l and w_i < cur_r:
                    f_l = (w_i - cur_l) / ler2l
                    f_r = (cur_r - w_i) / ler2l
                    im3[h_i][w_i] = im3[h_i][w_i] * f_l + im0[h_i][w_i] * f_r
        fig = plt.figure()
        fig.add_subplot('221')
        plt.imshow(im0)
        fig.add_subplot('222')
        plt.imshow(im1)
        fig.add_subplot('223')
        plt.imshow(im3)
        plt.show()
        return

    def __call__(self, *args, **kwargs):
        self.stitching2images(left_img=kwargs['left_img'],right_img=kwargs['right_img'])


if __name__ == '__main__':
    stitch = Stitching2Img()
    # stitch(left_img=os.getcwd() + "/stitch1.jpg", right_img=os.getcwd() + "/stitch2.jpg")
    stitch(left_img=os.getcwd() + "/2_2.jpg", right_img=os.getcwd() + "/1_1.jpg")

線性加權融合效果展示:


修改拼接函數成固定權重融合

    def stitching2images(self,left_img,right_img):
        self.loadImages([left_img,right_img])
        M = self.__adjacent[1][0][0][0]
        im0 = cv2.imread(self.__imgspath[0], 1)
        im1 = cv2.imread(self.__imgspath[1],1)
        h,w=im1.shape[0],im0.shape[1]
        im3=cv2.warpPerspective(im1,M,(math.floor(2.5*w),h))#warpAffine
        for h_i in range(h):
            for w_i in range(w):
                if sum(im3[h_i][w_i])/3>0.001 and sum(im0[h_i][w_i])/3>0.001:
                    im3[h_i][w_i] = im3[h_i][w_i] * 0.5 + im0[h_i][w_i] * 0.5
                if sum(im3[h_i][w_i]) / 3 < 0.001 and sum(im0[h_i][w_i]) / 3 > 0.001:
                    im3[h_i][w_i] =  im0[h_i][w_i]
        fig=plt.figure()
        fig.add_subplot('221')
        plt.imshow(im0)
        fig.add_subplot('222')
        plt.imshow(im1)
        fig.add_subplot('223')
        plt.imshow(im3)
        plt.show()
        return

此時的拼接效果:


參考文獻

[1] Brown M, Lowe D G. Recognising Panoramas[C]. IEEE International Conference on Computer Vision. IEEE Computer Society, 2003:1218.

-----------未完成---------

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章