見以下代碼
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
BATCH_SIZE=10
mnist =input_data.read_data_sets('MNIST_data/')
xs,ys = mnist.train.next_batch(BATCH_SIZE)
print(ys.shape)
print(ys)
print('-------------編碼後--------------')
mnist =input_data.read_data_sets('MNIST_data/',one_hot = True)
xs,ys = mnist.train.next_batch(BATCH_SIZE)
# print(xs.shape)
print(ys.shape)
print(ys)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
(10,)
[7 0 8 2 4 5 7 8 6 2]
-------------編碼後--------------
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
(10, 10)
[[ 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[ 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]]
pandas中one-hot編碼的神坑
機器學習中,經常會用到one-hot編碼。pandas中已經提供了這一函數。
但是這裏有一個神坑,得到的one-hot編碼數據類型是uint8,進行數值計算時會溢出!!!
import pandas as pd
import numpy as np
a = [1, 2, 3, 1]
one_hot = pd.get_dummies(a)
print(one_hot.dtypes)
print(one_hot)
print(-one_hot)
1 uint8
2 uint8
3 uint8
dtype: object
1 2 3
0 1 0 0
1 0 1 0
2 0 0 1
3 1 0 0
1 2 3
0 255 0 0
1 0 255 0
2 0 0 255
3 255 0 0