sklearn DecisionTree 源码分析

sklearn.tree._classes.BaseDecisionTree#fit
y至少为1维(意思是可以处理multilabels数据)

y = np.atleast_1d(y)
if is_classifier(self):
    self.tree_ = Tree(self.n_features_,
                      self.n_classes_, self.n_outputs_)
else:
    self.tree_ = Tree(self.n_features_,
                      # TODO: tree should't need this in this case
                      np.array([1] * self.n_outputs_, dtype=np.intp),
                      self.n_outputs_)
self.n_outputs_ = y.shape[1]
self.n_classes_ = self.n_classes_[0]
self.n_classes_ = []
for k in range(self.n_outputs_):
    classes_k, y_encoded[:, k] = np.unique(y[:, k],
                                           return_inverse=True)
    self.classes_.append(classes_k)
    self.n_classes_.append(classes_k.shape[0])
np.unique([3,2,2,3,3,4], return_inverse=True)
Out[4]: (array([2, 3, 4]), array([1, 0, 0, 1, 1, 2]))

return_inverse类似于LabelEncode

sklearn.tree._tree.Tree

    def __cinit__(self, int n_features, np.ndarray[SIZE_t, ndim=1] n_classes,
                  int n_outputs):
  1. 特征数
  2. 类别数
  3. label维度
# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
if max_leaf_nodes < 0:
    builder = DepthFirstTreeBuilder(splitter, min_samples_split,
                                    min_samples_leaf,
                                    min_weight_leaf,
                                    max_depth,
                                    self.min_impurity_decrease,
                                    min_impurity_split)
else:
    builder = BestFirstTreeBuilder(splitter, min_samples_split,
                                   min_samples_leaf,
                                   min_weight_leaf,
                                   max_depth,
                                   max_leaf_nodes,
                                   self.min_impurity_decrease,
                                   min_impurity_split)

scikit-learn决策树算法类库介绍

最大叶子节点数max_leaf_nodes

通过限制最大叶子节点数,可以防止过拟合,默认是"None”,即不限制最大的叶子节点数。如果加了限制,算法会建立在最大叶子节点数内最优的决策树。如果特征不多,可以不考虑这个值,但是如果特征分成多的话,可以加以限制,具体的值可以通过交叉验证得到。

sklearn.tree._tree.DepthFirstTreeBuilder#build

builder.build(self.tree_, X, y, sample_weight, X_idx_sorted)
cpdef build(self, Tree tree, object X, np.ndarray y,
            np.ndarray sample_weight=None,
            np.ndarray X_idx_sorted=None):

注意到一个现象,这里该有的参数都有,但是class_weight去哪了呢?怀疑是转化了sample_weight

if self.class_weight is not None:
    expanded_class_weight = compute_sample_weight(
        self.class_weight, y_original)
if expanded_class_weight is not None:
    if sample_weight is not None:
        sample_weight = sample_weight * expanded_class_weight
    else:
        sample_weight = expanded_class_weight

sklearn/tree/_tree.pyx:203

splitter.init(X, y, sample_weight_ptr, X_idx_sorted)
cdef SIZE_t n_node_samples = splitter.n_samples
rc = stack.push(0, n_node_samples, 0, _TREE_UNDEFINED, 0, INFINITY, 0)

rc是根节点,在分裂前含有所有的样本

StackStackRecord都是sklearn自己写的数据结构

is_leaf = (depth >= max_depth or
           n_node_samples < min_samples_split or
           n_node_samples < 2 * min_samples_leaf or
           weighted_n_node_samples < 2 * min_weight_leaf)
is_leaf = (is_leaf or (impurity <= min_impurity_split))

满足以上条件直接停止分裂

sklearn.tree._splitter.BestSplitter

sklearn.tree._splitter.BestSplitter#node_split


scikit-learn uses an optimised version of the CART algorithm; however, scikit-learn implementation does not support categorical variables for now.

在这里插入图片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章