Validación de modelos predictivos (machine learning): Cross-validation, OneLeaveOut, Bootstraping


Más sobre ciencia de datos en cienciadedatos.net

Versión PDF: Github



Introducción


La finalidad última de un modelo es predecir la variable respuesta en observaciones futuras o en observaciones que el modelo no ha “visto” antes. El error mostrado por defecto tras entrenar un modelo suele ser el error de entrenamiento, el error que comete el modelo al predecir las observaciones que ya ha “visto”. Si bien estos errores son útiles para entender cómo está aprendiendo el modelo (estudio de residuos), no es una estimación realista de cómo se comporta el modelo ante nuevas observaciones (el error de entrenamiento suele ser demasiado optimista). Para conseguir una estimación más certera, se tiene que recurrir a un conjunto de test o emplear estrategias de validación basadas en resampling.

Los métodos de validación, también conocidos como resampling, son estrategias que permiten estimar la capacidad predictiva de los modelos cuando se aplican a nuevas observaciones, haciendo uso únicamente de los datos de entrenamiento. La idea en la que se basan todos ellos es la siguiente: el modelo se ajusta empleando un subconjunto de observaciones del conjunto de entrenamiento y se evalúa (calcular una métrica que mida como de bueno es el modelo, por ejemplo, accuracy) con las observaciones restantes. Este proceso se repite múltiples veces y los resultados se agregan y promedian. Gracias a las repeticiones, se compensan las posibles desviaciones que puedan surgir por el reparto aleatorio de las observaciones. La diferencia entre métodos suele ser la forma en la que se generan los subconjuntos de entrenamiento/validación.

Agunos de los términos empleados a lo largo de este documento son:

Conjunto de entrenamiento (training set): datos/observaciones con las que se entrena el modelo.

Conjunto de validación y conjunto de test (validation set y test set): datos/observaciones del mismo tipo que las que forman el conjunto de entrenamiento pero que no se han empleado en la creación del modelo. Son datos que el modelo no ha “visto”.

Error de entrenamiento (training error): error que comete el modelo al predecir observaciones que pertenecen al conjunto de entrenamiento.

Error de validación y error de test (evaluation error y test error): error que comete el modelo al predecir observaciones del conjunto de validación y del conjunto de test. En ambos casos son observaciones que el modelo no ha “visto”.



Estrategias de validación


Validación simple


El método más sencillo de validación consiste en repartir aleatoriamente las observaciones disponibles en dos grupos, uno se emplea para entrenar al modelo y otro para evaluarlo. Si bien es la opción más simple, tiene dos problemas importantes:

  • La estimación del error es altamente variable dependiendo de qué observaciones se incluyan como conunto de entrenamiento y cuáles como conjunto de validación (problema de varianza).

  • Al excluir parte de las observaciones disponibles como datos de entrenamiento (generalmente el 20%), se dispone de menos información con la que entrenar el modelo y, por lo tanto, se reduce su capacidad. Esto suele tener como consecuencia una sobrestimación del error comparado al que se obtendría si se emplearan todas las observaciones para el entrenamiento (problema de bias).

Leave One Out Cross-Validation (LOOCV)


El método LOOCV en un método iterativo que se inicia empleando como conjunto de entrenamiento todas las observaciones disponibles excepto una, que se excluye para emplearla como validación. Si se emplea una única observación para calcular el error, este varía mucho dependiendo de qué observación se haya seleccionado. Para evitarlo, el proceso se repite tantas veces como observaciones disponibles, excluyendo en cada iteración una observación distinta, ajustando el modelo con el resto y calculando el error con dicha observación. Finalmente, el error estimado por el LOOCV es el promedio de todos lo i errores calculados.

El método LOOCV permite reducir la variabilidad que se origina si se divide aleatoriamente las observaciones únicamente en dos grupos. Esto es así porque al final del proceso de LOOCV se acaban empleando todos los datos disponibles tanto como entrenamiento como validación. Al no haber una separación aleatoria de los datos, los resultados de LOOCV son totalmente reproducibles.

La principal desventaja de este método es su coste computacional. El proceso requiere que el modelo sea reajustado y validado tantas veces como observaciones disponibles (n) lo que en algunos casos puede ser muy complicado. Excepcionalmente, en la regresión por mínimos cuadrados y regresión polinomial, por sus características matemáticas, solo es necesario un ajuste, lo que agiliza mucho el proceso.

