Hadoop/MapReduce、Spark 樸素貝葉斯分類器分類符號數據





package cjbayesclassfier;

import java.io.IOException;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.lib.output.MultipleOutputs; 
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;

import edu.umd.cloud9.io.pair.PairOfStrings;
/***
 * 第一步:拆分輸入文件每一行,得到類輸出和條件概率輸出
 * @author chenjie
 */
public class CJBayesClassfier_Step1 extends Configured implements Tool  {
    /***
     * 映射器:
     * 輸入:weather.txt
     * 其中一行示例如下:Sunny,Hot,High,Weak,No
     * 輸出:
     * key                  value
     * (Sunny,No)       1
     * (Hot,No)            1
     * (High,No)          1
     * (Weak,No)        1
     * (CLASS,No)       1
     * @author chenjie
     */
    public static class CJBayesClassfierMapper extends Mapper<LongWritable, Text, PairOfStrings, LongWritable>
    {
        PairOfStrings outputKey = new PairOfStrings();
        LongWritable outputValue = new LongWritable(1);
        @Override
        protected void map(
                LongWritable key,
                Text value,
                Context context)
                throws IOException, InterruptedException {
               String tokens[] = value.toString().split(",");
               if(tokens == null || tokens.length < 2)
                   return;
               String classfier = tokens[tokens.length-1];
               for(int i = 0; i < tokens.length; i++)
               {
                   if(i < tokens.length-1)
                       outputKey.set(tokens[i], classfier);
                   else
                       outputKey.set("CLASS", classfier);
                   context.write(outputKey, outputValue);
               }
           }
        }
    
    @Deprecated
    public static class CJBayesClassfierReducer extends Reducer<PairOfStrings, LongWritable, PairOfStrings, LongWritable>
    {
        @Override
        protected void reduce(
                PairOfStrings key,
                Iterable<LongWritable> values,
                Context context)
                throws IOException, InterruptedException {
            Long sum = 0L;
            for(LongWritable time : values)
            {
                sum +=  time.get();
            }
            context.write(key, new LongWritable(sum));
        }
    }

    public static class CJBayesClassfierReducer2 extends Reducer<PairOfStrings, LongWritable, PairOfStrings, Text>
    {
        /** 
         * 設置多個文件輸出 
         * */ 
       private MultipleOutputs<PairOfStrings, Text> mos;
       
       @Override 
       protected void setup(Context context) 
       throws IOException, InterruptedException { 
         mos=new MultipleOutputs<PairOfStrings, Text>(context);//初始化mos 
       } 
       
       /***
        * 將key值相同的value進行累加
        */
        @Override
        protected void reduce(
                PairOfStrings key,
                Iterable<LongWritable> values,
                Context context)
                throws IOException, InterruptedException {
            System.out.println("key =" + key );
            Long sum = 0L;
            for(LongWritable time : values)
            {
                sum +=  time.get();
            }
           String result = key.getLeftElement() + "," + key.getRightElement() + "," + sum;
            if(key.getLeftElement().equals("CLASS"))
                mos.write("CLASS",  NullWritable.get(), new Text(result));
            else
                mos.write("OTHERS", NullWritable.get(), new Text(result));
        }
        
        /***
         * 務必釋放資源,否則不會有輸出內容
         */
        @Override 
        protected void cleanup( 
        Context context) 
        throws IOException, InterruptedException { 
        mos.close();//釋放資源 
        } 
    }

    public static void main(String[] args) throws Exception
    {
        args = new String[2];
        args[0] = "/media/chenjie/0009418200012FF3/ubuntu/weather.txt";
        args[1] = "/media/chenjie/0009418200012FF3/ubuntu/CJBayes";;
        int jobStatus = submitJob(args);
        System.exit(jobStatus);
    }
    
    public static int submitJob(String[] args) throws Exception {
        int jobStatus = ToolRunner.run(new CJBayesClassfier_Step1(), args);
        return jobStatus;
    }

