决策树算法详解

1. What is a Decision Tree

A Decision Tree is an intuitive and powerful supervised learning algorithm used for both classification and regression tasks. The core idea is to split data through a series of yes/no questions, progressively narrowing down to a prediction at a leaf node. Think of it as an upside-down tree: the root node asks the first question, branches represent answers, and leaf nodes hold the final decisions.

Here is a simple example -- a "should I play outside?" decision tree:

Weather? / | \ Sunny Cloudy Rainy / | \ Humidity? Play! Windy? / \ / \ No Play No Play

This simple tree uses only 3 features (weather, humidity, wind) to make predictions. Real-world decision trees may have hundreds of nodes, but the principle is identical: at each node, select the best feature to split on, and continue until a stopping condition is met.

2. How It Works

Building a decision tree is a recursive splitting process. Here are the core steps:

2.1 Recursive Splitting

Starting from the root, the algorithm examines all available features and possible split points, selecting the one that best reduces "impurity." This process is then repeated for each resulting subset.

2.2 Feature Selection

At each node, the algorithm must answer: which feature should we split on? The answer is the feature that produces the "purest" child nodes. Purity is measured by Information Gain (ID3), Gain Ratio (C4.5), or Gini Impurity (CART).

2.3 Leaf Nodes and Stopping Conditions

A node becomes a leaf (no further splitting) when any of these conditions are met:

1. All samples in the node belong to the same class (pure node)
2. Maximum depth limit reached (max_depth)
3. Number of samples is below the minimum split threshold (min_samples_split)
4. No remaining features to split on
5. Information gain from any possible split falls below a threshold

For classification trees, the leaf outputs the majority class among its samples. For regression trees, the leaf outputs the mean of the target values.

3. Splitting Criteria

3.1 Information Gain -- ID3 Algorithm

Entropy

H(S) = -Σ pᵢ log₂(pᵢ)

Where pᵢ is the proportion of class i in set S. Higher entropy means more disorder; entropy of 0 means perfectly pure.

Information Gain (IG)

IG(S, A) = H(S) - Σ (|Sᵥ| / |S|) * H(Sᵥ)

Information Gain = entropy before split - weighted sum of child entropies. Higher IG means the feature produces purer subsets.

Worked Example: 14 samples total -- 9 positive (+), 5 negative (-).

Parent entropy: H(S) = -(9/14)log₂(9/14) - (5/14)log₂(5/14) = 0.940

Split on Feature A into two subsets:
S₁: 6 samples (4+, 2-) → H(S₁) = -(4/6)log₂(4/6) - (2/6)log₂(2/6) = 0.918
S₂: 8 samples (5+, 3-) → H(S₂) = -(5/8)log₂(5/8) - (3/8)log₂(3/8) = 0.954

IG(A): 0.940 - (6/14)*0.918 - (8/14)*0.954 = 0.940 - 0.393 - 0.545 = 0.002

Split on Feature B into two subsets:
S₁: 7 samples (7+, 0-) → H(S₁) = 0 (pure!)
S₂: 7 samples (2+, 5-) → H(S₂) = -(2/7)log₂(2/7) - (5/7)log₂(5/7) = 0.863

IG(B): 0.940 - (7/14)*0 - (7/14)*0.863 = 0.940 - 0 - 0.431 = 0.509

Feature B's IG (0.509) is much larger than Feature A's (0.002), so we split on Feature B.

3.2 Gain Ratio -- C4.5 Algorithm

Split Information

SplitInfo(S, A) = -Σ (|Sᵥ| / |S|) * log₂(|Sᵥ| / |S|)

SplitInfo measures how evenly the feature splits the data. More distinct values lead to higher SplitInfo.

Gain Ratio (GR)

GR(S, A) = IG(S, A) / SplitInfo(S, A)

Gain Ratio penalizes features with many distinct values by dividing IG by SplitInfo, fixing the multi-value bias of ID3.

Worked Example: Continuing from above, Feature B splits 14 samples evenly into 7 and 7:

SplitInfo(S, B) = -(7/14)log₂(7/14) - (7/14)log₂(7/14) = 1.0
GR(S, B) = 0.509 / 1.0 = 0.509

Suppose Feature C has 5 distinct values, splitting into groups of 2-3 samples. IG=0.52 but SplitInfo=2.3:
GR(S, C) = 0.52 / 2.3 = 0.226

