監督分類:用SVM做遙感影像特徵級分類

前面已經做了很多像素級的分類了,這裏繼續深入,這裏用的特徵是LBP(局部閾值模式),思路就是先手動選點取樣本點鄰域並提取LBP特徵進行訓練得到模型,然後取每個像素的鄰域提取LBP特徵,然後用訓練好的模型對每個像素對應鄰域的LBP特徵進行判斷,最終確定像素的類型
這裏放了一個初版和一個改進後的版本,初版裏面有想記錄的東西所以就也放這裏了,過客們可以直接往下翻看終版。

# -*- coding: utf-8 -*-
import os, sys, time
import warnings
import gdal
import numpy as np
from numpy import average, dot, linalg
import cv2
import skimage
from osgeo import ogr
from osgeo import gdal
from osgeo import gdal_array as ga
from gdalconst import *
from skimage.feature import local_binary_pattern
from skimage.util.shape import view_as_windows #這是影像切片的模塊,存在邊界無法處理的缺陷
from sklearn.svm import SVC
from numba import jit

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

@jit
def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
    out = np.zeros_like(bands).astype(np.float32)
    # a = 0
    # b = 65535
    a = img_min
    b = img_max
    c = np.percentile(bands[:, :], lower_percent)
    d = np.percentile(bands[:, :], higher_percent)
    x = d-c
    if (x==0).any():
        t = 0
    else:
        t = a + (bands[:, :] - c) * (b - a) / (d - c)
        t[t < a] = a
        t[t > b] = b
        out[:, :] = t
    
    out = np.uint8(out)
    return out

def getPixels(shp, img, size):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    ds = driver.Open(shp, 0)
    if ds is None:
        print('Could not open ' + shp)
        sys.exit(1)

    layer = ds.GetLayer()

    xValues = []
    yValues = []
    feature = layer.GetNextFeature()
    while feature:
        geometry = feature.GetGeometryRef()
        x = geometry.GetX()
        y = geometry.GetY()
        xValues.append(x)
        yValues.append(y)
        feature = layer.GetNextFeature()

    gdal.AllRegister()

    ds = gdal.Open(img, GA_ReadOnly)
    if ds is None:
        print('Could not open image')
        sys.exit(1)

    rows = ds.RasterYSize
    cols = ds.RasterXSize
    bands = ds.RasterCount

    transform = ds.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    values = []
    for i in range(len(xValues)):
        x = xValues[i]
        y = yValues[i]
        new_transform=list(transform)
        #print new_transform
        new_transform[0]=x-im_geotrans[1]*int(size)/2.0
        new_transform[3]=y-im_geotrans[5]*int(size)/2.0
        new_transfor_mtuple=tuple(new_transform)
        x1=x-int(size)/2*transform[1]
        y1=y-int(size)/2*transform[5]
        x2=x+int(size)/2*transform[1]
        y2=y-int(size)/2*transform[5]
        x3=x-int(size)/2*transform[1]
        y3=y+int(size)/2*transform[5]
        x4=x+int(size)/2*transform[1]
        y4=y+int(size)/2*transform[5]
        Xpix=(x1-transform[0])/transform[1]         
        Ypix=(new_transform[3]-transform[3])/transform[5]
        data = ds.ReadAsArray(int(Xpix),int(Ypix),int(size),int(size))

        values.append([data, new_transfor_mtuple])
    return values

def lbp(img, n_points, radius, level=256, method='default'):  #這裏注意256位與8bit對應
    lbp = local_binary_pattern(img, n_points, radius, method)
    # n_bins = int(lbp.max() + 1)
    n_bins = level
    hist, _ = np.histogram(lbp, density=True, bins=n_bins, range=(0, n_bins))
    return hist
    
def get_data(data_point, img_path, patch_path, im_proj, band, size, radius, n_points, data_type='false'):
    data_patches = getPixels(data_point, img_path, size)
    data_lbp_hists = []
    data_label = []
    count = 0
    for patch in data_patches:
        patch_p = os.path.join(patch_path, data_type + '_' + str(count) + '.tif')
        write_img(patch_p, im_proj, patch[1], patch[0])
        count += 1
        temp = []
        for i in range(band):
            hist = lbp(patch[0][i,...], n_points, radius)
            # print(hist.shape)
            temp.append(hist)
        temp_arr = temp[0]
        for j in range(band-1):
            temp_arr = np.concatenate((temp_arr,temp[j+1]),axis=0) 
        data_lbp_hists.append(temp_arr)
        if data_type == 'false':
            data_label.append(0)
        else:
            data_label.append(1)
    return data_lbp_hists, data_label