    @SuppressWarnings("deprecation")
    @Override
    public int run(String[] args) throws Exception {
        Configuration conf = getConf();
        Job job = new Job(conf);
        job.setJobName("Bayes");

        MultipleOutputs.addNamedOutput(job, "CLASS", TextOutputFormat.class, Text.class, Text.class); 
        MultipleOutputs.addNamedOutput(job, "OTHERS", TextOutputFormat.class, Text.class, Text.class); 
        
        job.setInputFormatClass(TextInputFormat.class);
        job.setOutputFormatClass(TextOutputFormat.class);
        
        job.setOutputKeyClass(PairOfStrings.class);       
        job.setOutputValueClass(LongWritable.class);      
       
        
        job.setMapperClass(CJBayesClassfierMapper.class);
        job.setReducerClass(CJBayesClassfierReducer2.class);

        FileInputFormat.setInputPaths(job, new Path(args[0]));
        FileOutputFormat.setOutputPath(job, new Path(args[1]));
        
        FileSystem fs = FileSystem.get(conf);
        Path outPath = new Path(args[1]);
        if(fs.exists(outPath))
        {
            fs.delete(outPath, true);
        }
        
        boolean status = job.waitForCompletion(true);
        return status ? 0 : 1;
    }
    
    
}


package cjbayesclassfier;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;

import edu.umd.cloud9.io.pair.PairOfStrings;
/***
 * 第二步:計算概率
 * @author chenjie
 *
 */
public class CJBayesClassfier_Step2 extends Configured implements Tool  {
    public static class CJBayesClassfierMapper2 extends Mapper<LongWritable, Text, PairOfStrings, DoubleWritable>
    {
        PairOfStrings outputKey = new PairOfStrings();
        DoubleWritable outputValue = new DoubleWritable(1);
        private Map<String,Integer> classMap = new HashMap<String,Integer>();
        @Override
        protected void setup(Context context) throws IOException, InterruptedException {
            FileReader fr = new FileReader("CLASS");
            BufferedReader br = new BufferedReader(fr);
            String line = null;
            while((line = br.readLine()) != null)
            {
                String tokens[] = line.split(",");
                String classfier = tokens[1];
                String count = tokens[2];
                classMap.put(classfier, Integer.parseInt(count));
            }
            fr.close();
            br.close();
            int sum = 0;
            for(Map.Entry<String,Integer> entry : classMap.entrySet())
            {
                sum += entry.getValue();
            }
            for(Map.Entry<String,Integer> entry : classMap.entrySet())
            {
                double poss = entry.getValue() * 1.0 / sum;
                context.write(new PairOfStrings("CLASS", entry.getKey()), new DoubleWritable(poss));
            }
        }
        
        @Override
        protected void map(
                LongWritable key,
                Text value,
                Context context)
                throws IOException, InterruptedException {
               String tokens[] = value.toString().split(",");
               if(tokens == null || tokens.length < 3)
                   return;
               String X = tokens[0];
               String classfier = tokens[1];
               Integer count = Integer.valueOf(tokens[2]);
               outputKey.set(X, classfier);
               Integer classCount = classMap.get(classfier);
               outputValue.set(count * 1.0 / classCount);
               context.write(outputKey, outputValue);
           }
        }
    
    public static class CJBayesClassfierReducer2 extends Reducer<PairOfStrings, DoubleWritable, NullWritable, Text>
    {
        @Override
        protected void reduce(
                PairOfStrings key,
                Iterable<DoubleWritable> values,
                Context context)
                throws IOException, InterruptedException {
            for(DoubleWritable dw : values)
                context.write(NullWritable.get(), new Text(key.getLeftElement() + "," + key.getRightElement() + "," + dw));
        }
    }


    public static void main(String[] args) throws Exception
    {
        args = new String[2];
        args[0] = "/media/chenjie/0009418200012FF3/ubuntu/CJBayes/OTHERS-r-00000";
        args[1] = "/media/chenjie/0009418200012FF3/ubuntu/CJBayes2";
        int jobStatus = submitJob(args);
        System.exit(jobStatus);
    }
    
    public static int submitJob(String[] args) throws Exception {
        int jobStatus = ToolRunner.run(new CJBayesClassfier_Step2(), args);
        return jobStatus;
    }

