python gdal + skimage實現基於遙感影像的傳統圖像分割及合併外加矢量化

根據我前述博客中對圖像傳分割算法及圖像塊合併方法的實驗探究,在此將這些方法用於遙感影像並嘗試矢量化。
這個過程中我自己遇到了一個棘手的問題,在最後的結果那裏有描述,希望知道的朋友幫忙解答一下,謝謝!
直接上代碼:

# -*- coding: utf-8 -*-
import os
import cv2
import gdal
from osgeo import ogr,osr
import numpy as np
from skimage import morphology, color, measure
from skimage.segmentation import felzenszwalb, slic, quickshift
from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float
from skimage.future import graph

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_width,im_height,im_proj,im_geotrans,im_data

def write_img(filename,im_proj,im_geotrans,im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)
    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

def DoesDriverHandleExtension(drv, ext):
    exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS)
    return exts is not None and exts.lower().find(ext.lower()) >= 0


def GetExtension(filename):
    ext = os.path.splitext(filename)[1]
    if ext.startswith('.'):
        ext = ext[1:]
    return ext


def GetOutputDriversFor(filename):
    drv_list = []
    ext = GetExtension(filename)
    for i in range(gdal.GetDriverCount()):
        drv = gdal.GetDriver(i)
        if (drv.GetMetadataItem(gdal.DCAP_CREATE) is not None or
            drv.GetMetadataItem(gdal.DCAP_CREATECOPY) is not None) and \
           drv.GetMetadataItem(gdal.DCAP_VECTOR) is not None:
            if ext and DoesDriverHandleExtension(drv, ext):
                drv_list.append(drv.ShortName)
            else:
                prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX)
                if prefix is not None and filename.lower().startswith(prefix.lower()):
                    drv_list.append(drv.ShortName)

    return drv_list

def GetOutputDriverFor(filename):
    drv_list = GetOutputDriversFor(filename)
    ext = GetExtension(filename)
    if not drv_list:
        if not ext:
            return 'ESRI Shapefile'
        else:
            raise Exception("Cannot guess driver for %s" % filename)
    elif len(drv_list) > 1:
        print("Several drivers matching %s extension. Using %s" % (ext if ext else '', drv_list[0]))
    return drv_list[0]

def _weight_mean_color(graph, src, dst, n):
    """Callback to handle merging nodes by recomputing mean color.
    The method expects that the mean color of `dst` is already computed.
    Parameters
    ----------
    graph : RAG
        The graph under consideration.
    src, dst : int
        The vertices in `graph` to be merged.
    n : int
        A neighbor of `src` or `dst` or both.

    Returns
    -------
    data : dict
        A dictionary with the `"weight"` attribute set as the absolute
        difference of the mean color between node `dst` and `n`.
    """
    diff = graph.nodes[dst]['mean color'] - graph.nodes[n]['mean color']
    diff = np.linalg.norm(diff)
    return {'weight': diff}

def merge_mean_color(graph, src, dst):
    """Callback called before merging two nodes of a mean color distance graph.
    This method computes the mean color of `dst`.
    Parameters
    ----------
    graph : RAG
        The graph under consideration.
    src, dst : int
        The vertices in `graph` to be merged.
    """
    graph.nodes[dst]['total color'] += graph.nodes[src]['total color']
    graph.nodes[dst]['pixel count'] += graph.nodes[src]['pixel count']
    graph.nodes[dst]['mean color'] = (graph.nodes[dst]['total color'] /
                                      graph.nodes[dst]['pixel count'])

if __name__ == '__main__':
    img_path = "E:/geo_test/test.tif"
    temp_path = "E:/geo_test/temp/"
    im_width,im_height,im_proj,im_geotrans,im_data = read_img(img_path)  
    temp = im_data.transpose((2,1,0))
    segments_quick = quickshift(temp, kernel_size=3, max_dist=6, ratio=0.5)
    
    mark0 = mark_boundaries(temp, segments_quick)
    save_path = temp_path + "qs_seg0.tif"
    re0 = mark0.transpose((2,1,0))
    write_img(save_path,im_proj,im_geotrans,re0)

    grid_path = temp_path + "qs_grid0.tif"
    grid0 = np.uint8(re0[0,...])
    write_img(grid_path,im_proj,im_geotrans,grid0)

    skeleton = morphology.skeletonize(grid0)
    border0 = np.multiply(grid0, skeleton)
    ret,border0 = cv2.threshold(border0,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    border_path = temp_path + "qs_border0.tif"
    write_img(border_path,im_proj,im_geotrans,border0)
    
    g = graph.rag_mean_color(temp, segments_quick)
    labels2 = graph.merge_hierarchical(segments_quick, g, thresh=5, 
		      rag_copy=False,
              in_place_merge=True,
              merge_func=merge_mean_color,
              weight_func=_weight_mean_color)
    label_rgb2 = color.label2rgb(labels2, temp, kind='avg')
    rgb_path = temp_path + "qs_label.tif"
    lb = labels2.transpose((1,0))
    # lb = median(lb, disk(3))
    write_img(rgb_path,im_proj,im_geotrans,lb)
    
    mark = mark_boundaries(label_rgb2, labels2)
    save_path = temp_path + "qs_seg.tif"
    re = mark.transpose((2,1,0))
    write_img(save_path,im_proj,im_geotrans,re)

    grid_path = temp_path + "qs_grid.tif"
    grid = np.uint8(re[0,...])
    write_img(grid_path,im_proj,im_geotrans,grid)

    skeleton = morphology.skeletonize(grid)
    border = np.multiply(grid, skeleton)
    ret,border = cv2.threshold(border,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    border_path = temp_path + "qs_border.tif"
    write_img(border_path,im_proj,im_geotrans,border)

    # out_shp = temp_path + "temp.shp"
    # RasterToLineshp(border_path, out_shp, 1)

    border_driver = gdal.Open(rgb_path)
    border_band = border_driver.GetRasterBand(1)
    border_mask = border_band.GetMaskBand()

    dst_filename = temp_path + 'temp.shp'
    frmt = GetOutputDriverFor(dst_filename)
    drv = ogr.GetDriverByName(frmt)
    dst_ds = drv.CreateDataSource(dst_filename)
    
    dst_layername = 'out'
    srs = osr.SpatialReference(wkt=border_driver.GetProjection())
    dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbPolygon, srs=srs)
    # dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbLineString, srs=srs)


    dst_fieldname = 'DN'
    fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
    dst_layer.CreateField(fd)
    dst_field = 0

    options = [""]
    options.append('DATASET_FOR_GEOREF=' + rgb_path)
    prog_func = gdal.TermProgress_nocb
    gdal.Polygonize(border_band, border_mask, dst_layer, dst_field, options,
                         callback=prog_func)

    srcband = None
    src_ds = None
    dst_ds = None
    mask_ds = None

# enum WKBGeometryType {
# wkbPoint = 1,
# wkbLineString = 2,
# wkbPolygon = 3,
# wkbTriangle = 17
# wkbMultiPoint = 4,
# wkbMultiLineString = 5,
# wkbMultiPolygon = 6,
# wkbGeometryCollection = 7,
# wkbPolyhedralSurface = 15,
# wkbTIN = 16
# wkbPointZ = 1001,
# wkbLineStringZ = 1002,
# wkbPolygonZ = 1003,
# wkbTrianglez = 1017
# wkbMultiPointZ = 1004,
# wkbMultiLineStringZ = 1005,
# wkbMultiPolygonZ = 1006,
# wkbGeometryCollectionZ = 1007,
# wkbPolyhedralSurfaceZ = 1015,
# wkbTINZ = 1016
# wkbPointM = 2001,
# wkbLineStringM = 2002,
# wkbPolygonM = 2003,
# wkbTriangleM = 2017
# wkbMultiPointM = 2004,
# wkbMultiLineStringM = 2005,
# wkbMultiPolygonM = 2006,
# wkbGeometryCollectionM = 2007,
# wkbPolyhedralSurfaceM = 2015,
# wkbTINM = 2016
# wkbPointZM = 3001,
# wkbLineStringZM = 3002,
# wkbPolygonZM = 3003,
# wkbTriangleZM = 3017
# wkbMultiPointZM = 3004,
# wkbMultiLineStringZM = 3005,
# wkbMultiPolygonZM = 3006,
# wkbGeometryCollectionZM = 3007,
# wkbPolyhedralSurfaceZM = 3015,
# wkbTinZM = 3016,
# }

對應的結果圖如下:
原圖:
原圖
粗分割結果(代碼中的qs_seg0.tif)
粗分割結果
粗分割格網(代碼中的qs_grid0.tif)
粗分割格網
粗分割格網骨架(代碼中的qs_border0.tif),格網的結果不是單線的,這裏取了中心線。
粗分割格網骨架
合併後的分割結果(代碼中的qs_seg.tif):
合併後的粗分割結果
合併後的格網結果(代碼中的qs_grid.tif)
合併後的格網結果
合併後的格網骨架結果(代碼中的qs_border.tif):
合併後的格網骨架結果
下面是矢量化以後的最終結果,這是代碼中的qs_label.tif經過矢量化以後得到的結果,這裏說明一下,之所以不用柵格線來直接轉矢量線是因爲我在GDAL裏面並沒有找到直接轉化的方法,目前的方法強行轉的話只能得到雙線,完全不對,找了很久也沒找到解決辦法只能折中一下先得到面了,後面再面轉線,看到的朋友如果知道的話煩請告知一下用什麼辦法可以直接把柵格線轉爲矢量線,要求脫離arcgis哈。
矢量化以後的結果

TO DO:
1.矢量面轉線
2.線簡化
3.線平滑
做完更新,感興趣的朋友可以關注一下。

後續:
目前矢量面轉矢量線肯定是沒問題的,但是有個大問題就是矢量線的平滑對我來說還有一定難度,想不到具體高效的方式,唯一想到的方式就是將圖層裏的每一個節點找到,在節點位置不變的情況下取出節點之間的線條逐個平滑再放回到圖層中,這樣做有點慢,並且實現起來也比較複雜感覺,所以再次折中,我直接進行面的平滑,平滑完了再轉線看看有沒有可能對結果有幫助。
雖然不做線平滑了,下面還是先給出面專線的代碼:

# -*- coding: utf-8 -*-
import os
import gdal
from osgeo import ogr,osr
import numpy as np

def Test_Poly2Line(input_poly,output_line):
    ogr.RegisterAll()
    
    driver = ogr.GetDriverByName('ESRI Shapefile')
    source_ds = driver.Open(input_poly,1)   
    source_layer = source_ds.GetLayer(0)

    # polygon2geometryCollection
    geomcol =  ogr.Geometry(ogr.wkbGeometryCollection)
    for feat in source_layer:
        geom = feat.GetGeometryRef()
        ring = geom.GetGeometryRef(0)
        geomcol.AddGeometry(ring)
        
    # geometryCollection2shp
    shpDriver = ogr.GetDriverByName("ESRI Shapefile")
    if os.path.exists(output_line):
            shpDriver.DeleteDataSource(output_line)
    outDataSource = shpDriver.CreateDataSource(output_line)
    outLayer = outDataSource.CreateLayer(output_line, geom_type=ogr.wkbMultiLineString)
    featureDefn = outLayer.GetLayerDefn()
    outFeature = ogr.Feature(featureDefn)
    outFeature.SetGeometry(geomcol)
    outLayer.CreateFeature(outFeature)
    outFeature = None


if __name__ == "__main__":
    poly_path = "E:/geo_test/temp/temp.shp"
    line_path = "E:/geo_test/temp/temp2line.shp"
    Test_Poly2Line(poly_path, line_path)

結果如下,可以看到這個結果和麪完全保持一致,畢竟是gdal源碼哈哈。
面專線

下面說一下在面未轉爲線的時候就平滑,在下面的位置加入了中值濾波
在這裏插入圖片描述
這是柵格面平滑後轉化爲面矢量的結果
面平滑
這是和之前沒有進行平滑的結果的疊加對比,變化是有的,但是這裏有一個大問題,就是鋸齒狀太嚴重。
對比

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章