决策树算法详解

1. 什么是决策树

决策树(Decision Tree)是一种直观且强大的监督学习算法,可用于分类和回归任务。它的核心思想是通过一系列"是/否"问题将数据逐步划分,最终到达一个叶节点得出预测结果。你可以把它想象成一个倒置的树:根节点是第一个问题,每个分支代表一个答案,叶节点是最终的决策。

以一个简单的"是否出去打球"的例子来理解:

天气如何? / | \ 晴天 阴天 雨天 / | \ 湿度高? 打球 有风? / \ / \ 不打 打球 不打 打球

这棵简单的决策树只用了3个特征(天气、湿度、风力)就能做出预测。真实世界的决策树可能有数百个节点,但基本原理完全相同:在每个节点选择最优特征进行分裂,直到满足停止条件。

2. 决策树工作原理

决策树的构建过程是一个递归分裂(Recursive Splitting)的过程,核心步骤如下:

2.1 递归分裂

从根节点开始,算法遍历所有可用特征和可能的分裂点,选择能最大程度降低"不纯度"(impurity)的特征和阈值进行分裂。对分裂后的每个子集重复此过程。

2.2 特征选择

在每个节点,算法需要回答一个关键问题:选择哪个特征来分裂? 答案是选择使得子节点"最纯"的特征。衡量"纯度"的指标有信息增益(Information Gain)、增益率(Gain Ratio)和基尼不纯度(Gini Impurity),它们分别对应 ID3、C4.5 和 CART 三种算法。

2.3 叶节点与停止条件

当满足以下条件之一时,节点变为叶节点,不再继续分裂:

1. 节点中所有样本属于同一类别(纯节点)
2. 达到最大深度限制(max_depth)
3. 节点中样本数小于最小分裂阈值(min_samples_split)
4. 没有可用特征进行分裂
5. 分裂带来的信息增益小于阈值

分类树的叶节点输出该节点中出现最多的类别(多数表决);回归树的叶节点输出该节点中目标值的均值。

3. 分裂准则

3.1 信息增益 — ID3 算法

信息熵 (Entropy)

H(S) = -Σ pᵢ log₂(pᵢ)

其中 pᵢ 是类别 i 在集合 S 中的比例。熵越大,数据越混乱;熵为 0 表示完全纯净。

信息增益 (Information Gain)

IG(S, A) = H(S) - Σ (|Sᵥ| / |S|) * H(Sᵥ)

信息增益 = 分裂前的熵 - 分裂后各子集的加权熵之和。IG 越大,说明按该特征分裂后数据变得越"纯"。

数值示例: 假设有14个样本,9个正例(+),5个负例(-)。

父节点熵: H(S) = -(9/14)log₂(9/14) - (5/14)log₂(5/14) = 0.940

按特征A分裂为两个子集:
S₁: 6个样本 (4+, 2-) → H(S₁) = -(4/6)log₂(4/6) - (2/6)log₂(2/6) = 0.918
S₂: 8个样本 (5+, 3-) → H(S₂) = -(5/8)log₂(5/8) - (3/8)log₂(3/8) = 0.954

信息增益: IG = 0.940 - (6/14)*0.918 - (8/14)*0.954 = 0.940 - 0.393 - 0.545 = 0.002

按特征B分裂为两个子集:
S₁: 7个样本 (7+, 0-) → H(S₁) = 0 (纯节点!)
S₂: 7个样本 (2+, 5-) → H(S₂) = -(2/7)log₂(2/7) - (5/7)log₂(5/7) = 0.863

信息增益: IG = 0.940 - (7/14)*0 - (7/14)*0.863 = 0.940 - 0 - 0.431 = 0.509

特征B的信息增益(0.509)远大于特征A(0.002),所以选择特征B分裂。

3.2 增益率 — C4.5 算法

分裂信息 (Split Information)

SplitInfo(S, A) = -Σ (|Sᵥ| / |S|) * log₂(|Sᵥ| / |S|)

SplitInfo 衡量按特征A分裂后子集大小的分布均匀程度。特征取值越多,SplitInfo 越大。

增益率 (Gain Ratio)

GR(S, A) = IG(S, A) / SplitInfo(S, A)

增益率通过除以 SplitInfo 来惩罚取值过多的特征,解决了 ID3 偏好多值特征的问题。

数值示例: 继续上面的例子,特征B将14个样本均匀分为7和7:

