FP-Tree思想與實現

在關聯規則挖掘領域最經典的算法法是Apriori,其致命的缺點是需要多次掃描事務數據庫。於是人們提出了各種裁剪(prune)數據集的方法以減少I/O開支,韓嘉煒老師的FP-Tree算法就是其中非常高效的一種。

支持度和置信度

嚴格地說Apriori和FP-Tree都是尋找頻繁項集的算法,頻繁項集就是所謂的“支持度”比較高的項集,下面解釋一下支持度和置信度的概念。

設事務數據庫爲:

複製代碼
A  E  F  G

A  F  G

A  B  E  F  G

E  F  G
複製代碼

則{A,F,G}的支持度數爲3,支持度爲3/4。

{F,G}的支持度數爲4,支持度爲4/4。

{A}的支持度數爲3,支持度爲3/4。

{F,G}=>{A}的置信度爲:{A,F,G}的支持度數 除以 {F,G}的支持度數,即3/4

{A}=>{F,G}的置信度爲:{A,F,G}的支持度數 除以 {A}的支持度數,即3/3

強關聯規則挖掘是在滿足一定支持度的情況下尋找置信度達到閾值的所有模式。

FP-Tree算法

我們舉個例子來詳細講解FP-Tree算法的完整實現。

事務數據庫如下,一行表示一條購物記錄:

複製代碼
牛奶,雞蛋,麪包,薯片

雞蛋,爆米花,薯片,啤酒

雞蛋,麪包,薯片

牛奶,雞蛋,麪包,爆米花,薯片,啤酒

牛奶,麪包,啤酒

雞蛋,麪包,啤酒

牛奶,麪包,薯片

牛奶,雞蛋,麪包,黃油,薯片

牛奶,雞蛋,黃油,薯片
複製代碼

我們的目的是要找出哪些商品總是相伴出現的,比如人們買薯片的時候通常也會買雞蛋,則[薯片,雞蛋]就是一條頻繁模式(frequent pattern)。

FP-Tree算法第一步:掃描事務數據庫,每項商品按頻數遞減排序,並刪除頻數小於最小支持度MinSup的商品。(第一次掃描數據庫)

薯片:7雞蛋:7麪包:7牛奶:6啤酒:4                       (這裏我們令MinSup=3)

以上結果就是頻繁1項集,記爲F1。

第二步:對於每一條購買記錄,按照F1中的順序重新排序。(第二次也是最後一次掃描數據庫)

複製代碼
薯片,雞蛋,麪包,牛奶

薯片,雞蛋,啤酒

薯片,雞蛋,麪包

薯片,雞蛋,麪包,牛奶,啤酒

麪包,牛奶,啤酒

雞蛋,麪包,啤酒

薯片,麪包,牛奶

薯片,雞蛋,麪包,牛奶

薯片,雞蛋,牛奶
複製代碼

第三步:把第二步得到的各條記錄插入到FP-Tree中。剛開始時後綴模式爲空。

插入每一條(薯片,雞蛋,麪包,牛奶)之後

插入第二條記錄(薯片,雞蛋,啤酒)

插入第三條記錄(麪包,牛奶,啤酒)

估計你也知道怎麼插了,最終生成的FP-Tree是:

上圖中左邊的那一叫做表頭項,樹中相同名稱的節點要鏈接起來,鏈表的第一個元素就是表頭項裏的元素。

如果FP-Tree爲空(只含一個虛的root節點),則FP-Growth函數返回。

此時輸出表頭項的每一項+postModel,支持度爲表頭項中對應項的計數。

第四步:從FP-Tree中找出頻繁項。

遍歷表頭項中的每一項(我們拿“牛奶:6”爲例),對於各項都執行以下(1)到(5)的操作:

(1)從FP-Tree中找到所有的“牛奶”節點,向上遍歷它的祖先節點,得到4條路徑:

複製代碼
薯片:7,雞蛋:6,牛奶:1

薯片:7,雞蛋:6,麪包:4,牛奶:3

薯片:7,麪包:1,牛奶:1

麪包:1,牛奶:1
複製代碼

對於每一條路徑上的節點,其count都設置爲牛奶的count

複製代碼
薯片:1,雞蛋:1,牛奶:1

薯片:3,雞蛋:3,麪包:3,牛奶:3

薯片:1,麪包:1,牛奶:1

麪包:1,牛奶:1
複製代碼

因爲每一項末尾都是牛奶,可以把牛奶去掉,得到條件模式基(Conditional Pattern Base,CPB),此時的後綴模式是:(牛奶)。

複製代碼
薯片:1,雞蛋:1

薯片:3,雞蛋:3,麪包:3

薯片:1,麪包:1

麪包:1
複製代碼

(2)我們把上面的結果當作原始的事務數據庫,返回到第3步,遞歸迭代運行。

沒講清楚,你可以參考這篇博客,直接看核心代碼吧:

[java] view plaincopy
  1. public void FPGrowth(List<List<String>> transRecords,  
  2.         List<String> postPattern,Context context) throws IOException, InterruptedException {  
  3.     // 構建項頭表,同時也是頻繁1項集  
  4.     ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);  
  5.     // 構建FP-Tree  
  6.     TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);  
  7.     // 如果FP-Tree爲空則返回  
  8.     if (treeRoot.getChildren()==null || treeRoot.getChildren().size() == 0)  
  9.         return;  
  10.     //輸出項頭表的每一項+postPattern  
  11.     if(postPattern!=null){  
  12.         for (TreeNode header : HeaderTable) {  
  13.             String outStr=header.getName();  
  14.             int count=header.getCount();  
  15.             for (String ele : postPattern)  
  16.                 outStr+="\t" + ele;  
  17.             context.write(new IntWritable(count), new Text(outStr));  
  18.         }  
  19.     }  
  20.     // 找到項頭表的每一項的條件模式基,進入遞歸迭代  
  21.     for (TreeNode header : HeaderTable) {  
  22.         // 後綴模式增加一項  
  23.         List<String> newPostPattern = new LinkedList<String>();  
  24.         newPostPattern.add(header.getName());  
  25.         if (postPattern != null)  
  26.             newPostPattern.addAll(postPattern);  
  27.         // 尋找header的條件模式基CPB,放入newTransRecords中  
  28.         List<List<String>> newTransRecords = new LinkedList<List<String>>();  
  29.         TreeNode backnode = header.getNextHomonym();  
  30.         while (backnode != null) {  
  31.             int counter = backnode.getCount();  
  32.             List<String> prenodes = new ArrayList<String>();  
  33.             TreeNode parent = backnode;  
  34.             // 遍歷backnode的祖先節點,放到prenodes中  
  35.             while ((parent = parent.getParent()).getName() != null) {  
  36.                 prenodes.add(parent.getName());  
  37.             }  
  38.             while (counter-- > 0) {  
  39.                 newTransRecords.add(prenodes);  
  40.             }  
  41.             backnode = backnode.getNextHomonym();  
  42.         }  
  43.         // 遞歸迭代  
  44.         FPGrowth(newTransRecords, newPostPattern,context);  
  45.     }  
  46. }  

對於FP-Tree已經是單枝的情況,就沒有必要再遞歸調用FPGrowth了,直接輸出整條路徑上所有節點的各種組合+postModel就可了。例如當FP-Tree爲:

我們直接輸出:

3  A+postModel

3  B+postModel

3  A+B+postModel

就可以了。

如何按照上面代碼裏的做法,是先輸出:

3  A+postModel

3  B+postModel

然後把B插入到postModel的頭部,重新建立一個FP-Tree,這時Tree中只含A,於是輸出

3  A+(B+postModel)

兩種方法結果是一樣的,但畢竟重新建立FP-Tree計算量大些。

Java實現

FP樹節點定義

[java] view plaincopy
  1. package fptree;  
  2.     
  3. import java.util.ArrayList;  
  4. import java.util.List;  
  5.     
  6. public class TreeNode implements Comparable<TreeNode> {  
  7.     
  8.     private String name; // 節點名稱  
  9.     private int count; // 計數  
  10.     private TreeNode parent; // 父節點  
  11.     private List<TreeNode> children; // 子節點  
  12.     private TreeNode nextHomonym; // 下一個同名節點  
  13.     
  14.     public TreeNode() {  
  15.     
  16.     }  
  17.     
  18.     public TreeNode(String name) {  
  19.         this.name = name;  
  20.     }  
  21.     
  22.     public String getName() {  
  23.         return name;  
  24.     }  
  25.     
  26.     public void setName(String name) {  
  27.         this.name = name;  
  28.     }  
  29.     
  30.     public int getCount() {  
  31.         return count;  
  32.     }  
  33.     
  34.     public void setCount(int count) {  
  35.         this.count = count;  
  36.     }  
  37.     
  38.     public TreeNode getParent() {  
  39.         return parent;  
  40.     }  
  41.     
  42.     public void setParent(TreeNode parent) {  
  43.         this.parent = parent;  
  44.     }  
  45.     
  46.     public List<TreeNode> getChildren() {  
  47.         return children;  
  48.     }  
  49.     
  50.     public void addChild(TreeNode child) {  
  51.         if (this.getChildren() == null) {  
  52.             List<TreeNode> list = new ArrayList<TreeNode>();  
  53.             list.add(child);  
  54.             this.setChildren(list);  
  55.         } else {  
  56.             this.getChildren().add(child);  
  57.         }  
  58.     }  
  59.     
  60.     public TreeNode findChild(String name) {  
  61.         List<TreeNode> children = this.getChildren();  
  62.         if (children != null) {  
  63.             for (TreeNode child : children) {  
  64.                 if (child.getName().equals(name)) {  
  65.                     return child;  
  66.                 }  
  67.             }  
  68.         }  
  69.         return null;  
  70.     }  
  71.     
  72.     public void setChildren(List<TreeNode> children) {  
  73.         this.children = children;  
  74.     }  
  75.     
  76.     public void printChildrenName() {  
  77.         List<TreeNode> children = this.getChildren();  
  78.         if (children != null) {  
  79.             for (TreeNode child : children) {  
  80.                 System.out.print(child.getName() + " ");  
  81.             }  
  82.         } else {  
  83.             System.out.print("null");  
  84.         }  
  85.     }  
  86.     
  87.     public TreeNode getNextHomonym() {  
  88.         return nextHomonym;  
  89.     }  
  90.     
  91.     public void setNextHomonym(TreeNode nextHomonym) {  
  92.         this.nextHomonym = nextHomonym;  
  93.     }  
  94.     
  95.     public void countIncrement(int n) {  
  96.         this.count += n;  
  97.     }  
  98.     
  99.     @Override  
  100.     public int compareTo(TreeNode arg0) {  
  101.         // TODO Auto-generated method stub  
  102.         int count0 = arg0.getCount();  
  103.         // 跟默認的比較大小相反,導致調用Arrays.sort()時是按降序排列  
  104.         return count0 - this.count;  
  105.     }  
  106. }  