Although Feature C has slightly higher IG (0.52 vs 0.509), its Gain Ratio is much lower (0.226 vs 0.509). C4.5 would choose Feature B, correctly avoiding the multi-value trap.

3.3 Gini Impurity -- CART Algorithm

Gini Impurity

Gini(S) = 1 - Σ pᵢ²

Gini measures the probability that two randomly chosen samples from the set have different classes. Range [0, 0.5] for binary; 0 means perfectly pure.

Weighted Gini after Split

Gini_split = Σ (|Sᵥ| / |S|) * Gini(Sᵥ)

CART picks the split that minimizes the weighted Gini value. Note: CART always produces binary splits (exactly two children per node).

Worked Example: Same 14 samples (9+, 5-).

Parent Gini: Gini = 1 - (9/14)² - (5/14)² = 1 - 0.413 - 0.128 = 0.459

Split on Feature B (threshold = x):
Left child: 7 samples (7+, 0-) → Gini = 1 - 1² - 0² = 0.000
Right child: 7 samples (2+, 5-) → Gini = 1 - (2/7)² - (5/7)² = 1 - 0.082 - 0.510 = 0.408

Weighted Gini: (7/14)*0.000 + (7/14)*0.408 = 0.204

Gini dropped from 0.459 to 0.204 -- a significant improvement. CART selects the split with the lowest weighted Gini across all features and thresholds.

4. ID3 vs C4.5 vs CART Comparison

PropertyID3C4.5CART
Split CriterionInformation GainGain RatioGini Impurity
Tree StructureMulti-wayMulti-wayStrictly binary
Continuous FeaturesNot supportedSupported (threshold binarization)Supported (optimal split point)
Missing ValuesNot supportedSupported (weighted distribution)Supported (surrogate splits)
PruningNonePessimistic Error Pruning (PEP)Cost-Complexity Pruning (CCP)
Task TypeClassification onlyClassification onlyClassification + Regression
BiasFavors multi-value featuresCorrects multi-value biasNo notable bias
Year Proposed1986 (Quinlan)1993 (Quinlan)1984 (Breiman)
Sklearn SupportNoNoDecisionTreeClassifier

In practice, sklearn's decision tree is an optimized CART implementation. If you use Python for ML, you are using CART by default. ID3 and C4.5 appear mainly in academic research and interviews.

5. Python Implementation from Scratch

Below is a simple entropy-based decision tree classifier in pure Python (only using the math module).

import math from collections import Counter def entropy(labels): """Calculate information entropy.""" n = len(labels) if n == 0: return 0 counts = Counter(labels) return -sum((c/n) * math.log2(c/n) for c in counts.values()) def best_split(X, y, features): """Select best feature to split on (by information gain).""" base_ent = entropy(y) best_ig, best_feat = -1, None for feat in features: values = set(row[feat] for row in X) weighted_ent = 0 for val in values: sub_y = [y[i] for i, row in enumerate(X) if row[feat] == val] weighted_ent += (len(sub_y) / len(y)) * entropy(sub_y) ig = base_ent - weighted_ent if ig > best_ig: best_ig, best_feat = ig, feat return best_feat def build_tree(X, y, features, depth=0, max_depth=10): """Recursively build the decision tree.""" # Stop: pure node, no features, or max depth if len(set(y)) == 1: return y[0] if not features or depth >= max_depth: return Counter(y).most_common(1)[0][0] feat = best_split(X, y, features) if feat is None: return Counter(y).most_common(1)[0][0] tree = {feat: {}} values = set(row[feat] for row in X) remaining = [f for f in features if f != feat] for val in values: sub_X = [row for row in X if row[feat] == val] sub_y = [y[i] for i, row in enumerate(X) if row[feat] == val] tree[feat][val] = build_tree(sub_X, sub_y, remaining, depth+1, max_depth) return tree def predict(tree, sample): """Predict label for a single sample.""" if not isinstance(tree, dict): return tree feat = list(tree.keys())[0] val = sample.get(feat) subtree = tree[feat].get(val) if subtree is None: return None # unseen feature value return predict(subtree, sample) # Usage example data = [ {'outlook': 'sunny', 'humidity': 'high', 'wind': 'weak'}, {'outlook': 'sunny', 'humidity': 'high', 'wind': 'strong'}, {'outlook': 'overcast', 'humidity': 'high', 'wind': 'weak'}, {'outlook': 'rain', 'humidity': 'high', 'wind': 'weak'}, {'outlook': 'rain', 'humidity': 'normal', 'wind': 'weak'}, ] labels = ['no', 'no', 'yes', 'yes', 'yes'] features = ['outlook', 'humidity', 'wind'] tree = build_tree(data, labels, features, max_depth=5) print(tree) print(predict(tree, {'outlook': 'sunny', 'humidity': 'normal', 'wind': 'weak'}))

