决策树

决策树

  决策树是最经典的数据挖掘算法之一,它是用树形结构的方法来展现决策/分类过程。因此比较简单直观,方便理解。根据适用情况的不同可分为分类树回归树。当数据集的因变量为为连续型数值时,这样的树就是回归树;当数据集的因变量为离散型数值时,这样的树就是一个分类树,可以很好的解决分类问题。

1 原理


  一棵决策树的形成过程:从根节点开始,选择最能预测目标类的特征,然后通过这个特征将案例划分到不同的组中,这样形成了第一组树枝。该算法继续此过程,每次选择最佳的候选特征,形成新的结点,直到达到停止的标准为止。当所得到的生成树过大时,容易出现过拟合现象。因此,为了解决这个问题,需要对树进行剪枝操作。

  综上来看,决策树中最重要的两部分就是选择最优划分属性剪枝操作。下面将分别介绍这两部分的内容。

1.1 最优划分属性

  决策树学习的关键是如何选择最优划分属性。一般而言,有信息增益信息增益率(有的地方也叫信息增益比)和基尼指数三种方法来选择最优划分属性。

1.1.1 信息增益

  要明白信息增益首先要知道信息熵的概念。信息熵是度量样本集纯度最常用的一种指标。假设当前样本集合\(D\)中第\(k\)类样本所占的比例为\(p_k(k = 1,2,…,|y|)\),则\(D\)的信息熵定义为: \[Ent(D) = – \sum_{k=1}^{|y|}p_klog_2p_k\]

  信息增益的定义是建立在信息熵的基础上。假定离散属性\(a\)\(V\)个可能的取值,现若使用属性\(a\)来对样本集\(D\)进行划分,则会产生\(V\)个分支结点,其中第\(v\)个分支结点包含了\(D\)中所有在属性\(a\)上取值为\(a^v\)的样本,记为\(D^v\)。再给分支结点赋予权重\(|D^v|/|D|\),这样就可计算出用属性\(a\)对样本集\(D\)进行划分所获得的信息增益: \[Gain(D,a) = Ent(D)-\sum_{v=1}^{V}\frac{|D^v|}{|D|}Ent(D^v)\]

  一般而言,信息增益越大,则意味着使用属性\(a\)来进行划分所获得的“纯度提升”越大。因此,可以用信息增益来进行决策树的划分属性选择。

1.1.2 信息增益率

  信息增益准则在选择特征的时候对可取值数目较多的特征有所偏好。举个极端的例子,当将“编号”也作为数据集特征的时候,信息增益准则会倾向于选择“编号”特征对数据进行分类。但是这样的决策树显然不具有泛化能力。为了减少这种偏好可能带来的不利影响,使用信息增益率来选择最优划分属性。

  简单来讲,信息增益比 = 惩罚参数 * 信息增益,用公式定义则为: \[Gain\_ratio(D,a) = \frac{Gain(D,a)}{IV(a)}\] 其中 \[IV(a) = -\sum_{v=1}^{V}\frac{|D^v|}{|D|}log_2\frac{|D^v|}{|D|}\]

  需要注意的是,信息增益率准则对可取值数目较少的特征有所偏好,因此,当真正运用的时候并不是直接选择增益率最大的候选划分特征,而是先从候选划分特征中找出信息增益高出平均水平的属性,再从中选择增益率最高的。

1.1.3 基尼指数

  直观来说基尼指数反映了从数据集中随机抽取两个样本其类别标记不一致的概率,因此,基尼指数越小则数据集的纯度越高,采用与上边信息增益相同的符号表示,属性\(a\)的基尼指数定义为: \[Gini\_index(D,a) = \sum_{v=1}^{V}\frac{|D^v|}{|D|}Gini(D^v)\]

  于是,在候选集合\(A\)中,选择使得划分后基尼指数最小的属性作为最优划分属性。

1.2 剪枝

  当决策树的分支过多的时候会造成过拟合问题。因此,可以通过剪枝操作去掉一部分分支来降低过拟合的风险。决策树的剪枝操作分为预剪枝后剪枝

  • 预剪枝指在决策树生成的过程中对每个结点在划分前先进行估计,若当前结点的划分不能带来决策树泛化性能的提升则停止划分,并将当前结点标记为叶结点。
  • 后剪枝是先从训练集生成一棵完整的决策树,然后再从底向上地对非叶结点进行考察,若果将该叶结点对应的子树替换为叶结点能带来决策树的泛化性能提升,则将该子树替换为叶结点。

