Comment tracer un échantillon d'arbre à partir de randomForest :: getTree ()? [fermé]

62

Tout le monde a des suggestions de code ou de bibliothèque sur la manière de tracer réellement quelques échantillons d’arbres de:

getTree(rfobj, k, labelVar=TRUE)

(Oui, je sais que vous n’êtes pas censé le faire de manière opérationnelle, RF est une boîte noire, etc., etc. comment bien mes facteurs encodés fonctionnent, etc.)


Questions précédentes sans réponse décente:

Je veux en fait tracer un arbre de l'échantillon . Alors ne discutez pas avec moi à ce sujet, déjà. Je ne demande pas au sujet varImpPlot(importance variable Plot) ou partialPlotou MDSPlotou ces parcelles , j'ai déjà ceux -ci , mais ils ne sont pas un substitut à voir un arbre échantillon. Oui, je peux inspecter visuellement la sortie de getTree(...,labelVar=TRUE).

(Je suppose qu'une plot.rf.tree()contribution serait très bien reçue.)

smci
la source
6
Je ne vois pas la nécessité d'argumenter de manière préventive, surtout si vous demandez à quelqu'un de faire du bénévolat pour vous aider. ça ne passe pas bien. CV a une politique de l'étiquette - vous voudrez peut-être lire notre FAQ .
gung - Réintégrer Monica
9
@gung: toutes les questions précédentes sur ce sujet ont été écrasées par des personnes insistant sur le fait qu'il n'était pas nécessaire, ni même hérétique, de tracer un échantillon d'arbre. Lisez les citations que j'ai données. Je cherche ici un croquis sur la manière de coder un arbre RF.
smci
3
Je vois des réponses où les utilisateurs essaient d’être utiles et répondent à la question, ainsi que des commentaires remettant en question la prémisse de l’idée (qui, j’estime sincèrement, se veulent également utiles). Il est certainement possible de reconnaître que certaines personnes ne seront pas d’accord pour ne pas être téméraires.
gung - Rétablir Monica
4
Je ne vois pas de réponses où quiconque ait déjà tracé un arbre, en plus d'un an. Je cherche une réponse spécifique à cette question.
smci
1
Il est possible de tracer un seul arbre construit avec cforest(dans le paquet du parti ). Sinon, vous devrez convertir l'objet data.framerenvoyé par randomForest::getTreeen un treeobjet de type.
chl

Réponses:

44

Première solution (et la plus simple): si vous ne souhaitez pas vous en tenir à la RF classique, telle qu'implémentée dans Andy Liaw randomForest, vous pouvez essayer le paquet party qui fournit une implémentation différente de l' algorithme RF original (utilisation d'arbres conditionnels et de schémas d'agrégation). sur le poids moyen des unités). Ensuite, comme indiqué dans cette publication R-help , vous pouvez tracer un seul membre de la liste des arbres. Il semble que tout se passe bien, pour autant que je sache. Ci-dessous, un graphique d’un arbre généré par cforest(Species ~ ., data=iris, controls=cforest_control(mtry=2, mincriterion=0)).

entrez la description de l'image ici

Deuxième (presque aussi facile) solution: La plupart des techniques à base d' arbres dans R ( tree, rpart, TWIX, etc.) offre une treestructure de -comme pour l' impression / traçage d' un seul arbre. L'idée serait de convertir la sortie de randomForest::getTreeen un tel objet R, même si cela n'a aucun sens d'un point de vue statistique. En gros, il est facile d’accéder à l’arborescence à partir d’un treeobjet, comme indiqué ci-dessous. Veuillez noter que cela diffère légèrement en fonction du type de tâche (régression ou classification). Dans ce dernier cas, les probabilités spécifiques à la classe seront ajoutées à la dernière colonne de obj$frame(qui est a data.frame).

> library(tree)
> tr <- tree(Species ~ ., data=iris)
> tr
node), split, n, deviance, yval, (yprob)
      * denotes terminal node

 1) root 150 329.600 setosa ( 0.33333 0.33333 0.33333 )  
   2) Petal.Length < 2.45 50   0.000 setosa ( 1.00000 0.00000 0.00000 ) *
   3) Petal.Length > 2.45 100 138.600 versicolor ( 0.00000 0.50000 0.50000 )  
     6) Petal.Width < 1.75 54  33.320 versicolor ( 0.00000 0.90741 0.09259 )  
      12) Petal.Length < 4.95 48   9.721 versicolor ( 0.00000 0.97917 0.02083 )  
        24) Sepal.Length < 5.15 5   5.004 versicolor ( 0.00000 0.80000 0.20000 ) *
        25) Sepal.Length > 5.15 43   0.000 versicolor ( 0.00000 1.00000 0.00000 ) *
      13) Petal.Length > 4.95 6   7.638 virginica ( 0.00000 0.33333 0.66667 ) *
     7) Petal.Width > 1.75 46   9.635 virginica ( 0.00000 0.02174 0.97826 )  
      14) Petal.Length < 4.95 6   5.407 virginica ( 0.00000 0.16667 0.83333 ) *
      15) Petal.Length > 4.95 40   0.000 virginica ( 0.00000 0.00000 1.00000 ) *
