tensorflow 通過TextLineDataset dataset.map 讀取數據

這樣讀取數據比較麻煩,因爲map是一行行讀取,需要自己把數據整理成列的方式。處理起來還是比較麻煩,用decode_csv可以直接把數據處理成列的方式,簡單很多。

 

import tensorflow as tf
from tensorflow.contrib.lookup import HashTable
from tensorflow.contrib.lookup import TextFileIdTableInitializer
from tensorflow.contrib.lookup import IdTableWithHashBuckets


label_idx = 0
fields_type = ["int", "tags", "weights", "indexes", "weights"]
fields_idx = [1, 2, 3, 4]
fields_count = len(fields_idx)

new_fields_idx = []
for i in fields_idx:
    if i < label_idx:
        new_fields_idx.append(i)
    else:
        new_fields_idx.append(i-1)

## feature parse ##
def input_fn(file_list, epoches=1, batch_size=2, shuffle=False):
    def parse_index(indexes, sep=",", default_value="0"):
        w = tf.string_split(indexes, ",")
        process_str = tf.map_fn(lambda x: tf.cond(tf.equal(tf.string_strip(x), ""),
                                            lambda: default_value,
                                            lambda: x),
                                   elems=w.values)

        indexes_number = tf.string_to_number(process_str,tf.int32)
        spt_index = tf.SparseTensor(indices=w.indices,
                                    values=indexes_number,
                                    dense_shape=w.dense_shape)

        return spt_index

    def parse_weight(weights, sep=",", default_value="1"):
        w = tf.string_split(weights, sep)
        process_str = tf.map_fn(lambda x: tf.cond(tf.equal(tf.string_strip(x), ""),
                                                  lambda: default_value,
                                                  lambda: x),
                                elems=w.values)

        wgt_number = tf.string_to_number(process_str, tf.float32)
        spt_wgt = tf.SparseTensor(indices=w.indices,
                                  values=wgt_number,
                                  dense_shape=w.dense_shape)
        return spt_wgt

    def parse_split(line):
        parse_res = tf.string_split([line], delimiter='|')
        values = parse_res.values
        label = values[label_idx]
        features_values = [label]
        for idx in fields_idx:
            s = values[idx]
            features_values.append(s)
        return features_values, label

    def parse_feature(f,y):  ## 解析feature-value ##
        weights = []
        for i in range(batch_size):
            k = f[i][-1]
            weights.append(k)

        spt_weight = parse_weight(weights)

        index = []
        for i in range(batch_size):
            k = f[i][-2]
            index.append(k)

        spt_index = parse_index(index, ",")
        return spt_index, spt_weight, y
    # 讀取文件列表
    dataset = tf.data.TextLineDataset(file_list)
    # 並行讀取
    dataset = dataset.map(parse_split, num_parallel_calls=2)
    if shuffle:
        dataset = dataset.shuffle(buffer_size=5000)
    dataset = dataset.repeat(count=epoches)
    # 提取讀取 節約時間,這裏的數量設置爲cpu數量* k?
    dataset.prefetch(batch_size * 12)
    # 如果數據不夠一個batch_size 則丟棄
    dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
    # 數據只消費異常
    dataset = dataset.make_one_shot_iterator()
    features_values, y = dataset.get_next()


    spt_index, spt_weight, y = parse_feature(features_values,y)

    # batch_sparse_feature = SparseTensorFeature(keys, values)
    #batch_sparse_feature = SparseTensorFeature_batch(keys, values)

    return spt_index, spt_weight, y

# https://blog.csdn.net/cjopengler/article/details/78150650
from tensorflow.python.training import coordinator
with tf.Session() as sess:
    table_used.init.run()
    # 定義numpy input fn
    # 運行input_fn, 產生featrue和targets
    spt_index, spt_weight, y = input_fn(["./data.txt"], batch_size=3)
    coord = coordinator.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord=coord)

    num_step = 1
    for step in range(num_step):
        spt_index, spt_weight, y= sess.run([spt_index, spt_weight, y])
        print('featrues:', spt_index, spt_weight, y)
    coord.request_stop()
    coord.join(threads)


 

數據:

1|click,show,李志林,股災,演變|21.0,120.0,1,1,1|1,2,3|0.1
1|click,show,李志林,股災,演變|21.0,120.0,1,1,1|9,2,3|0.2,0.3
1|click,show,楊冪,股災,演變|21.0,120.0,1,1,1|8,2,3|0.4,0.5,0.6
1|click,show,開心,股災,演變|21.0,120.0,1,1,1|7,2,3|0.1,0.2,0.3

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章