Regresión cuantílica: intervalos de predicción con Random Forest Python

Regresión cuantílica: intervalos de predicción con Random Forest Python

Joaquín Amat Rodrigo
Noviembre, 2020

Introducción


La predicción de una variable continua $Y$ en función de uno o varios predictores $X$ es un problema de aprendizaje supervisado que puede resolverse con múltiples métodos de Machine Learning y aprendizaje estadístico. Algunos de ellos, consideran que la relación entre $X$ e $Y$ es únicamente lineal, mientras que otros permiten incorporar relaciones no lineales o incluso interacciones entre predictores. De una forma u otra, todos ellos tratan de inferir la relación entre $X$ e $Y$.

El objetivo de la mayoría de estos algoritmos es predecir el valor promedio de $Y$ en función del valor de $X$, $E(Y|X = x)$. Si bien conocer la media condicional es de utilidad, este resultado ignora otras características de la distribución de $Y$ que pueden ser claves a la hora de tomar decisiones, por ejemplo, su dispersión.

Véase el siguiente ejemplo simulado (y muy simplificado) sobre la evolución del consumo eléctrico de todas las casas de una ciudad en función de la hora del día.

La media del consumo eléctrico es la misma durante todo el día, $\overline{consumo} = 15 Mwh$, sin embargo, su dispersión no es constante (heterocedasticidad). Véase el resultado de predecir el consumo medio en función de la hora del día con un modelo Random Forest.

El valor predicho es muy próximo a la media real, es decir, el modelo es bueno prediciendo el consumo medio esperado. Ahora, imagínese que la compañía encargada de suministrar la electricidad debe de ser capaz de provisionar, en un momento dado, con hasta un 50% de electricidad extra respecto al promedio. Esto significa un máximo de 22.5 Mwh. Estar preparado para suministrar este extra de energía implica gastos de personal y maquinaría, por lo que la compañía se pregunta si es necesario estar preparado para producir tal cantidad durante todo el día, o si, por lo contrario, podría evitarse durante algunas horas, ahorrando así gastos.

Un modelo que predice únicamente el promedio no permite responder a esta pregunta ya que, tanto para las 2h de la mañana como para las 8h, el consumo promedio predicho es en torno a 15 Mwh. Sin embargo, la probabilidad de que se alcancen consumos de 22.5 Mwh a las 2h es prácticamente nula mientras que esto ocurra a las 10h sí es razonable.

Una forma de describir la dispersión de una variable es el uso de cuantiles. El cuantil de orden $\tau$ $(0 < \tau < 1)$ de una distribución es el valor de la variable $X$ que marca un corte tal que, una proporción $\tau$ de valores de la población, es menor o igual que dicho valor. Por ejemplo, el cuantil de orden 0.36 deja un 36% de valores por debajo y el cuantil de orden 0.50 el 50% (se corresponde con la mediana de la distribución).

Dado que los datos se han simulado empleando distribuciones normales, se conoce el valor de los cuantiles teóricos para cada $X$. Se muestra de nuevo el mismo gráfico pero esta vez añadiendo los cuantiles 0.1 y 0.9.

Si como resultado del modelo, además de la predicción de la media, se predice también el valor de los cuantiles, se dispone de una caracterización mayor de la distribución de la variable respuesta, y con ello se puede responder a más preguntas. Por ejemplo, en el caso de la energía, se tendría cierta seguridad al decir que, durante los intervalos de 0h a 5h y de 15h a 17h, es poco probable que se alcancen consumos de 22.5 Mwh.

Otros casos en los que conocer la distribución de cuantiles puede ser útil son:

  • Identificación de regiones en las que la variable respuesta $Y$ tiene mayor dispersión en torno a su media.

  • Entrenar modelos que predicen la mediana (cuantil 0.5) en lugar de la media. Estos modelos son más robustos frente a outliers.

  • Detectar anomalías, identificando aquellas observaciones que están fuera de un determinado intervalo cuantílico.

