時間序列分類算法之時間序列森林(TSF)

算法介紹

      時間序列森林(Time Series Forest, TSF)模型將時間序列轉化爲子序列的均值、方差和斜率等統計特徵,並使用隨機森林進行分類。TSF通過使用隨機森林方法(以每個間隔的統計信息作爲特徵)來克服間隔特徵空間巨大的問題。訓練一棵樹涉及選擇根號m 個隨機區間,生成每個系列的隨機區間的均值,標準差和斜率,然後在所得的3根號m 個特徵上創建和訓練一棵樹。

       分類樹有兩個特徵。首先,預先定義固定數量的評估點,而不是評估所有可能的分割點,以獲取最佳信息增益。作者認爲這是使分類器更快的權宜之計,因爲它消除了對每個案例進行分類的需要屬性值。其次,介紹了一種細化的分割標準,以在具有相等信息增益的特徵之間進行選擇。這被定義爲分割餘量和最接近情況之間的距離。這個想法背後的直覺是,如果兩個拆分具有相等的熵增益,則應該選擇距離最近的情況最遠的拆分。如果評估了所有可能的間隔,則該度量將沒有價值,因爲根據定義,分割點被視爲個案之間的等距距離。

算法代碼分析

	public class TSF extends AbstractClassifierWithTrainingData implements SaveParameterInfo, TrainAccuracyEstimate{
    boolean setSeed=false;
    int seed=0;//隨機種子
    RandomTree[] trees;//存儲森林中的每一棵樹
    int numTrees=500;//樹的數量
    int numFeatures;//特徵個數,√m
    int[][][] intervals;//用於存儲子序列位置
    Random rand;//隨機函數
    Instances testHolder;
    boolean trainCV=false;  //是否進行交叉驗證獲取最優參數
 
    String trainCVPath="";//存儲交叉驗證結果路徑
    public TSF(){
        rand=new Random();
    }
    public TSF(int s){
        rand=new Random();
        seed=s;
        rand.setSeed(seed);
        setSeed=true;
    }
    public void setSeed(int s){
        this.setSeed=true;
        seed=s;
        rand=new Random();
        rand.setSeed(seed);
    }
    @Override
    public void writeCVTrainToFile(String train) {
        trainCVPath=train;
        trainCV=true;
    }
 @Override
    public void setFindTrainAccuracyEstimate(boolean setCV){
        trainCV=setCV;
    }
    
    @Override
    public boolean findsTrainAccuracyEstimate(){ return trainCV;}
   
      
    @Override
    public void buildClassifier(Instances data) throws Exception {
        long t1=System.currentTimeMillis();
        //特徵個數
        numFeatures=(int)Math.sqrt(data.numAttributes()-1);
        
         if(trainCV){//如果爲真,則進行交叉驗證
            int numFolds=setNumberOfFolds(data);
            CrossValidator cv = new CrossValidator();
            if (setSeed)
              cv.setSeed(seed);
            cv.setNumFolds(numFolds);
           //遞歸調用算法計算準確率
            TSF tsf=new TSF();
            tsf.trainCV=false;
            trainResults=cv.crossValidateWithStats(tsf,data);
        }
        numFeatures=(int)Math.sqrt(data.numAttributes()-1);
        intervals =new int[numTrees][][];
        trees=new RandomTree[numTrees];
        //初始化輸出向量格式. 
        FastVector atts=new FastVector();
        String name;
        //最終向量長度爲3√m
        for(int j=0;j<numFeatures*3;j++){
                name = "F"+j;
                atts.addElement(new Attribute(name));
        }
        //設置類屬性			
        Attribute target =data.attribute(data.classIndex());

        FastVector vals=new FastVector(target.numValues());
        for(int j=0;j<target.numValues();j++)
                vals.addElement(target.value(j));
        atts.addElement(new Attribute(data.attribute(data.classIndex()).name(),vals));
//初始化空實例                
        Instances result = new Instances("Tree",atts,data.numInstances());
        result.setClassIndex(result.numAttributes()-1);
        for(int i=0;i<data.numInstances();i++){
            DenseInstance in=new DenseInstance(result.numAttributes());
            in.setValue(result.numAttributes()-1,data.instance(i).classValue());
            result.add(in);
        }
        
        testHolder =new Instances(result,0);       
        DenseInstance in=new DenseInstance(result.numAttributes());
        testHolder.add(in);
//初始化每一棵樹       
        for(int i=0;i<numTrees;i++){
            intervals[i]=new int[numFeatures][2];  //開始和結束結點
            for(int j=0;j<numFeatures;j++){//隨機獲取沒棵樹的子序列位置
               intervals[i][j][0]=rand.nextInt(data.numAttributes()-1);     
               int length=rand.nextInt(data.numAttributes()-1-intervals[i][j][0]);//最小長度爲3
               intervals[i][j][1]=intervals[i][j][0]+length;
            }
        //2. 生成並存儲樹            
            for(int j=0;j<numFeatures;j++){
                //遍歷數據集中的實例
                for(int k=0;k<data.numInstances();k++){
                    //提取每個實例的子序列
                    double[] series=data.instance(k).toDoubleArray();
                    //每個子序列提取3個屬性,構成屬性集合
                    FeatureSet f= new FeatureSet();
                    f.setFeatures(series, intervals[i][j][0], intervals[i][j][1]);
                    result.instance(k).setValue(j*3, f.mean);
                    result.instance(k).setValue(j*3+1, f.stDev);
                    result.instance(k).setValue(j*3+2, f.slope);
                }
            }
//Set features
/*Create and build tree using all the features. Feature selection
  has already occurred
        */
            trees[i]=new RandomTree();   
            trees[i].setKValue(numFeatures);
            trees[i].buildClassifier(result);
        }
        long t2=System.currentTimeMillis();
        trainResults.buildTime=t2-t1;
        if(trainCVPath!=""){//存儲交叉驗證後的參數
             OutFile of=new OutFile(trainCVPath);
             of.writeLine(data.relationName()+",TSF,train");
             of.writeLine(getParameters());
            of.writeLine(trainResults.acc+"");
            double[] trueClassVals,predClassVals;
            trueClassVals=trainResults.getTrueClassVals();
            predClassVals=trainResults.getPredClassVals();
            for(int i=0;i<data.numInstances();i++){
                //Basic sanity check
                if(data.instance(i).classValue()!=trueClassVals[i]){
                    throw new Exception("ERROR in TSF cross validation, class mismatch!");
                }
                of.writeString((int)trueClassVals[i]+","+(int)predClassVals[i]+",");
                for(double d:trainResults.getDistributionForInstance(i))
                    of.writeString(","+d);
                of.writeString("\n");
            }
        }
        
    }
    
    @Override
    public double classifyInstance(Instance ins) throws Exception {
        int[] votes=new int[ins.numClasses()];
//Build instance
        double[] series=ins.toDoubleArray();
        for(int i=0;i<trees.length;i++){
            for(int j=0;j<numFeatures;j++){
                    //extract the interval
                    FeatureSet f= new FeatureSet();
                    f.setFeatures(series, intervals[i][j][0], intervals[i][j][1]);
                    testHolder.instance(0).setValue(j*3, f.mean);
                    testHolder.instance(0).setValue(j*3+1, f.stDev);
                    testHolder.instance(0).setValue(j*3+2, f.slope);
                }
            int c=(int)trees[i].classifyInstance(testHolder.instance(0));
            votes[c]++;
        }
//Return majority vote            
       int maxVote=0;
       for(int i=1;i<votes.length;i++)
           if(votes[i]>votes[maxVote])
               maxVote=i;
        return maxVote;
    }

//屬性集合
    public static class FeatureSet{
        double mean;//均值
        double stDev;//方差
        double slope;//斜率,指的是用直線擬合子序列之後,該直線的斜率
        RandomForest r; 
        public void setFeatures(double[] data, int start, int end){
            double sumX=0,sumYY=0;
            double sumY=0,sumXY=0,sumXX=0;
            int length=end-start+1;
            for(int i=start;i<=end;i++){
                sumY+=data[i];
                sumYY+=data[i]*data[i];
                sumX+=(i-start);
                sumXX+=(i-start)*(i-start);
                sumXY+=data[i]*(i-start);
            }
            mean=sumY/length;
            stDev=sumYY-(sumY*sumY)/length;
            slope=(sumXY-(sumX*sumY)/length);
            if(sumXX-(sumX*sumX)/length!=0)
                slope/=sumXX-(sumX*sumX)/length;
            else
                slope=0;
            stDev/=length;
            if(stDev==0)    //Flat line
                slope=0;
//            else
//                stDev=Math.sqrt(stDev);
            if(slope==0)
                stDev=0;
        }
        public void setFeatures(double[] data){
            setFeatures(data,0,data.length-1);
        }
        @Override
        public String toString(){
            return "mean="+mean+" stdev = "+stDev+" slope ="+slope;
        }
    } 
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章