1.3 其它问题

  前面只介绍了离散属性如何生成决策树,而现实中还会遇到连续属性,甚至还会碰到缺失值的问题。因此,很有必要介绍连续属性生成决策树和缺失值的解决办法。

1.3.1 连续值处理

  对待连续属性的办法无非是连续属性的离散化,而最简单的方法是二分法。二分法的思想也比较简单:当给定一个样本集\(D\)的连续属性\(a\),假设该属性有\(n\)个不同的取值,首先将这些值从小到大进行排序,记为\(\{a^1,a^2,…,a^n\}\),基于划分点\(t\)可以将样本集分为\(D_{t}^{-}\)\(D_{t}^{+}\)两个子集,其中\(D_{t}^{-}\)中包含那些在属性\(a\)上的取值不大于\(t\)的样本,而\(D_{t}^{+}\)则包含那些在属性\(a\)上取值大于\(t\)的样本。

  候选划分点的选取也很简单,选择两个数的平均数作为一个候选划分点,这样\(n\)个数就有\(n-1\)个候选划分点。举个例子:当有连续变量的取值为\(\{2,3,8,9,15\}\)时得到的候选划分集合为\(\{2.5,5.5,8.5,12\}\)

  然后就可以像离散属性值一样来考察这些划分点,选取最优的划分点进行样本集合的划分。这里也用信息增益的标准来选出最优的划分点: \[Gian(D,a)=\max_{t\in T_a}Ent(D)-\sum_{\lambda \in \{-,+\}}\frac{|D_{t}^{\lambda}|}{|D|}Ent(D_{t}^{\lambda})\]

  其中,\(Gain(D,a)\)是样本集\(D\)基于划分点\(t\)二分后的信息增益。因此,在这里就可以选择使之最大的划分点。

1.3.2 缺失值处理

  当样本的某些属性值缺失时,如果简单的放弃不完整的样本,仅使用无缺失值的样本来进行学习,显然是对数据信息的极大浪费。

  假设\(a,\rho\)表示无缺失值样本所占的比例,\(\tilde{p}_k\)表示无缺失值样本中第\(k\)类所占的比例,\(\tilde{r}_v\)则表示无缺失值样本中在属性\(a\)上取值\(a^v\)的样本所占的比例。显然,\(\sum_{k=1}^{|y|}\tilde{p}_k=1,\sum_{v=1}^{V}\tilde{r}_v=1\)。基于上述定义,我们可以将信息增益的计算式推广为: \[Gain(D,a)=\rho\times Gain(\tilde{D},a)=\rho\times(Ent(\tilde{D})-\sum_{v=1}^{V}\tilde{r}_vEnt(\tilde{D}^v))\]

其中:

\[Ent(\tilde{D})=-\sum_{k=1}^{|y|}\tilde{p}_klog_2\tilde{p}_k\]

2 算法


2.1 ID3/C4.5/C5.0

  C4.5是ID3的改进算法,两者都以熵理论和信息增益理论为基础。其算法的精髓所在,即是使用熵值或者信息增益来确定使用哪个变量作为各节点的判定变量。而C4.5是为了解决ID3只能用于离散型变量,且确定判定变量时偏向于选择取值较多的变量这两项缺陷而提出来的。因此,C4.5避免了偏袒的可能性,而且对自变量不作任何限制,但是最大的缺点是该算法的运行效率比较低。虽然目前已有了在运行效率等方面进一步完善的算法C5.0,但由于C5.0多用于商业用途,且C5.0并没有公布算法的更多细节。C4.5仍然是更为常用的决策树算法。

  这里需要明确的是:ID3算法在判定变量的时候用的信息增益,而C4.5用的是信息增益比

  C4.5是一种多叉树(即如果根节点或中间节点存在连续型的自变量,则该变量会一分为二的展开两个分支(这个过程后边会讲到);如果根节点或中间节点存在离散的自变量,则该变量会根据离散变量的水平数分开多个分支),就会导致某个变量一旦被使用,后面结点将不会再启用该变量。C4.5只能解决分类问题。

