Hadoop 實現kmeans 算法

關於kmeans說在前面:kmeans算法有一個硬性的規定就是簇的個數要提前設定。大家可能會質疑這個限制是否影響聚類效果,但是這種擔心是多餘的。在該算法誕生的這麼多年裏,該算法已被證明能夠廣泛的用於解決現實世界問題,即使簇個數k值是次優的,聚類的質量不會受到太大影響。

聚類在現實中很大應用就是對新聞報道進行聚類,以得到頂層類別,如政治、科學、體育、財經等。對此我們傾向於選擇比較小的k值,可能10-20之間。如果需要細粒度的主體,則需要更大的k值。爲了得到較好的聚類質量,首先需要對k值進行預估。一個最簡單粗暴的方法就是基於數據量和需要的簇個數估計,比如我們有100萬新聞,我們希望每個類別新聞有500篇,那就可以簡單估算k值爲1000000/500=2000。

需要明確一點就是kmeans聚類質量的決定因素是使用的距離衡量標準。

關於kmeans 算法思路可以參考:kmeans

算法原理比較簡單,現在需要做的是基於mapreduce 框架去實現這個算法。

從理論上來講用MapReduce技術實現KMeans算法是很Natural的想法:在Mapper中逐個計算樣本點離哪個中心最近,然後發出key-value(樣本點所屬的簇編號,樣本點);shuffle後在Reducer中屬於同一個質心的樣本點在一個list中,方便我們計算新的中心,然後發出新的key-value(質心編號,質心)。但是技術上的事並沒有理論層面那麼簡單。

要實現這個算法需要解決兩個問題:

1. 如何存儲每次聚類的質心。

2. 如何存儲原始聚類數據。

Hadoop中變量或者說數據共享的三種主要方式:

序號 方法
1 使用Configuration的set方法,只適合數據內容比較小的場景
2 將共享文件放在HDFS上,每次都去讀取,效率比較低
3 將共享文件放在DistributedCache裏,在setup初始化一次後,即可多次使用,缺點是不支持修改操作,僅能讀取
從上面可以知道,我們存儲原始聚類數據用DistributedCache,而存儲質心在HDFS file上。

此時我們需要2個質心文件:一個存放上一次的質心prevCenterFile,一個存放reducer更新後的質心currCenterFile。Mapper從prevCenterFile中讀取質心,Reducer把更新後有質心寫入currCenterFile。在主函數中讀入prevCenterFile和currCenterFile,比較前後兩次的質心是否相同(或足夠地接近),如果相同則停止迭代,否則就用currCenterFile覆prevCenterFile(使用fs.rename),進入下一次的迭代。(PS:其實這種方式效率也不是很高,真正使用spark 基於內存運算會效率更高)

代碼參考:kmeans 參考 

package kmeans;
 
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
 
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.io.Writable;
 
public class Sample implements Writable{
    private static final Log log=LogFactory.getLog(Sample.class);
    public static final int DIMENTION=60;
    public double arr[];
     
    public Sample(){
        arr=new double[DIMENTION];
    }
     
    public static double getEulerDist(Sample vec1,Sample vec2){
        if(!(vec1.arr.length==DIMENTION && vec2.arr.length==DIMENTION)){
            log.error("vector's dimention is not "+DIMENTION);
            System.exit(1);
        }
        double dist=0.0;
        for(int i=0;i<DIMENTION;++i){
            dist+=(vec1.arr[i]-vec2.arr[i])*(vec1.arr[i]-vec2.arr[i]);
        }
        return Math.sqrt(dist);
    }
     
    public void clear(){
        for(int i=0;i<arr.length;i++)
            arr[i]=0.0;
    }
     
    @Override
    public String toString(){
        String rect=String.valueOf(arr[0]);
        for(int i=1;i<DIMENTION;i++)
            rect+="\t"+String.valueOf(arr[i]);
        return rect;
    }
 
