Lire源碼解析一

Lucene image retrieval是以圖搜圖的java開源框架,這幾天沒什麼事,就讀了點源碼,並寫了點註釋,特在這分享給大家。

這裏主要給出的是BOVWBuilder.javaKmeans.javaCluster.java。就是用詞頻對特徵進行編碼,用到是BOF(bag of feature)模型,原理就是提取N張圖片的特徵(比如sift),放在一起就可以得到矩陣,然後對矩陣進行kmeans聚類,就會到到若干個聚類中心;對於新來的一副圖像,我們分別計算該特徵點與那個聚類中心最近,這樣該聚類中心的量值就加1,這樣就可以編碼得到與聚類中心個數想等的維數向量。

一切都從BOVWBuilder中index函數開始...

BOVWBuilder.java(包含註釋)

package lmc.imageretrieval.imageanalysis.bovw;

import java.io.File;
import java.io.IOException;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;

import javax.swing.ProgressMonitor;

import lmc.imageretrieval.imageanalysis.Histogram;
import lmc.imageretrieval.imageanalysis.LireFeature;
import lmc.imageretrieval.tools.DocumentBuilder;
import lmc.imageretrieval.utils.SerializationUtils;

import org.apache.lucene.analysis.core.WhitespaceAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexWriterConfig.OpenMode;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.Version;

public class BOVWBuilder {
    IndexReader reader;
    // number of documents used to build the vocabulary / clusters.
    private int numDocsForVocabulary = 500;
    private int numClusters = 512;
    private Cluster[] clusters = null;
    DecimalFormat df = (DecimalFormat) NumberFormat.getNumberInstance();
    private ProgressMonitor pm = null;

    protected LireFeature lireFeature;
    protected String localFeatureFieldName;
    protected String visualWordsFieldName;
    protected String localFeatureHistFieldName;
    protected String clusterFile;

    public static boolean DELETE_LOCAL_FEATURES = true;
    /**
     *
     * @param reader
     * @deprecated
     */
    public BOVWBuilder(IndexReader reader) {
        this.reader = reader;
    }

    /**
     * Creates a new instance of the BOVWBuilder using the given reader. The numDocsForVocabulary
     * indicates how many documents of the index are used to build the vocabulary (clusters).
     *
     * @param reader               the reader used to open the Lucene index,
     * @param numDocsForVocabulary gives the number of documents for building the vocabulary (clusters).
     * @deprecated
     */
    public BOVWBuilder(IndexReader reader, int numDocsForVocabulary) {
        this.reader = reader;
        this.numDocsForVocabulary = numDocsForVocabulary;
    }

    /**
     * Creates a new instance of the BOVWBuilder using the given reader. The numDocsForVocabulary
     * indicates how many documents of the index are used to build the vocabulary (clusters). The numClusters gives
     * the number of clusters k-means should find. Note that this number should be lower than the number of features,
     * otherwise an exception will be thrown while indexing.
     *
     * @param reader               the index reader
     * @param numDocsForVocabulary the number of documents that should be sampled for building the visual vocabulary
     * @param numClusters          the size of the visual vocabulary
     * @deprecated
     */
    public BOVWBuilder(IndexReader reader, int numDocsForVocabulary, int numClusters) {
        this.numDocsForVocabulary = numDocsForVocabulary;
        this.numClusters = numClusters;
        this.reader = reader;
    }

    /**
     * Creates a new instance of the BOVWBuilder using the given reader. TODO: write
     *
     * @param reader               the index reader
     * @param lireFeature          lireFeature used
     */
    public BOVWBuilder(IndexReader reader, LireFeature lireFeature) {
        this.reader = reader;
        this.lireFeature = lireFeature;
    }

    /**
     * Creates a new instance of the BOVWBuilder using the given reader. The numDocsForVocabulary
     * indicates how many documents of the index are used to build the vocabulary (clusters).
     * TODO: write
     *
     * @param reader               the index reader
     * @param lireFeature          lireFeature used
     * @param numDocsForVocabulary the number of documents that should be sampled for building the visual vocabulary
     */
    public BOVWBuilder(IndexReader reader, LireFeature lireFeature, int numDocsForVocabulary) {
        this.numDocsForVocabulary = numDocsForVocabulary;
        this.reader = reader;
        this.lireFeature = lireFeature;
    }

