libsvm庫的作者主頁 :http://www.csie.ntu.edu.tw/~cjlin
好吧,都是英文的......
libsvm的分鏈接:
http://www.csie.ntu.edu.tw/~cjlin/libsvm/index.html
到通過分連接裏面找到download,下載下libsvm的壓縮包,其目錄如下圖:
好多東西的感腳,其實如果是java的話只要java那個文件夾裏面的東西就口以啦
下面是java文件夾下面的動動,也很多的感腳,其實也只要libsvm.jar這個包就可以啦,超喜歡,一個包導進到工程項目中就可以直接用了,SO 方便.
對於其它的java文件要幹嘛用呢,你可以當作例子,那應該都是作者寫的實例吧,裏面都是調用libsvm.jar包的類,看那些代碼
有助於清晰明瞭的學會libsvm.jar的各種類的使用方法,文章後部分的代碼參考svm_train,額,也不算參考吧,應爲基本都是
那個文件裏面的代碼哈。
下面說下主要要用到的類有3個:
1.svm_parameter:用來保存svm的一些設置參數
2.svm_problem:用來保存樣本的
3.svm:主要用來做svm分類的
其實還有svm_model類也是很重要的,只是暫時沒用到,因爲一開始我的目的只是爲了能夠用這個讓lissvm成功運行起來,我就哈皮啦,所以這個還沒研究。
不過後面要更靈活地使用svm就需要這個東西咯, mark先:svm_model ,保存分類模型的,具體用法還咩研究。
下面直接貼代碼了,裏面有關於參數設置以及如何將文件裏的數據存到svm_problem的實例中,以及交叉驗證的代碼,摘抄自svm_train.java
package loma; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.util.StringTokenizer; import java.util.Vector; import libsvm.svm; import libsvm.svm_node; import libsvm.svm_parameter; import libsvm.svm_problem; public class test{ private static double atof(String s) { double d = Double.valueOf(s).doubleValue(); if (Double.isNaN(d) || Double.isInfinite(d)) { System.err.print("NaN or Infinity in input\n"); System.exit(1); } return(d); } //獲取參數 private svm_parameter getParameter(){ svm_parameter param = new svm_parameter(); // default values param.svm_type = svm_parameter.C_SVC; param.kernel_type = svm_parameter.RBF; param.degree = 3; param.gamma = 0; // 1/num_features param.coef0 = 0; param.nu = 0.5; param.cache_size = 100; param.C = 1; param.eps = 1e-3; param.p = 0.1; param.shrinking = 1; param.probability = 0; param.nr_weight = 0; param.weight_label = new int[0]; param.weight = new double[0]; return param; } private static int atoi(String s) { return Integer.parseInt(s); } //獲取問題描述 private svm_problem read_problem(String input_file_name,svm_parameter param) throws IOException { BufferedReader fp = new BufferedReader(new FileReader(input_file_name)); Vector<Double> vy = new Vector<Double>(); Vector<svm_node[]> vx = new Vector<svm_node[]>(); int max_index = 0; while(true) { String line = fp.readLine(); if(line == null) break; StringTokenizer st = new StringTokenizer(line," \t\n\r\f:"); vy.addElement(atof(st.nextToken())); int m = st.countTokens()/2; svm_node[] x = new svm_node[m]; for(int j=0;j<m;j++) { x[j] = new svm_node(); x[j].index = atoi(st.nextToken()); x[j].value = atof(st.nextToken()); } if(m>0) max_index = Math.max(max_index, x[m-1].index); vx.addElement(x); } svm_problem prob = new svm_problem(); prob.l = vy.size(); prob.x = new svm_node[prob.l][]; for(int i=0;i<prob.l;i++) prob.x[i] = vx.elementAt(i); prob.y = new double[prob.l]; for(int i=0;i<prob.l;i++) prob.y[i] = vy.elementAt(i); if(param.gamma == 0 && max_index > 0) param.gamma = 1.0/max_index; if(param.kernel_type == svm_parameter.PRECOMPUTED) for(int i=0;i<prob.l;i++) { if (prob.x[i][0].index != 0) { System.err.print("Wrong kernel matrix: first column must be 0:sample_serial_number\n"); System.exit(1); } if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index) { System.err.print("Wrong input format: sample_serial_number out of range\n"); System.exit(1); } } fp.close(); return prob; } //交叉驗證 private void do_cross_validation(svm_problem prob,svm_parameter param,int nr_fold) { int i; int total_correct = 0; double total_error = 0; double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; double[] target = new double[prob.l]; svm.svm_cross_validation(prob,param,nr_fold,target); if(param.svm_type == svm_parameter.EPSILON_SVR || param.svm_type == svm_parameter.NU_SVR) { for(i=0;i<prob.l;i++) { double y = prob.y[i]; double v = target[i]; total_error += (v-y)*(v-y); sumv += v; sumy += y; sumvv += v*v; sumyy += y*y; sumvy += v*y; } System.out.print("Cross Validation Mean squared error = "+total_error/prob.l+"\n"); System.out.print("Cross Validation Squared correlation coefficient = "+ ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/ ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))+"\n" ); } else { for(i=0;i<prob.l;i++) if(target[i] == prob.y[i]) ++total_correct; System.out.print("Cross Validation Accuracy = "+100.0*total_correct/prob.l+"%\n"); } } public static void main(String[] args){ test t = new test(); svm_parameter param = t.getParameter(); svm_problem prob = null; try{ prob = t.read_problem("d:\\iris.scale",param); }catch(IOException e){ //todo } t.do_cross_validation(prob, param, 10); } }
整體步驟如下:
1.到官網下載,libsvm的壓縮包。
2.從壓縮包裏面獲取libsvm.jar並添加到java工程中。
3.創建一個簡單的測試類,可直接將上面的代碼複製粘貼過去體驗下。
上述代碼的基本思路:
首先要實例化svm_parameter,然後設置相應的參數;
其次讀取數據文件,將裏面的內容按要求存放到svm_problem對象中;
最後,通過調用svm.svm_cross_validation(prob,param,nr_fold,target);進行交叉驗證
4.從http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ 下載數據集進行測試,具體的數據格式可以模仿這邊下載下的文件
大概格式就是這樣的 label index1:value1 index2:value2....
5.over