LOOCV es un método de validación muy extendido ya que puede aplicarse para evaluar cualquier tipo de modelo. Sin embargo, los autores de An Introduction to Statistical Learning consideran que, al emplearse todas las observaciones como entrenamiento, se puede estar cayendo en overfitting, por lo que, aun considerándolo muy aceptable, recomiendan emplear K-Fold Cross-Validation.

K-Fold Cross-Validation


El método K-Fold Cross-Validation es también un proceso iterativo. Consiste en dividir los datos de forma aleatoria en k grupos de aproximadamente el mismo tamaño, k-1 grupos se emplean para entrenar el modelo y uno de los grupos se emplea como validación. Este proceso se repite k veces utilizando un grupo distinto como validación en cada iteración. El proceso genera k estimaciones del error cuyo promedio se emplea como estimación final.

Dos ventajas del método K-Fold Cross-Validation frente al LOOCV:

  • Requerimientos computacionales: el número de iteraciones necesarias viene determinado por el valor k escogido. Por lo general, se recomienda un k entre 5 y 10. LOOCV es un caso particular de K-Fold Cross-Validation en el que k = nº observaciones, si el data set es muy grande o el modelo muy complejo, se requiere muchas más iteraciones.

  • Balance entre bias y varianza: la principal ventaja de K-fold CV es que consigue una estimación precisa del error de test gracias a un mejor balance entre bias y varianza. LOOCV emplea n-1 observaciones para entrenar el modelo, lo que es prácticamente todo el set de datos disponible, maximizando así el ajuste del modelo a los datos disponibles y reduciendo el bias. Sin embargo, para la estimación final del error se promedian las estimaciones de n modelos entrenados con prácticamente los mismos datos (solo hay un dato de diferencia entre cada conjunto de entrenamiento), por lo que están altamente correlacionados. Esto se traduce en un mayor riesgo de overfitting y por lo tanto de varianza. En el método K-fold CV los k grupos empleados como entrenamiento son mucho menos solapantes, lo que se traduce en menor varianza al promediar las estimaciones de error.

Aunque emplea menos observaciones como entrenamiento que LOOCV, son un número suficiente como para no tener un bias excesivo, por lo que el método K-fold CV con valores de k= [5, 10] consigue un mejor balance final.

Repeated k-Fold-Cross-Validation


Es exactamente igual al método k-Fold-Cross-Validation pero repitiendo el proceso completo n veces. Por ejemplo, 10-Fold-Cross-Validation con 5 repeticiones implica a un total de 50 iteraciones ajuste-validación, pero no equivale a un 50-Fold-Cross-Validation.

Bootstrapping


Una muestra bootstrap es una muestra obtenida a partir de la muestra original por muestreo aleatorio con reposición, y del mismo tamaño que la muestra original. Muestreo aleatorio con reposición (resampling with replacement) significa que, después de que una observación sea extraída, se vuelve a poner a disposición para las siguientes extracciones. Como resultado de este tipo de muestreo, algunas observaciones aparecerán múltiples veces en la muestra bootstrap y otras ninguna. Las observaciones no seleccionadas reciben el nombre de out-of-bag (OOB). Por cada iteración de bootstrapping se genera una nueva muestra bootstrap, se ajusta el modelo con ella y se evalúa con las observaciones out-of-bag.


  1. Obtener una nueva muestra del mismo tamaño que la muestra original mediante muestro aleatorio con reposición.

  2. Ajustar el modelo empleando la nueva muestra generada en el paso 1.

  3. Calcular el error del modelo empleando aquellas observaciones de la muestra original que no se han incluido en la nueva muestra. A este error se le conoce como error de validación.

  4. Repetir el proceso n veces y calcular la media de los n errores de validación.

  5. Finalmente, y tras las n repeticiones, se ajusta el modelo final empleando todas las observaciones de entrenamiento originales.


La naturaleza del proceso de bootstrapping genera cierto bias en las estimaciones que puede ser problemático cuando el conjunto de entrenamiento es pequeño. Existen ciertas modificaciones del algoritmo original para corregir este problema, algunos de ellos son: 632 method y 632+ method. Para más información sobre bootstrapping consultar Resampling: Test de permutación, Simulación de Monte Carlo y Bootstrapping.

