問題
爲了方便對比,我們仍然拿手寫識別Mnist這個數據集作爲我們的實驗的數據集。Mnist數據集[1,2] 中包含60000張手寫數字圖片,10,000 張測試圖片。每張圖片的大小爲28*28,包含一個手寫數字。下面是一些樣本舉例:
我們希望實現這樣一個分類器: 給定一張手寫圖片,分類器給出改數字屬於哪個分類。(0-9共10個分類)
貝葉斯公式
(公式1)
給定B時A的概率等於給定A時B的概率乘以A的概率 除以 B的概率。
舉例:
假設下雨的概率 P(A) = 0.1 不下雨的概率爲 P(not A) = 1-P(A) = 0.9 ,假設P(B)表示帶傘的概率。
下雨時我帶傘的概率爲 P(B|A) = 0.8 ,下雨不帶傘的概率爲P(not B|A) = 1-P(B|A) = 0.2
不下雨時我帶傘的概率爲 P(B| not A) = 0.05 ,不下雨不帶傘不帶傘的概率爲P(not B |not A) = 0.95
那麼有一天,我帶傘了,問那天下雨嗎,或者下雨的概率是多少?
根據公式,我們還需要知道帶傘的概率P(B),
可以用全概率公式來求
帶傘時,下雨的概率爲:
那麼 P(not A | B) = 1 - P(A|B) = 0.36 < P(A|B)
說明, 我帶了傘更有可能是那天下雨。
另外,我們其實沒有必要計算P(B),因爲對於分類問題來說,我們只需要比較 P(A|B) 和 P(not A|B)的大小即可。即: P(B|A)P(A) vs P(B|not A)P(not A)
上述方法實際上是使用了最大後驗概率(MAP)的方法,即認爲後驗概率最大的那個情況(分類)是最優可能的情況(分類)。
P(A)的值是一個先驗概率,是通過已有經驗得到的,比如本年度下雨天數和總天數的比值得到。
如果樣本是多維數據,
對Mnist應用樸素貝葉斯分類
在數據集中,每個圖片作爲一個樣本,它的分類結果共有10類,分別爲0,1,2,3,…9 。
數據集中每張圖片的的每個像素採用灰度值,我們爲了方便下面處理將它變成二值圖像。即將非0的點置爲1。這樣處理後,我們可以認爲一個像素是否爲1變成一個0-1分佈。
我們計算這樣一個概率值:
(公式3)
爲了表述簡單,我們令
(公式4)
我們求解的目標可以寫爲:
(公式5)
簡單來說就是計算
那麼下面的任務就是要求解
(公式6)
我們通俗的理解一下這個公式, 如果從屬於第0類的圖片在像素20上是1 個概率較高。那麼如果發現像素20爲1,則屬於第0類的概率較高。在看分母,如果對於所有圖片,在像素20上的1概率較高,說明這個像素對於分類的區分能力低,所以分母這個概率越高,則總的概率越低。
下面我們來逐個獲得公式中需要的量。第i個樣本第k個像素爲1的概率,我們通過統計所有樣本得到。
根據0-1分佈公式
對上式求對數
爲了讓公式看起來簡練,令
使用matlab實現
function model = bc_train(x,y,J)
%x : 樣本 y:標籤 J分類數
[K,N] = size(x); %K爲維度, N爲圖像個數
%計算 P(x_k^i)
px = sum(x,2);
px = px/N;
%p(y^i=j)計算屬於每一類的圖像個數
py= zeros(J,1);
for j = 1:J
py(j) = sum(y == j);
end
%p(x|y)
pxy = zeros(J,K);
for j=1:J
xj = x(:,y == j); %屬於J的圖片
pxy(j,:) = sum(xj,2) / size(xj,2);
end
model.px = px;
model.py = py;
model.pxy = pxy;
model.J = J;
end
function yp = bc_predit(x,y,model)
px =model.px + 1e-10;
pxy = model.pxy;
py =model.py;
J= model.J;
[K,N] = size(x); %K爲維度, N爲待分類圖像個數
% log_pyx 應爲 J*N的矩陣, X爲K*N的矩陣 pxy 爲 J*K的矩陣
log_pxy = log(pxy+1e-10);
% log_pxik 應爲 N*K 表示樣本i的第k個像素的對數概率 px爲K*1
rpx = repmat(px',N,1); % N*K
log_pxik = x'.*log(rpx) + (1-x').*(1-log(rpx));
t = sum(log_pxik,2);
t =repmat(t',J,1);
%log_pyx 應爲 J*N的矩陣
log_pyx = log_pxy*x + log(1-pxy)*(1-x) + repmat(log(py),1,N) - t;
[m,I] = max(log_pyx);
yp = I;
end
function bc_show_model(model)
pa = [];
pxy = model.pxy;
for i = 1:model.J
p = pxy(i,:);
p = reshape(p,28,28);
pa = [pa p];
end
pa = imresize(pa,10,'nearest');
imshow(pa);
end
測試腳本
labels = loadMNISTLabels('mnist/train-labels.idx1-ubyte');
images = loadMNISTImages('mnist/train-images.idx3-ubyte');
%二值化圖像
images(images>0) = 1;
%label +1
labels = labels+1;
m = bc_train(images,labels,10);
labels = loadMNISTLabels('mnist/t10k-labels.idx1-ubyte')+1;
images = loadMNISTImages('mnist/t10k-images.idx3-ubyte');
images(images>0) = 1;
yp = bc_predit(images,labels,m);
acc = sum(yp' == labels)/length(labels)
最後輸出結果 準確率 : 0.8418
使用bc_show_model 將