算法介紹
時間序列森林(Time Series Forest, TSF)模型將時間序列轉化爲子序列的均值、方差和斜率等統計特徵,並使用隨機森林進行分類。TSF通過使用隨機森林方法(以每個間隔的統計信息作爲特徵)來克服間隔特徵空間巨大的問題。訓練一棵樹涉及選擇根號m 個隨機區間,生成每個系列的隨機區間的均值,標準差和斜率,然後在所得的3根號m 個特徵上創建和訓練一棵樹。
分類樹有兩個特徵。首先,預先定義固定數量的評估點,而不是評估所有可能的分割點,以獲取最佳信息增益。作者認爲這是使分類器更快的權宜之計,因爲它消除了對每個案例進行分類的需要屬性值。其次,介紹了一種細化的分割標準,以在具有相等信息增益的特徵之間進行選擇。這被定義爲分割餘量和最接近情況之間的距離。這個想法背後的直覺是,如果兩個拆分具有相等的熵增益,則應該選擇距離最近的情況最遠的拆分。如果評估了所有可能的間隔,則該度量將沒有價值,因爲根據定義,分割點被視爲個案之間的等距距離。
算法代碼分析
public class TSF extends AbstractClassifierWithTrainingData implements SaveParameterInfo, TrainAccuracyEstimate{
boolean setSeed=false;
int seed=0;//隨機種子
RandomTree[] trees;//存儲森林中的每一棵樹
int numTrees=500;//樹的數量
int numFeatures;//特徵個數,√m
int[][][] intervals;//用於存儲子序列位置
Random rand;//隨機函數
Instances testHolder;
boolean trainCV=false; //是否進行交叉驗證獲取最優參數
String trainCVPath="";//存儲交叉驗證結果路徑
public TSF(){
rand=new Random();
}
public TSF(int s){
rand=new Random();
seed=s;
rand.setSeed(seed);
setSeed=true;
}
public void setSeed(int s){
this.setSeed=true;
seed=s;
rand=new Random();
rand.setSeed(seed);
}
@Override
public void writeCVTrainToFile(String train) {
trainCVPath=train;
trainCV=true;
}
@Override
public void setFindTrainAccuracyEstimate(boolean setCV){
trainCV=setCV;
}
@Override
public boolean findsTrainAccuracyEstimate(){ return trainCV;}
@Override
public void buildClassifier(Instances data) throws Exception {
long t1=System.currentTimeMillis();
//特徵個數
numFeatures=(int)Math.sqrt(data.numAttributes()-1);
if(trainCV){//如果爲真,則進行交叉驗證
int numFolds=setNumberOfFolds(data);
CrossValidator cv = new CrossValidator();
if (setSeed)
cv.setSeed(seed);
cv.setNumFolds(numFolds);
//遞歸調用算法計算準確率
TSF tsf=new TSF();
tsf.trainCV=false;
trainResults=cv.crossValidateWithStats(tsf,data);
}
numFeatures=(int)Math.sqrt(data.numAttributes()-1);
intervals =new int[numTrees][][];
trees=new RandomTree[numTrees];
//初始化輸出向量格式.
FastVector atts=new FastVector();
String name;
//最終向量長度爲3√m
for(int j=0;j<numFeatures*3;j++){
name = "F"+j;
atts.addElement(new Attribute(name));
}
//設置類屬性
Attribute target =data.attribute(data.classIndex());
FastVector vals=new FastVector(target.numValues());
for(int j=0;j<target.numValues();j++)
vals.addElement(target.value(j));
atts.addElement(new Attribute(data.attribute(data.classIndex()).name(),vals));
//初始化空實例
Instances result = new Instances("Tree",atts,data.numInstances());
result.setClassIndex(result.numAttributes()-1);
for(int i=0;i<data.numInstances();i++){
DenseInstance in=new DenseInstance(result.numAttributes());
in.setValue(result.numAttributes()-1,data.instance(i).classValue());
result.add(in);
}
testHolder =new Instances(result,0);
DenseInstance in=new DenseInstance(result.numAttributes());
testHolder.add(in);
//初始化每一棵樹
for(int i=0;i<numTrees;i++){
intervals[i]=new int[numFeatures][2]; //開始和結束結點
for(int j=0;j<numFeatures;j++){//隨機獲取沒棵樹的子序列位置
intervals[i][j][0]=rand.nextInt(data.numAttributes()-1);
int length=rand.nextInt(data.numAttributes()-1-intervals[i][j][0]);//最小長度爲3
intervals[i][j][1]=intervals[i][j][0]+length;
}
//2. 生成並存儲樹
for(int j=0;j<numFeatures;j++){
//遍歷數據集中的實例
for(int k=0;k<data.numInstances();k++){
//提取每個實例的子序列
double[] series=data.instance(k).toDoubleArray();
//每個子序列提取3個屬性,構成屬性集合
FeatureSet f= new FeatureSet();
f.setFeatures(series, intervals[i][j][0], intervals[i][j][1]);
result.instance(k).setValue(j*3, f.mean);
result.instance(k).setValue(j*3+1, f.stDev);
result.instance(k).setValue(j*3+2, f.slope);
}
}
//Set features
/*Create and build tree using all the features. Feature selection
has already occurred
*/
trees[i]=new RandomTree();
trees[i].setKValue(numFeatures);
trees[i].buildClassifier(result);
}
long t2=System.currentTimeMillis();
trainResults.buildTime=t2-t1;
if(trainCVPath!=""){//存儲交叉驗證後的參數
OutFile of=new OutFile(trainCVPath);
of.writeLine(data.relationName()+",TSF,train");
of.writeLine(getParameters());
of.writeLine(trainResults.acc+"");
double[] trueClassVals,predClassVals;
trueClassVals=trainResults.getTrueClassVals();
predClassVals=trainResults.getPredClassVals();
for(int i=0;i<data.numInstances();i++){
//Basic sanity check
if(data.instance(i).classValue()!=trueClassVals[i]){
throw new Exception("ERROR in TSF cross validation, class mismatch!");
}
of.writeString((int)trueClassVals[i]+","+(int)predClassVals[i]+",");
for(double d:trainResults.getDistributionForInstance(i))
of.writeString(","+d);
of.writeString("\n");
}
}
}
@Override
public double classifyInstance(Instance ins) throws Exception {
int[] votes=new int[ins.numClasses()];
//Build instance
double[] series=ins.toDoubleArray();
for(int i=0;i<trees.length;i++){
for(int j=0;j<numFeatures;j++){
//extract the interval
FeatureSet f= new FeatureSet();
f.setFeatures(series, intervals[i][j][0], intervals[i][j][1]);
testHolder.instance(0).setValue(j*3, f.mean);
testHolder.instance(0).setValue(j*3+1, f.stDev);
testHolder.instance(0).setValue(j*3+2, f.slope);
}
int c=(int)trees[i].classifyInstance(testHolder.instance(0));
votes[c]++;
}
//Return majority vote
int maxVote=0;
for(int i=1;i<votes.length;i++)
if(votes[i]>votes[maxVote])
maxVote=i;
return maxVote;
}
//屬性集合
public static class FeatureSet{
double mean;//均值
double stDev;//方差
double slope;//斜率,指的是用直線擬合子序列之後,該直線的斜率
RandomForest r;
public void setFeatures(double[] data, int start, int end){
double sumX=0,sumYY=0;
double sumY=0,sumXY=0,sumXX=0;
int length=end-start+1;
for(int i=start;i<=end;i++){
sumY+=data[i];
sumYY+=data[i]*data[i];
sumX+=(i-start);
sumXX+=(i-start)*(i-start);
sumXY+=data[i]*(i-start);
}
mean=sumY/length;
stDev=sumYY-(sumY*sumY)/length;
slope=(sumXY-(sumX*sumY)/length);
if(sumXX-(sumX*sumX)/length!=0)
slope/=sumXX-(sumX*sumX)/length;
else
slope=0;
stDev/=length;
if(stDev==0) //Flat line
slope=0;
// else
// stDev=Math.sqrt(stDev);
if(slope==0)
stDev=0;
}
public void setFeatures(double[] data){
setFeatures(data,0,data.length-1);
}
@Override
public String toString(){
return "mean="+mean+" stdev = "+stDev+" slope ="+slope;
}
}
}