Skip to content

RandomForest some questions #3

@ghost

Description

def fit(self, X_train, Y_train):
        # Generate decision tree
        for i in range(self.tree_count):
            dt_CART = decision_tree.DTreeCART()
            # Bagging data
            n, m = X_train.shape
            sample_idx = np.random.permutation(n) # 感觉应该改为 sample_idx = np.random.choice(n,n) 才是重采样
            feature_idx = np.random.permutation(m)[:int(np.sqrt(m))] # 这里feature_idx只对根节点进行了随机选择,随机森林原始应该是对所有节点分裂时属性进行随机选择吧
            X_t_ = X_train[:, feature_idx]
            X_t_, Y_t_ = X_t_[sample_idx, :], Y_train[sample_idx]
            # Train
            dt_CART.fit(X_t_, Y_t_)
            self.tree_list.append((dt_CART, feature_idx))
            print('=' * 10 + ' %r/%r tree trained ' % (i + 1, self.tree_count) + '=' * 10)
            # print(dt_CART.visualization())

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions