Forfait R pour la forêt aléatoire pondérée? option classwt?

16

J'essaie d'utiliser Random Forest pour prédire le résultat d'un ensemble de données extrêmement déséquilibré (le taux de classe minoritaire n'est que d'environ 1% ou même moins). Étant donné que l'algorithme traditionnel de la forêt aléatoire minimise le taux d'erreur global, plutôt que d'accorder une attention particulière aux classes minoritaires, il n'est pas directement applicable aux données déséquilibrées. Je veux donc attribuer un coût élevé à une mauvaise classification de la classe minoritaire (apprentissage sensible aux coûts).

J'ai lu plusieurs sources que nous pouvons utiliser l'option classwtde randomForestla R, mais je ne sais pas comment l' utiliser. Et avons-nous d'autres alternatives à la fonction randomForest?

Metariat
la source

Réponses:

29

Ce fil se réfère à deux autres fils et à un bel article sur cette question. Il semble que la pondération par classe et le sous-échantillonnage soient tout aussi bons. J'utilise le sous-échantillonnage comme décrit ci-dessous.

N'oubliez pas que l'ensemble d'entraînement doit être volumineux car seulement 1% caractérisera la classe rare. Moins de 25 ~ 50 échantillons de cette classe seront probablement problématiques. Peu d'échantillons caractérisant la classe rendront inévitablement le modèle appris grossier et moins reproductible.

RF utilise le vote majoritaire par défaut. Les prévalences de classe de l'ensemble de formation fonctionneront comme une sorte de préalable efficace. Ainsi, à moins que la classe rare ne soit parfaitement séparable, il est peu probable que cette classe rare remporte un vote majoritaire lors de la prédiction. Au lieu d'agréger par vote majoritaire, vous pouvez agréger les fractions de vote.

L'échantillonnage stratifié peut être utilisé pour augmenter l'influence de la classe rare. Cela se fait sur le coût de sous-échantillonnage des autres classes. Les arbres cultivés deviendront moins profonds car il faudra diviser beaucoup moins d'échantillons, ce qui limitera la complexité du modèle potentiel appris. Le nombre d'arbres cultivés doit être important, par exemple 4000, de sorte que la plupart des observations participent à plusieurs arbres.

Dans l'exemple ci-dessous, j'ai simulé un ensemble de données de formation de 5000 échantillons avec 3 classes avec des prévalences de 1%, 49% et 50% respectivement. Ainsi, il y aura 50 échantillons de classe 0. La première figure montre la véritable classe de l'ensemble d'apprentissage en fonction de deux variables x1 et x2. Cette image montre le motif simulé pour apprendre

Quatre modèles ont été formés: un modèle par défaut et trois modèles stratifiés avec une stratification 1:10:10 1: 2: 2 et 1: 1: 1 des classes. Principal alors que le nombre d'échantillons inbag (y compris les redessins) dans chaque arbre sera de 5000, 1050, 250 et 150. Comme je n'utilise pas le vote majoritaire, je n'ai pas besoin de faire une stratification parfaitement équilibrée. Au lieu de cela, les votes sur les classes rares pourraient être pondérés 10 fois ou une autre règle de décision. Le coût des faux négatifs et des faux positifs devrait influencer cette règle.

La figure suivante montre comment la stratification influence les fractions de vote. Notez que les ratios de classe stratifiés sont toujours le centre de gravité des prédictions. stratification et vote

Enfin, vous pouvez utiliser une courbe ROC pour trouver une règle de vote qui vous offre un bon compromis entre spécificité et sensibilité. La ligne noire n'est pas stratifiée, rouge 1: 5: 5, verte 1: 2: 2 et bleue 1: 1: 1. Pour cet ensemble de données 1: 2: 2 ou 1: 1: 1 semble le meilleur choix. courbe roc

Soit dit en passant, les fractions de vote sont ici validées de façon croisée.

Et le code:

library(plotrix)
library(randomForest)
library(AUC)

