前面已經做了很多像素級的分類了,這裏繼續深入,這裏用的特徵是LBP(局部閾值模式),思路就是先手動選點取樣本點鄰域並提取LBP特徵進行訓練得到模型,然後取每個像素的鄰域提取LBP特徵,然後用訓練好的模型對每個像素對應鄰域的LBP特徵進行判斷,最終確定像素的類型
這裏放了一個初版和一個改進後的版本,初版裏面有想記錄的東西所以就也放這裏了,過客們可以直接往下翻看終版。
# -*- coding: utf-8 -*-
import os, sys, time
import warnings
import gdal
import numpy as np
from numpy import average, dot, linalg
import cv2
import skimage
from osgeo import ogr
from osgeo import gdal
from osgeo import gdal_array as ga
from gdalconst import *
from skimage.feature import local_binary_pattern
from skimage.util.shape import view_as_windows #這是影像切片的模塊,存在邊界無法處理的缺陷
from sklearn.svm import SVC
from numba import jit
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
del dataset
return im_proj,im_geotrans,im_width, im_height,im_data
def write_img(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
@jit
def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
out = np.zeros_like(bands).astype(np.float32)
# a = 0
# b = 65535
a = img_min
b = img_max
c = np.percentile(bands[:, :], lower_percent)
d = np.percentile(bands[:, :], higher_percent)
x = d-c
if (x==0).any():
t = 0
else:
t = a + (bands[:, :] - c) * (b - a) / (d - c)
t[t < a] = a
t[t > b] = b
out[:, :] = t
out = np.uint8(out)
return out
def getPixels(shp, img, size):
driver = ogr.GetDriverByName('ESRI Shapefile')
ds = driver.Open(shp, 0)
if ds is None:
print('Could not open ' + shp)
sys.exit(1)
layer = ds.GetLayer()
xValues = []
yValues = []
feature = layer.GetNextFeature()
while feature:
geometry = feature.GetGeometryRef()
x = geometry.GetX()
y = geometry.GetY()
xValues.append(x)
yValues.append(y)
feature = layer.GetNextFeature()
gdal.AllRegister()
ds = gdal.Open(img, GA_ReadOnly)
if ds is None:
print('Could not open image')
sys.exit(1)
rows = ds.RasterYSize
cols = ds.RasterXSize
bands = ds.RasterCount
transform = ds.GetGeoTransform()
xOrigin = transform[0]
yOrigin = transform[3]
pixelWidth = transform[1]
pixelHeight = transform[5]
values = []
for i in range(len(xValues)):
x = xValues[i]
y = yValues[i]
new_transform=list(transform)
#print new_transform
new_transform[0]=x-im_geotrans[1]*int(size)/2.0
new_transform[3]=y-im_geotrans[5]*int(size)/2.0
new_transfor_mtuple=tuple(new_transform)
x1=x-int(size)/2*transform[1]
y1=y-int(size)/2*transform[5]
x2=x+int(size)/2*transform[1]
y2=y-int(size)/2*transform[5]
x3=x-int(size)/2*transform[1]
y3=y+int(size)/2*transform[5]
x4=x+int(size)/2*transform[1]
y4=y+int(size)/2*transform[5]
Xpix=(x1-transform[0])/transform[1]
Ypix=(new_transform[3]-transform[3])/transform[5]
data = ds.ReadAsArray(int(Xpix),int(Ypix),int(size),int(size))
values.append([data, new_transfor_mtuple])
return values
def lbp(img, n_points, radius, level=256, method='default'): #這裏注意256位與8bit對應
lbp = local_binary_pattern(img, n_points, radius, method)
# n_bins = int(lbp.max() + 1)
n_bins = level
hist, _ = np.histogram(lbp, density=True, bins=n_bins, range=(0, n_bins))
return hist
def get_data(data_point, img_path, patch_path, im_proj, band, size, radius, n_points, data_type='false'):
data_patches = getPixels(data_point, img_path, size)
data_lbp_hists = []
data_label = []
count = 0
for patch in data_patches:
patch_p = os.path.join(patch_path, data_type + '_' + str(count) + '.tif')
write_img(patch_p, im_proj, patch[1], patch[0])
count += 1
temp = []
for i in range(band):
hist = lbp(patch[0][i,...], n_points, radius)
# print(hist.shape)
temp.append(hist)
temp_arr = temp[0]
for j in range(band-1):
temp_arr = np.concatenate((temp_arr,temp[j+1]),axis=0)
data_lbp_hists.append(temp_arr)
if data_type == 'false':
data_label.append(0)
else:
data_label.append(1)
return data_lbp_hists, data_label
if __name__ == "__main__":
warnings.filterwarnings("ignore")
img_path = "E:/20200210/forest/gf2/dys_gf2.tif" #大圖
false_point = 'E:/20200210/forest/gf2/point/1.shp' #負樣本點
true_point = 'E:/20200210/forest/gf2/point/2.shp' #正樣本點
patch_path = "E:/20200210/forest/tif_temp/patches2/" #樣本點對應的圖像塊,以樣本點爲中心裁剪
im_proj,im_geotrans,im_width, im_height,im_data = read_img(img_path)
temp8bit_path = "E:/20200210/forest/gf2/dys_gf2_8bit.tif"
temp8bit = stretch_n(im_data,0,255)
write_img(temp8bit_path, im_proj, im_geotrans, temp8bit) #影像原本是16bit的,這裏轉成8bit,減小計算量
band = 4
size = 10 #鄰域大小
radius = 2 #LBP參數
n_points = 8 * radius #LBP參數
false_lbp_hists, false_label = get_data(false_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points, data_type='false')
true_lbp_hists, true_label = get_data(true_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points,data_type='true')
train_data = np.array(true_lbp_hists + false_lbp_hists)
train_label = np.array(true_label + false_label)
svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000) #這裏用了SVM,也可以選用別的方法訓練模型,比如以前提到的隨機森林
svc.fit(train_data, train_label)
test_area = "E:/20200210/forest/gf2/dys_gf2_test.tif" #測試區域,有點慢,用小點的。
ds = gdal.Open(test_area)
im_width0 = ds.RasterXSize
im_height0 = ds.RasterYSize
im_geotrans0 = ds.GetGeoTransform()
im_proj0 = ds.GetProjection()
test_data = ds.ReadAsArray(0,0,im_width0,im_height0)
test_data = stretch_n(test_data, 0, 255)
# xlength = int((im_width+0.0)/size)
# ylength = int((im_height+0.0)/size)
xlength = im_width0 - size
ylength = im_height0 - size
window_shape = (4,10,10) #切片大小,應該和訓練時的切片大小一致,也就是上面的size
windows = view_as_windows(test_data, window_shape) #自動切圖,步長是1,也就是說右邊界和下邊界的10個像素都無法處理,感興趣可以搜索一下這個函數,挺不錯的,至少速度很快。
windows = np.squeeze(windows)
all_arr = []
for i in xrange(ylength):
for j in xrange(xlength):
temp = []
for h in xrange(band):
hist = lbp(windows[i,j][h,...], n_points, radius)
temp.append(hist)
temp_arr = temp[0]
for k in xrange(band-1):
temp_arr = np.concatenate((temp_arr,temp[k+1]),axis=0)
all_arr.append(temp_arr)
predict = svc.predict(np.array(all_arr))
re = predict.reshape((im_height0-size, im_width0-size)) #由於無法處理右邊10和下邊10個像素,所以預測結果是少了一部分的,需要減去後再reshape
seg_path = "E:/20200210/forest/tif_temp/dys_seg.tif"
write_img(seg_path, im_proj0, im_geotrans0, re)
del ds
改進版本
由於上述的圖像切片不但不能處理邊界而且每次的切片都是以掩膜左上角爲準的鄰域,而不是以每個待判別像素爲中心的領域,問題很大。下面的代碼給圖像上下左右增加了padding,輔助判別邊界點,代碼的整體邏輯沒什麼問題了,但是效果不太理想,我推測是特徵部分有問題,有時間再探究,你們就作爲進一步探索的參考吧。
# -*- coding: utf-8 -*-
import os, sys, time
import warnings
import gdal
import numpy as np
from numpy import average, dot, linalg
import cv2
import skimage
from osgeo import ogr
from osgeo import gdal
from osgeo import gdal_array as ga
from gdalconst import *
from skimage.feature import local_binary_pattern
from skimage.util.shape import view_as_windows, view_as_blocks
from sklearn.svm import SVC
from numba import jit
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
del dataset
return im_proj,im_geotrans,im_width, im_height,im_data
def write_img(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
@jit
def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
out = np.zeros_like(bands).astype(np.float32)
# a = 0
# b = 65535
a = img_min
b = img_max
c = np.percentile(bands[:, :], lower_percent)
d = np.percentile(bands[:, :], higher_percent)
x = d-c
if (x==0).any():
t = 0
else:
t = a + (bands[:, :] - c) * (b - a) / (d - c)
t[t < a] = a
t[t > b] = b
out[:, :] = t
out = np.uint8(out)
return out
@jit
def pad_data(data,nei_size):
c,m,n = data.shape
t1 = np.zeros([c,nei_size//2,n])
data = np.concatenate((t1,data,t1),axis=1)
c,m,n = data.shape
t2 = np.zeros([c,m,nei_size//2])
data = np.concatenate((t2,data,t2),axis=2)
return data
def getPixels(shp, img, size):
driver = ogr.GetDriverByName('ESRI Shapefile')
ds = driver.Open(shp, 0)
if ds is None:
print('Could not open ' + shp)
sys.exit(1)
layer = ds.GetLayer()
xValues = []
yValues = []
feature = layer.GetNextFeature()
while feature:
geometry = feature.GetGeometryRef()
x = geometry.GetX()
y = geometry.GetY()
xValues.append(x)
yValues.append(y)
feature = layer.GetNextFeature()
gdal.AllRegister()
ds = gdal.Open(img, GA_ReadOnly)
if ds is None:
print('Could not open image')
sys.exit(1)
rows = ds.RasterYSize
cols = ds.RasterXSize
bands = ds.RasterCount
transform = ds.GetGeoTransform()
xOrigin = transform[0]
yOrigin = transform[3]
pixelWidth = transform[1]
pixelHeight = transform[5]
values = []
for i in range(len(xValues)):
x = xValues[i]
y = yValues[i]
new_transform=list(transform)
#print new_transform
new_transform[0]=x-im_geotrans[1]*int(size)/2.0
new_transform[3]=y-im_geotrans[5]*int(size)/2.0
new_transfor_mtuple=tuple(new_transform)
x1=x-int(size)/2*transform[1]
y1=y-int(size)/2*transform[5]
x2=x+int(size)/2*transform[1]
y2=y-int(size)/2*transform[5]
x3=x-int(size)/2*transform[1]
y3=y+int(size)/2*transform[5]
x4=x+int(size)/2*transform[1]
y4=y+int(size)/2*transform[5]
Xpix=(x1-transform[0])/transform[1]
#Xpix=(new_transform[0]-transform[0])
Ypix=(new_transform[3]-transform[3])/transform[5]
#Ypix=abs(new_transform[3]-transform[3])
data = ds.ReadAsArray(int(Xpix),int(Ypix),int(size),int(size))
values.append([data, new_transfor_mtuple])
return values
def lbp(img, n_points, radius, level=256, method='default'):
lbp = local_binary_pattern(img, n_points, radius, method)
# n_bins = int(lbp.max() + 1)
n_bins = level
hist, _ = np.histogram(lbp, density=True, bins=n_bins, range=(0, n_bins))
return hist
def get_data(data_point, img_path, patch_path, im_proj, band, size, radius, n_points, data_type='false'):
data_patches = getPixels(data_point, img_path, size)
data_lbp_hists = []
data_label = []
count = 0
for patch in data_patches:
patch_p = os.path.join(patch_path, data_type + '_' + str(count) + '.tif')
write_img(patch_p, im_proj, patch[1], patch[0])
count += 1
temp = []
for i in range(band):
hist = lbp(patch[0][i,...], n_points, radius)
# print(hist.shape)
temp.append(hist)
temp_arr = temp[0]
for j in range(band-1):
temp_arr = np.concatenate((temp_arr,temp[j+1]),axis=0)
data_lbp_hists.append(temp_arr)
if data_type == 'false':
data_label.append(0)
else:
data_label.append(1)
return data_lbp_hists, data_label
if __name__ == "__main__":
warnings.filterwarnings("ignore")
img_path = "E:/20200210/forest/gf2/dys_gf2.tif"
false_point = 'E:/20200210/forest/gf2/point/1.shp'
true_point = 'E:/20200210/forest/gf2/point/2.shp'
patch_path = "E:/20200210/forest/tif_temp/patches2/"
im_proj,im_geotrans,im_width, im_height,im_data = read_img(img_path)
temp8bit_path = "E:/20200210/forest/gf2/dys_gf2_8bit.tif"
temp8bit = stretch_n(im_data,0,255)
write_img(temp8bit_path, im_proj, im_geotrans, temp8bit)
band = 4
size = 10
radius = 2
n_points = 8 * radius
false_lbp_hists, false_label = get_data(false_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points, data_type='false')
true_lbp_hists, true_label = get_data(true_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points,data_type='true')
train_data = np.array(true_lbp_hists + false_lbp_hists)
train_label = np.array(true_label + false_label)
svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000)
svc.fit(train_data, train_label)
test_area = "E:/20200210/forest/gf2/dys_gf2_test.tif"
ds = gdal.Open(test_area)
im_width0 = ds.RasterXSize
im_height0 = ds.RasterYSize
im_geotrans0 = ds.GetGeoTransform()
im_proj0 = ds.GetProjection()
test_data = ds.ReadAsArray(0,0,im_width0,im_height0)
test_data = stretch_n(test_data, 0, 255)
test_data = pad_data(test_data,size)
all_arr = []
for i in xrange(size//2, im_height0+size//2):
for j in xrange(size//2, im_width0+size//2):
windows = test_data[:,i-size//2:i+size//2+1,j-size//2:j+size//2+1]
temp = []
for h in xrange(band):
hist = lbp(windows[h,...], n_points, radius)
temp.append(hist)
temp_arr = temp[0]
for k in xrange(band-1):
temp_arr = np.concatenate((temp_arr,temp[k+1]),axis=0)
all_arr.append(temp_arr)
predict = svc.predict(np.array(all_arr))
re = predict.reshape((im_height0, im_width0))
seg_path = "E:/20200210/forest/tif_temp/dys_seg.tif"
write_img(seg_path, im_proj0, im_geotrans0, re)
del ds