tfidf算法+餘弦相似度算法計算文本相似度

TF-IDF(term frequency–inverse document frequency)是一種用於信息檢索與數據挖掘的常用加權技術。TF意思是詞頻(Term Frequency),IDF意思是逆向文件頻率(Inverse Document Frequency)。

思想:對文本進行分詞,然後用tfidf算法得到文本對應的詞向量,然後利用餘弦算法求相似度
需要的jar :je-analysis-1.5.3.jar ,lucene-core-2.4.1.jar(高於4的版本會有衝突)

/**
 * 直接匹配2個文本
 * 
 * @author rock
 *
 */
public class GetText {
    private static List<String> fileList = new ArrayList<String>();
    private static HashMap<String, HashMap<String, Double>> allTheTf = new HashMap<String, HashMap<String, Double>>();
    private static HashMap<String, HashMap<String, Integer>> allTheNormalTF = new HashMap<String, HashMap<String, Integer>>();
    private static LinkedHashMap<String, Double[]> vectorMap = new LinkedHashMap<String, Double[]>();

    /**
     * 分詞
     * 
     * @author create by rock
     */
    public static String[] TextcutWord(String text) throws IOException {
        String[] cutWordResult = null;
        MMAnalyzer analyzer = new MMAnalyzer();
        String tempCutWordResult = analyzer.segment(text, " ");
        cutWordResult = tempCutWordResult.split(" ");
        return cutWordResult;
    }

    public static Map<String, HashMap<String, Integer>> NormalTFOfAll(String key1, String key2, String text1,
            String text2) throws IOException {
        if (allTheNormalTF.get(key1) == null) {
            HashMap<String, Integer> dict1 = new HashMap<String, Integer>();
            dict1 = normalTF(TextcutWord(text1));
            allTheNormalTF.put(key1, dict1);
        }
        if (allTheNormalTF.get(key2) == null) {
            HashMap<String, Integer> dict2 = new HashMap<String, Integer>();
            dict2 = normalTF(TextcutWord(text2));
            allTheNormalTF.put(key2, dict2);
        }
        return allTheNormalTF;
    }

    public static Map<String, HashMap<String, Double>> tfOfAll(String key1, String key2, String text1, String text2)
            throws IOException {
            allTheTf.clear();
            HashMap<String, Double> dict1 = new HashMap<String, Double>();
            HashMap<String, Double> dict2 = new HashMap<String, Double>();
            dict1 = tf(TextcutWord(text1));
            dict2 = tf(TextcutWord(text2));
            allTheTf.put(key1, dict1);
            allTheTf.put(key2, dict2);
            return allTheTf;
    }

    /**
     * 計算詞頻
     * 
     * @author create by rock
     */
    public static HashMap<String, Double> tf(String[] cutWordResult) {
        HashMap<String, Double> tf = new HashMap<String, Double>();// 正規化
        int wordNum = cutWordResult.length;
        int wordtf = 0;
        for (int i = 0; i < wordNum; i++) {
            wordtf = 0;
            if (cutWordResult[i] != " ") {
                for (int j = 0; j < wordNum; j++) {
                    if (i != j) {
                        if (cutWordResult[i].equals(cutWordResult[j])) {
                            cutWordResult[j] = " ";
                            wordtf++;
                        }
                    }
                }
                tf.put(cutWordResult[i], (new Double(++wordtf)) / wordNum);
                cutWordResult[i] = " ";
            }
        }
        return tf;
    }

    public static HashMap<String, Integer> normalTF(String[] cutWordResult) {
        HashMap<String, Integer> tfNormal = new HashMap<String, Integer>();// 沒有正規化
        int wordNum = cutWordResult.length;
        int wordtf = 0;
        for (int i = 0; i < wordNum; i++) {
            wordtf = 0;
            if (cutWordResult[i] != " ") {
                for (int j = 0; j < wordNum; j++) {
                    if (i != j) {
                        if (cutWordResult[i].equals(cutWordResult[j])) {
                            cutWordResult[j] = " ";
                            wordtf++;
                        }
                    }
                }
                tfNormal.put(cutWordResult[i], ++wordtf);
                cutWordResult[i] = " ";
            }
        }
        return tfNormal;
    }

    public static Map<String, Double> idf(String key1, String key2, String text1, String text2)
            throws FileNotFoundException, UnsupportedEncodingException, IOException {
        // 公式IDF=log((1+|D|)/|Dt|),其中|D|表示文檔總數,|Dt|表示包含關鍵詞t的文檔數量。
        Map<String, Double> idf = new HashMap<String, Double>();
        List<String> located = new ArrayList<String>();

        NormalTFOfAll(key1, key2, text1, text2);

        float Dt = 1;
        float D = allTheNormalTF.size();// 文檔總數
        List<String> key = fileList;// 存儲各個文檔名的List

        String[] keyarr = new String[2];
        keyarr[0] = key1;
        keyarr[1] = key2;

        for(String item :keyarr) {
            if (!fileList.contains(item)) {
                 fileList.add(item);
            }
        }

        Map<String, HashMap<String, Integer>> tfInIdf = allTheNormalTF;// 存儲各個文檔tf的Map

        for (int i = 0; i < D; i++) {
            HashMap<String, Integer> temp = tfInIdf.get(key.get(i));
            for (String word : temp.keySet()) {
                Dt = 1;
                if (!(located.contains(word))) {
                    for (int k = 0; k < D; k++) {
                        if (k != i) {
                            HashMap<String, Integer> temp2 = tfInIdf.get(key.get(k));
                            if (temp2.keySet().contains(word)) {
                                located.add(word);
                                Dt = Dt + 1;
                                continue;
                            }
                        }
                    }
                    idf.put(word, (double) Log.log((1 + D) / Dt, 10));
                }
            }
        }
        return idf;
    }

