【分類】- 基於樸素貝葉斯的垃圾郵件預測系統

簡單介紹:

這裏我舉例來簡單的說明下貝葉斯算法 :

如上圖 : 假設一個班有 100 人 , 其中80%的是玩王者榮耀的 , 20%玩吃雞 ,那麼在所有人中 玩LOL的佔 10% 

其中 同時玩 王者 和 LOL 的人 有 8個人 ,佔所有玩 王者的人中 的 8 / 80 ,同樣的 同時玩 吃雞和 LOL的佔所有玩吃雞人數的 

2 / 20 根據這個概率表 , 我們可以計算出 當我們只知道玩 LOL的前提下 , 這個人可能玩 王者 , 或者 吃雞的概率

 如上圖 : 其實簡單的理解就是 用  同時 玩王者和LOL的人數 除以 所有玩 LOL的總人數 , 這樣就可以計算出 ,當一個人玩LOL 的前提下 , 這個人還有可能玩王者的概率 , 同理可以計算出 當一個人玩LOL的前提下 還有可能玩 吃雞的概率。

然後二者進行比較 , 那個概率大就說明更有可能玩哪種遊戲 。 如上圖,顯然當知道一個人玩LOL的前提下 , 這個人玩王者的概率 顯然要大於玩 吃雞的概率 , 那麼我們更傾向於 這個人可能玩 王者的概率較大。

這裏我們可以的得出一個貝葉斯公式 :P(A|B) = (PB|A)* P(A)  / P(B) 如下圖

這裏我在右邊貼出了 計算 玩LOL的前提下 玩王者的概率 計算公式 , 進行對比 比較容易理解 。其實這個就是貝葉斯算法的簡單理解 , 下面說說垃圾郵件預測的思路

垃圾郵件預測介紹

算法思路 

這裏我同樣用舉例的方式來介紹 ,那麼假設 我們有 100封郵件 , 其中80封是正常郵件(佔所有有郵件的80%)20封是正常郵件(佔所有郵件的20%) , 然後 ,在所有郵件中 有 5 封是含有 fuck 這個單詞的 (佔所有郵件的5%)那麼我們可以得出下面這個概率表

如上圖 : 同理我們可以計算出 , 當一封郵件中出現 fuck這個單詞的時候 , 這封郵件可能是垃圾郵件的概率是 80% , 有可能是非垃圾郵件的概率是 20% , 通過比較二者的概率 , 可以計算出 ,當一封郵件出現 fuck 單詞時 , 這封郵件更有可能是垃圾單詞的,當然 , 一封郵件中肯定不可能只有一個單詞, 我們可以根據出現的單詞在不同的郵件中出現的概率計算出當這些郵件中出現這些單詞時 ,這封郵件可能是垃圾郵件的概率。

至於 如何計算每個單詞在不同種類的概率 ,下面我舉例說明

如上圖 , 這裏有四封郵件 , 我們可以對郵件內容進行切分, 得到切分後的所有不重複的單詞 , 然後統計郵件中對應單詞的數量,進而計算出單詞在指定類型的郵件中出現的概率。 比如 單詞 i ,在四封郵件中出現了三次, 那麼我們可以認爲 i 在這種類型的單詞中出現的概率是 3 / 4 進而進行計算得出 一個單詞概率表 ,用於計算。

數據來源

垃圾郵件預測的最重要的一環是 構建模型的數據來源 , 因爲數據源直接影響了你最終預測的準確性, 我在網上找到的多是英文的郵件 , 中文的比較少 , 這個是我好不容易找到的 ,大概有14000+ 封郵件 , 其中一半是正常的郵件 , 一半是垃圾郵件 ,這裏我直接給出百度雲鏈接 , 如果大家有好的數據源 也可已在評論區分享下 , 感謝

百度網盤下載鏈接: https://pan.baidu.com/s/1Hsno4oREMROxWwcC_jYAOA 

密碼: qa49

 這裏有一點需要注意的是 ,這些郵件都是GBK格式的 ,如果你的項目使用的是UTF-8 ,那麼有可能會亂碼 , 這裏我是使用的是轉換流, 使用GBK格式讀取 , 使用UTF-8 寫出 , 然後就可以批量轉化 爲 UTF-8 格式 代碼如下

