Bayes分類算法
簡介
- 概率論的公式
- 一個小例子
- 算法的思想
- 呈上代碼
貝葉斯公式的簡介
在這裏p(x | y)表示在y事件發生時,x事件發生的概率。
一個小例子
Name | Gender | Height | Class |
---|---|---|---|
張三 | F | 1.68 | Medium |
李四 | M | 1.0 | Short |
王五 | M | 1.9 | Tall |
趙六 | M | 1.2 | Short |
分類算法的目的在於給出了以上面的一些例子作爲訓練集,按Class將每一個條目分類,訓練集裏的條目是分好類的,我們根據它訓練出一個模型(公式),當有給定的(Name,Gender,Height)時,我們就馬上可以輸出他的Class結果。
例如現在我們要預測t=(陳七,F,1.67)是屬於那個分類的。
爲了方便,我們先把身高做一個劃分,把身高分爲(0,1.6], (1.6,1.7], (1.7,1.8], (1.8,1.9],(1.9,2.0],(2.0,…] 總共6個區間。
分別屬於矮,中,高。
那我們現在就要預測t屬於哪一類,既要計算:
p(矮 | t) p(中 | t) p(高 | t)的大小,最大的那個就是我們所要選擇的。
- 根據貝葉斯公式,爲了求p(矮 | t)我們先得求p(t)
- 求p(矮)
- 求p(t | 矮)
- 上述三個都可以根據已經給出的數據求出
- p(中 | t) p(高 | t)的求法類似p(矮 | t)
算法思想
上面的屬性有性別(男,女),身高(數值區間),目標屬性有分類的(矮,中,高)
例如上面的第一行表示性別爲男的條目分類中矮,中,高的條目數分別爲1,2,3.
- 下面的算法中Attribute類表示各個屬性(性別,身高),它的值域爲性別(男,女),目標屬性值爲:矮,中,高。這個類維護一個計數表,負責計數表的更新。
- DataSet類:維護目標屬性,添加新的目標屬性值。
- Bayes類:計算各個分類的概率
package Bayes;
import java.util.ArrayList;
public class Attribute {
public ArrayList<String> range;//屬性的值域
public ArrayList<ArrayList<Double>> countMatrix;//計數表,統計各屬性值在目標屬性上的取值的個數,目標屬性就是最終的分類屬性。
public String attrName;//屬性名
public int attrIndex;//屬性在DataSet上的序號
public Attribute(String attrName, int attrIndex)
{
this.attrIndex = attrIndex;
this.attrName = attrName;
this.range = new ArrayList<String>();
this.countMatrix = new ArrayList<ArrayList<Double>>();
}
public void AddData(ArrayList<String> dataRow)
{
String columnValue = dataRow.get(attrIndex);
String targetValue = dataRow.get(DataSet.attr.size());
if(range.contains(columnValue))//如果該屬性值存在,則在原來的基礎上加1
{
int columValueIndex = range.indexOf(columnValue);
int targetValueIndex = DataSet.targetValueRange.indexOf(targetValue);
ArrayList<Double> matrixRow = countMatrix.get(columValueIndex);
if(targetValueIndex >= matrixRow.size())
{
int targetSize = DataSet.targetValueRange.size();
for(int i = 0; i < (targetSize - matrixRow.size()); i++)//有新的目標屬性值,將原來缺失的補齊。
{
matrixRow.add(new Double(0));
}
matrixRow.set(targetValueIndex, matrixRow.get(targetValueIndex)+1);
// System.out.println("add");
}
}
else//若屬性值不存在,則得將屬性值加進去。
{
this.range.add(columnValue);
int targetValueIndex = DataSet.targetValueRange.indexOf(targetValue);
ArrayList<Double> matrixRow = new ArrayList<Double>();
for(int i = 0; i < DataSet.targetValueRange.size(); i++)//該屬性不存在,則爲它構建一行新的
{
matrixRow.add(new Double(0));
}
matrixRow.set(targetValueIndex, new Double(1));
this.countMatrix.add(matrixRow);
// System.out.println(matrixRow.get(0));
}
}
}
package Bayes;
import java.util.ArrayList;
import Bayes.Attribute;
public class DataSet {
public static ArrayList<Attribute> attr;//屬性集
public String targetAttribute;//目標屬性名
public static ArrayList<String> targetValueRange;//目標屬性的值域
public static ArrayList<Double> targetValueCount;//目標屬性各值出現的次數
/**
* 數據集初始化,輸入一個屬性集和一個目標屬性名
* @param attrSet //屬性集
* @param targetAttrbute //目標屬性
*/
public DataSet(ArrayList<String> attrSet, String targetAttribute)
{
DataSet.attr = new ArrayList<Attribute>();
for(int i = 0; i < attrSet.size(); i++)
{
DataSet.attr.add(new Attribute(attrSet.get(i),i));
}
this.targetAttribute = targetAttribute;
targetValueCount = new ArrayList<Double>();
this.targetValueRange = new ArrayList<String>();
}
public void addRow(String... datas)
{
ArrayList<String> row = new ArrayList<String>();
for(String str : datas)
{
row.add(str);
}
String targetValue = row.get(DataSet.attr.size());
if(targetValueRange.contains(targetValue))
{
int targetIndex = this.targetValueRange.indexOf(targetValue);
targetValueCount.set(targetIndex, targetValueCount.get(targetIndex) + 1);
}
else
{
targetValueRange.add(targetValue);
targetValueCount.add(1.0);
}
for(int i = 0; i < attr.size(); i++)//更新計數表
{
Attribute att = DataSet.attr.get(i);
att.AddData(row);
}
System.out.println(targetValueRange.size());
}
}
package Bayes;
public class Bayes {
public double[] Test(String... dataRow)
{
//存放各個目標屬性的似然值
double[] likelihood = new double[DataSet.targetValueRange.size()];
//計算dataRow相對於各個目標屬性的似然值,即:P(dataRow|targetValue)*P(targetValue)
for(int i = 0; i < DataSet.targetValueRange.size(); i++)
{
String targetValue = DataSet.targetValueRange.get(i);
double probOfTarget = getProb(targetValue);
Double probOfData = null;
for(int j = 0; j < DataSet.attr.size(); j++)
{
Attribute attr = DataSet.attr.get(j);
double tempProb = getProb(attr, dataRow[j], targetValue);
if(probOfData == null)
{
probOfData = tempProb;
}
else
{
probOfData *= tempProb;
}
}
likelihood[i] = probOfTarget * probOfData;
}
double sumlikelihood = 0.0;
for(int i = 0; i < likelihood.length; i++)
{
sumlikelihood += likelihood[i];
}
double[] result = new double[DataSet.targetValueRange.size()];
for(int i = 0; i < result.length; i++)
{
result[i] = sumlikelihood == 0.0 ? 0 : (likelihood[i]/sumlikelihood);
}
return result;
}
/**
* 計算P(targetValue)的值
* @param targetValue //指定要計算的分類的值
* @return 概率
*/
private double getProb(String targetValue)
{
double sum = 0.0;//總的目標屬性值的次數
double valueCount = 0.0;//標記指定目標屬性出現的次數
for(int i = 0; i < DataSet.targetValueRange.size(); i++)
{
String value = DataSet.targetValueRange.get(i);
double count = DataSet.targetValueCount.get(i);
sum += count;
if(targetValue.equals(value))
valueCount = count;
}
return sum < 1 ? 0 : (valueCount / sum);
}
/**
* 獲取p(attrValue/targetValue)
* @param attr 屬性
* @param attrValue 屬性值
* @param targetValue 目標屬性
* @return
*/
private double getProb(Attribute attr, String attrValue, String targetValue)
{
double sum = 0.0;
double count = 0.0;
int columnValueIndex = DataSet.targetValueRange.indexOf(targetValue);
int attrIndex = attr.range.indexOf(attrValue);
for(int i = 0; i < attr.range.size(); i++)
{
double tempCount = columnValueIndex >= attr.countMatrix.get(i).size()?0:attr.countMatrix.get(i).get(columnValueIndex);
sum += tempCount;
if(attrIndex == i)
count = tempCount;
}
return sum < 1 ? 0 : (count/sum);
}
}
package Bayes;
import java.util.ArrayList;
public class main {
public static void main(String[] args) {
// TODO Auto-generated method stub
ArrayList<String> attr = new ArrayList<String>();
attr.add("Gender");
attr.add("Height");
DataSet dataSet = new DataSet(attr, "Class");
//添加數據
dataSet.addRow("F", "1", "Short");
dataSet.addRow("F", "1.5", "Medium");
dataSet.addRow("M", "1.8", "Tall");
Bayes bayes = new Bayes();
double[] result = bayes.Test("M", "1.8");
for(int i = 0; i < result.length; i++)
{
String targetValue = DataSet.targetValueRange.get(i);
System.out.println("P("+targetValue+"): " + result[i]) ;
}
}
}