圖像中的天空區域檢測!

一、引言

       天空區域作爲圖像中的背景信息,爲機器人導航、自動駕駛等領域的圖像理解提供了重要依據,因此如何檢測圖像中的天空區域非常重要,本文提供了一個基於傳統視覺算法(非機器學習方法)的提取天空區域的方法,參考文獻:https://journals.sagepub.com/doi/pdf/10.5772/56884

二、算法思路

1.使用sobel算子爲圖像提取梯度信息圖,從梯度信息圖中根據給定閾值提取天空的邊界線

於是,天空區域和地面區域可以表示如下:

2.使用一個能量函數作爲適應度函數,來優化提取天空邊界線,當能量函數最大時,天空邊界線最準確

3.通過檢測邊界線是否發生大的跳變來判斷提取的天空區域包含大量建築,若是,對天空區域進行聚類分析,分爲兩類,通過與非天空區域的馬氏距離來得到真實的天空區域。

4.掃描每一列的像素,通過每一列像素與真實天空區域的馬氏距離來判斷該像素是否屬於天空的像素,從而可以改善天空邊界線

算法是這樣假設的:

(1)圖像中的天空區域在圖像的上方

(2)天空區域的像素變化比較平滑

然而實際的城市道路中由於電線杆、紅路燈杆等的存在,以及太陽、雲彩的影響,會使得天空邊界線的提取被阻擋。而且上述假設在很多情況下不成立。此外能量函數只是表明了同一區域的同質性,而實際圖像內容的情況比這要複雜的多。

我的改進:

(1)對梯度圖像進行閾值分割,可以使得天空的提取魯棒性更強

(2)使用多項式擬合修正錯誤的邊界線(只是一個實驗性的想法)

(3)計算真實天空區域的像素的R、G、B均值,通過對每一列的像素的RGB各通道像素值與天空區域的RGB均值比較,可以得到更完整的天空邊界線

三、代碼實現

      下面是我的python代碼實現,寫的不夠好,大家適當參考吧,裏面部分代碼僅僅是爲了實現論文中的算法而已。有些經過我改進後的函數,比如梯度信息提取函數,已經不需要再使用extract_border_optimal()函數了。

       大家有什麼問題、建議和更好的改進方法請留言評論,我將感激不盡。代碼中增加了計算天空消失點的程序,當然通過天空區域判斷消失點肯定沒有通過車道線計算消失點的準確性更強,但是車道線的檢測難度較大,這個問題留待以後再考慮吧!

import cv2
import os
import math
import time
import sys
import numpy as np
import matplotlib.pyplot as plt
from numba import jit
from scipy import spatial
from scipy.optimize import curve_fit


#加載原始圖像
def load_image(image_file_path):

    if not os.path.exists(image_file_path):
        print("圖像文件不存在!")
        #sys.exit()
    else:
        img = cv2.imread(image_file_path)
        if img is None:
            print('讀取圖像失敗!')
            #sys.exit()
        else:
            return img

#提取圖像天空區域
def extract_sky(src_image):

    height, width = src_image.shape[0:2]

    sky_border_optimal = extract_border_optimal(src_image)
    border_correct = correct_border_polynomial(sky_border_optimal,src_image)
    sky_exists = has_sky_region(sky_border_optimal, height / 30, height / 10, 5)

    if sky_exists == 0:
        print('沒有檢測到天空區域')
        #sys.exit()

    """
    if has_partial_sky_region(border_correct, width / 3):
        border_new = refine_border(border_correct, src_image)
        sky_mask = make_sky_mask(src_image, border_new,1)
        return sky_mask, sky_exists
        #sky_image = display_sky_region(src_image, sky_border_optimal)
    """


    sky_mask = make_sky_mask(src_image, border_correct, 1)

    return sky_mask, sky_exists

#檢測圖像天空區域
def detect(image_file_path, output_path):

    #加載圖像
    src_image = load_image(image_file_path)
    src_image = cv2.pyrDown(src_image)
    #x, y = src_image.shape[0:2]
    #src_image = cv2.resize(src_image, (int(2*y/3),int(2*x/3)), cv2.INTER_CUBIC)

    #提取圖像天空區域
    sky_mask,sky_exists = extract_sky(src_image)

    #製作掩碼輸出
    tic = time.time()
    height = src_image.shape[0]
    width = src_image.shape[1]

    """
    sky_image_full = np.zeros(src_image.shape, dtype= np.uint8)
    for row in range(height):
        for col in range(width):
            if sky_mask[row, col] != 0:
                sky_image_full[row, col, 0] = 0
                sky_image_full[row, col, 1] = 0
                sky_image_full[row, col, 2] = 255

    sky_image = cv2.addWeighted(src_image, 1, sky_image_full, 1, 0)
    """

    for row in range(height):
        for col in range(width):
            if sky_mask[row, col] != 0:
                src_image[row, col, 0] = 0
                src_image[row, col, 1] = 0
                src_image[row, col, 2] = 255

    cv2.imwrite(output_path, src_image)
    toc = time.time()
    print('display mask time: ',(toc - tic), 's')
    print('圖像檢測完畢!')

