算法介绍
时间序列森林(Time Series Forest, TSF)模型将时间序列转化为子序列的均值、方差和斜率等统计特征,并使用随机森林进行分类。TSF通过使用随机森林方法(以每个间隔的统计信息作为特征)来克服间隔特征空间巨大的问题。训练一棵树涉及选择根号m 个随机区间,生成每个系列的随机区间的均值,标准差和斜率,然后在所得的3根号m 个特征上创建和训练一棵树。
分类树有两个特征。首先,预先定义固定数量的评估点,而不是评估所有可能的分割点,以获取最佳信息增益。作者认为这是使分类器更快的权宜之计,因为它消除了对每个案例进行分类的需要属性值。其次,介绍了一种细化的分割标准,以在具有相等信息增益的特征之间进行选择。这被定义为分割余量和最接近情况之间的距离。这个想法背后的直觉是,如果两个拆分具有相等的熵增益,则应该选择距离最近的情况最远的拆分。如果评估了所有可能的间隔,则该度量将没有价值,因为根据定义,分割点被视为个案之间的等距距离。
算法代码分析
public class TSF extends AbstractClassifierWithTrainingData implements SaveParameterInfo, TrainAccuracyEstimate{
boolean setSeed=false;
int seed=0;//随机种子
RandomTree[] trees;//存储森林中的每一棵树
int numTrees=500;//树的数量
int numFeatures;//特征个数,√m
int[][][] intervals;//用于存储子序列位置
Random rand;//随机函数
Instances testHolder;
boolean trainCV=false; //是否进行交叉验证获取最优参数
String trainCVPath="";//存储交叉验证结果路径
public TSF(){
rand=new Random();
}
public TSF(int s){
rand=new Random();
seed=s;
rand.setSeed(seed);
setSeed=true;
}
public void setSeed(int s){
this.setSeed=true;
seed=s;
rand=new Random();
rand.setSeed(seed);
}
@Override
public void writeCVTrainToFile(String train) {
trainCVPath=train;
trainCV=true;
}
@Override
public void setFindTrainAccuracyEstimate(boolean setCV){
trainCV=setCV;
}
@Override
public boolean findsTrainAccuracyEstimate(){ return trainCV;}
@Override
public void buildClassifier(Instances data) throws Exception {
long t1=System.currentTimeMillis();
//特征个数
numFeatures=(int)Math.sqrt(data.numAttributes()-1);
if(trainCV){//如果为真,则进行交叉验证
int numFolds=setNumberOfFolds(data);
CrossValidator cv = new CrossValidator();
if (setSeed)
cv.setSeed(seed);
cv.setNumFolds(numFolds);
//递归调用算法计算准确率
TSF tsf=new TSF();
tsf.trainCV=false;
trainResults=cv.crossValidateWithStats(tsf,data);
}
numFeatures=(int)Math.sqrt(data.numAttributes()-1);
intervals =new int[numTrees][][];
trees=new RandomTree[numTrees];
//初始化输出向量格式.
FastVector atts=new FastVector();
String name;
//最终向量长度为3√m
for(int j=0;j<numFeatures*3;j++){
name = "F"+j;
atts.addElement(new Attribute(name));
}
//设置类属性
Attribute target =data.attribute(data.classIndex());
FastVector vals=new FastVector(target.numValues());
for(int j=0;j<target.numValues();j++)
vals.addElement(target.value(j));
atts.addElement(new Attribute(data.attribute(data.classIndex()).name(),vals));
//初始化空实例
Instances result = new Instances("Tree",atts,data.numInstances());
result.setClassIndex(result.numAttributes()-1);
for(int i=0;i<data.numInstances();i++){
DenseInstance in=new DenseInstance(result.numAttributes());
in.setValue(result.numAttributes()-1,data.instance(i).classValue());
result.add(in);
}
testHolder =new Instances(result,0);
DenseInstance in=new DenseInstance(result.numAttributes());
testHolder.add(in);
//初始化每一棵树
for(int i=0;i<numTrees;i++){
intervals[i]=new int[numFeatures][2]; //开始和结束结点
for(int j=0;j<numFeatures;j++){//随机获取没棵树的子序列位置
intervals[i][j][0]=rand.nextInt(data.numAttributes()-1);
int length=rand.nextInt(data.numAttributes()-1-intervals[i][j][0]);//最小长度为3
intervals[i][j][1]=intervals[i][j][0]+length;
}
//2. 生成并存储树
for(int j=0;j<numFeatures;j++){
//遍历数据集中的实例
for(int k=0;k<data.numInstances();k++){
//提取每个实例的子序列
double[] series=data.instance(k).toDoubleArray();
//每个子序列提取3个属性,构成属性集合
FeatureSet f= new FeatureSet();
f.setFeatures(series, intervals[i][j][0], intervals[i][j][1]);
result.instance(k).setValue(j*3, f.mean);
result.instance(k).setValue(j*3+1, f.stDev);
result.instance(k).setValue(j*3+2, f.slope);
}
}
//Set features
/*Create and build tree using all the features. Feature selection
has already occurred
*/
trees[i]=new RandomTree();
trees[i].setKValue(numFeatures);
trees[i].buildClassifier(result);
}
long t2=System.currentTimeMillis();
trainResults.buildTime=t2-t1;
if(trainCVPath!=""){//存储交叉验证后的参数
OutFile of=new OutFile(trainCVPath);
of.writeLine(data.relationName()+",TSF,train");
of.writeLine(getParameters());
of.writeLine(trainResults.acc+"");
double[] trueClassVals,predClassVals;
trueClassVals=trainResults.getTrueClassVals();
predClassVals=trainResults.getPredClassVals();
for(int i=0;i<data.numInstances();i++){
//Basic sanity check
if(data.instance(i).classValue()!=trueClassVals[i]){
throw new Exception("ERROR in TSF cross validation, class mismatch!");
}
of.writeString((int)trueClassVals[i]+","+(int)predClassVals[i]+",");
for(double d:trainResults.getDistributionForInstance(i))
of.writeString(","+d);
of.writeString("\n");
}
}
}
@Override
public double classifyInstance(Instance ins) throws Exception {
int[] votes=new int[ins.numClasses()];
//Build instance
double[] series=ins.toDoubleArray();
for(int i=0;i<trees.length;i++){
for(int j=0;j<numFeatures;j++){
//extract the interval
FeatureSet f= new FeatureSet();
f.setFeatures(series, intervals[i][j][0], intervals[i][j][1]);
testHolder.instance(0).setValue(j*3, f.mean);
testHolder.instance(0).setValue(j*3+1, f.stDev);
testHolder.instance(0).setValue(j*3+2, f.slope);
}
int c=(int)trees[i].classifyInstance(testHolder.instance(0));
votes[c]++;
}
//Return majority vote
int maxVote=0;
for(int i=1;i<votes.length;i++)
if(votes[i]>votes[maxVote])
maxVote=i;
return maxVote;
}
//属性集合
public static class FeatureSet{
double mean;//均值
double stDev;//方差
double slope;//斜率,指的是用直线拟合子序列之后,该直线的斜率
RandomForest r;
public void setFeatures(double[] data, int start, int end){
double sumX=0,sumYY=0;
double sumY=0,sumXY=0,sumXX=0;
int length=end-start+1;
for(int i=start;i<=end;i++){
sumY+=data[i];
sumYY+=data[i]*data[i];
sumX+=(i-start);
sumXX+=(i-start)*(i-start);
sumXY+=data[i]*(i-start);
}
mean=sumY/length;
stDev=sumYY-(sumY*sumY)/length;
slope=(sumXY-(sumX*sumY)/length);
if(sumXX-(sumX*sumX)/length!=0)
slope/=sumXX-(sumX*sumX)/length;
else
slope=0;
stDev/=length;
if(stDev==0) //Flat line
slope=0;
// else
// stDev=Math.sqrt(stDev);
if(slope==0)
stDev=0;
}
public void setFeatures(double[] data){
setFeatures(data,0,data.length-1);
}
@Override
public String toString(){
return "mean="+mean+" stdev = "+stDev+" slope ="+slope;
}
}
}