Forêt aléatoire et prédiction

13

J'essaie de comprendre comment fonctionne Random Forest. J'ai une compréhension de la façon dont les arbres sont construits, mais je ne comprends pas comment Random Forest fait des prédictions sur l'échantillon hors du sac. Quelqu'un pourrait-il me donner une explication simple, s'il vous plaît? :)

user1665355
la source

Réponses:

16

Chaque arbre de la forêt est construit à partir d'un échantillon bootstrap des observations dans vos données d'entraînement. Ces observations dans l'échantillon bootstrap construisent l'arborescence, tandis que celles qui ne sont pas dans l'échantillon bootstrap forment les échantillons hors sac (ou OOB).

Il doit être clair que les mêmes variables sont disponibles pour les cas dans les données utilisées pour construire un arbre que pour les cas dans l'échantillon OOB. Pour obtenir des prédictions pour l'échantillon OOB, chacune est transmise dans l'arborescence actuelle et les règles de l'arborescence sont suivies jusqu'à son arrivée dans un nœud terminal. Cela donne les prévisions OOB pour cet arbre particulier.

Ce processus est répété un grand nombre de fois, chaque arbre étant formé sur un nouvel échantillon bootstrap à partir des données d'apprentissage et des prédictions pour les nouveaux échantillons OOB dérivés.

Au fur et à mesure que le nombre d'arbres augmente, n'importe quel échantillon sera dans les échantillons OOB plus d'une fois, ainsi la "moyenne" des prédictions sur les N arbres où un échantillon est dans l'OOB est utilisée comme prédiction OOB pour chaque échantillon d'apprentissage pour arbres 1, ..., N. Par "moyenne", nous utilisons la moyenne des prédictions pour une réponse continue, ou le vote majoritaire peut être utilisé pour une réponse catégorique (le vote majoritaire est la classe avec le plus de votes sur l'ensemble des arbres 1, ..., N).

Par exemple, supposons que nous avions les prédictions OOB suivantes pour 10 échantillons en formation sur 10 arbres

set.seed(123)
oob.p <- matrix(rpois(100, lambda = 4), ncol = 10)
colnames(oob.p) <- paste0("tree", seq_len(ncol(oob.p)))
rownames(oob.p) <- paste0("samp", seq_len(nrow(oob.p)))
oob.p[sample(length(oob.p), 50)] <- NA
oob.p

> oob.p
       tree1 tree2 tree3 tree4 tree5 tree6 tree7 tree8 tree9 tree10
samp1     NA    NA     7     8     2     1    NA     5     3      2
samp2      6    NA     5     7     3    NA    NA    NA    NA     NA
samp3      3    NA     5    NA    NA    NA     3     5    NA     NA
samp4      6    NA    10     6    NA    NA     3    NA     6     NA
samp5     NA     2    NA    NA     2    NA     6     4    NA     NA
samp6     NA     7    NA     4    NA     2     4     2    NA     NA
samp7     NA    NA    NA     5    NA    NA    NA     3     9      5
samp8      7     1     4    NA    NA     5     6    NA     7     NA
samp9      4    NA    NA     3    NA     7     6     3    NA     NA
samp10     4     8     2     2    NA    NA     4    NA    NA      4

NAsignifie que l'échantillon se trouvait dans les données d'apprentissage de cet arbre (en d'autres termes, il ne faisait pas partie de l'échantillon OOB).

La moyenne des non- NAvaleurs pour chaque ligne donne la prédiction OOB pour chaque échantillon, pour toute la forêt

> rowMeans(oob.p, na.rm = TRUE)
 samp1  samp2  samp3  samp4  samp5  samp6  samp7  samp8  samp9 samp10 
  4.00   5.25   4.00   6.20   3.50   3.80   5.50   5.00   4.60   4.00

Comme chaque arbre est ajouté à la forêt, nous pouvons calculer l'erreur OOB jusqu'à inclure cet arbre. Par exemple, voici les moyennes cumulées pour chaque échantillon:

FUN <- function(x) {
  na <- is.na(x)
  cs <- cumsum(x[!na]) / seq_len(sum(!na))
  x[!na] <- cs
  x
}
t(apply(oob.p, 1, FUN))

> print(t(apply(oob.p, 1, FUN)), digits = 3)
       tree1 tree2 tree3 tree4 tree5 tree6 tree7 tree8 tree9 tree10
samp1     NA    NA  7.00  7.50  5.67  4.50    NA   4.6  4.33    4.0
samp2      6    NA  5.50  6.00  5.25    NA    NA    NA    NA     NA
samp3      3    NA  4.00    NA    NA    NA  3.67   4.0    NA     NA
samp4      6    NA  8.00  7.33    NA    NA  6.25    NA  6.20     NA
samp5     NA     2    NA    NA  2.00    NA  3.33   3.5    NA     NA
samp6     NA     7    NA  5.50    NA  4.33  4.25   3.8    NA     NA
samp7     NA    NA    NA  5.00    NA    NA    NA   4.0  5.67    5.5
samp8      7     4  4.00    NA    NA  4.25  4.60    NA  5.00     NA
samp9      4    NA    NA  3.50    NA  4.67  5.00   4.6    NA     NA
samp10     4     6  4.67  4.00    NA    NA  4.00    NA    NA    4.0