    @Override
    public void readFields(DataInput in) throws IOException {
        String str[]=in.readUTF().split("\\s+");
        for(int i=0;i<DIMENTION;++i)
            arr[i]=Double.parseDouble(str[i]);
    }
 
    @Override
    public void write(DataOutput out) throws IOException {
        out.writeUTF(this.toString());
    }
}

package kmeans;
 
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Vector;
 
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
 
public class KMeans extends Configured implements Tool{
    private static final Log log = LogFactory.getLog(KMeans2.class);
 
    private static final int K = 10;
    private static final int MAXITERATIONS = 300;
    private static final double THRESHOLD = 0.01;
     
    public static boolean stopIteration(Configuration conf) throws IOException{
        FileSystem fs=FileSystem.get(conf);
        Path pervCenterFile=new Path("/user/orisun/input/centers");
        Path currentCenterFile=new Path("/user/orisun/output/part-r-00000");
        if(!(fs.exists(pervCenterFile) && fs.exists(currentCenterFile))){
            log.info("兩個質心文件需要同時存在");
            System.exit(1);
        }
        //比較前後兩次質心的變化是否小於閾值,決定迭代是否繼續
        boolean stop=true;
        String line1,line2;
        FSDataInputStream in1=fs.open(pervCenterFile);
        FSDataInputStream in2=fs.open(currentCenterFile);
        InputStreamReader isr1=new InputStreamReader(in1);
        InputStreamReader isr2=new InputStreamReader(in2);
        BufferedReader br1=new BufferedReader(isr1);
        BufferedReader br2=new BufferedReader(isr2);
        Sample prevCenter,currCenter;
        while((line1=br1.readLine())!=null && (line2=br2.readLine())!=null){
            prevCenter=new Sample();
            currCenter=new Sample();
            String []str1=line1.split("\\s+");
            String []str2=line2.split("\\s+");
            assert(str1[0].equals(str2[0]));
            for(int i=1;i<=Sample.DIMENTION;i++){
                prevCenter.arr[i-1]=Double.parseDouble(str1[i]);
                currCenter.arr[i-1]=Double.parseDouble(str2[i]);
            }
            if(Sample.getEulerDist(prevCenter, currCenter)>THRESHOLD){
                stop=false;
                break;
            }
        }
        //如果還要進行下一次迭代,就用當前質心替代上一次的質心
        if(stop==false){
            fs.delete(pervCenterFile,true);
            if(fs.rename(currentCenterFile, pervCenterFile)==false){
                log.error("質心文件替換失敗");
                System.exit(1);
            }
        }
        return stop;
    }
     
    public static class ClusterMapper extends Mapper<LongWritable, Text, IntWritable, Sample> {
        Vector<Sample> centers = new Vector<Sample>();
        @Override
        //清空centers
        public void setup(Context context){
            for (int i = 0; i < K; i++) {
                centers.add(new Sample());
            }
        }
        @Override
        //從輸入文件讀入centers
        public void map(LongWritable key, Text value, Context context)
                throws IOException, InterruptedException {
            String []str=value.toString().split("\\s+");
            if(str.length!=Sample.DIMENTION+1){
                log.error("讀入centers時維度不對");
                System.exit(1);
            }
            int index=Integer.parseInt(str[0]);
            for(int i=1;i<str.length;i++)
                centers.get(index).arr[i-1]=Double.parseDouble(str[i]);
        }
        @Override
        //找到每個數據點離哪個質心最近
        public void cleanup(Context context) throws IOException,InterruptedException {
            Path []caches=DistributedCache.getLocalCacheFiles(context.getConfiguration());
            if(caches==null || caches.length<=0){
                log.error("data文件不存在");
                System.exit(1);
            }
            BufferedReader br=new BufferedReader(new FileReader(caches[0].toString()));
            Sample sample;
            String line;
            while((line=br.readLine())!=null){
                sample=new Sample();
                String []str=line.split("\\s+");
                for(int i=0;i<Sample.DIMENTION;i++)
                    sample.arr[i]=Double.parseDouble(str[i]);
                 
                int index=-1;
                double minDist=Double.MAX_VALUE;
                for(int i=0;i<K;i++){
                    double dist=Sample.getEulerDist(sample, centers.get(i));
                    if(dist<minDist){
                        minDist=dist;
                        index=i;
                    }
                }
                context.write(new IntWritable(index), sample);
            }
        }
    }
     
