斯坦福大學-樸素貝葉斯_Exercise Code

Naive Bayes 樸素貝葉斯

http://openclassroom.stanford.edu/MainFolder/DocumentPage.php?course=MachineLearning&doc=exercises/ex6/ex6.html

% train.m
% Naive Bayes text classifier

clear all; close all; clc

% store the number of training examples
numTrainDocs = 700;

% store the dictionary size
numTokens = 2500;

% read the features matrix
M = dlmread('train-features.txt', ' ');
spmatrix = sparse(M(:,1), M(:,2), M(:,3), numTrainDocs, numTokens);
train_matrix = full(spmatrix);

% train_matrix now contains information about the words within the emails
% the i-th row of train_matrix represents the i-th training email
% for a particular email, the entry in the j-th column tells
% you how many times the j-th dictionary word appears in that email



% read the training labels
train_labels = dlmread('train-labels.txt');
% the i-th entry of train_labels now indicates whether document i is spam


% Find the indices for the spam and nonspam labels
spam_indices = find(train_labels);
nonspam_indices = find(train_labels == 0);

% Calculate probability of spam
prob_spam = length(spam_indices) / numTrainDocs;

% Sum the number of words in each email by summing along each row of
% train_matrix
email_lengths = sum(train_matrix, 2);
% Now find the total word counts of all the spam emails and nonspam emails
spam_wc = sum(email_lengths(spam_indices));
nonspam_wc = sum(email_lengths(nonspam_indices));

% Calculate the probability of the tokens in spam emails
prob_tokens_spam = (sum(train_matrix(spam_indices, :)) + 1) ./ ...
    (spam_wc + numTokens);
% Now the k-th entry of prob_tokens_spam represents phi_(k|y=1)

% Calculate the probability of the tokens in non-spam emails
prob_tokens_nonspam = (sum(train_matrix(nonspam_indices, :)) + 1)./ ...
    (nonspam_wc + numTokens);
% Now the k-th entry of prob_tokens_nonspam represents phi_(k|y=0)

% test.m
%Naive Bayes text classifier

% read the test matrix in the same way we read the training matrix
N = dlmread('test-features.txt', ' ');
spmatrix = sparse(N(:,1), N(:,2), N(:,3));
test_matrix = full(spmatrix);

% Store the number of test documents and the size of the dictionary
numTestDocs = size(test_matrix, 1);
numTokens = size(test_matrix, 2);


% The output vector is a vector that will store the spam/nonspam prediction
% for the documents in our test set.
output = zeros(numTestDocs, 1);

% Calculate log p(x|y=1) + log p(y=1)
% and log p(x|y=0) + log p(y=0)
% for every document
% make your prediction based on what value is higher
% (note that this is a vectorized implementation and there are other
%  ways to calculate the prediction)
log_a = test_matrix*(log(prob_tokens_spam))' + log(prob_spam);
log_b = test_matrix*(log(prob_tokens_nonspam))'+ log(1 - prob_spam);  
output = log_a > log_b;


% Read the correct labels of the test set
test_labels = dlmread('test-labels.txt');

% Compute the error on the test set
% A document is misclassified if it's predicted label is different from
% the actual label, so count the number of 1's from an exclusive "or"
numdocs_wrong = sum(xor(output, test_labels))

%Print out error statistics on the test set
fraction_wrong = numdocs_wrong/numTestDocs



發佈了32 篇原創文章 · 獲贊 6 · 訪問量 6萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章