    /**
     * Creates a new instance of the BOVWBuilder using the given reader. The numDocsForVocabulary
     * indicates how many documents of the index are used to build the vocabulary (clusters). The numClusters gives
     * the number of clusters k-means should find. Note that this number should be lower than the number of features,
     * otherwise an exception will be thrown while indexing. TODO: write
     *
     * @param reader               the index reader
     * @param lireFeature          lireFeature used
     * @param numDocsForVocabulary the number of documents that should be sampled for building the visual vocabulary
     * @param numClusters          the size of the visual vocabulary
     */
    public BOVWBuilder(IndexReader reader, LireFeature lireFeature, int numDocsForVocabulary, int numClusters) {
        this.numDocsForVocabulary = numDocsForVocabulary;
        this.numClusters = numClusters;
        this.reader = reader;
        this.lireFeature = lireFeature;
    }

    protected void init() {
        localFeatureFieldName = lireFeature.getFieldName();
        visualWordsFieldName = lireFeature.getFieldName() + DocumentBuilder.FIELD_NAME_BOVW;
        localFeatureHistFieldName = lireFeature.getFieldName()+ DocumentBuilder.FIELD_NAME_BOVW_VECTOR;
        clusterFile = "./clusters-bovw" + lireFeature.getFeatureName() +  ".dat";
    }