挖掘頻繁模式
[java] view plaincopy
  1. package fptree;  
  2.    
  3. import java.io.BufferedReader;  
  4. import java.io.FileReader;  
  5. import java.io.IOException;  
  6. import java.util.ArrayList;  
  7. import java.util.Collections;  
  8. import java.util.Comparator;  
  9. import java.util.HashMap;  
  10. import java.util.LinkedList;  
  11. import java.util.List;  
  12. import java.util.Map;  
  13. import java.util.Map.Entry;  
  14. import java.util.Set;  
  15.    
  16. public class FPTree {  
  17.    
  18.     private int minSuport;  
  19.    
  20.     public int getMinSuport() {  
  21.         return minSuport;  
  22.     }  
  23.    
  24.     public void setMinSuport(int minSuport) {  
  25.         this.minSuport = minSuport;  
  26.     }  
  27.    
  28.     // 從若干個文件中讀入Transaction Record  
  29.     public List<List<String>> readTransRocords(String... filenames) {  
  30.         List<List<String>> transaction = null;  
  31.         if (filenames.length > 0) {  
  32.             transaction = new LinkedList<List<String>>();  
  33.             for (String filename : filenames) {  
  34.                 try {  
  35.                     FileReader fr = new FileReader(filename);  
  36.                     BufferedReader br = new BufferedReader(fr);  
  37.                     try {  
  38.                         String line;  
  39.                         List<String> record;  
  40.                         while ((line = br.readLine()) != null) {  
  41.                             if(line.trim().length()>0){  
  42.                                 String str[] = line.split(",");  
  43.                                 record = new LinkedList<String>();  
  44.                                 for (String w : str)  
  45.                                     record.add(w);  
  46.                                 transaction.add(record);  
  47.                             }  
  48.                         }  
  49.                     } finally {  
  50.                         br.close();  
  51.                     }  
  52.                 } catch (IOException ex) {  
  53.                     System.out.println("Read transaction records failed."  
  54.                             + ex.getMessage());  
  55.                     System.exit(1);  
  56.                 }  
  57.             }  
  58.         }  
  59.         return transaction;  
  60.     }  
  61.    
  62.     // FP-Growth算法  
  63.     public void FPGrowth(List<List<String>> transRecords,  
  64.             List<String> postPattern) {  
  65.         // 構建項頭表,同時也是頻繁1項集  
  66.         ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);  
  67.         // 構建FP-Tree  
  68.         TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);  
  69.         // 如果FP-Tree爲空則返回  
  70.         if (treeRoot.getChildren()==null || treeRoot.getChildren().size() == 0)  
  71.             return;  
  72.         //輸出項頭表的每一項+postPattern  
  73.         if(postPattern!=null){  
  74.             for (TreeNode header : HeaderTable) {  
  75.                 System.out.print(header.getCount() + "\t" + header.getName());  
  76.                 for (String ele : postPattern)  
  77.                     System.out.print("\t" + ele);  
  78.                 System.out.println();  
  79.             }  
  80.         }  
  81.         // 找到項頭表的每一項的條件模式基,進入遞歸迭代  
  82.         for (TreeNode header : HeaderTable) {  
  83.             // 後綴模式增加一項  
  84.             List<String> newPostPattern = new LinkedList<String>();  
  85.             newPostPattern.add(header.getName());  
  86.             if (postPattern != null)  
  87.                 newPostPattern.addAll(postPattern);  
  88.             // 尋找header的條件模式基CPB,放入newTransRecords中  
  89.             List<List<String>> newTransRecords = new LinkedList<List<String>>();  
  90.             TreeNode backnode = header.getNextHomonym();  
  91.             while (backnode != null) {  
  92.                 int counter = backnode.getCount();  
  93.                 List<String> prenodes = new ArrayList<String>();  
  94.                 TreeNode parent = backnode;  
  95.                 // 遍歷backnode的祖先節點,放到prenodes中  
  96.                 while ((parent = parent.getParent()).getName() != null) {  
  97.                     prenodes.add(parent.getName());  
  98.                 }  
  99.                 while (counter-- > 0) {  
  100.                     newTransRecords.add(prenodes);  
  101.                 }  
  102.                 backnode = backnode.getNextHomonym();  
  103.             }  
  104.             // 遞歸迭代  
  105.             FPGrowth(newTransRecords, newPostPattern);  
  106.         }  
  107.     }  
  108.    
  109.     // 構建項頭表,同時也是頻繁1項集  
  110.     public ArrayList<TreeNode> buildHeaderTable(List<List<String>> transRecords) {  
  111.         ArrayList<TreeNode> F1 = null;  
  112.         if (transRecords.size() > 0) {  
  113.             F1 = new ArrayList<TreeNode>();  
  114.             Map<String, TreeNode> map = new HashMap<String, TreeNode>();  
  115.             // 計算事務數據庫中各項的支持度  
  116.             for (List<String> record : transRecords) {  
  117.                 for (String item : record) {  
  118.                     if (!map.keySet().contains(item)) {  
  119.                         TreeNode node = new TreeNode(item);  
  120.                         node.setCount(1);  
  121.                         map.put(item, node);  
  122.                     } else {  
  123.                         map.get(item).countIncrement(1);  
  124.                     }  
  125.                 }  
  126.             }  
  127.             // 把支持度大於(或等於)minSup的項加入到F1中  
  128.             Set<String> names = map.keySet();  
  129.             for (String name : names) {  
  130.                 TreeNode tnode = map.get(name);  
  131.                 if (tnode.getCount() >= minSuport) {  
  132.                     F1.add(tnode);  
  133.                 }  
  134.             }  
  135.             Collections.sort(F1);  
  136.             return F1;  
  137.         } else {  
  138.             return null;  
  139.         }  
  140.     }  
  141.    
  142.     // 構建FP-Tree  
  143.     public TreeNode buildFPTree(List<List<String>> transRecords,  
  144.             ArrayList<TreeNode> F1) {  
  145.         TreeNode root = new TreeNode(); // 創建樹的根節點  
  146.         for (List<String> transRecord : transRecords) {  
  147.             LinkedList<String> record = sortByF1(transRecord, F1);  
  148.             TreeNode subTreeRoot = root;  
  149.             TreeNode tmpRoot = null;  
  150.             if (root.getChildren() != null) {  
  151.                 while (!record.isEmpty()  
  152.                         && (tmpRoot = subTreeRoot.findChild(record.peek())) != null) {  
  153.                     tmpRoot.countIncrement(1);  
  154.                     subTreeRoot = tmpRoot;  
  155.                     record.poll();  
  156.                 }  
  157.             }  
  158.             addNodes(subTreeRoot, record, F1);  
  159.         }  
  160.         return root;  
  161.     }  
  162.    
  163.     // 把交易記錄按項的頻繁程序降序排列  
  164.     public LinkedList<String> sortByF1(List<String> transRecord,  
  165.             ArrayList<TreeNode> F1) {  
  166.         Map<String, Integer> map = new HashMap<String, Integer>();  
  167.         for (String item : transRecord) {  
  168.             // 由於F1已經是按降序排列的,  
  169.             for (int i = 0; i < F1.size(); i++) {  
  170.                 TreeNode tnode = F1.get(i);  
  171.                 if (tnode.getName().equals(item)) {  
  172.                     map.put(item, i);  
  173.                 }  
  174.             }  
  175.         }  
  176.         ArrayList<Entry<String, Integer>> al = new ArrayList<Entry<String, Integer>>(  
  177.                 map.entrySet());  
  178.         Collections.sort(al, new Comparator<Map.Entry<String, Integer>>() {  
  179.             @Override  
  180.             public int compare(Entry<String, Integer> arg0,  
  181.                     Entry<String, Integer> arg1) {  
  182.                 // 降序排列  
  183.                 return arg0.getValue() - arg1.getValue();  
  184.             }  
  185.         });  
  186.         LinkedList<String> rest = new LinkedList<String>();  
  187.         for (Entry<String, Integer> entry : al) {  
  188.             rest.add(entry.getKey());  
  189.         }  
  190.         return rest;  
  191.     }  
  192.    
  193.     // 把record作爲ancestor的後代插入樹中  
  194.     public void addNodes(TreeNode ancestor, LinkedList<String> record,  
  195.             ArrayList<TreeNode> F1) {  
  196.         if (record.size() > 0) {  
  197.             while (record.size() > 0) {  
  198.                 String item = record.poll();  
  199.                 TreeNode leafnode = new TreeNode(item);  
  200.                 leafnode.setCount(1);  
  201.                 leafnode.setParent(ancestor);  
  202.                 ancestor.addChild(leafnode);  
  203.    
  204.                 for (TreeNode f1 : F1) {  
  205.                     if (f1.getName().equals(item)) {  
  206.                         while (f1.getNextHomonym() != null) {  
  207.                             f1 = f1.getNextHomonym();  
  208.                         }  
  209.                         f1.setNextHomonym(leafnode);  
  210.                         break;  
  211.                     }  
  212.                 }  
  213.    
  214.                 addNodes(leafnode, record, F1);  
  215.             }  
  216.         }  
  217.     }  
  218.    
  219.     public static void main(String[] args) {  
  220.         FPTree fptree = new FPTree();  
  221.         fptree.setMinSuport(3);  
  222.         List<List<String>> transRecords = fptree  
  223.                 .readTransRocords("/home/orisun/test/market");  
  224.         fptree.FPGrowth(transRecords, null);  
  225.     }  
  226. }  