    public static Map<String, HashMap<String, Double>> tfidf(String key1, String key2, String text1, String text2)
            throws IOException {
        Map<String, Double> idf = idf(key1, key2, text1, text2);
        tfOfAll(key1, key2, text1, text2);
        for (String key : allTheTf.keySet()) {
            Map<String, Double> singelFile = allTheTf.get(key);
            int length = idf.size();
            Double[] arr = new Double[length];
            int index = 0;

            for (String word : singelFile.keySet()) {
                singelFile.put(word, (idf.get(word)) * singelFile.get(word));
            }

            for (String word : idf.keySet()) {  
                arr[index] = singelFile.get(word) != null ?singelFile.get(word):0d;
                index++;
            }
            vectorMap.put(key, arr);
        }
        return allTheTf;
    }

    /* 得到詞向量以後,用餘弦相似度匹配 */
    public static Double sim(String key1, String key2) {
        Double[] arr1 = vectorMap.get(key1);
        Double[] arr2 = vectorMap.get(key2);
        int length = arr1.length;
        Double result1 = 0.00; // 向量1的模
        Double result2 = 0.00; // 向量2的模
        Double sum = 0d;
        if (length == 0) {
            return 0d;
        }
        for (int i = 0; i < length; i++) {
            result1 += arr1[i] * arr1[i];
            result2 += arr2[i] * arr2[i];
            sum += arr1[i] * arr2[i];
        }
        Double result = Math.sqrt(result1 * result2);
        System.out.println(key1 + "和" + key2 + "相似度" + sum / result);

        return sum / result;

    }

}

匹配多個文件

/**
 * 從語料倉庫去匹配
 * @author rock
 *
 */
public class ReadFiles {

    private static List<String> fileList = new ArrayList<String>();
    private static HashMap<String, HashMap<String, Float>> allTheTf = new HashMap<String, HashMap<String, Float>>();
    private static HashMap<String, HashMap<String, Integer>> allTheNormalTF = new HashMap<String, HashMap<String, Integer>>();
    private static LinkedHashMap<String, Float[]> vectorMap = new LinkedHashMap<String, Float[]>();
    /**
     * 讀取語料倉庫
     * @author create by rock
     */
    public static List<String> readDirs(String filepath) throws FileNotFoundException, IOException {
        try {
            File file = new File(filepath);
            if (!file.isDirectory()) {
                System.out.println("輸入的參數應該爲[文件夾名]");
                System.out.println("filepath: " + file.getAbsolutePath());
            } else if (file.isDirectory()) {
                String[] filelist = file.list();
                for (int i = 0; i < filelist.length; i++) {
                    File readfile = new File(filepath + "\\" + filelist[i]);
                    if (!readfile.isDirectory()) {
                        fileList.add(readfile.getAbsolutePath());
                    } else if (readfile.isDirectory()) {
                        readDirs(filepath + "\\" + filelist[i]);
                    }
                }
            }

        } catch (FileNotFoundException e) {
            System.out.println(e.getMessage());
        }
        return fileList;
    }

    /**
     * 讀取txt文件
     * @author create by rock
     */
    public static String readFiles(String file) throws FileNotFoundException, IOException {
        StringBuffer sb = new StringBuffer();
        InputStreamReader is = new InputStreamReader(new FileInputStream(file), "utf-8");
        BufferedReader br = new BufferedReader(is);
        String line = br.readLine();
        while (line != null) {
            sb.append(line).append("\r\n");
            line = br.readLine();
        }
        br.close();
        return sb.toString();
    }

    /**
     * 分詞
     * @author create by rock
     */
    public static String[] cutWord(String file) throws IOException {
        String[] cutWordResult = null;
        String text = ReadFiles.readFiles(file);
        MMAnalyzer analyzer = new MMAnalyzer();
        String tempCutWordResult = analyzer.segment(text, " ");
        cutWordResult = tempCutWordResult.split(" ");
        return cutWordResult;
    }



    /**
     * 計算詞頻
     * @author create by rock
     */
    public static HashMap<String, Float> tf(String[] cutWordResult) {
        HashMap<String, Float> tf = new HashMap<String, Float>();//正規化
        int wordNum = cutWordResult.length;
        int wordtf = 0;
        for (int i = 0; i < wordNum; i++) {
            wordtf = 0;
            for (int j = 0; j < wordNum; j++) {
                if (cutWordResult[i] != " " && i != j) {
                    if (cutWordResult[i].equals(cutWordResult[j])) {
                        cutWordResult[j] = " ";
                        wordtf++;
                    }
                }
            }
            if (cutWordResult[i] != " ") {
                tf.put(cutWordResult[i], (new Float(++wordtf)) / wordNum);
                cutWordResult[i] = " ";
            }
        }
        return tf;
    }