    @SuppressWarnings("deprecation")
    @Override
    public int run(String[] args) throws Exception {
        Configuration conf = getConf();
        Job job = new Job(conf);
        job.setJobName("Bayes");

        job.addCacheArchive(new URI("/media/chenjie/0009418200012FF3/ubuntu/CJBayes/CLASS-r-00000" + "#CLASS"));
        
        job.setInputFormatClass(TextInputFormat.class);
        job.setOutputFormatClass(TextOutputFormat.class);
        
        job.setOutputKeyClass(PairOfStrings.class);       
        job.setOutputValueClass(DoubleWritable.class);      
       
        
        job.setMapperClass(CJBayesClassfierMapper2.class);
        job.setReducerClass(CJBayesClassfierReducer2.class);

        FileInputFormat.setInputPaths(job, new Path(args[0]));
        FileOutputFormat.setOutputPath(job, new Path(args[1]));
        
        FileSystem fs = FileSystem.get(conf);
        Path outPath = new Path(args[1]);
        if(fs.exists(outPath))
        {
            fs.delete(outPath, true);
        }
        
        boolean status = job.waitForCompletion(true);
        return status ? 0 : 1;
    }
    
    
}


package cjbayesclassfier;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;

import edu.umd.cloud9.io.pair.PairOfStrings;

/***
 * 第三步:根據上一步計算的概率進行貝葉斯推斷
 * @author chenjie
 *
 */
public class CJBayesClassfier_Step3 extends Configured implements Tool  {
    public static class CJBayesClassfierMapper3 extends Mapper<LongWritable, Text, Text, LongWritable>
    {
        LongWritable outputValue = new LongWritable(1);
        @Override
        protected void map(
                LongWritable key,
                Text value,
                Context context)
                throws IOException, InterruptedException {
               context.write(value, outputValue);
           }
        }
    
    public static class CJBayesClassfierReducer3 extends Reducer<Text, LongWritable, Text, Text>
    {
        private List<String> classfications;
        @Override
        protected void setup(
                Reducer<Text, LongWritable, Text, Text>.Context context)
                throws IOException, InterruptedException {
            classfications = buildClassfications();
            for(String classfication : classfications)
            {
                System.out.println("分類:" + classfication);
            }
            buildCJGLTable();
            CJGLTable.show();
        }
        
        

        private List<String> buildClassfications() throws IOException {
            List<String> list = new ArrayList<String>();
            FileReader fr = new FileReader("CLASS");
            BufferedReader br = new BufferedReader(fr);
            String line = null;
            while((line = br.readLine()) != null)
            {
                String tokens[] = line.split(",");
                String classfier = tokens[1];
                list.add(classfier);
            }
            fr.close();
            br.close();
            return list;
        }
        
        private void buildCJGLTable() throws IOException {
            FileReader fr = new FileReader("GL");
            BufferedReader br = new BufferedReader(fr);
            String line = null;
            while((line = br.readLine()) != null)
            {
                String tokens[] = line.split(",");
                PairOfStrings key  = new PairOfStrings(tokens[0],tokens[1]);
                CJGLTable.add(key, Double.valueOf(tokens[2]));
            }
            fr.close();
            br.close();
        }

        @Override
        protected void reduce(
                Text key,
                Iterable<LongWritable> values,
                Context context)
                throws IOException, InterruptedException {
           System.out.println("key=" + key);
           System.out.println("values:");
           for(LongWritable lw : values)
           {
               System.out.println(lw);
           }
           String [] attributes = key.toString().split(",");
           String selectedClass = null;
           double maxPosterior = 0.0;
           for(String aClass : classfications)
           {
               System.out.println("對於類別:" + aClass);
               double posterior = CJGLTable.getClassGL(aClass);
               System.out.println("其概率爲:" + posterior);
               for(String attr : attributes)
               {
                   System.out.println("\t對於條件:"  + attr);
                   double conGL = CJGLTable.getConditionalGL(attr, aClass);
                   System.out.println("\t其概率爲:" + conGL);
                   posterior *= CJGLTable.getConditionalGL(attr, aClass);
               }
               
               if(selectedClass == null)
               {
                   selectedClass = aClass;
                   maxPosterior = posterior;
               }
               else
               {
                   if(posterior > maxPosterior)
                   {
                       selectedClass = aClass;
                       maxPosterior = posterior;
                   }
               }
               context.write(key, new Text("貝葉斯分類:" + selectedClass + ",其概率爲" + maxPosterior));
           }
           context.write(key, new Text("最終結果:貝葉斯分類爲" + selectedClass + ",其概率爲" + maxPosterior));
        }
    }