2.2 CART

  CART算法拥有一个非常完整的体系,包括树的生长过程剪枝过程;而且可以解决回归问题。它假设决策树是二叉树,内部结点特征的取值为“是”和“否”,左分支是取值为“是”的分支,右分支是取值为“否”的分支。这样的决策树等价于递归地二分每个特征,将输入空间即特征空间划分为有限个单元,并在这些单元上确定预测的概率分布,也就是在输入给定的条件下输出的条件概率分布。

  当某个非叶节点是多个水平(2个以上)的离散变量时,该变量就有可能被多次使用。举个例子:如果年龄段可分为\(\{低级,中级,高级\}\),则其子集可以是\(\{低级,中级,高级\}\)\(\{低级,中级\}\)\(\{低级,高级\}\)\(\{中级,高级\}\)\(\{低级\}\)\(\{中级\}\)\(\{高级\}\)\(\{\}\)。其中\(\{低级,中级,高级\}\)和空集\(\{\}\)为无意义分割,所以最终有\(2^3-2=6\)种组合,形成三对对立的组合,如\(\{低级,中级\}\)\(\{高级\}\)

3 案例


  如上所述,CART算法C4.5算法比较常用。因此,下面仅用这两种算法进行案例分析。

3.1 数据准备

  接下来的案例中需要用到的数据集是来自mvpart程序包中的car.test.frame数据集。首先,需要加载mvpart程序包;另外,后面的分析还需要用到DT程序包中的datatable()函数进行数据集展示,因此在这里与mvpart程序包一起加载:

# install.packages("mvpart") # 下载mvpart程序包
# install.packages("DT")     # 下载DT程序包
library(mvpart)              # 加载mvpart程序包,使用car.test.frame数据集
library(DT)                  # 加载DT程序包,使用datatable()函数

注意:如果没有下载mvpart和DT程序包,请先使用install.packages("mvpart")install.packages("DT")两个命令下载两个程序包之后再加载。

  下面获取car.test.frame数据集并查看数据集的概况:

data(car.test.frame)      # 获取数据集
datatable(car.test.frame) # 展示数据集
str(car.test.frame)       # 查看数据集的基本结构
## 'data.frame':    60 obs. of  8 variables:
##  $ Price      : int  8895 7402 6319 6635 6599 8672 7399 7254 9599 5866 ...
##  $ Country    : Factor w/ 8 levels "France","Germany",..: 8 8 5 4 3 6 4 5 3 3 ...
##  $ Reliability: int  4 2 4 5 5 4 5 1 5 NA ...
##  $ Mileage    : int  33 33 37 32 32 26 33 28 25 34 ...
##  $ Type       : Factor w/ 6 levels "Compact","Large",..: 4 4 4 4 4 4 4 4 4 4 ...
##  $ Weight     : int  2560 2345 1845 2260 2440 2285 2275 2350 2295 1900 ...
##  $ Disp.      : int  97 114 81 91 113 97 97 98 109 73 ...
##  $ HP         : int  113 90 63 92 103 82 90 74 90 73 ...

  以上的指令可以得出大量关于数据集的信息:

  • 语句data.frame告诉我们数据集是数据框类型
  • 语句60 obs. of 8 variables告诉我们数据一共包含60个观测、8个变量
  • 左下角Price、Country等告诉我们每个变量的变量名和所对应的类型(int表示整数型、Factor表示因子型)
  • 这8个变量分别为:价格(Price)、产地(Country)、可靠性(Reliability)、英里数(Mileage)、类型(Type)、车重(Weight)、发动机功率(Disp.)、净马力(HP)
  • 数据类型后面展示了每个变量所对应的部分数值

  为了方便,这里将“英里数(Mileage)”换算成大家所熟悉的“油耗”指标:

