Gráficos ICE para interpretar modelos predictivos


Versión PDF: Github

Introducción


Los gráficos Individual Conditional Expectation (ICE) muestran la variación de las predicciones de un modelo de machine learning en función del valor que toma alguno de sus predictores. Además de ser muy útiles para entender la relación entre la variable respuesta y los predictores aprendida por el modelo, permiten diferenciar cuándo, dicha relación, es aditiva o está afectada por interacciones con otros predictores. También permiten entender cómo se comporta un modelo cuando se extrapola a regiones para las que no se dispone de observaciones.

Los gráficos ICE pueden considerarse una extensión de los gráficos de dependencia parcial Partial Dependence Plots (PDP). La diferencia entre ambos reside en que, los PDP, muestran, con una única curva, cómo varía en promedio la predicción de la variable respuesta a medida que se modifica uno de los predictores. Los ICE, muestran cómo varía la predicción para cada una de las observaciones (una curva distinta por cada observación).

A lo largo de este documento se muestran ejemplos de cómo se pueden obtener gráficos ICE y de qué información se puede extraer de ellos.

Paquete ICEbox


El paquete ICEbox contiene funciones que permiten calcular, explorar y representar gráficos ICE para cualquier tipo de modelo predictivo. A continuación, se muestra un ejemplo introductorio de cómo utilizarlo.

Curvas ICE


El set de datos Boston contiene información sobre la mediana del precio de las viviendas de la ciudad de Boston junto con variables relacionadas con las características de las casas y la zona donde se encuentran.

library(MASS)
datos <- Boston
head(datos)

Se entrena un modelo predictivo de tipo Random Forest con el objetivo de predecir el precio de la vivienda (medv) en función de todas las demás variables disponibles.

library(randomForest)
modelo_rf <- randomForest(formula = medv ~ .,
                          data = datos,
                          ntree = 500)

Una vez entrenado el modelo, con la función ice() se obtiene el gráfico ICE de cualquiera los predictores. Los principales argumentos de esta función son:

  • object: modelo del cual se quieren obtener las curvas ICE.

  • X: valor de los predictores con los que se ha entrenado el modelo.

  • y: valor de la variable respuesta con la que se ha entrenado el modelo. Se emplea para identificar el rango del eje y.

  • predictor: nombre o posición del predictor para el que se quiere obtener el gráfico ICE.

  • frac_to_build: fracción de observaciones de entrenamiento que se incluyen en el gráfico ICE. Por defecto se emplean todas (frac_to_build = 1) pero, si el set de datos es muy grande, se recomienda reducirlo. La selección se hace de forma que se incluya aproximadamente todo el rango de valores observado en el entrenamiento.

  • indices_to_build: índices de las observaciones que se incluyen en el gráfico ICE. Es una alternativa no aleatoria a frac_to_build. No pueden emplearse ambos argumentos a la vez.

  • num_grid_pts: número de puntos dentro del rango del predictor empleados para construir la curva ICE. Por defecto, se utilizan todos los valores del predictor observados en los datos de entrenamiento del modelo.

  • predictfcn: función opcional que acepta dos argumentos, un modelo (object) y un conjunto de datos (newdata), y devuelve un vector con las predicciones. Gracias a esta función se pueden obtener los gráficos ICE de cualquier modelo. Si este argumento no se especifica, se intenta encontrar automáticamente la función predict() correspondiente a la clase del modelo pasado a la función ice().

A continuación, se explora influencia que tiene la antigüedad de la vivienda (age) sobre el precio de la vivienda (medv).

library(ICEbox)

# Se separan los predictores de la variable respuesta.
datos_x      <- datos
datos_x$medv <- NULL
datos_y      <- datos$medv

ice_age <- ice(object = modelo_rf,
               X = datos_x,
               y = datos_y,
               predictor = "age",
               frac_to_build = 1,
               verbose = FALSE)