    public static class UpdateCenterReducer extends Reducer<IntWritable, Sample, IntWritable, Sample> {
        int prev=-1;
        Sample center=new Sample();;
        int count=0;
        @Override
        //更新每個質心(除最後一個)
        public void reduce(IntWritable key,Iterable<Sample> values,Context context) throws IOException,InterruptedException{
            while(values.iterator().hasNext()){
                Sample value=values.iterator().next();
                if(key.get()!=prev){
                    if(prev!=-1){
                        for(int i=0;i<center.arr.length;i++)
                            center.arr[i]/=count;       
                        context.write(new IntWritable(prev), center);
                    }
                    center.clear();
                    prev=key.get();
                    count=0;
                }
                for(int i=0;i<Sample.DIMENTION;i++)
                    center.arr[i]+=value.arr[i];
                count++;
            }
        }
        @Override
        //更新最後一個質心
        public void cleanup(Context context) throws IOException,InterruptedException{
            for(int i=0;i<center.arr.length;i++)
                center.arr[i]/=count;
            context.write(new IntWritable(prev), center);
        }
    }
 
    @Override
    public int run(String[] args) throws Exception {
        Configuration conf=getConf();
        FileSystem fs=FileSystem.get(conf);
        Job job=new Job(conf);
        job.setJarByClass(KMeans.class);
         
        //質心文件每行的第一個數字是索引
        FileInputFormat.setInputPaths(job, "/user/orisun/input/centers");
        Path outDir=new Path("/user/orisun/output");
        fs.delete(outDir,true);
        FileOutputFormat.setOutputPath(job, outDir);
         
        job.setInputFormatClass(TextInputFormat.class);
        job.setOutputFormatClass(TextOutputFormat.class);
        job.setMapperClass(ClusterMapper.class);
        job.setReducerClass(UpdateCenterReducer.class);
        job.setOutputKeyClass(IntWritable.class);
        job.setOutputValueClass(Sample.class);
         
        return job.waitForCompletion(true)?0:1;
    }
    public static void main(String[] args) throws Exception {
        Configuration conf = new Configuration();
        FileSystem fs=FileSystem.get(conf);
         
        //樣本數據文件中每個樣本不需要標記索引
        Path dataFile=new Path("/user/orisun/input/data");
        DistributedCache.addCacheFile(dataFile.toUri(), conf);
 
        int iteration = 0;
        int success = 1;
        do {
            success ^= ToolRunner.run(conf, new KMeans(), args);
            log.info("iteration "+iteration+" end");
        } while (success == 1 && iteration++ < MAXITERATIONS
                && (!stopIteration(conf)));
        log.info("Success.Iteration=" + iteration);
         
        //迭代完成後再執行一次mapper,輸出每個樣本點所屬的分類--在/user/orisun/output2/part-m-00000中
        //質心文件保存在/user/orisun/input/centers中
        Job job=new Job(conf);
        job.setJarByClass(KMeans.class);
         
        FileInputFormat.setInputPaths(job, "/user/orisun/input/centers");
        Path outDir=new Path("/user/orisun/output2");
        fs.delete(outDir,true);
        FileOutputFormat.setOutputPath(job, outDir);
         
        job.setInputFormatClass(TextInputFormat.class);
        job.setOutputFormatClass(TextOutputFormat.class);
        job.setMapperClass(ClusterMapper.class);
        job.setNumReduceTasks(0);
        job.setOutputKeyClass(IntWritable.class);
        job.setOutputValueClass(Sample.class);
         
        job.waitForCompletion(true);
    }
}


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章