A lo largo de este documento, se describe el algoritmo Quantile Regression Forest, una adaptación de Random Forest capaz de aprender la distribución de cuantiles y con ello, generar intervalos de predicción.



Quantile Random Forest


Un modelo Random Forest está formado por un conjunto (ensemble) de árboles de decisión individuales, cada uno entrenado con una muestra aleatoria extraída de los datos de entrenamiento originales mediante bootstrapping). En cada árbol individual, las observaciones se van distribuyendo por bifurcaciones (nodos) generando la estructura del árbol hasta alcanzar un nodo terminal. La predicción de una nueva observación se obtiene agregando, normalmente con la media, las predicciones de todos los árboles individuales que forman el modelo. El algoritmo de Quantile Regression Forest propuesto por Meinshausen, 2006 sigue exactamente la misma estrategia para crear el modelo, la diferencia reside en cómo se calculan las predicciones.

En lugar de promediar el valor de las observaciones en los nodos terminales de cada árbol y luego promediar el conjunto de árboles, se almacena el valor individual de todas las observaciones. De esta forma, se obtiene una estimación de la distribución de predicciones. Una vez obtenida esta distribución, se pueden calcular los cuantiles o cualquier otro estadístico.


  1. Entrenar $k$ árboles, al igual que se hace en el algoritmo de random forest, pero almacenando la información de qué observaciones forman parte de cada nodo terminal, no solo el promedio.

  2. Para una nueva observación ($X$), recorrer cada árbol del modelo hasta llegar a un nodo terminal.

  3. Extraer de cada árbol las observaciones que forman parte del mismo nodo terminal que la observación predicha $X$.

  4. Calcular los cuantiles de todos los valores extraídos en el paso 3.


La implementación de random forest de scikitlearn no permite acceder de forma fácil al valor de las observaciones de entrenamiento con las que se forma cada nodo terminal en cada árbol. Sin hacer modificaciones, no se pueden calcular los cuantiles tal y como propone Meinshausen, 2006. Esta característica sí está disponible en la librería skranger, una implementación muy rápida y completa de random forest que, además, es compatible con scikitlearn.

Ejemplo regresión

Librerías

In [33]:
# Librerías
# ==============================================================================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import norm
from skranger.ensemble import RangerForestRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import ParameterGrid
from sklearn.model_selection import RepeatedKFold
import multiprocessing

Datos


Se simulan datos a partir de una distribución normal con una varianza distinta dependiendo de la hora del día.

In [34]:
# Datos simulados
# ==============================================================================
# Simulación ligeramente modificada del ejemplo publicado en
# XGBoostLSS – An extension of XGBoost to probabilistic forecasting Alexander März

np.random.seed(seed=1234)
n = 3000

# Distribución normal con varianza cambiante
x = np.linspace(start=0, stop= 24, num=n)
y = np.random.normal(
        loc   = 15,
        scale = 1 + 1.5*((4.8 < x) & (x < 7.2)) + 4*((7.2 < x) & (x < 12)) \
                + 1.5*((12 < x) & (x < 14.4)) + 2*(x > 16.8)
    )

# Cálculo del cuantil 0.1 y 0.9 para cada posición de x simulada.
cuantil_10 = norm.ppf(
                q = 0.1,
                loc   = 15,
                scale = 1 + 1.5*((4.8 < x) & (x < 7.2)) + 4*((7.2 < x) & (x < 12)) \
                        + 1.5*((12 < x) & (x < 14.4)) + 2*(x > 16.8)
             )

cuantil_90 = norm.ppf(
                q = 0.9,
                loc   = 15,
                scale = 1 + 1.5*((4.8 < x) & (x < 7.2)) + 4*((7.2 < x) & (x < 12)) \
                        + 1.5*((12 < x) & (x < 14.4)) + 2*(x > 16.8)
             )

Modelo