#檢測圖像天空區域--批量
def batch_detect(image_dir, output_dir):

    img_filelist = os.listdir(image_dir)

    print('開始批量提取天空區域')
    i = 1
    for img_file in img_filelist:
        src_img = load_image(image_dir + img_file)
        src_img = cv2.pyrDown(src_img)

        sky_mask,sky_exists = extract_sky(src_img)
        if sky_exists == 0:
            i += 1
            cv2.imwrite(output_dir+img_file, src_img)
            continue
        height = src_img.shape[0]
        width  = src_img.shape[1]

        #sky_image_full = np.zeros(src_img.shape,dtype= src_img.dtype)
        for row in range(height):
            for col in range(width):
                if sky_mask[row, col] != 0:
                    src_img[row, col, 0] = 0
                    src_img[row, col, 1] = 0
                    src_img[row, col, 2] = 255
        #sky_img = cv2.addWeighted(src_img, 1, sky_image_full, 1, 0)
        cv2.imwrite(output_dir+img_file, src_img)

        print('已提取完成第',i,'張')
        i += 1

    print('批量提取完畢')

#計算天空滅點
def compute_vanish(image_file_path):
    # 加載圖像
    src_img = load_image(image_file_path)
    src_img = cv2.pyrDown(src_img)
    src_img = cv2.pyrDown(src_img)
    height, width = src_img.shape[0:2]

    # 計算天空邊界線
    sky_border_optimal = extract_border_optimal(src_img)
    border_correct = correct_border_polynomial(sky_border_optimal, src_img)

    # 判斷是否存在天空
    sky_exists = has_sky_region(border_correct, height / 30, height / 10, 5)
    if sky_exists == 0:
        #print('沒有檢測到天空區域')
        #cv2.imwrite(output_path, src_img)
        return 2*(src_img.shape[0]//3)-15

    # 計算天空消失點的高度,並畫圖
    vanish_h = refine_vanishpoint(border_correct, src_img)
    #cv2.circle(src_img, (src_img.shape[1]//2, vanish_h), 4, (0, 255, 0), 8)
    #cv2.imwrite(output_path, src_img)

    return 2*vanish_h

#計算天空滅點--批量
def batch_compute_vanish(image_dir, output_dir):

    vanishs = []
    img_filelist = sorted(os.listdir(image_dir))

    print('開始批量計算天空滅點')
    i = 1
    for img_file in img_filelist:
        #加載圖像
        src_image = load_image(image_dir + img_file)
        src_img = cv2.pyrDown(src_image)
        height, width = src_img.shape[0:2]

        #計算天空邊界線
        sky_border_optimal = extract_border_optimal(src_img)
        border_correct = correct_border_polynomial(sky_border_optimal, src_img)

        #判斷是否存在天空
        sky_exists = has_sky_region(border_correct, height / 30, height / 10, 5)
        if sky_exists == 0:
            print('沒有檢測到天空區域')
            i += 1
            cv2.imwrite(output_dir + img_file, src_image)
            continue

        #計算天空消失點的高度,並畫圖
        vanish_h = refine_vanishpoint(border_correct, src_img)
        vanishs.append(2*vanish_h)
        cv2.circle(src_image, (src_image.shape[1]//2, 4*vanish_h), 4, (0, 255, 0), 8)
        cv2.imwrite(output_dir+img_file, src_image)

        print('已計算完成第',i,'張')
        i += 1

    print('批量計算完畢')
    return vanishs

#提取圖像梯度信息
def extract_image_gradient(src_image):
    #轉灰度圖像
    gray_image = cv2.cvtColor(src_image, cv2.COLOR_BGR2GRAY)

    #Sobel算子提取圖像梯度信息
    x_gradient = cv2.Sobel(gray_image, cv2.CV_64F, 1, 0, 3)
    y_gradient = cv2.Sobel(gray_image, cv2.CV_64F, 0, 1, 3)

    #計算梯度幅值
    gradient_image = np.hypot(x_gradient, y_gradient)
    ret, gradient_image = cv2.threshold(gradient_image, 40, 1000, cv2.THRESH_BINARY)
    #gradient_image = np.uint8(np.sqrt(np.multiply(x_gradient,x_gradient) + np.multiply(y_gradient,y_gradient)))

    return gradient_image

#利用能量函數優化計算計算天空邊界線
def extract_border_optimal(src_image, thres_sky_min = 5, thres_sky_max = 600, thres_sky_search_step = 6):

    #提取梯度信息圖
    gradient_info_map = extract_image_gradient(src_image)

    n = math.floor((thres_sky_max - thres_sky_min)/ thres_sky_search_step) + 1

    border_opt = None
    jn_max = 0

    for i in range(n + 1):
        t = thres_sky_min + (math.floor((thres_sky_max - thres_sky_min) / n) - 1) * i
        b_tmp = extract_border(gradient_info_map, t)
        jn = calculate_sky_energy(b_tmp, src_image)
        #print('threshold= ',t,'energy= ',jn)

        if jn > jn_max:
            jn_max = jn
            border_opt = b_tmp

    return border_opt

# 計算天空圖像能量函數
def calculate_sky_energy(border, src_image):

    # 製作天空圖像掩碼和地面圖像掩碼
    sky_mask = make_sky_mask(src_image, border, 1)
    ground_mask = make_sky_mask(src_image, border, 0)

    # 扣取天空圖像和地面圖像
    sky_image_ma = np.ma.array(src_image, mask = cv2.cvtColor(sky_mask, cv2.COLOR_GRAY2BGR))
    ground_image_ma = np.ma.array(src_image, mask = cv2.cvtColor(ground_mask, cv2.COLOR_GRAY2BGR))

    # 計算天空和地面圖像協方差矩陣
    sky_image = sky_image_ma.compressed()
    ground_image = ground_image_ma.compressed()

    sky_image.shape = (sky_image.size//3, 3)
    ground_image.shape = (ground_image.size//3, 3)

    sky_covar, sky_mean = cv2.calcCovarMatrix(sky_image, mean=None, flags=cv2.COVAR_ROWS|cv2.COVAR_NORMAL|cv2.COVAR_SCALE)
    sky_retval, sky_eig_val, sky_eig_vec = cv2.eigen(sky_covar)

    ground_covar, ground_mean = cv2.calcCovarMatrix(ground_image, mean=None,flags=cv2.COVAR_ROWS|cv2.COVAR_NORMAL|cv2.COVAR_SCALE)
    ground_retval, ground_eig_val, ground_eig_vec = cv2.eigen(ground_covar)

    gamma = 2  # 論文原始參數

    sky_det = cv2.determinant(sky_covar)
    #sky_eig_det = cv2.determinant(sky_eig_vec)
    ground_det = cv2.determinant(ground_covar)
    #ground_eig_det = cv2.determinant(ground_eig_vec)

    sky_energy = 1 / ((gamma * sky_det + ground_det) + (gamma * sky_eig_val[0,0] + ground_eig_val[0,0]))

    return sky_energy

# 判斷圖像是否存在天空區域
def has_sky_region(border, thresh_1, thresh_2, thresh_3):

    border_mean = np.average(border)

    #求天際線位置差,取絕對值,取均值
    border_diff_mean = np.average(np.absolute(np.diff(border)))

    sky_exists = 0
    if border_mean < thresh_1 or (border_diff_mean > thresh_3 and border_mean < thresh_2):
        return sky_exists
    else:
        sky_exists = 1
        return sky_exists

#判斷圖像是否有部分區域爲天空區域
def has_partial_sky_region(border, thresh_4):

    border_diff = np.diff(border)

    '''
    if np.any(border_diff > thresh_4):
        index = np.argmax(border_diff)
        print(border_diff[index])
    '''

    return np.any(border_diff > thresh_4)

#計算天空邊界線
def extract_border(gradient_info_map, thresh):

    height, width = gradient_info_map.shape[0:2]
    border = np.full(width, height - 1)

    for col in range(width):
        #返回該列第一個大於閾值的元素的索引
        border_pos = np.argmax(gradient_info_map[:, col] > thresh)
        if border_pos > 0:
            border[col] = border_pos

    return border

#天空區域和原始圖像融合圖,顯示天空區域
def display_sky_region(src_image, border):

    height = src_image.shape[0]
    width = src_image.shape[1]

    #製作天空圖掩碼
    sky_mask = make_sky_mask(src_image, border, 1)

    #天空和原始圖像融合
    sky_image_full = np.zeros(src_image.shape, dtype = src_image.dtype)
    for row in range(height):
        for col in range(width):
            if sky_mask[row, col] != 0:
                src_image[row, col, 0] = 0
                src_image[row, col, 1] = 0
                src_image[row, col, 2] = 255
    sky_image = cv2.addWeighted(src_image, 1, sky_image_full, 1, 0)

    return sky_image

#製作天空掩碼圖像,type: 1: 天空 0: 地面
def make_sky_mask(src_image, border, type):

    height = src_image.shape[0]
    width = src_image.shape[1]

    mask = np.zeros((height,width),dtype= np.uint8)

    if type == 1:
        for col, row in enumerate(border):
            mask[0:row +1, col] = 255
    elif type == 0:
        for col, row in enumerate(border):
            mask[row + 1:, col] = 255
    else:
        assert type is 0 or type is 1,'type參數必須爲0或1'

    return mask

#改善天空邊界線
def refine_border(border, src_image):

    sky_covar, sky_mean, ic_s, ground_covar, ground_mean, ic_g = true_sky(border, src_image)

    for col in range(src_image.shape[1]):
        cnt = np.sum(np.greater(spatial.distance.cdist(src_image[0:border[col], col], sky_mean, 'mahalanobis', VI=ic_s), spatial.distance.cdist(src_image[0:border[col], col], ground_mean, 'mahalanobis', VI=ic_g)))

        if cnt < (border[col] / 2):
            border[col] = 0

    return border

#改善天空邊界線————alpha版本
def refine_border_alpha(border, src_image):

    sky_covar, sky_mean, ic_s, ground_covar, ground_mean, ic_g = true_sky(border, src_image)

    for col in range(src_image.shape[1]):
        for row in range(src_image.shape[0]):
            mahalanobis_sky = spatial.distance.cdist(src_image[row, col].reshape(1, 3), sky_mean, 'mahalanobis',VI=ic_s)
            mahalanobis_gr = spatial.distance.cdist(src_image[row, col].reshape(1, 3), ground_mean, 'mahalanobis',VI=ic_g)
            delta1 = abs(src_image[row, col, 0] - sky_mean[0,0]) < sky_mean[0,0] / 3.6
            delta2 = abs(src_image[row, col, 1] - sky_mean[0,1]) < sky_mean[0,1] / 3.6
            delta3 = abs(src_image[row, col, 2] - sky_mean[0,2]) < sky_mean[0,2] / 3.6
            if mahalanobis_sky < mahalanobis_gr and delta1 and delta2 and delta3:
                border[col] = row

    """
    sky_mean = np.mean(sky_image_true, axis= 0)
    for col in range(width):
        for row in range(height):
            delta1 = abs(src_image[row,col,0] - sky_mean[0]) < sky_mean[0]/3.6
            delta2 = abs(src_image[row,col,1] - sky_mean[1]) < sky_mean[1]/3.6
            delta3 = abs(src_image[row,col,2] - sky_mean[2]) < sky_mean[2]/3.6
            if delta1 and delta2 and delta3:
                border[col] = row
    """
    return border

#獲取更真實天空像素和地面像素的均值、協方差及其逆
def true_sky(border, src_image):

    #製作天空圖像掩碼和地面圖像掩碼
    sky_mask = make_sky_mask(src_image, border, 1)
    ground_mask = make_sky_mask(src_image, border, 0)

    #扣取天空圖像和地面圖像
    sky_image_ma = np.ma.array(src_image, mask = cv2.cvtColor(sky_mask, cv2.COLOR_GRAY2BGR))
    ground_image_ma = np.ma.array(src_image, mask = cv2.cvtColor(ground_mask, cv2.COLOR_GRAY2BGR))

    #將天空和地面區域shape轉換爲n*3
    sky_image = sky_image_ma.compressed()
    ground_image = ground_image_ma.compressed()

    sky_image.shape = (sky_image.size//3, 3)
    ground_image.shape = (ground_image.size//3, 3)

    # k均值聚類調整天空區域邊界--2類
    sky_image_float = np.float32(sky_image)
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    flags = cv2.KMEANS_RANDOM_CENTERS
    compactness, labels, centers = cv2.kmeans(sky_image_float, 2, None, criteria, 10, flags)

    sky_label_0 = sky_image[labels.ravel() == 0]
    sky_label_1 = sky_image[labels.ravel() == 1]

    sky_covar_0, sky_mean_0 = cv2.calcCovarMatrix(sky_label_0, mean= None, flags= cv2.COVAR_ROWS + cv2.COVAR_NORMAL + cv2.COVAR_SCALE)
    sky_covar_1, sky_mean_1 = cv2.calcCovarMatrix(sky_label_1, mean= None, flags= cv2.COVAR_ROWS + cv2.COVAR_NORMAL + cv2.COVAR_SCALE)
    ground_covar, ground_mean = cv2.calcCovarMatrix(ground_image, mean= None, flags= cv2.COVAR_ROWS + cv2.COVAR_NORMAL + cv2.COVAR_SCALE)

    ic_s0 = cv2.invert(sky_covar_0, cv2.DECOMP_SVD)[1]
    ic_s1 = cv2.invert(sky_covar_1, cv2.DECOMP_SVD)[1]
    ic_g = cv2.invert(ground_covar, cv2.DECOMP_SVD)[1]

    #推斷真實的天空區域
    if cv2.Mahalanobis(sky_mean_0, ground_mean, ic_s0) > cv2.Mahalanobis(sky_mean_1, ground_mean, ic_s1):
        sky_mean = sky_mean_0
        sky_covar = sky_covar_0
        ic_s = ic_s0
    else:
        sky_mean = sky_mean_1
        sky_covar = sky_covar_1
        ic_s = ic_s1


    return sky_covar,sky_mean,ic_s,ground_covar, ground_mean,ic_g

#修正天空滅點
def refine_vanishpoint(border,src_image):

    src_image = cv2.GaussianBlur(src_image, (7,7), 0)
    index = np.argmax(border)

    if border[index] >= 3*(src_image.shape[0]//4):
        dist = np.full(border[index], 0)
        width = src_image.shape[1]
        sky_covar,sky_mean,ic_s,ground_covar, ground_mean,ic_g = true_sky(border, src_image)
        for row in range(border[index]):
            distance = spatial.distance.cdist(src_image[width // 2, row].reshape(1, 3), sky_mean, 'mahalanobis',VI=ic_s)
            dist[row] = distance
        diff1 = np.diff(dist)
        diff2 = abs(np.diff(diff1))
        vanish_h = np.argmax(diff2)
    elif border[index] < src_image.shape[0]//2 :
        dist = np.full(src_image.shape[0], 0)
        width = src_image.shape[1]
        sky_covar,sky_mean,ic_s,ground_covar, ground_mean,ic_g = true_sky(border, src_image)
        for row in range(src_image.shape[0]):
            distance = spatial.distance.cdist(src_image[width//2, row].reshape(1, 3), sky_mean, 'mahalanobis', VI=ic_s)
            dist[row] = distance
        diff1 = np.diff(dist)
        diff2 = abs(np.diff(diff1))
        vanish_h = np.argmax(diff2)
    else:
        vanish_h = border[index]

    return vanish_h

#修正錯誤邊界線--多項式擬合
def correct_border_polynomial(border, src_image):

    x = np.arange(0, src_image.shape[1], 1)
    border_line_argument = np.polyfit(x, border, 10)
    border_line_function = np.poly1d(border_line_argument)
    border_polynomial = np.int64(border_line_function(x))

    outlier = np.percentile(border,90)
    for col in range(len(border)):
        if border[col] >= outlier: # or abs(border[col]-border_polynomial[col]) > src_image.shape[0]/3 :
            border[col] = border_polynomial[col]
        #elif border[col] <= src_image.shape[0]//3:
            #border[col] = border_polynomial[col]

    return border

'''
#修正錯誤邊界線--二次函數擬合
def correct_border_quardratic(border, src_image):
    outlier = np.percentile(border, 90)
    for col in range(len(border)):
        if border[col] >= outlier:
            if col == 0:
                border[col] = border[col + 1]
            elif col == src_image.shape[1] - 1:
                border[col] = border[col - 1]
            else:
                border[col] = (border[col - 1] + border[col + 1]) / 2
    x = np.arange(0, src_image.shape[1], 1)
    def fun(x,a,b,c):
        return a*(x**2) + b*x +c
    ppot,pcov = curve_fit(fun, x, border)
    a = ppot[0]
    b = ppot[1]
    c = ppot[2]
    border_new = np.int64(fun(x,a,b,c))

    return border_new
'''


if __name__ == '__main__':

    image_file_path = '/home/dulingwen/Pictures/skydetect/images/'
    out_path = '/home/dulingwen/Pictures/skydetect/output/'

    tic = time.time()
    batch_detect(image_file_path, out_path)
    toc = time.time()
    times = 1000*(toc- tic)
    print('運行時間:',times,'ms')

效果如下:

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章