package com.wangt.bayes.test;

import spark.utils.IOUtils;

import java.io.*;

/**
 * 批量轉換文件編碼
 */
public class ConductData {

    static InputStreamReader isr = null;
    static OutputStreamWriter osr = null;
    public static void main(String[] args) throws IOException {

        File normal = new File("data/test/normal"); // 這裏讀取的是文件的目錄
        File[] normals = normal.listFiles();
        ConductData cd = new ConductData();

        for (int i = 0; i < normals.length; i++) {
            cd.changeEcoding(normals[i]);
            System.out.println("更改成功 +" + i);
        }

    }

    /**
     * 轉換文件的編碼
     * @param f
     */
    public void changeEcoding(File f){

        try {
            isr =  new InputStreamReader(new FileInputStream(f), "GBK");
            File out = new File("mydata/test/ham/" + f.getName() + ".txt");
          if (!out.exists()){
              boolean res = out.createNewFile();
              System.out.println("文件 : " +f.getName()+ (res ? "創建成功" : "創建失敗"));
            }
            osr =  new OutputStreamWriter(new FileOutputStream(out) , "UTF-8");
           
             IOUtils.copy(isr , osr); // 這個是saprk自帶的拷貝流

        } catch (UnsupportedEncodingException e) {
            e.printStackTrace();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }

        // 關閉流
        if(isr != null){

            try {
                isr.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        if(osr != null){

            try {
                osr.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

}

構建模型的流程

文件切分 

這裏我使用的是結巴分詞器 ,它可以支持中文分詞 ,代碼如下

/**
     * 對單個字符串進行切分 獲取單詞集合
     *
     * @param line 被切分的單詞
     * @return 返回存儲被切分的單詞的 集合
     */
    public static List<String> splitWord(String line) {

        // 創建結巴分詞器
        JiebaSegmenter segmenter = new JiebaSegmenter();
        // 創建存儲 分割後的單詞的結婚
        ArrayList<String> datas = new ArrayList<>();
        // 分詞
        List<SegToken> list = segmenter.process(line, JiebaSegmenter.SegMode.SEARCH);

        // 過濾掉空格
        for (SegToken segToken : list) {

            if (segToken.equals(" ")) {
                continue;
            }
            datas.add(segToken.word);
        }
        return datas;
    }

如果你不知道如何下載結巴分詞器 , 下面我提供了maven的依賴

 <!-- 結巴 分詞-->
        <!-- https://mvnrepository.com/artifact/com.huaban/jieba-analysis -->
        <dependency>
            <groupId>com.huaban</groupId>
            <artifactId>jieba-analysis</artifactId>
            <version>1.0.2</version>
        </dependency>

構建詞袋

下一步是獲取所有不重複的單詞 ,構建詞袋 ,這裏我使用的是 TreeMap , 因爲Map的key 不能重複 ,正好可以存儲單詞 , value是默認爲 0 , 爲下一步計算單詞詞頻做鋪墊

/**
     * 切分一個字符串 , 獲取不重複的單詞 , 並存儲到一個 Map 中
     * key 爲 單詞 , value 默認是 0
     *
     * @param lines 獲取重複單詞的數據
     * @return 存放不重複單詞的詞袋
     */
    public static TreeMap<String, Integer> getWordBag(String lines) {
        // 分詞
        List<String> values = splitWord(lines);

        // 獲取不重複的單詞
        TreeMap<String, Integer> wordBag = new TreeMap<>();
        for (String line : values) {
            wordBag.put(line, 0);
        }
        // 返回值
        return wordBag;
    }

統計單詞詞頻 , 並將詞頻轉化爲 double數組

/**
     * 指定一個存儲單詞的詞袋 , 切分指定字符串後得到單詞 , 獲取詞袋中對應單詞出現的次數組成double數組
     *
     * @param line    需要統計單詞的語句
     * @param wordBag 存放不重複單詞的詞袋
     * @return 存放單詞詞頻的double數組
     */
    public static double[] getWordCount(String line, TreeMap<String, Integer> wordBag) {

        // 創建存放單詞詞頻的 對象
        TreeMap<String, Integer> countWord = new TreeMap<>();
        // 將詞袋內的單詞添加進 統計詞頻的對象中
        countWord.putAll(wordBag);

        // 對被統計的單詞進行分詞
        List<String> wordDatas = splitWord(line);

        // 統計單詞詞頻
        for (String wordData : wordDatas) {
            // 如果被統計的單詞在詞袋中沒有出現 , 將該單詞丟棄
            if (!countWord.containsKey(wordData)) {
                continue;
            }
            countWord.replace(wordData, countWord.get(wordData) + 1);
        }

        // 將單詞詞頻存放到 double 數組中
        double[] values = new double[countWord.size()];
        int size = 0;
        Iterator<String> countKey = countWord.keySet().iterator();
        while (countKey.hasNext()) {
            String word = countKey.next();
            values[size] = countWord.get(word);
            size++;
        }

        // 返回存放單詞詞頻的double數組
        return values;
    }

然後是 將 詞頻數組封裝到 LabeledPoint 對象中 進行模型的構建

public static void main(String[] args) {
        // 1.獲取 SparkContext
        JavaSparkContext sc = BayesUtils.init();

        // 2.讀取樣本數據
        JavaRDD<String> lines =  sc.textFile("input/navie_bayes_data.txt");

        List<String> words = lines.take((int) lines.count());

        // 3.使用結巴分詞器進行分詞 獲取不重複的單詞
        TreeMap<String , Integer> wordBag = BayesUtils.getWordBag(words);

        // 4.對數據進行切分 將切分後的 (標籤 , 內容) 封裝到 LabeledPoint 對象中 ,並將所有的LabeledPoint存儲到 RDD中
        JavaRDD<LabeledPoint> parsedData = lines.map(new Function<String, LabeledPoint>() {
            @Override
            public LabeledPoint call(String v1) throws Exception {

                // 對數據進行按行切分
                String[] fields = v1.split(",");
                double key = 0.0;
                if(fields[0].equals("ham")){ // ham = 0 , spam = 1 這裏特別注意 是fields[0] 不是fields
                    key = 0;
                }else {
                    key = 1;
                }
                // 獲取單詞詞頻
                double[] datas = BayesUtils.getWordCount(fields[1] , wordBag);

                // 返回封裝好的值
                return new LabeledPoint(key ,new DenseVector(datas));
            }
        });

        // 6.樣本數據劃分 訓練樣本 和測試樣本
        // 一個參數是 訓練樣本和測試的比例 , 另一個是隨機數的種子
        JavaRDD<LabeledPoint>[] splits = parsedData.randomSplit(new double[]{0.9 , 0.1} , 100L);

        // 7.獲取樣本集
        // 獲取訓練樣本集
        JavaRDD<LabeledPoint> training = splits[0];
        // 獲取測試樣本集
        JavaRDD<LabeledPoint> test = splits[1];

        // 8.對數據進行訓練 , 新建貝葉斯分類模型
        // 需要將 JavaRDD 轉化爲 org.apache.spark.rdd.RDD
        // 設置訓練集 , 以及拉普拉斯估計值
        NaiveBayesModel model = NaiveBayes.train(training.rdd() , 1);



        // 9.對測試樣本進行測試 並獲取預測的值
        JavaPairRDD<Double , Double> predictResult = test.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
            @Override
            public Tuple2<Double, Double> call(LabeledPoint labeledPoint) throws Exception {
                // 獲取預測的值
                Double res = model.predict(labeledPoint.features());
                // 返回值
                return new Tuple2<>(res , labeledPoint.label());
            }
        });

        predictResult.foreach(new VoidFunction<Tuple2<Double, Double>>() {
            @Override
            public void call(Tuple2<Double, Double> value) throws Exception {
                System.out.println("預測的結果 : " + value._1 + "\t" + "實際的結果 : " + value._2);
            }
        });

        predictResult.take(100);

        // 統計預測成功的個數
        Long sucess = predictResult.filter(new Function<Tuple2<Double, Double>, Boolean>() {
            @Override
            public Boolean call(Tuple2<Double, Double> v1) throws Exception {
                return v1._1().equals(v1._2()); // 注意 這裏是 Double類型的 不能直接用 ==
            }
        }).count();

        System.out.println("sucess =  " + sucess + "\t test = " + test.count());

        double rate = (((double)sucess) / test.count() ) * 100;
        System.out.println("預測的成功率是 :" + String.format("%.2f" , rate)+"%");

        // 保存模型
        //model.save(sc.sc() , "model/my_navie_bayes_model");
        // 關閉 SparkContext
        sc.stop();
    }

然後是預測的代碼 :

/**
     * 預測郵件是否爲垃圾郵件
     *
     * @param predictData 要預測的郵件
     * @return 返回預測結果 如果是 0 代表是正常郵件 如果是 1 代表郵件
     */
    public double predictChinese(String predictData) {

        // 獲取 SparkContext 對象
        JavaSparkContext sc = BayesUtils.init();
        // 獲取模型
        String modelPath = Predicted.class.getClassLoader().getResource("model_cn/naive_model_optimize").getPath();
        NaiveBayesModel model = NaiveBayesModel.load(sc.sc() , modelPath);

        // 獲取詞袋
        // 此處加載 resources 的資源時 一定要 獲取類加載器
        String wordBagPath = Predicted.class.getClassLoader().getResource("wordbags_cn").getPath();

        JavaRDD<String> wordBagRDD = sc.textFile(wordBagPath);
        TreeMap<String, Integer> wordBags = BayesUtils.getWordBag(wordBagRDD.collect());

        System.out.println("wordBags : ==>" + wordBags.size());

        // 獲取停用單詞
        String stopwWordPath = Predicted.class.getClassLoader().getResource("stopWords.txt").getPath();

        double[] datas = BayesUtils.getWordCount(predictData, wordBags);

        // 開始預測
        // 將語句向量封裝到 LabeledPoint 對象中
        LabeledPoint lp = new LabeledPoint(-1, new DenseVector(datas));

        // 獲取預測結果
        double predictResult = model.predict(lp.features());

        // 關閉 SparkContext
        sc.stop();
        return predictResult;
    }

 

總結

maven 依賴

<dependencies>
        <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>4.11</version>
            <scope>test</scope>
        </dependency>

        <!-- spark 核心依賴包 -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_2.11</artifactId>
            <version>2.1.1</version>
        </dependency>
        <!-- spark mllib 依賴-->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>2.1.1</version>
        </dependency>

        <!-- spark javaAPI 依賴-->
        <dependency>
            <groupId>com.sparkjava</groupId>
            <artifactId>spark-core</artifactId>
            <version>2.1</version>
        </dependency>

        <!-- spark sql 依賴-->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.11</artifactId>
            <version>2.1.2</version>
        </dependency>

        <!-- jcesg 分詞 -->
        <dependency>
            <groupId>org.lionsoul</groupId>
            <artifactId>jcseg-core</artifactId>
            <version>2.4.0</version>
        </dependency>
        <dependency>
            <groupId>org.lionsoul</groupId>
            <artifactId>jcseg-analyzer</artifactId>
            <version>2.4.0</version>
        </dependency>
        <dependency>
            <groupId>org.lionsoul</groupId>
            <artifactId>jcseg-elasticsearch</artifactId>
            <version>2.4.0</version>
        </dependency>
        <!-- 結巴 分詞-->
        <!-- https://mvnrepository.com/artifact/com.huaban/jieba-analysis -->
        <dependency>
            <groupId>com.huaban</groupId>
            <artifactId>jieba-analysis</artifactId>
            <version>1.0.2</version>
        </dependency>

        <!-- scala 庫-->
        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-library</artifactId>
            <version>2.11.8</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/org.scala-lang/scala-compiler -->
        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-compiler</artifactId>
            <version>2.11.8</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/org.scala-lang/scala-reflect -->
        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-reflect</artifactId>
            <version>2.11.8</version>
        </dependency>
        <dependency>
            <groupId>org.junit.jupiter</groupId>
            <artifactId>junit-jupiter-api</artifactId>
            <version>RELEASE</version>
            <scope>compile</scope>
        </dependency>
        <dependency>
            <groupId>org.junit.jupiter</groupId>
            <artifactId>junit-jupiter-api</artifactId>
            <version>RELEASE</version>
            <scope>compile</scope>
        </dependency>


    </dependencies>

構建模型的代碼 :

package com.wangt.bayes.test;

import com.wangt.bayes.utils.BayesUtils;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.regression.LabeledPoint;
import scala.Tuple2;

import java.util.List;
import java.util.TreeMap;

/**
 * @author wangt
 * @create 2019-03-06 20:35
 */
public class TrainModel {

   /* public static NaiveBayesModel trainModel(String trainDataPath){


    }
*/
    public static void main(String[] args) {
        // 1.獲取 SparkContext
        JavaSparkContext sc = BayesUtils.init();

        // 2.讀取樣本數據
        JavaRDD<String> lines =  sc.textFile("input/navie_bayes_data.txt");

        List<String> words = lines.take((int) lines.count());

        // 3.使用結巴分詞器進行分詞 獲取不重複的單詞
        TreeMap<String , Integer> wordBag = BayesUtils.getWordBag(words);

        // 4.對數據進行切分 將切分後的 (標籤 , 內容) 封裝到 LabeledPoint 對象中 ,並將所有的LabeledPoint存儲到 RDD中
        JavaRDD<LabeledPoint> parsedData = lines.map(new Function<String, LabeledPoint>() {
            @Override
            public LabeledPoint call(String v1) throws Exception {

                // 對數據進行按行切分
                String[] fields = v1.split(",");
                double key = 0.0;
                if(fields[0].equals("ham")){ // ham = 0 , spam = 1 這裏特別注意 是fields[0] 不是fields
                    key = 0;
                }else {
                    key = 1;
                }
                // 獲取單詞詞頻
                double[] datas = BayesUtils.getWordCount(fields[1] , wordBag);

                // 返回封裝好的值
                return new LabeledPoint(key ,new DenseVector(datas));
            }
        });

        // 6.樣本數據劃分 訓練樣本 和測試樣本
        // 一個參數是 訓練樣本和測試的比例 , 另一個是隨機數的種子
        JavaRDD<LabeledPoint>[] splits = parsedData.randomSplit(new double[]{0.9 , 0.1} , 100L);

        // 7.獲取樣本集
        // 獲取訓練樣本集
        JavaRDD<LabeledPoint> training = splits[0];
        // 獲取測試樣本集
        JavaRDD<LabeledPoint> test = splits[1];

        // 8.對數據進行訓練 , 新建貝葉斯分類模型
        // 需要將 JavaRDD 轉化爲 org.apache.spark.rdd.RDD
        // 設置訓練集 , 以及拉普拉斯估計值
        NaiveBayesModel model = NaiveBayes.train(training.rdd() , 1);



        // 9.對測試樣本進行測試 並獲取預測的值
        JavaPairRDD<Double , Double> predictResult = test.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
            @Override
            public Tuple2<Double, Double> call(LabeledPoint labeledPoint) throws Exception {
                // 獲取預測的值
                Double res = model.predict(labeledPoint.features());
                // 返回值
                return new Tuple2<>(res , labeledPoint.label());
            }
        });

        predictResult.foreach(new VoidFunction<Tuple2<Double, Double>>() {
            @Override
            public void call(Tuple2<Double, Double> value) throws Exception {
                System.out.println("預測的結果 : " + value._1 + "\t" + "實際的結果 : " + value._2);
            }
        });

        predictResult.take(100);

        // 統計預測成功的個數
        Long sucess = predictResult.filter(new Function<Tuple2<Double, Double>, Boolean>() {
            @Override
            public Boolean call(Tuple2<Double, Double> v1) throws Exception {
                return v1._1().equals(v1._2()); // 注意 這裏是 Double類型的 不能直接用 ==
            }
        }).count();

        System.out.println("sucess =  " + sucess + "\t test = " + test.count());

        double rate = (((double)sucess) / test.count() ) * 100;
        System.out.println("預測的成功率是 :" + String.format("%.2f" , rate)+"%");

        // 保存模型
        //model.save(sc.sc() , "model/my_navie_bayes_model");
        // 關閉 SparkContext
        sc.stop();
    }
}

封裝好的工具方法

package com.wangt.bayes.utils;

import com.huaban.analysis.jieba.JiebaSegmenter;
import com.huaban.analysis.jieba.SegToken;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;

import java.io.File;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.TreeMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * @author wangt
 * @create 2019-03-02 16:43
 */
public class BayesUtils {