make.data = function(obs=5000,vars=6,noise.factor = .2,smallGroupFraction=.01) {
X = data.frame(replicate(vars,rnorm(obs)))
yValue = with(X,sin(X1*pi)+sin(X2*pi*2)+rnorm(obs)*noise.factor)
yQuantile = quantile(yValue,c(smallGroupFraction,.5))
yClass = apply(sapply(yQuantile,function(x) x<yValue),1,sum)
yClass = factor(yClass)
print(table(yClass)) #five classes, first class has 1% prevalence only
Data=data.frame(X=X,y=yClass)
}

plot.separation = function(rf,...) {
triax.plot(rf$votes,...,col.symbols = c("#FF0000FF",
                                       "#00FF0010",
                                       "#0000FF10")[as.numeric(rf$y)])
}

#make data set where class "0"(red circles) are rare observations
#Class 0 is somewhat separateble from class "1" and fully separateble from class "2"
Data = make.data()
par(mfrow=c(1,1))
plot(Data[,1:2],main="separation problem: identify rare red circles",
     col = c("#FF0000FF","#00FF0020","#0000FF20")[as.numeric(Data$y)])

#train default RF and with 10x 30x and 100x upsumpling by stratification
rf1 = randomForest(y~.,Data,ntree=500, sampsize=5000)
rf2 = randomForest(y~.,Data,ntree=4000,sampsize=c(50,500,500),strata=Data$y)
rf3 = randomForest(y~.,Data,ntree=4000,sampsize=c(50,100,100),strata=Data$y)
rf4 = randomForest(y~.,Data,ntree=4000,sampsize=c(50,50,50)  ,strata=Data$y)

#plot out-of-bag pluralistic predictions(vote fractions).
par(mfrow=c(2,2),mar=c(4,4,3,3))
plot.separation(rf1,main="no stratification")
plot.separation(rf2,main="1:10:10")
plot.separation(rf3,main="1:5:5")
plot.separation(rf4,main="1:1:1")

par(mfrow=c(1,1))
plot(roc(rf1$votes[,1],factor(1 * (rf1$y==0))),main="ROC curves for four models predicting class 0")
plot(roc(rf2$votes[,1],factor(1 * (rf1$y==0))),col=2,add=T)
plot(roc(rf3$votes[,1],factor(1 * (rf1$y==0))),col=3,add=T)
plot(roc(rf4$votes[,1],factor(1 * (rf1$y==0))),col=4,add=T)
Soren Havelund Welling
la source
oups une légende indique 1: 5: 5 au lieu de 1: 2: 2
Soren Havelund Welling
je vous remercie beaucoup pour votre réponse détaillée, qui m'aidera certainement beaucoup dans mon travail quotidien. Il y a une phrase que je ne comprends pas: "Principale alors que le nombre d'échantillons dans le sac (y compris les retraits) dans chaque arbre sera de 5000, 1050, 250 et 150". Pourriez-vous s'il vous plaît m'expliquer d'où viennent les chiffres?
Metariat
1
mon plaisir;) dans cet exemple la classe rare comptait 50 membres. Si la stratification 1:10:10 nous aurions besoin de spécifier sampsize = c (50,500,500). 50 + 500 + 500 = 1050. Un arbre adulte de 1050 échantillons aura 1050x2 nœuds au total.
Soren Havelund Welling
Désolé si ma question est idiote, mais quelle est la signification de la stratification 1:10:10, 1: 2: 2 et 1: 1: 1 ici? Et quand vous avez dit "les votes sur les classes rares pourraient être pondérés 10 fois". Quelle partie du code représente cela? Est-ce 1:10:10? Merci beaucoup!
Metariat
1
1:10:10 sont les ratios entre les classes. L'ensemble de données simulé a été conçu pour avoir les rapports 1:49:50. Ces ratios ont été modifiés en sous-échantillonnant les deux classes plus importantes. En choisissant par exemple sampsize = c (50,500,500) de la même manière que c (1,10,10) * 50, vous modifiez les rapports de classe dans les arbres. 50 est le nombre d'échantillons de la classe rare. Si vous définissez en outre keep.inbag = TRUE et inspectez rf $ inbag, vous verrez que les échantillons des classes rares sont inbag dans ~ 2/3 arbres alors que chaque échantillon de classe non rare est inclus dans très peu d'arbres en raison d'un échantillonnage à la baisse.
Soren Havelund Welling du