監督分類:SVM即支持向量機實現遙感影像監督分類

環境:python2或3都可以,有osgeo(gdal)(本機版本2.2.4)、scikit-learn(本機版本0.20.3)、numpy(本機版本1.15.0)就行

這個代碼是用的4波段遙感影像,如果用作其他波段的稍微修改一下就行了,後面會標記出來。

數據位置:鏈接:https://pan.baidu.com/s/14i-ePeWm-gnIPSsrgHmnMw 
提取碼:qkgz

首先在圖像上選取樣本點,其實就是選取了圖像的像素值(我用的arcgis選點,一個矢量文件代表一個類別,後面會提供測試數據),然後就改一下對應的路徑就行了,以下是代碼部分:

# -*- coding: utf-8 -*-

from osgeo import ogr

from osgeo import gdal

from gdalconst import *

import os, sys, time

import numpy as np

from sklearn.svm import SVC

def getPixels(shp, img):

    driver = ogr.GetDriverByName('ESRI Shapefile')

    ds = driver.Open(shp, 0)

    if ds is None:

        print('Could not open ' + shp)

        sys.exit(1)

    layer = ds.GetLayer()

    xValues = []

    yValues = []

    feature = layer.GetNextFeature()

    while feature:

        geometry = feature.GetGeometryRef()

        x = geometry.GetX()

        y = geometry.GetY()

        xValues.append(x)

        yValues.append(y)

        feature = layer.GetNextFeature()

    gdal.AllRegister()

    ds = gdal.Open(img, GA_ReadOnly)

    if ds is None:

        print('Could not open image')

        sys.exit(1)

    rows = ds.RasterYSize

    cols = ds.RasterXSize

    bands = ds.RasterCount

    transform = ds.GetGeoTransform()

    xOrigin = transform[0]

    yOrigin = transform[3]

    pixelWidth = transform[1]

    pixelHeight = transform[5]

    values = []

    for i in range(len(xValues)):

        x = xValues[i]

        y = yValues[i]

        xOffset = int((x - xOrigin) / pixelWidth)

        yOffset = int((y - yOrigin) / pixelHeight)

        s = str(int(x)) + ' ' + str(int(y)) + ' ' + str(xOffset) + ' ' + str(yOffset) + ' '

        pt = []

        for j in range(bands):

            band = ds.GetRasterBand(j + 1)

            data = band.ReadAsArray(xOffset-5, yOffset-5, 10, 10) #取了以矢量點爲中心的10*10矩形範圍內的樣本,可修改

            value = data

            value = value.flatten()

            pt.append(value)

        temp = []

        pt = array_change(pt, temp)

        values.append(pt)

    temp2 = []

    all_values = array_change(values, temp2)

    all_values = np.asarray(all_values)

 

    temp3 = []

    result_values = array_change2(all_values, temp3)

    result_values = np.asarray(result_values)

    return result_values

 

def svmDeal(classArray, img_arr, outPath, im_proj, im_geotrans):

    array_num = len(classArray)

    classArray = np.asarray(classArray)

    RGB_arr = classArray[0]

    for k in range(array_num-1):

        RGB_arr = np.concatenate((RGB_arr,classArray[k+1]),axis=0)

    label= np.array([])

    for h in range(array_num):

        array_l = classArray[h].shape[0]

        label = np.append(label,h*np.ones(array_l))

 

    img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]])

    # svc = SVC(kernel='poly', degree=4, cache_size=1000, max_iter=100)

    svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000)

    svc.fit(RGB_arr,label)

    predict = svc.predict(img_reshape)

    for j in range(array_num):

        lake_bool = predict == np.float(j)

        lake_bool = lake_bool[:,np.newaxis]

        lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool,lake_bool),axis=1) #四個波段

        lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))

        img_arr[lake_bool_4d] = np.float(j)

 

    img_arr = img_arr.transpose((2,1,0))

    img_arr = img_arr[0]   #只要單波段的結果

    write_img(outPath, im_proj, im_geotrans, img_arr)

 

def read_img(filename):

    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize

    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()

    im_proj = dataset.GetProjection()

    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 

    return im_proj,im_geotrans,im_width, im_height,im_data

 

def write_img(filename, im_proj, im_geotrans, im_data):

    if 'int8' in im_data.dtype.name:

        datatype = gdal.GDT_Byte

    elif 'int16' in im_data.dtype.name:

        datatype = gdal.GDT_UInt16

    else:

        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:

        im_bands, im_height, im_width = im_data.shape

    else:

        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")

    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)

    dataset.SetProjection(im_proj)

 

    if im_bands == 1:

        dataset.GetRasterBand(1).WriteArray(im_data)

    else:

        for i in range(im_bands):

            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

 

def array_change(inlist, outlist):

    for i in range(len(inlist[0])):

        outlist.append([j[i] for j in inlist])

    return outlist

 

def array_change2(inlist, outlist):

    for ele in inlist:

        for ele2 in ele:

            outlist.append(ele2)

    return outlist

 