    /**
     * 初始化 獲取 SparkContext 對象
     *
     * @return
     */
    public static JavaSparkContext init() {
        // 1.創建 SparkConf 對象
        SparkConf conf = new SparkConf();
        // 2.設置 Spark 運行模式 以及 應用名稱
        conf.setMaster("local").setAppName("MyNaive_bayes");
        // 3.獲取 SparkContext 對象
        /**
         * SparkContext 是 Spark 程序的入口 相當於 Hadoop的Job
         */
        return new JavaSparkContext(conf);
    }

    /**
     * 對單個字符串進行切分 獲取單詞集合
     *
     * @param line 被切分的單詞
     * @return 返回存儲被切分的單詞的 集合
     */
    public static List<String> splitWord(String line) {

        // 創建結巴分詞器
        JiebaSegmenter segmenter = new JiebaSegmenter();
        // 創建存儲 分割後的單詞的結婚
        ArrayList<String> datas = new ArrayList<>();
        // 分詞
        List<SegToken> list = segmenter.process(line, JiebaSegmenter.SegMode.SEARCH);

        // 過濾掉空格
        for (SegToken segToken : list) {

            if (segToken.equals(" ")) {
                continue;
            }
            datas.add(segToken.word);
        }
        return datas;
    }

    /**
     * 對單箇中文字符串進行切分 獲取單詞集合
     *
     * @param line 被切分的單詞
     * @return 返回存儲被切分的單詞的 集合
     */
    public static List<String> splitChineseWord(String line) {

        // 創建結巴分詞器
        JiebaSegmenter segmenter = new JiebaSegmenter();
        // 創建存儲 分割後的單詞的結婚
        ArrayList<String> datas = new ArrayList<>();
        // 分詞
        List<SegToken> list = segmenter.process(line, JiebaSegmenter.SegMode.SEARCH);

        // 過濾掉 無意義的單詞
        for (SegToken segToken : list) {

            if (isChinese(segToken.word)) {
                datas.add(segToken.word);
            }
        }
        return datas;
    }


