R 学习笔记: Cross validation

在实际分析数据的时候, 并没有足够多的数据去建模并对模型进行测试. 而测试误差 (test error) 和训练误差 (training error) 是有差距的. 为了能够估计测试误差, 可以使用交叉验证的方法.

所谓交叉验证, 简单讲就是从数据中抽取一部分的数据作为训练数据, 然后另一部分数据作为测试数据, 然后再抽取一部分的数据作为训练数据, 另一部分作为测试数据, 循环往复. 如此就能得到训练误差和测试误差, 从而能够去评判模型的好坏.

常用的交叉验证方法有留下一个交叉验证 (leave one out cross validation) 和 k 倍 交叉验证 (k fold cross validation). LOOCV 是每次取一个数据作为测试数据, 而其他数据作为训练数据. k-fold CV 是把数据分为 k 份, 然后每次取其中一份作为测试数据, 其他作为测试数据. 实际上 LOOCV 可以看作是 k 为样本数的 k-fold CV.

R 例子

使用 ISLR 包中的 Auto 数据来举例.

> library(ISLR)
> names(mgp)
[1] "mpg"          "cylinders"    "displacement" "horsepower"   "weight"
[6] "acceleration" "year"         "origin"       "name"

先做简单的 generalized linear model.

> glm.fit <- glm(mpg~horsepower, data=Auto)
> coef(glm.fit)
(Intercept)  horsepower
 39.9358610  -0.1578447

可以随机抽取一部分数据作为训练数据, 另一部分数据是测试数据.

> set.seed(1)
> training <- sample(c(TRUE, FALSE), dim(Auto)[1], replace=TRUE, prob=c(0.9, 0.1))
> test <- !training

然后我们再做 glm 并计算 Mean Squared Error.

> glm.model <- glm(mpg~horsepower, data=Auto, subset=training)
> glm.pred <- predict(glm.model, newdata=Auto[test,])
> mean((glm.pred - Auto[test, 'mpg'])^2)
[1] 31.52817

接下来我们可以使用 10-fold CV.

> groups <- sample(1:10, dim(Auto[1]), replace=TRUE)
> cv.error <- rep(NA, 10)
> for (i in seq(10)) {
+     glm.model <- glm(mpg~horsepower, data=Auto, subset=(groups!=i))
+     glm.pred <- predict(glm.model, newdata=Auto[groups==i, ])
+     cv.error[i] <- mean((glm.pred - Auto[groups==i,'mpg'])^2)
+ }
> mean(cv.error)
[1] 23.99952

在 boot 包中, 提供了 cv.glm 函数, 可以自动完成 CV.

> library(boot)
> glm.model <- glm(mpg~horsepower, data=Auto)
> cv.err <- cv.glm(Auto, glm.model, K=10)
> cv.err$delta
[1] 24.17549 24.16324

cv.glm 结果中的 delta 是平均的 MSE 和校正过的 MSE

如果 mpg 和 horsepower 不是线性的关系, 则可以利用 CV 来确定具体的多项式.

> cv.error <- rep(NA, 5)
> for (i in seq(5)) {
+     glm.model <- glm(mpg~poly(horsepower, i), data=Auto)
+     cv.error[i] <- cv.glm(Auto, glm.model)$delta[1]
+ }
> cv.error
[1] 24.23151 19.24821 19.33498 19.42443 19.03321

当 i = 2 时, CV MSE 最小, 由此可见多项式分布的项数应该为 2, mpg 和 horsepower 的关系应该为 mpg = a0 + a1 * horsepower + a2 * horsepower ^ 2.

Reference

  • G. James et al., An Introduction to Statistical Learning: with Applications in R, Springer Texts in Statistics, DOI 10.1007/978-1-4614-7138-7 5
By @Wolfson Liu in
Tags : #r, #data,