> tr$frame
            var   n        dev       yval splits.cutleft splits.cutright yprob.setosa yprob.versicolor yprob.virginica
1  Petal.Length 150 329.583687     setosa          <2.45           >2.45   0.33333333       0.33333333      0.33333333
2        <leaf>  50   0.000000     setosa                                  1.00000000       0.00000000      0.00000000
3   Petal.Width 100 138.629436 versicolor          <1.75           >1.75   0.00000000       0.50000000      0.50000000
6  Petal.Length  54  33.317509 versicolor          <4.95           >4.95   0.00000000       0.90740741      0.09259259
12 Sepal.Length  48   9.721422 versicolor          <5.15           >5.15   0.00000000       0.97916667      0.02083333
24       <leaf>   5   5.004024 versicolor                                  0.00000000       0.80000000      0.20000000
25       <leaf>  43   0.000000 versicolor                                  0.00000000       1.00000000      0.00000000
13       <leaf>   6   7.638170  virginica                                  0.00000000       0.33333333      0.66666667
7  Petal.Length  46   9.635384  virginica          <4.95           >4.95   0.00000000       0.02173913      0.97826087
14       <leaf>   6   5.406735  virginica                                  0.00000000       0.16666667      0.83333333
15       <leaf>  40   0.000000  virginica                                  0.00000000       0.00000000      1.00000000

Ensuite, il existe des méthodes permettant d’imprimer et de tracer ces objets. Les fonctions de touche sont une tree:::plot.treeméthode générique (je mets un triple :qui permet de visualiser directement le code en R) en s'appuyant sur tree:::treepl(affichage graphique) et tree:::treeco(calculer les coordonnées des nœuds). Ces fonctions attendent la obj$framereprésentation de l'arbre. Autres problèmes subtils: (1) l’argument type = c("proportional", "uniform")dans la méthode de tracé par défaut tree:::plot.tree, aide à gérer la distance verticale entre les nœuds ( proportionalsignifie qu’il est proportionnel à la déviance, uniformqu’il est fixe); (2) vous devez compléter plot(tr)par un appel à text(tr)pour ajouter des étiquettes de texte aux noeuds et aux divisions, ce qui signifie dans ce cas que vous devrez également jeter un coup d'œil tree:::text.tree.

La getTreeméthode de randomForestrenvoie une structure différente, documentée dans l'aide en ligne. Une sortie typique est illustrée ci-dessous, avec les nœuds terminaux indiqués par le statuscode (-1). (Là encore, le résultat sera différent selon le type de tâche, mais uniquement sur les colonnes statuset prediction.)

> library(randomForest)
> rf <- randomForest(Species ~ ., data=iris)
> getTree(rf, 1, labelVar=TRUE)
   left daughter right daughter    split var split point status prediction
1              2              3 Petal.Length        4.75      1       <NA>
2              4              5 Sepal.Length        5.45      1       <NA>
3              6              7  Sepal.Width        3.15      1       <NA>
4              8              9  Petal.Width        0.80      1       <NA>
5             10             11  Sepal.Width        3.60      1       <NA>
6              0              0         <NA>        0.00     -1  virginica
7             12             13  Petal.Width        1.90      1       <NA>
8              0              0         <NA>        0.00     -1     setosa
9             14             15  Petal.Width        1.55      1       <NA>
10             0              0         <NA>        0.00     -1 versicolor
11             0              0         <NA>        0.00     -1     setosa
12            16             17 Petal.Length        5.40      1       <NA>
13             0              0         <NA>        0.00     -1  virginica
14             0              0         <NA>        0.00     -1 versicolor
15             0              0         <NA>        0.00     -1  virginica
16             0              0         <NA>        0.00     -1 versicolor
17             0              0         <NA>        0.00     -1  virginica

Si vous parvenez à convertir le tableau ci-dessus en celui généré par tree, vous pourrez probablement le personnaliser tree:::treepl, tree:::treecoet tree:::text.treepour l'adapter à vos besoins, bien que je ne possède pas d'exemple de cette approche. En particulier, vous voudrez probablement vous débarrasser de l'utilisation de la déviance, des probabilités de classe, etc., qui n'ont pas de sens dans RF. Tout ce que vous voulez, c'est définir les coordonnées des nœuds et les valeurs fractionnées. Vous pourriez utiliser fixInNamespace()pour cela, mais, pour être honnête, je ne suis pas sûr que ce soit la bonne façon de faire.

Troisième solution (et certainement astucieuse): écrivez une vraie as.treefonction d’assistance qui atténuera tous les "correctifs" ci-dessus. Vous pouvez ensuite utiliser les méthodes de traçage de R ou, probablement mieux, Klimt (directement à partir de R) pour afficher des arbres individuels.