    /**
     * 過濾掉停用單詞
     *
     * @param words         需要過濾的單詞
     * @param stopWordsBags 停用單詞庫
     * @return
     */
    public static List<String> filterStopWords(List<String> words, List<String> stopWordsBags) {

        // 存儲 過濾後單詞
        ArrayList<String> lines = new ArrayList<>();

        for (String word : words) {
            if (!stopWordsBags.contains(word)) {
                lines.add(word);
            }
        }
        return lines;
    }


    /**
     * 指定一個存儲單詞的詞袋 , 切分指定字符串後得到單詞 , 獲取詞袋中對應單詞出現的次數組成double數組
     *
     * @param line    需要統計單詞的語句
     * @param wordBag 存放不重複單詞的詞袋
     * @return 存放單詞詞頻的double數組
     */
    public static double[] getWordCount(String line, TreeMap<String, Integer> wordBag) {

        // 創建存放單詞詞頻的 對象
        TreeMap<String, Integer> countWord = new TreeMap<>();
        // 將詞袋內的單詞添加進 統計詞頻的對象中
        countWord.putAll(wordBag);

        // 對被統計的單詞進行分詞
        List<String> wordDatas = splitWord(line);

        // 統計單詞詞頻
        for (String wordData : wordDatas) {
            // 如果被統計的單詞在詞袋中沒有出現 , 將該單詞丟棄
            if (!countWord.containsKey(wordData)) {
                continue;
            }
            countWord.replace(wordData, countWord.get(wordData) + 1);
        }

        // 將單詞詞頻存放到 double 數組中
        double[] values = new double[countWord.size()];
        int size = 0;
        Iterator<String> countKey = countWord.keySet().iterator();
        while (countKey.hasNext()) {
            String word = countKey.next();
            values[size] = countWord.get(word);
            size++;
        }

        // 返回存放單詞詞頻的double數組
        return values;
    }