    public static void main(String[] args) throws Exception
    {
        args = new String[2];
        args[0] = "/media/chenjie/0009418200012FF3/ubuntu/weather_predict.txt";
        args[1] = "/media/chenjie/0009418200012FF3/ubuntu/CJBayes3";
        int jobStatus = submitJob(args);
        System.exit(jobStatus);
    }
    
    public static int submitJob(String[] args) throws Exception {
        int jobStatus = ToolRunner.run(new CJBayesClassfier_Step3(), args);
        return jobStatus;
    }

    @SuppressWarnings("deprecation")
    @Override
    public int run(String[] args) throws Exception {
        Configuration conf = getConf();
        Job job = new Job(conf);
        job.setJobName("Bayes");

        job.addCacheArchive(new URI("/media/chenjie/0009418200012FF3/ubuntu/CJBayes/CLASS-r-00000" + "#CLASS"));
        job.addCacheArchive(new URI("/media/chenjie/0009418200012FF3/ubuntu/CJBayes2/part-r-00000" + "#GL"));
        
        job.setInputFormatClass(TextInputFormat.class);
        job.setOutputFormatClass(TextOutputFormat.class);
        
        job.setOutputKeyClass(Text.class);       
        job.setOutputValueClass(LongWritable.class);      
       
        
        job.setMapperClass(CJBayesClassfierMapper3.class);
        job.setReducerClass(CJBayesClassfierReducer3.class);

        FileInputFormat.setInputPaths(job, new Path(args[0]));
        FileOutputFormat.setOutputPath(job, new Path(args[1]));
        
        FileSystem fs = FileSystem.get(conf);
        Path outPath = new Path(args[1]);
        if(fs.exists(outPath))
        {
            fs.delete(outPath, true);
        }
        
        boolean status = job.waitForCompletion(true);
        return status ? 0 : 1;
    }
    
    
}


package cjbayesclassfier;

import java.util.HashMap;
import java.util.Map;

import edu.umd.cloud9.io.pair.PairOfStrings;
/***
 * 保存概率表
 * @author chenjie
 */
public class CJGLTable {
    private static Map<PairOfStrings,Double> map = new HashMap<PairOfStrings,Double>();
    public static void add(PairOfStrings key,Double gl)
    {
        map.put(key, gl);
    }
    public static double getClassGL(String aClass)
    {
        PairOfStrings pos = new PairOfStrings("CLASS",aClass);
        return map.get(pos)==null ? 0 : map.get(pos);
    }
    public static double getConditionalGL(String conditional,String aClass)
    {
        PairOfStrings pos = new PairOfStrings(conditional,aClass);
        return map.get(pos)==null ? 0 : map.get(pos);
    }
    public static void show()
    {
        for(Map.Entry<PairOfStrings,Double> entry : map.entrySet())
        {
            System.out.println(entry);
        }
    }
}




第一步:
輸入:weather.txt
--------------------------
Sunny,Hot,High,Weak,No
Sunny,Hot,High,Strong,No
Overcast,Hot,High,Weak,Yes
Rain,Mild,High,Weak,Yes
Rain,Cool,Normal,Weak,Yes
Rain,Cool,Normal,Strong,No
Overcast,Cool,Normal,Strong,Yes
Sunny,Mild,High,Weak,No
Sunny,Cool,Normal,Weak,Yes
Rain,Mild,Normal,Weak,Yes
Sunny,Mild,Normal,Strong,Yes
Overcast,Mild,High,Strong,Yes
Overcast,Hot,Normal,Weak,Yes
Rain,Mild,High,Strong,No


輸出:
CLASS-r-00000
----------------------
CLASS,No,5
CLASS,Yes,9


OTHERS-r-00000
--------------------------
Cool,No,1
Cool,Yes,3
High,No,4
High,Yes,3
Hot,No,2
Hot,Yes,2
Mild,No,2
Mild,Yes,4
Normal,No,1
Normal,Yes,6
Overcast,Yes,4
Rain,No,2
Rain,Yes,3
Strong,No,3
Strong,Yes,3
Sunny,No,3
Sunny,Yes,2
Weak,No,2
Weak,Yes,6


