-
Notifications
You must be signed in to change notification settings - Fork 49
Open
Description
def fit(self, X_train, Y_train):
# Generate decision tree
for i in range(self.tree_count):
dt_CART = decision_tree.DTreeCART()
# Bagging data
n, m = X_train.shape
sample_idx = np.random.permutation(n) # 感觉应该改为 sample_idx = np.random.choice(n,n) 才是重采样
feature_idx = np.random.permutation(m)[:int(np.sqrt(m))] # 这里feature_idx只对根节点进行了随机选择,随机森林原始应该是对所有节点分裂时属性进行随机选择吧
X_t_ = X_train[:, feature_idx]
X_t_, Y_t_ = X_t_[sample_idx, :], Y_train[sample_idx]
# Train
dt_CART.fit(X_t_, Y_t_)
self.tree_list.append((dt_CART, feature_idx))
print('=' * 10 + ' %r/%r tree trained ' % (i + 1, self.tree_count) + '=' * 10)
# print(dt_CART.visualization())Metadata
Metadata
Assignees
Labels
No labels