    /**
     * 統計 一個list 內的單詞在對應詞袋中出現的頻率
     *
     * @param lines   存儲一條語句切分後的單詞
     * @param wordBag 存放不重複單詞的詞袋
     * @return 存放單詞詞頻的double數組
     */
    public static double[] getWordCount(List<String> lines, TreeMap<String, Integer> wordBag) {

        // 創建存放單詞詞頻的 對象
        TreeMap<String, Integer> countWord = new TreeMap<>();
        // 將詞袋內的單詞添加進 統計詞頻的對象中
        countWord.putAll(wordBag);

        for (String line : lines) {
            // 統計單詞詞頻
            // 如果被統計的單詞在詞袋中沒有出現 , 將該單詞丟棄
            if (countWord.containsKey(line)) {
                countWord.replace(line, countWord.get(line) + 1);
            }
        }

        // 將單詞詞頻存放到 double 數組中
        double[] values = new double[countWord.size()];
        int size = 0;
        Iterator<String> countKey = countWord.keySet().iterator();
        while (countKey.hasNext()) {
            String word = countKey.next();
            values[size] = countWord.get(word);
            size++;
        }

        // 返回存放單詞詞頻的double數組
        return values;
    }

    /**
     * 切分一個字符串 , 獲取不重複的單詞 , 並存儲到一個 Map 中
     * key 爲 單詞 , value 默認是 0
     *
     * @param lines 獲取重複單詞的數據
     * @return 存放不重複單詞的詞袋
     */
    public static TreeMap<String, Integer> getWordBag(String lines) {
        // 分詞
        List<String> values = splitWord(lines);

        // 獲取不重複的單詞
        TreeMap<String, Integer> wordBag = new TreeMap<>();
        for (String line : values) {
            wordBag.put(line, 0);
        }
        // 返回值
        return wordBag;
    }

