Python之網格搜索與檢查驗證-5.2

  一、網格搜索,在我們不確定超參數的時候,需要通過不斷驗證超參數,來確定最優的參數值。這個過程就是在不斷,搜索最優的參數值,這個過程也就稱爲網格搜索

  二、檢查驗證,將準備好的訓練數據進行平均拆分,分爲訓練集驗證集。訓練集和驗證集的大小差不多,總體份數通過手動設置。具體過程爲:

  

  由上圖可以得知,訓練集和驗證集是通過交叉的方式去不斷訓練,這樣的目的就是爲了獲取,更加優化的參數值。

  三、代碼演示(這裏我們通過K-近鄰的算法。來確認參數值):

# K-近鄰算法
def k_near_test():
    # 1、原始數據
    li = load_iris()
    # print(li.data)
    # print(li.DESCR)
    # 2、處理數據
    data = li.data
    target = li.target
    x_train, x_test, y_train, y_test = train_test_split(data, target, test_size=0.25)
    # 3、特徵工程
    std = StandardScaler()
    x_train = std.fit_transform(x_train, y_train)
    x_test = std.transform(x_test)
    # 4、算法
    knn = KNeighborsClassifier(n_neighbors=2)
    knn.fit(x_train, y_train)
    # 預估
    y_predict = knn.predict(x_test)
    print("預估值:", y_predict)
    # 5、評估
    source = knn.score(x_test, y_test)
    print("準確率:", source)

    """
        交叉驗證與網格搜索:
            交叉驗證:
                1、將一個訓練集分成對等的n份(cv值)
                2、將第一個作爲驗證集,其他作爲訓練集,得出準確率
                3、將第二個作爲驗證集,其他作爲訓練集,知道第n個爲驗證集,得出準確率
                4、把得出的n個準確率,求平均值,得出模型平均準確率
            網格搜索:
                1、用於參數的調整(比如,k近鄰算法中的n_neighbors值)
                2、通過不同參數傳入進行驗證(超參數),得出最優的參數值(最優n_neighbors值)
    """
    # 4、算法
    knn_gc = KNeighborsClassifier()
    # 構造值進行搜索
    param= {"n_neighbors": [2, 3, 5]}
    # 網格搜索
    gc = GridSearchCV(knn_gc, param_grid=param,cv=4)
    gc.fit(x_train, y_train)

    # 5、評估
    print("測試集的準確率:", gc.score(x_test, y_test))
    print("交叉驗證當中最好的結果:", gc.best_score_)
    print("選擇最好的模型:", gc.best_estimator_)
    print("每個超參數每次交叉驗證結果:", gc.cv_results_)

  注意:紅色部分的解釋,主要就是通過網格搜索和交叉驗證的方式來確認超參數中的最優方案。

  其中:

    # 4、算法
    knn_gc = KNeighborsClassifier()
    # 構造值進行搜索
    param= {"n_neighbors": [2, 3, 5]}
    # 網格搜索
    gc = GridSearchCV(knn_gc, param_grid=param,cv=4)
    gc.fit(x_train, y_train)

  這部分代碼就是網格搜索和交叉驗證的實現方式。cv爲具體的份數。

  四、結果:

  

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章