Comparación



No existe un método de validación que supere al resto en todos los escenarios, la elección debe basarse en varios factores.

  • Si el tamaño de la muestra es pequeño, se recomienda emplear repeated k-Fold-Cross-Validation, ya que consigue un buen equilibrio bias-varianza y, dado que no son muchas observaciones, el coste computacional no es excesivo.

  • Si el objetivo principal es comparar modelos mas que obtener una estimación precisa de las métricas, se recomienda bootstrapping ya que tiene menos varianza.

  • Si el tamaño muestral es muy grande, la diferencia entre métodos se reduce y toma más importancia la eficiencia computacional. En estos casos, 10-Fold-Cross-Validation simple es suficiente.

Puede encontrarse un estudio comparativo de los diferentes métodos en Comparing Different Species of Cross-Validation.

Ejemplos


En los siguientes apartados se muestran ejemplos sencillos de como aplicar las estrategias de validación. Sin embargo, para proyectos de modelado estadístico y machine learning, es conveniente utilizar librerías pensadas con este fin, ya que automatizan en gran medida los procesos de validación. Algunas de las principales en R son: tidymodels, mlr3 y caret.

Validación simple


En el siguiente ejemplo se emplea el dataset Auto del paquete ISLR para mostrar el método de validación simple consistente en dividir las observaciones de forma aleatoria en dos grupos, uno de ellos se emplea como set de entrenamiento y otro como set de validación. Se pretende generar un modelo que permita predecir el consumo de un vehículo (mpg) a partir de la potencia del motor (horse power), estimar su test error rate y el grado de flexibilidad (polinomio) más adecuado.

library(ISLR)
data(Auto)
head(Auto, n = 3)
dim(Auto)
## [1] 392   9

En primer lugar se separan aleatoriamente las observaciones en dos grupos de 196 observaciones (mitad de las observaciones para cada set).

# Se seleccionan 196 índices aleatorios que formarán el training set. 
set.seed(1)
train <- sample(x = 1:392, 196)

Se genera un modelo lineal que relacione el consumo (mpg) con la potencia (horsepower) empleando únicamente los datos de training. Se recurre a la función lm() para generar el modelo y el argumento subset para identificar las observaciones que se tienen que emplear como entrenamiento.

modelo <- lm(mpg~horsepower, data = Auto, subset = train)
summary(modelo)
## 
## Call:
## lm(formula = mpg ~ horsepower, data = Auto, subset = train)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -9.3177 -3.5428 -0.5591  2.3910 14.6836 
## 
## Coefficients:
##              Estimate Std. Error t value Pr(>|t|)    
## (Intercept) 41.283548   1.044352   39.53   <2e-16 ***
## horsepower  -0.169659   0.009556  -17.75   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 5.032 on 194 degrees of freedom
## Multiple R-squared:  0.619,  Adjusted R-squared:  0.6171 
## F-statistic: 315.2 on 1 and 194 DF,  p-value: < 2.2e-16

Una vez generado el modelo se emplea la función predict() para estimar el consumo de las 196 observaciones restantes no empleadas como entrenamiento.

predicciones <- predict(object = modelo, newdata = Auto[-train, ])

Dado que se conoce el valor real de consumo de los 196 coches empleados como test, se puede estimar el error de predicción del modelo. Al tratarse de una variable continua se emplea como medida de error el MSE (mean square error).

error <- mean((Auto$mpg[-train] - predicciones)^2)
error
## [1] 23.26601

La estimación de test error rate del modelo de regresión lineal creado es de .

Como se ha comentado previamente, una de las principales desventajas de la validación simple es que la estimación de error varía mucho dependiendo de cómo se hayan repartido los datos entre conjunto de entrenamiento y el conjunto de validación. A continuación, se calcula el error para 100 repeticiones de la validación, repartiendo en cada una de ellas los datos de forma aleatoria.

library(ggplot2)
library(gridExtra)