if __name__ == "__main__":
    warnings.filterwarnings("ignore")

    img_path = "E:/20200210/forest/gf2/dys_gf2.tif"   #大圖
    false_point = 'E:/20200210/forest/gf2/point/1.shp' #負樣本點
    true_point = 'E:/20200210/forest/gf2/point/2.shp'  #正樣本點
    patch_path = "E:/20200210/forest/tif_temp/patches2/" #樣本點對應的圖像塊,以樣本點爲中心裁剪
    
    im_proj,im_geotrans,im_width, im_height,im_data = read_img(img_path)
    temp8bit_path = "E:/20200210/forest/gf2/dys_gf2_8bit.tif"
    temp8bit = stretch_n(im_data,0,255)
    write_img(temp8bit_path, im_proj, im_geotrans, temp8bit)  #影像原本是16bit的,這裏轉成8bit,減小計算量

    band = 4
    size = 10  #鄰域大小
    
    radius = 2  #LBP參數
    n_points = 8 * radius  #LBP參數
    
    false_lbp_hists, false_label = get_data(false_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points, data_type='false')

    true_lbp_hists, true_label = get_data(true_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points,data_type='true')
    
    train_data = np.array(true_lbp_hists + false_lbp_hists)
    train_label = np.array(true_label + false_label)

    svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000) #這裏用了SVM,也可以選用別的方法訓練模型,比如以前提到的隨機森林
    svc.fit(train_data, train_label)
    
    test_area = "E:/20200210/forest/gf2/dys_gf2_test.tif"  #測試區域,有點慢,用小點的。
    ds = gdal.Open(test_area)
    im_width0 = ds.RasterXSize
    im_height0 = ds.RasterYSize
    im_geotrans0 = ds.GetGeoTransform()
    im_proj0 = ds.GetProjection()
    test_data = ds.ReadAsArray(0,0,im_width0,im_height0)
    test_data = stretch_n(test_data, 0, 255)

    # xlength = int((im_width+0.0)/size)
    # ylength = int((im_height+0.0)/size)
    xlength = im_width0 - size
    ylength = im_height0 - size
    window_shape = (4,10,10)  #切片大小,應該和訓練時的切片大小一致,也就是上面的size
    windows = view_as_windows(test_data, window_shape)  #自動切圖,步長是1,也就是說右邊界和下邊界的10個像素都無法處理,感興趣可以搜索一下這個函數,挺不錯的,至少速度很快。
    windows = np.squeeze(windows)

    all_arr = []
    for i in xrange(ylength):
        for j in xrange(xlength):
            temp = []
            for h in xrange(band):
                hist = lbp(windows[i,j][h,...], n_points, radius)
                temp.append(hist)
            temp_arr = temp[0]
            for k in xrange(band-1):
                temp_arr = np.concatenate((temp_arr,temp[k+1]),axis=0)
            all_arr.append(temp_arr)
  
    predict = svc.predict(np.array(all_arr))
    re = predict.reshape((im_height0-size, im_width0-size))  #由於無法處理右邊10和下邊10個像素,所以預測結果是少了一部分的,需要減去後再reshape
    seg_path = "E:/20200210/forest/tif_temp/dys_seg.tif"
    write_img(seg_path, im_proj0, im_geotrans0, re)
    del ds

改進版本
由於上述的圖像切片不但不能處理邊界而且每次的切片都是以掩膜左上角爲準的鄰域,而不是以每個待判別像素爲中心的領域,問題很大。下面的代碼給圖像上下左右增加了padding,輔助判別邊界點,代碼的整體邏輯沒什麼問題了,但是效果不太理想,我推測是特徵部分有問題,有時間再探究,你們就作爲進一步探索的參考吧。

# -*- coding: utf-8 -*-
import os, sys, time
import warnings
import gdal
import numpy as np
from numpy import average, dot, linalg
import cv2
import skimage
from osgeo import ogr
from osgeo import gdal
from osgeo import gdal_array as ga
from gdalconst import *
from skimage.feature import local_binary_pattern
from skimage.util.shape import view_as_windows, view_as_blocks
from sklearn.svm import SVC
from numba import jit

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

@jit
def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
    out = np.zeros_like(bands).astype(np.float32)
    # a = 0
    # b = 65535
    a = img_min
    b = img_max
    c = np.percentile(bands[:, :], lower_percent)
    d = np.percentile(bands[:, :], higher_percent)
    x = d-c
    if (x==0).any():
        t = 0
    else:
        t = a + (bands[:, :] - c) * (b - a) / (d - c)
        t[t < a] = a
        t[t > b] = b
        out[:, :] = t
    
    out = np.uint8(out)
    return out

