Google Open Images Dataset V4 圖片數據集詳解2-分類快速下載

上節我們介紹了open image v4數據集的結構信息,這節裏我們來嘗試來真正下載相應的圖片,整個數據集很大有561GB,這麼大的數據量,對於學習者,傳輸和存儲都是個問題。其實我最常用的方式是下載某些(某個)對象的圖片就夠了,根據上節我們講的關係,以對象檢測爲例,我們可以寫一個腳本,單獨的獲取某些對象圖片。這節我們講述如何快速下載一個烏龜的圖像集,我們先在v4的官網上瀏覽Tortoise,差不多是這樣:

一、安裝tensorflow object detect Api

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#在根目錄下創建一個output目錄
mkdir /output
cd /output/
 
#下載舊版本的tensorflow model(object api 包含在model裏),最新版本的api存在問題(當前2018.4.20)
wget  https://github.com/tensorflow/models/archive/dcfe009a024854207c9067d785c105f5ebf5a01b.zip
unzip dcfe009a024854207c9067d785c105f5ebf5a01b.zip 
mv models-dcfe009a024854207c9067d785c105f5ebf5a01b models
rm dcfe009a024854207c9067d785c105f5ebf5a01b.zip 
 
#安裝依賴項
pip install Cython
pip install pillow
pip install lxml
pip install jupyter
pip install matplotlib
pip install opencv-python
pip install pycocotools
 