ice_age
## ice object generated on data with n = 506 for predictor "age"
## predictor considered continuous, logodds off

El objeto devuelto por ice() puede graficarse empleando la función plot().

plot(ice_age,
     x_quantile = FALSE,
     plot_pdp = TRUE,
     plot_orig_pts_preds = TRUE,
     main = "ICE plot")

Cada curva del grafico anterior (curva ICE) muestra el valor predicho de la variable respuesta para cada observación con forme se va aumentando el valor de age y manteniendo constantes el resto de predictores en su valor observado. La curva resaltada en amarillo se corresponde con la curva PDP, que es la variación promedio de todas las observaciones. Además, el gráfico incluye puntos que representan el verdadero valor de age de cada observación.

La gran mayoría de las curvas son planas, lo que apunta a que, en la mayor parte de los casos, la antigüedad de la vivienda apenas influye en su precio. Sin embargo, puede apreciarse que, unas pocas observaciones, presentan una ligera tendencia de subida o bajada.

Curvas ICE centradas


Cuando los valores observados de la variable respuesta se acumulan en una región pequeña, el solapamiento de las curvas puede hacer difícil distinguir qué observaciones que se escapan de la tendencia general. Para evitar este problema, se puede recurrir a los gráficos ICE centrados (c-ICE). Los gráficos c-ICE se obtienen igual que los gráficos ICE con la única diferencia de que, a cada una de las curvas, se les resta un valor de referencia, normalmente el valor predicho para el mínimo observado del predictor. De esta forma, se consigue que todas las curvas tengan su origen en el 0.

plot(ice_age,
     x_quantile = FALSE,
     plot_pdp = TRUE,
     plot_orig_pts_preds = TRUE,
     centered = TRUE,
     main = "c-ICE plot")

Con esta nueva representación puede observarse con más claridad que, aunque la mayoría de observaciones se mantienen constantes, algunas tienen un claro patrón divergente (para algunas el precio incrementa con la antigüedad y en otras disminuye). Tal y como se describe más adelante, esto suele ser un indicativo de que el predictor age interacciona con otros predictores. El eje vertical de la izquierda muestra el \(\%\) de desviación respecto al rango de \(y\).

Derivada de las curvas ICE


Si la relación existente entre la variable respuesta y el predictor estudiado es independiente del resto de predictores del modelo, entonces, las curvas del gráfico ICE comparten una misma forma y son paralelas las unas a las otras (la única diferencia es un desplazamiento en el eje vertical). Este comportamiento puede resultar complicado de validar visualmente cuando las curvas se superponen. Una forma de facilitar la identificación de interacciones entre predictores es representando las derivadas parciales de las curvas ICE (d-ICE). Si no existe ninguna interacción, todas las curvas son aproximadamente paralelas, sus derivadas aproximadamente iguales y, por lo tanto, el gráfico de derivadas muestra una única recta. Si existen interacciones, entonces, la representación de las derivadas parciales es heterogénea.

dice_age <- dice(ice_obj = ice_age)
plot(dice_age,
     plot_sd = TRUE,
     plot_orig_pts_deriv = TRUE,
     plot_dpdp = TRUE,
     main = "d-ICE plot")

## NULL

El gráfico sugiere que, cuando la antigüedad de la vivienda es inferior a 60 años, las derivadas parciales son \(\simeq 0\), por lo que no hay interacciones. Superados los 60 años, hay observaciones cuyas derivadas parciales se desvían sustancialmente de 0, indicando que, a partir de este valor, el predictor age interacciona con otros predictores.

En la zona inferior del gráfico se muestra la desviación estándar de las derivadas parciales en cada punto, lo que facilita encontrar regiones de alta heterogeneidad (regiones de interacción).

Colorear curvas ICE