6. Sklearn Decision Tree

6.1 Classification (DecisionTreeClassifier)

from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, classification_report from sklearn.datasets import load_iris # Load data iris = load_iris() X_train, X_test, y_train, y_test = train_test_split( iris.data, iris.target, test_size=0.2, random_state=42 ) # Train decision tree clf = DecisionTreeClassifier( criterion='gini', # 'gini' or 'entropy' max_depth=4, # maximum depth min_samples_split=5, # min samples to split internal node min_samples_leaf=2, # min samples per leaf random_state=42 ) clf.fit(X_train, y_train) # Evaluate y_pred = clf.predict(X_test) print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}") print(classification_report(y_test, y_pred, target_names=iris.target_names))

6.2 Regression (DecisionTreeRegressor)

from sklearn.tree import DecisionTreeRegressor from sklearn.metrics import mean_squared_error, r2_score import numpy as np # Regression tree -- same API, different criterion reg = DecisionTreeRegressor( criterion='squared_error', # also 'friedman_mse', 'absolute_error' max_depth=5, min_samples_leaf=5, random_state=42 ) reg.fit(X_train, y_train) y_pred = reg.predict(X_test) print(f"MSE: {mean_squared_error(y_test, y_pred):.4f}") print(f"R^2: {r2_score(y_test, y_pred):.4f}")

6.3 Feature Importance

# Inspect feature importances import pandas as pd importances = clf.feature_importances_ feat_imp = pd.DataFrame({ 'feature': iris.feature_names, 'importance': importances }).sort_values('importance', ascending=False) print(feat_imp)

7. Pruning Strategies

An unpruned decision tree will easily overfit -- it can perfectly memorize the training set (one sample per leaf), but generalizes poorly. Pruning is the key technique to control overfitting.

7.1 Pre-pruning

Stop the tree from growing during construction. Sklearn pre-pruning parameters:

max_depth: Maximum tree depth. Typically 3-10. Too small = underfit, too large = overfit.
min_samples_split: Minimum samples to split an internal node (default 2). Increase to prevent overfitting.
min_samples_leaf: Minimum samples per leaf node (default 1). Increase for more stable leaves.
max_features: Max features considered per split. 'sqrt' or 'log2' adds randomness.
max_leaf_nodes: Maximum number of leaf nodes. Directly limits tree complexity.

7.2 Post-pruning -- Cost-Complexity Pruning (CCP)

Let the tree grow fully, then progressively remove subtrees that contribute little to prediction accuracy. Controlled by ccp_alpha in sklearn.

Cost-Complexity Formula

R_alpha(T) = R(T) + alpha * |T|

R(T) is the training error, |T| is the number of leaves, alpha is the penalty coefficient. Larger alpha = heavier penalty = smaller tree.

# Cost-Complexity Pruning workflow # 1. Get pruning path with alpha values path = clf.cost_complexity_pruning_path(X_train, y_train) ccp_alphas = path.ccp_alphas # 2. Train a tree for each alpha clfs = [] for alpha in ccp_alphas: c = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42) c.fit(X_train, y_train) clfs.append(c) # 3. Evaluate train/test accuracy, pick best alpha train_scores = [c.score(X_train, y_train) for c in clfs] test_scores = [c.score(X_test, y_test) for c in clfs] best_idx = np.argmax(test_scores) best_alpha = ccp_alphas[best_idx] print(f"Best alpha: {best_alpha:.4f}, Test Accuracy: {test_scores[best_idx]:.4f}")

8. Visualization

8.1 Using sklearn plot_tree

