La régression aléatoire de la forêt ne prévoit pas plus que les données d'entraînement

12

J'ai remarqué que lors de la construction de modèles de régression aléatoire des forêts, au moins dans R, la valeur prédite ne dépasse jamais la valeur maximale de la variable cible vue dans les données d'apprentissage. À titre d'exemple, consultez le code ci-dessous. Je construis un modèle de régression à prévoir en mpgfonction des mtcarsdonnées. Je construis des modèles OLS et des forêts aléatoires, et je les utilise pour prédire mpgune voiture hypothétique qui devrait avoir une très bonne économie de carburant. L'OLS prédit une mpgforêt élevée , comme prévu, mais pas une forêt aléatoire. J'ai également remarqué cela dans des modèles plus complexes. Pourquoi est-ce?

> library(datasets)
> library(randomForest)
> 
> data(mtcars)
> max(mtcars$mpg)
[1] 33.9
> 
> set.seed(2)
> fit1 <- lm(mpg~., data=mtcars) #OLS fit
> fit2 <- randomForest(mpg~., data=mtcars) #random forest fit
> 
> #Hypothetical car that should have very high mpg
> hypCar <- data.frame(cyl=4, disp=50, hp=40, drat=5.5, wt=1, qsec=24, vs=1, am=1, gear=4, carb=1)
> 
> predict(fit1, hypCar) #OLS predicts higher mpg than max(mtcars$mpg)
      1 
37.2441 
> predict(fit2, hypCar) #RF does not predict higher mpg than max(mtcars$mpg)
       1 
30.78899 
Gaurav Bansal
la source
Est-il courant que les gens appellent les régressions linéaires OLS? J'ai toujours pensé à l'OLS comme méthode.
Hao Ye
1
Je crois que l'OLS est la méthode par défaut de régression linéaire, au moins dans R.
Gaurav Bansal
Pour les arbres / forêts aléatoires, les prévisions sont la moyenne des données d'apprentissage dans le nœud correspondant. Il ne peut donc pas être supérieur aux valeurs des données d'entraînement.
Jason
1
Je suis d'accord mais au moins trois autres utilisateurs y ont répondu.
HelloWorld

Réponses:

12

Comme cela a déjà été mentionné dans les réponses précédentes, la forêt aléatoire pour les arbres de régression / régression ne produit pas de prédictions attendues pour les points de données au-delà de la portée de la plage de données de formation car ils ne peuvent pas (bien) extrapoler. Un arbre de régression se compose d'une hiérarchie de nœuds, où chaque nœud spécifie un test à effectuer sur une valeur d'attribut et chaque nœud feuille (terminal) spécifie une règle pour calculer une sortie prédite. Dans votre cas, l'observation des tests circule à travers les arbres jusqu'aux nœuds foliaires indiquant, par exemple, "si x> 335, alors y = 15", qui sont ensuite moyennés par forêt aléatoire.

Voici un script R visualisant la situation avec une forêt aléatoire et une régression linéaire. Dans le cas d'une forêt aléatoire, les prédictions sont constantes pour tester des points de données qui sont soit inférieurs à la valeur x des données d'entraînement les plus faibles, soit supérieurs à la valeur x des données d'entraînement les plus élevées.

library(datasets)
library(randomForest)
library(ggplot2)
library(ggthemes)

# Import mtcars (Motor Trend Car Road Tests) dataset
data(mtcars)

# Define training data
train_data = data.frame(
    x = mtcars$hp,  # Gross horsepower
    y = mtcars$qsec)  # 1/4 mile time

# Train random forest model for regression
random_forest <- randomForest(x = matrix(train_data$x),
                              y = matrix(train_data$y), ntree = 20)
# Train linear regression model using ordinary least squares (OLS) estimator
linear_regr <- lm(y ~ x, train_data)

# Create testing data
test_data = data.frame(x = seq(0, 400))

# Predict targets for testing data points
test_data$y_predicted_rf <- predict(random_forest, matrix(test_data$x)) 
test_data$y_predicted_linreg <- predict(linear_regr, test_data)

# Visualize
ggplot2::ggplot() + 
    # Training data points
    ggplot2::geom_point(data = train_data, size = 2,
                        ggplot2::aes(x = x, y = y, color = "Training data")) +
    # Random forest predictions
    ggplot2::geom_line(data = test_data, size = 2, alpha = 0.7,
                       ggplot2::aes(x = x, y = y_predicted_rf,
                                    color = "Predicted with random forest")) +
    # Linear regression predictions
    ggplot2::geom_line(data = test_data, size = 2, alpha = 0.7,
                       ggplot2::aes(x = x, y = y_predicted_linreg,
                                    color = "Predicted with linear regression")) +
    # Hide legend title, change legend location and add axis labels
    ggplot2::theme(legend.title = element_blank(),
                   legend.position = "bottom") + labs(y = "1/4 mile time",
                                                      x = "Gross horsepower") +
    ggthemes::scale_colour_colorblind()

Extrapolation avec forêt aléatoire et régression linéaire

tuomastik
la source
16

Il n'y a aucun moyen d'extrapoler une forêt aléatoire comme le fait un OLS. La raison est simple: les prédictions d'une forêt aléatoire sont faites en faisant la moyenne des résultats obtenus dans plusieurs arbres. Les arbres eux-mêmes produisent la valeur moyenne des échantillons dans chaque nœud terminal, les feuilles. Il est impossible que le résultat soit en dehors de la plage des données d'entraînement, car la moyenne est toujours dans la plage de ses constituants.

En d'autres termes, il est impossible qu'une moyenne soit supérieure (ou inférieure) à chaque échantillon, et les régressions des forêts aléatoires sont basées sur la moyenne.

Pyromane
la source
11

Les arbres de décision / forêts aléatoires ne peuvent pas extrapoler en dehors des données de formation. Et bien que l'OLS puisse le faire, ces prévisions doivent être examinées avec prudence; car le schéma identifié peut ne pas continuer en dehors de la plage observée.

B.Frost
la source