Se entrena un modelo random forest utilizando validación cruzada para encontrar los hiperparámetros óptimos. Para conocer más detalles sobre el entrenamiento de un modelo random forest visitar Random Forest con Python

In [35]:
# Grid de hiperparámetros evaluados
# ==============================================================================
param_grid = {'n_estimators': [1000, 2000, 5000],
              'min_node_size': [50, 100]
             }

# Búsqueda por grid search con validación cruzada
# ==============================================================================
grid = GridSearchCV(
        estimator  = RangerForestRegressor(seed=123, quantiles=True),
        param_grid = param_grid,
        scoring    = 'neg_root_mean_squared_error',
        n_jobs     = multiprocessing.cpu_count() - 1,
        cv         = RepeatedKFold(n_splits=5, n_repeats=1, random_state=123), 
        refit      = True,
        verbose    = 0,
        return_train_score = True
       )

grid.fit(X = x.reshape(-1, 1), y = y)

# Resultados
# ==============================================================================
resultados = pd.DataFrame(grid.cv_results_)
resultados.filter(regex = '(param.*|mean_t|std_t)') \
    .drop(columns = 'params') \
    .sort_values('mean_test_score', ascending = False) \
    .head(4)
Out[35]:
param_min_node_size param_n_estimators mean_test_score std_test_score mean_train_score std_train_score
3 100 1000 -3.167555 0.164495 -2.734589 0.034723
4 100 2000 -3.167831 0.162977 -2.734691 0.034647
5 100 5000 -3.168181 0.163076 -2.734557 0.034604
0 50 1000 -3.223868 0.169073 -2.604583 0.036241
In [36]:
# Mejores hiperparámetros por validación cruzada
# ==============================================================================
print("----------------------------------------")
print("Mejores hiperparámetros encontrados (cv)")
print("----------------------------------------")
print(grid.best_params_, ":", grid.best_score_, grid.scoring)

modelo_final = grid.best_estimator_
----------------------------------------
Mejores hiperparámetros encontrados (cv)
----------------------------------------
{'min_node_size': 100, 'n_estimators': 1000} : -3.1675551112349245 neg_root_mean_squared_error

Predicción

In [37]:
# Se predice todo el rango de X para representar los cuantiles y la predicción media
# ==============================================================================
pred_cuantiles = modelo_final.predict_quantiles(X=x.reshape(-1, 1), quantiles=[0.1, 0.9])
pred_cuantiles = pd.DataFrame(pred_cuantiles.transpose(), columns=['q_01', 'q_90'])
pred_media     = modelo_final.predict(X=x.reshape(-1, 1))
In [38]:
# Gráfico
# ==============================================================================
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 4))
ax.scatter(x, y, alpha = 0.2, c = "#333")
ax.plot(x, cuantil_10, c = "black")
ax.plot(x, cuantil_90, c = "black")
ax.plot(x, pred_cuantiles.q_01, c = "blue", label='intevalo cuantilico predicho')
ax.plot(x, pred_cuantiles.q_90, c = "blue")
ax.fill_between(x, cuantil_10, cuantil_90, alpha=0.2, color='red',
                label='intevalo cuantilico real 0.1-0.9')
ax.set_xticks(range(0,25))
ax.set_title('Evolución del consumo eléctrico a lo largo del día',
             fontdict={'fontsize':15})
ax.set_xlabel('Hora del día')
ax.set_ylabel('Consumo eléctrico (MWh)')
plt.legend();

Covertura


Una de las métricas empleadas para evaluar intervalos es la cobertura (coverage). Su valor se corresponde con el porcentaje de observaciones que caen dentro del intervalo estimado. Idealmente, la cobertura debe de ser igual a la confianza del intervalo, por lo que, en la práctica, cuanto más próximo sea su valor, mejor.

En este ejemplo, se han calculado los cuantiles 0.1 y 0.9, así que, el intervalo, tiene una confianza del 80%. Si el intervalo predicho es correcto, aproximadamente el 80% de las observaciones estarán dentro.