SplitInfo(S, B) = -(7/14)log₂(7/14) - (7/14)log₂(7/14) = 1.0
GR(S, B) = 0.509 / 1.0 = 0.509

假设特征C有5个取值,每个分到约2-3个样本,IG=0.52但SplitInfo=2.3:
GR(S, C) = 0.52 / 2.3 = 0.226

虽然特征C的信息增益略高(0.52 vs 0.509),但增益率更低(0.226 vs 0.509)。C4.5会选择特征B,因为它不会过度偏好多值特征。

3.3 基尼不纯度 — CART 算法

基尼不纯度 (Gini Impurity)

Gini(S) = 1 - Σ pᵢ²

基尼系数衡量从集合中随机抽取两个样本,类别不一致的概率。范围 [0, 0.5](二分类),0 表示完全纯净。

加权基尼 (Weighted Gini after split)

Gini_split = Σ (|Sᵥ| / |S|) * Gini(Sᵥ)

CART 选择使加权基尼值最小的分裂方式。注意:CART 只做二叉分裂(每次分成两个子集)。

数值示例: 同样14个样本 (9+, 5-)。

父节点基尼: Gini = 1 - (9/14)² - (5/14)² = 1 - 0.413 - 0.128 = 0.459

按特征B分裂(阈值=x):
左子集: 7个样本 (7+, 0-) → Gini = 1 - 1² - 0² = 0.000
右子集: 7个样本 (2+, 5-) → Gini = 1 - (2/7)² - (5/7)² = 1 - 0.082 - 0.510 = 0.408

加权基尼: (7/14)*0.000 + (7/14)*0.408 = 0.204

加权基尼从0.459降到0.204,下降显著。CART会在所有可能的分裂中选择加权基尼最小的那个。

4. ID3 vs C4.5 vs CART 对比

特性ID3C4.5CART
分裂准则信息增益增益率基尼不纯度
树结构多叉树多叉树严格二叉树
连续特征不支持支持(二分法离散化)支持(选最优分裂点)
缺失值处理不支持支持(加权分配)支持(代理分裂)
剪枝策略悲观错误剪枝(PEP)代价复杂度剪枝(CCP)
任务类型仅分类仅分类分类 + 回归
偏好偏好多值特征修正了多值偏好无明显偏好
提出年份1986 (Quinlan)1993 (Quinlan)1984 (Breiman)
sklearn实现DecisionTreeClassifier

实践中,sklearn的决策树实现基于优化的CART算法。如果你使用Python做机器学习,默认就是CART。ID3和C4.5更多出现在学术研究和面试中。

5. Python 从零实现决策树

下面用纯Python实现一个基于信息熵的简单决策树分类器,不依赖任何第三方库(仅用math模块)。

import math from collections import Counter def entropy(labels): """计算信息熵""" n = len(labels) if n == 0: return 0 counts = Counter(labels) return -sum((c/n) * math.log2(c/n) for c in counts.values()) def best_split(X, y, features): """选择最佳分裂特征(基于信息增益)""" base_ent = entropy(y) best_ig, best_feat = -1, None for feat in features: values = set(row[feat] for row in X) weighted_ent = 0 for val in values: sub_y = [y[i] for i, row in enumerate(X) if row[feat] == val] weighted_ent += (len(sub_y) / len(y)) * entropy(sub_y) ig = base_ent - weighted_ent if ig > best_ig: best_ig, best_feat = ig, feat return best_feat def build_tree(X, y, features, depth=0, max_depth=10): """递归构建决策树""" # 停止条件:纯节点、无特征、达到最大深度 if len(set(y)) == 1: return y[0] if not features or depth >= max_depth: return Counter(y).most_common(1)[0][0] feat = best_split(X, y, features) if feat is None: return Counter(y).most_common(1)[0][0] tree = {feat: {}} values = set(row[feat] for row in X) remaining = [f for f in features if f != feat] for val in values: sub_X = [row for row in X if row[feat] == val] sub_y = [y[i] for i, row in enumerate(X) if row[feat] == val] tree[feat][val] = build_tree(sub_X, sub_y, remaining, depth+1, max_depth) return tree def predict(tree, sample): """用决策树预测单个样本""" if not isinstance(tree, dict): return tree feat = list(tree.keys())[0] val = sample.get(feat) subtree = tree[feat].get(val) if subtree is None: return None # 未知特征值 return predict(subtree, sample) # 使用示例 data = [ {'天气': '晴', '湿度': '高', '风力': '弱'}, {'天气': '晴', '湿度': '高', '风力': '强'}, {'天气': '阴', '湿度': '高', '风力': '弱'}, {'天气': '雨', '湿度': '高', '风力': '弱'}, {'天气': '雨', '湿度': '正常', '风力': '弱'}, ] labels = ['不打球', '不打球', '打球', '打球', '打球'] features = ['天气', '湿度', '风力'] tree = build_tree(data, labels, features, max_depth=5) print(tree) print(predict(tree, {'天气': '晴', '湿度': '正常', '风力': '弱'}))