    /**
     * 從一個存儲字符串的 list集合中讀取字符串 , 獲取不重複的單詞 存儲到 map 中
     * key 爲 不重複的單詞 , value 默認爲 0
     *
     * @param lines 源數據
     * @return 存放不重複單詞的詞袋
     */
    public static TreeMap<String, Integer> getWordBag(List<String> lines) {

        // 分詞
        List<String> values = new ArrayList<String>();
        for (String line : lines) {

            // 同時兼容可以切分 帶有 ham,內容 的 和不帶 ham的
            String[] s = line.split(",");
            String words = s.length > 1 ? s[1] : s[0];
            values.addAll(splitWord(words));
        }
        // 過濾重複單詞
        TreeMap<String, Integer> wordBag = new TreeMap<>();
        for (String value : values) {
            wordBag.put(value, 0);
        }

        // 返回值
        return wordBag;
    }

    /**
     * 將從詞袋文件中讀取的單詞轉化爲 更易使用的 Map
     *
     * @param wordBags
     * @return
     */
    public static TreeMap<String, Integer> getWordBagToMap(List<String> wordBags) {

        TreeMap<String, Integer> wordBag = new TreeMap<>();
        for (String value : wordBags) {
            wordBag.put(value, 0);
        }
        // 返回值
        return wordBag;
    }

