时间序列分类算法之时间序列森林(TSF)

算法介绍

      时间序列森林(Time Series Forest, TSF)模型将时间序列转化为子序列的均值、方差和斜率等统计特征,并使用随机森林进行分类。TSF通过使用随机森林方法(以每个间隔的统计信息作为特征)来克服间隔特征空间巨大的问题。训练一棵树涉及选择根号m 个随机区间,生成每个系列的随机区间的均值,标准差和斜率,然后在所得的3根号m 个特征上创建和训练一棵树。

       分类树有两个特征。首先,预先定义固定数量的评估点,而不是评估所有可能的分割点,以获取最佳信息增益。作者认为这是使分类器更快的权宜之计,因为它消除了对每个案例进行分类的需要属性值。其次,介绍了一种细化的分割标准,以在具有相等信息增益的特征之间进行选择。这被定义为分割余量和最接近情况之间的距离。这个想法背后的直觉是,如果两个拆分具有相等的熵增益,则应该选择距离最近的情况最远的拆分。如果评估了所有可能的间隔,则该度量将没有价值,因为根据定义,分割点被视为个案之间的等距距离。

算法代码分析

	public class TSF extends AbstractClassifierWithTrainingData implements SaveParameterInfo, TrainAccuracyEstimate{
    boolean setSeed=false;
    int seed=0;//随机种子
    RandomTree[] trees;//存储森林中的每一棵树
    int numTrees=500;//树的数量
    int numFeatures;//特征个数,√m
    int[][][] intervals;//用于存储子序列位置
    Random rand;//随机函数
    Instances testHolder;
    boolean trainCV=false;  //是否进行交叉验证获取最优参数
 
    String trainCVPath="";//存储交叉验证结果路径
    public TSF(){
        rand=new Random();
    }
    public TSF(int s){
        rand=new Random();
        seed=s;
        rand.setSeed(seed);
        setSeed=true;
    }
    public void setSeed(int s){
        this.setSeed=true;
        seed=s;
        rand=new Random();
        rand.setSeed(seed);
    }
    @Override
    public void writeCVTrainToFile(String train) {
        trainCVPath=train;
        trainCV=true;
    }
 @Override
    public void setFindTrainAccuracyEstimate(boolean setCV){
        trainCV=setCV;
    }
    
    @Override
    public boolean findsTrainAccuracyEstimate(){ return trainCV;}
   
      
    @Override
    public void buildClassifier(Instances data) throws Exception {
        long t1=System.currentTimeMillis();
        //特征个数
        numFeatures=(int)Math.sqrt(data.numAttributes()-1);
        
         if(trainCV){//如果为真,则进行交叉验证
            int numFolds=setNumberOfFolds(data);
            CrossValidator cv = new CrossValidator();
            if (setSeed)
              cv.setSeed(seed);
            cv.setNumFolds(numFolds);
           //递归调用算法计算准确率
            TSF tsf=new TSF();
            tsf.trainCV=false;
            trainResults=cv.crossValidateWithStats(tsf,data);
        }
        numFeatures=(int)Math.sqrt(data.numAttributes()-1);
        intervals =new int[numTrees][][];
        trees=new RandomTree[numTrees];
        //初始化输出向量格式. 
        FastVector atts=new FastVector();
        String name;
        //最终向量长度为3√m
        for(int j=0;j<numFeatures*3;j++){
                name = "F"+j;
                atts.addElement(new Attribute(name));
        }
        //设置类属性			
        Attribute target =data.attribute(data.classIndex());

        FastVector vals=new FastVector(target.numValues());
        for(int j=0;j<target.numValues();j++)
                vals.addElement(target.value(j));
        atts.addElement(new Attribute(data.attribute(data.classIndex()).name(),vals));
//初始化空实例                
        Instances result = new Instances("Tree",atts,data.numInstances());
        result.setClassIndex(result.numAttributes()-1);
        for(int i=0;i<data.numInstances();i++){
            DenseInstance in=new DenseInstance(result.numAttributes());
            in.setValue(result.numAttributes()-1,data.instance(i).classValue());
            result.add(in);
        }
        
        testHolder =new Instances(result,0);       
        DenseInstance in=new DenseInstance(result.numAttributes());
        testHolder.add(in);
//初始化每一棵树       
        for(int i=0;i<numTrees;i++){
            intervals[i]=new int[numFeatures][2];  //开始和结束结点
            for(int j=0;j<numFeatures;j++){//随机获取没棵树的子序列位置
               intervals[i][j][0]=rand.nextInt(data.numAttributes()-1);     
               int length=rand.nextInt(data.numAttributes()-1-intervals[i][j][0]);//最小长度为3
               intervals[i][j][1]=intervals[i][j][0]+length;
            }
        //2. 生成并存储树            
            for(int j=0;j<numFeatures;j++){
                //遍历数据集中的实例
                for(int k=0;k<data.numInstances();k++){
                    //提取每个实例的子序列
                    double[] series=data.instance(k).toDoubleArray();
                    //每个子序列提取3个属性,构成属性集合
                    FeatureSet f= new FeatureSet();
                    f.setFeatures(series, intervals[i][j][0], intervals[i][j][1]);
                    result.instance(k).setValue(j*3, f.mean);
                    result.instance(k).setValue(j*3+1, f.stDev);
                    result.instance(k).setValue(j*3+2, f.slope);
                }
            }
//Set features
/*Create and build tree using all the features. Feature selection
  has already occurred
        */
            trees[i]=new RandomTree();   
            trees[i].setKValue(numFeatures);
            trees[i].buildClassifier(result);
        }
        long t2=System.currentTimeMillis();
        trainResults.buildTime=t2-t1;
        if(trainCVPath!=""){//存储交叉验证后的参数
             OutFile of=new OutFile(trainCVPath);
             of.writeLine(data.relationName()+",TSF,train");
             of.writeLine(getParameters());
            of.writeLine(trainResults.acc+"");
            double[] trueClassVals,predClassVals;
            trueClassVals=trainResults.getTrueClassVals();
            predClassVals=trainResults.getPredClassVals();
            for(int i=0;i<data.numInstances();i++){
                //Basic sanity check
                if(data.instance(i).classValue()!=trueClassVals[i]){
                    throw new Exception("ERROR in TSF cross validation, class mismatch!");
                }
                of.writeString((int)trueClassVals[i]+","+(int)predClassVals[i]+",");
                for(double d:trainResults.getDistributionForInstance(i))
                    of.writeString(","+d);
                of.writeString("\n");
            }
        }
        
    }
    
    @Override
    public double classifyInstance(Instance ins) throws Exception {
        int[] votes=new int[ins.numClasses()];
//Build instance
        double[] series=ins.toDoubleArray();
        for(int i=0;i<trees.length;i++){
            for(int j=0;j<numFeatures;j++){
                    //extract the interval
                    FeatureSet f= new FeatureSet();
                    f.setFeatures(series, intervals[i][j][0], intervals[i][j][1]);
                    testHolder.instance(0).setValue(j*3, f.mean);
                    testHolder.instance(0).setValue(j*3+1, f.stDev);
                    testHolder.instance(0).setValue(j*3+2, f.slope);
                }
            int c=(int)trees[i].classifyInstance(testHolder.instance(0));
            votes[c]++;
        }
//Return majority vote            
       int maxVote=0;
       for(int i=1;i<votes.length;i++)
           if(votes[i]>votes[maxVote])
               maxVote=i;
        return maxVote;
    }

//属性集合
    public static class FeatureSet{
        double mean;//均值
        double stDev;//方差
        double slope;//斜率,指的是用直线拟合子序列之后,该直线的斜率
        RandomForest r; 
        public void setFeatures(double[] data, int start, int end){
            double sumX=0,sumYY=0;
            double sumY=0,sumXY=0,sumXX=0;
            int length=end-start+1;
            for(int i=start;i<=end;i++){
                sumY+=data[i];
                sumYY+=data[i]*data[i];
                sumX+=(i-start);
                sumXX+=(i-start)*(i-start);
                sumXY+=data[i]*(i-start);
            }
            mean=sumY/length;
            stDev=sumYY-(sumY*sumY)/length;
            slope=(sumXY-(sumX*sumY)/length);
            if(sumXX-(sumX*sumX)/length!=0)
                slope/=sumXX-(sumX*sumX)/length;
            else
                slope=0;
            stDev/=length;
            if(stDev==0)    //Flat line
                slope=0;
//            else
//                stDev=Math.sqrt(stDev);
            if(slope==0)
                stDev=0;
        }
        public void setFeatures(double[] data){
            setFeatures(data,0,data.length-1);
        }
        @Override
        public String toString(){
            return "mean="+mean+" stdev = "+stDev+" slope ="+slope;
        }
    } 
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章