6. Sklearn 决策树实战

6.1 分类树 (DecisionTreeClassifier)

from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, classification_report from sklearn.datasets import load_iris # 加载数据 iris = load_iris() X_train, X_test, y_train, y_test = train_test_split( iris.data, iris.target, test_size=0.2, random_state=42 ) # 训练决策树 clf = DecisionTreeClassifier( criterion='gini', # 'gini' 或 'entropy' max_depth=4, # 最大深度 min_samples_split=5, # 内部节点最少样本数 min_samples_leaf=2, # 叶节点最少样本数 random_state=42 ) clf.fit(X_train, y_train) # 预测与评估 y_pred = clf.predict(X_test) print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}") print(classification_report(y_test, y_pred, target_names=iris.target_names))

6.2 回归树 (DecisionTreeRegressor)

from sklearn.tree import DecisionTreeRegressor from sklearn.metrics import mean_squared_error, r2_score import numpy as np # 回归树用法完全类似,只是 criterion 不同 reg = DecisionTreeRegressor( criterion='squared_error', # 也支持 'friedman_mse', 'absolute_error' max_depth=5, min_samples_leaf=5, random_state=42 ) reg.fit(X_train, y_train) y_pred = reg.predict(X_test) print(f"MSE: {mean_squared_error(y_test, y_pred):.4f}") print(f"R^2: {r2_score(y_test, y_pred):.4f}")

6.3 特征重要性

# 查看各特征的重要性 import pandas as pd importances = clf.feature_importances_ feat_imp = pd.DataFrame({ 'feature': iris.feature_names, 'importance': importances }).sort_values('importance', ascending=False) print(feat_imp)

7. 剪枝策略

未剪枝的决策树容易过拟合——它可以完美拟合训练集(每个叶节点只有1个样本),但泛化能力很差。剪枝是解决过拟合的关键手段。

7.1 预剪枝 (Pre-pruning)

在树的构建过程中提前停止生长。sklearn 支持的预剪枝参数:

max_depth: 限制树的最大深度。通常设为 3-10。过小欠拟合,过大过拟合。
min_samples_split: 节点分裂所需最少样本数,默认2。增大可防止过拟合。
min_samples_leaf: 叶节点最少样本数,默认1。增大可让叶节点更稳定。
max_features: 每次分裂考虑的最大特征数。'sqrt' 或 'log2' 可增加随机性。
max_leaf_nodes: 最大叶节点数。限制树的复杂度。

7.2 后剪枝 — 代价复杂度剪枝 (Cost-Complexity Pruning, CCP)

先让树完全生长,然后从底部开始逐步"剪掉"对预测帮助不大的子树。sklearn 中通过 ccp_alpha 参数控制。

代价复杂度公式

R_alpha(T) = R(T) + alpha * |T|

R(T) 是树的训练误差,|T| 是叶节点数量,alpha 是惩罚系数。alpha 越大,惩罚越重,最终树越小。

# 使用 CCP 后剪枝 # 1. 获取不同 alpha 值对应的剪枝路径 path = clf.cost_complexity_pruning_path(X_train, y_train) ccp_alphas = path.ccp_alphas # 2. 对每个 alpha 训练决策树 clfs = [] for alpha in ccp_alphas: c = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42) c.fit(X_train, y_train) clfs.append(c) # 3. 评估训练集和测试集准确率,选择最佳 alpha train_scores = [c.score(X_train, y_train) for c in clfs] test_scores = [c.score(X_test, y_test) for c in clfs] best_idx = np.argmax(test_scores) best_alpha = ccp_alphas[best_idx] print(f"Best alpha: {best_alpha:.4f}, Test Accuracy: {test_scores[best_idx]:.4f}")

8. 决策树可视化

8.1 使用 sklearn plot_tree