car.test.frame$Mileage <- 100*4.546/(1.6*car.test.frame$Mileage) # 将英里数换算成油耗指标

  由于后面要用变量“油耗(Mileage)”作为目标变量构建回归树;用变量“分组油耗(GFC)”作为目标变量构建分类树。因此,在这里需要添加一列新变量“分组油耗(GFC)”。这里将“油耗(Mileage)”划分为三个组别,从而成为含有三个水平的因子型变量:

  • C:7.6 ~ 9个油
  • B:9 ~ 11.6个油
  • A:11.6 ~ 15.8个油

  划分油耗为A、B、C三个组别:

car.test.frame$GFC <- NA # 新添加一列"分组油耗(GFC)"并用缺失值NA来填充

## 分组并设置相应分组油耗等级:
car.test.frame[which(car.test.frame$Mileage >= 7.6 & car.test.frame$Mileage < 9), "GFC"] = "C" # C组
car.test.frame[which(car.test.frame$Mileage >= 9 & car.test.frame$Mileage < 11.6), "GFC"] = "B" # B组
car.test.frame[which(car.test.frame$Mileage >= 11.6 & car.test.frame$Mileage <= 15.8), "GFC"] = "A" # A组

head(car.test.frame) # 展示处理后的前6行数据
##                  Price   Country Reliability   Mileage  Type Weight Disp.
## Eagle Summit 4    8895       USA           4  8.609848 Small   2560    97
## Ford Escort   4   7402       USA           2  8.609848 Small   2345   114
## Ford Festiva 4    6319     Korea           4  7.679054 Small   1845    81
## Honda Civic 4     6635 Japan/USA           5  8.878906 Small   2260    91
## Mazda Protege 4   6599     Japan           5  8.878906 Small   2440   113
## Mercury Tracer 4  8672    Mexico           4 10.927885 Small   2285    97
##                   HP GFC
## Eagle Summit 4   113   C
## Ford Escort   4   90   C
## Ford Festiva 4    63   C
## Honda Civic 4     92   C
## Mazda Protege 4  103   C
## Mercury Tracer 4  82   B

  在建模之前还有一项重要工作就是划分训练集和测试集,在这里采用分层抽样法将数据集按照3:1的比例划分为训练集(train)和测试集(test)。为了保持数据集原有的分布这里采用分层抽样的方法对数据集进行划分。

# install.packages("caret")
library(caret)

set.seed(111) # 设置随机数种子,保持随机抽样的一致性

# 按照分组油耗(GFC)进行分层抽样,设置训练集的占比为75%
index <- createDataPartition(car.test.frame$GFC, p = 0.75, list = FALSE) 
train <- car.test.frame[index, ] # 抽取75%的数据作为训练集
test <- car.test.frame[-index, ] # 抽取25%的数据作为训练集

  到这里数据的准备工作就完成了,下面正式进行建模操作。

3.2 CART算法实现

  在CART部分将分别使用变量“油耗(Mileage)”建立回归树,变量“分组油耗(GFC)”建立分类树。建模之前需要加载相关的程序包“rpart”:

# install.packages("rpart")
library(rpart)

3.2.1 CART回归树

  以除了“分组油耗(GFC)”以外的所有变量来对“油耗(Mileage)”变量建立回归决策树:

  在rpart()函数中,主要使用以下5个参数:

rpart(formula, data, subset, na.action = na.rpart, method, …)

  • formula:建模所需要的公式
  • data:一个数据框格式的数据集即训练集
  • na.action:缺失数据的处理方法,默认的是删除y缺失的所有观测值,但会保留缺少一个或多个预测变量的那些观测值
  • method:根据树末端的数据类型选择相应变量分割方法,可选值有anova(连续型)、poisson(计数型)、class(离散型)或者exp(生存分析型)中的一个
# 设置好模型需要用到的公式
formula.rgr <- Mileage ~ Price + Country + Reliability + Type + Weight + Disp. + HP

# 设置随机数种子,这里的作用与上面的相同
set.seed(123)

# 构建回归决策树模型,其中method = "anova"表示回归树
model.cart.regression <- rpart(formula.rgr, train, method = "anova")