    /**
     * Uses an existing index, where each and every document should have a set of local features. A number of
     * random images (numDocsForVocabulary) is selected and clustered to get a vocabulary of visual words
     * (the cluster means). For all images a histogram on the visual words is created and added to the documents.
     * Pre-existing histograms are deleted, so this method can be used for re-indexing.
     *
     * @throws java.io.IOException
     */
    public void index() throws IOException {
        init();
        df.setMaximumFractionDigits(3);
        // find the documents for building the vocabulary:
        HashSet<Integer> docIDs = selectVocabularyDocs();    //選擇全部要進行聚類的文檔docment的id
        KMeans k = new KMeans(numClusters);
        // fill the KMeans object:
        LinkedList<double[]> features = new LinkedList<double[]>();
        // Needed for check whether the document is deleted.
        Bits liveDocs = MultiFields.getLiveDocs(reader);
        for (Iterator<Integer> iterator = docIDs.iterator(); iterator.hasNext(); ) {
            int nextDoc = iterator.next();
            if (reader.hasDeletions() && !liveDocs.get(nextDoc)) continue; // if it is deleted, just ignore it.
            Document d = reader.document(nextDoc);   // 取出該文檔
            features.clear();
            IndexableField[] fields = d.getFields(localFeatureFieldName);   // 取出sift特徵點
            String file = d.getValues(DocumentBuilder.FIELD_NAME_IDENTIFIER)[0];   // 取出該圖片路徑名字
            for (int j = 0; j < fields.length; j++) {
                LireFeature f = getFeatureInstance();
                // 取出descriptor
                f.setByteArrayRepresentation(fields[j].binaryValue().bytes, fields[j].binaryValue().offset, fields[j].binaryValue().length);
                // copy the data over to new array ...  沒有用
                //double[] feat = new double[f.getDoubleHistogram().length];
                //System.arraycopy(f.getDoubleHistogram(), 0, feat, 0, feat.length);
                features.add(f.getDoubleHistogram());
            }
            k.addImage(file, features);    // 將descriptor與圖片相關聯
        }
        if (pm != null) { // set to 5 of 100 before clustering starts.
            pm.setProgress(5);
            pm.setNote("Starting clustering");
        }
        if (k.getFeatureCount() < numClusters) {    // 總的特徵數小於聚類中心個數,則拋出異常
            // this cannot work. You need more data points than clusters.
            throw new UnsupportedOperationException("Only " + features.size() + " features found to cluster in " + numClusters + ". Try to use less clusters or more images.");
        }
        // do the clustering:
        System.out.println("Number of local features: " + df.format(k.getFeatureCount()));
        System.out.println("Starting clustering ...");
        k.init();        // 聚類中心初始化
        System.out.println("Step.");
        double time = System.currentTimeMillis();
        double laststress = k.clusteringStep();    // 進行聚類,並獲得sum of squared error

        if (pm != null) { // set to 8 of 100 after first step.
            pm.setProgress(8);
            pm.setNote("Step 1 finished");
        }

        System.out.println(getDuration(time) + " -> Next step.");
        time = System.currentTimeMillis();
        double newStress = k.clusteringStep();    // 第二步聚類

        if (pm != null) { // set to 11 of 100 after second step.
            pm.setProgress(11);
            pm.setNote("Step 2 finished");
        }

        // critical part: Give the difference in between steps as a constraint for accuracy vs. runtime trade off.
        double threshold = Math.max(20d, (double) k.getFeatureCount() / 1000d);   // 如果兩次sse小於20 迭代停止
        System.out.println("Threshold = " + df.format(threshold));
        int cstep = 3;
        while (Math.abs(newStress - laststress) > threshold && cstep < 12) {    // 迭代次數超過12次,迭代停止
            System.out.println(getDuration(time) + " -> Next step. Stress difference ~ |" + (int) newStress + " - " + (int) laststress + "| = " + df.format(Math.abs(newStress - laststress)));
            time = System.currentTimeMillis();
            laststress = newStress;
            newStress = k.clusteringStep();
            if (pm != null) { // set to XX of 100 after second step.
                pm.setProgress(cstep * 3 + 5);
                pm.setNote("Step " + cstep + " finished");
            }
            cstep++;
        }
        // Serializing clusters to a file on the disk ...
        clusters = k.getClusters();    // 得到聚類中心
//        for (int i = 0; i < clusters.length; i++) {
//            Cluster cluster = clusters[i];
//            System.out.print(cluster.getMembers().size() + ", ");
//        }
//        System.out.println();
        Cluster.writeClusters(clusters, clusterFile);  // 將聚類中心點寫入文本文件
        //  create & store histograms:
        System.out.println("Creating histograms ...");
        time = System.currentTimeMillis();
//        int[] tmpHist = new int[numClusters];
        @SuppressWarnings("deprecation")
		IndexWriterConfig conf = new IndexWriterConfig(Version.LUCENE_4_10_2,
                new WhitespaceAnalyzer(Version.LUCENE_4_10_2));
        conf.setOpenMode(OpenMode.CREATE_OR_APPEND);
        IndexWriter iw = new IndexWriter(((DirectoryReader) reader).directory(), conf);
        if (pm != null) { // set to 50 of 100 after clustering.
            pm.setProgress(50);
            pm.setNote("Clustering finished");
        }
        // parallelized indexing
        LinkedList<Thread> threads = new LinkedList<Thread>();  // 線程隊列
        int numThreads = 8;     // 設置了8個線程
        // careful: copy reader to RAM for faster access when reading ...
//        reader = IndexReader.open(new RAMDirectory(reader.directory()), true);
        int step = reader.maxDoc() / numThreads;   // 對每個線程分配一定數量的任務
        for (int part = 0; part < numThreads; part++) {
            Indexer indexer = null;
            if (part < numThreads - 1) indexer = new Indexer(part * step, (part + 1) * step, iw, null);
            else indexer = new Indexer(part * step, reader.maxDoc(), iw, pm);
            Thread t = new Thread(indexer);
            threads.add(t);    
            t.start();
        }
        for (Iterator<Thread> iterator = threads.iterator(); iterator.hasNext(); ) {
            Thread next = iterator.next();
            try {
                next.join();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        if (pm != null) { // set to 50 of 100 after clustering.
            pm.setProgress(95);
            pm.setNote("Indexing finished, optimizing index now.");
        }

        System.out.println(getDuration(time));
        iw.commit();
        // this one does the "old" commit(), it removes the deleted SURF features.
        iw.forceMerge(1);
        iw.close();
        if (pm != null) { // set to 50 of 100 after clustering.
            pm.setProgress(100);
            pm.setNote("Indexing & optimization finished");
            pm.close();
        }
        System.out.println("Finished.");
    }

   // 此函數沒有用
    public void indexMissing() throws IOException {
        init();
        // Reading clusters from disk:
        clusters = Cluster.readClusters(clusterFile);
        //  create & store histograms:
        System.out.println("Creating histograms ...");
        LireFeature f = getFeatureInstance();

        // Needed for check whether the document is deleted.
        Bits liveDocs = MultiFields.getLiveDocs(reader);

        // based on bug report from Einav Itamar <[email protected]>
        @SuppressWarnings("deprecation")
		IndexWriterConfig conf = new IndexWriterConfig(Version.LUCENE_4_10_2,
                new WhitespaceAnalyzer(Version.LUCENE_4_10_2));
        IndexWriter iw = new IndexWriter(((DirectoryReader) reader).directory(), conf);
        for (int i = 0; i < reader.maxDoc(); i++) {
            if (reader.hasDeletions() && !liveDocs.get(i)) continue; // if it is deleted, just ignore it.
            Document d = reader.document(i);
            // Only if there are no values yet:
            if (d.getValues(visualWordsFieldName) == null || d.getValues(visualWordsFieldName).length == 0) {
                createVisualWords(d, f);
                // now write the new one. we use the identifier to update ;)
                iw.updateDocument(new Term(DocumentBuilder.FIELD_NAME_IDENTIFIER, d.getValues(DocumentBuilder.FIELD_NAME_IDENTIFIER)[0]), d);
            }
        }
        iw.commit();
        // added to permanently remove the deleted docs.
        iw.forceMerge(1);
        iw.close();
        System.out.println("Finished.");
    }

    /**
     * Takes one single document and creates the visual words and adds them to the document. The same document is returned.
     *
     * @param d the document to use for adding the visual words
     * @return
     * @throws IOException
     */
    public Document getVisualWords(Document d) throws IOException {     // 得到文檔d所對應的bow特徵
        clusters = Cluster.readClusters(clusterFile);   // 讀入聚類中心
        LireFeature f = getFeatureInstance();     
        createVisualWords(d, f);    // 創建bow特徵

        return d;
    }


    @SuppressWarnings("unused")    // 沒有用了
	private void quantize(double[] histogram) {
        double max = 0;
        for (int i = 0; i < histogram.length; i++) {
            max = Math.max(max, histogram[i]);
        }
        for (int i = 0; i < histogram.length; i++) {
            histogram[i] = (int) Math.floor((histogram[i] * 128d) / max);
        }
    }

    /**
     * Find the appropriate cluster for a given feature.
     *
     * @param f
     * @return the index of the cluster.
     */
    private int clusterForFeature(Histogram f) {   // 找到一個特徵點最近的聚類中心並返回該聚類中心的下標
        double distance = clusters[0].getDistance(f);
        double tmp;
        int result = 0;
        for (int i = 1; i < clusters.length; i++) {
            tmp = clusters[i].getDistance(f);
            if (tmp < distance) {
                distance = tmp;
                result = i;
            }
        }
        return result;
    }

    private String arrayToVisualWordString(double[] hist) {   // 以這種string類型進行存儲,感覺沒什麼用啊
        StringBuilder sb = new StringBuilder(1024);
        for (int i = 0; i < hist.length; i++) {
            int visualWordIndex = (int) hist[i];
            for (int j = 0; j < visualWordIndex; j++) {
                // sb.append('v');
                sb.append(Integer.toHexString(i));
                sb.append(' ');
            }
        }
        return sb.toString();
    }
        // 選擇圖片進行聚類
    private HashSet<Integer> selectVocabularyDocs() throws IOException {
        // need to make sure that this is not running forever ...
        int loopCount = 0;
        float maxDocs = reader.maxDoc();    // 返回總文檔數量
        int capacity = (int) Math.min(numDocsForVocabulary, maxDocs);
        if (capacity < 0) capacity = (int) (maxDocs / 2);   // 如果是-1 則選擇一半文檔
        HashSet<Integer> result = new HashSet<Integer>(capacity);
        int tmpDocNumber, tmpIndex;
        LinkedList<Integer> docCandidates = new LinkedList<Integer>();
        // three cases:
        //
        // either it's more or the same number as documents
        if (numDocsForVocabulary >= maxDocs) {   // 指定數量大於已有的,則將已有全部用來聚類
            for (int i = 0; i < maxDocs; i++) {
                result.add(i);
            }
            return result;
        } else if (numDocsForVocabulary >= maxDocs - 100) { // 在[maxDocs-100, maxDocs]之間,
            for (int i = 0; i < maxDocs; i++) {
                result.add(i);         // 先全部加入
            }
            while (result.size() > numDocsForVocabulary) {   // 隨機踢出掉多餘的圖片,使數量爲numDocForVocabulary
                result.remove((int) Math.floor(Math.random() * result.size()));
            }
            return result;
        } else {             // 不滿足上面幾種情況即numDocForVocabulary在[1, maxDocs-100]之間
            for (int i = 0; i < maxDocs; i++) {
                docCandidates.add(i);    // 先將全部加入
            }
            for (int r = 0; r < capacity; r++) { // capacity就等於numDocForVocabulary
                boolean worksFine = false;
                do {
                    tmpIndex = (int) Math.floor(Math.random() * (double) docCandidates.size());
                    tmpDocNumber = docCandidates.get(tmpIndex);
                    docCandidates.remove(tmpIndex);
                    // 該文檔是否存在及是否已經包含
                    // check if the selected doc number is valid: not null, not deleted and not already chosen.
                    worksFine = (reader.document(tmpDocNumber) != null) && !result.contains(tmpDocNumber);
                } while (!worksFine);
                result.add(tmpDocNumber);
                // need to make sure that this is not running forever ...
                if (loopCount++ > capacity * 100)
                    throw new UnsupportedOperationException("Could not get the documents, maybe there are not enough documents in the index?");
            }
            return result;
        }
    }

//    protected abstract LireFeature getFeatureInstance();

    protected LireFeature getFeatureInstance() {
        LireFeature result = null;
        try {
            result =  lireFeature.getClass().newInstance();
        } catch (InstantiationException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
        return result;
    }

    private class Indexer implements Runnable {     // 建索引的線程類 私有的
        int start, end;
        IndexWriter iw;
        ProgressMonitor pm = null;

        private Indexer(int start, int end, IndexWriter iw, ProgressMonitor pm) {
            this.start = start;
            this.end = end;
            this.iw = iw;
            this.pm = pm;
        }

        public void run() {               // 線程運行函數
            LireFeature f = getFeatureInstance();   // 得到feature的實例
            for (int i = start; i < end; i++) {
                try {
                    Document d = reader.document(i);    // 得到第i個文檔
                    createVisualWords(d, f);
                    iw.updateDocument(new Term(DocumentBuilder.FIELD_NAME_IDENTIFIER, d.getValues(DocumentBuilder.FIELD_NAME_IDENTIFIER)[0]), d);
                    if (pm != null) {
                        double len = (double) (end - start);
                        double percent = (double) (i - start) / len * 45d + 50;
                        pm.setProgress((int) percent);
                        pm.setNote("Creating visual words, ~" + (int) percent + "% finished");
                    }
//                    }
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    private void createVisualWords(Document d, LireFeature f)
    {
        double[] tmpHist = new double[numClusters];     
        Arrays.fill(tmpHist, 0d);
        IndexableField[] fields = d.getFields(localFeatureFieldName);
        // remove the fields if they are already there ...
        // 從索引中移除以下兩個字段以防已經存在
        d.removeField(visualWordsFieldName);
        d.removeField(localFeatureHistFieldName);

        // find the appropriate cluster for each feature:
        for (int j = 0; j < fields.length; j++) {      // 獲取該描述符 
            f.setByteArrayRepresentation(fields[j].binaryValue().bytes, fields[j].binaryValue().offset, fields[j].binaryValue().length);
            tmpHist[clusterForFeature((Histogram) f)]++;      // 得到每一個特徵點所對應的最近聚類中心就+1
        }
        //quantize(tmpHist);   // tmpHist就是最終的結果
        d.add(new TextField(visualWordsFieldName, arrayToVisualWordString(tmpHist), Field.Store.YES));    // 以字符串的形式進行存儲,沒什麼用
        d.add(new StoredField(localFeatureHistFieldName, SerializationUtils.toByteArray(tmpHist)));   // 轉換成字節類型進行存儲
        // remove local features to save some space if requested:
        if (DELETE_LOCAL_FEATURES) {
            d.removeFields(localFeatureFieldName);     // 移除原有的field
        }

        // for debugging ..
//        System.out.println(d.getValues(DocumentBuilder.FIELD_NAME_IDENTIFIER)[0] + " " + Arrays.toString(tmpHist));
    }

    private String getDuration(double time) {
        double min = (System.currentTimeMillis() - time) / (1000 * 60);
        double sec = (min - Math.floor(min)) * 60;
        return String.format("%02d:%02d", (int) min, (int) sec);
    }

    public void setProgressMonitor(ProgressMonitor pm) {
        this.pm = pm;
    }

}

KMeans.java(包含註釋)

package lmc.imageretrieval.imageanalysis.bovw;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

import lmc.imageretrieval.imageanalysis.Histogram;
import lmc.imageretrieval.utils.StatsUtils;

public class KMeans {
    protected List<Image> images = new LinkedList<Image>();
    protected int countAllFeatures = 0, numClusters = 256;
    protected ArrayList<double[]> features = null;
    protected Cluster[] clusters = null;
    protected HashMap<double[], Integer> featureIndex = null;

    public KMeans() {

    }

    public KMeans(int numClusters) {
        this.numClusters = numClusters;
    }

    public void addImage(String identifier, List<double[]> features) {    // 加入image
        images.add(new Image(identifier, features));
        countAllFeatures += features.size();
    }

    public int getFeatureCount() {
        return countAllFeatures;
    }

    public void init() {      // 聚類中心初始化
        // create a set of all features:
        features = new ArrayList<double[]>(countAllFeatures);
        for (Image image : images) {
            if (image.features.size() > 0)         // 將所有的descriptor放入features中
                for (double[] histogram : image.features) {
                    if (!hasNaNs(histogram)) features.add(histogram);
                }
            else {
                System.err.println("Image with no features: " + image.identifier);
            }
        }
        // --- check if there are (i) enough images and (ii) enough features
        if (images.size() < 500) {   // 圖片數量小於500 錯誤
            System.err.println("WARNING: Please note that this approach has been implemented for big data and *a lot of images*. " +
                    "You might not get appropriate results with a small number of images employed for constructing the visual vocabulary.");
        }
        if (features.size() < numClusters*2) {   // 特徵點個數不能小於聚類中心的兩倍
            System.err.println("WARNING: Please note that the number of local features, in this case " + features.size() + ", is" +
                    "smaller than the recommended minimum number, which is two times the number of visual words, in your case 2*" + numClusters +
                    ". Please adapt your data and either use images with more local features or more images for creating the visual vocabulary.");
        }
        if (features.size() < numClusters + 1) {    //特徵點個數不能小於聚類中心+1
            System.err.println("CRITICAL: The number of features is smaller than the number of clusters. This cannot work as there has to be at least one " +
                    "feature per cluster. Aborting process now.");
            System.out.println("images: " + images.size());
            System.out.println("features: " + features.size());
            System.out.println("clusters: " + numClusters);
            System.exit(1);
        }
        // find first clusters:
        clusters = new Cluster[numClusters];          // 初始的聚類中心
        Set<Integer> medians = selectInitialMedians(numClusters);
        assert(medians.size() == numClusters); // this has to be the same ...
        Iterator<Integer> mediansIterator = medians.iterator();
        for (int i = 0; i < clusters.length; i++) {
            double[] descriptor = features.get(mediansIterator.next());
            clusters[i] = new Cluster(new double[descriptor.length]);   // implicitly setting the length of the mean array.
            System.arraycopy(descriptor, 0, clusters[i].mean, 0, descriptor.length);
        }
    }

    protected Set<Integer> selectInitialMedians(int numClusters) {
        return StatsUtils.drawSample(numClusters, features.size());
    }

    /**
     * Do one step and return the overall stress (squared error). You should do this until
     * the error is below a threshold or doesn't change a lot in between two subsequent steps.
     *
     * @return
     */
    public double clusteringStep() {            // 聚類迭代
        for (int i = 0; i < clusters.length; i++) {
            clusters[i].members.clear();             // 清空該聚類中心所有的成員
        }
        reOrganizeFeatures();     // 重新計算每個樣本點到聚類中心的距離,重新分配
        recomputeMeans();          // 重新計算聚類中心的大小
        return overallStress();        // 返回sum of squared  迭代結束指標
    }

    protected boolean hasNaNs(double[] histogram) {    // 判斷是否有not a number
        boolean hasNaNs = false;
        for (int i = 0; i < histogram.length; i++) {
            if (Double.isNaN(histogram[i])) {
                hasNaNs = true;
                break;
            }
        }
        if (hasNaNs) {
            System.err.println("Found a NaN in init");
//            System.out.println("image.identifier = " + image.identifier);
            for (int j = 0; j < histogram.length; j++) {
                double v = histogram[j];
                System.out.print(v + ", ");
            }
            System.out.println("");
        }
        return hasNaNs;
    }

    /**
     * Re-shuffle all features.
     */
    protected void reOrganizeFeatures() {            // 重新計算每個點到聚類中心的距離,該點歸屬於哪一個聚類中心
        for (int k = 0; k < features.size(); k++) {     // 看k屬於哪個聚類中心最近
            double[] f = features.get(k);
            Cluster best = clusters[0];
            double minDistance = clusters[0].getDistance(f);
            for (int i = 1; i < clusters.length; i++) {
                double v = clusters[i].getDistance(f);   // 採用的是歐式距離
                if (minDistance > v) {
                    best = clusters[i];
                    minDistance = v;
                }
            }
            best.members.add(k);
        }
    }

    /**
     * Computes the mean per cluster (averaged vector)
     */
    protected void recomputeMeans() {        // 重新計算聚類中心
        int length = features.get(0).length;
        for (int i = 0; i < clusters.length; i++) {
            Cluster cluster = clusters[i];
            double[] mean = cluster.mean;
            for (int j = 0; j < length; j++) {
                mean[j] = 0;
                for (Integer member : cluster.members) {
                    mean[j] += features.get(member)[j];
                }
                if (cluster.members.size() > 1)
                    mean[j] = mean[j] / (double) cluster.members.size();
            }
            if (cluster.members.size() == 1) {         // 該聚類中心只含有一個點
                System.err.println("** There is just one member in cluster " + i);
            } else if (cluster.members.size() < 1) {   // 該聚類中心沒有點
                System.err.println("** There is NO member in cluster " + i);
                // fill it with a random member?!?
                int index = (int) Math.floor(Math.random()*features.size());    // 重新隨機選擇一個點作爲該聚類中心
                System.arraycopy(features.get(index), 0, clusters[i].mean, 0, clusters[i].mean.length);
            }

        }
    }

    /**
     * Squared error in classification.
     *
     * @return
     */
    protected double overallStress() {         // 計算聚類中的sum of squared
        double v = 0;
        int length = features.get(0).length;
        for (int i = 0; i < clusters.length; i++) {
            for (Integer member : clusters[i].members) {
                float tmpStress = 0;
                for (int j = 0; j < length; j++) {
//                    if (Float.isNaN(features.get(member).descriptor[j])) System.err.println("Error: there is a NaN in cluster " + i + " at member " + member);
                    tmpStress += Math.abs(clusters[i].mean[j] - features.get(member)[j]);
                }
                v += tmpStress;
            }
        }
        return v;
    }

    public Cluster[] getClusters() {
        return clusters;
    }

    public List<Image> getImages() {
        return images;
    }

    /**
     * Set the number of desired clusters.
     *
     * @return
     */
    public int getNumClusters() {
        return numClusters;
    }

    public void setNumClusters(int numClusters) {
        this.numClusters = numClusters;
    }

    private HashMap<double[], Integer> createIndex() {
        featureIndex = new HashMap<double[], Integer>(features.size());
        for (int i = 0; i < clusters.length; i++) {
            Cluster cluster = clusters[i];
            for (Iterator<Integer> fidit = cluster.members.iterator(); fidit.hasNext(); ) {
                int fid = fidit.next();
                featureIndex.put(features.get(fid), i);
            }
        }
        return featureIndex;
    }

    /**
     * Used to find the cluster of a feature actually used in the clustering process (so
     * it is known by the k-means class).
     *
     * @param f the feature to search for
     * @return the index of the Cluster
     */
    public int getClusterOfFeature(Histogram f) {
        if (featureIndex == null) createIndex();
        return featureIndex.get(f);
    }
}

class Image {
    public List<double[]> features;
    public String identifier;
    public float[] localFeatureHistogram = null;
    private final int QUANT_MAX_HISTOGRAM = 256;

    Image(String identifier, List<double[]> features) {
        this.features = new LinkedList<double[]>();
        this.features.addAll(features);
        this.identifier = identifier;
    }

    public float[] getLocalFeatureHistogram() {
        return localFeatureHistogram;
    }

    public void setLocalFeatureHistogram(float[] localFeatureHistogram) {
        this.localFeatureHistogram = localFeatureHistogram;
    }

    public void initHistogram(int bins) {
        localFeatureHistogram = new float[bins];
        for (int i = 0; i < localFeatureHistogram.length; i++) {
            localFeatureHistogram[i] = 0;
        }
    }

    public void normalizeFeatureHistogram() {         // 對聚類中心進行歸一化
        float max = 0;
        for (int i = 0; i < localFeatureHistogram.length; i++) {
            max = Math.max(localFeatureHistogram[i], max);
        }
        for (int i = 0; i < localFeatureHistogram.length; i++) {
            localFeatureHistogram[i] = (localFeatureHistogram[i] * QUANT_MAX_HISTOGRAM) / max;
        }
    }

    public void printHistogram() {
        for (int i = 0; i < localFeatureHistogram.length; i++) {
            System.out.print(localFeatureHistogram[i] + " ");

        }
        System.out.println("");
    }
}

Cluster.java(包含註釋)

package lmc.imageretrieval.imageanalysis.bovw;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;

import lmc.imageretrieval.imageanalysis.Histogram;
import lmc.imageretrieval.utils.MetricsUtils;
import lmc.imageretrieval.utils.SerializationUtils;

public class Cluster implements Comparable<Object> {
    double[] mean;
    HashSet<Integer> members = new HashSet<Integer>();

    private double stress = 0;

    public Cluster() {
        this.mean = new double[4 * 4 * 8];
        Arrays.fill(mean, 0f);
    }

    public Cluster(double[] mean) {
        this.mean = mean;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(512);
        for (Integer integer : members) {
            sb.append(integer);
            sb.append(", ");
        }
        for (int i = 0; i < mean.length; i++) {
            sb.append(mean[i]);
            sb.append(';');
        }
        return sb.toString();
    }

    public int compareTo(Object o) {
        return ((Cluster) o).members.size() - members.size();
    }

    public double getDistance(Histogram f) {
        return getDistance(f.getDoubleHistogram());
    }

    public double getDistance(double[] f) {
//        L1
//        return MetricsUtils.distL1(mean, f);

//        L2
        return MetricsUtils.distL2(mean, f);
    }

    /**
     * Creates a byte array representation from the clusters mean.
     *
     * @return the clusters mean as byte array.
     */
    public byte[] getByteRepresentation() {
        return SerializationUtils.toByteArray(mean);
    }

    public void setByteRepresentation(byte[] data) {
        mean = SerializationUtils.toDoubleArray(data);
    }

    public static void writeClusters(Cluster[] clusters, String file) throws IOException {   // 將聚類中心寫入磁盤上
        FileOutputStream fout = new FileOutputStream(file);
        fout.write(SerializationUtils.toBytes(clusters.length));   // 聚類中心個數
        fout.write(SerializationUtils.toBytes((clusters[0].getMean()).length));  // 聚類中心點的長度
        for (int i = 0; i < clusters.length; i++) {
            fout.write(clusters[i].getByteRepresentation());   // 寫入每個聚類中心
        }
        fout.close();
    }

    // TODO: re-visit here to make the length variable (depending on the actual feature size).
    public static Cluster[] readClusters(String file) throws IOException {    // 從磁盤上讀取聚類中心
        FileInputStream fin = new FileInputStream(file);
        byte[] tmp = new byte[4];
        fin.read(tmp, 0, 4);
        Cluster[] result = new Cluster[SerializationUtils.toInt(tmp)];
        fin.read(tmp, 0, 4);
        int size = SerializationUtils.toInt(tmp);
        tmp = new byte[size * 8];
        for (int i = 0; i < result.length; i++) {
            int bytesRead = fin.read(tmp, 0, size * 8);
            if (bytesRead != size * 8) System.err.println("Didn't read enough bytes ...");
            result[i] = new Cluster();
            result[i].setByteRepresentation(tmp);
        }
        fin.close();
        return result;
    }

    public double getStress() {
        return stress;
    }

    public void setStress(double stress) {
        this.stress = stress;
    }

    public HashSet<Integer> getMembers() {
        return members;
    }

    public void setMembers(HashSet<Integer> members) {
        this.members = members;
    }

    /**
     * Returns the cluster mean
     *
     * @return the cluster mean vector
     */
    public double[] getMean() {
        return mean;
    }
}


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章