import matplotlib.pyplot as plt from sklearn.tree import plot_tree fig, ax = plt.subplots(figsize=(20, 10)) plot_tree( clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, # 按类别填充颜色 rounded=True, # 圆角方框 fontsize=10, ax=ax ) plt.tight_layout() plt.savefig('decision_tree.png', dpi=150) plt.show()

8.2 使用 Graphviz

from sklearn.tree import export_graphviz import graphviz # 导出 DOT 格式 dot_data = export_graphviz( clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True, special_characters=True ) # 渲染为图片 graph = graphviz.Source(dot_data) graph.render('iris_tree', format='png', cleanup=True) # 保存为 iris_tree.png graph # 在 Jupyter Notebook 中直接显示

8.3 文本形式输出

from sklearn.tree import export_text # 纯文本形式打印决策树规则 tree_rules = export_text(clf, feature_names=iris.feature_names) print(tree_rules) # 输出类似: # |--- petal length (cm) <= 2.45 # | |--- class: setosa # |--- petal length (cm) > 2.45 # | |--- petal width (cm) <= 1.75 # | | |--- class: versicolor # | |--- petal width (cm) > 1.75 # | | |--- class: virginica

9. 优点与缺点

优点:
1. 直观易解释,可可视化
2. 无需特征缩放/标准化
3. 能处理数值和类别特征
4. 能捕获非线性关系和特征交互
5. 可处理缺失值(CART)
6. 训练和预测速度快
7. 可直接输出特征重要性
缺点:
1. 容易过拟合(特别是深树)
2. 不稳定:数据小变化可能导致完全不同的树
3. 对类别不平衡敏感
4. 贪心算法,不保证全局最优
5. 外推能力差(超出训练范围的值)
6. 处理高维稀疏数据效果差
7. 单棵树精度通常不如集成方法

10. 何时使用决策树

场景推荐算法原因
需要模型可解释性决策树可直接可视化决策规则,满足合规要求
快速建立基线模型决策树训练快、无需特征工程、不易调参
追求最高精度随机森林 / XGBoost集成方法组合多棵树,精度更高更稳定
高维线性可分数据SVM / 逻辑回归决策树对高维稀疏数据效果差
小数据集(<100样本)决策树(浅树) / KNN简单模型防过拟合
大规模数据(100万+)XGBoost / LightGBM优化算法效率高,支持分布式
需要特征重要性分析决策树 / 随机森林内置特征重要性输出
数据含缺失值CART决策树 / XGBoost原生支持缺失值处理

经验法则:先用决策树理解数据和特征,再用随机森林或梯度提升树提升精度。决策树是集成方法的基础——理解了决策树,就理解了随机森林和XGBoost的核心。

12. 常见问题 (FAQ)

Q: sklearn 的决策树用的是 ID3、C4.5 还是 CART?

sklearn 的 DecisionTreeClassifier 和 DecisionTreeRegressor 均基于优化的 CART 算法,仅构建二叉树。可通过 criterion 参数选择 'gini'(基尼系数)或 'entropy'(信息熵),但树结构始终是二叉的。如需 ID3 或 C4.5 的多叉树,需自行实现或使用其他库。

Q: 决策树和随机森林有什么区别?

随机森林(Random Forest)是由多棵决策树组成的集成模型。它通过两种随机化机制减少过拟合:1) Bagging — 每棵树在随机采样的数据子集上训练;2) 特征随机 — 每次分裂只考虑随机子集的特征。最终预测是所有树的投票/平均结果。随机森林牺牲了可解释性,但大幅提升了精度和稳定性。

Q: 如何确定最佳的 max_depth?

最常用的方法是交叉验证网格搜索:设定一个候选范围(如3到15),用 GridSearchCV 测试每个深度的验证集表现,选择测试集得分最高的深度。也可以使用 CCP 后剪枝自动找到最优复杂度。经验上,大多数场景下深度 4-8 就足够了。

Q: 决策树需要做特征缩放(标准化/归一化)吗?

不需要。决策树根据特征值的大小关系进行分裂(例如"身高 > 170cm?"),分裂结果不受特征尺度的影响。无论身高用厘米还是米表示,分裂效果完全一样。这是决策树和树集成方法相比 SVM、KNN 等算法的一大优势。

Q: 决策树能做多分类和多输出任务吗?

可以。sklearn 的决策树原生支持多分类(不需要 OvR/OvO),因为叶节点可以直接投票得到多个类别之一。也支持多输出(Multi-output),即同时预测多个目标变量——只需将 y 设为多列矩阵即可。