4-通過java調用libsvm

libsvm庫的作者主頁 :http://www.csie.ntu.edu.tw/~cjlin

好吧,都是英文的......

libsvm的分鏈接:

http://www.csie.ntu.edu.tw/~cjlin/libsvm/index.html

到通過分連接裏面找到download,下載下libsvm的壓縮包,其目錄如下圖:

111307246.png


好多東西的感腳,其實如果是java的話只要java那個文件夾裏面的東西就口以啦














下面是java文件夾下面的動動,也很多的感腳,其實也只要libsvm.jar這個包就可以啦,超喜歡,一個包導進到工程項目中就可以直接用了,SO 方便.

111307647.png


對於其它的java文件要幹嘛用呢,你可以當作例子,那應該都是作者寫的實例吧,裏面都是調用libsvm.jar包的類,看那些代碼

有助於清晰明瞭的學會libsvm.jar的各種類的使用方法,文章後部分的代碼參考svm_train,額,也不算參考吧,應爲基本都是

那個文件裏面的代碼哈。





下面說下主要要用到的類有3個:

1.svm_parameter:用來保存svm的一些設置參數

2.svm_problem:用來保存樣本的

3.svm:主要用來做svm分類的

其實還有svm_model類也是很重要的,只是暫時沒用到,因爲一開始我的目的只是爲了能夠用這個讓lissvm成功運行起來,我就哈皮啦,所以這個還沒研究。

不過後面要更靈活地使用svm就需要這個東西咯, mark先:svm_model ,保存分類模型的,具體用法還咩研究。



下面直接貼代碼了,裏面有關於參數設置以及如何將文件裏的數據存到svm_problem的實例中,以及交叉驗證的代碼,摘抄自svm_train.java

package loma;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.StringTokenizer;
import java.util.Vector;
import libsvm.svm;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
public class test{
    private static double atof(String s)
    {
        double d = Double.valueOf(s).doubleValue();
        if (Double.isNaN(d) || Double.isInfinite(d))
        {
            System.err.print("NaN or Infinity in input\n");
            System.exit(1);
        }
        return(d);
    }
    //獲取參數
    private svm_parameter getParameter(){
        svm_parameter param = new svm_parameter();
        // default values
        param.svm_type = svm_parameter.C_SVC;
        param.kernel_type = svm_parameter.RBF;
        param.degree = 3;
        param.gamma = 0;    // 1/num_features
        param.coef0 = 0;
        param.nu = 0.5;
        param.cache_size = 100;
        param.C = 1;
        param.eps = 1e-3;
        param.p = 0.1;
        param.shrinking = 1;
        param.probability = 0;
        param.nr_weight = 0;
        param.weight_label = new int[0];
        param.weight = new double[0];
        return param;
    }
    private static int atoi(String s)
    {
        return Integer.parseInt(s);
    }
    //獲取問題描述
    private svm_problem read_problem(String input_file_name,svm_parameter param) throws IOException
    {
        BufferedReader fp = new BufferedReader(new FileReader(input_file_name));
        Vector<Double> vy = new Vector<Double>();
        Vector<svm_node[]> vx = new Vector<svm_node[]>();
        int max_index = 0;
        while(true)
        {
            String line = fp.readLine();
            if(line == null) break;
            StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
            vy.addElement(atof(st.nextToken()));
            int m = st.countTokens()/2;
            svm_node[] x = new svm_node[m];
            for(int j=0;j<m;j++)
            {
                x[j] = new svm_node();
                x[j].index = atoi(st.nextToken());
                x[j].value = atof(st.nextToken());
            }
            if(m>0) max_index = Math.max(max_index, x[m-1].index);
            vx.addElement(x);
        }
        svm_problem prob = new svm_problem();
        prob.l = vy.size();
        prob.x = new svm_node[prob.l][];
        for(int i=0;i<prob.l;i++)
            prob.x[i] = vx.elementAt(i);
        prob.y = new double[prob.l];
        for(int i=0;i<prob.l;i++)
            prob.y[i] = vy.elementAt(i);
        if(param.gamma == 0 && max_index > 0)
            param.gamma = 1.0/max_index;
        if(param.kernel_type == svm_parameter.PRECOMPUTED)
            for(int i=0;i<prob.l;i++)
            {
                if (prob.x[i][0].index != 0)
                {
                    System.err.print("Wrong kernel matrix: first column must be 0:sample_serial_number\n");
                    System.exit(1);
                }
                if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
                {
                    System.err.print("Wrong input format: sample_serial_number out of range\n");
                    System.exit(1);
                }
            }
        fp.close();
        return prob;
    }
    //交叉驗證
    private void do_cross_validation(svm_problem prob,svm_parameter param,int nr_fold)
    {
        int i;
        int total_correct = 0;
        double total_error = 0;
        double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
        double[] target = new double[prob.l];
        svm.svm_cross_validation(prob,param,nr_fold,target);
        if(param.svm_type == svm_parameter.EPSILON_SVR ||
           param.svm_type == svm_parameter.NU_SVR)
        {
            for(i=0;i<prob.l;i++)
            {
                double y = prob.y[i];
                double v = target[i];
                total_error += (v-y)*(v-y);
                sumv += v;
                sumy += y;
                sumvv += v*v;
                sumyy += y*y;
                sumvy += v*y;
            }
            System.out.print("Cross Validation Mean squared error = "+total_error/prob.l+"\n");
            System.out.print("Cross Validation Squared correlation coefficient = "+
                ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
                ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))+"\n"
                );
        }
        else
        {
            for(i=0;i<prob.l;i++)
                if(target[i] == prob.y[i])
                    ++total_correct;
            System.out.print("Cross Validation Accuracy = "+100.0*total_correct/prob.l+"%\n");
        }
    }
                                                                                                                                                                                                                                                               
    public static void main(String[] args){
        test t = new test();
        svm_parameter param = t.getParameter();
        svm_problem prob = null;
        try{
            prob = t.read_problem("d:\\iris.scale",param);
        }catch(IOException e){
            //todo
        }
        t.do_cross_validation(prob, param, 10);
    }
}

整體步驟如下:

1.到官網下載,libsvm的壓縮包。

2.從壓縮包裏面獲取libsvm.jar並添加到java工程中。

3.創建一個簡單的測試類,可直接將上面的代碼複製粘貼過去體驗下。

上述代碼的基本思路:

首先要實例化svm_parameter,然後設置相應的參數;

其次讀取數據文件,將裏面的內容按要求存放到svm_problem對象中;

最後,通過調用svm.svm_cross_validation(prob,param,nr_fold,target);進行交叉驗證

4.從http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ 下載數據集進行測試,具體的數據格式可以模仿這邊下載下的文件

大概格式就是這樣的 label index1:value1 index2:value2....

5.over


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章