ID3、C4.5和Cart树

·7470·16 分钟·
AI摘要: This is a blog post on the implementation of DecisionTree algorithm using C++.

C4.5是非常经典的决策树算法,在ID3算法的基础上进行修改。

ID3的缺点:

  1. 基于信息增益的选择属性方法会偏向于选择类别多的属性

  2. 不能处理连续属性

C4.5的改进方法:

  1. 使用增益率(信息增益比)作为度量标准来选择最优特征,防止算法偏向于类别数较多的特征。

  2. 通过对连续属性离散化(注意,这里的处理是预处理,提前离散化分区间),然后计算信息增益比

  3. C4.5中还可以通过对数据的统计,选择平均值或者众数来填补缺失值。

  4. ID3中,倾向于构建完整的决策树,这就会导致过拟合问题,在C4.5中,通过预剪枝和后剪枝来抑制树的生长

Cart树C4.5相比又更好,主要在以下几个方面:

  1. Cart树采用基尼指数来作为度量标准: Gini(X)=1i=1Kpk2Gini(X) = 1 - \sum_{i = 1}^K p_k^2 , 一共K个类别

  2. 在连续型数值上,C4.5是对数据提前离散化,Cart则直接对连续数据寻找最优分割点

  3. Cart树还能处理回归问题,通过使用方差最小化作为分裂标准来实现。

叽里咕噜一堆理论,不如直接手撸一个Cart树:


#include <iostream>

#include <vector>

#include <map>

#include <algorithm>

#include <cmath>



struct TreeNode {

    int feature_index;

    double threshold;

    std::string value;

    TreeNode* left = nullptr;

    TreeNode* right = nullptr;



    TreeNode(int index, double thresh, std::string val) : feature_index(index), threshold(thresh), value(val) {}

};



class DecisionTree {

public:

    DecisionTree(double min_samples_split, double max_depth)

        : m_min_samples_split(min_samples_split), m_max_depth(max_depth) {}



    void train(const std::vector<std::vector<double>>& features, const std::vector<int>& labels);

    int predict(const std::vector<double>& sample);



private:

    TreeNode* build_tree(const std::vector<std::vector<double>>& features, const std::vector<int>& labels, int depth);

    std::pair<double, double> find_best_split(const std::vector<std::vector<double>>& features, const std::vector<int>& labels, int feature_index);

    double gini_impurity(const std::vector<int>& labels);

    double weighted_gini_impurity(const std::vector<int>& left_labels, const std::vector<int>& right_labels);

    int majority_vote(const std::vector<int>& labels);



    double m_min_samples_split;

    double m_max_depth;

};



void DecisionTree::train(const std::vector<std::vector<double>>& features, const std::vector<int>& labels) {

    if (features.size() == 0 || features[0].size() == 0) {

        std::cerr << "Error: Features or labels are empty." << std::endl;

        return;

    }



    m_root = build_tree(features, labels, 0);

}



TreeNode* DecisionTree::build_tree(const std::vector<std::vector<double>>& features, const std::vector<int>& labels, int depth) {

    if (labels.size() <= m_min_samples_split || depth >= m_max_depth) {

        return new TreeNode(-1, -1, std::to_string(majority_vote(labels)));

    }



    int best_feature_index = -1;

    double best_threshold = -1;

    double best_impurity = 1.0; // Maximum impurity



    for (int i = 0; i < features[0].size(); ++i) {

        auto [threshold, impurity] = find_best_split(features, labels, i);

        if (impurity < best_impurity) {

            best_feature_index = i;

            best_threshold = threshold;

            best_impurity = impurity;

        }

    }



    if (best_feature_index == -1) {

        return new TreeNode(-1, -1, std::to_string(majority_vote(labels)));

    }



    std::vector<std::vector<double>> left_features;

    std::vector<int> left_labels;

    std::vector<std::vector<double>> right_features;

    std::vector<int> right_labels;



    for (int i = 0; i < features.size(); ++i) {

        if (features[i][best_feature_index] < best_threshold) {

            left_features.push_back(features[i]);

            left_labels.push_back(labels[i]);

        } else {

            right_features.push_back(features[i]);

            right_labels.push_back(labels[i]);

        }

    }



    TreeNode* node = new TreeNode(best_feature_index, best_threshold, "");



    node->left = build_tree(left_features, left_labels, depth + 1);

    node->right = build_tree(right_features, right_labels, depth + 1);



    return node;

}



std::pair<double, double> DecisionTree::find_best_split(const std::vector<std::vector<double>>& features, const std::vector<int>& labels, int feature_index) {

    std::vector<double> values;

    for (const auto& feature : features) {

        values.push_back(feature[feature_index]);

    }



    std::sort(values.begin(), values.end());



    double best_threshold = -1;

    double best_impurity = 1.0; // Maximum impurity



    for (int i = 0; i < values.size() - 1; ++i) {

        double threshold = (values[i] + values[i + 1]) / 2.0;

        std::vector<int> left_labels, right_labels;



        for (int j = 0; j < features.size(); ++j) {

            if (features[j][feature_index] < threshold) {

                left_labels.push_back(labels[j]);

            } else {

                right_labels.push_back(labels[j]);

            }

        }



        double impurity = weighted_gini_impurity(left_labels, right_labels);

        if (impurity < best_impurity) {

            best_threshold = threshold;

            best_impurity = impurity;

        }

    }



    return {best_threshold, best_impurity};

}



double DecisionTree::gini_impurity(const std::vector<int>& labels) {

    std::map<int, int> label_counts;

    for (int label : labels) {

        label_counts[label]++;

    }



    double impurity = 1.0;

    for (const auto& pair : label_counts) {

        double p = static_cast<double>(pair.second) / labels.size();

        impurity -= p * p;

    }



    return impurity;

}



double DecisionTree::weighted_gini_impurity(const std::vector<int>& left_labels, const std::vector<int>& right_labels) {

    double total = left_labels.size() + right_labels.size();

    double left_impurity = gini_impurity(left_labels);

    double right_impurity = gini_impurity(right_labels);



    return (left_labels.size() * left_impurity + right_labels.size() * right_impurity) / total;

}



int DecisionTree::majority_vote(const std::vector<int>& labels) {

    std::map<int, int> label_counts;

    for (int label : labels) {

        label_counts[label]++;

    }



    int most_common_label = 0;

    int max_count = 0;

    for (const auto& pair : label_counts) {

        if (pair.second > max_count) {

            max_count = pair.second;

            most_common_label = pair.first;

        }

    }



    return most_common_label;

}



int DecisionTree::predict(const std::vector<double>& sample) {

    TreeNode* current_node = m_root;



    while (current_node->left != nullptr && current_node->right != nullptr) {

        if (sample[current_node->feature_index] < current_node->threshold) {

            current_node = current_node->left;

        } else {

            current_node = current_node->right;

        }

    }



    return std::stoi(current_node->value);

}



// Example usage

int main() {

    std::vector<std::vector<double>> features = {{2.5, 0.5}, {2.5, 0.75}, {2.5, 1.0}, {2.5, 1.25}, {2.5, 1.5}, {2.5, 1.75}, {1.0, 1.0}, {1.0, 1.5}, {1.0, 2.0}, {1.0, 2.5}, {0.5, 1.0}, {0.5, 1.5}, {0.5, 2.0}, {0.5, 2.5}};

    std::vector<int> labels = {0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1};



    DecisionTree tree(2, 10);

    tree.train(features, labels);



    std::vector<double> test_sample = {1.5, 1.0};

    std::cout << "Prediction: " << tree.predict(test_sample) << std::endl;



    return 0;

}