輸入文件

複製代碼
牛奶,雞蛋,麪包,薯片
雞蛋,爆米花,薯片,啤酒
雞蛋,麪包,薯片
牛奶,雞蛋,麪包,爆米花,薯片,啤酒
牛奶,麪包,啤酒
雞蛋,麪包,啤酒
牛奶,麪包,薯片
牛奶,雞蛋,麪包,黃油,薯片
牛奶,雞蛋,黃油,薯片
複製代碼

輸出

複製代碼
6    薯片    雞蛋
5    薯片    麪包
5    雞蛋    麪包
4    薯片    雞蛋    麪包
5    薯片    牛奶
5    麪包    牛奶
4    雞蛋    牛奶
4    薯片    麪包    牛奶
4    薯片    雞蛋    牛奶
3    麪包    雞蛋    牛奶
3    薯片    麪包    雞蛋    牛奶
3    雞蛋    啤酒
3    麪包    啤酒
複製代碼

用Hadoop來實現

在上面的代碼我們把整個事務數據庫放在一個List<List<String>>裏面傳給FPGrowth,在實際中這是不可取的,因爲內存不可能容下整個事務數據庫,我們可能需要從關係關係數據庫中一條一條地讀入來建立FP-Tree。但無論如何 FP-Tree是肯定需要放在內存中的,但內存如果容不下怎麼辦?另外FPGrowth仍然是非常耗時的,你想提高速度怎麼辦?解決辦法:分而治之,並行計算。