if __name__ == '__main__':

    img_p = 'E:/1/1/data2/cgnr/0/cgnr_0.tif'   #原始影像路徑

    shp_path = 'E:/1/1/data2/cgnr/0/point2/'   #點文件路徑,類似於0.shp (對應第一類)、1.shp(對應第二類)、2.shp(對應第三類)等,最終結果是和這些樣本點對應的。如果不想用矢量文件就更簡單了,可以直接在輸入的地方放入自己的樣本就行,這裏主要是爲了用在遙感上才這樣的

    class_list = []

    for shp in os.listdir(shp_path):

        if shp[-4:] == '.shp':

            shp_full_path = os.path.join(shp_path, shp)

            class_type  = getPixels(shp_full_path, img_p)

            class_list.append(class_type)

    

    time1 = time.time()

    im_proj, im_geotrans, im_width, im_height, im_data = read_img(img_p)

    im_data = im_data.transpose((2,1,0))

    out_path = 'E:/abg_test/1/data2/cgnr/0/cgnr_0_sd.tif' #輸出結果

    svmDeal(class_list, im_data, out_path, im_proj, im_geotrans)

    time2 = time.time()

 

 

上面代碼有人調不通,不行就試試下面這個:

# -*- coding: utf-8 -*-
from osgeo import ogr
from osgeo import gdal
from gdalconst import *
import os, sys, time
import numpy as np
from sklearn.svm import SVC


def getPixels(shp, img):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    ds = driver.Open(shp, 0)
    if ds is None:
        print('Could not open ' + shp)
        sys.exit(1)

    layer = ds.GetLayer()

    xValues = []
    yValues = []
    feature = layer.GetNextFeature()
    while feature:
        geometry = feature.GetGeometryRef()
        x = geometry.GetX()
        y = geometry.GetY()
        xValues.append(x)
        yValues.append(y)
        feature = layer.GetNextFeature()

    gdal.AllRegister()

    ds = gdal.Open(img, GA_ReadOnly)
    if ds is None:
        print('Could not open image')
        sys.exit(1)

    rows = ds.RasterYSize
    cols = ds.RasterXSize
    bands = ds.RasterCount

    transform = ds.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    values = []
    for i in range(len(xValues)):
        x = xValues[i]
        y = yValues[i]

        xOffset = int((x - xOrigin) / pixelWidth)
        yOffset = int((y - yOrigin) / pixelHeight)

        s = str(int(x)) + ' ' + str(int(y)) + ' ' + str(xOffset) + ' ' + str(yOffset) + ' '

        pt = []
        for j in range(bands):
            band = ds.GetRasterBand(j + 1)
            data = band.ReadAsArray(xOffset-5, yOffset-5, 10, 10)
            value = data
            value = value.flatten()
            pt.append(value)
        
        temp = []
        pt = array_change(pt, temp)
        values.append(pt)
    
    temp2 = []
    all_values = array_change(values, temp2)
    all_values = np.asarray(all_values)

    temp3 = []
    result_values = array_change2(all_values, temp3)
    result_values = np.asarray(result_values)
    return result_values


def svmDeal(classArray, img_arr, outPath, im_proj, im_geotrans):
    array_num = len(classArray)
    classArray = np.asarray(classArray)
    # array_l = classArray[0].shape[0]

    RGB_arr = classArray[0]
    for k in range(array_num-1):
        RGB_arr = np.concatenate((RGB_arr,classArray[k+1]),axis=0)
    
    label= np.array([])
    for h in range(array_num):
        array_l = classArray[h].shape[0]
        label = np.append(label,h*np.ones(array_l))

    img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]])
    # svc = SVC(kernel='poly', degree=4, cache_size=1000, max_iter=100)
    svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000)
    svc.fit(RGB_arr,label)
    predict = svc.predict(img_reshape)
    for j in range(array_num):
        lake_bool = predict == np.float(j)
        lake_bool = lake_bool[:,np.newaxis]
        lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool,lake_bool),axis=1)
        lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))
        img_arr[lake_bool_4d] = np.float(j)

    img_arr = img_arr.transpose((2,1,0))
    img_arr = img_arr[0]
    write_img(outPath, im_proj, im_geotrans, img_arr)


def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

def array_change(inlist, outlist):
    for i in range(len(inlist[0])):
        outlist.append([j[i] for j in inlist])
    return outlist

def array_change2(inlist, outlist):
    for ele in inlist:
        for ele2 in ele:
            outlist.append(ele2)
    return outlist

if __name__ == '__main__':
    img_p = 'C:/Users/DELL/Desktop/data/g1_test.tif'
    shp_path = 'C:/Users/DELL/Desktop/data/point/'

    class_list = []
    for shp in os.listdir(shp_path):
        if shp[-4:] == '.shp':
            shp_full_path = os.path.join(shp_path, shp)
            class_type  = getPixels(shp_full_path, img_p)
            class_list.append(class_type)
    
    time1 = time.time()
    im_proj, im_geotrans, im_width, im_height, im_data = read_img(img_p)
    im_data = im_data.transpose((2,1,0))

    out_path = 'C:/Users/DELL/Desktop/data/11.tif'
    svmDeal(class_list, im_data, out_path, im_proj, im_geotrans)

    time2 = time.time()
    print((time2-time1)/3600)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章