ID3、C4.5和Cart树
AI摘要: This is a blog post on the implementation of DecisionTree algorithm using C++.
C4.5
是非常经典的决策树算法,在ID3
算法的基础上进行修改。
ID3
的缺点:
- 基于信息增益的选择属性方法会偏向于选择类别多的属性
- 不能处理连续属性
C4.5
的改进方法:
- 使用增益率(信息增益比)作为度量标准来选择最优特征,防止算法偏向于类别数较多的特征。
- 通过对连续属性离散化(注意,这里的处理是预处理,提前离散化分区间),然后计算信息增益比
- 在
C4.5
中还可以通过对数据的统计,选择平均值或者众数来填补缺失值。 - 在
ID3
中,倾向于构建完整的决策树,这就会导致过拟合问题,在C4.5
中,通过预剪枝和后剪枝来抑制树的生长
Cart树
和C4.5
相比又更好,主要在以下几个方面:
- Cart树采用基尼指数来作为度量标准: , 一共K个类别
- 在连续型数值上,C4.5是对数据提前离散化,Cart则直接对连续数据寻找最优分割点
- Cart树还能处理回归问题,通过使用方差最小化作为分裂标准来实现。
叽里咕噜一堆理论,不如直接手撸一个Cart树:
#include <iostream>
#include <vector>
#include <map>
#include <algorithm>
#include <cmath>
struct TreeNode {
int feature_index;
double threshold;
std::string value;
TreeNode* left = nullptr;
TreeNode* right = nullptr;
TreeNode(int index, double thresh, std::string val) : feature_index(index), threshold(thresh), value(val) {}
};
class DecisionTree {
Blog:
DecisionTree(double min_samples_split, double max_depth)
: m_min_samples_split(min_samples_split), m_max_depth(max_depth) {}
void train(const std::vector<std::vector<double>>& features, const std::vector<int>& labels);
int predict(const std::vector<double>& sample);
private:
TreeNode* build_tree(const std::vector<std::vector<double>>& features, const std::vector<int>& labels, int depth);
std::pair<double, double> find_best_split(const std::vector<std::vector<double>>& features, const std::vector<int>& labels, int feature_index);
double gini_impurity(const std::vector<int>& labels);
double weighted_gini_impurity(const std::vector<int>& left_labels, const std::vector<int>& right_labels);
int majority_vote(const std::vector<int>& labels);
double m_min_samples_split;
double m_max_depth;
};
void DecisionTree::train(const std::vector<std::vector<double>>& features, const std::vector<int>& labels) {
if (features.size() == 0 || features[0].size() == 0) {
std::cerr << "Error: Features or labels are empty." << std::endl;
return;
}
m_root = build_tree(features, labels, 0);
}
TreeNode* DecisionTree::build_tree(const std::vector<std::vector<double>>& features, const std::vector<int>& labels, int depth) {
if (labels.size() <= m_min_samples_split || depth >= m_max_depth) {
return new TreeNode(-1, -1, std::to_string(majority_vote(labels)));
}
int best_feature_index = -1;
double best_threshold = -1;
double best_impurity = 1.0; // Maximum impurity
for (int i = 0; i < features[0].size(); ++i) {
auto [threshold, impurity] = find_best_split(features, labels, i);
if (impurity < best_impurity) {
best_feature_index = i;
best_threshold = threshold;
best_impurity = impurity;
}
}
if (best_feature_index == -1) {
return new TreeNode(-1, -1, std::to_string(majority_vote(labels)));
}
std::vector<std::vector<double>> left_features;
std::vector<int> left_labels;
std::vector<std::vector<double>> right_features;
std::vector<int> right_labels;
for (int i = 0; i < features.size(); ++i) {
if (features[i][best_feature_index] < best_threshold) {
left_features.push_back(features[i]);
left_labels.push_back(labels[i]);
} else {
right_features.push_back(features[i]);
right_labels.push_back(labels[i]);
}
}
TreeNode* node = new TreeNode(best_feature_index, best_threshold, "");
node->left = build_tree(left_features, left_labels, depth + 1);
node->right = build_tree(right_features, right_labels, depth + 1);
return node;
}
std::pair<double, double> DecisionTree::find_best_split(const std::vector<std::vector<double>>& features, const std::vector<int>& labels, int feature_index) {
std::vector<double> values;
for (const auto& feature : features) {
values.push_back(feature[feature_index]);
}
std::sort(values.begin(), values.end());
double best_threshold = -1;
double best_impurity = 1.0; // Maximum impurity
for (int i = 0; i < values.size() - 1; ++i) {
double threshold = (values[i] + values[i + 1]) / 2.0;
std::vector<int> left_labels, right_labels;
for (int j = 0; j < features.size(); ++j) {
if (features[j][feature_index] < threshold) {
left_labels.push_back(labels[j]);
} else {
right_labels.push_back(labels[j]);
}
}
double impurity = weighted_gini_impurity(left_labels, right_labels);
if (impurity < best_impurity) {
best_threshold = threshold;
best_impurity = impurity;
}
}
return {best_threshold, best_impurity};
}
double DecisionTree::gini_impurity(const std::vector<int>& labels) {
std::map<int, int> label_counts;
for (int label : labels) {
label_counts[label]++;
}
double impurity = 1.0;
for (const auto& pair : label_counts) {
double p = static_cast<double>(pair.second) / labels.size();
impurity -= p * p;
}
return impurity;
}
double DecisionTree::weighted_gini_impurity(const std::vector<int>& left_labels, const std::vector<int>& right_labels) {
double total = left_labels.size() + right_labels.size();
double left_impurity = gini_impurity(left_labels);
double right_impurity = gini_impurity(right_labels);
return (left_labels.size() * left_impurity + right_labels.size() * right_impurity) / total;
}
int DecisionTree::majority_vote(const std::vector<int>& labels) {
std::map<int, int> label_counts;
for (int label : labels) {
label_counts[label]++;
}
int most_common_label = 0;
int max_count = 0;
for (const auto& pair : label_counts) {
if (pair.second > max_count) {
max_count = pair.second;
most_common_label = pair.first;
}
}
return most_common_label;
}
int DecisionTree::predict(const std::vector<double>& sample) {
TreeNode* current_node = m_root;
while (current_node->left != nullptr && current_node->right != nullptr) {
if (sample[current_node->feature_index] < current_node->threshold) {
current_node = current_node->left;
} else {
current_node = current_node->right;
}
}
return std::stoi(current_node->value);
}
// Example usage
int main() {
std::vector<std::vector<double>> features = {{2.5, 0.5}, {2.5, 0.75}, {2.5, 1.0}, {2.5, 1.25}, {2.5, 1.5}, {2.5, 1.75}, {1.0, 1.0}, {1.0, 1.5}, {1.0, 2.0}, {1.0, 2.5}, {0.5, 1.0}, {0.5, 1.5}, {0.5, 2.0}, {0.5, 2.5}};
std::vector<int> labels = {0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1};
DecisionTree tree(2, 10);
tree.train(features, labels);
std::vector<double> test_sample = {1.5, 1.0};
std::cout << "Prediction: " << tree.predict(test_sample) << std::endl;
return 0;
}