python gdal根據圖像座標生成矢量框(含圖像座標轉地理座標)

要生成矢量框需要將圖像座標轉換爲地理座標或者投影座標,以下代碼是生成了滿足條件的1000*1000區域對應的矢量框,關鍵在於紅色字體部分。

# -*- coding: utf-8 -*-

import os

from osgeo import ogr, osr

import gdal

import heapq

import numba

import numpy as np

 

def read_img(filename):

    dataset=gdal.Open(filename)

 

    im_width = dataset.RasterXSize

    im_height = dataset.RasterYSize

 

    im_geotrans = dataset.GetGeoTransform()

    im_proj = dataset.GetProjection()

    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

 

    del dataset 

    return im_proj,im_geotrans,im_width, im_height,im_data


 

def write_img(filename, im_proj, im_geotrans, im_data):

    if 'int8' in im_data.dtype.name:

        datatype = gdal.GDT_Byte

    elif 'int16' in im_data.dtype.name:

        datatype = gdal.GDT_UInt16

    else:

        datatype = gdal.GDT_Float32

 

    if len(im_data.shape) == 3:

        im_bands, im_height, im_width = im_data.shape

    else:

        im_bands, (im_height, im_width) = 1,im_data.shape 

 

    driver = gdal.GetDriverByName("GTiff")

    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

 

    dataset.SetGeoTransform(im_geotrans)

    dataset.SetProjection(im_proj)

 

    if im_bands == 1:

        dataset.GetRasterBand(1).WriteArray(im_data)

    else:

        for i in range(im_bands):

            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

 

    del dataset

 

@numba.jit

def conv2(X, k):

    x_row, x_col = X.shape

    k_row, k_col = k.shape

    ret_row, ret_col = x_row - k_row + 1, x_col - k_col + 1

    ret = np.empty((ret_row, ret_col))

    for y in range(ret_row):

        for x in range(ret_col):

            sub = X[y : y + k_row, x : x + k_col]

            ret[y,x] = np.sum(sub * k)

    return ret



 

def imagexy2geo(dataset, row, col):

    '''

    根據GDAL的六參數模型將影像圖上座標(行列號)轉爲投影座標或地理座標(根據具體數據的座標系統轉換)

    :param dataset: GDAL地理數據

    :param row: 像素的行號

    :param col: 像素的列號

    :return: 行列號(row, col)對應的投影座標或地理座標(x, y)

    '''

    trans = dataset.GetGeoTransform()

    px = trans[0] + col * trans[1] + row * trans[2]

    py = trans[3] + col * trans[4] + row * trans[5]

    return px, py

 

if __name__ == '__main__':

    

    ogr.RegisterAll()

 

    img_path = 'E:/wsl/pre/xm_15rgb.tif'

    temp = 'E:/wsl/pre/xm_15rgb_temp.tif'

    out = 'E:/wsl/shp/'

 

    im_proj,im_geotrans,im_width, im_height,im_data = read_img(img_path)

    im_data[im_data<50] = 0

    im_data[im_data>250] = 0

    write_img(temp, im_proj, im_geotrans, im_data)

    print('save temp')

 

    kernel = np.ones((1000,1000))


 

    dataset=gdal.Open(temp)

    im_width = dataset.RasterXSize

    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()

    im_proj = dataset.GetProjection()

    w_left = im_width%1000

    h_left = im_height%1000

    im_width_count = im_width/1000

    im_height_count = im_height/1000

    data_list = {}

    for i in range(im_width_count):

        for j in range(im_height_count):

            data = dataset.ReadAsArray(i*1000, j*1000, 1000, 1000)

            conv = np.sum(data * kernel)

            data_list[str(i)+ "_" + str(j)] = conv

    

    new_list = sorted(data_list,key=data_list.__getitem__, reverse=True)

    # new_list = sorted(data_list,key=data_list.__getitem__)

    count = 1

    for ele in new_list[0:50]:  #輸出滿足條件的前50個矢量框

        w, h = ele.split('_')

 

        w = int(w)*1000

        h = int(h)*1000

        wa, ha = imagexy2geo(dataset, h, w)

 

        w1 = int(w) + 1000

        h1 = int(h)

        wa1, ha1 = imagexy2geo(dataset, h1, w1)

 

        w2 = int(w) + 1000

        h2 = int(h) + 1000

        wa2, ha2 = imagexy2geo(dataset, h2, w2)

 

        w3 = int(w)

        h3 = int(h) + 1000

        wa3, ha3 = imagexy2geo(dataset, h3, w3)

 

        shp_path = os.path.join(out, str(count)+'.shp')

        driver = ogr.GetDriverByName("ESRI Shapefile")

        data_source = driver.CreateDataSource(shp_path)

        srs = osr.SpatialReference()

        srs.ImportFromEPSG(4326)

        layer = data_source.CreateLayer("polygon", srs, ogr.wkbPolygon)

        feature = ogr.Feature(layer.GetLayerDefn())

        wkt = "POLYGON((" + str(wa)+ " " +str(ha)+ "," + str(wa1) + " " + str(ha1) + "," + str(wa2)+ " " +str(ha2)+ "," + str(wa3)+ " " +str(ha3) + "))"

        point = ogr.CreateGeometryFromWkt(wkt)

        point.CloseRings()

        feature.SetGeometry(point)

        layer.CreateFeature(feature)

        feature = None

        data_source = None

        count += 1

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章