决策树笔记

决策树的原理

决策树的原理可以参考这篇笔记《StatQuest学习笔记19——决策树》

决策树的计算过程

决策树(decision tree)用于回归时称为回归树(regression tree),用于分类时则称为分类树(classification tree)。这里我们先看一批数据,如下所示:

mark

原始数据很多,有1000多行。这个数据来自于Flake and lawrence(2002),可以在网上下载,这且数据是做什么的,我也不清楚。这个数据集有6个自变量和1个因变量,一共有1385个观测值。书上说之所以用这个数据是因为这里的所有变量都是数量变量,因此这个数据集也可以用经典线性回归来做,从而可以对各种方法进行比较。

线性回归

现在先对全部数据做一个简单的 线性回归,如下所示:

1
2
3
4
5
6
raw_dt <- read.table("https://raw.githubusercontent.com/20170505a/raw_data/master/mg.csv",header=T,sep=",")
result_dt <- lm(y~.,raw_dt)
summary(result_dt)
pairs(raw_dt)#查看不同变量之间的相关关系
# 或plot(raw_dt)
cor(raw_dt)

结果如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
> summary(result_dt)
Call:
lm(formula = y ~ ., data = raw_dt)
Residuals:
Min 1Q Median 3Q Max
-0.40866 -0.09516 0.01098 0.09645 0.40011
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 0.95582 0.07142 13.384 < 2e-16 ***
x1 0.28512 0.02773 10.282 < 2e-16 ***
x2 0.17263 0.02655 6.502 1.10e-10 ***
x3 -0.37956 0.02769 -13.709 < 2e-16 ***
x4 -0.36756 0.02764 -13.298 < 2e-16 ***
x5 0.07165 0.02656 2.697 0.00707 **
x6 0.19003 0.02778 6.841 1.18e-11 ***
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 0.1459 on 1378 degrees of freedom
Multiple R-squared: 0.5868, Adjusted R-squared: 0.585
F-statistic: 326.2 on 6 and 1378 DF, p-value: < 2.2e-16

mark

mg数据各个变量之间的相关系数如下所示:

1
2
3
4
5
6
7
8
9
> cor(raw_dt)
y x1 x2 x3 x4 x5 x6
y 1.0000000 0.2620280 0.5441721 -0.5006887 -0.4355964 0.5510561 0.1511016
x1 0.2620280 1.0000000 -0.1465800 -0.6543690 0.4333849 0.4543362 -0.6277910
x2 0.5441721 -0.1465800 1.0000000 -0.1475594 -0.6539043 0.4367584 0.4515753
x3 -0.5006887 -0.6543690 -0.1475594 1.0000000 -0.1471907 -0.6537680 0.4397873
x4 -0.4355964 0.4333849 -0.6539043 -0.1471907 1.0000000 -0.1494223 -0.6524989
x5 0.5510561 0.4543362 0.4367584 -0.6537680 -0.1494223 1.0000000 -0.1498846
x6 0.1511016 -0.6277910 0.4515753 0.4397873 -0.6524989 -0.1498846 1.0000000

从前面的线性回归结果可以看出来,$R^2=0.5868$,这个数值并不太好,再加上两两散点图与相关系数的结果,我们看不出各个变量之间的任何模式。现在我们使用其它的几种算法建模方法,同时使用交叉验证(利用均方误差或标准化均方误差)对各种方法时行比较。

下面我们对各种方法都用10折交叉验证的方法来判断其结果的可靠性(关于10折交叉验证的一些知识,可以看这笔记《StatQuest学习笔记22——交叉验证》)。对于每种方法,按照随机建立的10个训练集通过计算建立10个模型,对测试集分别得到10个标准化均方误差(NMSE),再得出10次平均的NMSE。令$\bar{y}$为因变量均值,$\hat{y}$为从训练集得到的模型对一个测试数据集的预测值,这里NMSE的定义如下所示:

$NMSE=\bar{(y-\hat{y})^2}/=\bar{(y-\bar{y})^2}=\sum(y-\hat{y})^2/\sum(y-\bar{y})^2$

显然,如果什么模型都不用,仅仅用均值来做预测,那么NMSE应该等于1.所以,如果在回归时得到NMSE大于1,这个模型就很糟糕了,还不如没有模型。仅仅对于训练集来说,其NMSE等于$1-R^2$,这里的$R^2$为回归的确定系数,但是对于测试集来说,其NMSE与测试集回归的$R^2$没有什么关系,交叉验证主要关心测试集的NMSE。

构建训练集和测试集

现在构建为交叉验证服务的10个训练集和测试集,具体过程是:随机把下标分配给1,2,…10这10个数字,也就是把数据下标随机分成10份,然后每次提取一份作为测试集,其它9份放在一起作为训练集,用模型进行拟合,记下结果和误差,如此下去,一共做10次,最后把误差平均起来。这里所用的随机选Z折下标集的函数如下所示:

1
2
3
4
5
6
7
8
9
CV=function(n,Z=10,seed=1000){
z=rep(1:Z,ceiling(n/Z))[1:n]
set.seed(seed)
z=sample(z,n)
mm=list()
for(i in 1:Z)
mm[[i]]=(1:n)[z==i]
return(mm)
}

这里的n是样本量,Z为折的数目(也就是训练集加测试集的数目,这里默认是10),seed为随机种子,这里设为了1000,转出为Z个下标集,mm[[i]]为第i(i=1,2,…,Z)个下标集,根据这个函数对这个mg数据集求10个下标集的代码如下所示:

1
2
3
4
n<-nrow(raw_dt)
Z=10
mm=CV(n,Z)
D=1

在这段代码中,mm存储了10个下标集,而D说明因变量是第一个变量,这些值在后面对每种方法都要用。

为了和其它方法进行地比较,先对数据mg求简单的线性回归的10折交叉验证的测试集的NMSE,代码如下所示:

1
2

结果如下所示:

1
2
> mean(MSE)
[1] 0.4222689

结果表明,测试集的NMSE为0.4222689

决策树回归

决策树(decision tree)用于回归时称为回归树(regression tree),用于分类时则称为分类树(classification tree),现在我们先用数据集mg做一棵回归树,然后解释其意义。这里要使用到rpart包中的rpart()函数,如下所示:

1
2
3
4
library(rpart)
library(rpart.plot)
a=rpart(y~.,raw_dt)
rpart.plot(a,type=2,faclen=0)

结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
> a
n= 1385
node), split, n, deviance, yval
* denotes terminal node
1) root 1385 70.9545100 0.9301440
2) x2< 0.9894983 736 31.9699900 0.8017888
4) x3>=0.8423859 551 16.4573000 0.7288982
8) x2>=0.5178963 483 13.3868900 0.7068909
16) x2< 0.7999875 263 4.7064590 0.6422662
32) x3< 1.092533 185 3.3567710 0.6073434 *
33) x3>=1.092533 78 0.5889238 0.7250958 *
17) x2>=0.7999875 220 6.2689860 0.7841468 *
9) x2< 0.5178963 68 1.1749050 0.8852150 *
5) x3< 0.8423859 185 3.8660570 1.0188850
10) x4>=1.062353 71 1.8511830 0.9239636 *
11) x4< 1.062353 114 0.9767496 1.0780020 *
3) x2>=0.9894983 649 13.1077400 1.0757060
6) x3>=1.056185 206 3.5797200 0.9644841
12) x4>=0.9107852 23 0.5746155 0.7823564 *
13) x4< 0.9107852 183 2.1462980 0.9873744 *
7) x3< 1.056185 443 5.7947760 1.1274250 *

mark

计算的结果有2个,一个是文本描述的结果,一个是图形结果,这两个结果是等价的,我们可以了解这个决策树的构造过程。现在解释一下这个决策树。

决策树就像一棵倒长的树,有很多分叉,分叉点叫节点(node),其中1号节点为根节点,也就是上图中最高的那个节点,输出的文本信息则是全部数据的信息,其中在1号节点上,n=1385,偏差(deviance)=70.9545100,在该节点的因变量的均值为0.9301440。

这里的偏差是批在每个节点的偏差,它的定义为:在该节点的数据中的因变量与其均值之差的平方和,即$\sum_{i}(y_{i}-\bar{y})^2$这里回归树用偏差大小来选择变量,当然也可以用其它诸如残差平方和作为选择变量的准则。

上述的这些信息显示在上图中(不过只显示了均值),然后计算机程度在每个自变量中选择一个数值分割点,使得在该节点的整个数据在这个分割点比在其它分割点分成两部分之后的偏差之和都小,对每个自变量都选择这样的一个分割点,并对不同自变量的分割结果做比较,得到一个使得分割后总偏差最小的变量,这个变量就是该节点的分割变量(或称为拆分变量)。

根据输出的文本信息,在这个案例中,根节点的分割变量为x2,其分割点为0.9894983,x2小于该值的观测值分到左边(节点2),大于该值的则分到右边(节点3)。在每个节点都展示出观测值数目、偏差及因变量均值这三项内容(不过决策树的图形中只有均值),这样数据就分成了两个部分。然后,对每一部分数据,选择分割变量及其分割点的程序重新开始。如此下去,根据一些按照拟合程度和避免决策树无限制长下去的准则,到一定节点就不再分割,这些节点就是终节点(也称为叶节点),也就是文本信息中带有星号(*)的为叶节点。

决策树的使用

使用决策树很简单,如果有新的(只知道自变量的)观测值,就可以根据决策树以及各个变量大小,从根节点一直走到一个终节点,那里因变量的均值就是该观测值的因变量预测值。决策树本身受数据影响较大,但许多决策树结果起来就可以构成很好的方法,例如boosting,bagging以及随机森林等组合方法就是以决策树作为基本决策单元构造出来的。

参考资料

  1. 复杂数据统计方法:基于R的应用.第三版.吴喜之