第二步:
緩存:CLASS-r-00000
-----------------------
CLASS,No,5
CLASS,Yes,9
輸入:OTHERS-r-00000
------------------------
Cool,No,1
Cool,Yes,3
High,No,4
High,Yes,3
Hot,No,2
Hot,Yes,2
Mild,No,2
Mild,Yes,4
Normal,No,1
Normal,Yes,6
Overcast,Yes,4
Rain,No,2
Rain,Yes,3
Strong,No,3
Strong,Yes,3
Sunny,No,3
Sunny,Yes,2
Weak,No,2
Weak,Yes,6

輸出:
part-r-00000
----------------------------------
CLASS,No,0.35714285714285715
CLASS,Yes,0.6428571428571429
Cool,No,0.2
Cool,Yes,0.3333333333333333
High,No,0.8
High,Yes,0.3333333333333333
Hot,No,0.4
Hot,Yes,0.2222222222222222
Mild,No,0.4
Mild,Yes,0.4444444444444444
Normal,No,0.2
Normal,Yes,0.6666666666666666
Overcast,Yes,0.4444444444444444
Rain,No,0.4
Rain,Yes,0.3333333333333333
Strong,No,0.6
Strong,Yes,0.3333333333333333
Sunny,No,0.6
Sunny,Yes,0.2222222222222222
Weak,No,0.4
Weak,Yes,0.6666666666666666


第三步:
緩存:CLASS-r-00000
-------------------------------
CLASS,No,5
CLASS,Yes,9

緩存:part-r-00000
------------------------------------
CLASS,No,0.35714285714285715
CLASS,Yes,0.6428571428571429
Cool,No,0.2
Cool,Yes,0.3333333333333333
High,No,0.8
High,Yes,0.3333333333333333
Hot,No,0.4
Hot,Yes,0.2222222222222222
Mild,No,0.4
Mild,Yes,0.4444444444444444
Normal,No,0.2
Normal,Yes,0.6666666666666666
Overcast,Yes,0.4444444444444444
Rain,No,0.4
Rain,Yes,0.3333333333333333
Strong,No,0.6
Strong,Yes,0.3333333333333333
Sunny,No,0.6
Sunny,Yes,0.2222222222222222
Weak,No,0.4
Weak,Yes,0.6666666666666666

輸入:weather_predict.txt
---------------------------------
Overcast,Hot,High,Strong


過程:
---------------------------------------------
分類:No
分類:Yes
(High, No)=0.8
(Strong, No)=0.6
(Normal, No)=0.2
(Normal, Yes)=0.6666666666666666
(Strong, Yes)=0.3333333333333333
(CLASS, No)=0.35714285714285715
(CLASS, Yes)=0.6428571428571429
(Cool, No)=0.2
(High, Yes)=0.3333333333333333
(Hot, No)=0.4
(Sunny, No)=0.6
(Weak, No)=0.4
(Cool, Yes)=0.3333333333333333
(Mild, No)=0.4
(Overcast, Yes)=0.4444444444444444
(Rain, No)=0.4
(Rain, Yes)=0.3333333333333333
(Weak, Yes)=0.6666666666666666
(Hot, Yes)=0.2222222222222222
(Sunny, Yes)=0.2222222222222222
(Mild, Yes)=0.4444444444444444
key=Overcast,Hot,High,Strong
values:
1
對於類別:No
其概率爲:0.35714285714285715
   對於條件:Overcast
   其概率爲:0.0
   對於條件:Hot
   其概率爲:0.4
   對於條件:High
   其概率爲:0.8
   對於條件:Strong
   其概率爲:0.6
對於類別:Yes
其概率爲:0.6428571428571429
   對於條件:Overcast
   其概率爲:0.4444444444444444
   對於條件:Hot
   其概率爲:0.2222222222222222
   對於條件:High
   其概率爲:0.3333333333333333
   對於條件:Strong
   其概率爲:0.3333333333333333

輸出:
Overcast,Hot,High,Strong   貝葉斯分類:No,其概率爲0.0
Overcast,Hot,High,Strong   貝葉斯分類:Yes,其概率爲0.007054673721340388
Overcast,Hot,High,Strong   最終結果:貝葉斯分類爲Yes,其概率爲0.007054673721340388



使用Spark(原生API)

import org.apache.spark.{SparkConf, SparkContext}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

