tensorflow 的 hashtable 和index table 讀取,求均值向量,缺失值處理

import tensorflow as tf

print(tf.__version__)
list_arr = [9, 8, 6, 5]
value_arr = [0, 1, 2, 3]
tf_look_up = tf.constant(list_arr, dtype=tf.int64)
tf_value_arr = tf.constant(value_arr, dtype=tf.int64)

table = tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer(tf_look_up, tf_value_arr), 0)
ph_vals = tf.constant([8, 5], dtype=tf.int64)
ph_idx = table.lookup(ph_vals)

with tf.compat.v1.Session() as sess:
	sess.run(tf.tables_initializer())
	sess.run(tf.initialize_all_variables())
	res = sess.run(ph_idx)
	print(res)

input = [
	"harden|james|curry",
	"wrestbrook|harden|durant|hardenx",
	"paul|towns",
	""
]
weight = [
	"0.4,0.3,0.1",
	"0.4,0.3,0.1,0.1",
	"0.4,0.3",
	""
]

TAG_SET = ["harden", "james", "curry", "durant", "paul", "towns", "wrestbrook"]


class IndexValueEmbedding(object):
	default_value = "xxx-default"

	def __init__(self, field_name, category_list, embedding_size):
		self.field_name = field_name
		self.category_list = category_list
		self.embedding_size = embedding_size
		self._init_dict()
		self._init_embedding()

	def _init_embedding(self):
		with tf.variable_scope("index_value"):
			self.embedding_params = tf.get_variable(name=self.field_name,
			                                        initializer=tf.truncated_normal(
				                                        [len(self._tags), self.embedding_size]))
		pass

	def _init_dict(self):
		self._tags = [self.default_value]
		self._tags.extend(self.category_list)
		self.table = tf.contrib.lookup.index_table_from_tensor(mapping=self._tags, default_value=0)

	def get_avg_embedding(self, input_indexes, input_weights, sep1=",", sep2=","):
		input, weight = self._preprocess(input_indexes, input_weights)
		tags = self._sparse_from_string_array(input, sep1)

		wgt = tf.string_split(tf.string_strip(weight), sep2, skip_empty=True)
		wgt_number = tf.string_to_number(wgt.values, tf.float32)
		mask = tf.equal(tags.values, 0)
		wgt_number1 = tf.where(mask, wgt_number, tf.zeros_like(wgt_number) + tf.constant(0.000001, tf.float32))
		sparse_wgt = tf.SparseTensor(wgt.indices, wgt_number1, wgt.dense_shape)

		embedded_tags = tf.nn.embedding_lookup_sparse(self.embedding_params, sp_ids=tags, sp_weights=sparse_wgt,
		                                              combiner="mean")
		return embedded_tags

	def _preprocess(self, input, weight):
		input = tf.map_fn(lambda x: tf.cond(tf.equal(tf.string_strip(x), ""), lambda: self.default_value, lambda: x),
		                  elems=tf.constant(input, tf.string))
		weight = tf.map_fn(lambda x: tf.cond(tf.equal(tf.string_strip(x), ""), lambda: "1", lambda: x),
		                   elems=tf.constant(weight, tf.string))
		return input, weight

	def _sparse_from_string_array(self, input_keys, sep):
		input_keys_trims = tf.string_strip(input_keys)
		split_tags = tf.string_split(input_keys_trims, sep, skip_empty=False)
		return tf.SparseTensor(indices=split_tags.indices,
		                       values=self.table.lookup(split_tags.values),
		                       dense_shape=split_tags.dense_shape)


index_value_emb = IndexValueEmbedding("cate", TAG_SET, 4)
avg = index_value_emb.get_avg_embedding(input, weight, "|")
# mask_res = tf.where(mask, tags.values, tf.zeros_like(tags.values))
# new_tags = tf.SparseTensor(indices=tags.indices, values=mask_res, dense_shape=tags.dense_shape)


with tf.compat.v1.Session() as sess:
	sess.run(tf.tables_initializer())
	sess.run(tf.initialize_all_variables())
	print("avg_res\n", sess.run(avg))

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章