    /**
     * 指定一個文件的路徑 , 讀取文件中的單詞 並且獲取不重複的單詞存儲到一個 map中
     * 其中 key 是 單詞 , value 默認爲 0
     *
     * @param sc   SparkContext對象
     * @param path 提取不重複單詞的路徑
     * @return 詞袋
     */
    public static TreeMap<String, Integer> getWordBag(JavaSparkContext sc, String path) {

        // 文件不存在時拋出異常
        if (!new File(path).exists()) {
            new RuntimeException("file is not exist");
        }

        // 讀取文件
        JavaRDD<String> lines = sc.textFile(path);
        // 將 RDD 轉化爲 list
        List<String> datas = lines.take((int) lines.count());

        // 獲取詞袋 併發揮
        return getWordBag(datas);
    }


    /**
     * 判斷 一個字符串是否包含中文
     *
     * @param word 需要判斷的單詞
     * @return 如果是中文 則返回 true , 否則返回 false
     */
    public static boolean isChinese(String word) {

        // 正則匹配
        // String parm="[\\u4e00-\\u9fa5]+"; // 表示一個或者多箇中文
        String parm = ".*[\\u4e00-\\u9faf].*"; // 表示一個或者多箇中文
        // 編譯
        Pattern pattern = Pattern.compile(parm);
        Matcher m = pattern.matcher(word);
        return m.matches();
    }