In [39]:
# Cobertura del intervalo predicho
# ==============================================================================
dentro_intervalo = np.where((y >= cuantil_10) & (y <= cuantil_90), True, False)
cobertura = dentro_intervalo.mean()
In [40]:
print(f"Cobertura del intervalo predicho: {100 * cobertura}")
Cobertura del intervalo predicho: 80.80000000000001

Detección de anomalías (Outliers)


Conocer los cuantiles condicionales de la variable respuesta permite identificar observaciones que se alejan atípicamente por encima o por debajo del valor esperado, para un determinado valor de los predictores. Véase el siguiente ejemplo en el que se trata de identificar precios anómalos de diamantes.

Librerías

In [41]:
# Librerías
# ==============================================================================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import norm
import seaborn as sns
from skranger.ensemble import RangerForestRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import ParameterGrid
from sklearn.model_selection import RepeatedKFold
from sklearn.preprocessing import OneHotEncoder
import multiprocessing

Datos

In [42]:
# Datos
# ==============================================================================
datos = sns.load_dataset("diamonds")
datos = datos[datos.cut.isin(["Fair", "Good"])]
datos['cut'] = datos['cut'].cat.remove_unused_categories()
datos['price'] = np.sqrt(datos['price'])
#datos = datos[['price', 'cut', 'color', 'carat']]
datos.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 6516 entries, 2 to 53936
Data columns (total 10 columns):
 #   Column   Non-Null Count  Dtype   
---  ------   --------------  -----   
 0   carat    6516 non-null   float64 
 1   cut      6516 non-null   category
 2   color    6516 non-null   category
 3   clarity  6516 non-null   category
 4   depth    6516 non-null   float64 
 5   table    6516 non-null   float64 
 6   price    6516 non-null   float64 
 7   x        6516 non-null   float64 
 8   y        6516 non-null   float64 
 9   z        6516 non-null   float64 
dtypes: category(3), float64(7)
memory usage: 427.2 KB

El siguiente gráfico muestra la distribución del precio de los diamantes en función de su peso, calidad del corte y color.

In [43]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(9, 4), sharex=True,sharey=True)

for i, cut in enumerate(set(datos['cut'])):
    sns.scatterplot(
        x         = "carat",
        y         = "price",
        hue       = "color",
        palette   = "viridis",
        linewidth = 0,
        alpha     = 0.5,
        data      = datos[datos.cut==cut],
        ax        = ax[i]
    )
    ax[i].set_title(cut, fontsize  = 'x-large')
    ax[i].axvline(x=1, linestyle   ='--',  color='black')
    ax[i].axvline(x=1.5, linestyle ='--',  color='black')
    ax[i].axhline(y=75, linestyle  ='--',  color='black')
    
fig.tight_layout()
plt.subplots_adjust(top = 0.85)
fig.suptitle('Distribución del precio de los diamantes', fontsize = 12);

Puede verse que, dependiendo del peso, calidad del corte y color, el precio varía notablemente. Por ejemplo, apenas hay ningún diamante con un peso entre 1 y 1.5 unidades, de color J que tenga un precio de más de 7000$, pero sí los hay de este peso y precio con otros colores.

Se añaden varias anomalías simuladas en cada uno de los grupos.

In [44]:
# Anomalías simuladas
anomalias_1 = datos.sample(frac=0.005, random_state=1234).copy()
anomalias_1['price'] = anomalias_1['price']*2.5

anomalias_2 = datos.sample(frac=0.005, random_state=1234).copy()
anomalias_2['price'] = anomalias_2['price']/2.5
In [45]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(9, 4), sharex=True,sharey=True)