@jit
def pad_data(data,nei_size):
    c,m,n = data.shape
    t1 = np.zeros([c,nei_size//2,n])
    data = np.concatenate((t1,data,t1),axis=1)
    c,m,n = data.shape
    t2 = np.zeros([c,m,nei_size//2])
    data = np.concatenate((t2,data,t2),axis=2)
    return data

def getPixels(shp, img, size):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    ds = driver.Open(shp, 0)
    if ds is None:
        print('Could not open ' + shp)
        sys.exit(1)

    layer = ds.GetLayer()

    xValues = []
    yValues = []
    feature = layer.GetNextFeature()
    while feature:
        geometry = feature.GetGeometryRef()
        x = geometry.GetX()
        y = geometry.GetY()
        xValues.append(x)
        yValues.append(y)
        feature = layer.GetNextFeature()

    gdal.AllRegister()

    ds = gdal.Open(img, GA_ReadOnly)
    if ds is None:
        print('Could not open image')
        sys.exit(1)

    rows = ds.RasterYSize
    cols = ds.RasterXSize
    bands = ds.RasterCount

    transform = ds.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    values = []
    for i in range(len(xValues)):
        x = xValues[i]
        y = yValues[i]
        new_transform=list(transform)
        #print new_transform
        new_transform[0]=x-im_geotrans[1]*int(size)/2.0
        new_transform[3]=y-im_geotrans[5]*int(size)/2.0
        new_transfor_mtuple=tuple(new_transform)
        x1=x-int(size)/2*transform[1]
        y1=y-int(size)/2*transform[5]
        x2=x+int(size)/2*transform[1]
        y2=y-int(size)/2*transform[5]
        x3=x-int(size)/2*transform[1]
        y3=y+int(size)/2*transform[5]
        x4=x+int(size)/2*transform[1]
        y4=y+int(size)/2*transform[5]
        Xpix=(x1-transform[0])/transform[1]
        #Xpix=(new_transform[0]-transform[0])            
        Ypix=(new_transform[3]-transform[3])/transform[5]
        #Ypix=abs(new_transform[3]-transform[3])
        data = ds.ReadAsArray(int(Xpix),int(Ypix),int(size),int(size))

        values.append([data, new_transfor_mtuple])
    return values

def lbp(img, n_points, radius, level=256, method='default'):
    lbp = local_binary_pattern(img, n_points, radius, method)
    # n_bins = int(lbp.max() + 1)
    n_bins = level
    hist, _ = np.histogram(lbp, density=True, bins=n_bins, range=(0, n_bins))
    return hist

def get_data(data_point, img_path, patch_path, im_proj, band, size, radius, n_points, data_type='false'):
    data_patches = getPixels(data_point, img_path, size)
    data_lbp_hists = []
    data_label = []
    count = 0
    for patch in data_patches:
        patch_p = os.path.join(patch_path, data_type + '_' + str(count) + '.tif')
        write_img(patch_p, im_proj, patch[1], patch[0])
        count += 1
        temp = []
        for i in range(band):
            hist = lbp(patch[0][i,...], n_points, radius)
            # print(hist.shape)
            temp.append(hist)
        temp_arr = temp[0]
        for j in range(band-1):
            temp_arr = np.concatenate((temp_arr,temp[j+1]),axis=0) 
        data_lbp_hists.append(temp_arr)
        if data_type == 'false':
            data_label.append(0)
        else:
            data_label.append(1)
    return data_lbp_hists, data_label

if __name__ == "__main__":
    warnings.filterwarnings("ignore")

    img_path = "E:/20200210/forest/gf2/dys_gf2.tif"
    false_point = 'E:/20200210/forest/gf2/point/1.shp'
    true_point = 'E:/20200210/forest/gf2/point/2.shp'
    patch_path = "E:/20200210/forest/tif_temp/patches2/"
    
    im_proj,im_geotrans,im_width, im_height,im_data = read_img(img_path)
    temp8bit_path = "E:/20200210/forest/gf2/dys_gf2_8bit.tif"
    temp8bit = stretch_n(im_data,0,255)
    write_img(temp8bit_path, im_proj, im_geotrans, temp8bit)

    band = 4
    size = 10
    
    radius = 2
    n_points = 8 * radius
    
    false_lbp_hists, false_label = get_data(false_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points, data_type='false')

    true_lbp_hists, true_label = get_data(true_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points,data_type='true')
    
    train_data = np.array(true_lbp_hists + false_lbp_hists)
    train_label = np.array(true_label + false_label)

    svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000)
    svc.fit(train_data, train_label)
    
    test_area = "E:/20200210/forest/gf2/dys_gf2_test.tif"
    ds = gdal.Open(test_area)
    im_width0 = ds.RasterXSize
    im_height0 = ds.RasterYSize
    im_geotrans0 = ds.GetGeoTransform()
    im_proj0 = ds.GetProjection()
    test_data = ds.ReadAsArray(0,0,im_width0,im_height0)
    test_data = stretch_n(test_data, 0, 255)

    test_data = pad_data(test_data,size)

    all_arr = []
    for i in xrange(size//2, im_height0+size//2):
        for j in xrange(size//2, im_width0+size//2):
            windows = test_data[:,i-size//2:i+size//2+1,j-size//2:j+size//2+1]
            temp = []
            for h in xrange(band):
                hist = lbp(windows[h,...], n_points, radius)
                temp.append(hist)
            temp_arr = temp[0]
            for k in xrange(band-1):
                temp_arr = np.concatenate((temp_arr,temp[k+1]),axis=0)
            all_arr.append(temp_arr)
    
    predict = svc.predict(np.array(all_arr))
    re = predict.reshape((im_height0, im_width0))
    seg_path = "E:/20200210/forest/tif_temp/dys_seg.tif"
    write_img(seg_path, im_proj0, im_geotrans0, re)
    del ds

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章