#安裝object detection api 並驗證
cd /output/models/research/
protoc object_detection/protos/*.proto --python_out=.
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
python object_detection/builders/model_builder_test.py


下載代碼github


二、根據關鍵字生成tfrecord

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import pandas as pd
import numpy as np
import os
import tensorflow as tf
import io
import logging
import random
import sys
import PIL.Image
import hashlib
from urllib import request
  
sys.path.append("/output/models/research/")
from object_detection.utils import dataset_util
  
  
class open_image_dataset:
  
      
      
  
      
  
    def download_test(self):
        print("start download test info")
        folder="test"
        if os.path.exists(folder,) is False:
            os.makedirs(folder)
        image_csv_path=folder+"/image.csv"
        box_csv_path=folder+"/box.csv"
        classname_csv_path=folder+"/classname.csv"
        if os.path.exists(image_csv_path) is False:
            request.urlretrieve(self.test_image_csv,image_csv_path)
        if os.path.exists(box_csv_path) is False:
            request.urlretrieve(self.test_box_csv,box_csv_path )
        if os.path.exists(classname_csv_path) is False:
            request.urlretrieve(self.classname_csv,classname_csv_path )
        print("download test complete")
    def download_val(self):
        folder="val"
        if os.path.exists(folder,) is False:
            os.makedirs(folder)
        image_csv_path=folder+"/image.csv"
        box_csv_path=folder+"/box.csv"
        classname_csv_path=folder+"/classname.csv"
        if os.path.exists(image_csv_path) is False:
            request.urlretrieve(self.val_image_csv,image_csv_path)
        if os.path.exists(box_csv_path) is False:
            request.urlretrieve(self.val_box_csv,box_csv_path )
        if os.path.exists(classname_csv_path) is False:
            request.urlretrieve(self.classname_csv,classname_csv_path)
        print("download val complete")
      
    def download_train(self):
        folder="train"
        if os.path.exists(folder,) is False:
            os.makedirs(folder)
        image_csv_path=folder+"/image.csv"
        box_csv_path=folder+"/box.csv"
        classname_csv_path=folder+"/classname.csv"
        if os.path.exists(image_csv_path) is False:
            request.urlretrieve(self.train_image_csv,image_csv_path)
        if os.path.exists(box_csv_path) is False:
            request.urlretrieve(self.train_box_csv,box_csv_path )
        if os.path.exists(classname_csv_path) is False:
            request.urlretrieve(self.classname_csv,classname_csv_path )
        print("download train complete")
              
    def create_tfrecord(self,folder,keywords):  
        image_csv_path=folder+"/image.csv"
        box_csv_path=folder+"/box.csv"
        classname_csv_path=folder+"/classname.csv"    
          
        df_image = pd.read_csv(image_csv_path)
        df_box = pd.read_csv(box_csv_path)
        df_classname = pd.read_csv(classname_csv_path,names=['labelID','LabelName'])
  
        data= df_classname[df_classname['LabelName']==keywords]
        data=pd.merge(data, df_box, left_on = 'labelID', right_on = 'LabelName', how='right')
        data=pd.merge(data, df_image, left_on = 'ImageID', right_on = 'ImageID', how='right')
        data=data[data['labelID'].notna() & data['ImageID'].notna()]
          
        folder_path=keywords+"/"+folder+"/"
        if os.path.exists(folder_path) is False:
            os.makedirs(folder_path)
              
        tfrecord_file=folder_path+keywords+".tfrecord"
        writer = tf.python_io.TFRecordWriter(tfrecord_file)
  
        for  index,row in data.iterrows():
            file_name=row['ImageID']+".jpg"
            file_path=folder_path+file_name
            if os.path.exists(file_path) is False:
                request.urlretrieve(row['OriginalURL'],file_path)        
            with tf.gfile.GFile(file_path, 'rb') as fid:
                encoded_jpg = fid.read()
            encoded_jpg_io = io.BytesIO(encoded_jpg)
            image = PIL.Image.open(encoded_jpg_io)
            if image.format != 'JPEG':
                print("file format error "+file_path)
                os.remove(file_path)
                continue
            image.close()  
            key = hashlib.sha256(encoded_jpg).hexdigest()    
  
            xmin = []
            ymin = []
            xmax = []
            ymax = []
            classes = []
            classes_text = []
            width=image.width
            height=image.height
            xmin.append(float(row['XMin']))
            xmax.append(float(row['XMax']))
            ymin.append(float(row['YMin']))
            ymax.append(float(row['YMax']))
            classes.append(int(1))
            classes_text.append(keywords.encode('utf8'))
              
            example = tf.train.Example(features=tf.train.Features(feature={
                'image/height': dataset_util.int64_feature(int(height)),
              'image/width': dataset_util.int64_feature(int(width)),
              'image/filename': dataset_util.bytes_feature(file_name.encode('utf8')),
              'image/source_id': dataset_util.bytes_feature(file_name.encode('utf8')),
              'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
              'image/encoded': dataset_util.bytes_feature(encoded_jpg),
              'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
              'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
              'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
              'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
              'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
              'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
              'image/object/class/label': dataset_util.int64_list_feature(classes),
            }))
            writer.write(example.SerializeToString())
            os.remove(file_path)
            print("file "+file_path)
        writer.close() 
        print("create "+tfrecord_file+" success!")
          
    def create_train_tfrecord(self,keywords):  
         self.download_train()
         self.create_tfrecord("train",keywords)
    def create_val_tfrecord(self,keywords):  
         self.download_val()
         self.create_tfrecord("val",keywords) 
    def create_test_tfrecord(self,keywords):  
         self.download_test()
         self.create_tfrecord("test",keywords)
    def create_all_tfrecord(self,keywords):
        self.create_train_tfrecord(keywords)
        self.create_val_tfrecord(keywords)
          
dataset=open_image_dataset()
dataset.download_test()
dataset.create_tfrecord("test","Tortoise")#下載關鍵字爲"Tortoise"的測試數據集
#dataset.download_val()
#dataset.create_tfrecord("val","Tortoise")#下載關鍵字爲"Tortoise"的驗證數據集
#dataset.download_train()
#dataset.create_tfrecord("train","Tortoise")#下載關鍵字爲"Tortoise"的訓練數據集
  
# dataset.create_all_tfrecord("Tortoise") #下載所有關鍵字爲"Tortoise"的數據集


三、對生成的tfrecord進行驗證

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import tensorflow as tf
import numpy as np
import os
import skimage.io as io
import cv2
tfrecords_filename = "Tortoise/test/Tortoise.tfrecord"
 
filename_queue = tf.train.string_input_producer([tfrecords_filename]) 
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) 
     
features = tf.parse_single_example(serialized_example,
                                   features={
                                        'image/width':tf.FixedLenFeature([], tf.int64),
                                        'image/height': tf.FixedLenFeature([], tf.int64),
                                        'image/filename':  tf.FixedLenFeature([], tf.string),
                                        'image/source_id': tf.FixedLenFeature([], tf.string),
                                        'image/key/sha256':  tf.FixedLenFeature([], tf.string),
                                        'image/encoded': tf.FixedLenFeature([], tf.string),
                                        'image/format':  tf.FixedLenFeature([], tf.string),
                                        'image/object/bbox/xmin': tf.FixedLenFeature([], tf.float32),
                                        'image/object/bbox/xmax': tf.FixedLenFeature([], tf.float32),
                                        'image/object/bbox/ymin':tf.FixedLenFeature([], tf.float32),
                                        'image/object/bbox/ymax':tf.FixedLenFeature([], tf.float32),
                                        'image/object/class/text':tf.FixedLenFeature([], tf.string),
                                        'image/object/class/label': tf.FixedLenFeature([], tf.int64),
                                   })  
 
width= tf.cast(features['image/width'], tf.int32)
height = tf.cast(features['image/height'], tf.int32)
filename = tf.cast(features['image/filename'], tf.string)
format = tf.cast(features['image/format'], tf.string)
xmin = tf.cast(features['image/object/bbox/xmin'], tf.float32)
xmax = tf.cast(features['image/object/bbox/xmax'], tf.float32)
ymin = tf.cast(features['image/object/bbox/ymin'], tf.float32)
ymax = tf.cast(features['image/object/bbox/ymax'], tf.float32)
text = tf.cast(features['image/object/class/text'], tf.string)
label = tf.cast(features['image/object/class/label'], tf.int64)
 
image =tf.image.decode_jpeg(features['image/encoded']);
image = tf.reshape(image,tf.stack([height,width,3]))
 
 
 
 
with tf.Session() as sess: 
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(coord=coord)
    for in range(20):
        width1,height1,filename1,format1,xmin1,xmax1,ymin1,ymax1,text1,label1,image1=sess.run([width,height,filename,format,xmin,xmax,ymin,ymax,text,label,image])
        print(width1,height1,filename1,format1,xmin1,xmax1,ymin1,ymax1,text1,label1)
        x1,y1=int(xmin1*width1),int(ymin1*height1)
        x2,y2=int(xmax1*width1),int(ymax1*height1)
        io.imshow(cv2.rectangle(np.array(image1),(x1,y1),(x2,y2),(0,255,0),3), cmap = 'gray', interpolation = 'bicubic')
        io.show()
         
    coord.request_stop()
    coord.join(threads)

下載代碼github

最終的結果如下:











發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章