model.cart.regression # 输出建模结果
## n= 46 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 46 208.384300 11.911310  
##    2) Disp.< 134 19  29.734120  9.889969 *
##    3) Disp.>=134 27  46.391150 13.333730  
##      6) Type=Compact,Medium,Sporty 20  21.053470 12.795730  
##       12) HP< 141.5 12   7.866862 12.275490 *
##       13) HP>=141.5 8   5.066979 13.576100 *
##      7) Type=Large,Van 7   3.009215 14.870870 *

  上面的输出结果中:

  • 1)表示根节点的样本有46个,即全部训练集的样本
  • 2)和3)以发动机功率(Disp.)变量为节点,且以134为分割点划分为两支(节点信息后有“*“号的代表叶节点),以此类推
  • 不同的缩进量表示节点的层次不同

  通过printcp()函数还可以导出回归树的cp表格:

printcp(model.cart.regression) # 输出建模结果的cp值等信息
## 
## Regression tree:
## rpart(formula = formula.rgr, data = train, method = "anova")
## 
## Variables actually used in tree construction:
## [1] Disp. HP    Type 
## 
## Root node error: 208.38/46 = 4.5301
## 
## n= 46 
## 
##         CP nsplit rel error  xerror     xstd
## 1 0.634688      0   1.00000 1.07274 0.172822
## 2 0.107150      1   0.36531 0.41860 0.066355
## 3 0.038965      2   0.25816 0.36532 0.052610
## 4 0.010000      3   0.21920 0.36532 0.052610

  从cp表格可以看出在建树的过程中要用到的变量有:发动机功率(Disp.)、价格(Price)、类型(Type),且各个节点的CP值、节点序号nsplit、错误率rel error、交互验证错误率xerror等也被列出,需要注意的是其中的CP值对于选择控制树的复杂程度十分重要。

  更详细的信息可以通过summary()函数进行查看:

summary(model.cart.regression) # 输出建模结果的具体信息
## Call:
## rpart(formula = formula.rgr, data = train, method = "anova")
##   n= 46 
## 
##           CP nsplit rel error    xerror       xstd
## 1 0.63468814      0 1.0000000 1.0727419 0.17282197
## 2 0.10715039      1 0.3653119 0.4186040 0.06635450
## 3 0.03896468      2 0.2581615 0.3653226 0.05260971
## 4 0.01000000      3 0.2191968 0.3653226 0.05260971
## 
## Variable importance
##   Disp.  Weight    Type      HP Country   Price 
##      24      20      16      15      13      12 
## 
## Node number 1: 46 observations,    complexity param=0.6346881
##   mean=11.91131, MSE=4.530094 
##   left son=2 (19 obs) right son=3 (27 obs)
##   Primary splits:
##       Disp.  < 134     to the left,  improve=0.6346881, (0 missing)
##       Weight < 3087.5  to the left,  improve=0.5466584, (0 missing)
##       HP     < 104.5   to the left,  improve=0.5334127, (0 missing)
##       Price  < 11484.5 to the left,  improve=0.5021378, (0 missing)
##       Type   splits as  RRRLRR,      improve=0.4920939, (0 missing)
##   Surrogate splits:
##       Weight  < 2722.5  to the left,  agree=0.870, adj=0.684, (0 split)
##       HP      < 104.5   to the left,  agree=0.826, adj=0.579, (0 split)
##       Country splits as  LLRLLLRR,    agree=0.804, adj=0.526, (0 split)
##       Type    splits as  RRRLRR,      agree=0.804, adj=0.526, (0 split)
##       Price   < 9115.5  to the left,  agree=0.783, adj=0.474, (0 split)
## 
## Node number 2: 19 observations
##   mean=9.889969, MSE=1.564954 
## 
## Node number 3: 27 observations,    complexity param=0.1071504
##   mean=13.33373, MSE=1.718191 
##   left son=6 (20 obs) right son=7 (7 obs)
##   Primary splits:
##       Type   splits as  LRL-LR,      improve=0.4813087, (0 missing)
##       Weight < 3165    to the left,  improve=0.4560014, (0 missing)
##       Price  < 11854.5 to the left,  improve=0.2893911, (0 missing)
##       HP     < 141.5   to the left,  improve=0.2816829, (0 missing)
##       Disp.  < 181.5   to the left,  improve=0.1960822, (0 missing)
##   Surrogate splits:
##       Weight < 3637.5  to the left,  agree=0.926, adj=0.714, (0 split)
## 
## Node number 6: 20 observations,    complexity param=0.03896468
##   mean=12.79573, MSE=1.052673 
##   left son=12 (12 obs) right son=13 (8 obs)
##   Primary splits:
##       HP     < 141.5   to the left,  improve=0.3856670, (0 missing)
##       Weight < 3087.5  to the left,  improve=0.3228469, (0 missing)
##       Price  < 11854.5 to the left,  improve=0.2187263, (0 missing)
##       Disp.  < 158     to the left,  improve=0.1813912, (0 missing)
##       Type   splits as  L-R-R-,      improve=0.1223695, (0 missing)
##   Surrogate splits:
##       Weight      < 3087.5  to the left,  agree=0.85, adj=0.625, (0 split)
##       Price       < 15165   to the left,  agree=0.80, adj=0.500, (0 split)
##       Disp.       < 152     to the left,  agree=0.80, adj=0.500, (0 split)
##       Country     splits as  --RL--LL,    agree=0.75, adj=0.375, (0 split)
##       Reliability < 3.5     to the left,  agree=0.65, adj=0.125, (0 split)
## 
## Node number 7: 7 observations
##   mean=14.87087, MSE=0.4298878 
## 
## Node number 12: 12 observations
##   mean=12.27549, MSE=0.6555718 
## 
## Node number 13: 8 observations
##   mean=13.5761, MSE=0.6333724

  不难看出summary()函数提供的结果中有一部分与上面printcp()函数相同,除此之外还增加了变量重要程度(Variable importance)、每个分支变量对生成树的提升程度(improve)等信息。

  以上是建立模型的部分,模型的调参这里就不进行细讲,如果有需要可以借助R中的帮助文档去了解各个参数的具体作用,然后进行相应的调参操作。

  为了建模结果的清晰直观,我们用rpart.plot程序包中的rpart.plot()函数以图形的方式将模型结果表示出来:

# install.packages("rpart.plot")    
library(rpart.plot)                

rpart.plot(model.cart.regression) # 图形展示

  参考之前的model.cart.regression的结果很容易发现,这里的结果与之相同:

  • 对于发动机功率(Disp.)小于134公升的车油耗(Mileage)被预测为9.9升/百公里
  • 对于发动机功率(Disp.)大于等于134公升、车型不为Cmp、Mdm、Spr的车油耗(Mileage)被预测为15升/百公里
  • 发动机功率(Disp.)大于等于134公升、车型为Cmp、Mdm、Spr且净马力(HP)小于142的车油耗(Mileage)被预测为12升/百公里
  • 发动机功率(Disp.)大于等于134公升、车型为Cmp、Mdm、Spr且净马力(HP)大于等于142的车油耗(Mileage)被预测为14升/百公里。

  同样的,rpart.plot()函数中也有很多参数,可以通过控制参数来达到各种目的,比如更改所绘制树状图的类型,那么可以修改参数type来达到目的。具体的参数信息请自行参考R的帮助文档。

3.2.2 CART分类树

  下面对变量分组油耗(GFC)构建分类树,即以除了“油耗(Mileage)”以外的所有变量来对“分组油耗(GFC)”变量建立分类决策树:

formula.cla <- GFC ~ Price + Country + Reliability + Type + Weight + Disp. + HP # 设置模型公式

set.seed(123)                                                                   # 设置随机数种子

model.cart.classification <- rpart(formula.cla, train, method = "class")        # 构建分类决策树模型,其中method = "class"表示分类树

model.cart.classification                                                       # 输出模型结果
## n= 46 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 46 19 A (0.58695652 0.26086957 0.15217391)  
##   2) Disp.>=134 27  2 A (0.92592593 0.07407407 0.00000000) *
##   3) Disp.< 134 19  9 B (0.10526316 0.52631579 0.36842105) *
rpart.plot(model.cart.classification)                                           # 图形展示

  观察图可以知道,这里在建树的过程中用到的变量只有发动机功率(Disp.),且小于134的车油耗分组为A组,否则为B组。

  接下来,我们使用分类树来对测试集test中的“分组油耗(GFC)”变量进行预测,并对预测结果进行评价:

pre.cart.cla <- predict(model.cart.classification, test, type = "class") # 对测试集进行预测

confusionMatrix(pre.cart.cla, test[, 9])                                 #建立混淆矩阵
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction A B C
##          A 8 0 0
##          B 0 4 2
##          C 0 0 0
## 
## Overall Statistics
##                                           
##                Accuracy : 0.8571          
##                  95% CI : (0.5719, 0.9822)
##     No Information Rate : 0.5714          
##     P-Value [Acc > NIR] : 0.02481         
##                                           
##                   Kappa : 0.7407          
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: A Class: B Class: C
## Sensitivity            1.0000   1.0000   0.0000
## Specificity            1.0000   0.8000   1.0000
## Pos Pred Value         1.0000   0.6667      NaN
## Neg Pred Value         1.0000   1.0000   0.8571
## Prevalence             0.5714   0.2857   0.1429
## Detection Rate         0.5714   0.2857   0.0000
## Detection Prevalence   0.5714   0.4286   0.0000
## Balanced Accuracy      1.0000   0.9000   0.5000

  可以看到,预测的正确率为85.71%,说明决策树模型虽然原理简单易懂,但是真正用起来确实是比较得心应手的。

需要说明的是,这里的预测仅仅是为了说明如何使用已经建立的决策树来对未知样本的目标变量进行预测。实质上,这种小样本计算错误率并没有太大意义。

3.3 C4.5算法实现

  在前面部分,我们见识了CART算法的分类树模型,这里将继续沿用上面的car.test.frame数据集中的“分组油耗(GFC)”进行C4.5算法的分类树模型建立。建模之前还是需要加载需要用到的程序包RWeka:

# install.packages("Rweka")  
library(RWeka)    

  为了使J48()函数可以识别,需要将分组油耗(GFC)的变量类型变为因子型:

train$GFC <- as.factor(train$GFC) # 将变量转换为因子型

  建立模型:

  在J48函数中,主要使用以下5个参数:

J48(formula, data, subset, na.action,…)

  • formula:用来建模的公式
  • data:待训练的数据集,即通常所说的训练集
  • subset:可以选择出data中的若干行样本来建立模型
  • na.action:用来处理缺失值,其默认选择是na.rpart
formula.C4.5 <- GFC ~ Price + Country + Reliability + Type + Weight + Disp. + HP # 设置建模公式

set.seed(111)                                                                   # 设置随机数种子

model.C4.5.classification <- J48(formula.C4.5, train)                           # 建立分类树模型

summary(model.C4.5.classification)                                              # 查看模型的概况
## 
## === Summary ===
## 
## Correctly Classified Instances          35               92.1053 %
## Incorrectly Classified Instances         3                7.8947 %
## Kappa statistic                          0.8546
## Mean absolute error                      0.095 
## Root mean squared error                  0.218 
## Relative absolute error                 25.41   %
## Root relative squared error             50.7594 %
## Total Number of Instances               38     
## 
## === Confusion Matrix ===
## 
##   a  b  c   <-- classified as
##  22  1  0 |  a = A
##   2  7  0 |  b = B
##   0  0  6 |  c = C

  上面的结果中,Correctly Classified Instances(35)表示被正确分类的样本数为35;相应的,被错误归类的样本数为3。以及它们各自占样本总数的百分比(92.1.53%, 7.8947%),当然还有样本总数Total Number of Instances(38),混淆矩阵Confusion Matrix也给了出来。为了更简单直观,下面将以图形的方式展示树形结果:

plot(model.C4.5.classification) # 模型的图形展示

  由输出结果我们知道,这是一个含有3个叶节点,共计5个节点的分类树。从图中很容易发现叶节点比较特殊,它是以一个一个小的图形展示出来的,就拿第二个小图来说,它表示价格(Price)大于7402美元且发动机功率(Disp.)小于等于133的车的油耗有大约90%的可能属于B范围内、10%的可能属于A范围内。

4 总结


  以上从决策树的分类以及原理开始讲解,还介绍了各种决策树算法,最后用案例来实现算法。相信这一套流程下来,大家对决策树会有个很深刻的认识。—–待完善—–

5 参考文献


[1] 周志华. 机器学习 : = Machine learning[M]. 清华大学出版社, 2016.

滚动至顶部