cv_MSE <- rep(NA,100)
for (i in 1:100) {
  train <- sample(x = 1:392, 196)
  modelo <- lm(mpg ~ horsepower, data = Auto, subset = train)
  predicciones <- predict(object = modelo,newdata = Auto[-train, ])
  cv_MSE[i] <- mean((Auto$mpg[-train] - predicciones)^2)
}

p1 <- ggplot(data = data.frame(cv_MSE = cv_MSE), aes(x = 1, y = cv_MSE)) +
      geom_boxplot(outlier.shape = NA) +
      geom_jitter(colour = c("firebrick3"), width = 0.1) +
      coord_flip() +
      labs(title = "Distribución del error de validación simple") +
      theme_bw() +
      theme(axis.title.x = element_blank(),
            axis.text.x = element_blank(),
            axis.ticks.x = element_blank())

p2 <- ggplot(data = data.frame(cv_MSE = cv_MSE), aes(cv_MSE)) +
      geom_histogram(colour = "firebrick3") +
      theme_bw()

grid.arrange(p1, p2, ncol = 1)

Se observa que la estimación del error de test oscila entre 20.74 y 29.07, con media de 24.45 y sd 1.84.

La representación gráfica de las observaciones y del modelo, muestra que la relación entre las variables mpg y horsepower no es del todo lineal.

ggplot(data = Auto, aes(x = horsepower, y = mpg)) +
geom_point(colour = c("firebrick3")) +
geom_smooth(method = "lm", colour = "black") +
theme_bw() +
labs(title  =  'mpg ~ horsepower') +
theme(plot.title = element_text(hjust = 0.5, face = 'bold'))

Introduciendo flexibilidad en el modelo mediante polinomios de mayor grado podría reducir el error de predicción. La validación cruzada nos permite identificar con qué grado de polinomio se consigue el mejor modelo (menor test cv-MSE).

A continuación, se ajustan 10 modelos distintos empleando polinomios de grado 1 hasta 10 y se registra el cv-MSE para cada uno. La función poly() permite determinar el grado de polinomio empleado.

cv_MSE <- rep(NA,10)
set.seed(1)
train <- sample(x = 1:392, 196)

for (i in 1:10) {
  modelo <- lm(mpg ~ poly(horsepower,i), data = Auto, subset = train)
  predicciones <- predict(object = modelo, newdata = Auto[-train, ])
  cv_MSE[i] <- mean((Auto$mpg[-train] - predicciones)^2)
}

ggplot(data = data.frame(polinomio = 1:10, cv_MSE = cv_MSE),
       aes(x = polinomio, y = cv_MSE)) +
geom_point(colour = c("firebrick3")) +
geom_path() +
scale_x_continuous(breaks = c(0:10)) +
theme_bw() + 
labs(title  =  'Test Error ~ Grado del polinomio') +
theme(plot.title = element_text(hjust = 0.5, face = 'bold'))

Mediante validación cruzada simple se identifica que un modelo que emplee una función cuadrática o cúbica de horsepower minimiza el test error, es decir, captura mejor la relación existente entre las variables y por lo tanto tiene mayor precisión en sus predicciones. Dada la mínima mejoría que consiguen los polinomios de grado 3 o superior, siguiendo el principio de parsimonia, el de grado 2 es el más adecuado.

ggplot(data = Auto, aes(x = horsepower, y = mpg)) + geom_point(colour = c("firebrick3")) + 
    stat_smooth(method = "lm", formula = y ~ poly(x, 2), color = "black") + theme_bw() + 
    labs(title = "mpg ~ horsepower^2") + theme(plot.title = element_text(hjust = 0.5, 
    face = "bold"))



LOOCV


En R, se puede realizar LOOCV de cualquier generalized linear model creado mediante glm() empleando la función cv.glm() del paquete boot. La función glm() engloba diferentes tipos de modelos lineales que se especifican mediante el argumento family. En el caso de regresión lineal, no se especifica ningún tipo de familia y es equivalente a emplear la función lm(). Para regresión logística se emplea family=binomial.

La función cv.glm() calcula el error de predicción mediante cross-validation. Si el argumento k no se especifica, se le asigna automáticamente el número de observaciones empleadas para crear el modelo, lo que equivale a leave-one-out cross-validation. La función devuelve una lista con múltiples componentes, el resultado de la validación se almacena dentro del vector delta que contiene la estimación de error con y sin corrección.

