Bayes 分类算法

Bayes分类算法

本文参考

简介

  • 概率论的公式
  • 一个小例子
  • 算法的思想
  • 呈上代码

贝叶斯公式的简介

这里写图片描述
在这里p(x | y)表示在y事件发生时,x事件发生的概率。

一个小例子

Name Gender Height Class
张三 F 1.68 Medium
李四 M 1.0 Short
王五 M 1.9 Tall
赵六 M 1.2 Short

分类算法的目的在于给出了以上面的一些例子作为训练集,按Class将每一个条目分类,训练集里的条目是分好类的,我们根据它训练出一个模型(公式),当有给定的(Name,Gender,Height)时,我们就马上可以输出他的Class结果。

例如现在我们要预测t=(陈七,F,1.67)是属于那个分类的。
为了方便,我们先把身高做一个划分,把身高分为(0,1.6], (1.6,1.7], (1.7,1.8], (1.8,1.9],(1.9,2.0],(2.0,…] 总共6个区间。
分别属于矮,中,高。
那我们现在就要预测t属于哪一类,既要计算:
p(矮 | t) p(中 | t) p(高 | t)的大小,最大的那个就是我们所要选择的。
- 根据贝叶斯公式,为了求p(矮 | t)我们先得求p(t)
- 求p(矮)
- 求p(t | 矮)
- 上述三个都可以根据已经给出的数据求出
- p(中 | t) p(高 | t)的求法类似p(矮 | t)

算法思想

我们先看上面案例的一张表
上面的属性有性别(男,女),身高(数值区间),目标属性有分类的(矮,中,高)
例如上面的第一行表示性别为男的条目分类中矮,中,高的条目数分别为1,2,3.
- 下面的算法中Attribute类表示各个属性(性别,身高),它的值域为性别(男,女),目标属性值为:矮,中,高。这个类维护一个计数表,负责计数表的更新。
- DataSet类:维护目标属性,添加新的目标属性值。
- Bayes类:计算各个分类的概率

package Bayes;

import java.util.ArrayList;

public  class Attribute {
    public ArrayList<String> range;//属性的值域
    public ArrayList<ArrayList<Double>> countMatrix;//计数表,统计各属性值在目标属性上的取值的个数,目标属性就是最终的分类属性。
    public String attrName;//属性名
    public int attrIndex;//属性在DataSet上的序号

    public Attribute(String attrName, int attrIndex)
    {
        this.attrIndex = attrIndex;
        this.attrName = attrName;
        this.range = new ArrayList<String>();
        this.countMatrix = new ArrayList<ArrayList<Double>>();
    }
    public void AddData(ArrayList<String> dataRow)
    {
        String columnValue = dataRow.get(attrIndex);
        String targetValue = dataRow.get(DataSet.attr.size());
        if(range.contains(columnValue))//如果该属性值存在,则在原来的基础上加1
        {
            int columValueIndex = range.indexOf(columnValue);
            int targetValueIndex = DataSet.targetValueRange.indexOf(targetValue);
            ArrayList<Double> matrixRow = countMatrix.get(columValueIndex);  
            if(targetValueIndex >= matrixRow.size())
            {
                int targetSize = DataSet.targetValueRange.size();
                for(int i = 0; i < (targetSize - matrixRow.size()); i++)//有新的目标属性值,将原来缺失的补齐。
                {
                    matrixRow.add(new Double(0));
                }
                matrixRow.set(targetValueIndex, matrixRow.get(targetValueIndex)+1); 
//              System.out.println("add");
            }
        }
        else//若属性值不存在,则得将属性值加进去。
        {
            this.range.add(columnValue);
            int targetValueIndex = DataSet.targetValueRange.indexOf(targetValue);
            ArrayList<Double> matrixRow = new ArrayList<Double>();
            for(int i = 0; i < DataSet.targetValueRange.size(); i++)//该属性不存在,则为它构建一行新的 
            {
                matrixRow.add(new Double(0));
            }
            matrixRow.set(targetValueIndex, new Double(1));
            this.countMatrix.add(matrixRow);
//          System.out.println(matrixRow.get(0));

        }
    }


}
package Bayes;

import java.util.ArrayList;
import Bayes.Attribute;
public class DataSet {
public static ArrayList<Attribute> attr;//属性集
public String targetAttribute;//目标属性名
public static ArrayList<String> targetValueRange;//目标属性的值域
public static ArrayList<Double> targetValueCount;//目标属性各值出现的次数
/**
 * 数据集初始化,输入一个属性集和一个目标属性名
 * @param attrSet //属性集
 * @param targetAttrbute //目标属性
 */
public DataSet(ArrayList<String> attrSet, String targetAttribute)
{
    DataSet.attr = new ArrayList<Attribute>();
    for(int i = 0; i < attrSet.size(); i++)
    {
        DataSet.attr.add(new Attribute(attrSet.get(i),i));
    }
    this.targetAttribute = targetAttribute;
    targetValueCount = new ArrayList<Double>();
    this.targetValueRange = new ArrayList<String>();
}