我們把原始事務數據庫分成N部分,在N個節點上並行地進行FPGrowth挖掘,最後把關聯規則彙總到一起就可以了。關鍵問題是怎麼“劃分”纔會不遺露任何一條關聯規則呢?參見這篇博客。這裏爲了達到並行計算的目的,採用了一種“冗餘”的劃分方法,即各部分的並集大於原來的集合。這種方法最終求出來的關聯規則也是有冗餘的,比如在節點1上得到一條規則(6:啤酒,尿布),在節點2上得到一條規則(3:尿布,啤酒),顯然節點2上的這條規則是冗餘的,需要採用後續步驟把冗餘的規則去掉。

代碼:

Record.java

[java] view plaincopy
  1. package fptree;  
  2.    
  3. import java.io.DataInput;  
  4. import java.io.DataOutput;  
  5. import java.io.IOException;  
  6. import java.util.Collections;  
  7. import java.util.LinkedList;  
  8.    
  9. import org.apache.hadoop.io.WritableComparable;  
  10.    
  11. public class Record implements WritableComparable<Record>{  
  12.        
  13.     LinkedList<String> list;  
  14.        
  15.     public Record(){  
  16.         list=new LinkedList<String>();  
  17.     }  
  18.        
  19.     public Record(String[] arr){  
  20.         list=new LinkedList<String>();  
  21.         for(int i=0;i<arr.length;i++)  
  22.             list.add(arr[i]);  
  23.     }  
  24.        
  25.     @Override  
  26.     public String toString(){  
  27.         String str=list.get(0);  
  28.         for(int i=1;i<list.size();i++)  
  29.             str+="\t"+list.get(i);  
  30.         return str;  
  31.     }  
  32.    
  33.     @Override  
  34.     public void readFields(DataInput in) throws IOException {  
  35.         list.clear();  
  36.         String line=in.readUTF();  
  37.         String []arr=line.split("\\s+");  
  38.         for(int i=0;i<arr.length;i++)  
  39.             list.add(arr[i]);  
  40.     }  
  41.    
  42.     @Override  
  43.     public void write(DataOutput out) throws IOException {  
  44.         out.writeUTF(this.toString());  
  45.     }  
  46.    
  47.     @Override  
  48.     public int compareTo(Record obj) {  
  49.         Collections.sort(list);  
  50.         Collections.sort(obj.list);  
  51.         return this.toString().compareTo(obj.toString());  
  52.     }  
  53.    
  54. }  

DC_FPTree.java

