協同過濾推薦算法(java原生JDK實現-附源碼地址)
一、項目需求
1. 需求鏈接
https://tianchi.aliyun.com/getStart/information.htm?raceId=231522
2. 需求內容
競賽題目
在真實的業務場景下,我們往往需要對所有商品的一個子集構建個性化推薦模型。在完成這件任務的過程中,我們不僅需要利用用戶在這個商品子集上的行爲數據,往往還需要利用更豐富的用戶行爲數據。定義如下的符號:
U——用戶集合
I——商品全集
P——商品子集,P ⊆ I
D——用戶對商品全集的行爲數據集合
那麼我們的目標是利用D來構造U中用戶對P中商品的推薦模型。
數據說明
本場比賽提供20000用戶的完整行爲數據以及百萬級的商品信息。競賽數據包含兩個部分。
第一部分是用戶在商品全集上的移動端行爲數據(D),表名爲tianchi_fresh_comp_train_user_2w,包含如下字段:
字段 |
字段說明 |
提取說明 |
user_id |
用戶標識 |
抽樣&字段脫敏 |
item_id |
商品標識 |
字段脫敏 |
behavior_type |
用戶對商品的行爲類型 |
包括瀏覽、收藏、加購物車、購買,對應取值分別是1、2、3、4。 |
user_geohash |
用戶位置的空間標識,可以爲空 |
由經緯度通過保密的算法生成 |
item_category |
商品分類標識 |
字段脫敏 |
time |
行爲時間 |
精確到小時級別 |
第二個部分是商品子集(P),表名爲tianchi_fresh_comp_train_item_2w,包含如下字段:
字段 |
字段說明 |
提取說明 |
item_id |
商品標識 |
抽樣&字段脫敏 |
item_ geohash |
商品位置的空間標識,可以爲空 |
由經緯度通過保密的算法生成 |
item_category |
商品分類標識 |
字段脫敏 |
訓練數據包含了抽樣出來的一定量用戶在一個月時間(11.18~12.18)之內的移動端行爲數據(D),評分數據是這些用戶在這個一個月之後的一天(12.19)對商品子集(P)的購買數據。參賽者要使用訓練數據建立推薦模型,並輸出用戶在接下來一天對商品子集購買行爲的預測結果。
評分數據格式
具體計算公式如下:參賽者完成用戶對商品子集的購買預測之後,需要將結果放入指定格式的數據表(非分區表)中,要求結果表名爲:tianchi_mobile_recommendation_predict.csv,且以utf-8格式編碼;包含user_id和item_id兩列(均爲string類型),要求去除重複。例如:
評估指標
比賽採用經典的精確度(precision)、召回率(recall)和F1值作爲評估指標。具體計算公式如下:
其中PredictionSet爲算法預測的購買數據集合,ReferenceSet爲真實的答案購買數據集合。我們以F1值作爲最終的唯一評測標準。
二、協同過濾推薦算法原理及實現流程
1. 基於用戶的協同過濾推薦算法
基於用戶的協同過濾推薦算法通過尋找與目標用戶具有相似評分的鄰居用戶,通過查找鄰居用戶喜歡的項目,推測目標用戶也具有相同的喜好。基於用戶的協同過濾推薦算法基本思想是:根據用戶-項目評分矩陣查找當前用戶的最近鄰居,利用最近鄰居的評分來預測當前用戶對項目的預測值,將評分最高的N個項目推薦給用戶,其中的項目可理解爲系統處理的商品。其算法流程圖如下圖1所示。
圖1基於用戶的協同過濾推薦算法流程
基於用戶的協同過濾推薦算法流程爲:
1).構建用戶項目評分矩陣
R={ , …… },T:m×n的用戶評分矩陣,其中r={ , ,……, }爲用戶 的評分向量, 代表用戶 對項目 的評分。
2).計算用戶相似度
基於用戶的協同過濾推薦算法,需查找與目標用戶相似的用戶。衡量用戶之間的相似性需要計算每個用戶的評分與其他用戶評分的相似度,即評分矩陣中的用戶評分記錄。每個用戶對項目的評分可以看作是一個n維的評分向量。使用評分向量計算目標用戶 與其他用戶 之間的相似度sim(i,j),通常計算用戶相似度的方法有三種:餘弦相似度、修正的餘弦相似度和皮爾森相關係數。
3).構建最近鄰居集
最近鄰居集Neighor(u)中包含的是與目標用戶具有相同愛好的其他用戶。爲選取鄰居用戶,我們首先計算目標用戶u與其他用戶v的相似度sim(u,v),再選擇相似度最大的k個用戶。用戶相似度可理解爲用戶之間的信任值或推薦權重。通常,sim(u,v)∈[1,1]。用戶相似度爲1表示兩個用戶互相的推薦權重很大。如果爲-1,表示兩個用戶的由於興趣相差很大,因此互相的推薦權重很小。
4).預測評分計算
用戶a 對項目i的預測評分p(a,i)爲鄰居用戶對該項目評分的加權評分值。顯然,不同用戶對於目標用戶的影響程度不同,所以在計算預測評分時,不同用戶有不同的權重。計算時,我們選擇用戶相似度作爲用戶的權重因子,計算公式如下:
基於用戶的協同過濾推薦算法實現步驟爲:
1).實時統計user對item的打分,從而生成user-item表(即構建用戶-項目評分矩陣);
2).計算各個user之間的相似度,從而生成user-user的得分表,並進行排序;
3).對每一user的item集合排序;
4).針對預推薦的user,在user-user的得分表中選擇與該用戶最相似的N個用戶,並在user-item表中選擇這N個用戶中已排序好的item集合中的topM;
5).此時的N*M個商品即爲該用戶推薦的商品集。
2. 基於項目的協同過濾推薦算法
基於項目的協同過濾推薦算法依據用戶-項目評分矩陣通過計算項目之間的評分相似性來衡量項目評分相似性,找到與目標項目最相似的n個項目作爲最近鄰居集。然後通過對目標項目的相似鄰居賦予一定的權重來預測當前項目的評分,再將得到的最終預測評分按序排列,將評分最高的N個項目推薦給當前用戶,其中的項目可理解爲系統處理的商品。其算法流程如下圖2所示。
圖2基於項目的協同過濾推薦算法流程
基於項目的協同過濾推薦算法流程爲:
首先,讀取目標用戶的評分記錄集合 ;然後計算項目i與 中其他項目的相似度,選取k個最近鄰居;根據評分相似度計算公式計算候選集中所有項目的預測評分;最後選取預測評分最高的N個項目推薦給用戶。
基於項目的協同過濾推薦算法預測評分與其他用戶評分的加權評分值相關,不同的歷史評分項目與當前項目i的相關度有差異,所以在進行計算時,不同的項目有不同的權重。評分預測函數p(u,i),以項目相似度作爲項目的權重因子,得到的評分公式如下:
基於項目的協同過濾推薦算法實現步驟爲:
1).實時統計user對item的打分,從而生成user-item表(即構建用戶-項目評分矩陣);
2).計算各個item之間的相似度,從而生成item-item的得分表,並進行排序;
3).對每一user的item集合排序;
4).針對預推薦的user,在該用戶已選擇的item集合中,根據item-item表選擇與已選item最相似的N個item;
5).此時的N個商品即爲該用戶推薦的商品集。
3. 基於用戶的協同過濾推薦算法與基於項目的協同過濾推薦算法比較
基於用戶的協同過濾推薦算法:
可以幫助用戶發現新的商品,但需要較複雜的在線計算,需要處理新用戶的問題。
基於項目的協同過濾推薦算法:
準確性好,表現穩定可控,便於離線計算,但推薦結果的多樣性會差一些,一般不會帶給用戶驚喜性。
三、 項目實現
針對移動推薦,我們選擇使用基於用戶的協同過濾推薦算法來進行實現。
1. 數據模型及其實體類
用戶行爲數據:(user.csv)
user_id,item_id,behavior_type,user_geohash,item_category,time
10001082,285259775,1,97lk14c,4076,2014-12-08 18
10001082,4368907,1,,5503,2014-12-12 12
10001082,4368907,1,,5503,2014-12-12 12
10001082,53616768,1,,9762,2014-12-02 15
10001082,151466952,1,,5232,2014-12-12 11
10001082,53616768,4,,9762,2014-12-02 15
10001082,290088061,1,,5503,2014-12-12 12
10001082,298397524,1,,10894,2014-12-12 12
10001082,32104252,1,,6513,2014-12-12 12
10001082,323339743,1,,10894,2014-12-1212
商品信息:(item.csv)
item_id,item_geohash,item_category
100002303,,3368
100003592,,7995
100006838,,12630
100008089,,7791
100012750,,9614
100014072,,1032
100014463,,9023
100019387,,3064
100023812,,6700package entity;
public class Item {
private String itemId;
private String itemGeoHash;
private String itemCategory;
public String getItemId() {
return itemId;
}
public void setItemId(String itemId) {
this.itemId = itemId;
}
public String getItemGeoHash() {
return itemGeoHash;
}
public void setItemGeoHash(String itemGeoHash) {
this.itemGeoHash = itemGeoHash;
}
public String getItemCategory() {
return itemCategory;
}
public void setItemCategory(String itemCategory) {
this.itemCategory = itemCategory;
}
@Override
public String toString() {
return "item [itemId=" + itemId + ", itemGeoHash=" + itemGeoHash
+ ", itemCategory=" + itemCategory + "]";
}
}
package entity;
public class Score implements Comparable<Score> {
private String userId; // 用戶標識
private String itemId; // 商品標識
private double score;
public String getUserId() {
return userId;
}
public void setUserId(String userId) {
this.userId = userId;
}
public String getItemId() {
return itemId;
}
public void setItemId(String itemId) {
this.itemId = itemId;
}
public double getScore() {
return score;
}
public void setScore(double score) {
this.score = score;
}
@Override
public String toString() {
return "Score [userId=" + userId + ", itemId=" + itemId + ", score="
+ score + "]";
}
@Override
public int compareTo(Score o) {
if ((this.score - o.score) < 0) {
return 1;
}else if ((this.score - o.score) > 0) {
return -1;
}else {
return 0;
}
}
}
package entity;
public class User implements Comparable<User> {
private String userId; // 用戶標識
private String itemId; // 商品標識
private int behaviorType; // 用戶對商品的行爲類型,可以爲空,包括瀏覽、收藏、加購物車、購買,對應取值分別是1、2、3、4.
private String userGeoHash; // 用戶位置的空間標識
private String itemCategory;// 商品分類標識
private String time; // 行爲時間
private int count;
private double weight; // 權重
public String getUserId() {
return userId;
}
public void setUserId(String userId) {
this.userId = userId;
}
public String getItemId() {
return itemId;
}
public void setItemId(String itemId) {
this.itemId = itemId;
}
public int getBehaviorType() {
return behaviorType;
}
public void setBehaviorType(int behaviorType) {
this.behaviorType = behaviorType;
}
public String getUserGeoHash() {
return userGeoHash;
}
public void setUserGeoHash(String userGeoHash) {
this.userGeoHash = userGeoHash;
}
public String getItemCategory() {
return itemCategory;
}
public void setItemCategory(String itemCategory) {
this.itemCategory = itemCategory;
}
public String getTime() {
return time;
}
public void setTime(String time) {
this.time = time;
}
@Override
public String toString() {
return "User [userId=" + userId + ", itemId=" + itemId
+ ", behaviorType=" + behaviorType + ", count=" + count + "]";
}
public int getCount() {
return count;
}
public void setCount(int count) {
this.count = count;
}
public double getWeight() {
return weight;
}
public void setWeight(double weight) {
this.weight = weight;
}
@Override
public int compareTo(User o) {
return (int)((-1) * (this.weight - o.weight));
}
}
2. 工具類
文件處理工具:
package util;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import entity.Item;
import entity.Score;
import entity.User;
public class FileTool {
public static FileReader fr=null;
public static BufferedReader br=null;
public static String line=null;
public static FileOutputStream fos1 = null,fos2 = null,fos3 = null;
public static PrintStream ps1 = null,ps2 = null,ps3 = null;
public static int count = 0;
/**
* 初始化寫文件器(單一指針)
* */
public static void initWriter1(String writePath) {
try {
fos1 = new FileOutputStream(writePath);
ps1 = new PrintStream(fos1);
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
/**
* 關閉文件器(單一指針)
* */
public static void closeRedaer() {
try {
br.close();
fr.close();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 關閉文件器(單一指針)
* */
public static void closeWriter1() {
try {
ps1.close();
fos1.close();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 初始化寫文件器(雙指針)
* */
public static void initWriter2(String writePath1,String writePath2) {
try {
fos1 = new FileOutputStream(writePath1);
ps1 = new PrintStream(fos1);
fos2 = new FileOutputStream(writePath2);
ps2 = new PrintStream(fos2);
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
/**
* 關閉文件器(雙指針)
* */
public static void closeWriter2() {
try {
ps1.close();
fos1.close();
ps2.close();
fos2.close();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 初始化寫文件器(三指針)
* */
public static void initWriter3(String writePath1,String writePath2,String writePath3) {
try {
fos1 = new FileOutputStream(writePath1);
ps1 = new PrintStream(fos1);
fos2 = new FileOutputStream(writePath2);
ps2 = new PrintStream(fos2);
fos3 = new FileOutputStream(writePath3);
ps3 = new PrintStream(fos3);
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
/**
* 關閉文件器(三指針)
* */
public static void closeWriter3() {
try {
ps1.close();
fos1.close();
ps2.close();
fos2.close();
ps3.close();
fos3.close();
} catch (IOException e) {
e.printStackTrace();
}
}
public static List readFileOne(String path,boolean isTitle,String token,String pattern) throws Exception {
List<Object> ret = new ArrayList<Object>();
fr = new FileReader(path);
br = new BufferedReader(fr);
int count = 0,i = 0;
if (isTitle) {
line = br.readLine();
count++;
}
while((line = br.readLine()) != null){
String[] strArr = line.split(token);
switch (pattern) {
case "item":
ret.add(ParseTool.parseItem(strArr));
break;
case "user":
ret.add(ParseTool.parseUser(strArr));
break;
case "score":
ret.add(ParseTool.parseScore(strArr));
default:
ret.add(line);
break;
}
count++;
if (count/100000 == 1) {
i++;
System.out.println(100000*i);
count = 0;
}
}
closeRedaer();
return ret;
}
public static void makeSampleData(String inputPath,boolean isTitle,String outputPath,int threshold) throws Exception {
fr = new FileReader(inputPath);
br = new BufferedReader(fr);
initWriter1(outputPath);
if (isTitle) {
line = br.readLine();
}
int count = 0;
while((line = br.readLine()) != null){
ps1.println(line);
count++;
if (count == threshold) {
break;
}
}
closeRedaer();
}
public static List<String> traverseFolder(String dir) {
File file = new File(dir);
String[] fileList = null;
if (file.exists()) {
fileList = file.list();
}
List<String> list = new ArrayList<String>();
for(String path : fileList){
list.add(path);
}
return list;
}
public static Map<String, List<Score>> loadScoreMap(String path,boolean isTitle,String token) throws Exception {
fr = new FileReader(path);
br = new BufferedReader(fr);
if (isTitle) {
line = br.readLine();
}
Map<String, List<Score>> scoreMap = new HashMap<String, List<Score>>();
while((line = br.readLine()) != null){
String[] arr = line.split(token);
Score score = ParseTool.parseScore(arr);
List<Score> temp = new ArrayList<Score>();
if (scoreMap.containsKey(score.getUserId())) {
temp = scoreMap.get(score.getUserId());
}
temp.add(score);
scoreMap.put(score.getUserId(), temp);
}
closeRedaer();
return scoreMap;
}
public static Map<String, List<String>> loadPredictData(String path,boolean isTitle,String token) throws Exception {
fr = new FileReader(path);
br = new BufferedReader(fr);
if (isTitle) {
line = br.readLine();
}
Map<String, List<String>> map = new HashMap<String, List<String>>();
while((line = br.readLine()) != null){
String[] arr = line.split(token);
String userId = arr[0];
String itemId = arr[1];
List<String> temp = new ArrayList<String>();
if (map.containsKey(userId)) {
temp = map.get(userId);
}
temp.add(itemId);
map.put(userId, temp);
count++;
}
closeRedaer();
return map;
}
public static Map<String, List<String>> loadTestData(Map<String, List<String>> predictMap, String dir, boolean isTitle, String token) throws Exception {
List<String> fileList = traverseFolder(dir);
Set<String> predictKeySet = predictMap.keySet();
Map<String, List<String>> testMap = new HashMap<String, List<String>>();
for(String predictKey : predictKeySet){
if (fileList.contains(predictKey)) {
List<String> itemList = loadTestData(dir + predictKey, isTitle, token);
testMap.put(predictKey, itemList);
}
}
return testMap;
}
public static List<String> loadTestData(String path, boolean isTitle, String token) throws Exception {
fr = new FileReader(path);
br = new BufferedReader(fr);
if (isTitle) {
line = br.readLine();
}
List<String> list = new ArrayList<String>();
Set<String> set = new HashSet<String>();
while((line = br.readLine()) != null){
String[] arr = line.split(token);
set.add(arr[1]);
count++;
}
closeRedaer();
for(String item : set){
list.add(item);
}
return list;
}
public static Map<String, Double> loadUser_ItemData(String path,boolean isTitle,String token) throws Exception {
fr = new FileReader(path);
br = new BufferedReader(fr);
if (isTitle) {
line = br.readLine();
}
Map<String, Double> map = new HashMap<String, Double>();
while((line = br.readLine()) != null){
String[] arr = line.split(token);
String itemId = arr[1];
double score = Double.valueOf(arr[2]);
if(map.containsKey(itemId)){
double temp = map.get(itemId);
if (temp > score) {
score = temp;
}
}
map.put(itemId, score);
}
closeRedaer();
return map;
}
public static Map<String, Set<String>> loadTestUser(String path,boolean isTitle,String token) throws Exception {
fr = new FileReader(path);
br = new BufferedReader(fr);
int count = 0,i = 0;
if (isTitle) {
line = br.readLine();
count++;
}
Map<String, Set<String>> map = new HashMap<String, Set<String>>();
while((line = br.readLine()) != null){
String[] arr = line.split(token);
String userId = arr[0];
String itemId = arr[1];
Set<String> set = new HashSet<String>();
if (map.containsKey(userId)) {
set = map.get(userId);
set.add(itemId);
}
map.put(userId, set);
count++;
if (count/100000 == 1) {
i++;
System.out.println(100000*i);
count = 0;
}
}
closeRedaer();
return map;
}
}
解析工具:
package util;
import entity.Item;
import entity.Score;
import entity.User;
public class ParseTool {
public static boolean isNumber(String str) {
int i,n;
n = str.length();
for(i = 0;i < n;i++){
if (!Character.isDigit(str.charAt(i))) {
return false;
}
}
return true;
}
public static Item parseItem(String[] contents) {
Item item = new Item();
if (contents[0] != null && !contents[0].isEmpty()) {
item.setItemId(contents[0].trim());
}
if (contents[1] != null && !contents[1].isEmpty()) {
item.setItemGeoHash(contents[1].trim());
}
if (contents[2] != null && !contents[2].isEmpty()) {
item.setItemCategory(contents[2].trim());
}
return item;
}
public static User parseUser(String[] contents) {
User user = new User();
int n = contents.length;
if (contents[0] != null && !contents[0].isEmpty()) {
user.setUserId(contents[0].trim());
}
if (contents[1] != null && !contents[1].isEmpty()) {
user.setItemId(contents[1].trim());
}
/*
// 2.調用CountFileTest需放開,其它需註釋
if (contents[2] != null && !contents[2].isEmpty()) {
user.setBehaviorType(Integer.valueOf(contents[2].trim()));
}
// 2.調用CountFileTest需放開,其它需註釋
if (contents[n-1] != null && !contents[n-1].isEmpty()) {
user.setCount(Integer.valueOf(contents[n-1].trim()));
}
*/
// 3.調用PredictTest需放開,其它需註釋
if (contents[n-1] != null && !contents[n-1].isEmpty()) {
user.setWeight(Double.valueOf(contents[n-1].trim()));
}
/*
// 1.調用SpliteFileAndMakeScoreTable需放開,其它需註釋
if (contents[3] != null && !contents[3].isEmpty()) {
user.setUserGeoHash(contents[3].trim());
}
if (contents[4] != null && !contents[4].isEmpty()) {
user.setItemCategory(contents[4].trim());
}
if (contents[5] != null && !contents[5].isEmpty()) {
user.setTime(contents[5].trim());
}
*/
return user;
}
public static Score parseScore(String[] contents) {
Score score = new Score();
if (contents[0] != null && !contents[0].isEmpty()) {
score.setUserId(contents[0].trim());
}
if (contents[1] != null && !contents[1].isEmpty()) {
score.setItemId(contents[1].trim());
}
if (contents[2] != null && !contents[2].isEmpty()) {
score.setScore(Double.parseDouble(contents[2].trim()));
}
return score;
}
}
3. 數據處理模塊:
package service;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;
import util.FileTool;
import entity.Item;
import entity.Score;
import entity.User;
public class DataProcess {
public static final double[] w = {0,10,20,30};
public static void output(Map<String, Map<String, List<User>>> userMap,String outputPath) {
for(Entry<String, Map<String, List<User>>> entry : userMap.entrySet()){
FileTool.initWriter1(outputPath + entry.getKey());
Map<String, List<User>> temp = entry.getValue();
for(Entry<String, List<User>> tempEntry : temp.entrySet()){
List<User> users = tempEntry.getValue();
int count = users.size();
for(User user : users){
FileTool.ps1.print(user.getUserId() + "\t");
FileTool.ps1.print(user.getItemId() + "\t");
FileTool.ps1.print(user.getBehaviorType() + "\t");
//FileTool.ps1.print(user.getUserGeoHash() + "\t");
//FileTool.ps1.print(user.getItemCategory() + "\t");
//FileTool.ps1.print(user.getTime() + "\t");
FileTool.ps1.print(count + "\n");
}
}
}
FileTool.closeWriter1();
}
public static void output(Map<String, Map<String, Double>> scoreTable, String outputPath, Set<String> userSet, Set<String> itemSet, String token) {
FileTool.initWriter1(outputPath);
for(String itemId: itemSet){
FileTool.ps1.print(token + itemId);
}
FileTool.ps1.println();
for(String userId : userSet){
FileTool.ps1.print(userId + token);
Map<String, Double> itemMap = scoreTable.get(userId);
for(String itemId: itemSet){
if(itemMap.containsKey(itemId)){
FileTool.ps1.print(itemMap.get(itemId));
}else {
//FileTool.ps1.print(0);
}
FileTool.ps1.print(token);
}
FileTool.ps1.print("\n");
}
}
public static void outputUser(List<User> userList) {
for(User user : userList){
FileTool.ps1.println(user.getUserId() + "\t" + user.getItemId() + "\t" + user.getWeight());
}
}
public static void outputScore(List<Score> scoreList) {
for(Score score : scoreList){
FileTool.ps1.println(score.getUserId() + "\t" + score.getItemId() + "\t" + score.getScore());
}
}
public static void outputRecommendList(Map<String, Set<String>> map) {
for(Entry<String, Set<String>> entry : map.entrySet()){
String userId = entry.getKey();
Set<String> itemSet = entry.getValue();
for(String itemId : itemSet){
FileTool.ps1.println(userId + "," + itemId);
}
}
}
public static void output(Map<String, Set<String>> map) {
for(Entry<String, Set<String>> entry : map.entrySet()){
String userId = entry.getKey();
Set<String> set = entry.getValue();
for(String itemId : set){
FileTool.ps1.println(userId + "\t" + itemId);
}
}
}
public static Map<String, Map<String, List<User>>> mapByUser(List<User> userList,Set<String> userSet,Set<String> itemSet) {
Map<String, Map<String, List<User>>> userMap = new HashMap<>();
for(User user: userList){
Map<String, List<User>> tempMap = new HashMap<String, List<User>>();
List<User> tempList = new ArrayList<User>();
if (!userMap.containsKey(user.getUserId())) {
}else {
tempMap = userMap.get(user.getUserId());
if (!tempMap.containsKey(user.getItemId())) {
}else {
tempList = tempMap.get(user.getItemId());
}
}
tempList.add(user);
tempMap.put(user.getItemId(), tempList);
userMap.put(user.getUserId(), tempMap);
userSet.add(user.getUserId());
itemSet.add(user.getItemId());
}
return userMap;
}
public static Map<String, Map<String, Double>> makeScoreTable(Map<String, Map<String, List<User>>> userMap) {
Map<String, Map<String, Double>> scoreTable = new HashMap<String, Map<String,Double>>();
for(Entry<String, Map<String, List<User>>> userEntry : userMap.entrySet()){
Map<String, List<User>> itemMap = userEntry.getValue();
String userId = userEntry.getKey();
Map<String, Double> itemScoreMap = new HashMap<String, Double>();
for(Entry<String, List<User>> itemEntry : itemMap.entrySet()){
String itemId = itemEntry.getKey();
List<User> users = itemEntry.getValue();
double weight = 0.0;
int maxType = 0;
for(User user : users){
if (user.getBehaviorType() > maxType) {
maxType = user.getBehaviorType();
}
}
int count = users.size();
if (maxType != 0) {
weight += w[maxType-1];
}
weight += count;
itemScoreMap.put(itemId, weight);
}
scoreTable.put(userId, itemScoreMap);
}
return scoreTable;
}
public static double calculateWeight(int behaviorType, int count) {
double weight = w[behaviorType-1] + count;
return weight;
}
public static List<User> reduceUserByItem(List<User> userList) {
List<User> list = new ArrayList<User>();
Map<String, User> userMap = new LinkedHashMap<String, User>();
for(User user : userList){
String itemId = user.getItemId();
if (!userMap.containsKey(itemId)) {
double weight = calculateWeight(user.getBehaviorType(), user.getCount());
user.setWeight(weight);
userMap.put(itemId, user);
list.add(user);
}else {
User temp = userMap.get(itemId);
if (temp.getBehaviorType() < user.getBehaviorType()) {
double weight = calculateWeight(user.getBehaviorType(), user.getCount());
user.setWeight(weight);
userMap.put(itemId, user);
list.add(user);
}
}
}
userMap.clear();
return list;
}
public static void sortScoreMap(Map<String, List<Score>> scoreMap) {
Set<String> userSet = scoreMap.keySet();
for(String userId : userSet){
List<Score> temp = scoreMap.get(userId);
Collections.sort(temp);
scoreMap.put(userId, temp);
}
}
public static Map<String, Set<String>> predict(Map<String, List<Score>> scoreMap, List<String> fileNameList, String userDir,int topNUser,int topNItem) throws Exception {
Map<String, Set<String>> recommendList = new HashMap<String, Set<String>>();
for(Entry<String, List<Score>> entry : scoreMap.entrySet()){
String userId1 = entry.getKey();
List<Score> list = entry.getValue();
int countUser = 0;
Set<String> predictItemSet = new LinkedHashSet<String>();
for(Score score : list){
String userId2 = score.getItemId();
if(fileNameList.contains(userId2)){
List<User> userList = FileTool.readFileOne(userDir + userId2, false, "\t", "user");
int countItem = 0;
for(User user : userList){
predictItemSet.add(user.getItemId());
countItem++;
if (countItem == topNItem) {
break;
}
}
countUser++;
}
if (countUser == topNUser) {
break;
}
}
recommendList.put(userId1, predictItemSet);
}
return recommendList;
}
public static void prediction(Map<String, List<String>> predictMap,int predictN, Map<String, List<String>> referenceMap, int refN) {
int count = 0;
for(Entry<String, List<String>> predictEntity : predictMap.entrySet()){
String userId = predictEntity.getKey();
if (referenceMap.containsKey(userId)) {
List<String> predictList = predictEntity.getValue();
for(String itemId : predictList){
if (referenceMap.get(userId).contains(itemId)) {
count++;
}
}
}
}
double precision = (1.0 * count / predictN) * 100;
double recall = (1.0 * count / refN) * 100;
double f1 = (2 * precision * recall)/(precision + recall);
System.out.println("precision="+precision+",recall="+recall+",f1="+f1);
}
}
4. 計算模塊
package service;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import entity.Score;
import util.FileTool;
public class CalculateSimilarity {
public static double EuclidDist(Map<String, Double> userMap1,
Map<String, Double> userMap2, Set<String> userSet,
Set<String> itemSet) {
double sum = 0;
for (String itemId : itemSet) {
double score1 = 0.0;
double score2 = 0.0;
if (userMap1.containsKey(itemId) && userMap2.containsKey(itemId)) {
score1 = userMap1.get(itemId);
score2 = userMap2.get(itemId);
} else if (userMap1.containsKey(itemId)) {
score1 = userMap1.get(itemId);
} else if (userMap2.containsKey(itemId)) {
score2 = userMap2.get(itemId);
}
double temp = Math.pow((score1 - score2), 2);
sum += temp;
}
sum = Math.sqrt(sum);
return sum;
}
public static double CosineDist(Map<String, Double> userMap1,
Map<String, Double> userMap2, Set<String> userSet,
Set<String> itemSet) {
double dist = 0;
double numerator = 0; // 分子
double denominator1 = 0; // 分母
double denominator2 = 0; // 分母
for (String itemId : itemSet) {
double score1 = 0.0;
double score2 = 0.0;
if (userMap1.containsKey(itemId) && userMap2.containsKey(itemId)) {
numerator++;
score1 = userMap1.get(itemId);
score2 = userMap2.get(itemId);
} else if (userMap1.containsKey(itemId)) {
score1 = userMap1.get(itemId);
} else if (userMap2.containsKey(itemId)) {
score2 = userMap2.get(itemId);
}
denominator1 += Math.pow(score1, 2);
denominator2 += Math.pow(score2, 2);
}
dist = ((1.0 * numerator) / (Math.sqrt(denominator1) * Math
.sqrt(denominator2)));
return dist;
}
public static double execute(Map<String, Double> userMap1,Map<String, Double> userMap2,Set<String> userSet,Set<String> itemSet) {
double dist = EuclidDist(userMap1, userMap2, userSet, itemSet);
double userScore = 1.0 / (1.0 + dist);
// double userScore = CosineDist(userMap1, userMap2, userSet, itemSet);
return userScore;
}
public static void execute(String userId,Map<String, Map<String, Double>> scoreTable,
Set<String> userSet, Set<String> itemSet) {
for (Entry<String, Map<String, Double>> userEntry : scoreTable.entrySet()) {
String userId2 = userEntry.getKey();
Map<String, Double> userMap2 = userEntry.getValue();
double dist = EuclidDist(scoreTable.get(userId), userMap2, userSet, itemSet);
double userScore = 1.0 / (1.0 + dist);
// double userScore = CosineDist(userMap1, userMap2, userSet, itemSet);
FileTool.ps1.println(userId + "\t" + userId2 + "\t" + userScore);
}
}
public static void execute(Map<String, Map<String, Double>> scoreTable,
Set<String> userSet, Set<String> itemSet) {
List<Score> similarList = new ArrayList<Score>();
for (Entry<String, Map<String, Double>> userEntry1 : scoreTable.entrySet()) {
String userId = userEntry1.getKey();
execute(userId, scoreTable, userSet, itemSet);
}
}
}
5. 腳本
生成userset和itemset:
package script;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import entity.User;
import util.FileTool;
public class MakeSet {
public static void main(String[] args) throws Exception {
String inputDir = args[0];
String outputDir = args[1];
Set<String> userSet = new HashSet<String>();
Set<String> itemSet = new HashSet<String>();
List<String> pathList = FileTool.traverseFolder(inputDir);
for(String path : pathList){
String inputPath = inputDir + path;
List<User> list = FileTool.readFileOne(inputPath, false, "\t", "user");
for(User user : list){
userSet.add(user.getUserId());
itemSet.add(user.getItemId());
}
}
FileTool.initWriter1(outputDir+"userSet");
for(String userId : userSet){
FileTool.ps1.println(userId);
}
FileTool.closeWriter1();
FileTool.initWriter1(outputDir+"itemSet");
for(String itemId : itemSet){
FileTool.ps1.println(itemId);
}
FileTool.closeWriter1();
}
}
map文件構建user-item評分矩陣並計算user間的相似度生成user-user的得分表:
package script;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import entity.Item;
import entity.Score;
import entity.User;
import service.CalculateSimilarity;
import service.DataProcess;
import util.FileTool;
public class SpliteFileAndMakeScoreTable {
public static void main(String[] args) throws Exception {
//String inputDir = "data/fresh_comp_offline/";
//String outputDir = "data/fresh_comp_offline/sample/";
//String inputDir = "data/fresh_comp_offline/sample/";
//String outputDir = "data/fresh_comp_offline/sample/out/";
String inputDir = args[0];
String outputDir = args[1];
//String userPath = inputDir + "tianchi_fresh_comp_train_user.csv";
String userPath = inputDir + args[2];
String outputPath = args[3];
//String outputPath = outputDir + "user.csv";
//FileTool.makeSampleData(userPath, true, outputPath, 10000);
//List<Object> itemList = FileTool.readFileOne(itemPath, true, ",", "item");
//List<User> userList = FileTool.readFileOne(userPath, false, ",", "user");
List<User> userList = FileTool.readFileOne(userPath, false, ",", "user");
Set<String> userSet = new HashSet<String>();
Set<String> itemSet = new HashSet<String>();
Map<String, Map<String, List<User>>> userMap = DataProcess.mapByUser(userList,userSet,itemSet);
userList.clear();
DataProcess.output(userMap, outputDir);
//生成userToItem的打分表
Map<String, Map<String, Double>> scoreTable = DataProcess.makeScoreTable(userMap);
//DataProcess.output(scoreTable, outputDir + "scoreTable.csv" , userSet, itemSet, ",");
userMap.clear();
FileTool.initWriter1(outputPath);
CalculateSimilarity.execute(scoreTable, userSet, itemSet);
FileTool.closeWriter1();
}
}
reduce文件:
package script;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import entity.User;
import service.DataProcess;
import util.FileTool;
public class ReduceFileTest {
public static void main(String[] args) throws Exception {
//String inputDir = "data/fresh_comp_offline/";
//String outputDir = "data/fresh_comp_offline/sample/";
//String inputDir = "data/fresh_comp_offline/sample/";
//String outputDir = "data/fresh_comp_offline/sample/out/";
String inputDir = args[0];
String outputDir = args[1];
//String userPath = inputDir + "tianchi_fresh_comp_train_user.csv";
//String itemPath = inputDir + args[2];
//String userPath = inputDir + args[3];
List<String> pathList = FileTool.traverseFolder(inputDir);
for(String path : pathList){
List<User> userList = FileTool.readFileOne(inputDir+path, false, "\t", "user");
List<User> list = DataProcess.reduceUserByItem(userList);
userList.clear();
FileTool.initWriter1(outputDir + path);
Collections.sort(list);
DataProcess.outputUser(list);
FileTool.closeWriter1();
list.clear();
}
}
}
爲用戶進行推薦,生成預測列表:
package script;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;
import service.DataProcess;
import util.FileTool;
import entity.Score;
public class PredictTest {
public static void main(String[] args) throws Exception {
//String inputDir = "data/fresh_comp_offline/";
//String outputDir = "data/fresh_comp_offline/sample/";
//String inputDir = "data/fresh_comp_offline/sample/";
//String outputDir = "data/fresh_comp_offline/sample/out/";
String inputDir = args[0];
String outputDir = args[1];
//String userPath = inputDir + "tianchi_fresh_comp_train_user.csv";
String inputPath = inputDir + args[2];
String outputPath = inputDir + args[3];
String userDir = args[4];
Map<String, List<Score>> scoreMap = FileTool.loadScoreMap(inputPath, false, "\t");
DataProcess.sortScoreMap(scoreMap);
List<String> fileNameList = FileTool.traverseFolder(userDir);
//我選擇推薦該user的最相似的5個user的前5個item
Map<String, Set<String>> predictMap = DataProcess.predict(scoreMap, fileNameList, userDir, 5, 5);
FileTool.initWriter1(outputPath);
DataProcess.outputRecommendList(predictMap);
FileTool.closeWriter1();
scoreMap.clear();
}
}
計算準確率、召回率、F測度值:
package script;
import java.util.List;
import java.util.Map;
import service.DataProcess;
import util.FileTool;
public class MatchTest2 {
public static void main(String[] args) throws Exception {
String inputDir = args[0];
String inputPath1 = inputDir + args[1];
String userDir = args[2];
Map<String, List<String>> predictMap = FileTool.loadPredictData(inputPath1, false, ",");
int predictN = FileTool.count;
System.out.println(predictN);
FileTool.count = 0;
Map<String, List<String>> referenceMap = FileTool.loadTestData(predictMap, userDir, false, "\t");
int referenceN = FileTool.count;
System.out.println(referenceN);
DataProcess.prediction(predictMap, predictN, referenceMap, referenceN);
}
}
以上爲核心代碼,大家可以參考項目源代碼地址:
http://download.csdn.net/download/u013473512/10141066
https://github.com/Emmitte/recommendSystem