    /**
     * 過濾一個字符串中不是中文的部分 然後將所有中文詞存儲到一個集合中返回
     *
     * @param lines 被過濾的集合
     * @return 存儲過濾後的數據的集合
     */
    public static List<String> filterNotChineseWords(List<String> lines) {

        // 存儲過濾後的數據
        ArrayList<String> newLines = new ArrayList<>();

        // 對數據過濾
        for (String line : lines) {
            // 判斷是否是中文
            if (isChinese(line)) {
                newLines.add(line);
            }
        }
        return newLines;
    }

    /**
     * 去除掉字符串中不是中文的部分 , 然後返回剩下是中文的字符串
     *
     * @param line 被過濾的字符串
     * @return 存儲過濾後的字符串
     */
    public static String filterNotChineseWords(String line) {

        // 對字符串進行切分
        List<String> lines = BayesUtils.splitWord(line);

        // 存儲切分後的字符串
        StringBuilder words = new StringBuilder();

        // 對數據過濾
        for (String s : lines) {
            // 判斷是否是中文
            if (isChinese(s)) {
                words.append(s); // 認真寫代碼
            }
        }

        return words.toString();
    }

    public static void main(String[] args) {

        List<String> stopwords = new ArrayList<>();
        stopwords.add("java");
        stopwords.add("c");

        List<String> words = new ArrayList<>();
        words.add("hello");
        words.add("java");
        words.add("c");
        words.add("jk");

        words = filterStopWords(words, stopwords);

        for (String word : words) {
            System.out.println(word);
        }
    }
}

預測的方法

 /**
     * 預測郵件是否爲垃圾郵件
     *
     * @param predictData 要預測的郵件
     * @return 返回預測結果 如果是 0 代表是正常郵件 如果是 1 代表郵件
     */
    public double predictChinese(String predictData) {

        // 獲取 SparkContext 對象
        JavaSparkContext sc = BayesUtils.init();
        // 獲取模型
        String modelPath = Predicted.class.getClassLoader().getResource("model_cn/naive_model_optimize").getPath();
        NaiveBayesModel model = NaiveBayesModel.load(sc.sc() , modelPath);

        // 獲取詞袋
        // 此處加載 resources 的資源時 一定要 獲取類加載器
        String wordBagPath = Predicted.class.getClassLoader().getResource("wordbags_cn").getPath();

        JavaRDD<String> wordBagRDD = sc.textFile(wordBagPath);
        TreeMap<String, Integer> wordBags = BayesUtils.getWordBag(wordBagRDD.collect());

        System.out.println("wordBags : ==>" + wordBags.size());

        // 獲取停用單詞
        String stopwWordPath = Predicted.class.getClassLoader().getResource("stopWords.txt").getPath();

        double[] datas = BayesUtils.getWordCount(predictData, wordBags);

        // 開始預測
        // 將語句向量封裝到 LabeledPoint 對象中
        LabeledPoint lp = new LabeledPoint(-1, new DenseVector(datas));

        // 獲取預測結果
        double predictResult = model.predict(lp.features());

        // 關閉 SparkContext
        sc.stop();
        return predictResult;
    }

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章