4-通过java调用libsvm

libsvm库的作者主页 :http://www.csie.ntu.edu.tw/~cjlin

好吧,都是英文的......

libsvm的分链接:

http://www.csie.ntu.edu.tw/~cjlin/libsvm/index.html

到通过分连接里面找到download,下载下libsvm的压缩包,其目录如下图:

111307246.png


好多东西的感脚,其实如果是java的话只要java那个文件夹里面的东西就口以啦














下面是java文件夹下面的动动,也很多的感脚,其实也只要libsvm.jar这个包就可以啦,超喜欢,一个包导进到工程项目中就可以直接用了,SO 方便.

111307647.png


对于其它的java文件要干嘛用呢,你可以当作例子,那应该都是作者写的实例吧,里面都是调用libsvm.jar包的类,看那些代码

有助于清晰明了的学会libsvm.jar的各种类的使用方法,文章后部分的代码参考svm_train,额,也不算参考吧,应为基本都是

那个文件里面的代码哈。





下面说下主要要用到的类有3个:

1.svm_parameter:用来保存svm的一些设置参数

2.svm_problem:用来保存样本的

3.svm:主要用来做svm分类的

其实还有svm_model类也是很重要的,只是暂时没用到,因为一开始我的目的只是为了能够用这个让lissvm成功运行起来,我就哈皮啦,所以这个还没研究。

不过后面要更灵活地使用svm就需要这个东西咯, mark先:svm_model ,保存分类模型的,具体用法还咩研究。



下面直接贴代码了,里面有关于参数设置以及如何将文件里的数据存到svm_problem的实例中,以及交叉验证的代码,摘抄自svm_train.java

package loma;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.StringTokenizer;
import java.util.Vector;
import libsvm.svm;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
public class test{
    private static double atof(String s)
    {
        double d = Double.valueOf(s).doubleValue();
        if (Double.isNaN(d) || Double.isInfinite(d))
        {
            System.err.print("NaN or Infinity in input\n");
            System.exit(1);
        }
        return(d);
    }
    //获取参数
    private svm_parameter getParameter(){
        svm_parameter param = new svm_parameter();
        // default values
        param.svm_type = svm_parameter.C_SVC;
        param.kernel_type = svm_parameter.RBF;
        param.degree = 3;
        param.gamma = 0;    // 1/num_features
        param.coef0 = 0;
        param.nu = 0.5;
        param.cache_size = 100;
        param.C = 1;
        param.eps = 1e-3;
        param.p = 0.1;
        param.shrinking = 1;
        param.probability = 0;
        param.nr_weight = 0;
        param.weight_label = new int[0];
        param.weight = new double[0];
        return param;
    }
    private static int atoi(String s)
    {
        return Integer.parseInt(s);
    }
    //获取问题描述
    private svm_problem read_problem(String input_file_name,svm_parameter param) throws IOException
    {
        BufferedReader fp = new BufferedReader(new FileReader(input_file_name));
        Vector<Double> vy = new Vector<Double>();
        Vector<svm_node[]> vx = new Vector<svm_node[]>();
        int max_index = 0;
        while(true)
        {
            String line = fp.readLine();
            if(line == null) break;
            StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
            vy.addElement(atof(st.nextToken()));
            int m = st.countTokens()/2;
            svm_node[] x = new svm_node[m];
            for(int j=0;j<m;j++)
            {
                x[j] = new svm_node();
                x[j].index = atoi(st.nextToken());
                x[j].value = atof(st.nextToken());
            }
            if(m>0) max_index = Math.max(max_index, x[m-1].index);
            vx.addElement(x);
        }
        svm_problem prob = new svm_problem();
        prob.l = vy.size();
        prob.x = new svm_node[prob.l][];
        for(int i=0;i<prob.l;i++)
            prob.x[i] = vx.elementAt(i);
        prob.y = new double[prob.l];
        for(int i=0;i<prob.l;i++)
            prob.y[i] = vy.elementAt(i);
        if(param.gamma == 0 && max_index > 0)
            param.gamma = 1.0/max_index;
        if(param.kernel_type == svm_parameter.PRECOMPUTED)
            for(int i=0;i<prob.l;i++)
            {
                if (prob.x[i][0].index != 0)
                {
                    System.err.print("Wrong kernel matrix: first column must be 0:sample_serial_number\n");
                    System.exit(1);
                }
                if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
                {
                    System.err.print("Wrong input format: sample_serial_number out of range\n");
                    System.exit(1);
                }
            }
        fp.close();
        return prob;
    }
    //交叉验证
    private void do_cross_validation(svm_problem prob,svm_parameter param,int nr_fold)
    {
        int i;
        int total_correct = 0;
        double total_error = 0;
        double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
        double[] target = new double[prob.l];
        svm.svm_cross_validation(prob,param,nr_fold,target);
        if(param.svm_type == svm_parameter.EPSILON_SVR ||
           param.svm_type == svm_parameter.NU_SVR)
        {
            for(i=0;i<prob.l;i++)
            {
                double y = prob.y[i];
                double v = target[i];
                total_error += (v-y)*(v-y);
                sumv += v;
                sumy += y;
                sumvv += v*v;
                sumyy += y*y;
                sumvy += v*y;
            }
            System.out.print("Cross Validation Mean squared error = "+total_error/prob.l+"\n");
            System.out.print("Cross Validation Squared correlation coefficient = "+
                ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
                ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))+"\n"
                );
        }
        else
        {
            for(i=0;i<prob.l;i++)
                if(target[i] == prob.y[i])
                    ++total_correct;
            System.out.print("Cross Validation Accuracy = "+100.0*total_correct/prob.l+"%\n");
        }
    }
                                                                                                                                                                                                                                                               
    public static void main(String[] args){
        test t = new test();
        svm_parameter param = t.getParameter();
        svm_problem prob = null;
        try{
            prob = t.read_problem("d:\\iris.scale",param);
        }catch(IOException e){
            //todo
        }
        t.do_cross_validation(prob, param, 10);
    }
}

整体步骤如下:

1.到官网下载,libsvm的压缩包。

2.从压缩包里面获取libsvm.jar并添加到java工程中。

3.创建一个简单的测试类,可直接将上面的代码复制粘贴过去体验下。

上述代码的基本思路:

首先要实例化svm_parameter,然后设置相应的参数;

其次读取数据文件,将里面的内容按要求存放到svm_problem对象中;

最后,通过调用svm.svm_cross_validation(prob,param,nr_fold,target);进行交叉验证

4.从http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ 下载数据集进行测试,具体的数据格式可以模仿这边下载下的文件

大概格式就是这样的 label index1:value1 index2:value2....

5.over


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章