[java] view plaincopy
  1. package fptree;  
  2.    
  3. import java.io.BufferedReader;  
  4. import java.io.IOException;  
  5. import java.io.InputStreamReader;  
  6. import java.util.ArrayList;  
  7. import java.util.BitSet;  
  8. import java.util.Collections;  
  9. import java.util.Comparator;  
  10. import java.util.HashMap;  
  11. import java.util.LinkedList;  
  12. import java.util.List;  
  13. import java.util.Map;  
  14. import java.util.Map.Entry;  
  15. import java.util.Set;  
  16.    
  17. import org.apache.hadoop.conf.Configuration;  
  18. import org.apache.hadoop.conf.Configured;  
  19. import org.apache.hadoop.fs.FSDataInputStream;  
  20. import org.apache.hadoop.fs.FileSystem;  
  21. import org.apache.hadoop.fs.Path;  
  22. import org.apache.hadoop.io.IntWritable;  
  23. import org.apache.hadoop.io.LongWritable;  
  24. import org.apache.hadoop.io.Text;  
  25. import org.apache.hadoop.mapreduce.Job;  
  26. import org.apache.hadoop.mapreduce.Mapper;  
  27. import org.apache.hadoop.mapreduce.Reducer;  
  28. import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;  
  29. import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;  
  30. import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;  
  31. import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;  
  32. import org.apache.hadoop.util.Tool;  
  33. import org.apache.hadoop.util.ToolRunner;  
  34.    
  35. public class DC_FPTree extends Configured implements Tool {  
  36.    
  37.     private static final int GroupNum = 10;  
  38.     private static final int minSuport=6;  
  39.    
  40.     public static class GroupMapper extends  
  41.             Mapper<LongWritable, Text, IntWritable, Record> {  
  42.         List<String> freq = new LinkedList<String>(); // 頻繁1項集  
  43.         List<List<String>> freq_group = new LinkedList<List<String>>(); // 分組後的頻繁1項集  
  44.    
  45.         @Override  
  46.         public void setup(Context context) throws IOException {  
  47.             // 從文件讀入頻繁1項集  
  48.             FileSystem fs = FileSystem.get(context.getConfiguration());  
  49.             Path freqFile = new Path("/user/orisun/input/F1");  
  50.             FSDataInputStream in = fs.open(freqFile);  
  51.             InputStreamReader isr = new InputStreamReader(in);  
  52.             BufferedReader br = new BufferedReader(isr);  
  53.             try {  
  54.                 String line;  
  55.                 while ((line = br.readLine()) != null) {  
  56.                     String[] str = line.split("\\s+");  
  57.                     String word = str[0];  
  58.                     freq.add(word);  
  59.                 }  
  60.             } finally {  
  61.                 br.close();  
  62.             }  
  63.             // 對頻繁1項集進行分組  
  64.             Collections.shuffle(freq); // 打亂順序  
  65.             int cap = freq.size() / GroupNum; // 每段分爲一組  
  66.             for (int i = 0; i < GroupNum; i++) {  
  67.                 List<String> list = new LinkedList<String>();  
  68.                 for (int j = 0; j < cap; j++) {  
  69.                     list.add(freq.get(i * cap + j));  
  70.                 }  
  71.                 freq_group.add(list);  
  72.             }  
  73.             int remainder = freq.size() % GroupNum;  
  74.             int base = GroupNum * cap;  
  75.             for (int i = 0; i < remainder; i++) {  
  76.                 freq_group.get(i).add(freq.get(base + i));  
  77.             }  
  78.         }  
  79.    
  80.         @Override  
  81.         public void map(LongWritable key, Text value, Context context)  
  82.                 throws IOException, InterruptedException {  
  83.             String[] arr = value.toString().split("\\s+");  
  84.             Record record = new Record(arr);  
  85.             LinkedList<String> list = record.list;  
  86.             BitSet bs=new BitSet(freq_group.size());  
  87.             bs.clear();  
  88.             while (record.list.size() > 0) {  
  89.                 String item = list.peekLast(); // 取出record的最後一項  
  90.                 int i=0;  
  91.                 for (; i < freq_group.size(); i++) {  
  92.                     if(bs.get(i))  
  93.                         continue;  
  94.                     if (freq_group.get(i).contains(item)) {  
  95.                         bs.set(i);  
  96.                         break;  
  97.                     }  
  98.                 }  
  99.                 if(i<freq_group.size()){     //找到了  
  100.                     context.write(new IntWritable(i), record);    
  101.                 }  
  102.                 record.list.pollLast();  
  103.             }  
  104.         }  
  105.     }  
  106.        
  107.     public static class FPReducer extends Reducer<IntWritable,Record,IntWritable,Text>{  
  108.         public void reduce(IntWritable key,Iterable<Record> values,Context context)throws IOException,InterruptedException{  
  109.             List<List<String>> trans=new LinkedList<List<String>>();  
  110.             while(values.iterator().hasNext()){  
  111.                 Record record=values.iterator().next();  
  112.                 LinkedList<String> list=new LinkedList<String>();  
  113.                 for(String ele:record.list)  
  114.                     list.add(ele);  
  115.                 trans.add(list);  
  116.             }  
  117.             FPGrowth(trans, null,context);  
  118.         }  
  119.         // FP-Growth算法  
  120.     public void FPGrowth(List<List<String>> transRecords,  
  121.             List<String> postPattern,Context context) throws IOException, InterruptedException {  
  122.         // 構建項頭表,同時也是頻繁1項集  
  123.         ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);  
  124.         // 構建FP-Tree  
  125.         TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);  
  126.         // 如果FP-Tree爲空則返回  
  127.         if (treeRoot.getChildren()==null || treeRoot.getChildren().size() == 0)  
  128.             return;  
  129.         //輸出項頭表的每一項+postPattern  
  130.         if(postPattern!=null){  
  131.             for (TreeNode header : HeaderTable) {  
  132.                 String outStr=header.getName();  
  133.                 int count=header.getCount();  
  134.                 for (String ele : postPattern)  
  135.                     outStr+="\t" + ele;  
  136.                 context.write(new IntWritable(count), new Text(outStr));  
  137.             }  
  138.         }  
  139.         // 找到項頭表的每一項的條件模式基,進入遞歸迭代  
  140.         for (TreeNode header : HeaderTable) {  
  141.             // 後綴模式增加一項  
  142.             List<String> newPostPattern = new LinkedList<String>();  
  143.             newPostPattern.add(header.getName());  
  144.             if (postPattern != null)  
  145.                 newPostPattern.addAll(postPattern);  
  146.             // 尋找header的條件模式基CPB,放入newTransRecords中  
  147.             List<List<String>> newTransRecords = new LinkedList<List<String>>();  
  148.             TreeNode backnode = header.getNextHomonym();  
  149.             while (backnode != null) {  
  150.                 int counter = backnode.getCount();  
  151.                 List<String> prenodes = new ArrayList<String>();  
  152.                 TreeNode parent = backnode;  
  153.                 // 遍歷backnode的祖先節點,放到prenodes中  
  154.                 while ((parent = parent.getParent()).getName() != null) {  
  155.                     prenodes.add(parent.getName());  
  156.                 }  
  157.                 while (counter-- > 0) {  
  158.                     newTransRecords.add(prenodes);  
  159.                 }  
  160.                 backnode = backnode.getNextHomonym();  
  161.             }  
  162.             // 遞歸迭代  
  163.             FPGrowth(newTransRecords, newPostPattern,context);  
  164.         }  
  165.     }  
  166.    
  167.         // 構建項頭表,同時也是頻繁1項集  
  168.         public ArrayList<TreeNode> buildHeaderTable(List<List<String>> transRecords) {  
  169.             ArrayList<TreeNode> F1 = null;  
  170.             if (transRecords.size() > 0) {  
  171.                 F1 = new ArrayList<TreeNode>();  
  172.                 Map<String, TreeNode> map = new HashMap<String, TreeNode>();  
  173.                 // 計算事務數據庫中各項的支持度  
  174.                 for (List<String> record : transRecords) {  
  175.                     for (String item : record) {  
  176.                         if (!map.keySet().contains(item)) {  
  177.                             TreeNode node = new TreeNode(item);  
  178.                             node.setCount(1);  
  179.                             map.put(item, node);  
  180.                         } else {  
  181.                             map.get(item).countIncrement(1);  
  182.                         }  
  183.                     }  
  184.                 }  
  185.                 // 把支持度大於(或等於)minSup的項加入到F1中  
  186.                 Set<String> names = map.keySet();  
  187.                 for (String name : names) {  
  188.                     TreeNode tnode = map.get(name);  
  189.                     if (tnode.getCount() >= minSuport) {  
  190.                         F1.add(tnode);  
  191.                     }  
  192.                 }  
  193.                 Collections.sort(F1);  
  194.                 return F1;  
  195.             } else {  
  196.                 return null;  
  197.             }  
  198.         }  
  199.    
  200.         // 構建FP-Tree  
  201.         public TreeNode buildFPTree(List<List<String>> transRecords,  
  202.                 ArrayList<TreeNode> F1) {  
  203.             TreeNode root = new TreeNode(); // 創建樹的根節點  
  204.             for (List<String> transRecord : transRecords) {  
  205.                 LinkedList<String> record = sortByF1(transRecord, F1);  
  206.                 TreeNode subTreeRoot = root;  
  207.                 TreeNode tmpRoot = null;  
  208.                 if (root.getChildren() != null) {  
  209.                     while (!record.isEmpty()  
  210.                             && (tmpRoot = subTreeRoot.findChild(record.peek())) != null) {  
  211.                         tmpRoot.countIncrement(1);  
  212.                         subTreeRoot = tmpRoot;  
  213.                         record.poll();  
  214.                     }  
  215.                 }  
  216.                 addNodes(subTreeRoot, record, F1);  
  217.             }  
  218.             return root;  
  219.         }  
  220.    
  221.         // 把交易記錄按項的頻繁程序降序排列  
  222.         public LinkedList<String> sortByF1(List<String> transRecord,  
  223.                 ArrayList<TreeNode> F1) {  
  224.             Map<String, Integer> map = new HashMap<String, Integer>();  
  225.             for (String item : transRecord) {  
  226.                 // 由於F1已經是按降序排列的,  
  227.                 for (int i = 0; i < F1.size(); i++) {  
  228.                     TreeNode tnode = F1.get(i);  
  229.                     if (tnode.getName().equals(item)) {  
  230.                         map.put(item, i);  
  231.                     }  
  232.                 }  
  233.             }  
  234.             ArrayList<Entry<String, Integer>> al = new ArrayList<Entry<String, Integer>>(  
  235.                     map.entrySet());  
  236.             Collections.sort(al, new Comparator<Map.Entry<String, Integer>>() {  
  237.                 @Override  
  238.                 public int compare(Entry<String, Integer> arg0,  
  239.                         Entry<String, Integer> arg1) {  
  240.                     // 降序排列  
  241.                     return arg0.getValue() - arg1.getValue();  
  242.                 }  
  243.             });  
  244.             LinkedList<String> rest = new LinkedList<String>();  
  245.             for (Entry<String, Integer> entry : al) {  
  246.                 rest.add(entry.getKey());  
  247.             }  
  248.             return rest;  
  249.         }  
  250.    
  251.         // 把record作爲ancestor的後代插入樹中  
  252.         public void addNodes(TreeNode ancestor, LinkedList<String> record,  
  253.                 ArrayList<TreeNode> F1) {  
  254.             if (record.size() > 0) {  
  255.                 while (record.size() > 0) {  
  256.                     String item = record.poll();  
  257.                     TreeNode leafnode = new TreeNode(item);  
  258.                     leafnode.setCount(1);  
  259.                     leafnode.setParent(ancestor);  
  260.                     ancestor.addChild(leafnode);  
  261.    
  262.                     for (TreeNode f1 : F1) {  
  263.                         if (f1.getName().equals(item)) {  
  264.                             while (f1.getNextHomonym() != null) {  
  265.                                 f1 = f1.getNextHomonym();  
  266.                             }  
  267.                             f1.setNextHomonym(leafnode);  
  268.                             break;  
  269.                         }  
  270.                     }  
  271.    
  272.                     addNodes(leafnode, record, F1);  
  273.                 }  
  274.             }  
  275.         }  
  276.     }  
  277.        
  278.     public static class InverseMapper extends  
  279.             Mapper<LongWritable, Text, Record, IntWritable> {  
  280.         @Override  
  281.         public void map(LongWritable key, Text value, Context context)  
  282.                 throws IOException, InterruptedException {  
  283.             String []arr=value.toString().split("\\s+");  
  284.             int count=Integer.parseInt(arr[0]);  
  285.             Record record=new Record();  
  286.             for(int i=1;i<arr.length;i++){  
  287.                 record.list.add(arr[i]);  
  288.             }  
  289.             context.write(record, new IntWritable(count));  
  290.         }  
  291.     }  
  292.        
  293.     public static class MaxReducer extends Reducer<Record,IntWritable,IntWritable,Record>{  
  294.         public void reduce(Record key,Iterable<IntWritable> values,Context context)throws IOException,InterruptedException{  
  295.             int max=-1;  
  296.             for(IntWritable value:values){  
  297.                 int i=value.get();  
  298.                 if(i>max)  
  299.                     max=i;  
  300.             }  
  301.             context.write(new IntWritable(max), key);  
  302.         }  
  303.     }  
  304.    
  305.    
  306.     @Override  
  307.     public int run(String[] arg0) throws Exception {  
  308.         Configuration conf=getConf();  
  309.         conf.set("mapred.task.timeout""6000000");  
  310.         Job job=new Job(conf);  
  311.         job.setJarByClass(DC_FPTree.class);  
  312.         FileSystem fs=FileSystem.get(getConf());  
  313.            
  314.         FileInputFormat.setInputPaths(job, "/user/orisun/input/data");  
  315.         Path outDir=new Path("/user/orisun/output");  
  316.         fs.delete(outDir,true);  
  317.         FileOutputFormat.setOutputPath(job, outDir);  
  318.            
  319.         job.setMapperClass(GroupMapper.class);  
  320.         job.setReducerClass(FPReducer.class);  
  321.            
  322.         job.setInputFormatClass(TextInputFormat.class);  
  323.         job.setOutputFormatClass(TextOutputFormat.class);  
  324.         job.setMapOutputKeyClass(IntWritable.class);  
  325.         job.setMapOutputValueClass(Record.class);  
  326.         job.setOutputKeyClass(IntWritable.class);  
  327.         job.setOutputKeyClass(Text.class);  
  328.            
  329.         boolean success=job.waitForCompletion(true);  
  330.            
  331.         job=new Job(conf);  
  332.         job.setJarByClass(DC_FPTree.class);  
  333.            
  334.         FileInputFormat.setInputPaths(job, "/user/orisun/output/part-r-*");  
  335.         Path outDir2=new Path("/user/orisun/output2");  
  336.         fs.delete(outDir2,true);  
  337.         FileOutputFormat.setOutputPath(job, outDir2);  
  338.            
  339.         job.setMapperClass(InverseMapper.class);  
  340.         job.setReducerClass(MaxReducer.class);  
  341.         //job.setNumReduceTasks(0);  
  342.            
  343.         job.setInputFormatClass(TextInputFormat.class);  
  344.         job.setOutputFormatClass(TextOutputFormat.class);  
  345.         job.setMapOutputKeyClass(Record.class);  
  346.         job.setMapOutputValueClass(IntWritable.class);  
  347.         job.setOutputKeyClass(IntWritable.class);  
  348.         job.setOutputKeyClass(Record.class);  
  349.            
  350.         success |= job.waitForCompletion(true);  
  351.            
  352.         return success?0:1;  
  353.     }  
  354.    
  355.     public static void main(String[] args) throws Exception{  
  356.         int res=ToolRunner.run(new Configuration(), new DC_FPTree(), args);  
  357.         System.exit(res);  
  358.     }  
  359. }  

結束語

在實踐中,關聯規則挖掘可能並不像人們期望的那麼有用。一方面是因爲支持度置信度框架會產生過多的規則,並不是每一個規則都是有用的。另一方面大部分的關聯規則並不像“啤酒與尿布”這種經典故事這麼普遍。關聯規則分析是需要技巧的,有時需要用更嚴格的統計學知識來控制規則的增殖。 

原文來自:博客園(華夏35度)http://www.cnblogs.com/zhangchaoyang 作者:Orisun


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章