from collections import Counter
from math import log
defsplit(X, y, d, value):
index_a =(X[:,d]<= value)
index_b =(X[:,d]> value)return X[index_a], X[index_b], y[index_a], y[index_b]defgini(y):
counter = Counter(y)
res =1.0for num in counter.values():
p = num /len(y)
res -= p**2return res
deftry_split(X, y):
best_g =float('inf')
best_d, best_v =-1,-1for d inrange(X.shape[1]):
sorted_index = np.argsort(X[:,d])for i inrange(1,len(X)):if X[sorted_index[i], d]!= X[sorted_index[i-1], d]:
v =(X[sorted_index[i], d]+ X[sorted_index[i-1], d])/2
X_l, X_r, y_l, y_r = split(X, y, d, v)
g = gini(y_l)+ gini(y_r)if g < best_g:
best_g, best_d, best_v = g, d, v
return best_g, best_d, best_v