De cette façon, nous voyons comment la prédiction est accumulée sur les N arbres dans la forêt jusqu'à une itération donnée. Si vous lisez les lignes, la non- NAvaleur la plus à droite est celle que je montre ci-dessus pour la prédiction OOB. C'est ainsi que des traces de performances OOB peuvent être faites - un RMSEP peut être calculé pour les échantillons OOB sur la base des prévisions OOB accumulées cumulativement sur les N arbres.

Notez que le code R affiché n'est pas tiré des internes du code randomForest dans le package randomForest pour R - Je viens de créer un code simple afin que vous puissiez suivre ce qui se passe une fois que les prédictions de chaque arbre sont déterminées.

C'est parce que chaque arbre est construit à partir d'un échantillon bootstrap et qu'il existe un grand nombre d'arbres dans une forêt aléatoire, de sorte que chaque observation d'ensemble d'apprentissage se trouve dans l'échantillon OOB pour un ou plusieurs arbres, que les prévisions OOB peuvent être fournies pour tous échantillons dans les données de formation.

J'ai ignoré des problèmes tels que des données manquantes pour certains cas OOB, etc., mais ces problèmes concernent également une seule régression ou un seul arbre de classification. Notez également que chaque arbre dans une forêt utilise uniquement mtrydes variables sélectionnées au hasard.

Réintégrer Monica - G. Simpson
la source
Excellente réponse Gavin! Lorsque vous écrivez "To get predictions for the OOB sample, each one is passed down the current tree and the rules for the tree followed until it arrives in a terminal node", avez-vous une explication simple de ce que rules for the treec'est? Et est-ce que je comprends samplecorrectement une ligne si je comprends que les échantillons sont groupsdes observations dans lesquelles les arbres divisent les données?
user1665355
@ user1665355 Je suppose que vous avez compris comment les arbres de régression ou de classification ont été construits? Les arbres en RF ne sont pas différents (sauf peut-être dans les règles d'arrêt). Chaque arbre divise les données d'apprentissage en groupes d'échantillons avec des "valeurs" similaires pour la réponse. L'emplacement variable et fractionné (par exemple pH> 4,5) qui prédit le mieux (c'est-à-dire minimise "l'erreur") forme la première division ou règle de l'arbre. Chaque branche de cette division est ensuite considérée à son tour et de nouvelles divisions / règles sont identifiées qui minimisent "l'erreur" de l'arbre. Il s'agit de l'algorithme de partitionnement récursif binaire. Les scissions sont les règles.
Rétablir Monica - G. Simpson
@ user1665355 Oui, désolé je viens d'un champ où un échantillon est une observation, une ligne dans l'ensemble de données. Mais lorsque vous commencez à parler d'un échantillon bootstrap, c'est un ensemble de N observations, tirées avec remplacement à partir des données d'apprentissage et a donc N lignes ou observations. J'essaierai de nettoyer ma terminologie plus tard.
Rétablir Monica - G. Simpson
Merci! Je suis très nouveau sur RF, donc désolé pour des questions peut-être stupides :) Je pense que je comprends presque tout ce que vous avez écrit, très bonne explication! Je me demande simplement à propos de l'emplacement variable et fractionné (par exemple, pH> 4,5) qui prédit le mieux (c.-à-d. Minimise "l'erreur") la première division ou règle de l'arbre ... Je ne comprends pas quelle est l'erreur. : / Je lis et j'essaie de comprendre http://www.ime.unicamp.br/~ra109078/PED/Data%20Minig%20with%20R/Data%20Mining%20with%20R.pdf. À la page 115-116, les auteurs utilisent RF pour choisir variable importancedes indicateurs techniques.
user1665355
L '"erreur" dépend du type d'arbre à installer. La déviance est la mesure habituelle des réponses continues (gaussiennes). Dans le package rpart, le coefficient de Gini est la valeur par défaut pour les réponses catégoriques, mais il en existe d'autres pour différents modèles, etc. Vous devriez vous prévaloir d'un bon livre sur les arbres et les RF si vous souhaitez le déployer avec succès. Les mesures de l'amélioration de la variable sont quelque chose de différent - elles mesurent "l'importance" de chaque variable dans l'ensemble de données en voyant à quel point quelque chose change lorsque cette variable est utilisée pour s'adapter à un arbre et lorsque cette variable n'est pas utilisée.
Rétablir Monica - G. Simpson