chl
la source
40

J'ai quatre ans de retard, mais si vous voulez vraiment vous en tenir au randomForestpaquet (et il y a de bonnes raisons de le faire) et si vous voulez réellement visualiser l'arbre, vous pouvez utiliser le paquet de reproches .

Le paquet n'est pas très bien documenté (vous pouvez trouver la documentation ici ), mais tout est assez simple. Pour installer le package, reportez-vous à initialize.R dans le référentiel, exécutez simplement les éléments suivants:

options(repos='http://cran.rstudio.org')
have.packages <- installed.packages()
cran.packages <- c('devtools','plotrix','randomForest','tree')
to.install <- setdiff(cran.packages, have.packages[,1])
if(length(to.install)>0) install.packages(to.install)

library(devtools)
if(!('reprtree' %in% installed.packages())){
  install_github('araastat/reprtree')
}
for(p in c(cran.packages, 'reprtree')) eval(substitute(library(pkg), list(pkg=p)))

Alors allez-y et faites votre modèle et votre arbre:

library(randomForest)
library(reprtree)

model <- randomForest(Species ~ ., data=iris, importance=TRUE, ntree=500, mtry = 2, do.trace=100)

reprtree:::plot.getTree(model)

Et voilà! Beau et simple.

arbre généré à partir de plot.getTree (modèle)

Vous pouvez consulter le dépôt github pour en savoir plus sur les autres méthodes du package. En fait, si vous vérifiez sur plot.getTree.R , vous remarquerez que l'auteur utilise sa propre implémentation as.tree()dont chl ♦ suggère de vous construire vous-même dans sa réponse. Cela signifie que vous pourriez faire ceci:

tree <- getTree(model, k=1, labelVar=TRUE)
realtree <- reprtree:::as.tree(tree, model)

Et puis potentiellement utiliser realtreeavec d'autres paquets de traçage d'arbres tels que tree .

jgozal
la source
Merci beaucoup, j'accepte toujours avec joie les réponses, cela semble être un domaine où les gens sont dissatistifiés avec les offres. Je suppose que la nouvelle chose à faire serait de soutenir xgboostaussi.
smci
6
aucun problème. Il m'a fallu des heures pour trouver la bibliothèque / le paquet, alors je me suis dit que si ce n'était pas utile pour vous, ce serait pour d'autres personnes qui essaient de dessiner des arbres tout en restant collées au randomForestpaquet.
jgozal
2
Bonne découverte. Remarque: Il trace l'arborescence représentative, en un sens, l'arbre de l'ensemble qui est en moyenne le "plus proche" de tous les autres arbres de l'ensemble
Chris
2
@Chris La fonction plot.getTree()trace un arbre individuel. La fonction plot.reprtree()dans ce package trace un arbre représentatif.
Chun Li
1
J'ai obtenu le modèle de caret et je veux alimenter reptree avec reprtree:::plot.getTree(mod_rf_1$finalModel), cependant, il y a une "erreur dans data.frame (var = fr $ var, splits = as.character (gTree [," point de partage "]),: les arguments impliquent différant. nombre de rangées: 2631, 0 "
HappyCoding
15

J'ai créé des fonctions pour extraire les règles d'un arbre.

#**************************
#return the rules of a tree
#**************************
getConds<-function(tree){
  #store all conditions into a list
  conds<-list()
  #start by the terminal nodes and find previous conditions
  id.leafs<-which(tree$status==-1)
	  j<-0
	  for(i in id.leafs){
		j<-j+1
		prevConds<-prevCond(tree,i)
		conds[[j]]<-prevConds$cond
		while(prevConds$id>1){
		  prevConds<-prevCond(tree,prevConds$id)
		  conds[[j]]<-paste(conds[[j]]," & ",prevConds$cond)
        }
		if(prevConds$id==1){
			conds[[j]]<-paste(conds[[j]]," => ",tree$prediction[i])
    }
    }

  }

  return(conds)
}

#**************************
#find the previous conditions in the tree
#**************************
prevCond<-function(tree,i){
  if(i %in% tree$right_daughter){
		id<-which(tree$right_daughter==i)
		cond<-paste(tree$split_var[id],">",tree$split_point[id])
	  }
	  if(i %in% tree$left_daughter){
    id<-which(tree$left_daughter==i)
		cond<-paste(tree$split_var[id],"<",tree$split_point[id])
  }

  return(list(cond=cond,id=id))
}

#remove spaces in a word
collapse<-function(x){
  x<-sub(" ","_",x)

  return(x)
}


data(iris)
require(randomForest)
mod.rf <- randomForest(Species ~ ., data=iris)
tree<-getTree(mod.rf, k=1, labelVar=TRUE)
#rename the name of the column
colnames(tree)<-sapply(colnames(tree),collapse)
rules<-getConds(tree)
print(rules)
Dalpozz
la source