import matplotlib.pyplot as plt from sklearn.tree import plot_tree fig, ax = plt.subplots(figsize=(20, 10)) plot_tree( clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, # color by class rounded=True, # rounded boxes fontsize=10, ax=ax ) plt.tight_layout() plt.savefig('decision_tree.png', dpi=150) plt.show()

8.2 Using Graphviz

from sklearn.tree import export_graphviz import graphviz # Export to DOT format dot_data = export_graphviz( clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True, special_characters=True ) # Render to image graph = graphviz.Source(dot_data) graph.render('iris_tree', format='png', cleanup=True) graph # displays inline in Jupyter

8.3 Text Output

from sklearn.tree import export_text # Print tree rules as plain text tree_rules = export_text(clf, feature_names=iris.feature_names) print(tree_rules) # Output looks like: # |--- petal length (cm) <= 2.45 # | |--- class: setosa # |--- petal length (cm) > 2.45 # | |--- petal width (cm) <= 1.75 # | | |--- class: versicolor # | |--- petal width (cm) > 1.75 # | | |--- class: virginica

9. Pros and Cons

Pros:
1. Intuitive and easy to interpret, visualizable
2. No feature scaling required
3. Handles both numerical and categorical features
4. Captures non-linear relationships and feature interactions
5. Handles missing values natively (CART)
6. Fast training and prediction
7. Built-in feature importance
Cons:
1. Prone to overfitting (especially deep trees)
2. Unstable: small data changes can produce very different trees
3. Sensitive to class imbalance
4. Greedy algorithm -- no guarantee of global optimum
5. Poor extrapolation (values outside training range)
6. Weak on high-dimensional sparse data
7. Single tree accuracy typically lower than ensembles

10. When to Use Decision Trees

ScenarioRecommendedReason
Need model interpretabilityDecision TreeDirectly visualize decision rules; meets compliance needs
Quick baseline modelDecision TreeFast training, no feature engineering, minimal tuning
Maximum accuracyRandom Forest / XGBoostEnsembles combine many trees for higher, more stable accuracy
High-dim linearly separable dataSVM / Logistic RegressionDecision trees perform poorly on high-dim sparse data
Small dataset (<100 samples)Decision Tree (shallow) / KNNSimple models prevent overfitting
Large-scale data (1M+)XGBoost / LightGBMOptimized algorithms, distributed computing support
Feature importance analysisDecision Tree / Random ForestBuilt-in feature importance output
Data with missing valuesCART / XGBoostNative missing value handling

Rule of thumb: Start with a decision tree to understand your data and features, then upgrade to Random Forest or gradient boosting for better accuracy. Decision trees are the foundation of all ensemble methods -- understanding decision trees means understanding the core of Random Forest and XGBoost.

12. FAQ

Q: Does sklearn use ID3, C4.5, or CART?

Sklearn's DecisionTreeClassifier and DecisionTreeRegressor are both based on an optimized CART algorithm, always producing binary trees. You can choose 'gini' or 'entropy' via the criterion parameter, but the tree structure is always binary. For multi-way trees (ID3/C4.5), you need a custom implementation or another library.

Q: What is the difference between Decision Tree and Random Forest?

Random Forest is an ensemble of many decision trees. It reduces overfitting through two randomization mechanisms: 1) Bagging -- each tree is trained on a random subset of data; 2) Feature randomization -- each split considers only a random subset of features. The final prediction is the vote/average of all trees. Random Forest sacrifices interpretability but greatly improves accuracy and stability.

Q: How do I find the optimal max_depth?

The most common approach is cross-validated grid search: define a range (e.g., 3 to 15), use GridSearchCV to evaluate each depth on validation data, and pick the depth with the best test score. Alternatively, use CCP post-pruning to automatically find the optimal complexity. In practice, depth 4-8 is sufficient for most problems.

Q: Do decision trees need feature scaling (standardization/normalization)?

No. Decision trees split based on feature value comparisons (e.g., "height > 170cm?"), and the result is invariant to feature scale. Whether height is in centimeters or meters, the split is equally effective. This is a major advantage of tree-based methods over distance-based algorithms like SVM and KNN.

Q: Can decision trees handle multi-class and multi-output tasks?

Yes. Sklearn's decision tree natively supports multi-class classification (no OvR/OvO needed) since leaf nodes can vote among multiple classes. It also supports multi-output tasks (predicting multiple target variables simultaneously) -- simply pass y as a multi-column matrix.