# Se genera el modelo lineal, dado que se va a emplear LOOCV no es necesario
# dividir las observaciones en dos grupos
modelo <- glm(mpg~horsepower, data = Auto)

# Se emplea la función cv.glm() para la validación LOOCV
library(boot)
cv_error <- cv.glm(data = Auto,glmfit = modelo)
cv_error$delta
## [1] 24.23151 24.23114

La estimación de error del modelo lineal mediante LOOCV es de 24.23. Se ha obtenido un valor cercano al estimado mediante validación simple pero eliminando el problema de la variabilidad. Al igual que se ha hecho en el ejemplo anterior, se puede emplear LOOCV para identificar el grado de flexibilidad que permite obtener el mejor modelo.

cv_MSE <- rep(NA,10)
for (i in 1:10) {
  modelo <- glm(mpg ~ poly(horsepower, i), data = Auto)
  cv_MSE[i] <- cv.glm(data = Auto, glmfit = modelo)$delta[1]
}

ggplot(data = data.frame(polinomio = 1:10, cv_MSE = cv_MSE),
       aes(x = polinomio, y = cv_MSE)) +
geom_point(colour = c("firebrick3")) +
geom_path() +
scale_x_continuous(breaks = c(0:10)) +
theme_bw() + 
labs(title  =  'Test Error ~ Grado del polinomio') +
theme(plot.title = element_text(hjust = 0.5, face = 'bold'))

Los resultados de la validación indican que polinomios superiores a grado 2 o 3 no aportan una mejora sustancial.



K-fold Cross-Validation


La función cv.glm() del paquete boot puede emplearse para realizar K-fold Cross-Validation de cualquier modelo creado mediante la función glm(), especificando el número de grupos en el argumento k.

# Se genera el modelo lineal
modelo <- glm(mpg ~ horsepower, data = Auto)

# Se emplea la función cv.glm() para la validación, empleando en este caso k=10
set.seed(1)
cv_error <- cv.glm(data = Auto, glmfit = modelo, K = 10)
cv_error$delta
## [1] 24.21538 24.20081

La estimación de error del modelo lineal mediante K-fold Cross-Validation k=10 es de 24.1, un valor muy próximo al estimado mediante LOOCV, pero empleando mucho menos tiempo de computación. Cuando se especifica el número de grupos K a emplear en la validación, la función cv.glm() devuelve dos resultados, uno con corrección de continuidad y otro sin ella.

Si se estudia la influencia de la flexibilidad del modelo, los resultados muestran la misma tendencia que LOOCV. Se obtiene una mejora importante empleando una función cuadrática o cubica en comparación al modelo lineal. Para polinomios de grado > 3 la mejora es mínima.

cv_MSE_k10 <- rep(NA,10)

for (i in 1:10) {
  modelo <- glm(mpg ~ poly(horsepower, i), data = Auto)
  set.seed(17)
  cv_MSE_k10[i] <- cv.glm(data = Auto, glmfit = modelo, K = 10)$delta[1]
}
ggplot(data = data.frame(polinomio = 1:10, cv_MSE = cv_MSE_k10),
       aes(x = polinomio, y = cv_MSE)) +
geom_point(colour = c("firebrick3")) +
geom_path() +
scale_x_continuous(breaks = c(0:10)) +
theme_bw() + 
labs(title  =  'Test Error ~ Grado del polinomio') +
theme(plot.title = element_text(hjust = 0.5, face = 'bold'))



Bootstrapping


En el siguiente ejemplo se emplea el método bootstrap para estudiar la variabilidad en la estimación de los coeficientes de regresión \(\beta_0\) y \(\beta_1\) (intersección y pendiente) de un modelo de regresión lineal. El modelo en cuestión emplea la variable horsepower para predecir el consumo de un vehículo mpg, los datos empleados son del data set Auto.

En el caso de regresión lineal, es posible estimar la variabilidad de los coeficientes de regresión mediante fórmulas matemáticas (esta es la forma en la que lo calcula R por defecto y que se puede ver mediante summary(modelo)). Sin embargo, para que sean válidos los resultados, los residuos se tienen que distribuir de forma normal. La estimación mediante bootstrap no requiere de ninguna condición por lo que suele ser más precisa.