for i, cut in enumerate(set(datos['cut'])):
    sns.scatterplot(
        x         = "carat",
        y         = "price",
        hue       = "color",
        palette   = "viridis",
        linewidth = 0,
        alpha     = 0.5,
        data      = datos[datos.cut==cut],
        legend    = False,
        ax        = ax[i]
    )
    sns.scatterplot(
        x         = "carat",
        y         = "price",
        color     = "red",
        linewidth = 0,
        alpha     = 1,
        data      = anomalias_1[anomalias_1.cut==cut],
        ax        = ax[i]
    )
    
    sns.scatterplot(
        x         = "carat",
        y         = "price",
        color     = "red",
        linewidth = 0,
        alpha     = 1,
        data      = anomalias_2[anomalias_2.cut==cut],
        label     = 'anomalía',
        ax        = ax[i]
    )
    ax[i].set_title(cut, fontsize = 'x-large')
    
fig.tight_layout()
plt.subplots_adjust(top = 0.85)
fig.suptitle('Distribución del precio de los diamantes', fontsize = 12);
In [46]:
# Se añaden anomalías al set de datos
datos["anomalia"]       = False
anomalias_1["anomalia"] = True
anomalias_2["anomalia"] = True
datos = pd.concat([datos, anomalias_1, anomalias_2]).reset_index(drop=True)
datos.head()
Out[46]:
carat cut color clarity depth table price x y z anomalia
0 0.23 Good E VS1 56.9 65.0 18.083141 4.05 4.07 2.31 False
1 0.31 Good J SI2 63.3 58.0 18.303005 4.34 4.35 2.75 False
2 0.22 Fair E VS2 65.1 61.0 18.357560 3.87 3.78 2.49 False
3 0.30 Good J SI1 64.0 55.0 18.411953 4.25 4.28 2.73 False
4 0.30 Good J SI1 63.4 54.0 18.734994 4.23 4.29 2.70 False

Modelo

In [47]:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import make_column_selector

# Se hace one-hot-encoding de las columnas cualitativas. Para mantener las
# columnas a las que no se les aplica ninguna transformación se tiene que indicar
# remainder='passthrough'.
cat_cols = datos.select_dtypes(include=['object', 'category']).columns.to_list()

preprocessor = ColumnTransformer(
                   [('onehot', OneHotEncoder(handle_unknown='ignore'), cat_cols)],
                   remainder='passthrough'
               )

X_train_prep = preprocessor.fit_transform(
                   datos.drop(columns=['price', 'anomalia'])
               )
In [48]:
modelo = RangerForestRegressor(
    n_estimators  = 5000,
    min_node_size = 300,
    max_depth     = 3,
    quantiles     = True,
)
In [49]:
modelo.fit(
    X=X_train_prep,
    y=datos.price
)
Out[49]:
RangerForestRegressor(max_depth=3, min_node_size=300, n_estimators=5000,
                      quantiles=True, respect_categorical_features='ignore')

Anomalías

Se identifica como anomalías aquellos diamantes con un precio por debajo del percentil del 2% o por encima del percentil 98% predichos por el modelo.

In [50]:
# Se predicen los cuantiles para cada valor observado
# ==============================================================================
pred_cuantiles = modelo.predict_quantiles(X=X_train_prep, quantiles=[0.01, 0.99])
pred_cuantiles = pd.DataFrame(pred_cuantiles.transpose(), columns=['q_01', 'q_99'])

# Se identifican las observaciones cuyo valor real está fuera del intervalo
# ==============================================================================
datos = pd.concat([datos.reset_index(), pred_cuantiles], axis=1)
datos['fuera_intervalo'] = ~datos.price.between(datos.q_01, datos.q_99)
In [51]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(9, 4), sharex=True,sharey=True)

