tensorflow中的datasets以及自定義的數據增強--1.15.0版本

作爲一個tensorflow的初學者,感覺這個框架好。。。。。被動

之前看tensorflow的代碼時候,遇到

https://github.com/guoqiangqi/PFLD

使用瞭如下的方式,進行數據讀取,和batch打包,如果我們想對每一個batch載入的圖像數據增強,應該寫在_parse_data中:

dataset = tf.data.Dataset.from_tensor_slices((file_list, landmarks, attributes,euler_angles))
dataset = dataset.map(_parse_data)
dataset = dataset.shuffle(buffer_size=10000)
.......
train_dataset, num_train_file = DateSet(args.file_list, args, debug)
batch_train_dataset = train_dataset.batch(args.batch_size).repeat()
train_next_element = train_iterator.get_next()

但是,此時的數據已經是tensor格式,如果我們的數據增強全是numpy和opencv形式的自定義數據增強,我們該怎麼處理,

原本我聽說過@tf.function,準備試試,但是發現這是2.0的函數,並且這是把numpy行爲寫入tensorflow圖的函數

 

查找了許多資料,找到了解決方式

這時候需要引入tf.py_func

dataset = dataset.map(
    lambda filename, landmarks, euler_angles: tf.py_func(
        _read_py_function, [filename, landmarks, euler_angles], [tf.uint8, landmarks.dtype, euler_angles.dtype]))

具體參考: http://d0evi1.com/tensorflow/datasets/

原始https://github.com/guoqiangqi/PFLD的代碼如下:

import tensorflow as tf
import numpy as np
import cv2

def DateSet(file_list, args, debug=False):
    file_list, landmarks, attributes,euler_angles = gen_data(file_list)
    if debug:
        n = args.batch_size * 10
        file_list = file_list[:n]
        landmarks = landmarks[:n]
        attributes = attributes[:n]
        euler_angles=euler_angles[:n]
    dataset = tf.data.Dataset.from_tensor_slices((file_list, landmarks, attributes,euler_angles))

    def _parse_data(filename, landmarks, attributes,euler_angles):
        # filename, landmarks, attributes = data
        file_contents = tf.read_file(filename)
        image = tf.image.decode_png(file_contents, channels=args.image_channels)
        # print(image.get_shape())
        # image.set_shape((args.image_size, args.image_size, args.image_channels))
        image = tf.image.resize_images(image, (args.image_size, args.image_size), method=0)
        image = tf.cast(image, tf.float32)

        image = image / 256.0
        return (image, landmarks, attributes,euler_angles)

    dataset = dataset.map(_parse_data)
    dataset = dataset.shuffle(buffer_size=10000)
    return dataset, len(file_list)

def gen_data(file_list):
    with open(file_list,'r') as f:
        lines = f.readlines()
    filenames, landmarks,attributes,euler_angles = [], [], [],[]
    for line in lines:
        line = line.strip().split()
        path = line[0]
        landmark = line[1:197]
        attribute = line[197:203]
        euler_angle = line[203:206]

        landmark = np.asarray(landmark, dtype=np.float32)
        attribute = np.asarray(attribute, dtype=np.int32)
        euler_angle = np.asarray(euler_angle,dtype=np.float32)
        filenames.append(path)
        landmarks.append(landmark)
        attributes.append(attribute)
        euler_angles.append(euler_angle)
        
    filenames = np.asarray(filenames, dtype=np.str)
    landmarks = np.asarray(landmarks, dtype=np.float32)
    attributes = np.asarray(attributes, dtype=np.int32)
    euler_angles = np.asarray(euler_angles,dtype=np.float32)
    return (filenames, landmarks, attributes,euler_angles)


if __name__ == '__main__':
    file_list = 'data/train_data/list.txt'
    filenames, landmarks, attributes = gen_data(file_list)
    for i in range(len(filenames)):
        filename = filenames[i]
        landmark = landmarks[i]
        attribute = attributes[i]
        print(attribute)
        img = cv2.imread(filename)
        h,w,_ = img.shape
        landmark = landmark.reshape(-1,2)*[h,w]
        for (x,y) in landmark.astype(np.int32):
            cv2.circle(img, (x,y),1,(0,0,255))
        cv2.imshow('0', img)
        cv2.waitKey(0)

改進後:

 

import tensorflow as tf
import numpy as np
import cv2
import random