    public void addRow(String... datas)
    {
        ArrayList<String> row = new ArrayList<String>();
        for(String str : datas)
        {
            row.add(str);
        }
        String targetValue = row.get(DataSet.attr.size());
        if(targetValueRange.contains(targetValue))
        {
            int targetIndex = this.targetValueRange.indexOf(targetValue);
            targetValueCount.set(targetIndex, targetValueCount.get(targetIndex) + 1);
        }
        else
        {
            targetValueRange.add(targetValue);
            targetValueCount.add(1.0);
        }
        for(int i = 0; i < attr.size(); i++)//更新计数表
        {
            Attribute att = DataSet.attr.get(i);
            att.AddData(row);
        }
        System.out.println(targetValueRange.size());

    }

}
package Bayes;

public class Bayes {
    public double[] Test(String... dataRow)
    {

        //存放各个目标属性的似然值
        double[] likelihood = new double[DataSet.targetValueRange.size()];
        //计算dataRow相对于各个目标属性的似然值,即:P(dataRow|targetValue)*P(targetValue)
        for(int i = 0; i < DataSet.targetValueRange.size(); i++)
        {
            String targetValue = DataSet.targetValueRange.get(i);
            double probOfTarget = getProb(targetValue);
            Double probOfData = null;
            for(int j = 0; j < DataSet.attr.size(); j++)
            {
                Attribute attr = DataSet.attr.get(j);
                double tempProb = getProb(attr, dataRow[j], targetValue);
                if(probOfData == null)
                {
                    probOfData = tempProb;
                }
                else
                {
                    probOfData *= tempProb;
                }
            }
            likelihood[i] = probOfTarget * probOfData;
        }
        double sumlikelihood = 0.0;
        for(int i = 0; i < likelihood.length; i++)
        {
            sumlikelihood += likelihood[i];
        }
        double[] result = new double[DataSet.targetValueRange.size()];
        for(int i = 0; i < result.length; i++)
        {
            result[i] = sumlikelihood == 0.0 ? 0 : (likelihood[i]/sumlikelihood);
        }
        return result;
    }
    /**
     * 计算P(targetValue)的值
     * @param targetValue //指定要计算的分类的值
     * @return 概率
     */
    private double getProb(String targetValue)
    {
        double sum = 0.0;//总的目标属性值的次数
        double valueCount = 0.0;//标记指定目标属性出现的次数
        for(int i = 0; i < DataSet.targetValueRange.size(); i++)
        {
            String value = DataSet.targetValueRange.get(i);
            double count =  DataSet.targetValueCount.get(i);
            sum += count;
            if(targetValue.equals(value))
                valueCount = count;
        }
        return sum < 1 ? 0 : (valueCount / sum);
    }
    /**
     * 获取p(attrValue/targetValue)
     * @param attr 属性
     * @param attrValue 属性值
     * @param targetValue 目标属性
     * @return
     */
    private double getProb(Attribute attr, String attrValue, String targetValue)
    {
        double sum = 0.0;
        double count = 0.0;
        int columnValueIndex = DataSet.targetValueRange.indexOf(targetValue);
        int attrIndex = attr.range.indexOf(attrValue);
        for(int i = 0; i < attr.range.size(); i++)
        {
            double tempCount = columnValueIndex >= attr.countMatrix.get(i).size()?0:attr.countMatrix.get(i).get(columnValueIndex);
            sum += tempCount;
            if(attrIndex == i)
                count = tempCount;
        }
        return sum < 1 ? 0 : (count/sum);
    }

}
package Bayes;

import java.util.ArrayList;

public class main {

    public static void main(String[] args) {
        // TODO Auto-generated method stub

        ArrayList<String> attr = new ArrayList<String>();
        attr.add("Gender");
        attr.add("Height");
        DataSet dataSet = new DataSet(attr, "Class");
        //添加数据
        dataSet.addRow("F", "1", "Short");
        dataSet.addRow("F", "1.5", "Medium");
        dataSet.addRow("M", "1.8", "Tall");

        Bayes bayes = new Bayes();
        double[] result = bayes.Test("M", "1.8");
        for(int i = 0; i < result.length; i++)
        {
            String targetValue = DataSet.targetValueRange.get(i);
            System.out.println("P("+targetValue+"): " + result[i]) ;
        }

    }

}
发布了36 篇原创文章 · 获赞 12 · 访问量 3万+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章