一文上手决策树:从理论到实战
connygpt 2024-12-20 11:47 3 浏览
一、基础概念
决策树是一类极为常用的机器学习方法,尤其是在分类场景。决策树通过树形结构来递归地将样本分割到不同的叶子结点中去,并根据每个叶子结点中的样本构成对该结点中的样本进行分类。
我们可以从两个视角来理解决策树模型。
- 第一种视角是将决策树视为一组规则的集合。对一棵完整的决策树来说,从根节点到每一个叶子结点都对应了一条规则,不同的规则之间互斥且完备。
- 第二种视角是从条件概率的角度来理解决策树。我们对每个叶子结点的分类,都是依据该结点包含的样本集合分属于不同分类的概率来决定的。从这个角度来看,决策树本质上也是一种概率模型。
与其他机器学习方法一样,使用决策树进行预测时,我们的目标是尽可能地在新样本上预测得更准确。那么,一方面我们要在训练集上得到尽可能高的预测精度,另一方面,我们要通过正则化参数来保证模型没有过拟合。
决策树学习的目标就是最小化上述函数,该函数无法使用常规的梯度下降来直接求解,因此我们一般使用启发式的方法来寻找最优决策树,具体来说,就是递归地选择最优特征来分割数据集。如果某次分割后的子集可以完全正确划分到某一类,那么该子集可以归到一个叶子结点;否则,继续从这些子集中选择最优特征进行下一次划分,直到所有子集都能被正确分类。
以上思路会构建一棵完整的树,但是正如前文所述,我们还需要保证模型没有过拟合,因此我们需要对决策树进行剪枝。决策树剪枝通常有预剪枝和后剪枝两种方法。
总的来说,完整的决策树包含特征选择、决策树构建和决策树剪枝三大方面。
二、特征选择
为了构建一棵性能良好的决策树,我们需要从训练集中不断选取最具有区分度(分类能力)的特征。一般来说,我们通过三个指标来实现这一目标。
1. 信息增益
为了说明信息增益,我们需要引入信息熵的概念。在信息论和概率论中,熵是一种描述随机变量不确定性的度量方式,也可以用来描述样本集合的不纯度。熵越低,样本的不确定性就越低,纯度则越高。
信息增益是指在得到了某个特征X的信息之后,使得类Y的信息不确定性减少的程度。或者说,信息增益代表了某特征带来的分类确定性的增量,特征的信息增益越大,目标分类的确定性也就越大。
构建决策树时可以使用信息增益进行特征选择,特征的信息增益越大,代表了其分类能力越强,ID3算法就是基于信息增益做特征选择的。
我们举一个例子来演示信息增益的计算。
例1:假设有20位同学,其中有10位喜欢篮球,10位不喜欢篮球。在20位同学中有12位男同学,其中9位喜欢篮球,3位不喜欢篮球;有8位女同学,其中1位喜欢篮球,7位不喜欢篮球。那么性别(男/女)的信息增益是多少?
import numpy as np
def entropy(freq: list) -> float:
"""计算信息熵
"""
freq = np.array([i for i in freq if i > 0])
proba = freq / freq.sum()
entropy = - (proba * np.log2(proba)).sum()
return entropy
if __name__ == '__main__':
# 原始数据
like_basketball = [10, 10]
male_like_basketball = [9, 3]
female_like_basketball = [1, 7]
# 经验熵
entropy_init = entropy(like_basketball)
# 条件熵
entropy_cond = 10 / 20 * entropy(male_like_basketball) + \
10 / 20 * entropy(female_like_basketball)
# 信息增益
info_gain = entropy_init - entropy_cond
print('经验熵:{0}\n条件熵:{1}\n信息增益:{2}'.format(
entropy_init, entropy_cond, info_gain))
结果为:
经验熵:1.0
条件熵:0.6774212838293646
信息增益:0.3225787161706354
2. 信息增益率
信息增益存在一个问题:当某个特征分类取值较多时,该特征的信息增益计算结果会放大。取极端情况,如有一个特征为编号,每个样本对应了唯一的一个编号,这种情况下的信息纯度很高,那么基于这个特征得到的信息增益就很大。
gender_cnt = [12, 8]
entropy_gender = entropy(gender_cnt)
gain_rate = info_gain / entropy_gender
print('信息增益率:{0}'.format(gain_rate))
结果为:
信息增益率:0.33222979419649123
3. 基尼系数
仍以例1来演示基尼系数的计算。
def gini(freq: list) -> float:
"""计算基尼系数
"""
freq = np.array([i for i in freq if i > 0])
proba = freq / freq.sum()
g = 1 - (proba ** 2).sum()
return g
gini_male = 12 / 20 * gini(male_like_basketball) + 8 / 20 * gini(female_like_basketball)
print('基尼系数:{0}'.format(gini_male))
结果为:
基尼系数:0.3125
三、决策树模型
三大经典决策树模型分别为ID3、C4.5、CART,它们都是通过递归地选择最优特征来构建决策树。如前文所述,在评估最优特征时,它们分别使用了信息增益、信息增益率和基尼系数三个指标。
ID3和C4.5算法仅有决策树的生成,不包含决策树剪枝的部分,因此容易过拟合。CART算法除了用于分类外,还可用于回归,也包含决策树剪枝,因此现在应用更为广泛。
1. ID3
ID3算法的全称为Iterative Dichotomiser 3,即迭代二叉树。其核心是基于信息增益递归地选择最优特征构造决策树。
简单来阐述,ID3算法的思路为:
- 首先预设一个决策树根节点,然后对所有特征计算信息增益;
- 选择一个信息增益最大的特征作为最优特征,根据该特征的不同取值建立子结点;
- 接着对每个子结点递归地调用上述方法,直到信息增益很小或者没有特征可选时,将这些子结点作为叶子结点,并以该叶子结点上的多数类作为预测类。
2. C4.5
C4.5算法实际上是对ID3算法的改进。
- ID3算法使用信息增益做特征选择,倾向于选择取值水平较多的特征。针对这一问题,C4.5算法改为使用信息增益率。
- ID3算法不可以处理缺失值,C4.5算法可以。
- ID3算法不支持连续值特征,C4.5算法支持。
- ID3算法不支持后剪枝,C4.5算法支持后剪枝。
3. CART
CART算法的全称为分类与回归树(classification and regression tree),它既可用于分类,又可用于回归,这是它与ID3/C4.5之间的主要区别之一,此处我们仅讨论CART算法用于分类的场景。此外,CART算法中的特征选择使用的是基尼系数。最后,CART算法不仅包含了决策树的生成算法,还包括了决策树的剪枝算法。
CART生成的决策树为二叉树,内部结点取值为“是”和“否”,这种方法等价于递归地二分每个特征,将特征空间划分为有限个子空间,并在这些子空间上确定预测的概率分布,即前述的预测条件概率分布。
其算法流程为:
4. 对比
决策树 | 模型分类 | 树结构 | 特征选择 | 连续值处理 | 缺失值处理 | 剪枝处理 |
ID3 | 分类 | 多叉树 | 信息增益 | 不可以 | 不可以 | 不可以 |
C4.5 | 分类 | 多叉树 | 信息增益率 | 可以 | 可以 | 可以 |
CART | 分类 | 二叉树 | 基尼系数 | 可以 | 可以 | 可以 |
四、决策树剪枝
决策树剪枝一般包含两种方法:预剪枝(pre-pruning)和后剪枝(post-pruning)。
1. 预剪枝
预剪枝,是指在决策树生成过程中提前停止树的增长的一种剪枝算法。其主要思路有:
- 提前设定决策树的深度,当达到这一深度时,停止生长。
- 当某结点的所有样本属于同一类别,停止生长。
- 提前设定某个阈值,当某结点的样本数小于该阈值时,停止生长。
- 提前设定某个阈值,当分裂带来的性能提升小于该阈值时,停止生长。
预剪枝方法直接、简单高效,适用于大规模求解问题。目前在主流的集成学习模型中,很多算法用到了预剪枝的思想。但因为决策树的构建使用的是启发式方法,具有局部最优的问题,预剪枝提前停止树的生长,存在一定的欠拟合风险。
2. 后剪枝
主流的后剪枝方法有四种:悲观错误剪枝(Pessimistic Error Pruning,PEP),最小错误剪枝(Minimum Error Pruning,MEP),代价复杂度剪枝(Cost-Complexity Pruning,CCP)和基于错误的剪枝(Error-Based Pruning,EBP)。C4.5采用悲观错误剪枝,CART采用代价复杂度剪枝。
后剪枝主要通过极小化决策树整体损失函数来实现。前文我们提到,决策树学习的目标是最小化如下损失函数:
CART算法使用的正是后剪枝方法。CART后剪枝首先通过计算子树的损失函数来实现剪枝并得到一个子树序列,然后通过交叉验证的方法从子树序列中选取最优子树。
五、优缺点
1. 优点
- 简单直观,生成的决策树很直观。
- 基本不需要预处理,不需要提前归一化和处理缺失值。
- 使用决策树预测的代价是O(log2N)?。N?为样本数。
- 既可以处理离散值也可以处理连续值。很多算法只是专注于离散值或者连续值。
- 可以处理多维度输出的分类问题。
- 相比于神经网络之类的黑盒分类模型,决策树在逻辑上可以很好解释。
- 可以交叉验证的剪枝来选择模型,从而提高泛化能力。
- 对于异常点的容错能力好,健壮性高。
2. 缺点
- 决策树算法非常容易过拟合,导致泛化能力不强。可以通过设置节点最少样本数量和限制决策树深度来改进。
- 决策树会因为样本发生一点的改动,导致树结构的剧烈改变。这个可以通过集成学习之类的方法解决。
- 寻找最优的决策树是一个NP难题,我们一般是通过启发式方法,容易陷入局部最优。可以通过集成学习的方法来改善。
- 有些比较复杂的关系,决策树很难学习,比如异或。这个就没有办法了,一般这种关系可以换神经网络分类方法来解决。
- 如果某些特征的样本比例过大,生成决策树容易偏向于这些特征。这个可以通过调节样本权重来改善。
六、代码实战
1. sklearn
在sklearn中,使用决策树进行分类预测非常简单,下面是一个来自官方文档的例子。
from sklearn.tree import DecisionTreeClassifier
X = [[0, 0], [1, 1]]
Y = [0, 1]
clf = DecisionTreeClassifier()
clf = clf.fit(X, Y)
# 预测
print(clf.predict([[2, 2]])
# 预测概率
print(clf.predict_proba([[2, 2]])
我们还可以将决策树通过可视化的方式呈现出来。
from sklearn.datasets import load_iris
from sklearn import tree
import matplotlib.pyplot as plt
# 以iris数据为例
iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
# 可视化
plt.figure(figsize=(36, 24))
tree.plot_tree(clf, feature_names=iris.feature_names,
filled=True, proportion=True, fontsize=14)
?
2. PySpark
在PySpark中使用决策树模型稍显复杂。
import numpy as np
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier, RandomForestClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, TrainValidationSplit
from pyspark.ml.evaluation import BinaryClassificationEvaluator
spark = SparkSession.builder.appName('test').getOrCreate()
# 准备数据
# 数据可以从Hive中读取,或从pandas.DataFrame格式创建等。
# 此处假设一份用于二分类预测模型训练的数据已准备好,
data = YOUR_PYSPARK_DATAFRAME
features = YOUR_FEATURE_COLUMN_NAMES
label_col = YOUR_LABEL_COLUMNS
# 数据集分割
traindf, testdf = data.randomSplit([0.8, 0.2], seed=1)
# 特征向量化
vec_assembler = VectorAssembler(inputCols=features, outputCol='features')
# 决策树
dtree = DecisionTreeClassifier(
seed=1,
labelCol=label_col,
featuresCol='features',
predictionCol='pred',
probabilityCol='proba',
maxDepth=5,
minInstancesPerNode=3,
impurity='gini',
maxBins=10
)
# 训练模型
pipeline = Pipeline(stages=[vec_assembler, dtree])
model = pipeline.fit(traindf)
# 特征重要性
feat_importances = list(zip(features, model.stages[1].featureImportances))
df_importances = pd.DataFrame(sorted(feat_importances, key=lambda x: x[1], reverse=True),
columns=['feature', 'importances'])
df_importances.head()
# 预测
df_pred = model.transform(testdf)
to_array = F.udf(lambda x: x.toArray().tolist(), ArrayType(DoubleType()))
df_pred = df_pred.withColumn('proba_score', to_array('proba')[1])
?
我们还可以在PySpark中使用网格搜索来确定最佳参数。
# 特征向量化
vec_assembler = VectorAssembler(inputCols=features, outputCol='features')
# 随机森林
dtree = DecisionTreeClassifier(
seed=1,
labelCol=label_col,
featuresCol='features',
predictionCol='pred',
probabilityCol='proba',
impurity='gini',
# maxDepth=5,
# minInstancesPerNode=3,
# maxBins=10
)
# 流水线
pipeline = Pipeline(stages=[vec_assembler, dtree])
# 设置网格参数
param_grid = ParamGridBuilder() \
.baseOn({dtree.labelCol:'label'}) \
.baseOn({dtree.featuresCol: 'features'}) \
.baseOn({dtree.predictionCol: 'pred'}) \
.baseOn({dtree.probabilityCol: 'proba'}) \
.addGrid(dtree.minInstancesPerNode, [3, 5, 7]) \
.addGrid(dtree.maxDepth, [10, 12, 15, 20]) \
.addGrid(dtree.maxBins, [5, 10, 15]) \
.build()
# 模型评估
evaluator = BinaryClassificationEvaluator()
# 交叉验证
cv = CrossValidator(
estimator=pipeline,
estimatorParamMaps=param_grid,
evaluator=evaluator,
numFolds=5,
seed=1024
)
# 开始执行
a = time.time()
cvModel = cv.fit(traindf)
b = time.time()
print(b - a)
# 打印最佳参数
params = cvModel.getEstimatorParamMaps()
avg_metrics = cvModel.avgMetrics
all_params = list(zip(params, avg_metrics))
best_param = sorted(all_params, key=lambda x: x[1], reverse=True)[0]
for p, v in best_param[0].items():
print("{}: {}".format(p.name, v))
?
- 上一篇:3分钟让你的项目支持AI问答模块,完全开源!
- 下一篇:统计学习方法-决策树
相关推荐
- 自学Python,写一个挨打的游戏代码来初识While循环
-
自学Python的第11天。旋转~跳跃~,我~闭着眼!学完循环,沐浴着while的光芒,闲来无事和同事一起扯皮,我说:“编程语言好神奇,一个小小的循环,竟然在生活中也可以找到原理和例子”,同事也...
- 常用的 Python 工具与资源,你知道几个?
-
最近几年你会发现,越来越多的人开始学习Python,工欲善其事必先利其器,今天纬软小编就跟大家分享一些常用的Python工具与资源,记得收藏哦!不然下次就找不到我了。1、PycharmPychar...
- 一张思维导图概括Python的基本语法, 一周的学习成果都在里面了
-
一周总结不知不觉已经自学Python一周的时间了,这一周,从认识Python到安装Python,再到基本语法和基本数据类型,对于小白的我来说无比艰辛的,充满坎坷。最主要的是每天学习时间有限。只...
- 三日速成python?打工人,小心钱包,别当韭菜
-
随着人工智能的热度越来越高,许多非计算机专业的同学们也都纷纷投入到学习编程的道路上来。而Python,作为一种相对比较容易上手的语言,也越来越受欢迎。网络上各类网课层出不穷,各式广告令人眼花缭乱。某些...
- Python自动化软件测试怎么学?路线和方法都在这里了
-
Python自动化测试是指使用Python编程语言和相关工具,对软件系统进行自动化测试的过程。学习Python自动化测试需要掌握以下技术:Python编程语言:学习Python自动化测试需要先掌握Py...
- Python从放弃到入门:公众号历史文章爬取为例谈快速学习技能
-
这篇文章不谈江流所专研的营销与运营,而聊一聊技能学习之路,聊一聊Python这门最简单的编程语言该如何学习,我完成的第一个Python项目,将任意公众号的所有历史文章导出成PDF电子书。或许我这个Py...
- 【黑客必会】python学习计划
-
阅读Python文档从Python官方网站上下载并阅读Python最新版本的文档(中文版),这是学习Python的最好方式。对于每个新概念和想法,请尝试运行一些代码片段,并检查生成的输出。这将帮助您更...
- 公布了!2025CDA考试安排
-
CDA数据分析师报考流程数据分析师是指在不同行业中专门从事行业数据搜集、整理、分析依据数据作出行业研究评估的专业人员CDA证书分为1-3级,中英文双证就业面广,含金量高!!?报考条件:满18...
- 一文搞懂全排列、组合、子集问题(经典回溯递归)
-
原创公众号:【bigsai】头条号:程序员bigsai前言Hello,大家好,我是bigsai,longtimenosee!在刷题和面试过程中,我们经常遇到一些排列组合类的问题,而全排列、组合...
- 「西法带你学算法」一次搞定前缀和
-
我花了几天时间,从力扣中精选了五道相同思想的题目,来帮助大家解套,如果觉得文章对你有用,记得点赞分享,让我看到你的认可,有动力继续做下去。467.环绕字符串中唯一的子字符串[1](中等)795.区...
- 平均数的5种方法,你用过几种方法?
-
平均数,看似很简单的东西,其实里面包含着很多学问。今天,分享5种经常会用到的平均数方法。1.算术平均法用到最多的莫过于算术平均法,考试平均分、平均工资等等,都是用到这个。=AVERAGE(B2:B11...
- 【干货收藏】如何最简单、通俗地理解决策树分类算法?
-
决策树(Decisiontree)是基于已知各种情况(特征取值)的基础上,通过构建树型决策结构来进行分析的一种方式,是常用的有监督的分类算法。决策树算法是机器学习中的一种经典算法,它通过一系列的规则...
- 面试必备:回溯算法详解
-
我们刷leetcode的时候,经常会遇到回溯算法类型题目。回溯算法是五大基本算法之一,一般大厂也喜欢问。今天跟大家一起来学习回溯算法的套路,文章如果有不正确的地方,欢迎大家指出哈,感谢感谢~什么是回溯...
- 「机器学习」决策树——ID3、C4.5、CART(非常详细)
-
决策树是一个非常常见并且优秀的机器学习算法,它易于理解、可解释性强,其可作为分类算法,也可用于回归模型。本文将分三篇介绍决策树,第一篇介绍基本树(包括ID3、C4.5、CART),第二篇介绍Ran...
- 大话AI算法: 决策树
-
所谓的决策树算法,通俗的说就是建立一个树形的结构,通过这个结构去一层一层的筛选判断问题是否好坏的算法。比如判断一个西瓜是否好瓜,有20条西瓜的样本提供给你,让你根据这20条(通过机器学习)建立起...
- 一周热门
- 最近发表
- 标签列表
-
- kubectlsetimage (56)
- mysqlinsertoverwrite (53)
- addcolumn (54)
- helmpackage (54)
- varchar最长多少 (61)
- 类型断言 (53)
- protoc安装 (56)
- jdk20安装教程 (60)
- rpm2cpio (52)
- 控制台打印 (63)
- 401unauthorized (51)
- vuexstore (68)
- druiddatasource (60)
- 企业微信开发文档 (51)
- rendertexture (51)
- speedphp (52)
- gitcommit-am (68)
- bashecho (64)
- str_to_date函数 (58)
- yum下载包及依赖到本地 (72)
- jstree中文api文档 (59)
- mvnw文件 (58)
- rancher安装 (63)
- nginx开机自启 (53)
- .netcore教程 (53)