object CJBayes {
  def main(args: Array[String]): Unit = {
    val sparkConf = new SparkConf().setAppName("cjbayes").setMaster("local")
    val sc = new SparkContext(sparkConf)
    val input = "file:///media/chenjie/0009418200012FF3/ubuntu/weather.txt"
    val predictFile = "file:///media/chenjie/0009418200012FF3/ubuntu/weather_predict.txt"
    val output = "file:///media/chenjie/0009418200012FF3/ubuntu/weather"
    val inputRDD = sc.textFile(input)
    val trainDataSize = inputRDD.count()
    val mapRDD = inputRDD.flatMap{line=>
      val result = ArrayBuffer[Tuple2[Tuple2[String,String],Integer]]()
      val tokens = line.split(",")
      val classfier = tokens(tokens.length-1)
      for(i <- 0 until tokens.length-1){
        result += (Tuple2(Tuple2(tokens(i),classfier),1))
      }
      result += (Tuple2(Tuple2("CLASS",classfier),1))
      result
    }
    val reduceRDD = mapRDD.reduceByKey(_+_)
    val countsMap = reduceRDD.collectAsMap()
    val PT = new mutable.HashMap[Tuple2[String,String],Double]()
    val CLASSFICATIONS = new mutable.ArrayBuffer[String]()
    countsMap.foreach(item=>{
      val k = item._1
      val v:Integer = item._2
      val condition = k._1
      val classfication = k._2
      if(condition.equals("CLASS")){
        PT.put(k,v.toDouble/trainDataSize.toDouble)
        CLASSFICATIONS += k._2
      }
      else{
        val k2 = new Tuple2[String,String]("CLASS",classfication)
        val count = countsMap.get(k2)
        if(count==null){
          PT.put(k,0.0)
        }
        else{
          PT.put(k,v.toDouble/count.get)
        }
      }
    })
    PT.foreach(println)

    val predict = sc.textFile(predictFile)
    predict.map(line=>{
      val attributes = line.split(",")
      var selectedClass = ""
      var maxPosterior = 0.0
      for(aClass <- CLASSFICATIONS){
        println("對於類:" + aClass)
        var posterior: Double = if (PT.get(Tuple2("CLASS", aClass)) == None) 0 else PT.get(Tuple2("CLASS", aClass)).get
        println("其概率爲:" + posterior)
        for(attr <- attributes){
          println("\t對於條件:" + attr)
          val probability:Double = if (PT.get(Tuple2(attr,aClass)) == None) 0 else PT.get(Tuple2(attr,aClass)).get
          println("\t其概率爲:" + probability)
          posterior *= probability
          if(selectedClass == null){
            selectedClass = aClass
            maxPosterior = posterior
          }
          else{
            if(posterior > maxPosterior){
              selectedClass = aClass
              maxPosterior = posterior
            }
          }
        }
      }
      line + "," + selectedClass + ":" + maxPosterior
    }).foreach(println)
}

使用Spark(mllib機器學習庫)
import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.{SparkConf, SparkContext}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

object CJBayes {
  def main(args: Array[String]): Unit = {
    val sparkConf = new SparkConf().setAppName("cjbayes").setMaster("local")
    val sc = new SparkContext(sparkConf)
    val input = "file:///media/chenjie/0009418200012FF3/ubuntu/weather1.txt"
    val predictFile = "file:///media/chenjie/0009418200012FF3/ubuntu/weather_predict.txt"

    val data = sc.textFile(input)
    val parsedData =data.map { line =>
      val parts =line.split(',')
      LabeledPoint(parts(1).toDouble,Vectors.dense(parts(0).split(' ').map(_.toDouble)))
    }
    // 把數據的100%作爲訓練集,0%作爲測試集.
    val splits = parsedData.randomSplit(Array(1.0,0.0),seed = 11L)
    val training =splits(0)
    val test =splits(1)

    //獲得訓練模型,第一個參數爲數據,第二個參數爲平滑參數,默認爲1,可改
    val model =NaiveBayes.train(training,lambda = 1.0)

    //對模型進行準確度分析
    val predictionAndLabel= test.map(p => (model.predict(p.features),p.label))
    val accuracy =1.0 *predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()

    println("accuracy-->"+accuracy)
    println("Predictionof (2.0,1.0,1.0,2.0):"+model.predict(Vectors.dense(2.0,1.0,1.0,2.0)))
  }
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章