for i, cut in enumerate(set(datos['cut'])):
    sns.scatterplot(
        x         = "carat",
        y         = "price",
        hue       = "color",
        palette   = "viridis",
        linewidth = 0,
        alpha     = 0.5,
        data      = datos[datos.cut==cut],
        legend    = False,
        ax        = ax[i]
    )
    sns.scatterplot(
        x         = "carat",
        y         = "price",
        color     = "red",
        linewidth = 0,
        alpha     = 1,
        data      = anomalias_1[anomalias_1.cut==cut],
        ax        = ax[i]
    )
    
    sns.scatterplot(
        x         = "carat",
        y         = "price",
        color     = "red",
        linewidth = 0,
        alpha     = 1,
        data      = anomalias_2[anomalias_2.cut==cut],
        label     = 'anomalía',
        ax        = ax[i]
    )
    
    sns.scatterplot(
        x         = "carat",
        y         = "price",
        edgecolor = 'black',
        facecolor = 'none',
        s         = 100,
        alpha     = 1,
        data      = datos[(datos.cut==cut) & (datos.fuera_intervalo==True)],
        label     = 'anomalía detectada',
        ax        = ax[i]
    )
    
    ax[i].set_title(cut, fontsize = 'x-large')
    
fig.tight_layout()
plt.subplots_adjust(top = 0.85)
fig.suptitle('Anomálias detectadas', fontsize = 12);
In [52]:
# Anomalías reales vs anomalías detectadas
# ==============================================================================
pd.crosstab(
    index   = datos["anomalia"], 
    columns = datos["fuera_intervalo"],
    margins = False
)
Out[52]:
fuera_intervalo False True
anomalia
False 6496 20
True 25 41

Consideraciones prácticas


Uno de los hiperparámetros de quantile regression forest es el número mínimo de observaciones en los nodos terminales. Cómo el algoritmo únicamente emplea aquellas observaciones para las que $Y≤y$, si son muy pocas las observaciones en los nodos terminales, apenas las habrá que cumplan la condición. Esto se traduce en que quantile regression forest es notablemente sensible a overfitting por valores bajos de este hiperparámetro. Véase cómo afecta pasar de 10 a 200 observaciones mínimas por nodo al predecir los cuantiles.

In [53]:
# min_node_size=10
# ==============================================================================
modelo = RangerForestRegressor(n_estimators=1000, min_node_size=10, quantiles=True)
modelo.fit(X=x.reshape(-1, 1), y=y)
pred_min_node_size_10 = modelo.predict_quantiles(
                            X=x.reshape(-1, 1),
                            quantiles=[0.1, 0.9]
                        )
pred_min_node_size_10 = pd.DataFrame(
                            pred_min_node_size_10.transpose(),
                            columns=['q_01', 'q_90']
                        )
In [54]:
# min_node_size=200
# ==============================================================================
modelo = RangerForestRegressor(n_estimators=1000, min_node_size=200, quantiles=True)
modelo.fit(X=x.reshape(-1, 1), y=y)
pred_min_node_size_200 = modelo.predict_quantiles(
                            X=x.reshape(-1, 1),
                            quantiles=[0.1, 0.9]
                         )
pred_min_node_size_200 = pd.DataFrame(
                            pred_min_node_size_200.transpose(),
                            columns=['q_01', 'q_90']
                         )
In [55]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 4))
ax.scatter(x, y, alpha = 0.1, c = "#333")
ax.plot(x, cuantil_10, c = "black")
ax.plot(x, cuantil_90, c = "black")
ax.plot(x, pred_min_node_size_10.q_01, c = "#cd5700", alpha = 0.7, label = 'min_node_size=10')
ax.plot(x, pred_min_node_size_10.q_90, c = "#cd5700", alpha = 0.7)
ax.plot(x, pred_min_node_size_200.q_01, c = "blue", label = 'min_node_size=200')
ax.plot(x, pred_min_node_size_200.q_90, c = "blue")
ax.set_xticks(range(0,25))
ax.set_title('Evolución del consumo eléctrico a lo largo del día',
             fontdict={'fontsize':15})
ax.set_xlabel('Hora del día')
ax.set_ylabel('Consumo eléctrico (MWh)')
plt.legend();