Como se ha visto en los apartados anteriores, algunas observaciones pueden alejarse de la tendencia general del modelo. Una forma de conseguir información extra que permita comprender las razones de estos patrones divergentes es colorear las curvas de cada observación en función de otro factor. Por ejemplo, en el modelo de predicción del valor medio, se crea una nueva variable binaria que indique si el número de habitaciones de la vivienda está por encima o por debajo de la mediana.

# Si la variable no es uno de los predictores originales con los que se entrenó 
# el modelo, hay que añadirla en el objeto $Xice.
mediana_habitaciones <- median(x = ice_age$Xice$rm)
ice_age$Xice$supera_mediana <- ifelse(ice_age$Xice$rm > mediana_habitaciones,
                                      "si", "no")
plot(ice_age,
     x_quantile = FALSE,
     plot_pdp = TRUE,
     plot_orig_pts_preds = TRUE,
     centered = TRUE,
     color_by = "supera_mediana",
     main = "c-ICE plot")
## ICE Plot Color Legend
##  supera_mediana       color
##              no  firebrick3
##              si dodgerblue3

Gracias a los colores puede verse claramente que, para viviendas con un número de habitaciones por encima de la mediana (azul), la antigüedad de la vivienda está asociada positivamente con el precio. Para viviendas con un número de habitaciones inferior a la media, ocurre lo contrario.

Interacción entre predictores


En la introducción de documento, se menciona la diferencia entre los gráficos PDP y los ICE. La ventaja de los gráficos ICE queda patente cuando existe interacción entre predictores o cuando no todas las observaciones siguen una misma tendencia. Véase el siguiente ejemplo ilustrativo.

Se simula un set de datos siguiendo la siguiente ecuación:

\[Y = 0.2 X_1 - 5X_2 + 10 X_2 \mathbf{1}_{X_3 \geq 0} + \epsilon\] o lo que es equivalente

\[Y=\begin{cases} 0.2X_1 - 5X_2 + 10X_2 + \epsilon & \text{ si } X_3 \geq0 \\ 0.2X_1 - 5X_2 \epsilon & \text{ si } X_3 < 0 \end{cases}\]

\[\epsilon \sim N(0,1) \ \ \ \ X_1,X_2,X_3 \sim U(-1,1)\]

library(ggplot2)
set.seed(123)
x1 <- runif(n = 1000, min = -1, max = 1)
x2 <- runif(n = 1000, min = -1, max = 1)
x3 <- runif(n = 1000, min = -1, max = 1)
e  <- rnorm(n = 1000, mean = 0, sd = 1)
y <- 0.2*x1 - 5*x2 + 10*x2*I(x3 >= 0) + e

datos <- data.frame(x1, x2, x3, y)

ggplot(data = datos, aes(x = x2,y = y)) + 
  geom_point() +
  theme_bw()

Se entrena un modelo GBM para predecir \(y\) en función de las 3 variables disponibles.

library(gbm)
set.seed(123)
modelo_gbm <- gbm(formula = y ~ .,
                  data = datos,
                  n.tree = 500,
                  interaction.depth = 3,
                  shrinkage = 0.1,
                  distribution = "gaussian",
                  cv.folds = 5,
                  verbose = FALSE)
# Se separan los predictores de la variable respuesta.
datos_x      <- datos
datos_x$medv <- NULL
datos_y      <- datos$medv

# Aunque existe una función predict.gbm(), a modo ilustrativo, se indica una
# función propia en el argumento predictfcn.
ice_gbm_x3 <- ice(object = modelo_gbm,
              X = datos_x,
              y = datos_y,
              predictor = "x3", 
                    predictfcn = function(object, newdata){
                                   predict.gbm(object = object,
                                               newdata = newdata,
                                               n.trees = 435)
                                 },
                    frac_to_build = 1,
                    verbose = FALSE)
# Se grafican únicamente el 1% de las curvas.
plot(ice_gbm_x3,
     x_quantile = FALSE,
     plot_pdp = TRUE,
     frac_to_plot = 0.1)