在我看來,東北大學郭貴兵老師的 LibRec 推薦算法開源庫真的是幫了我這種學術小菜鳥很多忙,它幫助我們復現了很多學術論文的方法,同時給予我們很大的方便去自主復現頂會論文。那作爲學術小菜鳥的我如果想要寫自己的推薦算法,怎麼辦呢?答案是我們完全也可以直接利用這個開源庫!
那下面以我自己粗淺的理解配合官方文檔給出指引,如何利用已有輪子 LibRec 去實現自己的推薦算法。
一、LibRec 中推薦的流程
在講如何去實現自己的推薦算法之前,我想要介紹一下 LibRec 中整個推薦的流程:
- 首先我們需要得到算法所需要的數據,比如用戶行爲數據(如最常用的用戶-物品-評分數據,用戶的隱性反饋數據)、附加數據(如社交網絡數據、地理位置數據、物品內容數據)等等
- 其次我們需要對數據進行相應的處理,這裏的處理包括轉化數據格式(比如text、arff數據格式)、根據要求去劃分數據集(比如按比例劃分、留一劃分、k折劃分、指定測試集和數據集)、處理附加數據等。
- 然後我們才能使用推薦算法利用訓練集進行訓練,把得到的結果利用測試集去進行評估推薦算法的好壞,其中如果需要計算相似度的算法,需要考慮相似度計算模塊(比如歐式距離等10餘種)
- 最後如果不需要進行過濾數據的話,把結果保存下來。
整個推薦流程最關鍵的一步在於訓練部分,也就是使用了什麼推薦算法?其他部分的操作在所有算法實現上來說都是一模一樣的,這些部分的輪子其實我們可以不用造,對相應的配置項進行配置直接用就好了。
所以對於實現一個推薦算法只要寫好整個算法的train()方法即可,所以想來是不是會很方便!
二、LibRec 中的 6 個抽象類
在 LibRec 中總共設計了 6 個抽象的基類方便我們去繼承,從而實現不同類型的推薦算法,分別是:
- Abstract Recommender 抽象推薦算法
- Matrix Recommender 基於矩陣抽象推薦算法
- Matrix Probabilistic Graphical Recommender 基於矩陣概率圖模型的抽象推薦算法
- Matrix Factorization Recommender 矩陣分解抽象推薦算法
- Factorization Machine Recommender 因子分解抽象推薦算法
- Social Recommender 社交抽象推薦算法
- Tensor Recommender張量抽象推薦算法
目前在 LibRec 中已有70+ 的推薦算法都是在這些基類的基礎上去進行設計的,它們包括以下7個類別,分別是基準算法、協調過濾算法、基於內容的推薦算法、情景感知算法、深度學習算法、混合算法以及其他擴展算法。
所以如果你需要在 LibRec 中實現自己的算法,首先需要按照自己算法所屬的類別去繼承相應的抽象類,並按要求去實現相應的抽象方法,也可以按自己的需要去重寫抽象類中的方法。
三、實現自己的推薦算法
以繼承 Abstract Recommender 抽象方法爲例,下面是該方法的代碼(可以先略看):
package net.librec.recommender;
import com.google.common.collect.BiMap;
import net.librec.common.LibrecException;
import net.librec.conf.Configuration;
import net.librec.data.DataModel;
import net.librec.job.progress.ProgressBar;
import net.librec.recommender.item.*;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
/**
* Abstract Recommender Methods
*
* @author WangYuFeng and Wang Keqiang
*/
public abstract class AbstractRecommender implements Recommender {
/**
* LOG
*/
protected final Log LOG = LogFactory.getLog(this.getClass());
/**
* is ranking or rating
*/
protected boolean isRanking;
/**
* topN
*/
protected int topN;
/**
* conf
*/
protected Configuration conf;
/**
* RecommenderContext
*/
protected RecommenderContext context;
/**
* early-stop criteria
*/
protected boolean earlyStop;
/**
* verbose
*/
protected static boolean verbose = true;
/**
* objective loss
*/
protected double loss, lastLoss = 0.0d;
/**
* whether to adjust learning rate automatically
*/
protected boolean isBoldDriver;
/**
* decay of learning rate
*/
protected float decay;
/**
* report the training progress
*/
protected ProgressBar progressBar;
/**
* user Mapping Data
*/
public BiMap<String, Integer> userMappingData;
/**
* item Mapping Data
*/
public BiMap<String, Integer> itemMappingData;
/**
* setup
*
* @throws LibrecException if error occurs during setup
*/
protected void setup() throws LibrecException {
conf = context.getConf();//通過 RecommenderContext類獲取所有的配置項到 conf 變量中,而這些變量就是
//librec-default.properties和具體算法中的配置項,比如sbpr-test.properties
isRanking = conf.getBoolean("rec.recommender.isranking");//獲取是否進行排序配置項,比如topN任務這一項都有
if (isRanking) {
topN = conf.getInt("rec.recommender.ranking.topn", 10);//TopN值,默認是10
if (this.topN <= 0) {
throw new IndexOutOfBoundsException("rec.recommender.ranking.topn should be more than 0!");
}
}
earlyStop = conf.getBoolean("rec.recommender.earlystop", false);//是否進行早停策略,默認值是false
verbose = conf.getBoolean("rec.recommender.verbose", true);//是否輸出打印信息,就是控制檯輸出的那些信息,默認值是true
userMappingData = getDataModel().getUserMappingData();//得到用戶隱射數據(具體用途,暫時不詳)
itemMappingData = getDataModel().getItemMappingData();//得到物品隱射數據(具體用途,暫時不詳)
if (verbose) {//如果可以輸出打印消息,則設置進度條的大小
progressBar = new ProgressBar(100, 100);
}
}
/**
* train Model
*
* @throws LibrecException if error occurs during training model
*/
protected abstract void trainModel() throws LibrecException;
/**
* recommend
*
* @param context recommender context
* @throws LibrecException if error occurs during recommending
*/
public void train(RecommenderContext context) throws LibrecException {
this.context = context;
setup();//基本就是配置項的讀取等操作
LOG.info("Job Setup completed.");
trainModel();//調用具體推薦算法的訓練方法
LOG.info("Job Train completed.");
cleanup();
}
/**
* cleanup
*
* @throws LibrecException if error occurs during cleanup
*/
protected void cleanup() throws LibrecException {
}
/**
* (non-Javadoc)
*
* @see net.librec.recommender.Recommender#loadModel(String)
*/
@Override
public void loadModel(String filePath) {
}
/**
* (non-Javadoc)
*
* @see net.librec.recommender.Recommender#saveModel(String)
*/
@Override
public void saveModel(String filePath) {
}
/**
* get Context
*
* @return recommender context
*/
protected RecommenderContext getContext() {
return context;
}
/**
* set Context
*
* @param context recommender context
*/
public void setContext(RecommenderContext context) {
this.context = context;
}
/**
* get Data Model
*
* @return data model
*/
public DataModel getDataModel() {
return context.getDataModel();
}
/**
* get Recommended List
*
* @return Recommended List
*/
//得到推薦的結果
public List<RecommendedItem> getRecommendedList(RecommendedList recommendedList) {
if (recommendedList != null && recommendedList.size() > 0) {
List<RecommendedItem> userItemList = new ArrayList<>();
Iterator<ContextKeyValueEntry> recommendedEntryIter = recommendedList.iterator();
if (userMappingData != null && userMappingData.size() > 0 && itemMappingData != null && itemMappingData.size() > 0) {
BiMap<Integer, String> userMappingInverse = userMappingData.inverse();
BiMap<Integer, String> itemMappingInverse = itemMappingData.inverse();
while (recommendedEntryIter.hasNext()) {
ContextKeyValueEntry contextKeyValueEntry = recommendedEntryIter.next();
if (contextKeyValueEntry != null) {
String userId = userMappingInverse.get(contextKeyValueEntry.getContextIdx());
String itemId = itemMappingInverse.get(contextKeyValueEntry.getKey());
if (StringUtils.isNotBlank(userId) && StringUtils.isNotBlank(itemId)) {
userItemList.add(new GenericRecommendedItem(userId, itemId, contextKeyValueEntry.getValue()));
}
}
}
return userItemList;
}
}
return null;
}
/**
* Post each iteration, we do things:
* <ol>
* <li>print debug information</li>
* <li>check if converged</li>
* <li>if not, adjust learning rate</li>
* </ol>
*
* @param iter current iteration
* @return boolean: true if it is converged; false otherwise
* @throws LibrecException if error occurs
*/
protected boolean isConverged(int iter) throws LibrecException {
float delta_loss = (float) (lastLoss - loss);
// 如果verbose爲真,輸出信息
if (verbose) {
String recName = getClass().getSimpleName();
String info = recName + " iter " + iter + ": loss = " + loss + ", delta_loss = " + delta_loss;
LOG.info(info);
}
//判斷是否有異常
if (Double.isNaN(loss) || Double.isInfinite(loss)) {
//LOG.error("Loss = NaN or Infinity: current settings does not fit the recommender! Change the settings and try again!");
throw new LibrecException("Loss = NaN or Infinity: current settings does not fit the recommender! Change the settings and try again!");
}
//判斷是否收斂
return Math.abs(delta_loss) < 1e-5;
}
public void updateProgress(int currentPoint) {
if (verbose) {
conf.setInt("train.current.progress", currentPoint);
progressBar.showBarByPoint(conf.getInt("train.current.progress"));
}
}
}
如果我們的推薦算法繼承的是這個 Abstract Recommender 抽象類的話,我們實現一個自己的推薦算法的大致流程如下:
-
Override 並且重新寫 setup 方法
setup 方法完成的任務主要是對算法成員變量的初始化,例如從配置文件中讀取參數的操作可以寫在這裏,具體細節可以參考這篇博客中講到的setup方法。 當然,這個步驟是可選步驟,但如果要重新寫setup方法的時候,需要調用原抽象類中的setup方法,第一行用super.setup()
,保證算法的基本參數得到初始化。 -
實現 trainModel 方法
trainModel 方法完成的任務是算法模型的訓練,例如模型的損失函數利用梯度下降進行訓練的過程,也就是需要我們寫模型的地方!!!在基類 Abstract Recommender 中這個方法是爲空的,方便後面繼承的類進行改寫、覆蓋。 -
實現 predict 方法
predict 方法完成的任務是,使用訓練好的模型進行預測。例如對於評分預測算法,在 predict 方法中需要對測試集中的每個評分值進行預測,即對於給定的 user index 和 item index,使用模型預測它們之間的評分。
是的,正如你看到的那麼簡單,只要把以上三個方法寫了就可以了。下面給出一個直接繼承 Abstract Recommender 的推薦算法,比如 USG 算法(好吧,代碼有點長,而且沒有註釋,大家不用細看,不過本意就是讓大家看看以上三個方法在 USG 算法中是怎樣寫的,有沒有發現除了這三個方法外,還有其他的方法,有些部分直接覆蓋了基類的其他方法,有些部分是爲了輔助以上三個方法的部分操作寫的~)
所以如果讀者選擇好了要繼承的基類的話,務必把這個基類是怎麼寫的,有哪些方法看一遍。
package net.librec.recommender.poi;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.google.common.primitives.Ints;
import net.librec.common.LibrecException;
import net.librec.data.convertor.appender.LocationDataAppender;
import net.librec.data.structure.AbstractBaseDataEntry;
import net.librec.data.structure.LibrecDataList;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DataSet;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.recommender.AbstractRecommender;
import net.librec.recommender.item.KeyValue;
import net.librec.recommender.item.RecommendedList;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.*;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.*;
/**
* Ye M, Yin P, Lee W C, et al. Exploiting geographical influence for collaborative point-of-interest recommendation[C]//
* International ACM SIGIR Conference on Research and Development in Information Retrieval. ACM, 2011:325-334.
* @author Yuanyuan Jin
*
* ###special notes###
* 1. prediction for all user, please set:
* data.testset.path = poi/Gowalla/checkin/Gowalla_test.txt
* and delete the para setting for "rec.limit.userNum" in usg.properties
*
* 2. prediction for small user set like userids in [0, 100],
* in usg.properties, please set:
* data.testset.path = poi/Gowalla/checkin/testDataFor101users.txt
* rec.limit.userNum = 101
* In EntropyEvaluator and NoveltyEvaluator, you also need to reset the variable "numUsers" = your limited userNum
*/
public class USGRecommender extends AbstractRecommender {
private SequentialAccessSparseMatrix socialSimilarityMatrix;
private SequentialAccessSparseMatrix userSimilarityMatrix;
private SequentialAccessSparseMatrix socialMatrix;
private SequentialAccessSparseMatrix trainMatrix;
private SequentialAccessSparseMatrix testMatrix;
/**
* weight of the social score part
*/
private double alpha;
/**
* weight of the geographical score part
*/
private double beta;
/**
* tuning parameter in social similarity
*/
private double eta;
/**
* linear coefficients for modeling the "log-log scale" power-law distribution
*/
private double w0;
private double w1;
/**
* number of pois
*/
private int numPois;
/**
* number of users
*/
private int numUsers;
/**
* for limiting test user cardinality
*/
private int limitUserNum;
private static final int BSIZE = 1024 * 1024;
private String socialPath;
private KeyValue<Double, Double>[] locationCoordinates;
@Override
protected void setup() throws LibrecException {
super.setup();
BiMap<Integer, String> userIds = this.userMappingData.inverse();
BiMap<Integer, String> itemIds = this.itemMappingData.inverse();
numPois = itemMappingData.size();
numUsers = userMappingData.size();
trainMatrix = (SequentialAccessSparseMatrix) getDataModel().getTrainDataSet();
testMatrix = (SequentialAccessSparseMatrix) getDataModel().getTestDataSet();
alpha = conf.getDouble("rec.alpha", 0.1d);
beta = conf.getDouble("rec.beta", 0.1d);
eta = conf.getDouble("rec.eta", 0.05d);
//default value is numUsers
limitUserNum = conf.getInt("rec.limit.userNum", numUsers);
locationCoordinates = ((LocationDataAppender) getDataModel().getDataAppender()).getLocationAppender();
userSimilarityMatrix = context.getSimilarity().getSimilarityMatrix().toSparseMatrix();
socialPath = conf.get("dfs.data.dir") + "/" + conf.get("data.social.path");
// for AUCEvaluator and nDCGEvaluator
int[] numDroppedItemsArray = new int[numUsers];
int maxNumTestItemsByUser = 0;
for (int userIdx = 0; userIdx < numUsers; ++userIdx) {
numDroppedItemsArray[userIdx] = numPois - trainMatrix.row(userIdx).getNumEntries();
int numTestItemsByUser = testMatrix.row(userIdx).getNumEntries();
maxNumTestItemsByUser = maxNumTestItemsByUser < numTestItemsByUser ? numTestItemsByUser : maxNumTestItemsByUser;
}
conf.setInts("rec.eval.auc.dropped.num", numDroppedItemsArray);
conf.setInt("rec.eval.key.test.max.num", maxNumTestItemsByUser);
// for EntropyEvaluator
conf.setInt("rec.eval.item.num", testMatrix.columnSize());
// for NoveltyEvaluator
int[] itemPurchasedCount = new int[numPois];
for (int itemIdx = 0; itemIdx < numPois; ++itemIdx) {
int userNum = 0;
int[] userArray = trainMatrix.column(itemIdx).getIndices();
for (int userIdx : userArray) {
if (userIdx >= 0 && userIdx < limitUserNum) {
userNum++;
}
}
userArray = testMatrix.column(itemIdx).getIndices();
for (int userIdx : userArray) {
if (userIdx >= 0 && userIdx < limitUserNum) {
userNum++;
}
}
itemPurchasedCount[itemIdx] = userNum;
}
conf.setInts("rec.eval.item.purchase.num", itemPurchasedCount);
}
@Override
protected void trainModel() throws LibrecException {
LOG.info("start buliding socialmatrix" + new Date());
try {
buildSocialMatrix(socialPath);
} catch (IOException e) {
e.printStackTrace();
}
LOG.info("start buliding socialSimilarityMatrix" + new Date());
buildSocialSimilarity();
LOG.info("start fitting the powerlaw distribution" + new Date());
fitPowerLaw();
}
public double[] predictScore(int userIdx, int itemIdx) {
//score array for three aspects: user preference, social influence and geographical influence
double[] predictScore = new double[]{0.0d, 0.0d, 0.0d};
int[] userArray = trainMatrix.column(itemIdx).getIndices();
List<Integer> userList = Ints.asList(userArray);
/*---------start user preference socre calculation--------*/
//iterator to iterate other similar users for each user
Iterator<Vector.VectorEntry> userSimIter = userSimilarityMatrix.row(userIdx).iterator();
//similarities between userIdx and its neighbors
List<Double> neighborSimis = new ArrayList<>();
while (userSimIter.hasNext()) {
Vector.VectorEntry userRatingEntry = userSimIter.next();
int similarUserIdx = userRatingEntry.index();
if (!userList.contains(similarUserIdx)) {
continue;
}
neighborSimis.add(userRatingEntry.get());
}
if (neighborSimis.size() == 0) {
predictScore[0] = 0.0d;
} else {
double sum = 0.0d;
for (int i = 0; i < neighborSimis.size(); i++) {
sum += neighborSimis.get(i);
}
predictScore[0] = sum;
}
/*---------end user preference socre calculation--------*/
/*---------start social influence socre calculation--------*/
//social similarities between userIdx and its social neighbors
List<Double> socialNeighborSimis = new ArrayList<>();
Iterator<Vector.VectorEntry> friendIter = socialSimilarityMatrix.row(userIdx).iterator();
while (friendIter.hasNext()) {
Vector.VectorEntry userRatingEntry = friendIter.next();
int similarUserIdx = userRatingEntry.index();
if (!userList.contains(similarUserIdx)) {
continue;
}
socialNeighborSimis.add(userRatingEntry.get());
}
if (socialNeighborSimis.size() == 0) {
predictScore[1] = 0.0d;
} else {
double sum = 0.0d;
for (int i = 0; i < socialNeighborSimis.size(); i++) {
sum += socialNeighborSimis.get(i);
}
predictScore[1] = sum;
}
/*---------end social influence socre calculation--------*/
/*---------start geo influence socre calculation--------*/
double geoScore = 1.0d;
int[] itemList = trainMatrix.row(userIdx).getIndices();
if (itemList.length == 0) {
geoScore = 0.0d;
} else {
for (int visitedPOI : itemList) {
double distance = getDistance(locationCoordinates[visitedPOI].getKey(), locationCoordinates[visitedPOI].getValue(),
locationCoordinates[itemIdx].getKey(), locationCoordinates[itemIdx].getValue());
if (distance < 0.01) {
distance = 0.01;
}
geoScore *= w0 * Math.pow(distance, w1);
}
}
predictScore[2] = geoScore;
/*---------end geo influence socre calculation--------*/
return predictScore;
}
public void buildSocialSimilarity() {
Table<Integer, Integer, Double> socialSimilarityTable = HashBasedTable.create();
for (int userIdx = 0; userIdx < numUsers; userIdx++) {
SequentialSparseVector userVector = trainMatrix.row(userIdx);
if (userVector.getNumEntries() == 0) {
continue;
}
int[] socialNeighborList = socialMatrix.column(userIdx).getIndices();
for (int socialNeighborIdx : socialNeighborList) {
if (userIdx < socialNeighborIdx) {
SequentialSparseVector socialVector = trainMatrix.row(socialNeighborIdx);
int[] friendList = socialMatrix.column(socialNeighborIdx).getIndices();
if (socialVector.getNumEntries() == 0 || friendList.length == 0) {
continue;
}
if (getCorrelation(userVector, socialVector) > 0.0 && getCorrelation(socialNeighborList, friendList) > 0.0) {
double sim = (1 - eta) * getCorrelation(userVector, socialVector) + eta * getCorrelation(socialNeighborList, friendList);
if (!Double.isNaN(sim) && sim != 0.0) {
socialSimilarityTable.put(userIdx, socialNeighborIdx, sim);
}
}
}
}
}
socialSimilarityMatrix = new SequentialAccessSparseMatrix(numUsers, numUsers, socialSimilarityTable);
}
/**
* fit the "log-log" scale power law distribution
*/
public void fitPowerLaw() {
Map<Integer, Double> distanceMap = new HashMap<>();
Map<Double, Double> logdistanceMap = new HashMap<>();
int pairNum = 0;
for (int userIdx = 0; userIdx < numUsers; userIdx++) {
int[] itemList = trainMatrix.row(userIdx).getIndices();
if (itemList.length == 0) {
continue;
}
for (int i = 0; i < itemList.length - 1; i++) {
for (int j = i + 1; j < itemList.length; j++) {
double distance = getDistance(locationCoordinates[itemList[i]].getKey(), locationCoordinates[itemList[i]].getValue(),
locationCoordinates[itemList[j]].getKey(), locationCoordinates[itemList[j]].getValue());
if ((int) distance > 0) {
int intDistance = (int) distance;
if (!distanceMap.containsKey(intDistance)) {
distanceMap.put(intDistance, 0.0d);
}
distanceMap.put(intDistance, distanceMap.get(intDistance) + 1.0d);
}
pairNum++;
}
}
}
for (Map.Entry<Integer, Double> distanceEntry : distanceMap.entrySet()) {
logdistanceMap.put(Math.log10(distanceEntry.getKey()), Math.log10(distanceEntry.getValue() * 1.0 / pairNum));
}
/*-------start gradient descent--------*/
w0 = Randoms.random();
w1 = Randoms.random();
//regularization coefficient
double reg = 0.1;
//learn rate
double lrate = 0.00001;
//max number of iterations
int maxIterations = 2000;
for (int i = 0; i < maxIterations; i++) {
//gradients of w0 and w1
double w0Gradient = 0.0d;
double w1Gradient = 0.0d;
for (Map.Entry<Double, Double> distanceEntry : logdistanceMap.entrySet()) {
double distance = distanceEntry.getKey();
double probability = distanceEntry.getValue();
w0Gradient += (w0 + w1 * distance - probability);
w1Gradient += (w0 + w1 * distance - probability) * distance;
}
w0 -= lrate * (w0Gradient + reg * w0);
w1 -= lrate * (w1Gradient + reg * w1);
}
/*-------end gradient descent--------*/
w0 = Math.pow(10, w0);
}
/**
* calculate the spherical distance between location(lat1, long1) and location (lat2, long2)
* @param lat1
* @param long1
* @param lat2
* @param long2
* @return
*/
protected double getDistance(Double lat1, Double long1, Double lat2, Double long2) {
if (Math.abs(lat1 - lat2) < 1e-6 && Math.abs(long1 - long2) < 1e-6) {
return 0.0d;
}
double degreesToRadius = Math.PI / 180.0;
double phi1 = (90.0 - lat1) * degreesToRadius;
double phi2 = (90.0 - lat2) * degreesToRadius;
double theta1 = long1 * degreesToRadius;
double theta2 = long2 * degreesToRadius;
double cos = (Math.sin(phi1) * Math.sin(phi2) * Math.cos(theta1 - theta2) +
Math.cos(phi1) * Math.cos(phi2));
double arc = Math.acos(cos);
double earthRadius = 6371;
return arc * earthRadius;
}
public double getCorrelation(SequentialSparseVector thisVector, SequentialSparseVector thatVector) {
// compute jaccard similarity
Set<Integer> elements = unionArrays(thisVector.getIndices(), thatVector.getIndices());
int numAllElements = elements.size();
int numCommonElements = thisVector.getIndices().length + thatVector.getIndices().length - numAllElements;
return (numCommonElements + 0.0) / numAllElements;
}
public Set<Integer> unionArrays(int[] arr1, int[] arr2) {
Set<Integer> set = new HashSet<>();
for (int num : arr1) {
set.add(num);
}
for (int num : arr2) {
set.add(num);
}
return set;
}
public double getCorrelation(int[] thisList, int[] thatList) {
// compute jaccard similarity
Set<Integer> elements = new HashSet<Integer>();
for (int num : thisList) {
elements.add(num);
}
for (int num : thatList) {
elements.add(num);
}
int numAllElements = elements.size();
int numCommonElements = thisList.length + thatList.length
- numAllElements;
return (numCommonElements + 0.0) / numAllElements;
}
@Override
public RecommendedList recommendRating(DataSet predictDataSet) throws LibrecException {
return null;
}
@Override
public RecommendedList recommendRating(LibrecDataList<AbstractBaseDataEntry> dataList) throws LibrecException {
return null;
}
@Override
public RecommendedList recommendRank() throws LibrecException {
LOG.info("Eveluate for users from id 0 to id\t" + (limitUserNum-1));
RecommendedList recommendedList = new RecommendedList(numUsers);
for (int userIdx = 0; userIdx < numUsers; ++userIdx) {
recommendedList.addList(new ArrayList<>());
}
List<Integer> userList = new ArrayList<>();
for (int userIdx = 0; userIdx < limitUserNum; ++userIdx) {
userList.add(userIdx);
}
userList.parallelStream().forEach((Integer userIdx) -> {
List<Integer> itemList = Ints.asList(trainMatrix.row(userIdx).getIndices());
List<KeyValue<Integer, double[]>> tempItemValueList = new ArrayList<>();
double[] maxScore = new double[]{0.0d, 0.0d, 0.0d};
for (int itemIdx = 0; itemIdx < numPois; ++itemIdx) {
if (!itemList.contains(itemIdx)) {
double[] predictRating = predictScore(userIdx, itemIdx);
if (predictRating[0] >= maxScore[0]) {
maxScore[0] = predictRating[0];
}
if (predictRating[1] >= maxScore[1]) {
maxScore[1] = predictRating[1];
}
if (predictRating[2] >= maxScore[2]) {
maxScore[2] = predictRating[2];
}
tempItemValueList.add(new KeyValue<>(itemIdx, new double[]{predictRating[0], predictRating[1], predictRating[2]}));
}
}
List<KeyValue<Integer, Double>> itemValueList = new ArrayList<>();
//normalize scores
for (KeyValue<Integer, double[]> entry : tempItemValueList) {
double[] scores = entry.getValue();
if (maxScore[0] != 0.0d) {
scores[0] = scores[0] / maxScore[0];
}
if (maxScore[1] != 0.0d) {
scores[1] = scores[1] / maxScore[1];
}
if (maxScore[2] != 0.0d) {
scores[2] = scores[2] / maxScore[2];
}
double predictRating = (1 - alpha - beta) * scores[0] + alpha * scores[1]
+ beta * scores[2];
itemValueList.add(new KeyValue<>(entry.getKey(), predictRating));
}
recommendedList.setList(userIdx, itemValueList);
recommendedList.topNRankByIndex(userIdx, topN);
});
if (recommendedList.size() == 0) {
throw new IndexOutOfBoundsException("No item is recommended, there is something error in the recommendation algorithm! Please check it!");
}
LOG.info("end recommendation");
return recommendedList;
}
@Override
public RecommendedList recommendRank(LibrecDataList<AbstractBaseDataEntry> dataList) throws LibrecException {
return null;
}
/**
* load social relation data
* @param inputDataPath
* @throws IOException
*/
private void buildSocialMatrix(String inputDataPath) throws IOException {
LOG.info("Now loading users' social relation data success! " + socialPath);
Table<Integer, Integer, Double> dataTable = HashBasedTable.create();
final List<File> files = new ArrayList<File>();
final ArrayList<Long> fileSizeList = new ArrayList<Long>();
SimpleFileVisitor<Path> finder = new SimpleFileVisitor<Path>() {
@Override
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
fileSizeList.add(file.toFile().length());
files.add(file.toFile());
return super.visitFile(file, attrs);
}
};
Files.walkFileTree(Paths.get(inputDataPath), finder);
long allFileSize = 0;
for (Long everyFileSize : fileSizeList) {
allFileSize = allFileSize + everyFileSize.longValue();
}
for (File dataFile : files) {
FileInputStream fis = new FileInputStream(dataFile);
FileChannel fileRead = fis.getChannel();
ByteBuffer buffer = ByteBuffer.allocate(BSIZE);
int len;
String bufferLine = new String();
byte[] bytes = new byte[BSIZE];
while ((len = fileRead.read(buffer)) != -1) {
buffer.flip();
buffer.get(bytes, 0, len);
bufferLine = bufferLine.concat(new String(bytes, 0, len)).replaceAll("\r", "\n");
String[] bufferData = bufferLine.split("(\n)+");
boolean isComplete = bufferLine.endsWith("\n");
int loopLength = isComplete ? bufferData.length : bufferData.length - 1;
for (int i = 0; i < loopLength; i++) {
String line = new String(bufferData[i]);
String[] data = line.trim().split("[ \t,]+");
String userA = data[0];
String userB = data[1];
Double rate = (data.length >= 3) ? Double.valueOf(data[2]) : 1.0;
if (this.userMappingData.containsKey(userA) && this.userMappingData.containsKey(userB)) {
int row = this.userMappingData.get(userA);
int col = this.userMappingData.get(userB);
dataTable.put(row, col, rate);
dataTable.put(col, row, rate);
}
}
if (!isComplete) {
bufferLine = bufferData[bufferData.length - 1];
}
buffer.clear();
}
fileRead.close();
fis.close();
}
int numRows = this.userMappingData.size(), numCols = this.userMappingData.size();
socialMatrix = new SequentialAccessSparseMatrix(numRows, numCols, dataTable);
dataTable = null;
LOG.info("Load users' social relation data success! " + socialPath);
}
}
四、測試自己的推薦算法
同樣,自己已經寫完了以上部分,那如何去測試自己算法的好壞呢?
以上面 usg 方法爲例,直接利用 RecommenderJob 函數配合 usg 算法的配置項進行運行,即可看最後的實驗效果:
另外,對於 RecommenderJob 類,它是一個封裝以上整個推薦流程的類,包括數據集處理、劃分、訓練、預測、評估等,只要傳入相應的配置項(如上面的rec/poi/usg-test.properties,不瞭解的可以看我往期寫的內容),指定運行的推薦算法(比如這個 usg),它就會幫你跑整個實驗。
這麼講起來優點複雜,之後有空,會直接更新這個類的代碼走讀~讓大家看的更明白。
五、一點小建議
以上的代碼在源碼中都有,如何讀者對我上面寫的東西不知所云,完全可以去看看相應的源碼,配合我這裏寫的內容,琢磨琢磨,相信讀者都會比我最初學習的快!