def DateSet(file_list, args, debug=False):
    file_list, landmarks,euler_angles = gen_data(file_list)
    if debug:
        n = args.batch_size * 10
        file_list = file_list[:n]
        landmarks = landmarks[:n]
        #attributes = attributes[:n]
        euler_angles=euler_angles[:n]

    dataset = tf.data.Dataset.from_tensor_slices((file_list, landmarks,euler_angles))

    # yangninghua
    def _read_py_function(filename, landmarks, euler_angles):
        #print(filename.decode('ascii'))
        image_decoded = cv2.imread(filename.decode('ascii'))
        H,W = image_decoded.shape[:2]

        alpha = random.choice([-15, 15])
        # debug
        #print("alpha:", alpha)
        flag = np.random.randint(2)

        # debug
        # print("landmarks: ", landmarks)
        # print("euler_angles: ", euler_angles)
        # new_annotation = []
        # for key in range(0, len(landmarks), 2):
        #     new_annotation.append(landmarks[key])
        #     new_annotation.append(landmarks[key+1])
        # temp = np.array(new_annotation)
        # print("new_landmarks: ", temp)

        if flag==0:
            center = ((0 + W-1) / 2, (0 + H-1) / 2)
            rot_mat = cv2.getRotationMatrix2D(center, alpha, 1)
            # debug
            #print("rot_mat: ", rot_mat)

            image_decoded = cv2.warpAffine(image_decoded, rot_mat, (image_decoded.shape[1], image_decoded.shape[0]))
            new_annotation = []
            for key in range(0, len(landmarks), 2):
                x = landmarks[key]*W
                y = landmarks[key+1]*H
                new_annotation.append((rot_mat[0][0] * x + rot_mat[0][1] * y + rot_mat[0][2])/W)
                new_annotation.append((rot_mat[1][0] * x + rot_mat[1][1] * y + rot_mat[1][2])/H)
            landmarks = np.array(new_annotation)
            landmarks = landmarks.astype('float32')
            print(euler_angles)
            euler_angles[1] = euler_angles[1]+alpha
            print(euler_angles)
        return image_decoded, landmarks, euler_angles

    # yangninghua
    def _resize_function(image_decoded, landmarks, euler_angles):
        image_decoded.set_shape([None, None, None])
        image = tf.image.resize_images(image_decoded, (args.image_size, args.image_size), method=0)
        image = tf.cast(image, tf.float32)
        image = image / 256.0
        return (image, landmarks, euler_angles)

    # def _parse_data(filename, landmarks,euler_angles):
    #     file_contents = tf.read_file(filename)
    #     file_contents, landmarks, euler_angles = rotate(file_contents, landmarks, euler_angles)
    #     image = tf.image.decode_png(file_contents, channels=args.image_channels)
    #     image = tf.image.resize_images(image, (args.image_size, args.image_size), method=0)
    #     image = tf.cast(image, tf.float32)
    #     image = image / 256.0
    #     return (image, landmarks, euler_angles)

    #dataset = dataset.map(_parse_data)

    # yangninghua
    dataset = dataset.map(
        lambda filename, landmarks, euler_angles: tf.py_func(
            _read_py_function, [filename, landmarks, euler_angles], [tf.uint8, landmarks.dtype, euler_angles.dtype]))
    dataset = dataset.map(_resize_function)

    dataset = dataset.shuffle(buffer_size=10000)
    return dataset, len(file_list)

def gen_data(file_list):
    with open(file_list,'r') as f:
        lines = f.readlines()
    filenames, landmarks,attributes,euler_angles = [], [], [],[]
    for line in lines:
        line = line.strip().split()
        path = line[0]
        landmark = line[1:43]
        attribute = line[43:49]
        euler_angle = line[49:52]

        landmark = np.asarray(landmark, dtype=np.float32)
        attribute = np.asarray(attribute, dtype=np.int32)
        euler_angle = np.asarray(euler_angle,dtype=np.float32)
        sub1 = path.split("/")[-1]
        sub2 = "D:/code/python/untitled/data/train_data/imgs/"
        path = sub2 + sub1
        filenames.append(path)
        landmarks.append(landmark)
        attributes.append(attribute)
        euler_angles.append(euler_angle)
        
    filenames = np.asarray(filenames, dtype=np.str)
    landmarks = np.asarray(landmarks, dtype=np.float32)
    attributes = np.asarray(attributes, dtype=np.int32)
    euler_angles = np.asarray(euler_angles,dtype=np.float32)
    return (filenames, landmarks, euler_angles)


if __name__ == '__main__':
    file_list = 'data/train_data/list.txt'
    filenames, landmarks, attributes = gen_data(file_list)
    for i in range(len(filenames)):
        filename = filenames[i]
        landmark = landmarks[i]
        attribute = attributes[i]
        print(attribute)
        img = cv2.imread(filename)
        h,w,_ = img.shape
        landmark = landmark.reshape(-1,2)*[h,w]
        for (x,y) in landmark.astype(np.int32):
            cv2.circle(img, (x,y),1,(0,0,255))
        cv2.imshow('0', img)
        cv2.waitKey(0)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章