ID3、C4.5和Cart树

·7034·15 分钟·
AI摘要: This is a blog post on the implementation of DecisionTree algorithm using C++.

C4.5是非常经典的决策树算法,在ID3算法的基础上进行修改。

ID3的缺点:

  1. 基于信息增益的选择属性方法会偏向于选择类别多的属性
  2. 不能处理连续属性

C4.5的改进方法:

  1. 使用增益率(信息增益比)作为度量标准来选择最优特征,防止算法偏向于类别数较多的特征。
  2. 通过对连续属性离散化(注意,这里的处理是预处理,提前离散化分区间),然后计算信息增益比
  3. C4.5中还可以通过对数据的统计,选择平均值或者众数来填补缺失值。
  4. ID3中,倾向于构建完整的决策树,这就会导致过拟合问题,在C4.5中,通过预剪枝和后剪枝来抑制树的生长

Cart树C4.5相比又更好,主要在以下几个方面:

  1. Cart树采用基尼指数来作为度量标准: Gini(X)=1i=1Kpk2Gini(X) = 1 - \sum_{i = 1}^K p_k^2 , 一共K个类别
  2. 在连续型数值上,C4.5是对数据提前离散化,Cart则直接对连续数据寻找最优分割点
  3. Cart树还能处理回归问题,通过使用方差最小化作为分裂标准来实现。

叽里咕噜一堆理论,不如直接手撸一个Cart树:

#include <iostream>
#include <vector>
#include <map>
#include <algorithm>
#include <cmath>

struct TreeNode {
    int feature_index;
    double threshold;
    std::string value;
    TreeNode* left = nullptr;
    TreeNode* right = nullptr;

    TreeNode(int index, double thresh, std::string val) : feature_index(index), threshold(thresh), value(val) {}
};

class DecisionTree {
Blog:
    DecisionTree(double min_samples_split, double max_depth)
        : m_min_samples_split(min_samples_split), m_max_depth(max_depth) {}

    void train(const std::vector<std::vector<double>>& features, const std::vector<int>& labels);
    int predict(const std::vector<double>& sample);

private:
    TreeNode* build_tree(const std::vector<std::vector<double>>& features, const std::vector<int>& labels, int depth);
    std::pair<double, double> find_best_split(const std::vector<std::vector<double>>& features, const std::vector<int>& labels, int feature_index);
    double gini_impurity(const std::vector<int>& labels);
    double weighted_gini_impurity(const std::vector<int>& left_labels, const std::vector<int>& right_labels);
    int majority_vote(const std::vector<int>& labels);

    double m_min_samples_split;
    double m_max_depth;
};

void DecisionTree::train(const std::vector<std::vector<double>>& features, const std::vector<int>& labels) {
    if (features.size() == 0 || features[0].size() == 0) {
        std::cerr << "Error: Features or labels are empty." << std::endl;
        return;
    }

    m_root = build_tree(features, labels, 0);
}

TreeNode* DecisionTree::build_tree(const std::vector<std::vector<double>>& features, const std::vector<int>& labels, int depth) {
    if (labels.size() <= m_min_samples_split || depth >= m_max_depth) {
        return new TreeNode(-1, -1, std::to_string(majority_vote(labels)));
    }

    int best_feature_index = -1;
    double best_threshold = -1;
    double best_impurity = 1.0; // Maximum impurity

    for (int i = 0; i < features[0].size(); ++i) {
        auto [threshold, impurity] = find_best_split(features, labels, i);
        if (impurity < best_impurity) {
            best_feature_index = i;
            best_threshold = threshold;
            best_impurity = impurity;
        }
    }

    if (best_feature_index == -1) {
        return new TreeNode(-1, -1, std::to_string(majority_vote(labels)));
    }

    std::vector<std::vector<double>> left_features;
    std::vector<int> left_labels;
    std::vector<std::vector<double>> right_features;
    std::vector<int> right_labels;

    for (int i = 0; i < features.size(); ++i) {
        if (features[i][best_feature_index] < best_threshold) {
            left_features.push_back(features[i]);
            left_labels.push_back(labels[i]);
        } else {
            right_features.push_back(features[i]);
            right_labels.push_back(labels[i]);
        }
    }

    TreeNode* node = new TreeNode(best_feature_index, best_threshold, "");

    node->left = build_tree(left_features, left_labels, depth + 1);
    node->right = build_tree(right_features, right_labels, depth + 1);

    return node;
}

std::pair<double, double> DecisionTree::find_best_split(const std::vector<std::vector<double>>& features, const std::vector<int>& labels, int feature_index) {
    std::vector<double> values;
    for (const auto& feature : features) {
        values.push_back(feature[feature_index]);
    }

    std::sort(values.begin(), values.end());

    double best_threshold = -1;
    double best_impurity = 1.0; // Maximum impurity

    for (int i = 0; i < values.size() - 1; ++i) {
        double threshold = (values[i] + values[i + 1]) / 2.0;
        std::vector<int> left_labels, right_labels;

        for (int j = 0; j < features.size(); ++j) {
            if (features[j][feature_index] < threshold) {
                left_labels.push_back(labels[j]);
            } else {
                right_labels.push_back(labels[j]);
            }
        }

        double impurity = weighted_gini_impurity(left_labels, right_labels);
        if (impurity < best_impurity) {
            best_threshold = threshold;
            best_impurity = impurity;
        }
    }

    return {best_threshold, best_impurity};
}

double DecisionTree::gini_impurity(const std::vector<int>& labels) {
    std::map<int, int> label_counts;
    for (int label : labels) {
        label_counts[label]++;
    }

    double impurity = 1.0;
    for (const auto& pair : label_counts) {
        double p = static_cast<double>(pair.second) / labels.size();
        impurity -= p * p;
    }

    return impurity;
}

double DecisionTree::weighted_gini_impurity(const std::vector<int>& left_labels, const std::vector<int>& right_labels) {
    double total = left_labels.size() + right_labels.size();
    double left_impurity = gini_impurity(left_labels);
    double right_impurity = gini_impurity(right_labels);

    return (left_labels.size() * left_impurity + right_labels.size() * right_impurity) / total;
}

int DecisionTree::majority_vote(const std::vector<int>& labels) {
    std::map<int, int> label_counts;
    for (int label : labels) {
        label_counts[label]++;
    }

    int most_common_label = 0;
    int max_count = 0;
    for (const auto& pair : label_counts) {
        if (pair.second > max_count) {
            max_count = pair.second;
            most_common_label = pair.first;
        }
    }

    return most_common_label;
}

int DecisionTree::predict(const std::vector<double>& sample) {
    TreeNode* current_node = m_root;

    while (current_node->left != nullptr && current_node->right != nullptr) {
        if (sample[current_node->feature_index] < current_node->threshold) {
            current_node = current_node->left;
        } else {
            current_node = current_node->right;
        }
    }

    return std::stoi(current_node->value);
}

// Example usage
int main() {
    std::vector<std::vector<double>> features = {{2.5, 0.5}, {2.5, 0.75}, {2.5, 1.0}, {2.5, 1.25}, {2.5, 1.5}, {2.5, 1.75}, {1.0, 1.0}, {1.0, 1.5}, {1.0, 2.0}, {1.0, 2.5}, {0.5, 1.0}, {0.5, 1.5}, {0.5, 2.0}, {0.5, 2.5}};
    std::vector<int> labels = {0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1};

    DecisionTree tree(2, 10);
    tree.train(features, labels);

    std::vector<double> test_sample = {1.5, 1.0};
    std::cout << "Prediction: " << tree.predict(test_sample) << std::endl;

    return 0;
}
Kaggle学习赛初探