El proceso de bootstraping consiste en generar de forma iterativa diferentes modelos lineales, empleando en cada caso una bootstrap-sample creada mediante resampling del mismo tamaño que la muestra inicial. Para cada modelo ajustado se registran los valores de los coeficientes \(\beta_0\) y \(\beta_1\) y finalmente se estudia su distribución.

# Se define la función que devuelve el estadístico de interés, los coeficientes
# de regresión
fun_coeficientes <- function(data, index){
    return(coef(lm(mpg ~ horsepower, data = data, subset = index)))
}

# Se implementa un bucle que genere los modelos de forma iterativa y almacene
# los coeficientes. El data frame Auto tiene 392 observaciones
beta_0 <- rep(NA,9999)
beta_1 <- rep(NA,9999)
for(i in 1:9999) {
    coeficientes <- fun_coeficientes(data = Auto,
                                     index = sample(1:392, 392, replace = TRUE))
    beta_0[i] <- coeficientes[1]
    beta_1[i] <- coeficientes[2]
}
# Se muestra la distribución de los coeficientes
p5 <- ggplot(data = data.frame(beta_0 = beta_0), aes(beta_0)) +
      geom_histogram(colour = "firebrick3") + 
      theme_bw()
p6 <- ggplot(data = data.frame(beta_1 = beta_1), aes(beta_1)) +
      geom_histogram(colour = "firebrick3") +
      theme_bw()

grid.arrange(p5,p6, ncol = 2,
             top = "Bootstrap distribution de los coeficientes")

El valor del estadístico estimado mediante bootstrapping es la media de la bootstrap-distribution y la desviación estándar del estadístico la desviación estándar de la distribución:

  • \(\hat{\beta_0} =\) mean(beta_0) =39.9650243
  • \(SE(\hat{\beta_0}) =\) sd(beta_0) = 0.8687025
  • \(\hat{\beta_1} =\) mean(beta_1) = -0.1581849
  • \(SE(\hat{\beta_1}) =\) sd(beta_1) = 0.0074936

Para este caso, si se comparan los valores estimados mediante bootstrapping con los devueltos por la función lm() (obtenidos por t-test), se observa una diferencia muy pequeña.

summary(lm(mpg~horsepower,data = Auto))$coef
##               Estimate  Std. Error   t value      Pr(>|t|)
## (Intercept) 39.9358610 0.717498656  55.65984 1.220362e-187
## horsepower  -0.1578447 0.006445501 -24.48914  7.031989e-81

El mismo resultado se puede obtener empleando la función boot().

boot(data = Auto, statistic = fun_coeficientes, R = 9999)
## 
## ORDINARY NONPARAMETRIC BOOTSTRAP
## 
## 
## Call:
## boot(data = Auto, statistic = fun_coeficientes, R = 9999)
## 
## 
## Bootstrap Statistics :
##       original        bias    std. error
## t1* 39.9358610  0.0379396334 0.866345788
## t2* -0.1578447 -0.0004041619 0.007507923



Información sesión


sesion_info <- devtools::session_info()
dplyr::select(
  tibble::as_tibble(sesion_info$packages),
  c(package, loadedversion, source)
)



Bibliografía


Applied Predictive Modeling by Max Kuhn and Kjell Johnson

An Introduction to Statistical Learning: with Applications in R (Springer Texts in Statistics)

Linear Models with R by Julian J.Faraway

Points of Significance: Sampling distributions and the bootstrap by Anthony Kulesa, Martin Krzywinski, Paul Blainey & Naomi Altman



¿Cómo citar este documento?

Validación de modelos predictivos: Cross-validation, OneLeaveOut, Bootstraping por Joaquín Amat Rodrigo, disponible con licencia CC BY-NC-SA 4.0 en https://www.cienciadedatos.net/documentos/30_cross-validation_oneleaveout_bootstrap


¿Te ha gustado el artículo? Tu ayuda es importante

Mantener un sitio web tiene unos costes elevados, tu contribución me ayudará a seguir generando contenido divulgativo gratuito. ¡Muchísimas gracias! 😊


Creative Commons Licence
Este material, creado por Joaquín Amat Rodrigo, tiene licencia Attribution-NonCommercial-ShareAlike 4.0 International.