斯坦福大學機器學習作業題Problem Set #2 Spam classication下篇

               

最終得出最代表垃圾郵件的五個詞爲gun,moral,israel,jew,faith

將上一篇的main函數替換爲這個

def main():
    trainMatrix, tokenlist, trainCategory = readMatrix('MATRIX.TRAIN')
    testMatrix, tokenlist, testCategory = readMatrix('MATRIX.TEST')
    state0, state1, proportion_state0, proportion_state1 = nb_train(trainMatrix,tokenlist,trainCategory)
    proportion_p1_p0=[]
    for i in range(len(state0)):
        proportion_p1_p0.append((state0[i]/state1[i]))
    largest_five=heapq.nlargest(5, proportion_p1_p0)
    location=[]
    for i in range(len(largest_five)):
        j=proportion_p1_p0.index(largest_five[i])
        location.append(j)
        print tokenlist[j]
    return



隨着數據量的增大,誤差不斷減小

將上一篇的main函數替換爲這個

def main():
    trainfile=['MATRIX.TRAIN.50','MATRIX.TRAIN.100','MATRIX.TRAIN.200','MATRIX.TRAIN.400','MATRIX.TRAIN.800','MATRIX.TRAIN.1400']
    error= np.zeros(len(trainfile))
    x=[50,100,200,400,800,1400]
    for i in range(len(trainfile)):
        trainMatrix, tokenlist, trainCategory = readMatrix(trainfile[i])
        testMatrix, tokenlist, testCategory = readMatrix('MATRIX.TEST')
        state0, state1, proportion_state0, proportion_state1 = nb_train(trainMatrix,tokenlist,trainCategory)
        output = nb_test(testMatrix,state0, state1, proportion_state0, proportion_state1)
        error[i]=evaluate(output, testCategory)
    plt.xlabel('Data quantity')
    plt.ylabel('Error')
    plt.plot(x,error)
    plt.show()
    return


svm的誤差更小一些




發佈了32 篇原創文章 · 獲贊 3 · 訪問量 1萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章