    public static HashMap<String, Integer> normalTF(String[] cutWordResult) {
        HashMap<String, Integer> tfNormal = new HashMap<String, Integer>();//沒有正規化
        int wordNum = cutWordResult.length;
        int wordtf = 0;
        for (int i = 0; i < wordNum; i++) {
            wordtf = 0;
            if (cutWordResult[i] != " ") {
                for (int j = 0; j < wordNum; j++) {
                    if (i != j) {
                        if (cutWordResult[i].equals(cutWordResult[j])) {
                            cutWordResult[j] = " ";
                            wordtf++;

                        }
                    }
                }
                tfNormal.put(cutWordResult[i], ++wordtf);
                cutWordResult[i] = " ";
            }
        }
        return tfNormal;
    }

    public static Map<String, HashMap<String, Float>> tfOfAll(String dir) throws IOException {
        List<String> fileList = ReadFiles.readDirs(dir);
        for (String file : fileList) {
            HashMap<String, Float> dict = new HashMap<String, Float>();
            dict = ReadFiles.tf(ReadFiles.cutWord(file));
            allTheTf.put(file, dict);
        }
        return allTheTf;
    }

    /**
     * 自定義文檔內容
     * @author create by rock
     */
    public static Map<String, HashMap<String, Float>> tfOfAll(String[] files) throws IOException {
        for (String file : files) {
            HashMap<String, Float> dict = new HashMap<String, Float>();
            dict = ReadFiles.tf(ReadFiles.cutWord(file));
            allTheTf.put(file, dict);
        }
        return allTheTf;
    }


    public static Map<String, HashMap<String, Integer>> NormalTFOfAll(String dir) throws IOException {
        List<String> fileList = ReadFiles.readDirs(dir);
        for (int i = 0; i < fileList.size(); i++) {
            HashMap<String, Integer> dict = new HashMap<String, Integer>();
            dict = ReadFiles.normalTF(ReadFiles.cutWord(fileList.get(i)));
            allTheNormalTF.put(fileList.get(i), dict);
        }
        return allTheNormalTF;
    }

    public static Map<String, Float> idf(String dir) throws FileNotFoundException, UnsupportedEncodingException, IOException {
        //公式IDF=log((1+|D|)/|Dt|),其中|D|表示文檔總數,|Dt|表示包含關鍵詞t的文檔數量。
        Map<String, Float> idf = new HashMap<String, Float>();
        List<String> located = new ArrayList<String>();
        NormalTFOfAll(dir);

        float Dt = 1;
        float D = allTheNormalTF.size();//文檔總數
        List<String> key = fileList;//存儲各個文檔名的List
        Map<String, HashMap<String, Integer>> tfInIdf = allTheNormalTF;//存儲各個文檔tf的Map

        for (int i = 0; i < D; i++) {
            HashMap<String, Integer> temp = tfInIdf.get(key.get(i));
            for (String word : temp.keySet()) {
                Dt = 1;
                if (!(located.contains(word))) {
                    for (int k = 0; k < D; k++) {
                        if (k != i) {
                            HashMap<String, Integer> temp2 = tfInIdf.get(key.get(k));
                            if (temp2.keySet().contains(word)) {
                                located.add(word);
                                Dt = Dt + 1;
                                continue;
                            }
                        }
                    }
                    idf.put(word, Log.log((1 + D) / Dt, 10));
                }
            }
        }
        return idf;
    }

    public static Map<String, HashMap<String, Float>> tfidf(String dir) throws IOException {
        Map<String, Float> idf = ReadFiles.idf(dir);
        Map<String, HashMap<String, Float>> tf = ReadFiles.tfOfAll(dir);
        for (String file : tf.keySet()) {
            Map<String, Float> singelFile = tf.get(file);
            int length = idf.size();
            Float[] arr = new Float[length];
            int index = 0;
            for (String word : singelFile.keySet()) {
                singelFile.put(word, (idf.get(word)) * singelFile.get(word));
            }
            for(String word : idf.keySet()) {
                if(singelFile.get(word) != null) {
                    arr[index] = singelFile.get(word);
                }else {
                    arr[index] = 0f;
                }
                index++;
            }
            vectorMap.put(file, arr);
        }    
        return tf;
    } 



    public static double sim(String file1,String file2) {
        Float [] arr1 = vectorMap.get(file1);
        Float [] arr2 = vectorMap.get(file2);
         int length = arr1.length;
         double result1 = 0.00;  //向量1的模
         double result2 = 0.00;  //向量2的模
         Float sum = 0f;

        for(int i =0;i<length;i++) {
            result1 += arr1[i]*arr1[i];
            result2 += arr2[i]*arr2[i];
            sum+=arr1[i]*arr2[i];
        }
        double result = Math.sqrt(result1*result2);
        System.out.println(sum/result);
        return sum/result;

    }

}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章